Spaces:
Sleeping
Sleeping
| # import streamlit as st | |
| # import json | |
| # import torch | |
| # from transformers import AutoTokenizer, AutoModel | |
| # import faiss | |
| # import google.generativeai as genai | |
| # from flashrank.Ranker import Ranker, RerankRequest | |
| # # Configure Google Generative AI API Key | |
| # genai.configure(api_key="AIzaSyArG3gnpZHnzi10mMSnyOMhzYJBeAZEJUs") # Replace with your API key | |
| # # Load and preprocess the uploaded file | |
| # def load_and_preprocess(uploaded_file): | |
| # data = json.load(uploaded_file) | |
| # passages = [f"Speaker: {item['speaker']}. Text: {item['text']}" | |
| # for item in data if item["text"].strip()] | |
| # return data, passages | |
| # # Load embedding model | |
| # def load_model(model_name="BAAI/bge-m3"): | |
| # tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # model = AutoModel.from_pretrained(model_name) | |
| # return tokenizer, model | |
| # # Generate embeddings | |
| # def generate_embeddings(passages, tokenizer, model, batch_size=10, device="cuda" if torch.cuda.is_available() else "cpu"): | |
| # model.to(device) | |
| # embeddings = [] | |
| # for i in range(0, len(passages), batch_size): | |
| # batch = passages[i:i + batch_size] | |
| # inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device) | |
| # with torch.no_grad(): | |
| # outputs = model(**inputs).last_hidden_state.mean(dim=1) | |
| # embeddings.append(outputs.cpu()) | |
| # embeddings = torch.cat(embeddings, dim=0) | |
| # return embeddings.numpy() | |
| # # Store embeddings in FAISS | |
| # def store_in_faiss(embeddings): | |
| # dimension = embeddings.shape[1] | |
| # index = faiss.IndexFlatL2(dimension) | |
| # index.add(embeddings) | |
| # return index | |
| # # Retrieve top-k passages | |
| # def retrieve_top_k(query, tokenizer, model, faiss_index, passages, k=5, device="cuda" if torch.cuda.is_available() else "cpu"): | |
| # model.to(device) | |
| # inputs = tokenizer([query], return_tensors="pt", padding=True, truncation=True, max_length=512).to(device) | |
| # with torch.no_grad(): | |
| # query_embedding = model(**inputs).last_hidden_state.mean(dim=1).cpu().numpy() | |
| # distances, indices = faiss_index.search(query_embedding, k) | |
| # retrieved_passages = [passages[i] for i in indices[0]] | |
| # return retrieved_passages | |
| # # Rerank passages using FlashRank Ranker | |
| # def rerank_passages(query, passages): | |
| # formatted_passages = [{"text": passage} for passage in passages] | |
| # ranker = Ranker(model_name="rank-T5-flan", cache_dir="/my_cache_dir") # Adjust cache directory as needed | |
| # rerank_request = RerankRequest(query=query, passages=formatted_passages) | |
| # results = ranker.rerank(rerank_request) | |
| # return results | |
| # # Generate a response using Gemini 1.5 Flash | |
| # def generate_response(reranked_passages, query): | |
| # context = " ".join([passage["text"] for passage in reranked_passages]) | |
| # input_text = f"Context: {context}\n\nQuestion: {query}\n\nAnswer:" | |
| # model = genai.GenerativeModel("gemini-1.5-flash") | |
| # response = model.generate_content(input_text) | |
| # return response.text | |
| # # Streamlit app | |
| # def main(): | |
| # st.set_page_config(page_title="Chatbot with Document Upload", layout="wide") | |
| # st.title("π Chatbot for Minutes of Meeting") | |
| # # Initialize session state | |
| # if "chat_history" not in st.session_state: | |
| # st.session_state.chat_history = [] | |
| # if "faiss_index" not in st.session_state: | |
| # st.session_state.faiss_index = None | |
| # if "passages" not in st.session_state: | |
| # st.session_state.passages = None | |
| # if "tokenizer" not in st.session_state or "model" not in st.session_state: | |
| # st.session_state.tokenizer, st.session_state.model = load_model() | |
| # # File uploader | |
| # uploaded_file = st.file_uploader("Upload a JSON file for processing", type=["json"]) | |
| # if uploaded_file: | |
| # st.write("Processing the file...") | |
| # data, passages = load_and_preprocess(uploaded_file) | |
| # st.session_state.passages = passages | |
| # # Generate embeddings and store in FAISS | |
| # tokenizer, model = st.session_state.tokenizer, st.session_state.model | |
| # embeddings = generate_embeddings(passages, tokenizer, model) | |
| # st.session_state.faiss_index = store_in_faiss(embeddings) | |
| # st.success("File processed and embeddings generated successfully!") | |
| # # Chat interface | |
| # if st.session_state.faiss_index: | |
| # st.header("Ask a Question") | |
| # user_query = st.text_input("Type your question here:") | |
| # if user_query: | |
| # # Retrieve and rerank passages | |
| # top_k_passages = retrieve_top_k(user_query, st.session_state.tokenizer, st.session_state.model, st.session_state.faiss_index, st.session_state.passages) | |
| # reranked_passages = rerank_passages(user_query, top_k_passages) | |
| # # Generate response | |
| # response = generate_response(reranked_passages, user_query) | |
| # # Display response | |
| # st.markdown(f"**Question:** {user_query}") | |
| # st.markdown(f"**Answer:** {response}") | |
| # # Update chat history | |
| # st.session_state.chat_history.append({"question": user_query, "answer": response}) | |
| # # Chat history | |
| # if st.session_state.chat_history: | |
| # st.header("Chat History") | |
| # for chat in st.session_state.chat_history: | |
| # st.markdown(f"**Q:** {chat['question']}") | |
| # st.markdown(f"**A:** {chat['answer']}") | |
| # # Run the app | |
| # if __name__ == "__main__": | |
| # main() | |
| import streamlit as st | |
| from streamlit_chat import message | |
| import json | |
| import torch | |
| from transformers import AutoTokenizer, AutoModel | |
| import faiss | |
| import google.generativeai as genai | |
| from flashrank.Ranker import Ranker, RerankRequest | |
| from langchain.memory import ConversationBufferMemory | |
| from pydantic import BaseModel,ConfigDict | |
| import requests | |
| genai.configure(api_key="AIzaSyArG3gnpZHnzi10mMSnyOMhzYJBeAZEJUs") | |
| class CustomMemory(ConversationBufferMemory): | |
| model_config = ConfigDict(arbitrary_types_allowed=True) | |
| def load_and_preprocess(uploaded_file): | |
| data = json.load(uploaded_file) | |
| passages = [f"Speaker: {item['speaker']}. Text: {item['text']}" | |
| for item in data if item["text"].strip()] | |
| return data, passages | |
| def generate_text_to_speech(text): | |
| API_KEY = 'sk_926210280a2b0e013545e33350ae35c73a080b0f24f9542e' | |
| VOICE_ID = 'TX3LPaxmHKxFdv7VOQHJ' | |
| url = f"https://api.elevenlabs.io/v1/text-to-speech/{VOICE_ID}" | |
| headers = { | |
| "Content-Type": "application/json", | |
| "xi-api-key": API_KEY | |
| } | |
| payload = { | |
| "text": text, | |
| "model_id": "eleven_monolingual_v1", | |
| "voice_settings": {"stability": 0.5, "similarity_boost": 0.75} | |
| } | |
| response = requests.post(url, headers=headers, json=payload) | |
| return response.content if response.status_code == 200 else None | |
| def load_model(model_name="BAAI/bge-m3"): | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModel.from_pretrained(model_name) | |
| return tokenizer, model | |
| def generate_embeddings(passages, tokenizer, model, batch_size=10, device="cuda" if torch.cuda.is_available() else "cpu"): | |
| model.to(device) | |
| embeddings = [] | |
| for i in range(0, len(passages), batch_size): | |
| batch = passages[i:i + batch_size] | |
| inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs).last_hidden_state.mean(dim=1) | |
| embeddings.append(outputs.cpu()) | |
| embeddings = torch.cat(embeddings, dim=0) | |
| return embeddings.numpy() | |
| def store_in_faiss(embeddings): | |
| dimension = embeddings.shape[1] | |
| index = faiss.IndexFlatL2(dimension) | |
| index.add(embeddings) | |
| return index | |
| def retrieve_top_k(query, tokenizer, model, faiss_index, passages, k=5, device="cuda" if torch.cuda.is_available() else "cpu"): | |
| model.to(device) | |
| inputs = tokenizer([query], return_tensors="pt", padding=True, truncation=True, max_length=512).to(device) | |
| with torch.no_grad(): | |
| query_embedding = model(**inputs).last_hidden_state.mean(dim=1).cpu().numpy() | |
| distances, indices = faiss_index.search(query_embedding, k) | |
| retrieved_passages = [passages[i] for i in indices[0]] | |
| return retrieved_passages | |
| def rerank_passages(query, passages): | |
| formatted_passages = [{"text": passage} for passage in passages] | |
| ranker = Ranker(model_name="rank-T5-flan", cache_dir="/app/.cache") # Adjust cache directory as needed | |
| rerank_request = RerankRequest(query=query, passages=formatted_passages) | |
| results = ranker.rerank(rerank_request) | |
| return results | |
| def generate_response(context, query): | |
| input_text = f"Context: {context}\n\nQuestion: {query}\n\nAnswer:" | |
| model = genai.GenerativeModel("gemini-1.5-flash") | |
| response = model.generate_content(input_text) | |
| return response.text | |
| def handle_userinput(user_question): | |
| top_k_passages = retrieve_top_k(user_question, st.session_state.tokenizer, st.session_state.model, st.session_state.faiss_index, st.session_state.passages) | |
| reranked_passages = rerank_passages(user_question, top_k_passages) | |
| context = " ".join([passage["text"] for passage in reranked_passages]) | |
| response = generate_response(context, user_question) | |
| st.session_state.memory.chat_memory.add_user_message(user_question) | |
| st.session_state.memory.chat_memory.add_ai_message(response) | |
| return response | |
| def main(): | |
| st.set_page_config(page_title="Chatbot with MoM Document Upload", layout="wide") | |
| st.title("π Chatbot for Minutes of Meeting ") | |
| if "memory" not in st.session_state: | |
| st.session_state.memory = CustomMemory(memory_key='chat_history', return_messages=True) | |
| if "faiss_index" not in st.session_state: | |
| st.session_state.faiss_index = None | |
| if "passages" not in st.session_state: | |
| st.session_state.passages = None | |
| if "tokenizer" not in st.session_state or "model" not in st.session_state: | |
| st.session_state.tokenizer, st.session_state.model = load_model() | |
| uploaded_file = st.file_uploader("Upload a JSON file for processing", type=["json"]) | |
| if uploaded_file: | |
| st.write("Processing the file...") | |
| data, passages = load_and_preprocess(uploaded_file) | |
| st.session_state.passages = passages | |
| tokenizer, model = st.session_state.tokenizer, st.session_state.model | |
| embeddings = generate_embeddings(passages, tokenizer, model) | |
| st.session_state.faiss_index = store_in_faiss(embeddings) | |
| st.success("File processed and embeddings generated successfully!") | |
| if st.session_state.faiss_index: | |
| st.header("Ask a Question") | |
| user_query = st.text_input("Type your question here:") | |
| play_audio = st.checkbox("π Generate audio for the response") | |
| if user_query: | |
| response = handle_userinput(user_query) | |
| if "chat_history_ui" not in st.session_state: | |
| st.session_state.chat_history_ui = [] | |
| st.session_state.chat_history_ui.append({"role": "user", "content": user_query}) | |
| st.session_state.chat_history_ui.append({"role": "bot", "content": response}) | |
| if play_audio: | |
| if "audio_cache" not in st.session_state: | |
| st.session_state.audio_cache = {} | |
| if response not in st.session_state.audio_cache: | |
| audio_bytes = generate_text_to_speech(response) | |
| st.session_state.audio_cache[response] = audio_bytes | |
| if "chat_history_ui" in st.session_state: | |
| for i,chat in enumerate(st.session_state.chat_history_ui): | |
| if chat["role"] == "user": | |
| message(chat["content"], is_user=True,key=f"user_{i}") | |
| else: | |
| message(chat["content"], is_user=False,key=f"bot_{i}") | |
| audio_bytes = st.session_state.get("audio_cache", {}).get(chat["content"]) | |
| if audio_bytes: | |
| st.audio(audio_bytes, format="audio/mpeg",start_time=0) | |
| else: | |
| if st.button(f"π Generate & Play Audio for Response {i}"): | |
| audio = generate_text_to_speech(chat["content"]) | |
| if "audio_cache" not in st.session_state: | |
| st.session_state.audio_cache = {} | |
| st.session_state.audio_cache[chat["content"]] = audio | |
| st.audio(audio, format="audio/mpeg", start_time=0) | |
| if __name__ == "__main__": | |
| main() | |