Spaces:
Running
Running
import glob | |
import os | |
import pathlib | |
import shutil | |
from typing import Union | |
from fastapi import BackgroundTasks, Depends, Path, Request, UploadFile | |
from fastapi.params import File | |
from fastapi.responses import FileResponse, StreamingResponse | |
from loguru import logger | |
from app.config import config | |
from app.controllers import base | |
from app.controllers.manager.memory_manager import InMemoryTaskManager | |
from app.controllers.manager.redis_manager import RedisTaskManager | |
from app.controllers.v1.base import new_router | |
from app.models.exception import HttpException | |
from app.models.schema import ( | |
AudioRequest, | |
BgmRetrieveResponse, | |
BgmUploadResponse, | |
SubtitleRequest, | |
TaskDeletionResponse, | |
TaskQueryRequest, | |
TaskQueryResponse, | |
TaskResponse, | |
TaskVideoRequest, | |
) | |
from app.services import state as sm | |
from app.services import task as tm | |
from app.utils import utils | |
# 认证依赖项 | |
# router = new_router(dependencies=[Depends(base.verify_token)]) | |
router = new_router() | |
_enable_redis = config.app.get("enable_redis", False) | |
_redis_host = config.app.get("redis_host", "localhost") | |
_redis_port = config.app.get("redis_port", 6379) | |
_redis_db = config.app.get("redis_db", 0) | |
_redis_password = config.app.get("redis_password", None) | |
_max_concurrent_tasks = config.app.get("max_concurrent_tasks", 5) | |
redis_url = f"redis://:{_redis_password}@{_redis_host}:{_redis_port}/{_redis_db}" | |
# 根据配置选择合适的任务管理器 | |
if _enable_redis: | |
task_manager = RedisTaskManager( | |
max_concurrent_tasks=_max_concurrent_tasks, redis_url=redis_url | |
) | |
else: | |
task_manager = InMemoryTaskManager(max_concurrent_tasks=_max_concurrent_tasks) | |
def create_video( | |
background_tasks: BackgroundTasks, request: Request, body: TaskVideoRequest | |
): | |
return create_task(request, body, stop_at="video") | |
def create_subtitle( | |
background_tasks: BackgroundTasks, request: Request, body: SubtitleRequest | |
): | |
return create_task(request, body, stop_at="subtitle") | |
def create_audio( | |
background_tasks: BackgroundTasks, request: Request, body: AudioRequest | |
): | |
return create_task(request, body, stop_at="audio") | |
def create_task( | |
request: Request, | |
body: Union[TaskVideoRequest, SubtitleRequest, AudioRequest], | |
stop_at: str, | |
): | |
task_id = utils.get_uuid() | |
request_id = base.get_task_id(request) | |
try: | |
task = { | |
"task_id": task_id, | |
"request_id": request_id, | |
"params": body.model_dump(), | |
} | |
sm.state.update_task(task_id) | |
task_manager.add_task(tm.start, task_id=task_id, params=body, stop_at=stop_at) | |
logger.success(f"Task created: {utils.to_json(task)}") | |
return utils.get_response(200, task) | |
except ValueError as e: | |
raise HttpException( | |
task_id=task_id, status_code=400, message=f"{request_id}: {str(e)}" | |
) | |
def get_task( | |
request: Request, | |
task_id: str = Path(..., description="Task ID"), | |
query: TaskQueryRequest = Depends(), | |
): | |
endpoint = config.app.get("endpoint", "") | |
if not endpoint: | |
endpoint = str(request.base_url) | |
endpoint = endpoint.rstrip("/") | |
request_id = base.get_task_id(request) | |
task = sm.state.get_task(task_id) | |
if task: | |
task_dir = utils.task_dir() | |
def file_to_uri(file): | |
if not file.startswith(endpoint): | |
_uri_path = v.replace(task_dir, "tasks").replace("\\", "/") | |
_uri_path = f"{endpoint}/{_uri_path}" | |
else: | |
_uri_path = file | |
return _uri_path | |
if "videos" in task: | |
videos = task["videos"] | |
urls = [] | |
for v in videos: | |
urls.append(file_to_uri(v)) | |
task["videos"] = urls | |
if "combined_videos" in task: | |
combined_videos = task["combined_videos"] | |
urls = [] | |
for v in combined_videos: | |
urls.append(file_to_uri(v)) | |
task["combined_videos"] = urls | |
return utils.get_response(200, task) | |
raise HttpException( | |
task_id=task_id, status_code=404, message=f"{request_id}: task not found" | |
) | |
def delete_video(request: Request, task_id: str = Path(..., description="Task ID")): | |
request_id = base.get_task_id(request) | |
task = sm.state.get_task(task_id) | |
if task: | |
tasks_dir = utils.task_dir() | |
current_task_dir = os.path.join(tasks_dir, task_id) | |
if os.path.exists(current_task_dir): | |
shutil.rmtree(current_task_dir) | |
sm.state.delete_task(task_id) | |
logger.success(f"video deleted: {utils.to_json(task)}") | |
return utils.get_response(200) | |
raise HttpException( | |
task_id=task_id, status_code=404, message=f"{request_id}: task not found" | |
) | |
# @router.get( | |
# "/musics", response_model=BgmRetrieveResponse, summary="Retrieve local BGM files" | |
# ) | |
# def get_bgm_list(request: Request): | |
# suffix = "*.mp3" | |
# song_dir = utils.song_dir() | |
# files = glob.glob(os.path.join(song_dir, suffix)) | |
# bgm_list = [] | |
# for file in files: | |
# bgm_list.append( | |
# { | |
# "name": os.path.basename(file), | |
# "size": os.path.getsize(file), | |
# "file": file, | |
# } | |
# ) | |
# response = {"files": bgm_list} | |
# return utils.get_response(200, response) | |
# | |
# @router.post( | |
# "/musics", | |
# response_model=BgmUploadResponse, | |
# summary="Upload the BGM file to the songs directory", | |
# ) | |
# def upload_bgm_file(request: Request, file: UploadFile = File(...)): | |
# request_id = base.get_task_id(request) | |
# # check file ext | |
# if file.filename.endswith("mp3"): | |
# song_dir = utils.song_dir() | |
# save_path = os.path.join(song_dir, file.filename) | |
# # save file | |
# with open(save_path, "wb+") as buffer: | |
# # If the file already exists, it will be overwritten | |
# file.file.seek(0) | |
# buffer.write(file.file.read()) | |
# response = {"file": save_path} | |
# return utils.get_response(200, response) | |
# | |
# raise HttpException( | |
# "", status_code=400, message=f"{request_id}: Only *.mp3 files can be uploaded" | |
# ) | |
# | |
# | |
# @router.get("/stream/{file_path:path}") | |
# async def stream_video(request: Request, file_path: str): | |
# tasks_dir = utils.task_dir() | |
# video_path = os.path.join(tasks_dir, file_path) | |
# range_header = request.headers.get("Range") | |
# video_size = os.path.getsize(video_path) | |
# start, end = 0, video_size - 1 | |
# | |
# length = video_size | |
# if range_header: | |
# range_ = range_header.split("bytes=")[1] | |
# start, end = [int(part) if part else None for part in range_.split("-")] | |
# if start is None: | |
# start = video_size - end | |
# end = video_size - 1 | |
# if end is None: | |
# end = video_size - 1 | |
# length = end - start + 1 | |
# | |
# def file_iterator(file_path, offset=0, bytes_to_read=None): | |
# with open(file_path, "rb") as f: | |
# f.seek(offset, os.SEEK_SET) | |
# remaining = bytes_to_read or video_size | |
# while remaining > 0: | |
# bytes_to_read = min(4096, remaining) | |
# data = f.read(bytes_to_read) | |
# if not data: | |
# break | |
# remaining -= len(data) | |
# yield data | |
# | |
# response = StreamingResponse( | |
# file_iterator(video_path, start, length), media_type="video/mp4" | |
# ) | |
# response.headers["Content-Range"] = f"bytes {start}-{end}/{video_size}" | |
# response.headers["Accept-Ranges"] = "bytes" | |
# response.headers["Content-Length"] = str(length) | |
# response.status_code = 206 # Partial Content | |
# | |
# return response | |
# | |
# | |
# @router.get("/download/{file_path:path}") | |
# async def download_video(_: Request, file_path: str): | |
# """ | |
# download video | |
# :param _: Request request | |
# :param file_path: video file path, eg: /cd1727ed-3473-42a2-a7da-4faafafec72b/final-1.mp4 | |
# :return: video file | |
# """ | |
# tasks_dir = utils.task_dir() | |
# video_path = os.path.join(tasks_dir, file_path) | |
# file_path = pathlib.Path(video_path) | |
# filename = file_path.stem | |
# extension = file_path.suffix | |
# headers = {"Content-Disposition": f"attachment; filename={filename}{extension}"} | |
# return FileResponse( | |
# path=video_path, | |
# headers=headers, | |
# filename=f"{filename}{extension}", | |
# media_type=f"video/{extension[1:]}", | |
# ) | |