Spaces:
Sleeping
Sleeping
import gradio as gr | |
from huggingface_hub import hf_hub_download, list_repo_files | |
import faiss | |
import pandas as pd | |
import os | |
import json | |
from llama_index.core import Document, VectorStoreIndex, Settings | |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
from llama_index.llms.google_genai import GoogleGenAI | |
from llama_index.llms.openai import OpenAI | |
from llama_index.core.query_engine import RetrieverQueryEngine | |
from llama_index.core.retrievers import VectorIndexRetriever | |
from llama_index.core.response_synthesizers import get_response_synthesizer, ResponseMode | |
from llama_index.core.prompts import PromptTemplate | |
from llama_index.retrievers.bm25 import BM25Retriever | |
from sentence_transformers import CrossEncoder | |
from llama_index.core.retrievers import QueryFusionRetriever | |
import time | |
import sys | |
import logging | |
from config import * | |
REPO_ID = "MrSimple01/AIEXP_RAG_FILES" | |
faiss_index_filename = "cleaned_faiss_index.index" | |
chunks_filename = "processed_chunks.csv" | |
table_data_dir = "Табличные данные_JSON" | |
image_data_dir = "Изображения" | |
download_dir = "rag_files" | |
HF_TOKEN = os.getenv('HF_TOKEN') | |
# Global variables | |
query_engine = None | |
chunks_df = None | |
reranker = None | |
vector_index = None | |
current_model = DEFAULT_MODEL | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
def log_message(message): | |
logger.info(message) | |
print(message, flush=True) | |
sys.stdout.flush() | |
def get_llm_model(model_name): | |
"""Get LLM model instance based on model name""" | |
try: | |
model_config = AVAILABLE_MODELS.get(model_name) | |
if not model_config: | |
log_message(f"Модель {model_name} не найдена, использую модель по умолчанию") | |
model_config = AVAILABLE_MODELS[DEFAULT_MODEL] | |
if not model_config.get("api_key"): | |
raise Exception(f"API ключ не найден для модели {model_name}") | |
if model_config["provider"] == "google": | |
return GoogleGenAI( | |
model=model_config["model_name"], | |
api_key=model_config["api_key"] | |
) | |
elif model_config["provider"] == "openai": | |
return OpenAI( | |
model=model_config["model_name"], | |
api_key=model_config["api_key"] | |
) | |
else: | |
raise Exception(f"Неподдерживаемый провайдер: {model_config['provider']}") | |
except Exception as e: | |
log_message(f"Ошибка создания модели {model_name}: {str(e)}") | |
# Fallback to default Google model | |
return GoogleGenAI(model="gemini-2.0-flash", api_key=GOOGLE_API_KEY) | |
def switch_model(model_name): | |
"""Switch to a different LLM model""" | |
global query_engine, current_model | |
try: | |
log_message(f"Переключение на модель: {model_name}") | |
# Create new LLM instance | |
new_llm = get_llm_model(model_name) | |
Settings.llm = new_llm | |
# Recreate query engine with new model | |
if vector_index is not None: | |
recreate_query_engine() | |
current_model = model_name | |
log_message(f"Модель успешно переключена на: {model_name}") | |
return f"✅ Модель переключена на: {model_name}" | |
else: | |
return "❌ Ошибка: система не инициализирована" | |
except Exception as e: | |
error_msg = f"Ошибка переключения модели: {str(e)}" | |
log_message(error_msg) | |
return f"❌ {error_msg}" | |
def recreate_query_engine(): | |
"""Recreate query engine with current settings""" | |
global query_engine | |
try: | |
# Create BM25 retriever | |
bm25_retriever = BM25Retriever.from_defaults( | |
docstore=vector_index.docstore, | |
similarity_top_k=15 | |
) | |
# Create vector retriever | |
vector_retriever = VectorIndexRetriever( | |
index=vector_index, | |
similarity_top_k=20, | |
similarity_cutoff=0.5 | |
) | |
# Create hybrid retriever | |
hybrid_retriever = QueryFusionRetriever( | |
[vector_retriever, bm25_retriever], | |
similarity_top_k=30, | |
num_queries=1 | |
) | |
# Create response synthesizer | |
custom_prompt_template = PromptTemplate(CUSTOM_PROMPT) | |
response_synthesizer = get_response_synthesizer( | |
response_mode=ResponseMode.TREE_SUMMARIZE, | |
text_qa_template=custom_prompt_template | |
) | |
# Create new query engine | |
query_engine = RetrieverQueryEngine( | |
retriever=hybrid_retriever, | |
response_synthesizer=response_synthesizer | |
) | |
log_message("Query engine успешно пересоздан") | |
except Exception as e: | |
log_message(f"Ошибка пересоздания query engine: {str(e)}") | |
raise | |
def table_to_document(table_data, document_id=None): | |
content = "" | |
if isinstance(table_data, dict): | |
doc_id = document_id or table_data.get('document_id', table_data.get('document', 'Неизвестно')) | |
table_num = table_data.get('table_number', 'Неизвестно') | |
table_title = table_data.get('table_title', 'Неизвестно') | |
section = table_data.get('section', 'Неизвестно') | |
content += f"Таблица: {table_num}\n" | |
content += f"Название: {table_title}\n" | |
content += f"Документ: {doc_id}\n" | |
content += f"Раздел: {section}\n" | |
if 'data' in table_data and isinstance(table_data['data'], list): | |
for row in table_data['data']: | |
if isinstance(row, dict): | |
row_text = " | ".join([f"{k}: {v}" for k, v in row.items()]) | |
content += f"{row_text}\n" | |
return Document( | |
text=content, | |
metadata={ | |
"type": "table", | |
"table_number": table_data.get('table_number', 'unknown'), | |
"table_title": table_data.get('table_title', 'unknown'), | |
"document_id": doc_id or table_data.get('document_id', table_data.get('document', 'unknown')), | |
"section": table_data.get('section', 'unknown') | |
} | |
) | |
def download_table_data(): | |
log_message("Начинаю загрузку табличных данных") | |
table_files = [] | |
try: | |
files = list_repo_files(repo_id=REPO_ID, repo_type="dataset", token=HF_TOKEN) | |
for file in files: | |
if file.startswith(table_data_dir) and file.endswith('.json'): | |
table_files.append(file) | |
log_message(f"Найдено {len(table_files)} JSON файлов с таблицами") | |
table_documents = [] | |
for file_path in table_files: | |
try: | |
log_message(f"Обрабатываю файл: {file_path}") | |
local_path = hf_hub_download( | |
repo_id=REPO_ID, | |
filename=file_path, | |
local_dir='', | |
repo_type="dataset", | |
token=HF_TOKEN | |
) | |
with open(local_path, 'r', encoding='utf-8') as f: | |
table_data = json.load(f) | |
if isinstance(table_data, dict): | |
document_id = table_data.get('document', 'unknown') | |
if 'sheets' in table_data: | |
for sheet in table_data['sheets']: | |
sheet['document'] = document_id | |
doc = table_to_document(sheet, document_id) | |
table_documents.append(doc) | |
else: | |
doc = table_to_document(table_data, document_id) | |
table_documents.append(doc) | |
elif isinstance(table_data, list): | |
for table_json in table_data: | |
doc = table_to_document(table_json) | |
table_documents.append(doc) | |
except Exception as e: | |
log_message(f"Ошибка обработки файла {file_path}: {str(e)}") | |
continue | |
log_message(f"Создано {len(table_documents)} документов из таблиц") | |
return table_documents | |
except Exception as e: | |
log_message(f"Ошибка загрузки табличных данных: {str(e)}") | |
return [] | |
def download_image_data(): | |
log_message("Начинаю загрузку данных изображений") | |
image_files = [] | |
try: | |
files = list_repo_files(repo_id=REPO_ID, repo_type="dataset", token=HF_TOKEN) | |
for file in files: | |
if file.startswith(image_data_dir) and file.endswith('.csv'): | |
image_files.append(file) | |
log_message(f"Найдено {len(image_files)} CSV файлов с изображениями") | |
image_documents = [] | |
for file_path in image_files: | |
try: | |
log_message(f"Обрабатываю файл изображений: {file_path}") | |
local_path = hf_hub_download( | |
repo_id=REPO_ID, | |
filename=file_path, | |
local_dir='', | |
repo_type="dataset", | |
token=HF_TOKEN | |
) | |
df = pd.read_csv(local_path) | |
log_message(f"Загружено {len(df)} записей изображений из файла {file_path}") | |
for _, row in df.iterrows(): | |
content = f"Изображение: {row.get('№ Изображения', 'Неизвестно')}\n" | |
content += f"Название: {row.get('Название изображения', 'Неизвестно')}\n" | |
content += f"Описание: {row.get('Описание изображение', 'Неизвестно')}\n" | |
content += f"Документ: {row.get('Обозначение документа', 'Неизвестно')}\n" | |
content += f"Раздел: {row.get('Раздел документа', 'Неизвестно')}\n" | |
content += f"Файл: {row.get('Файл изображения', 'Неизвестно')}\n" | |
doc = Document( | |
text=content, | |
metadata={ | |
"type": "image", | |
"image_number": row.get('№ Изображения', 'unknown'), | |
"document_id": row.get('Обозначение документа', 'unknown'), | |
"file_path": row.get('Файл изображения', 'unknown'), | |
"section": row.get('Раздел документа', 'unknown') | |
} | |
) | |
image_documents.append(doc) | |
except Exception as e: | |
log_message(f"Ошибка обработки файла {file_path}: {str(e)}") | |
continue | |
log_message(f"Создано {len(image_documents)} документов из изображений") | |
return image_documents | |
except Exception as e: | |
log_message(f"Ошибка загрузки данных изображений: {str(e)}") | |
return [] | |
def initialize_models(): | |
global query_engine, chunks_df, reranker, vector_index, current_model | |
try: | |
log_message("Инициализация системы") | |
os.makedirs(download_dir, exist_ok=True) | |
log_message("Загружаю основные файлы") | |
chunks_csv_path = hf_hub_download( | |
repo_id=REPO_ID, | |
filename=chunks_filename, | |
local_dir=download_dir, | |
repo_type="dataset", | |
token=HF_TOKEN | |
) | |
log_message("Загружаю данные чанков") | |
chunks_df = pd.read_csv(chunks_csv_path) | |
log_message(f"Загружено {len(chunks_df)} чанков") | |
log_message("Инициализирую модели") | |
embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2") | |
llm = get_llm_model(current_model) | |
log_message("Инициализирую переранкер") | |
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2') | |
Settings.embed_model = embed_model | |
Settings.llm = llm | |
text_column = None | |
for col in chunks_df.columns: | |
if 'text' in col.lower() or 'content' in col.lower() or 'chunk' in col.lower(): | |
text_column = col | |
break | |
if text_column is None: | |
text_column = chunks_df.columns[0] | |
log_message(f"Использую колонку: {text_column}") | |
log_message("Создаю документы из чанков") | |
documents = [] | |
for i, (_, row) in enumerate(chunks_df.iterrows()): | |
doc = Document( | |
text=str(row[text_column]), | |
metadata={ | |
"chunk_id": row.get('chunk_id', i), | |
"document_id": row.get('document_id', 'unknown'), | |
"type": "text" | |
} | |
) | |
documents.append(doc) | |
log_message(f"Создано {len(documents)} текстовых документов") | |
log_message("Добавляю табличные данные") | |
table_documents = download_table_data() | |
documents.extend(table_documents) | |
log_message("Добавляю данные изображений") | |
image_documents = download_image_data() | |
documents.extend(image_documents) | |
log_message(f"Всего документов: {len(documents)}") | |
log_message("Строю векторный индекс") | |
vector_index = VectorStoreIndex.from_documents(documents) | |
# Create query engine | |
recreate_query_engine() | |
log_message(f"Система успешно инициализирована с моделью: {current_model}") | |
return True | |
except Exception as e: | |
log_message(f"Ошибка инициализации: {str(e)}") | |
return False | |
def rerank_nodes(query, nodes, top_k=10): | |
if not nodes or not reranker: | |
return nodes[:top_k] | |
try: | |
log_message(f"Переранжирую {len(nodes)} узлов") | |
pairs = [] | |
for node in nodes: | |
pairs.append([query, node.text]) | |
scores = reranker.predict(pairs) | |
scored_nodes = list(zip(nodes, scores)) | |
scored_nodes.sort(key=lambda x: x[1], reverse=True) | |
reranked_nodes = [node for node, score in scored_nodes[:top_k]] | |
log_message(f"Возвращаю топ-{len(reranked_nodes)} переранжированных узлов") | |
return reranked_nodes | |
except Exception as e: | |
log_message(f"Ошибка переранжировки: {str(e)}") | |
return nodes[:top_k] | |
def answer_question(question): | |
global query_engine, chunks_df, current_model | |
if query_engine is None: | |
return "<div style='background-color: #e53e3e; color: white; padding: 20px; border-radius: 10px;'>Система не инициализирована</div>", "" | |
try: | |
log_message(f"Получен вопрос: {question}") | |
log_message(f"Используется модель: {current_model}") | |
start_time = time.time() | |
log_message("Извлекаю релевантные узлы") | |
retrieved_nodes = query_engine.retriever.retrieve(question) | |
log_message(f"Извлечено {len(retrieved_nodes)} узлов") | |
log_message("Применяю переранжировку") | |
reranked_nodes = rerank_nodes(question, retrieved_nodes, top_k=10) | |
log_message(f"Отправляю запрос в LLM с {len(reranked_nodes)} узлами") | |
response = query_engine.query(question) | |
end_time = time.time() | |
processing_time = end_time - start_time | |
log_message(f"Обработка завершена за {processing_time:.2f} секунд") | |
sources_html = generate_sources_html(reranked_nodes) | |
answer_with_time = f"""<div style='background-color: #2d3748; color: white; padding: 20px; border-radius: 10px; margin-bottom: 10px;'> | |
<h3 style='color: #63b3ed; margin-top: 0;'>Ответ (Модель: {current_model}):</h3> | |
<div style='line-height: 1.6; font-size: 16px;'>{response.response}</div> | |
<div style='margin-top: 15px; padding-top: 10px; border-top: 1px solid #4a5568; font-size: 14px; color: #a0aec0;'> | |
Время обработки: {processing_time:.2f} секунд | |
</div> | |
</div>""" | |
return answer_with_time, sources_html | |
except Exception as e: | |
log_message(f"Ошибка обработки вопроса: {str(e)}") | |
error_msg = f"<div style='background-color: #e53e3e; color: white; padding: 20px; border-radius: 10px;'>Ошибка обработки вопроса: {str(e)}</div>" | |
return error_msg, "" | |
def generate_sources_html(nodes): | |
html = "<div style='background-color: #2d3748; color: white; padding: 20px; border-radius: 10px; max-height: 400px; overflow-y: auto;'>" | |
html += "<h3 style='color: #63b3ed; margin-top: 0;'>Источники:</h3>" | |
for i, node in enumerate(nodes): | |
metadata = node.metadata if hasattr(node, 'metadata') else {} | |
doc_type = metadata.get('type', 'text') | |
doc_id = metadata.get('document_id', 'unknown') | |
html += f"<div style='margin-bottom: 15px; padding: 15px; border: 1px solid #4a5568; border-radius: 8px; background-color: #1a202c;'>" | |
if doc_type == 'text': | |
html += f"<h4 style='margin: 0 0 10px 0; color: #63b3ed;'>📄 {doc_id}</h4>" | |
elif doc_type == 'table': | |
table_num = metadata.get('table_number', 'unknown') | |
if table_num and table_num != 'unknown': | |
if not table_num.startswith('№'): | |
table_num = f"№{table_num}" | |
html += f"<h4 style='margin: 0 0 10px 0; color: #68d391;'>📊 Таблица {table_num} - {doc_id}</h4>" | |
else: | |
html += f"<h4 style='margin: 0 0 10px 0; color: #68d391;'>📊 Таблица - {doc_id}</h4>" | |
elif doc_type == 'image': | |
image_num = metadata.get('image_number', 'unknown') | |
section = metadata.get('section', '') | |
if image_num and image_num != 'unknown': | |
if not str(image_num).startswith('№'): | |
image_num = f"№{image_num}" | |
html += f"<h4 style='margin: 0 0 10px 0; color: #fbb6ce;'>🖼️ Изображение {image_num} - {doc_id} ({section})</h4>" | |
else: | |
html += f"<h4 style='margin: 0 0 10px 0; color: #fbb6ce;'>🖼️ Изображение - {doc_id} ({section})</h4>" | |
if chunks_df is not None and 'file_link' in chunks_df.columns and doc_type == 'text': | |
doc_rows = chunks_df[chunks_df['document_id'] == doc_id] | |
if not doc_rows.empty: | |
file_link = doc_rows.iloc[0]['file_link'] | |
html += f"<a href='{file_link}' target='_blank' style='color: #68d391; text-decoration: none; font-size: 14px; display: inline-block; margin-top: 10px;'>🔗 Ссылка на документ</a><br>" | |
html += "</div>" | |
html += "</div>" | |
return html | |
def create_demo_interface(): | |
with gr.Blocks(title="AIEXP - AI Expert для нормативной документации", theme=gr.themes.Soft()) as demo: | |
gr.Markdown(""" | |
# AIEXP - Artificial Intelligence Expert | |
## Инструмент для работы с нормативной документацией | |
""") | |
with gr.Tab("🏠 Поиск по нормативным документам"): | |
gr.Markdown("### Задайте вопрос по нормативной документации") | |
# Model selection section | |
with gr.Row(): | |
with gr.Column(scale=2): | |
model_dropdown = gr.Dropdown( | |
choices=list(AVAILABLE_MODELS.keys()), | |
value=current_model, | |
label="🤖 Выберите языковую модель", | |
info="Выберите модель для генерации ответов" | |
) | |
with gr.Column(scale=1): | |
switch_btn = gr.Button("🔄 Переключить модель", variant="secondary") | |
model_status = gr.Textbox( | |
value=f"Текущая модель: {current_model}", | |
label="Статус модели", | |
interactive=False | |
) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
question_input = gr.Textbox( | |
label="Ваш вопрос к базе знаний", | |
placeholder="Введите вопрос по нормативным документам...", | |
lines=3 | |
) | |
ask_btn = gr.Button("🔍 Найти ответ", variant="primary", size="lg") | |
gr.Examples( | |
examples=[ | |
"О чем этот рисунок: ГОСТ Р 50.04.07-2022 Приложение Л. Л.1.5 Рисунок Л.2", | |
"Л.9 Формула в ГОСТ Р 50.04.07 - 2022 что и о чем там?", | |
"Какой стандарт устанавливает порядок признания протоколов испытаний продукции в области использования атомной энергии?", | |
"Кто несет ответственность за организацию и проведение признания протоколов испытаний продукции?", | |
"В каких случаях могут быть признаны протоколы испытаний, проведенные лабораториями?", | |
], | |
inputs=question_input | |
) | |
with gr.Row(): | |
with gr.Column(scale=2): | |
answer_output = gr.HTML( | |
label="", | |
value=f"<div style='background-color: #2d3748; color: white; padding: 20px; border-radius: 10px; text-align: center;'>Здесь появится ответ на ваш вопрос...<br><small>Текущая модель: {current_model}</small></div>", | |
) | |
with gr.Column(scale=1): | |
sources_output = gr.HTML( | |
label="", | |
value="<div style='background-color: #2d3748; color: white; padding: 20px; border-radius: 10px; text-align: center;'>Здесь появятся источники...</div>", | |
) | |
# Event handlers | |
def update_model_status(new_model): | |
result = switch_model(new_model) | |
return result | |
switch_btn.click( | |
fn=update_model_status, | |
inputs=[model_dropdown], | |
outputs=[model_status] | |
) | |
ask_btn.click( | |
fn=answer_question, | |
inputs=[question_input], | |
outputs=[answer_output, sources_output] | |
) | |
question_input.submit( | |
fn=answer_question, | |
inputs=[question_input], | |
outputs=[answer_output, sources_output] | |
) | |
return demo | |
if __name__ == "__main__": | |
log_message("Запуск AIEXP - AI Expert для нормативной документации") | |
if initialize_models(): | |
log_message("Запуск веб-интерфейса") | |
demo = create_demo_interface() | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=True, | |
debug=False | |
) | |
else: | |
log_message("Невозможно запустить приложение из-за ошибки инициализации") | |
sys.exit(1) |