|
import os |
|
import gc |
|
import io |
|
import asyncio |
|
import threading |
|
import time |
|
import torch |
|
from fastapi import FastAPI, HTTPException, BackgroundTasks, Response |
|
from fastapi.responses import StreamingResponse, JSONResponse, PlainTextResponse, HTMLResponse |
|
from pydantic import BaseModel, Field |
|
from transformers import ( |
|
AutoConfig, |
|
AutoTokenizer, |
|
GenerationConfig, |
|
AutoModelForCausalLM, |
|
AutoProcessor, |
|
MusicgenForConditionalGeneration |
|
) |
|
from diffusers import AnimateDiffPipeline, LCMScheduler, MotionAdapter |
|
from diffusers.utils import export_to_gif |
|
from concurrent.futures import ThreadPoolExecutor |
|
import uvicorn |
|
from duckduckgo_search import DDGS |
|
|
|
|
|
|
|
|
|
MODEL_NAME = "jnjj/gemma-3-4b-it-1layer-actual" |
|
MAX_CONTEXT_LEN = 1024 |
|
MUSICGEN_MAX_TOKENS = 256 |
|
|
|
global_model = None |
|
global_tokenizer = None |
|
global_tokens = {} |
|
motion_adapter = None |
|
anim_pipe = None |
|
music_processor = None |
|
music_model = None |
|
executor = ThreadPoolExecutor(max_workers=4) |
|
|
|
|
|
|
|
|
|
app = FastAPI() |
|
|
|
@app.on_event("startup") |
|
def load_global_models(): |
|
"""Carga modelos de texto, animaci贸n y audio con optimizaciones.""" |
|
global global_model, global_tokenizer, global_tokens |
|
global motion_adapter, anim_pipe |
|
global music_processor, music_model |
|
|
|
|
|
text_config = AutoConfig.from_pretrained(MODEL_NAME) |
|
text_config.max_position_embeddings = MAX_CONTEXT_LEN |
|
global_tokenizer = AutoTokenizer.from_pretrained( |
|
MODEL_NAME, |
|
config=text_config, |
|
use_fast=True |
|
) |
|
global_tokenizer.model_max_length = MAX_CONTEXT_LEN |
|
|
|
global_model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_NAME, |
|
config=text_config, |
|
device_map="auto", |
|
offload_folder="./offload", |
|
offload_state_dict=True, |
|
torch_dtype=torch.float16, |
|
low_cpu_mem_usage=True, |
|
trust_remote_code=True, |
|
|
|
) |
|
global_model = torch.compile(global_model, backend="inductor") |
|
|
|
if global_tokenizer.eos_token_id is not None and global_tokenizer.pad_token_id is None: |
|
global_tokenizer.pad_token_id = global_tokenizer.eos_token_id |
|
global_tokens.update({ |
|
"eos_token_id": global_tokenizer.eos_token_id, |
|
"pad_token_id": global_tokenizer.pad_token_id |
|
}) |
|
|
|
|
|
motion_adapter = MotionAdapter.from_pretrained( |
|
"wangfuyun/AnimateLCM", torch_dtype=torch.float16 |
|
) |
|
anim_pipe = AnimateDiffPipeline.from_pretrained( |
|
"emilianJR/epiCRealism", |
|
motion_adapter=motion_adapter, |
|
torch_dtype=torch.float16 |
|
) |
|
anim_pipe.scheduler = LCMScheduler.from_config( |
|
anim_pipe.scheduler.config, beta_schedule="linear" |
|
) |
|
anim_pipe.load_lora_weights( |
|
"wangfuyun/AnimateLCM", |
|
weight_name="AnimateLCM_sd15_t2v_lora.safetensors", |
|
adapter_name="lcm-lora" |
|
) |
|
anim_pipe.set_adapters(["lcm-lora"], [0.8]) |
|
anim_pipe.enable_vae_slicing() |
|
anim_pipe.enable_model_cpu_offload() |
|
|
|
|
|
music_processor = AutoProcessor.from_pretrained("facebook/musicgen-small") |
|
music_model = MusicgenForConditionalGeneration.from_pretrained( |
|
"facebook/musicgen-small" |
|
) |
|
music_model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) |
|
|
|
print("Modelos de texto, animaci贸n y audio cargados sin bitsandbytes.") |
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
def index(): |
|
return HTMLResponse( |
|
content=""" |
|
<html><head><title>Generador Ultra-R谩pido</title></head> |
|
<body> |
|
<h1>Servicio de Generaci贸n Multimedia</h1> |
|
<ul> |
|
<li>Texto: FP16, offload, torch.compile.</li> |
|
<li>Animaci贸n: AnimateDiffPipeline con LoRA y CPU offload.</li> |
|
<li>Audio: MusicGen small, max tokens 256.</li> |
|
</ul> |
|
</body></html> |
|
""", |
|
status_code=200 |
|
) |
|
|
|
@app.get("/health") |
|
def health(): |
|
return {"status": "ok"} |
|
|
|
class GenerateRequest(BaseModel): |
|
input_text: str = "" |
|
max_new_tokens: int = 2 |
|
temperature: float = Field(default_factory=lambda: round(torch.rand(1).item() * 0.3 + 0.5, 2)) |
|
top_p: float = Field(default_factory=lambda: round(torch.rand(1).item() * 0.2 + 0.75, 2)) |
|
top_k: int = Field(default_factory=lambda: int(torch.randint(20, 61, (1,)).item())) |
|
repetition_penalty: float = Field(default_factory=lambda: round(torch.rand(1).item() * 0.7 + 1.1, 2)) |
|
frequency_penalty: float = Field(default_factory=lambda: round(torch.rand(1).item() * 0.5 + 0.2, 2)) |
|
presence_penalty: float = Field(default_factory=lambda: round(torch.rand(1).item() * 0.5 + 0.2, 2)) |
|
seed: int = Field(default_factory=lambda: int(torch.randint(0, 1001, (1,)).item())) |
|
do_sample: bool = True |
|
stream: bool = True |
|
chunk_token_limit: int = 2 |
|
stop_sequences: list[str] = [] |
|
include_duckasgo: bool = False |
|
|
|
class DuckasgoRequest(BaseModel): |
|
query: str |
|
|
|
class AnimateRequest(BaseModel): |
|
prompt: str |
|
negative_prompt: str = "" |
|
num_frames: int = 16 |
|
guidance_scale: float = 2.0 |
|
num_inference_steps: int = 6 |
|
seed: int = Field(default_factory=lambda: int(torch.randint(0, 1001, (1,)).item())) |
|
|
|
class MusicRequest(BaseModel): |
|
texts: list[str] |
|
max_new_tokens: int = MUSICGEN_MAX_TOKENS |
|
|
|
async def cleanup_memory(): |
|
gc.collect() |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
async def perform_duckasgo_search(query: str, max_results: int = 3) -> str: |
|
try: |
|
with DDGS() as ddgs: |
|
results = ddgs.text(query, max_results=max_results) |
|
except Exception as e: |
|
return f"Error DuckDuckGo: {e}" |
|
if not results: |
|
return "No se encontraron resultados." |
|
text = "\nResultados DuckDuckGo:\n" |
|
for i, r in enumerate(results, 1): |
|
text += f"{i}. {r.get('title','')}\n URL: {r.get('href','')}\n {r.get('body','')}\n" |
|
return text |
|
|
|
async def generate_next_token(input_ids, past_key_values, gen_config, device): |
|
with torch.autocast(device_type=device, dtype=torch.float16): |
|
outputs = global_model( |
|
input_ids, |
|
past_key_values=past_key_values, |
|
use_cache=True, |
|
return_dict=True |
|
) |
|
logits = outputs.logits[:, -1, :] |
|
past_key_values = outputs.past_key_values |
|
if gen_config.do_sample: |
|
logits = logits / gen_config.temperature |
|
if gen_config.top_k > 0: |
|
topk_vals, _ = torch.topk(logits, k=gen_config.top_k) |
|
logits[logits < topk_vals[..., -1]] = -float('Inf') |
|
probs = torch.nn.functional.softmax(logits, dim=-1) |
|
next_token = torch.multinomial(probs, num_samples=1) |
|
else: |
|
next_token = torch.argmax(logits, dim=-1, keepdim=True) |
|
return next_token, past_key_values |
|
|
|
async def stream_text(request: GenerateRequest, device: str): |
|
encoded = global_tokenizer(request.input_text, return_tensors="pt", truncation=False) |
|
all_ids = encoded.input_ids.to(device) |
|
segments = [all_ids[:, i:i+MAX_CONTEXT_LEN] for i in range(0, all_ids.size(1), MAX_CONTEXT_LEN)] |
|
past_key_values = None |
|
for seg in segments[:-1]: |
|
with torch.no_grad(): |
|
out = global_model(seg, past_key_values=past_key_values, use_cache=True, return_dict=True) |
|
past_key_values = out.past_key_values |
|
last_seg = segments[-1] |
|
input_ids = last_seg[:, -1].unsqueeze(-1) |
|
gen_config = GenerationConfig( |
|
temperature=request.temperature, |
|
max_new_tokens=request.max_new_tokens, |
|
top_p=request.top_p, |
|
top_k=request.top_k, |
|
repetition_penalty=request.repetition_penalty, |
|
frequency_penalty=request.frequency.penalty, |
|
presence_penalty=request.presence_penalty, |
|
do_sample=request.do_sample |
|
) |
|
torch.manual_seed(request.seed) |
|
buffer = "" |
|
count = 0 |
|
while True: |
|
next_token, past_key_values = await asyncio.to_thread( |
|
generate_next_token, input_ids, past_key_values, gen_config, device |
|
) |
|
tid = next_token.item() |
|
txt = global_tokenizer.decode([tid], skip_special_tokens=True) |
|
buffer += txt |
|
count += 1 |
|
input_ids = next_token.unsqueeze(0) |
|
if count >= request.chunk_token_limit: |
|
yield buffer |
|
buffer = "" |
|
count = 0 |
|
if tid == global_tokens["eos_token_id"]: |
|
break |
|
if buffer: |
|
yield buffer |
|
if request.include_duckasgo: |
|
yield "\n" + await perform_duckasgo_search(request.input_text) |
|
await cleanup_memory() |
|
|
|
@app.post("/generate") |
|
async def generate_text(request: GenerateRequest, background_tasks: BackgroundTasks): |
|
if global_model is None: |
|
raise HTTPException(status_code=500, detail="Modelo de texto no cargado.") |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
return StreamingResponse(stream_text(request, device), media_type="text/plain") |
|
|
|
@app.post("/duckasgo") |
|
def duckasgo(request: DuckasgoRequest): |
|
try: |
|
with DDGS() as ddgs: |
|
results = ddgs.text(request.query, max_results=10) |
|
return JSONResponse(content={"query": request.query, "results": results or []}) |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@app.post("/animate") |
|
async def animate(request: AnimateRequest): |
|
if anim_pipe is None: |
|
raise HTTPException(status_code=500, detail="Pipeline de animaci贸n no cargado.") |
|
def run_pipeline(): |
|
return anim_pipe( |
|
prompt=request.prompt, |
|
negative_prompt=request.negative_prompt, |
|
num_frames=request.num_frames, |
|
guidance_scale=request.guidance_scale, |
|
num_inference_steps=request.num_inference_steps, |
|
generator=torch.Generator("cpu").manual_seed(request.seed) |
|
) |
|
output = await asyncio.get_event_loop().run_in_executor(executor, run_pipeline) |
|
frames = output.frames[0] |
|
buf = io.BytesIO() |
|
export_to_gif(frames, buf) |
|
buf.seek(0) |
|
return StreamingResponse(buf, media_type="image/gif") |
|
|
|
@app.post("/music") |
|
async def generate_music(request: MusicRequest): |
|
if music_model is None or music_processor is None: |
|
raise HTTPException(status_code=500, detail="Modelo de audio no cargado.") |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
inputs = music_processor( |
|
text=request.texts, |
|
padding=True, |
|
return_tensors="pt" |
|
).to(device) |
|
with torch.no_grad(): |
|
audio = music_model.generate(**inputs, max_new_tokens=request.max_new_tokens) |
|
wav_bytes = music_processor.decode(audio[0].cpu()).numpy().tobytes() |
|
return Response(wav_bytes, media_type="audio/wav") |
|
|
|
if __name__ == "__main__": |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|