RAG_AIEXP_0 / app.py
MrSimple01's picture
Update app.py
242ac42 verified
import gradio as gr
from huggingface_hub import hf_hub_download, list_repo_files
import faiss
import pandas as pd
import os
import json
from llama_index.core import Document, VectorStoreIndex, Settings
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.llms.google_genai import GoogleGenAI
from llama_index.llms.openai import OpenAI
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.response_synthesizers import get_response_synthesizer, ResponseMode
from llama_index.core.prompts import PromptTemplate
from llama_index.retrievers.bm25 import BM25Retriever
from sentence_transformers import CrossEncoder
from llama_index.core.retrievers import QueryFusionRetriever
import time
import sys
import logging
from config import *
REPO_ID = "MrSimple01/AIEXP_RAG_FILES"
faiss_index_filename = "cleaned_faiss_index.index"
chunks_filename = "processed_chunks.csv"
table_data_dir = "Табличные данные_JSON"
image_data_dir = "Изображения"
download_dir = "rag_files"
HF_TOKEN = os.getenv('HF_TOKEN')
# Global variables
query_engine = None
chunks_df = None
reranker = None
vector_index = None
current_model = DEFAULT_MODEL
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def log_message(message):
logger.info(message)
print(message, flush=True)
sys.stdout.flush()
def get_llm_model(model_name):
"""Get LLM model instance based on model name"""
try:
model_config = AVAILABLE_MODELS.get(model_name)
if not model_config:
log_message(f"Модель {model_name} не найдена, использую модель по умолчанию")
model_config = AVAILABLE_MODELS[DEFAULT_MODEL]
if not model_config.get("api_key"):
raise Exception(f"API ключ не найден для модели {model_name}")
if model_config["provider"] == "google":
return GoogleGenAI(
model=model_config["model_name"],
api_key=model_config["api_key"]
)
elif model_config["provider"] == "openai":
return OpenAI(
model=model_config["model_name"],
api_key=model_config["api_key"]
)
else:
raise Exception(f"Неподдерживаемый провайдер: {model_config['provider']}")
except Exception as e:
log_message(f"Ошибка создания модели {model_name}: {str(e)}")
# Fallback to default Google model
return GoogleGenAI(model="gemini-2.0-flash", api_key=GOOGLE_API_KEY)
def switch_model(model_name):
"""Switch to a different LLM model"""
global query_engine, current_model
try:
log_message(f"Переключение на модель: {model_name}")
# Create new LLM instance
new_llm = get_llm_model(model_name)
Settings.llm = new_llm
# Recreate query engine with new model
if vector_index is not None:
recreate_query_engine()
current_model = model_name
log_message(f"Модель успешно переключена на: {model_name}")
return f"✅ Модель переключена на: {model_name}"
else:
return "❌ Ошибка: система не инициализирована"
except Exception as e:
error_msg = f"Ошибка переключения модели: {str(e)}"
log_message(error_msg)
return f"❌ {error_msg}"
def recreate_query_engine():
"""Recreate query engine with current settings"""
global query_engine
try:
# Create BM25 retriever
bm25_retriever = BM25Retriever.from_defaults(
docstore=vector_index.docstore,
similarity_top_k=15
)
# Create vector retriever
vector_retriever = VectorIndexRetriever(
index=vector_index,
similarity_top_k=20,
similarity_cutoff=0.5
)
# Create hybrid retriever
hybrid_retriever = QueryFusionRetriever(
[vector_retriever, bm25_retriever],
similarity_top_k=30,
num_queries=1
)
# Create response synthesizer
custom_prompt_template = PromptTemplate(CUSTOM_PROMPT)
response_synthesizer = get_response_synthesizer(
response_mode=ResponseMode.TREE_SUMMARIZE,
text_qa_template=custom_prompt_template
)
# Create new query engine
query_engine = RetrieverQueryEngine(
retriever=hybrid_retriever,
response_synthesizer=response_synthesizer
)
log_message("Query engine успешно пересоздан")
except Exception as e:
log_message(f"Ошибка пересоздания query engine: {str(e)}")
raise
def table_to_document(table_data, document_id=None):
content = ""
if isinstance(table_data, dict):
doc_id = document_id or table_data.get('document_id', table_data.get('document', 'Неизвестно'))
table_num = table_data.get('table_number', 'Неизвестно')
table_title = table_data.get('table_title', 'Неизвестно')
section = table_data.get('section', 'Неизвестно')
content += f"Таблица: {table_num}\n"
content += f"Название: {table_title}\n"
content += f"Документ: {doc_id}\n"
content += f"Раздел: {section}\n"
if 'data' in table_data and isinstance(table_data['data'], list):
for row in table_data['data']:
if isinstance(row, dict):
row_text = " | ".join([f"{k}: {v}" for k, v in row.items()])
content += f"{row_text}\n"
return Document(
text=content,
metadata={
"type": "table",
"table_number": table_data.get('table_number', 'unknown'),
"table_title": table_data.get('table_title', 'unknown'),
"document_id": doc_id or table_data.get('document_id', table_data.get('document', 'unknown')),
"section": table_data.get('section', 'unknown')
}
)
def download_table_data():
log_message("Начинаю загрузку табличных данных")
table_files = []
try:
files = list_repo_files(repo_id=REPO_ID, repo_type="dataset", token=HF_TOKEN)
for file in files:
if file.startswith(table_data_dir) and file.endswith('.json'):
table_files.append(file)
log_message(f"Найдено {len(table_files)} JSON файлов с таблицами")
table_documents = []
for file_path in table_files:
try:
log_message(f"Обрабатываю файл: {file_path}")
local_path = hf_hub_download(
repo_id=REPO_ID,
filename=file_path,
local_dir='',
repo_type="dataset",
token=HF_TOKEN
)
with open(local_path, 'r', encoding='utf-8') as f:
table_data = json.load(f)
if isinstance(table_data, dict):
document_id = table_data.get('document', 'unknown')
if 'sheets' in table_data:
for sheet in table_data['sheets']:
sheet['document'] = document_id
doc = table_to_document(sheet, document_id)
table_documents.append(doc)
else:
doc = table_to_document(table_data, document_id)
table_documents.append(doc)
elif isinstance(table_data, list):
for table_json in table_data:
doc = table_to_document(table_json)
table_documents.append(doc)
except Exception as e:
log_message(f"Ошибка обработки файла {file_path}: {str(e)}")
continue
log_message(f"Создано {len(table_documents)} документов из таблиц")
return table_documents
except Exception as e:
log_message(f"Ошибка загрузки табличных данных: {str(e)}")
return []
def download_image_data():
log_message("Начинаю загрузку данных изображений")
image_files = []
try:
files = list_repo_files(repo_id=REPO_ID, repo_type="dataset", token=HF_TOKEN)
for file in files:
if file.startswith(image_data_dir) and file.endswith('.csv'):
image_files.append(file)
log_message(f"Найдено {len(image_files)} CSV файлов с изображениями")
image_documents = []
for file_path in image_files:
try:
log_message(f"Обрабатываю файл изображений: {file_path}")
local_path = hf_hub_download(
repo_id=REPO_ID,
filename=file_path,
local_dir='',
repo_type="dataset",
token=HF_TOKEN
)
df = pd.read_csv(local_path)
log_message(f"Загружено {len(df)} записей изображений из файла {file_path}")
for _, row in df.iterrows():
content = f"Изображение: {row.get('№ Изображения', 'Неизвестно')}\n"
content += f"Название: {row.get('Название изображения', 'Неизвестно')}\n"
content += f"Описание: {row.get('Описание изображение', 'Неизвестно')}\n"
content += f"Документ: {row.get('Обозначение документа', 'Неизвестно')}\n"
content += f"Раздел: {row.get('Раздел документа', 'Неизвестно')}\n"
content += f"Файл: {row.get('Файл изображения', 'Неизвестно')}\n"
doc = Document(
text=content,
metadata={
"type": "image",
"image_number": row.get('№ Изображения', 'unknown'),
"document_id": row.get('Обозначение документа', 'unknown'),
"file_path": row.get('Файл изображения', 'unknown'),
"section": row.get('Раздел документа', 'unknown')
}
)
image_documents.append(doc)
except Exception as e:
log_message(f"Ошибка обработки файла {file_path}: {str(e)}")
continue
log_message(f"Создано {len(image_documents)} документов из изображений")
return image_documents
except Exception as e:
log_message(f"Ошибка загрузки данных изображений: {str(e)}")
return []
def initialize_models():
global query_engine, chunks_df, reranker, vector_index, current_model
try:
log_message("Инициализация системы")
os.makedirs(download_dir, exist_ok=True)
log_message("Загружаю основные файлы")
chunks_csv_path = hf_hub_download(
repo_id=REPO_ID,
filename=chunks_filename,
local_dir=download_dir,
repo_type="dataset",
token=HF_TOKEN
)
log_message("Загружаю данные чанков")
chunks_df = pd.read_csv(chunks_csv_path)
log_message(f"Загружено {len(chunks_df)} чанков")
log_message("Инициализирую модели")
embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
llm = get_llm_model(current_model)
log_message("Инициализирую переранкер")
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
Settings.embed_model = embed_model
Settings.llm = llm
text_column = None
for col in chunks_df.columns:
if 'text' in col.lower() or 'content' in col.lower() or 'chunk' in col.lower():
text_column = col
break
if text_column is None:
text_column = chunks_df.columns[0]
log_message(f"Использую колонку: {text_column}")
log_message("Создаю документы из чанков")
documents = []
for i, (_, row) in enumerate(chunks_df.iterrows()):
doc = Document(
text=str(row[text_column]),
metadata={
"chunk_id": row.get('chunk_id', i),
"document_id": row.get('document_id', 'unknown'),
"type": "text"
}
)
documents.append(doc)
log_message(f"Создано {len(documents)} текстовых документов")
log_message("Добавляю табличные данные")
table_documents = download_table_data()
documents.extend(table_documents)
log_message("Добавляю данные изображений")
image_documents = download_image_data()
documents.extend(image_documents)
log_message(f"Всего документов: {len(documents)}")
log_message("Строю векторный индекс")
vector_index = VectorStoreIndex.from_documents(documents)
# Create query engine
recreate_query_engine()
log_message(f"Система успешно инициализирована с моделью: {current_model}")
return True
except Exception as e:
log_message(f"Ошибка инициализации: {str(e)}")
return False
def rerank_nodes(query, nodes, top_k=10):
if not nodes or not reranker:
return nodes[:top_k]
try:
log_message(f"Переранжирую {len(nodes)} узлов")
pairs = []
for node in nodes:
pairs.append([query, node.text])
scores = reranker.predict(pairs)
scored_nodes = list(zip(nodes, scores))
scored_nodes.sort(key=lambda x: x[1], reverse=True)
reranked_nodes = [node for node, score in scored_nodes[:top_k]]
log_message(f"Возвращаю топ-{len(reranked_nodes)} переранжированных узлов")
return reranked_nodes
except Exception as e:
log_message(f"Ошибка переранжировки: {str(e)}")
return nodes[:top_k]
def answer_question(question):
global query_engine, chunks_df, current_model
if query_engine is None:
return "<div style='background-color: #e53e3e; color: white; padding: 20px; border-radius: 10px;'>Система не инициализирована</div>", ""
try:
log_message(f"Получен вопрос: {question}")
log_message(f"Используется модель: {current_model}")
start_time = time.time()
log_message("Извлекаю релевантные узлы")
retrieved_nodes = query_engine.retriever.retrieve(question)
log_message(f"Извлечено {len(retrieved_nodes)} узлов")
log_message("Применяю переранжировку")
reranked_nodes = rerank_nodes(question, retrieved_nodes, top_k=10)
log_message(f"Отправляю запрос в LLM с {len(reranked_nodes)} узлами")
response = query_engine.query(question)
end_time = time.time()
processing_time = end_time - start_time
log_message(f"Обработка завершена за {processing_time:.2f} секунд")
sources_html = generate_sources_html(reranked_nodes)
answer_with_time = f"""<div style='background-color: #2d3748; color: white; padding: 20px; border-radius: 10px; margin-bottom: 10px;'>
<h3 style='color: #63b3ed; margin-top: 0;'>Ответ (Модель: {current_model}):</h3>
<div style='line-height: 1.6; font-size: 16px;'>{response.response}</div>
<div style='margin-top: 15px; padding-top: 10px; border-top: 1px solid #4a5568; font-size: 14px; color: #a0aec0;'>
Время обработки: {processing_time:.2f} секунд
</div>
</div>"""
return answer_with_time, sources_html
except Exception as e:
log_message(f"Ошибка обработки вопроса: {str(e)}")
error_msg = f"<div style='background-color: #e53e3e; color: white; padding: 20px; border-radius: 10px;'>Ошибка обработки вопроса: {str(e)}</div>"
return error_msg, ""
def generate_sources_html(nodes):
html = "<div style='background-color: #2d3748; color: white; padding: 20px; border-radius: 10px; max-height: 400px; overflow-y: auto;'>"
html += "<h3 style='color: #63b3ed; margin-top: 0;'>Источники:</h3>"
for i, node in enumerate(nodes):
metadata = node.metadata if hasattr(node, 'metadata') else {}
doc_type = metadata.get('type', 'text')
doc_id = metadata.get('document_id', 'unknown')
html += f"<div style='margin-bottom: 15px; padding: 15px; border: 1px solid #4a5568; border-radius: 8px; background-color: #1a202c;'>"
if doc_type == 'text':
html += f"<h4 style='margin: 0 0 10px 0; color: #63b3ed;'>📄 {doc_id}</h4>"
elif doc_type == 'table':
table_num = metadata.get('table_number', 'unknown')
if table_num and table_num != 'unknown':
if not table_num.startswith('№'):
table_num = f"№{table_num}"
html += f"<h4 style='margin: 0 0 10px 0; color: #68d391;'>📊 Таблица {table_num} - {doc_id}</h4>"
else:
html += f"<h4 style='margin: 0 0 10px 0; color: #68d391;'>📊 Таблица - {doc_id}</h4>"
elif doc_type == 'image':
image_num = metadata.get('image_number', 'unknown')
section = metadata.get('section', '')
if image_num and image_num != 'unknown':
if not str(image_num).startswith('№'):
image_num = f"№{image_num}"
html += f"<h4 style='margin: 0 0 10px 0; color: #fbb6ce;'>🖼️ Изображение {image_num} - {doc_id} ({section})</h4>"
else:
html += f"<h4 style='margin: 0 0 10px 0; color: #fbb6ce;'>🖼️ Изображение - {doc_id} ({section})</h4>"
if chunks_df is not None and 'file_link' in chunks_df.columns and doc_type == 'text':
doc_rows = chunks_df[chunks_df['document_id'] == doc_id]
if not doc_rows.empty:
file_link = doc_rows.iloc[0]['file_link']
html += f"<a href='{file_link}' target='_blank' style='color: #68d391; text-decoration: none; font-size: 14px; display: inline-block; margin-top: 10px;'>🔗 Ссылка на документ</a><br>"
html += "</div>"
html += "</div>"
return html
def create_demo_interface():
with gr.Blocks(title="AIEXP - AI Expert для нормативной документации", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# AIEXP - Artificial Intelligence Expert
## Инструмент для работы с нормативной документацией
""")
with gr.Tab("🏠 Поиск по нормативным документам"):
gr.Markdown("### Задайте вопрос по нормативной документации")
# Model selection section
with gr.Row():
with gr.Column(scale=2):
model_dropdown = gr.Dropdown(
choices=list(AVAILABLE_MODELS.keys()),
value=current_model,
label="🤖 Выберите языковую модель",
info="Выберите модель для генерации ответов"
)
with gr.Column(scale=1):
switch_btn = gr.Button("🔄 Переключить модель", variant="secondary")
model_status = gr.Textbox(
value=f"Текущая модель: {current_model}",
label="Статус модели",
interactive=False
)
with gr.Row():
with gr.Column(scale=3):
question_input = gr.Textbox(
label="Ваш вопрос к базе знаний",
placeholder="Введите вопрос по нормативным документам...",
lines=3
)
ask_btn = gr.Button("🔍 Найти ответ", variant="primary", size="lg")
gr.Examples(
examples=[
"О чем этот рисунок: ГОСТ Р 50.04.07-2022 Приложение Л. Л.1.5 Рисунок Л.2",
"Л.9 Формула в ГОСТ Р 50.04.07 - 2022 что и о чем там?",
"Какой стандарт устанавливает порядок признания протоколов испытаний продукции в области использования атомной энергии?",
"Кто несет ответственность за организацию и проведение признания протоколов испытаний продукции?",
"В каких случаях могут быть признаны протоколы испытаний, проведенные лабораториями?",
],
inputs=question_input
)
with gr.Row():
with gr.Column(scale=2):
answer_output = gr.HTML(
label="",
value=f"<div style='background-color: #2d3748; color: white; padding: 20px; border-radius: 10px; text-align: center;'>Здесь появится ответ на ваш вопрос...<br><small>Текущая модель: {current_model}</small></div>",
)
with gr.Column(scale=1):
sources_output = gr.HTML(
label="",
value="<div style='background-color: #2d3748; color: white; padding: 20px; border-radius: 10px; text-align: center;'>Здесь появятся источники...</div>",
)
# Event handlers
def update_model_status(new_model):
result = switch_model(new_model)
return result
switch_btn.click(
fn=update_model_status,
inputs=[model_dropdown],
outputs=[model_status]
)
ask_btn.click(
fn=answer_question,
inputs=[question_input],
outputs=[answer_output, sources_output]
)
question_input.submit(
fn=answer_question,
inputs=[question_input],
outputs=[answer_output, sources_output]
)
return demo
if __name__ == "__main__":
log_message("Запуск AIEXP - AI Expert для нормативной документации")
if initialize_models():
log_message("Запуск веб-интерфейса")
demo = create_demo_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
debug=False
)
else:
log_message("Невозможно запустить приложение из-за ошибки инициализации")
sys.exit(1)