|
from fastapi import FastAPI, HTTPException |
|
from pydantic import BaseModel |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline |
|
import torch |
|
import numpy as np |
|
import logging |
|
|
|
|
|
assert np.__version__.startswith('1.'), f"Несовместимая версия NumPy: {np.__version__}" |
|
|
|
app = FastAPI() |
|
|
|
class RequestData(BaseModel): |
|
prompt: str |
|
max_tokens: int = 50 |
|
|
|
|
|
|
|
MODEL_NAME = "TinyLlama/TinyLlama_v1.1" |
|
|
|
try: |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_NAME, |
|
torch_dtype=torch.float16, |
|
device_map="auto", |
|
low_cpu_mem_usage=True |
|
) |
|
|
|
|
|
generator = pipeline( |
|
"text-generation", |
|
model=model, |
|
tokenizer=tokenizer |
|
) |
|
except Exception as e: |
|
print(f"Ошибка загрузки модели: {str(e)}") |
|
generator = None |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
@app.on_event("startup") |
|
async def startup_event(): |
|
routes = [route.path for route in app.routes] |
|
print(f"Registered routes: {routes}") |
|
|
|
@app.get("/") |
|
async def root_health_check(): |
|
return {"status": "ok"} |
|
|
|
@app.post("/generate") |
|
async def generate_text(request: RequestData): |
|
if not generator: |
|
raise HTTPException(status_code=503, detail="Модель не загружена") |
|
|
|
try: |
|
output = generator( |
|
request.prompt, |
|
max_new_tokens=min(request.max_tokens, 100), |
|
do_sample=False, |
|
num_beams=1, |
|
temperature=0.7, |
|
) |
|
return {"response": output[0]["generated_text"]} |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@app.get("/health") |
|
async def health_check(): |
|
return {"status": "ok" if generator else "unavailable"} |