|
import glob |
|
import os |
|
import pathlib |
|
import shutil |
|
from typing import Union |
|
|
|
from fastapi import BackgroundTasks, Depends, Path, Request, UploadFile |
|
from fastapi.params import File |
|
from fastapi.responses import FileResponse, StreamingResponse |
|
from loguru import logger |
|
|
|
from app.config import config |
|
from app.controllers import base |
|
from app.controllers.manager.memory_manager import InMemoryTaskManager |
|
from app.controllers.manager.redis_manager import RedisTaskManager |
|
from app.controllers.v1.base import new_router |
|
from app.models.exception import HttpException |
|
from app.models.schema import ( |
|
AudioRequest, |
|
BgmRetrieveResponse, |
|
BgmUploadResponse, |
|
SubtitleRequest, |
|
TaskDeletionResponse, |
|
TaskQueryRequest, |
|
TaskQueryResponse, |
|
TaskResponse, |
|
TaskVideoRequest, |
|
) |
|
from app.services import state as sm |
|
from app.services import task as tm |
|
from app.utils import utils |
|
|
|
|
|
|
|
router = new_router() |
|
|
|
_enable_redis = config.app.get("enable_redis", False) |
|
_redis_host = config.app.get("redis_host", "localhost") |
|
_redis_port = config.app.get("redis_port", 6379) |
|
_redis_db = config.app.get("redis_db", 0) |
|
_redis_password = config.app.get("redis_password", None) |
|
_max_concurrent_tasks = config.app.get("max_concurrent_tasks", 5) |
|
|
|
redis_url = f"redis://:{_redis_password}@{_redis_host}:{_redis_port}/{_redis_db}" |
|
|
|
if _enable_redis: |
|
task_manager = RedisTaskManager( |
|
max_concurrent_tasks=_max_concurrent_tasks, redis_url=redis_url |
|
) |
|
else: |
|
task_manager = InMemoryTaskManager(max_concurrent_tasks=_max_concurrent_tasks) |
|
|
|
|
|
@router.post("/videos", response_model=TaskResponse, summary="Generate a short video") |
|
def create_video( |
|
background_tasks: BackgroundTasks, request: Request, body: TaskVideoRequest |
|
): |
|
return create_task(request, body, stop_at="video") |
|
|
|
|
|
@router.post("/subtitle", response_model=TaskResponse, summary="Generate subtitle only") |
|
def create_subtitle( |
|
background_tasks: BackgroundTasks, request: Request, body: SubtitleRequest |
|
): |
|
return create_task(request, body, stop_at="subtitle") |
|
|
|
|
|
@router.post("/audio", response_model=TaskResponse, summary="Generate audio only") |
|
def create_audio( |
|
background_tasks: BackgroundTasks, request: Request, body: AudioRequest |
|
): |
|
return create_task(request, body, stop_at="audio") |
|
|
|
|
|
def create_task( |
|
request: Request, |
|
body: Union[TaskVideoRequest, SubtitleRequest, AudioRequest], |
|
stop_at: str, |
|
): |
|
task_id = utils.get_uuid() |
|
request_id = base.get_task_id(request) |
|
try: |
|
task = { |
|
"task_id": task_id, |
|
"request_id": request_id, |
|
"params": body.model_dump(), |
|
} |
|
sm.state.update_task(task_id) |
|
task_manager.add_task(tm.start, task_id=task_id, params=body, stop_at=stop_at) |
|
logger.success(f"Task created: {utils.to_json(task)}") |
|
return utils.get_response(200, task) |
|
except ValueError as e: |
|
raise HttpException( |
|
task_id=task_id, status_code=400, message=f"{request_id}: {str(e)}" |
|
) |
|
|
|
from fastapi import Query |
|
|
|
@router.get("/tasks", response_model=TaskQueryResponse, summary="Get all tasks") |
|
def get_all_tasks(request: Request, page: int = Query(1, ge=1), page_size: int = Query(10, ge=1)): |
|
request_id = base.get_task_id(request) |
|
tasks, total = sm.state.get_all_tasks(page, page_size) |
|
|
|
response = { |
|
"tasks": tasks, |
|
"total": total, |
|
"page": page, |
|
"page_size": page_size, |
|
} |
|
return utils.get_response(200, response) |
|
|
|
|
|
|
|
@router.get( |
|
"/tasks/{task_id}", response_model=TaskQueryResponse, summary="Query task status" |
|
) |
|
def get_task( |
|
request: Request, |
|
task_id: str = Path(..., description="Task ID"), |
|
query: TaskQueryRequest = Depends(), |
|
): |
|
endpoint = config.app.get("endpoint", "") |
|
if not endpoint: |
|
endpoint = str(request.base_url) |
|
endpoint = endpoint.rstrip("/") |
|
|
|
request_id = base.get_task_id(request) |
|
task = sm.state.get_task(task_id) |
|
if task: |
|
task_dir = utils.task_dir() |
|
|
|
def file_to_uri(file): |
|
if not file.startswith(endpoint): |
|
_uri_path = v.replace(task_dir, "tasks").replace("\\", "/") |
|
_uri_path = f"{endpoint}/{_uri_path}" |
|
else: |
|
_uri_path = file |
|
return _uri_path |
|
|
|
if "videos" in task: |
|
videos = task["videos"] |
|
urls = [] |
|
for v in videos: |
|
urls.append(file_to_uri(v)) |
|
task["videos"] = urls |
|
if "combined_videos" in task: |
|
combined_videos = task["combined_videos"] |
|
urls = [] |
|
for v in combined_videos: |
|
urls.append(file_to_uri(v)) |
|
task["combined_videos"] = urls |
|
return utils.get_response(200, task) |
|
|
|
raise HttpException( |
|
task_id=task_id, status_code=404, message=f"{request_id}: task not found" |
|
) |
|
|
|
|
|
@router.delete( |
|
"/tasks/{task_id}", |
|
response_model=TaskDeletionResponse, |
|
summary="Delete a generated short video task", |
|
) |
|
def delete_video(request: Request, task_id: str = Path(..., description="Task ID")): |
|
request_id = base.get_task_id(request) |
|
task = sm.state.get_task(task_id) |
|
if task: |
|
tasks_dir = utils.task_dir() |
|
current_task_dir = os.path.join(tasks_dir, task_id) |
|
if os.path.exists(current_task_dir): |
|
shutil.rmtree(current_task_dir) |
|
|
|
sm.state.delete_task(task_id) |
|
logger.success(f"video deleted: {utils.to_json(task)}") |
|
return utils.get_response(200) |
|
|
|
raise HttpException( |
|
task_id=task_id, status_code=404, message=f"{request_id}: task not found" |
|
) |
|
|
|
|
|
@router.get( |
|
"/musics", response_model=BgmRetrieveResponse, summary="Retrieve local BGM files" |
|
) |
|
def get_bgm_list(request: Request): |
|
suffix = "*.mp3" |
|
song_dir = utils.song_dir() |
|
files = glob.glob(os.path.join(song_dir, suffix)) |
|
bgm_list = [] |
|
for file in files: |
|
bgm_list.append( |
|
{ |
|
"name": os.path.basename(file), |
|
"size": os.path.getsize(file), |
|
"file": file, |
|
} |
|
) |
|
response = {"files": bgm_list} |
|
return utils.get_response(200, response) |
|
|
|
|
|
@router.post( |
|
"/musics", |
|
response_model=BgmUploadResponse, |
|
summary="Upload the BGM file to the songs directory", |
|
) |
|
def upload_bgm_file(request: Request, file: UploadFile = File(...)): |
|
request_id = base.get_task_id(request) |
|
|
|
if file.filename.endswith("mp3"): |
|
song_dir = utils.song_dir() |
|
save_path = os.path.join(song_dir, file.filename) |
|
|
|
with open(save_path, "wb+") as buffer: |
|
|
|
file.file.seek(0) |
|
buffer.write(file.file.read()) |
|
response = {"file": save_path} |
|
return utils.get_response(200, response) |
|
|
|
raise HttpException( |
|
"", status_code=400, message=f"{request_id}: Only *.mp3 files can be uploaded" |
|
) |
|
|
|
|
|
@router.get("/stream/{file_path:path}") |
|
async def stream_video(request: Request, file_path: str): |
|
tasks_dir = utils.task_dir() |
|
video_path = os.path.join(tasks_dir, file_path) |
|
range_header = request.headers.get("Range") |
|
video_size = os.path.getsize(video_path) |
|
start, end = 0, video_size - 1 |
|
|
|
length = video_size |
|
if range_header: |
|
range_ = range_header.split("bytes=")[1] |
|
start, end = [int(part) if part else None for part in range_.split("-")] |
|
if start is None: |
|
start = video_size - end |
|
end = video_size - 1 |
|
if end is None: |
|
end = video_size - 1 |
|
length = end - start + 1 |
|
|
|
def file_iterator(file_path, offset=0, bytes_to_read=None): |
|
with open(file_path, "rb") as f: |
|
f.seek(offset, os.SEEK_SET) |
|
remaining = bytes_to_read or video_size |
|
while remaining > 0: |
|
bytes_to_read = min(4096, remaining) |
|
data = f.read(bytes_to_read) |
|
if not data: |
|
break |
|
remaining -= len(data) |
|
yield data |
|
|
|
response = StreamingResponse( |
|
file_iterator(video_path, start, length), media_type="video/mp4" |
|
) |
|
response.headers["Content-Range"] = f"bytes {start}-{end}/{video_size}" |
|
response.headers["Accept-Ranges"] = "bytes" |
|
response.headers["Content-Length"] = str(length) |
|
response.status_code = 206 |
|
|
|
return response |
|
|
|
|
|
@router.get("/download/{file_path:path}") |
|
async def download_video(_: Request, file_path: str): |
|
""" |
|
download video |
|
:param _: Request request |
|
:param file_path: video file path, eg: /cd1727ed-3473-42a2-a7da-4faafafec72b/final-1.mp4 |
|
:return: video file |
|
""" |
|
tasks_dir = utils.task_dir() |
|
video_path = os.path.join(tasks_dir, file_path) |
|
file_path = pathlib.Path(video_path) |
|
filename = file_path.stem |
|
extension = file_path.suffix |
|
headers = {"Content-Disposition": f"attachment; filename={filename}{extension}"} |
|
return FileResponse( |
|
path=video_path, |
|
headers=headers, |
|
filename=f"{filename}{extension}", |
|
media_type=f"video/{extension[1:]}", |
|
) |
|
|