jameszokah commited on
Commit
68b189e
·
1 Parent(s): be055a2

Refactor database management to use MongoDB with async support; update audiobook routes for MongoDB integration and improve error handling.

Browse files
Files changed (4) hide show
  1. app/api/audiobook_routes.py +151 -164
  2. app/db.py +77 -22
  3. app/main.py +11 -1
  4. requirements.txt +3 -2
app/api/audiobook_routes.py CHANGED
@@ -13,202 +13,189 @@ from fastapi.responses import FileResponse, JSONResponse
13
  from sqlalchemy.orm import Session
14
  from app.db_models.database import Audiobook, AudiobookStatus, AudiobookChunk, TextChunk
15
  from app.services.storage import storage
16
- from app.db import get_db
 
 
17
  import torchaudio
 
 
 
18
 
19
  # Set up logging
20
  logger = logging.getLogger(__name__)
21
  router = APIRouter(prefix="/audiobook", tags=["Audiobook"])
22
 
23
- async def process_audiobook(
24
- request: Request,
25
- book_id: str,
26
- text_content: str,
27
- voice_id: int,
28
- db: Session
29
- ):
30
- """Process audiobook in the background."""
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  try:
32
- # Get the book from database
33
- book = db.query(Audiobook).filter(Audiobook.id == book_id).first()
34
- if not book:
35
- logger.error(f"Book {book_id} not found")
36
- return False
37
-
38
  # Update status to processing
39
- book.status = AudiobookStatus.PROCESSING
40
- db.commit()
41
-
42
- logger.info(f"Starting processing for audiobook {book_id}")
43
-
44
- # Get the generator from app state
45
- generator = request.app.state.generator
46
- if generator is None:
47
- raise Exception("TTS model not available")
48
-
49
- # Get voice info
50
- voice_info = request.app.state.get_voice_info(voice_id)
51
- if not voice_info:
52
- raise Exception(f"Voice ID {voice_id} not found")
53
-
54
- # Generate audio for the entire text
55
- logger.info(f"Generating audio for entire text of book {book_id}")
56
- audio = generator.generate(
57
- text=text_content,
58
- speaker=voice_info["speaker_id"],
59
- max_audio_length_ms=min(300000, len(text_content) * 80) # Big text = big audio
60
  )
61
 
62
- if audio is None:
63
- raise Exception("Failed to generate audio")
 
 
64
 
65
- # Save the audio using storage service
66
- audio_to_save = audio.unsqueeze(0).cpu() if len(audio.shape) == 1 else audio.cpu()
67
- audio_bytes = audio_to_save.numpy().tobytes()
68
- audio_path = await storage.save_audio_file(book_id, audio_bytes)
69
 
70
- # Update book status in database
71
- book.status = AudiobookStatus.COMPLETED
72
- book.audio_file_path = audio_path
73
- db.commit()
74
-
75
- logger.info(f"Successfully created audiobook {book_id}")
76
- return True
 
 
 
 
77
 
78
  except Exception as e:
79
- logger.error(f"Error processing audiobook {book_id}: {e}")
80
-
81
- # Update status to failed in database
82
- book = db.query(Audiobook).filter(Audiobook.id == book_id).first()
83
- if book:
84
- book.status = AudiobookStatus.FAILED
85
- book.error_message = str(e)
86
- db.commit()
87
-
88
- return False
89
-
90
- @router.post("/")
 
91
  async def create_audiobook(
92
- request: Request,
93
  background_tasks: BackgroundTasks,
94
- title: str = Form(...),
95
- author: str = Form(...),
96
- voice_id: int = Form(0),
97
  text_file: Optional[UploadFile] = File(None),
98
- text_content: Optional[str] = Form(None),
99
- db: Session = Depends(get_db)
100
  ):
101
- """Create a new audiobook from text."""
102
- try:
103
- # Validate input
104
- if not text_file and not text_content:
105
- raise HTTPException(status_code=400, detail="Either text_file or text_content is required")
106
-
107
- # Generate unique ID
108
- book_id = str(uuid.uuid4())
109
-
110
- # Handle text content
111
- if text_file:
112
- text_file_path = await storage.save_text_file(book_id, text_file)
113
- with open(text_file_path, "r", encoding="utf-8") as f:
114
- text_content = f.read()
115
- else:
116
- text_file_path = await storage.save_text_content(book_id, text_content)
117
-
118
- # Create book in database
119
- book = Audiobook(
120
- id=book_id,
121
- title=title,
122
- author=author,
123
- voice_id=voice_id,
124
- status=AudiobookStatus.PENDING,
125
- text_file_path=text_file_path,
126
- text_content=text_content if len(text_content) <= 10000 else None # Store small texts directly
127
  )
