Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from transformers import pipeline | |
from typing import List | |
# Initialize FastAPI | |
app = FastAPI() | |
# Load your fine-tuned Longformer model | |
sentiment_pipeline = pipeline( | |
"text-classification", | |
model="spacesedan/reddit-sentiment-analysis-longformer" | |
) | |
# Request models | |
class SentimentRequest(BaseModel): | |
content_id: str | |
text: str | |
class BatchSentimentRequest(BaseModel): | |
posts: List[SentimentRequest] | |
# Response model | |
class SentimentResponse(BaseModel): | |
content_id: str | |
sentiment_score: float | |
sentiment_label: str | |
confidence: float | |
# Updated label-to-score mapping | |
LABEL_MAP = { | |
"very negative": -1.0, | |
"negative": -0.7, | |
"neutral": 0.0, | |
"positive": 0.7, | |
"very positive": 1.0 | |
} | |
def normalize_prediction(label: str, confidence: float) -> (float, str): | |
label = label.lower() | |
score = LABEL_MAP.get(label, 0.0) | |
# Confidence-based fallback to neutral | |
if confidence < 0.6 and -0.7 < score < 0.7: | |
return 0.0, "neutral" | |
return score, label | |
def analyze_sentiment(request: SentimentRequest): | |
try: | |
result = sentiment_pipeline(request.text)[0] | |
confidence = round(result["score"], 3) | |
sentiment_score, sentiment_label = normalize_prediction(result["label"], confidence) | |
return SentimentResponse( | |
content_id=request.content_id, | |
sentiment_score=sentiment_score, | |
sentiment_label=sentiment_label, | |
confidence=confidence | |
) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
def analyze_sentiment_batch(request: BatchSentimentRequest): | |
try: | |
responses = [] | |
for post in request.posts: | |
result = sentiment_pipeline(post.text)[0] | |
confidence = round(result["score"], 3) | |
sentiment_score, sentiment_label = normalize_prediction(result["label"], confidence) | |
responses.append(SentimentResponse( | |
content_id=post.content_id, | |
sentiment_score=sentiment_score, | |
sentiment_label=sentiment_label, | |
confidence=confidence | |
)) | |
return responses | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
def root(): | |
return {"message": "Reddit Sentiment Analysis API (Longformer 5-point) is running!"} | |