aibot / main.py
xset's picture
float16
e4fa82c
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
import numpy as np
import logging
# Проверка версии NumPy
assert np.__version__.startswith('1.'), f"Несовместимая версия NumPy: {np.__version__}"
app = FastAPI()
class RequestData(BaseModel):
prompt: str
max_tokens: int = 50
#MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
#MODEL_NAME = "ai-forever/rugpt3small_based_on_gpt2"
MODEL_NAME = "TinyLlama/TinyLlama_v1.1"
try:
# Загрузка модели с явным указанием device_map
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16,
device_map="auto",
low_cpu_mem_usage=True
)
# Создаем pipeline без указания device
generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer
)
except Exception as e:
print(f"Ошибка загрузки модели: {str(e)}")
generator = None
logging.basicConfig(level=logging.INFO)
@app.on_event("startup")
async def startup_event():
routes = [route.path for route in app.routes]
print(f"Registered routes: {routes}")
@app.get("/")
async def root_health_check():
return {"status": "ok"}
@app.post("/generate")
async def generate_text(request: RequestData):
if not generator:
raise HTTPException(status_code=503, detail="Модель не загружена")
try:
output = generator(
request.prompt,
max_new_tokens=min(request.max_tokens, 100),
do_sample=False,
num_beams=1,
temperature=0.7,
)
return {"response": output[0]["generated_text"]}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
return {"status": "ok" if generator else "unavailable"}