File size: 2,570 Bytes
bb9b235
1ed9d03
 
bb9b235
1ed9d03
bb9b235
1ed9d03
 
d6a183d
bb9b235
d6a183d
 
bb9b235
1ed9d03
bb9b235
1ed9d03
bb9b235
1ed9d03
1609cc2
bb9b235
 
1ed9d03
bb9b235
 
 
 
 
 
4f95502
d6a183d
bb9b235
d6a183d
 
 
 
 
bb9b235
4f95502
d6a183d
 
 
 
 
 
 
 
 
bb9b235
 
 
 
d6a183d
 
 
bb9b235
 
 
4bec4bc
bb9b235
 
 
 
1609cc2
bb9b235
 
 
 
 
 
d6a183d
 
 
bb9b235
 
 
4bec4bc
bb9b235
 
 
 
 
4f95502
1ed9d03
bb9b235
d6a183d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import pipeline
from typing import List

# Initialize FastAPI
app = FastAPI()

# Load your fine-tuned Longformer model
sentiment_pipeline = pipeline(
    "text-classification", 
    model="spacesedan/reddit-sentiment-analysis-longformer"
)

# Request models
class SentimentRequest(BaseModel):
    content_id: str
    text: str

class BatchSentimentRequest(BaseModel):
    posts: List[SentimentRequest]

# Response model
class SentimentResponse(BaseModel):
    content_id: str
    sentiment_score: float
    sentiment_label: str
    confidence: float

# Updated label-to-score mapping
LABEL_MAP = {
    "very negative": -1.0,
    "negative": -0.7,
    "neutral": 0.0,
    "positive": 0.7,
    "very positive": 1.0
}

def normalize_prediction(label: str, confidence: float) -> (float, str):
    label = label.lower()
    score = LABEL_MAP.get(label, 0.0)

    # Confidence-based fallback to neutral
    if confidence < 0.6 and -0.7 < score < 0.7:
        return 0.0, "neutral"
    return score, label

@app.post("/analyze", response_model=SentimentResponse)
def analyze_sentiment(request: SentimentRequest):
    try:
        result = sentiment_pipeline(request.text)[0]
        confidence = round(result["score"], 3)
        sentiment_score, sentiment_label = normalize_prediction(result["label"], confidence)

        return SentimentResponse(
            content_id=request.content_id,
            sentiment_score=sentiment_score,
            sentiment_label=sentiment_label,
            confidence=confidence
        )
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/analyze_batch", response_model=List[SentimentResponse])
def analyze_sentiment_batch(request: BatchSentimentRequest):
    try:
        responses = []
        for post in request.posts:
            result = sentiment_pipeline(post.text)[0]
            confidence = round(result["score"], 3)
            sentiment_score, sentiment_label = normalize_prediction(result["label"], confidence)

            responses.append(SentimentResponse(
                content_id=post.content_id,
                sentiment_score=sentiment_score,
                sentiment_label=sentiment_label,
                confidence=confidence
            ))
        return responses
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/")
def root():
    return {"message": "Reddit Sentiment Analysis API (Longformer 5-point) is running!"}