File size: 2,086 Bytes
0d25792 6a28383 0d25792 771988f 0d25792 fc226a8 2d9580a 85e9a75 0d25792 e4fa82c 0d25792 6a28383 458c3f4 71d1f9b 0d25792 d46646b 0d25792 2c6cccc 0d25792 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
import numpy as np
import logging
# Проверка версии NumPy
assert np.__version__.startswith('1.'), f"Несовместимая версия NumPy: {np.__version__}"
app = FastAPI()
class RequestData(BaseModel):
prompt: str
max_tokens: int = 50
#MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
#MODEL_NAME = "ai-forever/rugpt3small_based_on_gpt2"
MODEL_NAME = "TinyLlama/TinyLlama_v1.1"
try:
# Загрузка модели с явным указанием device_map
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16,
device_map="auto",
low_cpu_mem_usage=True
)
# Создаем pipeline без указания device
generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer
)
except Exception as e:
print(f"Ошибка загрузки модели: {str(e)}")
generator = None
logging.basicConfig(level=logging.INFO)
@app.on_event("startup")
async def startup_event():
routes = [route.path for route in app.routes]
print(f"Registered routes: {routes}")
@app.get("/")
async def root_health_check():
return {"status": "ok"}
@app.post("/generate")
async def generate_text(request: RequestData):
if not generator:
raise HTTPException(status_code=503, detail="Модель не загружена")
try:
output = generator(
request.prompt,
max_new_tokens=min(request.max_tokens, 100),
do_sample=False,
num_beams=1,
temperature=0.7,
)
return {"response": output[0]["generated_text"]}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
return {"status": "ok" if generator else "unavailable"} |