128
- db.add(book)
129
- db.commit()
130
 
131
- # Process in background
132
- background_tasks.add_task(process_audiobook, request, book_id, text_content, voice_id, db)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
- return JSONResponse(content={"message": "Audiobook creation started", "book_id": book_id})
135
- except Exception as e:
136
- raise HTTPException(status_code=500, detail=f"Error creating audiobook: {str(e)}")
137
 
138
- @router.get("/{book_id}")
139
- async def get_audiobook(book_id: str, db: Session = Depends(get_db)):
 
 
 
 
 
140
  """Get audiobook information."""
141
- book = db.query(Audiobook).filter(Audiobook.id == book_id).first()
142
- if not book:
 
143
  raise HTTPException(status_code=404, detail="Audiobook not found")
144
-
145
- return {
146
- "id": book.id,
147
- "title": book.title,
148
- "author": book.author,
149
- "voice_id": book.voice_id,
150
- "status": book.status.value,
151
- "created_at": book.created_at.isoformat(),
152
- "updated_at": book.updated_at.isoformat(),
153
- "error_message": book.error_message
154
- }
155
 
156
  @router.get("/{book_id}/audio")
157
- async def get_audiobook_audio(book_id: str, db: Session = Depends(get_db)):
158
- """Get the audiobook audio file."""
159
- book = db.query(Audiobook).filter(Audiobook.id == book_id).first()
160
- if not book:
 
 
161
  raise HTTPException(status_code=404, detail="Audiobook not found")
162
-
163
- if book.status != AudiobookStatus.COMPLETED or not book.audio_file_path:
164
- raise HTTPException(status_code=400, detail="Audiobook is not yet completed")
165
-
166
- audio_path = await storage.get_audio_file(book_id)
167
- if not audio_path:
 
 
 
168
  raise HTTPException(status_code=404, detail="Audio file not found")
169
-
170
  return FileResponse(
171
- str(audio_path),
172
- media_type="audio/wav",
173
- filename=f"{book.title}.wav"
174
  )
175
 
176
- @router.get("/")
177
- async def get_audiobooks(db: Session = Depends(get_db)):
178
- """Get all audiobooks."""
179
- books = db.query(Audiobook).order_by(Audiobook.created_at.desc()).all()
180
- return {
181
- "audiobooks": [
182
- {
183
- "id": book.id,
184
- "title": book.title,
185
- "author": book.author,
186
- "voice_id": book.voice_id,
187
- "status": book.status.value,
188
- "created_at": book.created_at.isoformat(),
189
- "updated_at": book.updated_at.isoformat(),
190
- "error_message": book.error_message
191
- }
192
- for book in books
193
- ]
194
- }
195
 
196
  @router.delete("/{book_id}")
197
- async def delete_audiobook(book_id: str, db: Session = Depends(get_db)):
198
  """Delete an audiobook."""
199
- book = db.query(Audiobook).filter(Audiobook.id == book_id).first()
200
- if not book:
 
 
201
  raise HTTPException(status_code=404, detail="Audiobook not found")
202
-
203
- try:
204
- # Delete associated files
205
- await storage.delete_book_files(book_id)
206
-
207
- # Delete from database
208
- db.delete(book)
209
- db.commit()
210
-
211
- return {"message": "Audiobook deleted successfully"}
212
- except Exception as e:
213
- db.rollback()
214
- raise HTTPException(status_code=500, detail=f"Error deleting audiobook: {str(e)}")
 
13
  from sqlalchemy.orm import Session
14
  from app.db_models.database import Audiobook, AudiobookStatus, AudiobookChunk, TextChunk
15
  from app.services.storage import storage
16
+ from app.db import get_db, AUDIOBOOKS_COLLECTION
17
+ from app.config import AUDIO_DIR, TEXT_DIR, TEMP_DIR
18
+ from pydantic import BaseModel
19
  import torchaudio
20
+ import json
21
+ import shutil
22
+ from motor.motor_asyncio import AsyncIOMotorDatabase
23
 
24
  # Set up logging
25
  logger = logging.getLogger(__name__)
