import os import gc import io import asyncio import threading import time import torch from fastapi import FastAPI, HTTPException, BackgroundTasks, Response from fastapi.responses import StreamingResponse, JSONResponse, PlainTextResponse, HTMLResponse from pydantic import BaseModel, Field from transformers import ( AutoConfig, AutoTokenizer, GenerationConfig, AutoModelForCausalLM, AutoProcessor, MusicgenForConditionalGeneration ) from diffusers import AnimateDiffPipeline, LCMScheduler, MotionAdapter from diffusers.utils import export_to_gif from concurrent.futures import ThreadPoolExecutor import uvicorn from duckduckgo_search import DDGS # ----------------------- # Configuración y optimizaciones # ----------------------- MODEL_NAME = "jnjj/gemma-3-4b-it-1layer-actual" MAX_CONTEXT_LEN = 1024 MUSICGEN_MAX_TOKENS = 256 global_model = None global_tokenizer = None global_tokens = {} motion_adapter = None anim_pipe = None music_processor = None music_model = None executor = ThreadPoolExecutor(max_workers=4) # ----------------------- # FastAPI App # ----------------------- app = FastAPI() @app.on_event("startup") def load_global_models(): """Carga modelos de texto, animación y audio con optimizaciones.""" global global_model, global_tokenizer, global_tokens global motion_adapter, anim_pipe global music_processor, music_model # --- Text model --- text_config = AutoConfig.from_pretrained(MODEL_NAME) text_config.max_position_embeddings = MAX_CONTEXT_LEN global_tokenizer = AutoTokenizer.from_pretrained( MODEL_NAME, config=text_config, use_fast=True ) global_tokenizer.model_max_length = MAX_CONTEXT_LEN global_model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, config=text_config, device_map="auto", offload_folder="./offload", offload_state_dict=True, torch_dtype=torch.float16, low_cpu_mem_usage=True, trust_remote_code=True, # use_cache=True ) global_model = torch.compile(global_model, backend="inductor") if global_tokenizer.eos_token_id is not None and global_tokenizer.pad_token_id is None: global_tokenizer.pad_token_id = global_tokenizer.eos_token_id global_tokens.update({ "eos_token_id": global_tokenizer.eos_token_id, "pad_token_id": global_tokenizer.pad_token_id }) # --- Animation model --- motion_adapter = MotionAdapter.from_pretrained( "wangfuyun/AnimateLCM", torch_dtype=torch.float16 ) anim_pipe = AnimateDiffPipeline.from_pretrained( "emilianJR/epiCRealism", motion_adapter=motion_adapter, torch_dtype=torch.float16 ) anim_pipe.scheduler = LCMScheduler.from_config( anim_pipe.scheduler.config, beta_schedule="linear" ) anim_pipe.load_lora_weights( "wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora" ) anim_pipe.set_adapters(["lcm-lora"], [0.8]) anim_pipe.enable_vae_slicing() anim_pipe.enable_model_cpu_offload() # --- MusicGen model --- music_processor = AutoProcessor.from_pretrained("facebook/musicgen-small") music_model = MusicgenForConditionalGeneration.from_pretrained( "facebook/musicgen-small" ) music_model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) print("Modelos de texto, animación y audio cargados sin bitsandbytes.") @app.get("/", response_class=HTMLResponse) def index(): return HTMLResponse( content=""" Generador Ultra-Rápido

Servicio de Generación Multimedia

