Nightwing11 commited on
Commit
852dd8d
·
1 Parent(s): 6961452

Resolveed correference betwen query

Browse files
Files changed (3) hide show
  1. Rag/rag_pipeline.py +3 -2
  2. requirements.txt +4 -2
  3. utils/corefrence.py +14 -11
Rag/rag_pipeline.py CHANGED
@@ -6,7 +6,7 @@ import os
6
  import logging
7
  from Llm.llm_endpoints import get_llm_response
8
  from utils.get_link import get_source_link
9
- # from Rag.corefrence import resolve_coreference_in_query
10
  # Configuration
11
  API_KEY = os.getenv("GOOGLE_API_KEY")
12
  if API_KEY:
@@ -134,7 +134,8 @@ def main_workflow(transcripts_folder_path, collection):
134
  if query_text.lower() == "exit":
135
  print("Ending the conversation. Goodbye")
136
  break
137
- query_text_with_conversation_history = enhance_query_with_history(query_text, conversation_history)
 
138
  # resolved_query = resolve_coreference_in_query(query_text_with_conversation_history, conversation_history)
139
  retrived_docs, metadatas = query_database(collection, query_text_with_conversation_history)
140
  print("-" * 50)
 
6
  import logging
7
  from Llm.llm_endpoints import get_llm_response
8
  from utils.get_link import get_source_link
9
+ from utils.corefrence import resolve_corefrence
10
  # Configuration
11
  API_KEY = os.getenv("GOOGLE_API_KEY")
12
  if API_KEY:
 
134
  if query_text.lower() == "exit":
135
  print("Ending the conversation. Goodbye")
136
  break
137
+ resolved_query = resolve_corefrence(query_text, conversation_history)
138
+ query_text_with_conversation_history = enhance_query_with_history(resolved_query, conversation_history)
139
  # resolved_query = resolve_coreference_in_query(query_text_with_conversation_history, conversation_history)
140
  retrived_docs, metadatas = query_database(collection, query_text_with_conversation_history)
141
  print("-" * 50)
requirements.txt CHANGED
@@ -6,11 +6,13 @@ langchain
6
  langchain_openai
7
  langchain_chroma
8
  langchain_community
9
- chromadb
10
  pypdf
11
  flask
12
  flask_cors
13
  sentence_transformers
14
  tqdm
15
  torch
16
- transformers
 
 
 
6
  langchain_openai
7
  langchain_chroma
8
  langchain_community
9
+ chromadb==0.4.8
10
  pypdf
11
  flask
12
  flask_cors
13
  sentence_transformers
14
  tqdm
15
  torch
16
+ transformers
17
+ spacy==3.5.0
18
+ coreferee==1.4.1
utils/corefrence.py CHANGED
@@ -1,11 +1,14 @@
1
- from transformers import pipeline
2
-
3
- coref_pipeline = pipeline("coref-resolution", model="coref-roberta-large")
4
-
5
-
6
- def resolve_coreference_in_query(query_text, conversation_history):
7
- context = "\n".join([f"User: {turn['user']}\nBot: {turn['bot']}" for turn in conversation_history])
8
- full_text = f"{context}\nUser: {query_text}"
9
- resolved_text = coref_pipeline(full_text)
10
- resolved_query = resolved_text.split("User:")[-1].strip()
11
- return resolved_query
 
 
 
 
1
+ import spacy
2
+ nlp = spacy.load('en_core_web_sm')
3
+ nlp.add_pipe("coreferee")
4
+ def resolve_corefrence(query_text, conversation_history):
5
+ combined_text = []
6
+ for turn in conversation_history:
7
+ combined_text.append(f"User:{turn['user']}")
8
+ combined_text.append(f"Bot:{turn['Bot']}")
9
+ combined_text.append(f"User:{query_text}")
10
+ combined_text = "\n".join(combined_text)
11
+ doc = nlp(combined_text)
12
+ resolved_text = doc._.corefrence_resolved
13
+ resolved_query = resolved_text.split('\n')[-1].replace("User: ", "")
14
+ return resolved_query.strip()