import os from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch import warnings from huggingface_hub import spaces # Suppress all warnings warnings.filterwarnings("ignore") os.environ["TRANSFORMERS_CACHE"] = "/tmp" # Initialize GPU for Hugging Face Spaces @spaces.GPU def init_gpu(): return torch.device("cuda" if torch.cuda.is_available() else "cpu") # Initialize model and tokenizer MODEL_NAME = "s-nlp/roberta-base-formality-ranker" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) # Move model to GPU device = init_gpu() model = model.to(device) app = FastAPI(title="Formality Classifier API") class TextInput(BaseModel): text: str def calculate_formality_percentages(score): # Convert score to grayscale percentage (0-100) grayscale = int(score * 100) # Use grayscale to determine formal/informal percentages formal_percent = grayscale informal_percent = 100 - grayscale return formal_percent, informal_percent @app.get("/") async def home(): return {"message": "Formality Classifier API is running! Use /predict to classify text."} @app.post("/predict") async def predict_formality(input_data: TextInput): try: # Tokenize input encoding = tokenizer(input_data.text, return_tensors="pt", truncation=True, padding=True) encoding = {k: v.to(device) for k, v in encoding.items()} # Predict formality score with torch.no_grad(): logits = model(**encoding).logits score = logits.softmax(dim=1)[:, 1].item() # Calculate percentages using grayscale formal_percent, informal_percent = calculate_formality_percentages(score) # Create response in the new format response = { "formality_score": round(score, 3), "formal_percent": formal_percent, "informal_percent": informal_percent, "classification": f"Your speech is {formal_percent}% formal and {informal_percent}% informal." } return response except Exception as e: raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)