26
  router = APIRouter(prefix="/audiobook", tags=["Audiobook"])
27
 
28
+ class AudiobookBase(BaseModel):
29
+ title: str
30
+ author: str
31
+ voice_id: str
32
+ status: str = "pending"
33
+ created_at: datetime = datetime.utcnow()
34
+ updated_at: datetime = datetime.utcnow()
35
+
36
+ class Audiobook(AudiobookBase):
37
+ id: str
38
+ file_path: Optional[str] = None
39
+ text_path: Optional[str] = None
40
+ error: Optional[str] = None
41
+
42
+ class TextChunk(BaseModel):
43
+ text: str
44
+ start_time: float
45
+ end_time: float
46
+
47
+ async def process_audiobook(book_id: str, db: AsyncIOMotorDatabase):
48
+ """Process the audiobook in the background."""
49
  try:
 
 
 
 
 
 
50
  # Update status to processing
51
+ await db[AUDIOBOOKS_COLLECTION].update_one(
52
+ {"id": book_id},
53
+ {"$set": {"status": "processing", "updated_at": datetime.utcnow()}}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  )
55
 
56
+ # Get the audiobook data
57
+ audiobook = await db[AUDIOBOOKS_COLLECTION].find_one({"id": book_id})
58
+ if not audiobook:
59
+ raise HTTPException(status_code=404, detail="Audiobook not found")
60
 
61
+ # TODO: Implement TTS processing logic here
62
+ # For now, we'll just simulate processing
63
+ logger.info(f"Processing audiobook {book_id}")
 
64
 
65
+ # Update status to completed
66
+ await db[AUDIOBOOKS_COLLECTION].update_one(
67
+ {"id": book_id},
68
+ {
69
+ "$set": {
70
+ "status": "completed",
71
+ "file_path": f"{AUDIO_DIR}/{book_id}.mp3",
72
+ "updated_at": datetime.utcnow()
73
+ }
74
+ }
75
+ )
76
 
77
  except Exception as e:
78
+ logger.error(f"Error processing audiobook {book_id}: {str(e)}")
79
+ await db[AUDIOBOOKS_COLLECTION].update_one(
80
+ {"id": book_id},
81
+ {
82
+ "$set": {
83
+ "status": "failed",
84
+ "error": str(e),
85
+ "updated_at": datetime.utcnow()
86
+ }
87
+ }
88
+ )
89
+
90
+ @router.post("/", response_model=Audiobook)
91
  async def create_audiobook(
 
92
  background_tasks: BackgroundTasks,
93
+ title: str,
94
+ author: str,
95
+ voice_id: str,
96
  text_file: Optional[UploadFile] = File(None),
97
+ text_content: Optional[str] = None,
98
+ request: Request = None
99
  ):
100
+ """Create a new audiobook."""
101
+ db = await get_db()
102
+ book_id = str(uuid.uuid4())
103
+
104
+ # Validate input
105
+ if not text_file and not text_content:
106
+ raise HTTPException(
107
+ status_code=400,
108
+ detail="Either text_file or text_content must be provided"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  )
 
 
110
 
111
+ # Create audiobook document
112
+ audiobook = {
113
+ "id": book_id,
114
+ "title": title,
115
+ "author": author,
116
+ "voice_id": voice_id,
117
+ "status": "pending",
118
+ "created_at": datetime.utcnow(),
119
+ "updated_at": datetime.utcnow()
120
+ }
121
+
122
+ # Handle text input
123
+ if text_file:
124
+ text_path = f"{TEXT_DIR}/{book_id}.txt"
125
+ with open(text_path, "wb") as f:
126
+ shutil.copyfileobj(text_file.file, f)
127
+ audiobook["text_path"] = text_path
128
+ else:
129
+ text_path = f"{TEXT_DIR}/{book_id}.txt"
130
+ with open(text_path, "w") as f:
131
+ f.write(text_content)
132
+ audiobook["text_path"] = text_path
133
 
134
+ # Insert audiobook into database
135
+ await db[AUDIOBOOKS_COLLECTION].insert_one(audiobook)
 
136
 
137
+ # Start background processing
138
+ background_tasks.add_task(process_audiobook, book_id, db)
139
+
140
+ return audiobook
141
+
142
+ @router.get("/{book_id}", response_model=Audiobook)
143
+ async def get_audiobook(book_id: str):
144
  """Get audiobook information."""
