Spaces:
Sleeping
Sleeping
File size: 2,570 Bytes
bb9b235 1ed9d03 bb9b235 1ed9d03 bb9b235 1ed9d03 d6a183d bb9b235 d6a183d bb9b235 1ed9d03 bb9b235 1ed9d03 bb9b235 1ed9d03 1609cc2 bb9b235 1ed9d03 bb9b235 4f95502 d6a183d bb9b235 d6a183d bb9b235 4f95502 d6a183d bb9b235 d6a183d bb9b235 4bec4bc bb9b235 1609cc2 bb9b235 d6a183d bb9b235 4bec4bc bb9b235 4f95502 1ed9d03 bb9b235 d6a183d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import pipeline
from typing import List
# Initialize FastAPI
app = FastAPI()
# Load your fine-tuned Longformer model
sentiment_pipeline = pipeline(
"text-classification",
model="spacesedan/reddit-sentiment-analysis-longformer"
)
# Request models
class SentimentRequest(BaseModel):
content_id: str
text: str
class BatchSentimentRequest(BaseModel):
posts: List[SentimentRequest]
# Response model
class SentimentResponse(BaseModel):
content_id: str
sentiment_score: float
sentiment_label: str
confidence: float
# Updated label-to-score mapping
LABEL_MAP = {
"very negative": -1.0,
"negative": -0.7,
"neutral": 0.0,
"positive": 0.7,
"very positive": 1.0
}
def normalize_prediction(label: str, confidence: float) -> (float, str):
label = label.lower()
score = LABEL_MAP.get(label, 0.0)
# Confidence-based fallback to neutral
if confidence < 0.6 and -0.7 < score < 0.7:
return 0.0, "neutral"
return score, label
@app.post("/analyze", response_model=SentimentResponse)
def analyze_sentiment(request: SentimentRequest):
try:
result = sentiment_pipeline(request.text)[0]
confidence = round(result["score"], 3)
sentiment_score, sentiment_label = normalize_prediction(result["label"], confidence)
return SentimentResponse(
content_id=request.content_id,
sentiment_score=sentiment_score,
sentiment_label=sentiment_label,
confidence=confidence
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/analyze_batch", response_model=List[SentimentResponse])
def analyze_sentiment_batch(request: BatchSentimentRequest):
try:
responses = []
for post in request.posts:
result = sentiment_pipeline(post.text)[0]
confidence = round(result["score"], 3)
sentiment_score, sentiment_label = normalize_prediction(result["label"], confidence)
responses.append(SentimentResponse(
content_id=post.content_id,
sentiment_score=sentiment_score,
sentiment_label=sentiment_label,
confidence=confidence
))
return responses
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
def root():
return {"message": "Reddit Sentiment Analysis API (Longformer 5-point) is running!"}
|