from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import pipeline from typing import List # Initialize FastAPI app = FastAPI() # Load your fine-tuned Longformer model sentiment_pipeline = pipeline( "text-classification", model="spacesedan/reddit-sentiment-analysis-longformer" ) # Request models class SentimentRequest(BaseModel): content_id: str text: str class BatchSentimentRequest(BaseModel): posts: List[SentimentRequest] # Response model class SentimentResponse(BaseModel): content_id: str sentiment_score: float sentiment_label: str confidence: float # Updated label-to-score mapping LABEL_MAP = { "very negative": -1.0, "negative": -0.7, "neutral": 0.0, "positive": 0.7, "very positive": 1.0 } def normalize_prediction(label: str, confidence: float) -> (float, str): label = label.lower() score = LABEL_MAP.get(label, 0.0) # Confidence-based fallback to neutral if confidence < 0.6 and -0.7 < score < 0.7: return 0.0, "neutral" return score, label @app.post("/analyze", response_model=SentimentResponse) def analyze_sentiment(request: SentimentRequest): try: result = sentiment_pipeline(request.text)[0] confidence = round(result["score"], 3) sentiment_score, sentiment_label = normalize_prediction(result["label"], confidence) return SentimentResponse( content_id=request.content_id, sentiment_score=sentiment_score, sentiment_label=sentiment_label, confidence=confidence ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/analyze_batch", response_model=List[SentimentResponse]) def analyze_sentiment_batch(request: BatchSentimentRequest): try: responses = [] for post in request.posts: result = sentiment_pipeline(post.text)[0] confidence = round(result["score"], 3) sentiment_score, sentiment_label = normalize_prediction(result["label"], confidence) responses.append(SentimentResponse( content_id=post.content_id, sentiment_score=sentiment_score, sentiment_label=sentiment_label, confidence=confidence )) return responses except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/") def root(): return {"message": "Reddit Sentiment Analysis API (Longformer 5-point) is running!"}