Spaces:
Sleeping
Sleeping
""" | |
File managing the text analysis functionality for the application. | |
""" | |
import torch | |
import torch.nn.functional as F | |
from datasets import Dataset | |
from models import load_model, device | |
def create_chunked_dataset(dataset, tokenizer, max_length=512, stride=256): | |
""" | |
Crée un nouveau dataset avec des chunks à partir du dataset original | |
""" | |
all_chunks = { | |
'input_ids': [], | |
'attention_mask': [], | |
'chunk_id': [], | |
'example_id': [], | |
'labels': [] | |
} | |
for idx, example in enumerate(dataset): | |
text = example['text'] | |
label = int(example['label']) | |
tokenized = tokenizer(text, truncation=False, padding=False) | |
input_ids = tokenized['input_ids'] | |
if len(input_ids) <= max_length: | |
tokenized = tokenizer( | |
text, | |
truncation=True, | |
padding='max_length', | |
max_length=max_length | |
) | |
all_chunks['input_ids'].append(tokenized['input_ids']) | |
all_chunks['attention_mask'].append(tokenized['attention_mask']) | |
all_chunks['chunk_id'].append(0) | |
all_chunks['example_id'].append(idx) | |
all_chunks['labels'].append(label) | |
else: | |
chunk_id = 0 | |
for i in range(0, len(input_ids), stride): | |
chunk = input_ids[i:i + max_length] | |
if len(chunk) < max_length // 2: | |
continue | |
# Padding to max_length | |
attention_mask = [1] * len(chunk) | |
if len(chunk) < max_length: | |
padding_length = max_length - len(chunk) | |
chunk = chunk + [tokenizer.pad_token_id] * padding_length | |
attention_mask = attention_mask + [0] * padding_length | |
all_chunks['input_ids'].append(chunk) | |
all_chunks['attention_mask'].append(attention_mask) | |
all_chunks['chunk_id'].append(chunk_id) | |
all_chunks['example_id'].append(idx) | |
all_chunks['labels'].append(label) | |
chunk_id += 1 | |
return Dataset.from_dict(all_chunks) | |
def analyze_text(text: str, model_name: str): | |
""" | |
Analyze the text for bias or neutrality using a selected classification model. | |
Args: | |
text (str) : Text to analyze. | |
model_name (str) : Name of the model to use for analysis. | |
Returns: | |
tuple (confidence_map, message) : Confidence map and analysis message. | |
""" | |
if not text.strip(): | |
return {"Empty text": 1.0}, "Please enter text to analyze." | |
try: | |
print("[Checkpoint] Starting classification...") | |
model, tokenizer = load_model(model_name) | |
mini_dataset = Dataset.from_dict({"text": [text], "label": [0]}) | |
chunked_dataset = create_chunked_dataset(mini_dataset, tokenizer) | |
print("[Checkpoint] Tokenization complete. Running model...") | |
model.eval() | |
model.to(device) | |
all_logits = [] | |
for i in range(len(chunked_dataset)): | |
chunk = chunked_dataset[i] | |
inputs = { | |
'input_ids': torch.tensor([chunk['input_ids']]).to(device), | |
'attention_mask': torch.tensor([chunk['attention_mask']]).to(device), | |
} | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
all_logits.append(outputs.logits[0].cpu().squeeze(0)) | |
stacked_logits = torch.stack(all_logits) | |
averaged_logits = torch.mean(stacked_logits, dim=0) | |
probs = F.softmax(averaged_logits, dim=0) | |
predicted_class = torch.argmax(averaged_logits).item() | |
confidence = probs[predicted_class].item() | |
status = "neutral" if predicted_class == 1 else "biased" | |
tag = "✅" if predicted_class == 1 else "⚠️" | |
message = f"{tag} The text is classified as {status} with a confidence of {confidence:.2%}." | |
confidence_map = {"Neutral": probs[1].item(), "Biased": probs[0].item()} | |
print(f"[Checkpoint] Classification complete. Predicted answer: {status}") | |
return confidence_map, message | |
except ValueError as e: | |
return {"Error": 1.0}, f"Configuration error: {str(e)}" | |
except RuntimeError as e: | |
return {"Error": 1.0}, f"Model error: {str(e)}" | |
except Exception as e: | |
return {"Error": 1.0}, f"Error analyzing text: {str(e)}" |