File size: 9,310 Bytes
0539585
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
924c36d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bc9208
 
0539585
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import re
import pandas as pd
import warnings
import os
from fastapi import FastAPI
from pydantic import BaseModel
import uvicorn

warnings.filterwarnings('ignore')

class ArabicProfanityTester:
    def __init__(self, model_name='Speccco/arabic_profanity_filter'):
        """Initialize the tester with model from Hugging Face Hub"""
        print(f"🔄 Loading model from Hugging Face Hub: {model_name}...")
        
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
            self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
            self.model.eval()
            
            print("✅ Model loaded successfully from Hugging Face Hub!")
            print(f"📊 Model configuration:")
            print(f"   - Model type: {type(self.model).__name__}")
            print(f"   - Number of labels: {self.model.config.num_labels}")
            print(f"   - Max position embeddings: {self.model.config.max_position_embeddings}")
            
        except Exception as e:
            print(f"❌ Failed to load model from Hub: {e}")
            print("🔄 Falling back to base AraBERT model...")
            
            # Fallback to base model
            base_model = "aubmindlab/bert-base-arabertv02"
            self.tokenizer = AutoTokenizer.from_pretrained(base_model)
            self.model = AutoModelForSequenceClassification.from_pretrained(
                base_model, 
                num_labels=2
            )
            self.model.eval()
            print("⚠️  Using base AraBERT model (not fine-tuned)")
        
    def preprocess_text(self, text):
        """Simple text preprocessing"""
        if pd.isna(text):
            return ""
        
        text = str(text)
        # Remove URLs, mentions, hashtags
        text = re.sub(r'http\S+|www\S+|https\S+', '', text, flags=re.MULTILINE)
        text = re.sub(r'@\w+|#\w+', '', text)
        # Remove emojis and other unicode symbols
        emoji_pattern = re.compile("["
            u"\U0001F600-\U0001F64F"  # emoticons
            u"\U0001F300-\U0001F5FF"  # symbols & pictographs
            u"\U0001F680-\U0001F6FF"  # transport & map symbols
            u"\U0001F1E0-\U0001F1FF"  # flags (iOS)
            u"\U00002702-\U000027B0"  # dingbats
            u"\U000024C2-\U0001F251"  # enclosed characters
            u"\U0001F900-\U0001F9FF"  # supplemental symbols
            u"\U0001FA00-\U0001FAFF"  # extended symbols
            u"\u2600-\u26FF"          # miscellaneous symbols
            u"\u2700-\u27BF"          # dingbats
            u"\uFE00-\uFE0F"          # variation selectors
            u"\u200D"                 # zero width joiner
            "]+", flags=re.UNICODE)
        text = emoji_pattern.sub(r'', text)
        # Remove English alphabets
        text = re.sub(r'[a-zA-Z]', '', text)
        # Remove extra whitespace
        text = re.sub(r'\s+', ' ', text).strip()
        
        return text
    
    def check_bad_words(self, text):
        """Check if text contains explicit bad Arabic/Egyptian words"""
        bad_words = [
            'شرموطة', 'خرا', 'زفت', 'أمك', 'يلعن دينك', 'متناك', 
            'منيك', 'نايك', 'طيز', 'عرص', 'قواد', 'وسخة', 'كسك', 
            'يا دين أمي', 'ابن وسخة'
        ]
        
        text_lower = text.lower()
        found_words = []
        
        for bad_word in bad_words:
            if bad_word.lower() in text_lower:
                found_words.append(bad_word)
        
        return len(found_words) > 0, found_words
    
    def predict(self, text, show_details=True):
        """Predict if text is offensive or not with bad words override"""
        # Preprocess text
        processed_text = self.preprocess_text(text)
        
        # Check for explicit bad words first
        has_bad_words, found_bad_words = self.check_bad_words(text)
        
        # Tokenize
        inputs = self.tokenizer(
            processed_text, 
            return_tensors='pt', 
            truncation=True, 
            max_length=256,
            padding=True
        )
        
        # Get model prediction
        with torch.no_grad():
            outputs = self.model(**inputs)
            logits = outputs.logits
            probabilities = torch.softmax(logits, dim=-1)
            model_predicted_class = torch.argmax(probabilities, dim=-1).item()
            model_confidence = probabilities[0][model_predicted_class].item()
        
        # Final decision: bad words override model prediction
        if has_bad_words:
            final_prediction = "Bad"
            final_class = 1  # Offensive
            override_reason = f"Contains explicit bad words: {', '.join(found_bad_words)}"
        else:
            final_prediction = "Good" if model_predicted_class == 0 else "Bad"
            final_class = model_predicted_class
            override_reason = None
        
        # Prepare result
        result = {
            'original_text': text,
            'processed_text': processed_text,
            'model_prediction': 'Offensive' if model_predicted_class == 1 else 'Non-Offensive',
            'model_confidence': model_confidence,
            'final_prediction': final_prediction,
            'final_class': final_class,
            'has_bad_words': has_bad_words,
            'found_bad_words': found_bad_words,
            'override_reason': override_reason,
            'probabilities': {
                'non_offensive': probabilities[0][0].item(),
                'offensive': probabilities[0][1].item()
            }
        }
        
        return result

