Spaces:
Sleeping
Sleeping
File size: 11,894 Bytes
9d45390 e2672e3 9d45390 fd974ae 9d45390 fd974ae 9d45390 e2672e3 9d45390 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 |
from langchain_openai import ChatOpenAI
from dotenv import load_dotenv
import os
from langchain.schema import AIMessage, HumanMessage
import gradio as gr
from transformers import pipeline, AutoTokenizer, TFAutoModelForSeq2SeqLM
import subprocess
import torch
import tempfile
from langdetect import detect
from transformers import MarianMTModel, MarianTokenizer
import boto3
# Additional imports for loading PDF documents and QA chain.
from langchain_community.document_loaders import PyPDFLoader
# Additional imports for loading Wikipedia content and QA chain
from langchain_community.document_loaders import WikipediaLoader
from langchain.chains.question_answering import load_qa_chain
# Import RegEx for translate function, to split sentences in avoiding token limits
import re
#Get keys #########################################################################################
load_dotenv()
# Set the model name for our LLMs.
OPENAI_MODEL = "gpt-3.5-turbo"
# Store the API key in a variable.
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
#Define variables for AWS Polly Access#############################################################
aws_access_key_id = os.getenv('AWS_ACCESS_KEY_ID')
aws_secret_access_key = os.getenv('AWS_SECRET_ACCESS_KEY')
aws_default_region = os.getenv('AWS_DEFAULT_REGION')
#Define language variables ###################################################################################
#Define voice map
voice_map = {
"ar": "Hala",
"en": "Gregory",
"es": "Mia",
"fr": "Liam",
"de": "Vicki",
"it": "Bianca",
"zh": "Hiujin",
"hi": "Kajal",
"jap": "Tomoko",
"trk": "Burcu"
}
#Define language map from full names to ISO codes
language_map = {
"Arabic (Gulf)": "ar",
"Chinese (Cantonese)": "zh",
"English": "en",
"French": "fr",
"German": "de",
"Hindi": "hi",
"Italian": "it",
"Japanese": "jap",
"Spanish": "es",
"Turkish": "trk"
}
# list of languages and their codes for dropdown
languages = gr.Dropdown(
label="Click in the middle of the dropdown bar to select translation language",
choices=list(language_map.keys()))
#Define default language
default_language = "English"
#Setting the Chatbot Model #################################################################################
#Instantiating the llm we'll use and the arguments to pass
#This is done at a global level, and not within the definition of a function to improve
#the speed and efficiency of the app. Thus, the model will not be instantiated every time
#a new question is submitted. Similar setup is created for all of the models called. This
#was part of our optimization process to help the app be more efficient and effective.
llm = ChatOpenAI(openai_api_key=OPENAI_API_KEY, model_name=OPENAI_MODEL, temperature=0.0)
# Define the wikipedia topic as a string.
wiki_topic = "diabetes"
# Load the wikipedia results as documents, using a max of 2.
#included error handling- unable to load documents
try:
documents = WikipediaLoader(query=wiki_topic, load_max_docs=2, load_all_available_meta=True).load()
except Exception as e:
print("Failed to load documents:", str(e))
documents = []
# Create the QA chain using the LLM.
chain = load_qa_chain(llm)
##############################################################################################################
#Define the function to call the OpenAI chat LLM
def handle_query(user_query):
if not documents:
return "Source not loading info; please try again later."
if user_query.lower() == 'quit':
return "Goodbye!"
try:
# Pass the documents and the user's query to the chain, and return the result.
result = chain.invoke({"input_documents": documents, "question": user_query})
return result["output_text"] if result["output_text"].strip() else "No answer found, try a different question."
except Exception as e:
return "An error occurred while searching for the answer: " + str(e)
#Language models and functions ############################################################################
#Setup cache mechanism to initialize translation model at module level to improve app speed.
#Define global variables for tokenizer and model
helsinki_model_cache = {}
def get_helsinki_model_and_tokenizer(src_lang, target_lang):
helsinki_model_name =f"Helsinki-NLP/opus-mt-{src_lang}-{target_lang}"
if helsinki_model_name not in helsinki_model_cache:
tokenizer = MarianTokenizer.from_pretrained(helsinki_model_name)
model = MarianMTModel.from_pretrained(helsinki_model_name)
helsinki_model_cache[helsinki_model_name] = (tokenizer, model)
return helsinki_model_cache[helsinki_model_name]
#Define function to transcribe audio to text and then translate it into the specified language
def translate(transcribed_text, target_lang="es"):
try:
#Define the model and tokenizer
src_lang = detect(transcribed_text)
tokenizer, model = get_helsinki_model_and_tokenizer(src_lang, target_lang)
max_length = tokenizer.model_max_length
# Split text based on sentence endings to better manage translation segments
# This is done because in previous iterations of the app, some translations hit
# the max number of tokens and the output was truncated. This is part of our
# evaluation and optimization process
sentences = re.split(r'(?<=[.!?]) +', transcribed_text)
full_translation = ""
# Process each sentence individually
for sentence in sentences:
tokens = tokenizer.encode(sentence, return_tensors="pt", truncation=True, max_length=max_length)
if tokens.size(1) > max_length:
continue # optionally handle long sentences longer than max # tokens for model
translated_tokens = model.generate(tokens)
segment_translation = tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
full_translation += segment_translation + " "
return full_translation.strip()
except Exception as e:
print(f"An error occurred: {e}")
return "Error in transcription or translation"
#Initialize Whisper model at the module level to be used across different calls
transcription_pipeline = None
def initialize_transcription_model():
global transcription_pipeline
if transcription_pipeline is None:
transcription_pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-large")
#Define function to transcribes audio to text using Whisper in the original language it was spoken
def transcribe_audio_original(audio_filepath):
try:
if transcription_pipeline is None:
initialize_transcription_model()
transcription_result = transcription_pipeline(audio_filepath)
transcribed_text = transcription_result['text']
return transcribed_text
except Exception as e:
print(f"an error occured: {e}")
return "Error in transcription"
#Initialize Polly client at module level
polly_client = boto3.client(
'polly',
region_name=aws_default_region,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key
)
# Define text-to-speech function using Amazon Polly
def polly_text_to_speech(text, lang_code):
try:
#get the appropriate voice ID from the mapping
voice_id = voice_map[lang_code]
#request speech synthesis
response = polly_client.synthesize_speech(
Engine = 'neural',
Text=text,
OutputFormat='mp3',
VoiceId=voice_id
)
# Save the audio to a temporary file and return its path
if "AudioStream" in response:
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp3') as audio_file:
audio_file.write(response['AudioStream'].read())
return audio_file.name
except boto3.exceptions.Boto3Error as e:
print(f"Error accessing Polly: {e}")
return None # Return None if there was an error
#Define function to submit query to Wikipedia to feed into Gradio app
def submit_question (audio_filepath=None, typed_text=None, target_lang=default_language):
#Determine source of text: audio transctiption or direct text input
# if audio_filepath and typed_text:
# return "Please use only one input method at a time", None
if not audio_filepath and not typed_text:
return "Please provide input by typing or speaking", None
response_speech = None
response_text = None
if typed_text:
#submit through handle_query function
# query_text = typed_text
detected_lang_code = detect(typed_text)
response_text = handle_query(typed_text)
response_speech = polly_text_to_speech(response_text, detected_lang_code)
elif audio_filepath:
#transcribe audio to text in background
query_text = transcribe_audio_original(audio_filepath)
detected_lang_code = detect(query_text)
response_text = handle_query(query_text)
response_speech = polly_text_to_speech(response_text, detected_lang_code)
if not response_speech:
response_speech = "No audio available"
return response_text, response_speech
#Define function to transcribe audio and provide output in text and speech
def transcribe_and_speech(audio_filepath=None, typed_text=None, target_lang=default_language):
#Determine source of text: audio transctiption or direct text input
if audio_filepath and typed_text:
return "Please use only one input method at a time", None
query_text = None
detected_lang_code = None
original_speech = None
if typed_text:
#convert typed text to speech
query_text = typed_text
detected_lang_code = detect(query_text)
original_speech = polly_text_to_speech(query_text, detected_lang_code)
return None, original_speech
elif audio_filepath:
#transcribe audio to text
query_text = transcribe_audio_original(audio_filepath)
detected_lang_code = detect(query_text)
original_speech = polly_text_to_speech(query_text, detected_lang_code)
return query_text, original_speech
if not query_text:
return "Please provide input by typing or speaking.", None
#Check if the language is specified. Default to English if not.
target_lang_code = language_map.get(target_lang, "en")
#Map detected language code to language name
detected_lang = [key for key, value in language_map.items() if value == detected_lang_code][0]
return query_text, original_speech
#Define function to translate query into target language in text and audio
def translate_and_speech(response_text=None, target_lang=default_language):
#Detect language of input text
detected_lang_code = detect(response_text)
detected_lang = [key for key, value in language_map.items() if value == detected_lang_code][0]
#Check if the language is specified. Default to English if not.
target_lang_code = language_map.get(target_lang, "en")
#Process text: translate
#Check if the detected language and target language are the same
if detected_lang == target_lang:
translated_response = response_text
else:
translated_response = translate(response_text, target_lang_code)
#convert to speech
translated_speech = polly_text_to_speech(translated_response, target_lang_code)
return translated_response, translated_speech
# Function to clear out all inputs
def clear_inputs():
return None, None, None, None, None, None |