145
+ db = await get_db()
146
+ audiobook = await db[AUDIOBOOKS_COLLECTION].find_one({"id": book_id})
147
+ if not audiobook:
148
  raise HTTPException(status_code=404, detail="Audiobook not found")
149
+ return audiobook
 
 
 
 
 
 
 
 
 
 
150
 
151
  @router.get("/{book_id}/audio")
152
+ async def get_audiobook_audio(book_id: str):
153
+ """Get audiobook audio file."""
154
+ db = await get_db()
155
+ audiobook = await db[AUDIOBOOKS_COLLECTION].find_one({"id": book_id})
156
+
157
+ if not audiobook:
158
  raise HTTPException(status_code=404, detail="Audiobook not found")
159
+
160
+ if audiobook["status"] != "completed":
161
+ raise HTTPException(
162
+ status_code=400,
163
+ detail=f"Audiobook is not ready (status: {audiobook['status']})"
164
+ )
165
+
166
+ file_path = audiobook.get("file_path")
167
+ if not file_path or not os.path.exists(file_path):
168
  raise HTTPException(status_code=404, detail="Audio file not found")
169
+
170
  return FileResponse(
171
+ file_path,
172
+ media_type="audio/mpeg",
173
+ filename=f"{audiobook['title']}.mp3"
174
  )
175
 
176
+ @router.get("/", response_model=List[Audiobook])
177
+ async def list_audiobooks():
178
+ """List all audiobooks."""
179
+ db = await get_db()
180
+ audiobooks = await db[AUDIOBOOKS_COLLECTION].find().to_list(length=None)
181
+ return audiobooks
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
  @router.delete("/{book_id}")
184
+ async def delete_audiobook(book_id: str):
185
  """Delete an audiobook."""
186
+ db = await get_db()
187
+ audiobook = await db[AUDIOBOOKS_COLLECTION].find_one({"id": book_id})
188
+
189
+ if not audiobook:
190
  raise HTTPException(status_code=404, detail="Audiobook not found")
191
+
192
+ # Delete associated files
193
+ if audiobook.get("file_path") and os.path.exists(audiobook["file_path"]):
194
+ os.remove(audiobook["file_path"])
195
+ if audiobook.get("text_path") and os.path.exists(audiobook["text_path"]):
196
+ os.remove(audiobook["text_path"])
197
+
198
+ # Delete from database
199
+ await db[AUDIOBOOKS_COLLECTION].delete_one({"id": book_id})
200
+
201
+ return {"message": "Audiobook deleted successfully"}
 
 
app/db.py CHANGED
@@ -1,29 +1,84 @@
1
- """Database connection and session management."""
2
  import os
3
- from sqlalchemy import create_engine
4
- from sqlalchemy.orm import sessionmaker
5
- from app.db_models.database import Base
 
 
6
 
7
- # Get database URL from environment or use SQLite as default
8
- DATABASE_URL = os.getenv(
9
- "DATABASE_URL",
10
- "sqlite:///app/storage/audiobooks.db"
11
- )
12
 
13
- # Create engine
14
- engine = create_engine(DATABASE_URL)
15
 
16
- # Create session factory
17
- SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
 
18
 
19
- def init_db():
20
- """Initialize the database, creating all tables."""
21
- Base.metadata.create_all(bind=engine)
22
 
23
- def get_db():
24
- """Get a database session."""
25
- db = SessionLocal()
26
  try:
