Ggggggg / app.py
Hjgugugjhuhjggg's picture
Update app.py
b7b73eb verified
import os
import gc
import json
import random
import torch
import asyncio
import threading
import time
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.responses import StreamingResponse, JSONResponse, PlainTextResponse, HTMLResponse
from pydantic import BaseModel, Field
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig
import uvicorn
from duckduckgo_search import DDGS
from concurrent.futures import ThreadPoolExecutor
# Nombre del modelo y variables globales
MODEL_NAME = "lilmeaty/my_xdd"
global_model = None
global_tokenizer = None
global_tokens = {}
# Executor para ejecutar tareas en paralelo
executor = ThreadPoolExecutor(max_workers=4)
async def cleanup_memory(device: str):
gc.collect()
if device == "cuda":
torch.cuda.empty_cache()
class GenerateRequest(BaseModel):
input_text: str = ""
max_new_tokens: int = 2
temperature: float = Field(default_factory=lambda: round(random.uniform(0.5, 0.8), 2))
top_p: float = Field(default_factory=lambda: round(random.uniform(0.75, 0.95), 2))
top_k: int = Field(default_factory=lambda: random.randint(20, 60))
repetition_penalty: float = Field(default_factory=lambda: round(random.uniform(1.1, 1.8), 2))
frequency_penalty: float = Field(default_factory=lambda: round(random.uniform(0.2, 0.7), 2))
presence_penalty: float = Field(default_factory=lambda: round(random.uniform(0.2, 0.7), 2))
seed: int = Field(default_factory=lambda: random.randint(0, 1000))
do_sample: bool = True
stream: bool = True
chunk_token_limit: int = 2
stop_sequences: list[str] = []
include_duckasgo: bool = False
class DuckasgoRequest(BaseModel):
query: str
app = FastAPI()
@app.on_event("startup")
async def load_global_model():
"""
Carga el modelo y el tokenizador global al iniciar la aplicación.
"""
global global_model, global_tokenizer, global_tokens
config = AutoConfig.from_pretrained(MODEL_NAME)
global_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, config=config)
global_model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME, config=config, torch_dtype=torch.float16
)
if global_tokenizer.eos_token_id is not None and global_tokenizer.pad_token_id is None:
global_tokenizer.pad_token_id = config.pad_token_id or global_tokenizer.eos_token_id
device = "cuda" if torch.cuda.is_available() else "cpu"
global_model.to(device)
global_tokens["eos_token_id"] = global_tokenizer.eos_token_id
global_tokens["pad_token_id"] = global_tokenizer.pad_token_id
print(f"Modelo {MODEL_NAME} cargado correctamente en {device}.")
@app.get("/", response_class=HTMLResponse)
async def index():
"""
Endpoint raíz que devuelve una página HTML simple.
"""
html_content = """
<html>
<head>
<title>Generación de Texto</title>
</head>
<body>
<h1>Bienvenido al Generador de Texto</h1>
<p>Prueba los endpoints <code>/generate</code> o <code>/duckasgo</code>.</p>
</body>
</html>
"""
return HTMLResponse(content=html_content, status_code=200)
@app.get("/health")
async def health():
"""
Endpoint de salud para verificar el estado del servidor.
"""
return {"status": "ok"}
async def perform_duckasgo_search(query: str, max_results: int = 3) -> str:
"""
Realiza una búsqueda en DuckDuckGo y retorna un resumen de los resultados.
"""
try:
with DDGS() as ddgs:
results = ddgs.text(query, max_results=max_results)
except Exception as e:
return f"Error en la búsqueda de DuckDuckGo: {e}"
if not results:
result_text = "No se encontraron resultados en DuckDuckGo."
else:
result_text = "\nResultados de búsqueda (DuckDuckGo):\n"
for idx, res in enumerate(results, start=1):
title = res.get("title", "Sin título")
url = res.get("href", "Sin URL")
snippet = res.get("body", "")
result_text += f"{idx}. {title}\n URL: {url}\n {snippet}\n"
return result_text
def generate_next_token(input_ids, past_key_values, gen_config, device):
"""
Función síncrona que genera el siguiente token utilizando el modelo.
"""
with torch.no_grad():
outputs = global_model(
input_ids,
past_key_values=past_key_values,
use_cache=True,
return_dict=True
)
logits = outputs.logits[:, -1, :]
past_key_values = outputs.past_key_values
if gen_config.do_sample:
logits = logits / gen_config.temperature
if gen_config.top_k and gen_config.top_k > 0:
topk_values, _ = torch.topk(logits, k=gen_config.top_k)
logits[logits < topk_values[:, [-1]]] = -float('Inf')
probs = torch.nn.functional.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
token_prob = probs[0, next_token.item()]
token_logprob = torch.log(token_prob)
else:
next_token = torch.argmax(logits, dim=-1, keepdim=True)
probs = torch.nn.functional.softmax(logits, dim=-1)
token_prob = probs[0, next_token.item()]
token_logprob = torch.log(token_prob)
return next_token, past_key_values, token_logprob.item()
async def stream_text(request: GenerateRequest, device: str):
"""
Genera texto de forma streaming y envía cada bloque enviando únicamente el contenido
extraído (lo que estaba entre comillas en "generated_text") sin ningún otro dato.
"""
global global_model, global_tokenizer, global_tokens
# Preparar la entrada y configurar la generación
encoded_input = global_tokenizer(request.input_text, return_tensors="pt").to(device)
input_ids = encoded_input.input_ids
gen_config = GenerationConfig(
temperature=request.temperature,
max_new_tokens=request.max_new_tokens, # Se utiliza solo para parámetros
top_p=request.top_p,
top_k=request.top_k,
repetition_penalty=request.repetition_penalty,
frequency_penalty=request.frequency_penalty,
presence_penalty=request.presence_penalty,
do_sample=request.do_sample,
)
torch.manual_seed(request.seed)
current_chunk = ""
chunk_token_count = 0
past_key_values = None
while True:
next_token, past_key_values, _ = await asyncio.to_thread(
generate_next_token, input_ids, past_key_values, gen_config, device
)
token_id = next_token.item()
token_text = global_tokenizer.decode([token_id], skip_special_tokens=True)
current_chunk += token_text
chunk_token_count += 1
if chunk_token_count >= request.chunk_token_limit:
# Se envía únicamente el contenido extraído
yield current_chunk
current_chunk = ""
chunk_token_count = 0
input_ids = next_token
if token_id == global_tokens["eos_token_id"]:
break
if current_chunk:
yield current_chunk
if request.include_duckasgo:
search_summary = await perform_duckasgo_search(request.input_text)
yield "\n" + search_summary
await cleanup_memory(device)
def synchronous_generation(encoded_input, gen_config, device):
"""
Función síncrona para la generación completa en modo no streaming.
"""
with torch.no_grad():
output = global_model.generate(
**encoded_input,
generation_config=gen_config,
return_dict_in_generate=True,
output_scores=True,
return_legacy_cache=True
)
return output
@app.post("/generate")
async def generate_text(request: GenerateRequest, background_tasks: BackgroundTasks):
"""
Endpoint para la generación de texto.
Tanto en modo streaming como en modo no streaming se devuelve únicamente el contenido extraído.
"""
global global_model, global_tokenizer, global_tokens
if global_model is None or global_tokenizer is None:
raise HTTPException(status_code=500, detail="El modelo no se ha cargado correctamente.")
device = "cuda" if torch.cuda.is_available() else "cpu"
gen_config = GenerationConfig(
temperature=request.temperature,
max_new_tokens=request.max_new_tokens,
top_p=request.top_p,
top_k=request.top_k,
repetition_penalty=request.repetition_penalty,
frequency_penalty=request.frequency_penalty,
presence_penalty=request.presence_penalty,
do_sample=request.do_sample,
)
torch.manual_seed(request.seed)
try:
if request.stream:
generator = stream_text(request, device)
# Se envía la respuesta en streaming únicamente con el contenido extraído
return StreamingResponse(generator, media_type="text/plain")
else:
encoded_input = global_tokenizer(request.input_text, return_tensors="pt").to(device)
output = await asyncio.to_thread(synchronous_generation, encoded_input, gen_config, device)
input_length = encoded_input["input_ids"].shape[-1]
full_generated_text = global_tokenizer.decode(
output.sequences[0][input_length:], skip_special_tokens=True
)
if request.stop_sequences:
for stop_seq in request.stop_sequences:
if stop_seq in full_generated_text:
full_generated_text = full_generated_text.split(stop_seq)[0]
break
token_ids = global_tokenizer.encode(full_generated_text, add_special_tokens=False)
chunks = []
for i in range(0, len(token_ids), request.chunk_token_limit):
chunk_ids = token_ids[i:i+request.chunk_token_limit]
chunk_text = global_tokenizer.decode(chunk_ids, skip_special_tokens=True)
chunks.append(chunk_text)
final_text = "".join(chunks)
if request.include_duckasgo:
search_summary = await perform_duckasgo_search(request.input_text)
final_text += "\n" + search_summary
await cleanup_memory(device)
background_tasks.add_task(lambda: print("Generación completada."))
# Se devuelve únicamente el contenido extraído
return PlainTextResponse(final_text)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error durante la generación: {e}")
@app.post("/duckasgo")
async def duckasgo_search(request: DuckasgoRequest):
"""
Endpoint para búsquedas en DuckDuckGo.
"""
try:
with DDGS() as ddgs:
results = ddgs.text(request.query, max_results=10)
if not results:
results = []
return JSONResponse(content={"query": request.query, "results": results})
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error en la búsqueda: {e}")
def run_server():
uvicorn.run(app, host="0.0.0.0", port=7860)
if __name__ == "__main__":
server_thread = threading.Thread(target=run_server, daemon=True)
server_thread.start()
while True:
time.sleep(0)