Ggggggg / app.py
jnjj's picture
Update app.py
2bd504c verified
import os
import gc
import io
import asyncio
import threading
import time
import torch
from fastapi import FastAPI, HTTPException, BackgroundTasks, Response
from fastapi.responses import StreamingResponse, JSONResponse, PlainTextResponse, HTMLResponse
from pydantic import BaseModel, Field
from transformers import (
AutoConfig,
AutoTokenizer,
GenerationConfig,
AutoModelForCausalLM,
AutoProcessor,
MusicgenForConditionalGeneration
)
from diffusers import AnimateDiffPipeline, LCMScheduler, MotionAdapter
from diffusers.utils import export_to_gif
from concurrent.futures import ThreadPoolExecutor
import uvicorn
from duckduckgo_search import DDGS
# -----------------------
# Configuraci贸n y optimizaciones
# -----------------------
MODEL_NAME = "jnjj/gemma-3-4b-it-1layer-actual"
MAX_CONTEXT_LEN = 1024
MUSICGEN_MAX_TOKENS = 256
global_model = None
global_tokenizer = None
global_tokens = {}
motion_adapter = None
anim_pipe = None
music_processor = None
music_model = None
executor = ThreadPoolExecutor(max_workers=4)
# -----------------------
# FastAPI App
# -----------------------
app = FastAPI()
@app.on_event("startup")
def load_global_models():
"""Carga modelos de texto, animaci贸n y audio con optimizaciones."""
global global_model, global_tokenizer, global_tokens
global motion_adapter, anim_pipe
global music_processor, music_model
# --- Text model ---
text_config = AutoConfig.from_pretrained(MODEL_NAME)
text_config.max_position_embeddings = MAX_CONTEXT_LEN
global_tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
config=text_config,
use_fast=True
)
global_tokenizer.model_max_length = MAX_CONTEXT_LEN
global_model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
config=text_config,
device_map="auto",
offload_folder="./offload",
offload_state_dict=True,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
trust_remote_code=True,
# use_cache=True
)
global_model = torch.compile(global_model, backend="inductor")
if global_tokenizer.eos_token_id is not None and global_tokenizer.pad_token_id is None:
global_tokenizer.pad_token_id = global_tokenizer.eos_token_id
global_tokens.update({
"eos_token_id": global_tokenizer.eos_token_id,
"pad_token_id": global_tokenizer.pad_token_id
})
# --- Animation model ---
motion_adapter = MotionAdapter.from_pretrained(
"wangfuyun/AnimateLCM", torch_dtype=torch.float16
)
anim_pipe = AnimateDiffPipeline.from_pretrained(
"emilianJR/epiCRealism",
motion_adapter=motion_adapter,
torch_dtype=torch.float16
)
anim_pipe.scheduler = LCMScheduler.from_config(
anim_pipe.scheduler.config, beta_schedule="linear"
)
anim_pipe.load_lora_weights(
"wangfuyun/AnimateLCM",
weight_name="AnimateLCM_sd15_t2v_lora.safetensors",
adapter_name="lcm-lora"
)
anim_pipe.set_adapters(["lcm-lora"], [0.8])
anim_pipe.enable_vae_slicing()
anim_pipe.enable_model_cpu_offload()
# --- MusicGen model ---
music_processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
music_model = MusicgenForConditionalGeneration.from_pretrained(
"facebook/musicgen-small"
)
music_model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
print("Modelos de texto, animaci贸n y audio cargados sin bitsandbytes.")
@app.get("/", response_class=HTMLResponse)
def index():
return HTMLResponse(
content="""
<html><head><title>Generador Ultra-R谩pido</title></head>
<body>
<h1>Servicio de Generaci贸n Multimedia</h1>
<ul>
<li>Texto: FP16, offload, torch.compile.</li>
<li>Animaci贸n: AnimateDiffPipeline con LoRA y CPU offload.</li>
<li>Audio: MusicGen small, max tokens 256.</li>
</ul>
</body></html>
""",
status_code=200
)
@app.get("/health")
def health():
return {"status": "ok"}
class GenerateRequest(BaseModel):
input_text: str = ""
max_new_tokens: int = 2
temperature: float = Field(default_factory=lambda: round(torch.rand(1).item() * 0.3 + 0.5, 2))
top_p: float = Field(default_factory=lambda: round(torch.rand(1).item() * 0.2 + 0.75, 2))
top_k: int = Field(default_factory=lambda: int(torch.randint(20, 61, (1,)).item()))
repetition_penalty: float = Field(default_factory=lambda: round(torch.rand(1).item() * 0.7 + 1.1, 2))
frequency_penalty: float = Field(default_factory=lambda: round(torch.rand(1).item() * 0.5 + 0.2, 2))
presence_penalty: float = Field(default_factory=lambda: round(torch.rand(1).item() * 0.5 + 0.2, 2))
seed: int = Field(default_factory=lambda: int(torch.randint(0, 1001, (1,)).item()))
do_sample: bool = True
stream: bool = True
chunk_token_limit: int = 2
stop_sequences: list[str] = []
include_duckasgo: bool = False
class DuckasgoRequest(BaseModel):
query: str
class AnimateRequest(BaseModel):
prompt: str
negative_prompt: str = ""
num_frames: int = 16
guidance_scale: float = 2.0
num_inference_steps: int = 6
seed: int = Field(default_factory=lambda: int(torch.randint(0, 1001, (1,)).item()))
class MusicRequest(BaseModel):
texts: list[str]
max_new_tokens: int = MUSICGEN_MAX_TOKENS
async def cleanup_memory():
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
async def perform_duckasgo_search(query: str, max_results: int = 3) -> str:
try:
with DDGS() as ddgs:
results = ddgs.text(query, max_results=max_results)
except Exception as e:
return f"Error DuckDuckGo: {e}"
if not results:
return "No se encontraron resultados."
text = "\nResultados DuckDuckGo:\n"
for i, r in enumerate(results, 1):
text += f"{i}. {r.get('title','')}\n URL: {r.get('href','')}\n {r.get('body','')}\n"
return text
async def generate_next_token(input_ids, past_key_values, gen_config, device):
with torch.autocast(device_type=device, dtype=torch.float16):
outputs = global_model(
input_ids,
past_key_values=past_key_values,
use_cache=True,
return_dict=True
)
logits = outputs.logits[:, -1, :]
past_key_values = outputs.past_key_values
if gen_config.do_sample:
logits = logits / gen_config.temperature
if gen_config.top_k > 0:
topk_vals, _ = torch.topk(logits, k=gen_config.top_k)
logits[logits < topk_vals[..., -1]] = -float('Inf')
probs = torch.nn.functional.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = torch.argmax(logits, dim=-1, keepdim=True)
return next_token, past_key_values
async def stream_text(request: GenerateRequest, device: str):
encoded = global_tokenizer(request.input_text, return_tensors="pt", truncation=False)
all_ids = encoded.input_ids.to(device)
segments = [all_ids[:, i:i+MAX_CONTEXT_LEN] for i in range(0, all_ids.size(1), MAX_CONTEXT_LEN)]
past_key_values = None
for seg in segments[:-1]:
with torch.no_grad():
out = global_model(seg, past_key_values=past_key_values, use_cache=True, return_dict=True)
past_key_values = out.past_key_values
last_seg = segments[-1]
input_ids = last_seg[:, -1].unsqueeze(-1)
gen_config = GenerationConfig(
temperature=request.temperature,
max_new_tokens=request.max_new_tokens,
top_p=request.top_p,
top_k=request.top_k,
repetition_penalty=request.repetition_penalty,
frequency_penalty=request.frequency.penalty,
presence_penalty=request.presence_penalty,
do_sample=request.do_sample
)
torch.manual_seed(request.seed)
buffer = ""
count = 0
while True:
next_token, past_key_values = await asyncio.to_thread(
generate_next_token, input_ids, past_key_values, gen_config, device
)
tid = next_token.item()
txt = global_tokenizer.decode([tid], skip_special_tokens=True)
buffer += txt
count += 1
input_ids = next_token.unsqueeze(0)
if count >= request.chunk_token_limit:
yield buffer
buffer = ""
count = 0
if tid == global_tokens["eos_token_id"]:
break
if buffer:
yield buffer
if request.include_duckasgo:
yield "\n" + await perform_duckasgo_search(request.input_text)
await cleanup_memory()
@app.post("/generate")
async def generate_text(request: GenerateRequest, background_tasks: BackgroundTasks):
if global_model is None:
raise HTTPException(status_code=500, detail="Modelo de texto no cargado.")
device = "cuda" if torch.cuda.is_available() else "cpu"
return StreamingResponse(stream_text(request, device), media_type="text/plain")
@app.post("/duckasgo")
def duckasgo(request: DuckasgoRequest):
try:
with DDGS() as ddgs:
results = ddgs.text(request.query, max_results=10)
return JSONResponse(content={"query": request.query, "results": results or []})
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/animate")
async def animate(request: AnimateRequest):
if anim_pipe is None:
raise HTTPException(status_code=500, detail="Pipeline de animaci贸n no cargado.")
def run_pipeline():
return anim_pipe(
prompt=request.prompt,
negative_prompt=request.negative_prompt,
num_frames=request.num_frames,
guidance_scale=request.guidance_scale,
num_inference_steps=request.num_inference_steps,
generator=torch.Generator("cpu").manual_seed(request.seed)
)
output = await asyncio.get_event_loop().run_in_executor(executor, run_pipeline)
frames = output.frames[0]
buf = io.BytesIO()
export_to_gif(frames, buf)
buf.seek(0)
return StreamingResponse(buf, media_type="image/gif")
@app.post("/music")
async def generate_music(request: MusicRequest):
if music_model is None or music_processor is None:
raise HTTPException(status_code=500, detail="Modelo de audio no cargado.")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inputs = music_processor(
text=request.texts,
padding=True,
return_tensors="pt"
).to(device)
with torch.no_grad():
audio = music_model.generate(**inputs, max_new_tokens=request.max_new_tokens)
wav_bytes = music_processor.decode(audio[0].cpu()).numpy().tobytes()
return Response(wav_bytes, media_type="audio/wav")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)