Spaces:
Paused
Paused
import gradio as gr | |
from pydub import AudioSegment | |
import json | |
import uuid | |
import edge_tts | |
import asyncio | |
import aiofiles | |
import os | |
import time | |
import mimetypes | |
from typing import List, Dict | |
# NEW – Hugging Face Transformers | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
# NEW – external model id | |
MODEL_ID = "tabularisai/german-gemma-3-1b-it" | |
# Constants | |
MAX_FILE_SIZE_MB = 20 | |
MAX_FILE_SIZE_BYTES = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes | |
class PodcastGenerator: | |
def __init__(self): | |
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
self.model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
device_map="auto", | |
).eval() | |
async def generate_script( | |
self, | |
prompt: str, | |
language: str, | |
api_key: str, | |
file_obj=None, | |
progress=None, | |
) -> Dict: | |
example = """ | |
{ | |
"topic": "AGI", | |
"podcast": [ | |
{ | |
"speaker": 2, | |
"line": "So, AGI, huh? Seems like everyone's talking about it these days." | |
}, | |
{ | |
"speaker": 1, | |
"line": "Yeah, it's definitely having a moment, isn't it?" | |
} | |
] | |
} | |
""" | |
if language == "Auto Detect": | |
language_instruction = ( | |
"- The podcast MUST be in the same language as the user input." | |
) | |
else: | |
language_instruction = f"- The podcast MUST be in {language} language" | |
system_prompt = f""" | |
You are a professional podcast generator. Your task is to generate a professional podcast script based on the user input. | |
{language_instruction} | |
- The podcast should have 2 speakers. | |
- The podcast should be long. | |
- Do not use names for the speakers. | |
- The podcast should be interesting, lively, and engaging, and hook the listener from the start. | |
- The input text might be disorganized or unformatted, originating from sources like PDFs or text files. Ignore any formatting inconsistencies or irrelevant details; your task is to distill the essential points, identify key definitions, and highlight intriguing facts that would be suitable for discussion in a podcast. | |
- The script must be in JSON format. | |
Follow this example structure: | |
{example} | |
""" | |
if prompt and file_obj: | |
user_prompt = ( | |
f"Please generate a podcast script based on the uploaded file following user input:\n{prompt}" | |
) | |
elif prompt: | |
user_prompt = ( | |
f"Please generate a podcast script based on the following user input:\n{prompt}" | |
) | |
else: | |
user_prompt = "Please generate a podcast script based on the uploaded file." | |
# If a file is provided we still read it for completeness (not required for HF generation) | |
if file_obj: | |
_ = await self._read_file_bytes(file_obj) | |
if progress: | |
progress(0.3, "Generating podcast script...") | |
inputs = self.tokenizer( | |
f"{system_prompt}\n\n{user_prompt}", return_tensors="pt" | |
).to(self.model.device) | |
try: | |
output = self.model.generate(**inputs, max_new_tokens=2048, temperature=1.0) | |
response_text = self.tokenizer.decode(output[0], skip_special_tokens=True) | |
except Exception as e: | |
raise Exception(f"Failed to generate podcast script: {e}") | |
print(f"Generated podcast script:\n{response_text}") | |
if progress: | |
progress(0.4, "Script generated successfully!") | |
return json.loads(response_text) | |
async def _read_file_bytes(self, file_obj) -> bytes: | |
if hasattr(file_obj, "size"): | |
file_size = file_obj.size | |
else: | |
file_size = os.path.getsize(file_obj.name) | |
if file_size > MAX_FILE_SIZE_BYTES: | |
raise Exception( | |
f"File size exceeds the {MAX_FILE_SIZE_MB}MB limit. Please upload a smaller file." | |
) | |
if hasattr(file_obj, "read"): | |
return file_obj.read() | |
else: | |
async with aiofiles.open(file_obj.name, "rb") as f: | |
return await f.read() | |
def _get_mime_type(filename: str) -> str: | |
ext = os.path.splitext(filename)[1].lower() | |
if ext == ".pdf": | |
return "application/pdf" | |
elif ext == ".txt": | |
return "text/plain" | |
else: | |
mime_type, _ = mimetypes.guess_type(filename) | |
return mime_type or "application/octet-stream" | |