""", status_code=200 ) @app.get("/health") def health(): return {"status": "ok"} class GenerateRequest(BaseModel): input_text: str = "" max_new_tokens: int = 2 temperature: float = Field(default_factory=lambda: round(torch.rand(1).item() * 0.3 + 0.5, 2)) top_p: float = Field(default_factory=lambda: round(torch.rand(1).item() * 0.2 + 0.75, 2)) top_k: int = Field(default_factory=lambda: int(torch.randint(20, 61, (1,)).item())) repetition_penalty: float = Field(default_factory=lambda: round(torch.rand(1).item() * 0.7 + 1.1, 2)) frequency_penalty: float = Field(default_factory=lambda: round(torch.rand(1).item() * 0.5 + 0.2, 2)) presence_penalty: float = Field(default_factory=lambda: round(torch.rand(1).item() * 0.5 + 0.2, 2)) seed: int = Field(default_factory=lambda: int(torch.randint(0, 1001, (1,)).item())) do_sample: bool = True stream: bool = True chunk_token_limit: int = 2 stop_sequences: list[str] = [] include_duckasgo: bool = False class DuckasgoRequest(BaseModel): query: str class AnimateRequest(BaseModel): prompt: str negative_prompt: str = "" num_frames: int = 16 guidance_scale: float = 2.0 num_inference_steps: int = 6 seed: int = Field(default_factory=lambda: int(torch.randint(0, 1001, (1,)).item())) class MusicRequest(BaseModel): texts: list[str] max_new_tokens: int = MUSICGEN_MAX_TOKENS async def cleanup_memory(): gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() async def perform_duckasgo_search(query: str, max_results: int = 3) -> str: try: with DDGS() as ddgs: results = ddgs.text(query, max_results=max_results) except Exception as e: return f"Error DuckDuckGo: {e}" if not results: return "No se encontraron resultados." text = "\nResultados DuckDuckGo:\n" for i, r in enumerate(results, 1): text += f"{i}. {r.get('title','')}\n URL: {r.get('href','')}\n {r.get('body','')}\n" return text async def generate_next_token(input_ids, past_key_values, gen_config, device): with torch.autocast(device_type=device, dtype=torch.float16): outputs = global_model( input_ids, past_key_values=past_key_values, use_cache=True, return_dict=True ) logits = outputs.logits[:, -1, :] past_key_values = outputs.past_key_values if gen_config.do_sample: logits = logits / gen_config.temperature if gen_config.top_k > 0: topk_vals, _ = torch.topk(logits, k=gen_config.top_k) logits[logits < topk_vals[..., -1]] = -float('Inf') probs = torch.nn.functional.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) else: next_token = torch.argmax(logits, dim=-1, keepdim=True) return next_token, past_key_values async def stream_text(request: GenerateRequest, device: str): encoded = global_tokenizer(request.input_text, return_tensors="pt", truncation=False) all_ids = encoded.input_ids.to(device) segments = [all_ids[:, i:i+MAX_CONTEXT_LEN] for i in range(0, all_ids.size(1), MAX_CONTEXT_LEN)] past_key_values = None for seg in segments[:-1]: with torch.no_grad(): out = global_model(seg, past_key_values=past_key_values, use_cache=True, return_dict=True) past_key_values = out.past_key_values last_seg = segments[-1] input_ids = last_seg[:, -1].unsqueeze(-1) gen_config = GenerationConfig( temperature=request.temperature, max_new_tokens=request.max_new_tokens, top_p=request.top_p, top_k=request.top_k, repetition_penalty=request.repetition_penalty, frequency_penalty=request.frequency.penalty, presence_penalty=request.presence_penalty, do_sample=request.do_sample ) torch.manual_seed(request.seed) buffer = "" count = 0 while True: next_token, past_key_values = await asyncio.to_thread( generate_next_token, input_ids, past_key_values, gen_config, device ) tid = next_token.item() txt = global_tokenizer.decode([tid], skip_special_tokens=True) buffer += txt count += 1 input_ids = next_token.unsqueeze(0) if count >= request.chunk_token_limit: yield buffer buffer = "" count = 0 if tid == global_tokens["eos_token_id"]: break if buffer: yield buffer if request.include_duckasgo: yield "\n" + await perform_duckasgo_search(request.input_text) await cleanup_memory() @app.post("/generate") async def generate_text(request: GenerateRequest, background_tasks: BackgroundTasks): if global_model is None: raise HTTPException(status_code=500, detail="Modelo de texto no cargado.") device = "cuda" if torch.cuda.is_available() else "cpu" return StreamingResponse(stream_text(request, device), media_type="text/plain") @app.post("/duckasgo") def duckasgo(request: DuckasgoRequest): try: with DDGS() as ddgs: results = ddgs.text(request.query, max_results=10) return JSONResponse(content={"query": request.query, "results": results or []}) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/animate") async def animate(request: AnimateRequest): if anim_pipe is None: raise HTTPException(status_code=500, detail="Pipeline de animación no cargado.") def run_pipeline(): return anim_pipe( prompt=request.prompt, negative_prompt=request.negative_prompt, num_frames=request.num_frames, guidance_scale=request.guidance_scale, num_inference_steps=request.num_inference_steps, generator=torch.Generator("cpu").manual_seed(request.seed) ) output = await asyncio.get_event_loop().run_in_executor(executor, run_pipeline) frames = output.frames[0] buf = io.BytesIO() export_to_gif(frames, buf) buf.seek(0) return StreamingResponse(buf, media_type="image/gif") @app.post("/music") async def generate_music(request: MusicRequest): if music_model is None or music_processor is None: raise HTTPException(status_code=500, detail="Modelo de audio no cargado.") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") inputs = music_processor( text=request.texts, padding=True, return_tensors="pt" ).to(device) with torch.no_grad(): audio = music_model.generate(**inputs, max_new_tokens=request.max_new_tokens) wav_bytes = music_processor.decode(audio[0].cpu()).numpy().tobytes() return Response(wav_bytes, media_type="audio/wav") if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)