test / app.py
yummyu's picture
Update app.py
5746d3f verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import os
from typing import Optional
# キャッシュディレクトリの設定
os.environ["HF_HOME"] = "/app/.cache"
os.environ["TRANSFORMERS_CACHE"] = "/app/.cache"
os.environ["HUGGINGFACE_HUB_CACHE"] = "/app/.cache"
# キャッシュディレクトリを作成
cache_dir = "/app/.cache"
os.makedirs(cache_dir, exist_ok=True)
app = FastAPI(title="Lightweight Hugging Face API")
class TextRequest(BaseModel):
text: str
max_length: Optional[int] = 50
class SentimentResponse(BaseModel):
text: str
sentiment: str
confidence: float
model_name: str
class GenerateResponse(BaseModel):
input_text: str
generated_text: str
model_name: str
# グローバル変数
sentiment_classifier = None
text_generator = None
@app.on_event("startup")
async def load_models():
"""軽量モデルをロード"""
global sentiment_classifier, text_generator
print("🚀 軽量モデルのロード開始...")
try:
from transformers import pipeline
# 非常に軽量な感情分析モデル(cache_dirを削除)
print("📥 軽量感情分析モデルをロード中...")
sentiment_classifier = pipeline(
"sentiment-analysis",
model="prajjwal1/bert-tiny"
)
print("✅ 感情分析モデル読み込み完了")
# 軽量テキスト生成モデル(cache_dirを削除)
print("📥 軽量テキスト生成モデルをロード中...")
text_generator = pipeline(
"text-generation",
model="sshleifer/tiny-gpt2"
)
print("✅ テキスト生成モデル読み込み完了")
print("✅ 全てのモデル読み込み完了")
except Exception as e:
print(f"❌ モデル読み込みエラー: {e}")
import traceback
traceback.print_exc()
@app.get("/")
async def root():
return {
"message": "🤗 Lightweight Hugging Face API is running!",
"status": "healthy",
"models": "lightweight versions"
}
@app.post("/sentiment", response_model=SentimentResponse)
async def analyze_sentiment(request: TextRequest):
"""軽量感情分析"""
try:
if sentiment_classifier is None:
raise HTTPException(status_code=503, detail="Sentiment model not loaded")
result = sentiment_classifier(request.text)
return SentimentResponse(
text=request.text,
sentiment=result[0]["label"],
confidence=round(result[0]["score"], 4),
model_name="prajjwal1/bert-tiny"
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}")
@app.post("/generate", response_model=GenerateResponse)
async def generate_text(request: TextRequest):
"""軽量テキスト生成"""
try:
if text_generator is None:
raise HTTPException(status_code=503, detail="Generation model not loaded")
max_length = min(request.max_length, 100)
result = text_generator(
request.text,
max_length=max_length,
num_return_sequences=1,
temperature=0.7,
do_sample=True,
pad_token_id=text_generator.tokenizer.eos_token_id
)
return GenerateResponse(
input_text=request.text,
generated_text=result[0]["generated_text"],
model_name="sshleifer/tiny-gpt2"
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
@app.get("/models")
async def get_models():
return {
"sentiment_analysis": {
"model": "prajjwal1/bert-tiny",
"status": "loaded" if sentiment_classifier else "not loaded"
},
"text_generation": {
"model": "sshleifer/tiny-gpt2",
"status": "loaded" if text_generator else "not loaded"
},
"note": "Using lightweight models for Spaces compatibility"
}
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)