File size: 2,086 Bytes
0d25792
 
 
 
 
6a28383
0d25792
 
 
 
771988f
0d25792
 
 
 
 
fc226a8
2d9580a
85e9a75
0d25792
 
 
 
 
 
e4fa82c
0d25792
 
 
 
 
 
 
 
 
 
 
 
 
 
6a28383
 
 
 
 
 
 
 
458c3f4
 
 
 
71d1f9b
0d25792
 
 
 
 
 
 
 
d46646b
0d25792
 
 
 
 
 
 
2c6cccc
0d25792
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
import numpy as np
import logging

# Проверка версии NumPy
assert np.__version__.startswith('1.'), f"Несовместимая версия NumPy: {np.__version__}"

app = FastAPI()

class RequestData(BaseModel):
    prompt: str
    max_tokens: int = 50

#MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
#MODEL_NAME = "ai-forever/rugpt3small_based_on_gpt2"
MODEL_NAME = "TinyLlama/TinyLlama_v1.1"

try:
    # Загрузка модели с явным указанием device_map
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.float16,
        device_map="auto",
        low_cpu_mem_usage=True
    )
    
    # Создаем pipeline без указания device
    generator = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer
    )
except Exception as e:
    print(f"Ошибка загрузки модели: {str(e)}")
    generator = None


logging.basicConfig(level=logging.INFO)

@app.on_event("startup")
async def startup_event():
    routes = [route.path for route in app.routes]
    print(f"Registered routes: {routes}")

@app.get("/")
async def root_health_check():
    return {"status": "ok"}

@app.post("/generate")
async def generate_text(request: RequestData):
    if not generator:
        raise HTTPException(status_code=503, detail="Модель не загружена")
    
    try:
        output = generator(
            request.prompt,
            max_new_tokens=min(request.max_tokens, 100),
            do_sample=False,
            num_beams=1,
            temperature=0.7,
        )
        return {"response": output[0]["generated_text"]}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/health")
async def health_check():
    return {"status": "ok" if generator else "unavailable"}