27
- yield db
28
- finally:
29
- db.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MongoDB database configuration."""
2
  import os
3
+ import logging
4
+ from typing import Optional
5
+ from motor.motor_asyncio import AsyncIOMotorClient
6
+ from pymongo.errors import ConnectionFailure
7
+ from dotenv import load_dotenv
8
 
9
+ # Load environment variables
10
+ load_dotenv()
 
 
 
11
 
12
+ # Configure logging
13
+ logger = logging.getLogger(__name__)
14
 
15
+ # Get MongoDB URI from environment variable
16
+ MONGO_URI = os.getenv("MONGO_URI", "mongodb://localhost:27017")
17
+ DB_NAME = os.getenv("DB_NAME", "tts_api")
18
 
19
+ # MongoDB client instance
20
+ client: Optional[AsyncIOMotorClient] = None
 
21
 
22
+ async def connect_to_mongo():
23
+ """Connect to MongoDB."""
24
+ global client
25
  try:
26
+ client = AsyncIOMotorClient(MONGO_URI)
27
+ # Verify the connection
28
+ await client.admin.command('ping')
29
+ logger.info("Successfully connected to MongoDB")
30
+ except ConnectionFailure as e:
31
+ logger.error(f"Could not connect to MongoDB: {e}")
32
+ raise
33
+
34
+ async def close_mongo_connection():
35
+ """Close MongoDB connection."""
36
+ global client
37
+ if client:
38
+ client.close()
39
+ logger.info("MongoDB connection closed")
40
+
41
+ def get_db():
42
+ """Get database instance."""
43
+ if not client:
44
+ raise ConnectionError("MongoDB client not initialized")
45
+ return client[DB_NAME]
46
+
47
+ # Collection names
48
+ AUDIOBOOKS_COLLECTION = "audiobooks"
49
+ VOICES_COLLECTION = "voices"
50
+ AUDIO_CACHE_COLLECTION = "audio_cache"
51
+
52
+ # Database schemas/models
53
+ AUDIOBOOK_SCHEMA = {
54
+ "id": str, # UUID string
55
+ "title": str,
56
+ "author": str,
57
+ "voice_id": str,
58
+ "status": str, # pending, processing, completed, failed
59
+ "created_at": str, # ISO format datetime
60
+ "updated_at": str, # ISO format datetime
61
+ "duration": float,
62
+ "file_path": str,
63
+ "error": str,
64
+ "meta_data": dict
65
+ }
66
+
67
+ VOICE_SCHEMA = {
68
+ "id": str, # UUID string
69
+ "name": str,
70
+ "type": str, # standard, cloned
71
+ "speaker_id": int,
72
+ "created_at": str, # ISO format datetime
73
+ "is_active": bool,
74
+ "meta_data": dict
75
+ }
76
+
77
+ AUDIO_CACHE_SCHEMA = {
78
+ "id": str, # UUID string
79
+ "hash": str, # Hash of input parameters
80
+ "format": str, # Audio format (mp3, wav, etc.)
81
+ "created_at": str, # ISO format datetime
82
+ "file_path": str,
83
+ "meta_data": dict
84
+ }
app/main.py CHANGED
@@ -18,7 +18,7 @@ from fastapi.middleware.cors import CORSMiddleware
18
  from fastapi.responses import RedirectResponse, FileResponse
19
  from fastapi.staticfiles import StaticFiles
20
  from app.api.routes import router as api_router
21
- from app.db_models.database import Base, get_db
22
 
23
  # Setup logging
24
  os.makedirs("logs", exist_ok=True)
@@ -618,6 +618,16 @@ async def root():
618
  logger.debug("Root endpoint accessed, redirecting to docs")
619
  return RedirectResponse(url="/docs")
620
 
 
 
 
 
 
 
 
 
 
 
621
  if __name__ == "__main__":
622
  # Get port from environment or use default
623
  port = int(os.environ.get("PORT", 7860))
 
18
  from fastapi.responses import RedirectResponse, FileResponse
19
  from fastapi.staticfiles import StaticFiles
20
  from app.api.routes import router as api_router
21
+ from app.db import connect_to_mongo, close_mongo_connection
22
 
23
  # Setup logging
24
  os.makedirs("logs", exist_ok=True)
 
618
  logger.debug("Root endpoint accessed, redirecting to docs")
619
  return RedirectResponse(url="/docs")
620
 
621
+ @app.on_event("startup")
622
+ async def startup_db_client():
623
+ """Initialize MongoDB connection on startup."""
624
+ await connect_to_mongo()
625
+
626
+ @app.on_event("shutdown")
627
+ async def shutdown_db_client():
628
+ """Close MongoDB connection on shutdown."""
629
+ await close_mongo_connection()
630
+
631
  if __name__ == "__main__":
632
  # Get port from environment or use default
633
  port = int(os.environ.get("PORT", 7860))
requirements.txt CHANGED
@@ -22,5 +22,6 @@ yt-dlp>=2023.3.4
22
  openai-whisper>=20230314
23
  ffmpeg-python>=0.2.0
24
  accelerate>=0.20.0
25
- alembic>=1.12.0
26
- SQLAlchemy>=2.0.0
 
 
22
  openai-whisper>=20230314
23
  ffmpeg-python>=0.2.0
24
  accelerate>=0.20.0
25
+ pymongo>=4.6.1
26
+ motor>=3.3.2
27
+ python-dotenv>=1.0.1