spacesedan's picture
update to use my trained model
d6a183d verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import pipeline
from typing import List
# Initialize FastAPI
app = FastAPI()
# Load your fine-tuned Longformer model
sentiment_pipeline = pipeline(
"text-classification",
model="spacesedan/reddit-sentiment-analysis-longformer"
)
# Request models
class SentimentRequest(BaseModel):
content_id: str
text: str
class BatchSentimentRequest(BaseModel):
posts: List[SentimentRequest]
# Response model
class SentimentResponse(BaseModel):
content_id: str
sentiment_score: float
sentiment_label: str
confidence: float
# Updated label-to-score mapping
LABEL_MAP = {
"very negative": -1.0,
"negative": -0.7,
"neutral": 0.0,
"positive": 0.7,
"very positive": 1.0
}
def normalize_prediction(label: str, confidence: float) -> (float, str):
label = label.lower()
score = LABEL_MAP.get(label, 0.0)
# Confidence-based fallback to neutral
if confidence < 0.6 and -0.7 < score < 0.7:
return 0.0, "neutral"
return score, label
@app.post("/analyze", response_model=SentimentResponse)
def analyze_sentiment(request: SentimentRequest):
try:
result = sentiment_pipeline(request.text)[0]
confidence = round(result["score"], 3)
sentiment_score, sentiment_label = normalize_prediction(result["label"], confidence)
return SentimentResponse(
content_id=request.content_id,
sentiment_score=sentiment_score,
sentiment_label=sentiment_label,
confidence=confidence
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/analyze_batch", response_model=List[SentimentResponse])
def analyze_sentiment_batch(request: BatchSentimentRequest):
try:
responses = []
for post in request.posts:
result = sentiment_pipeline(post.text)[0]
confidence = round(result["score"], 3)
sentiment_score, sentiment_label = normalize_prediction(result["label"], confidence)
responses.append(SentimentResponse(
content_id=post.content_id,
sentiment_score=sentiment_score,
sentiment_label=sentiment_label,
confidence=confidence
))
return responses
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
def root():
return {"message": "Reddit Sentiment Analysis API (Longformer 5-point) is running!"}