|
import os |
|
import gc |
|
import json |
|
import random |
|
import torch |
|
import asyncio |
|
import threading |
|
import time |
|
from fastapi import FastAPI, HTTPException, BackgroundTasks |
|
from fastapi.responses import StreamingResponse, JSONResponse, PlainTextResponse, HTMLResponse |
|
from pydantic import BaseModel, Field |
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig |
|
import uvicorn |
|
from duckduckgo_search import DDGS |
|
from concurrent.futures import ThreadPoolExecutor |
|
|
|
|
|
MODEL_NAME = "lilmeaty/my_xdd" |
|
global_model = None |
|
global_tokenizer = None |
|
global_tokens = {} |
|
|
|
|
|
executor = ThreadPoolExecutor(max_workers=4) |
|
|
|
async def cleanup_memory(device: str): |
|
gc.collect() |
|
if device == "cuda": |
|
torch.cuda.empty_cache() |
|
|
|
class GenerateRequest(BaseModel): |
|
input_text: str = "" |
|
max_new_tokens: int = 2 |
|
temperature: float = Field(default_factory=lambda: round(random.uniform(0.5, 0.8), 2)) |
|
top_p: float = Field(default_factory=lambda: round(random.uniform(0.75, 0.95), 2)) |
|
top_k: int = Field(default_factory=lambda: random.randint(20, 60)) |
|
repetition_penalty: float = Field(default_factory=lambda: round(random.uniform(1.1, 1.8), 2)) |
|
frequency_penalty: float = Field(default_factory=lambda: round(random.uniform(0.2, 0.7), 2)) |
|
presence_penalty: float = Field(default_factory=lambda: round(random.uniform(0.2, 0.7), 2)) |
|
seed: int = Field(default_factory=lambda: random.randint(0, 1000)) |
|
do_sample: bool = True |
|
stream: bool = True |
|
chunk_token_limit: int = 2 |
|
stop_sequences: list[str] = [] |
|
include_duckasgo: bool = False |
|
|
|
class DuckasgoRequest(BaseModel): |
|
query: str |
|
|
|
app = FastAPI() |
|
|
|
@app.on_event("startup") |
|
async def load_global_model(): |
|
""" |
|
Carga el modelo y el tokenizador global al iniciar la aplicación. |
|
""" |
|
global global_model, global_tokenizer, global_tokens |
|
config = AutoConfig.from_pretrained(MODEL_NAME) |
|
global_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, config=config) |
|
global_model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_NAME, config=config, torch_dtype=torch.float16 |
|
) |
|
if global_tokenizer.eos_token_id is not None and global_tokenizer.pad_token_id is None: |
|
global_tokenizer.pad_token_id = config.pad_token_id or global_tokenizer.eos_token_id |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
global_model.to(device) |
|
global_tokens["eos_token_id"] = global_tokenizer.eos_token_id |
|
global_tokens["pad_token_id"] = global_tokenizer.pad_token_id |
|
print(f"Modelo {MODEL_NAME} cargado correctamente en {device}.") |
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
async def index(): |
|
""" |
|
Endpoint raíz que devuelve una página HTML simple. |
|
""" |
|
html_content = """ |
|
<html> |
|
<head> |
|
<title>Generación de Texto</title> |
|
</head> |
|
<body> |
|
<h1>Bienvenido al Generador de Texto</h1> |
|
<p>Prueba los endpoints <code>/generate</code> o <code>/duckasgo</code>.</p> |
|
</body> |
|
</html> |
|
""" |
|
return HTMLResponse(content=html_content, status_code=200) |
|
|
|
@app.get("/health") |
|
async def health(): |
|
""" |
|
Endpoint de salud para verificar el estado del servidor. |
|
""" |
|
return {"status": "ok"} |
|
|
|
async def perform_duckasgo_search(query: str, max_results: int = 3) -> str: |
|
""" |
|
Realiza una búsqueda en DuckDuckGo y retorna un resumen de los resultados. |
|
""" |
|
try: |
|
with DDGS() as ddgs: |
|
results = ddgs.text(query, max_results=max_results) |
|
except Exception as e: |
|
return f"Error en la búsqueda de DuckDuckGo: {e}" |
|
if not results: |
|
result_text = "No se encontraron resultados en DuckDuckGo." |
|
else: |
|
result_text = "\nResultados de búsqueda (DuckDuckGo):\n" |
|
for idx, res in enumerate(results, start=1): |
|
title = res.get("title", "Sin título") |
|
url = res.get("href", "Sin URL") |
|
snippet = res.get("body", "") |
|
result_text += f"{idx}. {title}\n URL: {url}\n {snippet}\n" |
|
return result_text |
|
|
|
def generate_next_token(input_ids, past_key_values, gen_config, device): |
|
""" |
|
Función síncrona que genera el siguiente token utilizando el modelo. |
|
""" |
|
with torch.no_grad(): |
|
outputs = global_model( |
|
input_ids, |
|
past_key_values=past_key_values, |
|
use_cache=True, |
|
return_dict=True |
|
) |
|
logits = outputs.logits[:, -1, :] |
|
past_key_values = outputs.past_key_values |
|
if gen_config.do_sample: |
|
logits = logits / gen_config.temperature |
|
if gen_config.top_k and gen_config.top_k > 0: |
|
topk_values, _ = torch.topk(logits, k=gen_config.top_k) |
|
logits[logits < topk_values[:, [-1]]] = -float('Inf') |
|
probs = torch.nn.functional.softmax(logits, dim=-1) |
|
next_token = torch.multinomial(probs, num_samples=1) |
|
token_prob = probs[0, next_token.item()] |
|
token_logprob = torch.log(token_prob) |
|
else: |
|
next_token = torch.argmax(logits, dim=-1, keepdim=True) |
|
probs = torch.nn.functional.softmax(logits, dim=-1) |
|
token_prob = probs[0, next_token.item()] |
|
token_logprob = torch.log(token_prob) |
|
return next_token, past_key_values, token_logprob.item() |
|
|
|
async def stream_text(request: GenerateRequest, device: str): |
|
""" |
|
Genera texto de forma streaming y envía cada bloque enviando únicamente el contenido |
|
extraído (lo que estaba entre comillas en "generated_text") sin ningún otro dato. |
|
""" |
|
global global_model, global_tokenizer, global_tokens |
|
|
|
encoded_input = global_tokenizer(request.input_text, return_tensors="pt").to(device) |
|
input_ids = encoded_input.input_ids |
|
gen_config = GenerationConfig( |
|
temperature=request.temperature, |
|
max_new_tokens=request.max_new_tokens, |
|
top_p=request.top_p, |
|
top_k=request.top_k, |
|
repetition_penalty=request.repetition_penalty, |
|
frequency_penalty=request.frequency_penalty, |
|
presence_penalty=request.presence_penalty, |
|
do_sample=request.do_sample, |
|
) |
|
torch.manual_seed(request.seed) |
|
current_chunk = "" |
|
chunk_token_count = 0 |
|
past_key_values = None |
|
|
|
while True: |
|
next_token, past_key_values, _ = await asyncio.to_thread( |
|
generate_next_token, input_ids, past_key_values, gen_config, device |
|
) |
|
token_id = next_token.item() |
|
token_text = global_tokenizer.decode([token_id], skip_special_tokens=True) |
|
current_chunk += token_text |
|
chunk_token_count += 1 |
|
|
|
if chunk_token_count >= request.chunk_token_limit: |
|
|
|
yield current_chunk |
|
current_chunk = "" |
|
chunk_token_count = 0 |
|
|
|
input_ids = next_token |
|
|
|
if token_id == global_tokens["eos_token_id"]: |
|
break |
|
|
|
if current_chunk: |
|
yield current_chunk |
|
|
|
if request.include_duckasgo: |
|
search_summary = await perform_duckasgo_search(request.input_text) |
|
yield "\n" + search_summary |
|
|
|
await cleanup_memory(device) |
|
|
|
def synchronous_generation(encoded_input, gen_config, device): |
|
""" |
|
Función síncrona para la generación completa en modo no streaming. |
|
""" |
|
with torch.no_grad(): |
|
output = global_model.generate( |
|
**encoded_input, |
|
generation_config=gen_config, |
|
return_dict_in_generate=True, |
|
output_scores=True, |
|
return_legacy_cache=True |
|
) |
|
return output |
|
|
|
@app.post("/generate") |
|
async def generate_text(request: GenerateRequest, background_tasks: BackgroundTasks): |
|
""" |
|
Endpoint para la generación de texto. |
|
Tanto en modo streaming como en modo no streaming se devuelve únicamente el contenido extraído. |
|
""" |
|
global global_model, global_tokenizer, global_tokens |
|
if global_model is None or global_tokenizer is None: |
|
raise HTTPException(status_code=500, detail="El modelo no se ha cargado correctamente.") |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
gen_config = GenerationConfig( |
|
temperature=request.temperature, |
|
max_new_tokens=request.max_new_tokens, |
|
top_p=request.top_p, |
|
top_k=request.top_k, |
|
repetition_penalty=request.repetition_penalty, |
|
frequency_penalty=request.frequency_penalty, |
|
presence_penalty=request.presence_penalty, |
|
do_sample=request.do_sample, |
|
) |
|
torch.manual_seed(request.seed) |
|
try: |
|
if request.stream: |
|
generator = stream_text(request, device) |
|
|
|
return StreamingResponse(generator, media_type="text/plain") |
|
else: |
|
encoded_input = global_tokenizer(request.input_text, return_tensors="pt").to(device) |
|
output = await asyncio.to_thread(synchronous_generation, encoded_input, gen_config, device) |
|
input_length = encoded_input["input_ids"].shape[-1] |
|
full_generated_text = global_tokenizer.decode( |
|
output.sequences[0][input_length:], skip_special_tokens=True |
|
) |
|
if request.stop_sequences: |
|
for stop_seq in request.stop_sequences: |
|
if stop_seq in full_generated_text: |
|
full_generated_text = full_generated_text.split(stop_seq)[0] |
|
break |
|
|
|
token_ids = global_tokenizer.encode(full_generated_text, add_special_tokens=False) |
|
chunks = [] |
|
for i in range(0, len(token_ids), request.chunk_token_limit): |
|
chunk_ids = token_ids[i:i+request.chunk_token_limit] |
|
chunk_text = global_tokenizer.decode(chunk_ids, skip_special_tokens=True) |
|
chunks.append(chunk_text) |
|
final_text = "".join(chunks) |
|
|
|
if request.include_duckasgo: |
|
search_summary = await perform_duckasgo_search(request.input_text) |
|
final_text += "\n" + search_summary |
|
|
|
await cleanup_memory(device) |
|
background_tasks.add_task(lambda: print("Generación completada.")) |
|
|
|
return PlainTextResponse(final_text) |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Error durante la generación: {e}") |
|
|
|
@app.post("/duckasgo") |
|
async def duckasgo_search(request: DuckasgoRequest): |
|
""" |
|
Endpoint para búsquedas en DuckDuckGo. |
|
""" |
|
try: |
|
with DDGS() as ddgs: |
|
results = ddgs.text(request.query, max_results=10) |
|
if not results: |
|
results = [] |
|
return JSONResponse(content={"query": request.query, "results": results}) |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Error en la búsqueda: {e}") |
|
|
|
def run_server(): |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|
|
if __name__ == "__main__": |
|
server_thread = threading.Thread(target=run_server, daemon=True) |
|
server_thread.start() |
|
while True: |
|
time.sleep(0) |