class ProfanityRequest(BaseModel):
    text: str

class BatchProfanityRequest(BaseModel):
    texts: list[str]

app = FastAPI(
    title="Arabic Profanity Filter API",
    description="An API to detect profanity in Arabic text using a fine-tuned AraBERT model with rule-based override.",
    version="1.0.0",
    docs_url="/docs",
    redoc_url="/redoc"
)

# Initialize the tester globally
tester = None

@app.on_event("startup")
async def startup_event():
    """Initialize the model on startup"""
    global tester
    try:
        tester = ArabicProfanityTester()
        print("🚀 Arabic Profanity Filter API is ready!")
    except Exception as e:
        print(f"❌ Failed to load model: {e}")
        raise e

@app.get("/", tags=["General"])
def read_root():
    return {
        "message": "Welcome to the Arabic Profanity Filter API",
        "description": "Detects profanity in Arabic text using AraBERT model with rule-based override",
        "endpoints": {
            "predict": "/predict - Single text prediction",
            "batch": "/batch - Batch text prediction",
            "health": "/health - Health check",
            "docs": "/docs - API documentation"
        }
    }

@app.get("/health", tags=["General"])
def health_check():
    """Health check endpoint"""
    if tester is None:
        return {"status": "unhealthy", "message": "Model not loaded"}
    return {"status": "healthy", "message": "API is running"}

@app.post("/predict", tags=["Prediction"])
async def predict_profanity(request: ProfanityRequest):
    """
    Predicts if the given Arabic text contains profanity.
    
    - **text**: The Arabic text to analyze.
    
    Returns:
    - original_text: The input text
    - processed_text: Text after preprocessing
    - model_prediction: Model's prediction (Offensive/Non-Offensive)
    - model_confidence: Model's confidence score
    - final_prediction: Final result (Good/Bad) after rule-based override
    - has_bad_words: Whether explicit bad words were found
    - found_bad_words: List of bad words found
    - probabilities: Detailed probability scores
    """
    if tester is None:
        return {"error": "Model not loaded"}
    
    try:
        result = tester.predict(request.text, show_details=False)
        return result
    except Exception as e:
        return {"error": f"Prediction failed: {str(e)}"}

@app.post("/batch", tags=["Prediction"])
async def predict_batch_profanity(request: BatchProfanityRequest):
    """
    Predicts profanity for multiple Arabic texts.
    
    - **texts**: List of Arabic texts to analyze.
    
    Returns list of prediction results for each text.
    """
    if tester is None:
        return {"error": "Model not loaded"}
    
    try:
        results = []
        for text in request.texts:
            result = tester.predict(text, show_details=False)
            results.append(result)
        
        return {
            "predictions": results,
            "summary": {
                "total": len(results),
                "bad_count": sum(1 for r in results if r['final_prediction'] == 'Bad'),
                "good_count": sum(1 for r in results if r['final_prediction'] == 'Good'),
                "explicit_bad_words_count": sum(1 for r in results if r['has_bad_words'])
            }
        }
    except Exception as e:
        return {"error": f"Batch prediction failed: {str(e)}"}

if __name__ == "__main__":
    import os
    port = int(os.environ.get("PORT", 7860))
    uvicorn.run(app, host="0.0.0.0", port=port)