|
import gradio as gr |
|
import fitz |
|
import torch |
|
import numpy as np |
|
import os |
|
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, pipeline |
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
os.environ["GRADIO_SERVER_NAME"] = "0.0.0.0" |
|
os.environ["GRADIO_SERVER_PORT"] = "7860" |
|
|
|
|
|
PDF_PATH = "reg_2024.pdf" |
|
EMBEDDING_MODEL = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2" |
|
QA_MODEL = "mrm8488/bert-italian-finedtuned-squadv1-it-alfa" |
|
CHUNK_SIZE = 800 |
|
OVERLAP = 150 |
|
TOP_K_CHUNKS = 3 |
|
MIN_SCORE = 0.1 |
|
|
|
def load_models(): |
|
"""Carica i modelli con gestione ottimizzata della memoria""" |
|
try: |
|
|
|
model = AutoModelForQuestionAnswering.from_pretrained( |
|
QA_MODEL, |
|
device_map="auto", |
|
load_in_4bit=True if torch.cuda.is_available() else False, |
|
torch_dtype=torch.float16 |
|
) |
|
except ImportError: |
|
|
|
model = AutoModelForQuestionAnswering.from_pretrained(QA_MODEL) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(QA_MODEL) |
|
embedder = SentenceTransformer(EMBEDDING_MODEL) |
|
|
|
return model, tokenizer, embedder |
|
|
|
def process_pdf(): |
|
"""Elabora il PDF e crea gli embeddings""" |
|
text = "" |
|
with fitz.open(PDF_PATH) as doc: |
|
for page in doc: |
|
text += page.get_text().replace("\n", " ") + " " |
|
|
|
|
|
words = text.split() |
|
chunks = [ |
|
' '.join(words[i:i + CHUNK_SIZE]) |
|
for i in range(0, len(words), CHUNK_SIZE - OVERLAP) |
|
] |
|
|
|
|
|
embeddings = model_embed.encode(chunks, convert_to_tensor=True) |
|
|
|
return chunks, embeddings |
|
|
|
def semantic_search(query, chunks, embeddings): |
|
"""Ricerca semantica dei chunk più rilevanti""" |
|
query_embed = model_embed.encode(query, convert_to_tensor=True) |
|
scores = torch.nn.functional.cosine_similarity(query_embed, embeddings) |
|
top_indices = torch.topk(scores, k=TOP_K_CHUNKS).indices.cpu().numpy() |
|
return [chunks[i] for i in top_indices] |
|
|
|
def answer_question(question): |
|
"""Pipeline completa per la risposta""" |
|
try: |
|
|
|
relevant_chunks = semantic_search(question, doc_chunks, doc_embeddings) |
|
|
|
|
|
qa_pipe = pipeline( |
|
"question-answering", |
|
model=model_qa, |
|
tokenizer=tokenizer, |
|
device=0 if torch.cuda.is_available() else -1 |
|
) |
|
|
|
best_answer = {"answer": "Nessuna risposta trovata", "score": 0} |
|
|
|
|
|
for chunk in relevant_chunks: |
|
try: |
|
result = qa_pipe( |
|
question=question, |
|
context=chunk, |
|
max_answer_len=100, |
|
handle_impossible_answer=True |
|
) |
|
|
|
if result["score"] > best_answer["score"]: |
|
best_answer = result |
|
except Exception: |
|
continue |
|
|
|
if best_answer["score"] > MIN_SCORE: |
|
return best_answer["answer"] |
|
return "Nessuna risposta sufficientemente certa trovata nel documento" |
|
|
|
except Exception as e: |
|
return f"Errore durante l'elaborazione: {str(e)}" |
|
|
|
|
|
print("Caricamento modelli...") |
|
model_qa, tokenizer, model_embed = load_models() |
|
print("Elaborazione documento...") |
|
doc_chunks, doc_embeddings = process_pdf() |
|
|
|
|
|
with gr.Blocks(title="AI Esperto Regolamento Calcio") as demo: |
|
gr.Markdown("# ⚽ Assistente Virtuale Regolamento FIFA") |
|
gr.Markdown("Poni domande sul regolamento ufficiale del calcio") |
|
|
|
with gr.Row(): |
|
question = gr.Textbox( |
|
label="La tua domanda", |
|
placeholder="Es: Quando si assegna un calcio di rigore?", |
|
max_lines=2 |
|
) |
|
answer = gr.Textbox(label="Risposta ufficiale", interactive=False) |
|
|
|
examples = gr.Examples( |
|
examples=[ |
|
["Quanti cambi sono permessi a partita?"], |
|
["Cosa costituisce un fallo da cartellino rosso diretto?"], |
|
["Quali sono le dimensioni minime del campo?"] |
|
], |
|
inputs=[question], |
|
outputs=[answer], |
|
fn=answer_question, |
|
cache_examples=True |
|
) |
|
|
|
question.submit(fn=answer_question, inputs=[question], outputs=[answer]) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(show_error=True, share=True) |