import os import torch import uvicorn from fastapi import FastAPI, HTTPException from fastapi.responses import JSONResponse from pydantic import BaseModel from transformers import pipeline # Utwórz instancję FastAPI app = FastAPI( title="Emotions PL API", description="API do oznaczaniem tagami emocji go-emotions-polish-gpt2-small-v0.0.1", version="1.0.0" ) # Ścieżka do modelu - Hugging Face automatycznie pobierze model MODEL_NAME = "nie3e/go-emotions-polish-gpt2-small-v0.0.1" generator = None # Zostanie załadowany później # Model wejściowy dla POST request class PredictRequest(BaseModel): prompt: str @app.on_event("startup") async def startup_event(): global generator if torch.cuda.is_available(): print("device: GPU") else: print("device: CPU") print(f"Ładowanie modelu: {MODEL_NAME}...") try: # Możesz dostosować device=0 (GPU) lub device=-1 (CPU) w zależności od wybranej maszyny Space # Free tier spaces usually run on CPU, unless you explicitly select a GPU. # It's safer to not specify device if you want it to auto-detect or default to CPU. generator = pipeline( "text-classification", model=MODEL_NAME, top_k=-1, # device=0 if torch.cuda.is_available() else -1 # Odkomentuj dla detekcji GPU ) print("Model załadowany pomyślnie!") except Exception as e: print(f"Błąd ładowania modelu: {e}") # Możesz zdecydować, czy aplikacja ma zakończyć działanie, czy kontynuować bez modelu # W przypadku błędu ładowania modelu, endpoint generacji tekstu będzie zwracał błąd generator = None # Ustaw na None, aby sygnalizować problem @app.get("/") async def root(): return {"message": "Polish emotions API is running!"} @app.post("/predict") async def predict(request: PredictRequest): if generator is None: raise HTTPException(status_code=503, detail="Model nie został załadowany lub wystąpił błąd.") try: generated_text = generator( request.prompt ) # Pipeline zwraca listę słowników, bierzemy pierwszy wynik response_data = generated_text[0] return JSONResponse( content=response_data, media_type="application/json; charset=utf-8" ) # return {"generated_text": generated_text[0]["generated_text"]} except Exception as e: raise HTTPException(status_code=500, detail=f"Błąd podczas generowania tekstu: {e}") # Uruchamianie serwera Uvicorn bezpośrednio (dla Dockera) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 7860)))