symantyc_search / src /streamlit_app.py
Kapex13's picture
Update src/streamlit_app.py
f967e08 verified
import streamlit as st
import os
import pandas as pd
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer
from langchain_groq import ChatGroq
from langchain_core.messages import SystemMessage, HumanMessage
# === Кэшируемая загрузка датасета ===
@st.cache_resource(show_spinner=False)
def load_data():
# Загружаем уже сохранённый CSV, тот же, что в Colab
df = pd.read_csv("tvshows_processed2_cached.csv") # используем сохранённый файл
df["clean_text"] = df["clean_text"].astype(str).fillna("—")
return df
# === Инициализация эмбеддинговой модели ===
@st.cache_resource(show_spinner=True)
def init_embedder():
model_name = "sberbank-ai/sbert_large_nlu_ru"
embedder = SentenceTransformer(model_name)
return embedder
# === Загрузка готовых эмбеддингов и FAISS индекса из файлов ===
@st.cache_resource(show_spinner=True)
def load_embeddings_and_index():
# Проверка наличия файлов
if not os.path.exists("embeddings.npy") or not os.path.exists("faiss_index.index"):
st.error("Не найдены сохранённые файлы 'embeddings.npy' и/или 'faiss_index.index'")
st.stop()
# Загрузка
embeddings = np.load("embeddings.npy")
index = faiss.read_index("faiss_index.index")
return embeddings, index
# === Поиск близких результатов по семантическому запросу ===
def semantic_search(query, embedder, index, df, k=5):
query_embedding = embedder.encode([query])
faiss.normalize_L2(query_embedding)
distances, indices = index.search(query_embedding, k)
results = df.iloc[indices[0]].copy()
results["score"] = distances[0]
return results[["tvshow_title", "year", "genres", "description", "score"]]
# === Форматирование результатов для LLM ===
def format_docs_for_prompt(results_df):
formatted = []
for _, row in results_df.iterrows():
title = row["tvshow_title"]
year = row.get("year", "—")
genres = row.get("genres", "—")
description = row["description"][:300]
try:
year_disp = int(float(year))
except:
year_disp = "—"
formatted.append(f"""
{title} ({year_disp}) — {genres}
{description}...
""")
return "\n".join(formatted)
# === Кэшируемая инициализация Groq LLM клиента ===
@st.cache_resource(ttl=3600)
def init_groq_llm():
groq_key = st.secrets.get("GROQ_API_KEY") or st.text_input("🔐 Введите API-ключ Groq:", type="password")
if not groq_key:
st.warning("Введите ваш Groq API ключ для генерации ответов.")
return None
os.environ["GROQ_API_KEY"] = groq_key
llm = ChatGroq(model="deepseek-r1-distill-llama-70b", temperature=0, max_tokens=2000)
return llm
# === Генерация ответа от LLM с учётом результатов ===
def generate_rag_response(user_query, search_results, llm):
context = format_docs_for_prompt(search_results)
messages = [
SystemMessage(content="Ты — эксперт по фильмам и сериалам. Помоги с рекомендациями."),
HumanMessage(content=(f"Запрос: {user_query}\n\n"
f"Найденные подходящие фильмы и сериалы:\n{context}\n\n"
"Объясни, почему они подходят, и порекомендуй похожие. Ответь на русском языке."))
]
response = llm.invoke(messages)
return response.content.strip()
# === Streamlit UI ===
def main():
st.set_page_config(page_title="Поиск сериалов", layout="wide")
st.title("📽️ Семантический поиск фильмов и сериалов + Groq AI")
df = load_data()
embedder = init_embedder()
embeddings, index = load_embeddings_and_index()
llm = init_groq_llm()
query = st.text_input("🔎 Что хотите посмотреть? Опишите интересующую тему, сюжет или жанр:")
if query:
if llm is None:
st.error("Требуется API-ключ Groq для генерации рекомендаций.")
return
with st.spinner("⏳ Выполняется семантический поиск..."):
results = semantic_search(query, embedder, index, df)
if results.empty:
st.warning("Ничего не найдено по данному запросу.")
return
st.success(f"Найдено подходящих результатов: {len(results)}")
for i, row in results.iterrows():
st.subheader(f"{row['tvshow_title']} ({row['year']}) — {row['genres']}")
st.write(row['description'][:300] + "...")
st.caption(f"Сходство: {row['score']:.2%}")
st.markdown("---")
if st.button("🧠 Объясни, почему эти фильмы подходят и что ещё посмотреть"):
with st.spinner("🧾 Генерация ответа от Groq LLM..."):
answer = generate_rag_response(query, results, llm)
st.markdown("### 🤖 Рекомендации от AI:")
st.write(answer)
if __name__ == "__main__":
main()