File size: 11,894 Bytes
9d45390
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e2672e3
 
 
 
 
9d45390
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd974ae
 
 
 
9d45390
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd974ae
 
 
9d45390
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e2672e3
 
 
 
 
 
9d45390
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
from langchain_openai import ChatOpenAI
from dotenv import load_dotenv
import os
from langchain.schema import AIMessage, HumanMessage
import gradio as gr
from transformers import pipeline, AutoTokenizer, TFAutoModelForSeq2SeqLM
import subprocess
import torch
import tempfile
from langdetect import detect
from transformers import MarianMTModel, MarianTokenizer
import boto3
# Additional imports for loading PDF documents and QA chain.
from langchain_community.document_loaders import PyPDFLoader
# Additional imports for loading Wikipedia content and QA chain
from langchain_community.document_loaders import WikipediaLoader
from langchain.chains.question_answering import load_qa_chain
# Import RegEx for translate function, to split sentences in avoiding token limits
import re

#Get keys #########################################################################################
load_dotenv()

# Set the model name for our LLMs.
OPENAI_MODEL = "gpt-3.5-turbo"
# Store the API key in a variable.
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

#Define variables for AWS Polly Access#############################################################
aws_access_key_id = os.getenv('AWS_ACCESS_KEY_ID')
aws_secret_access_key = os.getenv('AWS_SECRET_ACCESS_KEY')
aws_default_region = os.getenv('AWS_DEFAULT_REGION')

#Define language variables ###################################################################################

#Define voice map
voice_map = {
    "ar": "Hala",
    "en": "Gregory",
    "es": "Mia",
    "fr": "Liam",
    "de": "Vicki",
    "it": "Bianca",
    "zh": "Hiujin",
    "hi": "Kajal",
    "jap": "Tomoko",
    "trk": "Burcu"
    
    }

#Define language map from full names to ISO codes
language_map = {
    "Arabic (Gulf)": "ar",
    "Chinese (Cantonese)": "zh",
    "English": "en",
    "French": "fr",
    "German": "de",
    "Hindi": "hi",
    "Italian": "it",
    "Japanese": "jap",
    "Spanish": "es",
    "Turkish": "trk"
    
}

# list of languages and their codes for dropdown
languages = gr.Dropdown(
    label="Click in the middle of the dropdown bar to select translation language", 
    choices=list(language_map.keys()))

#Define default language
default_language = "English"

#Setting the Chatbot Model #################################################################################

#Instantiating the llm we'll use and the arguments to pass
#This is done at a global level, and not within the definition of a function to improve
#the speed and efficiency of the app. Thus, the model will not be instantiated every time 
#a new question is submitted.  Similar setup is created for all of the models called. This 
#was part of our optimization process to help the app be more efficient and effective.
llm = ChatOpenAI(openai_api_key=OPENAI_API_KEY, model_name=OPENAI_MODEL, temperature=0.0)

# Define the wikipedia topic as a string.
wiki_topic = "diabetes"

# Load the wikipedia results as documents, using a max of 2.
#included error handling- unable to load documents
try:
    documents = WikipediaLoader(query=wiki_topic, load_max_docs=2, load_all_available_meta=True).load()
except Exception as e:
    print("Failed to load documents:", str(e))
    documents = []
# Create the QA chain using the LLM.
chain = load_qa_chain(llm)

##############################################################################################################

#Define the function to call the OpenAI chat LLM
def handle_query(user_query):
    if not documents:
        return "Source not loading info; please try again later."

    if user_query.lower() == 'quit':
        return "Goodbye!"

    try:
        # Pass the documents and the user's query to the chain, and return the result.
        result = chain.invoke({"input_documents": documents, "question": user_query})
        return result["output_text"] if result["output_text"].strip() else "No answer found, try a different question."
        
    except Exception as e:
        return "An error occurred while searching for the answer: " + str(e)
    
#Language models and functions ############################################################################

#Setup cache mechanism to initialize translation model at module level to improve app speed.

#Define global variables for tokenizer and model
helsinki_model_cache = {}

def get_helsinki_model_and_tokenizer(src_lang, target_lang):
    helsinki_model_name =f"Helsinki-NLP/opus-mt-{src_lang}-{target_lang}"
    if helsinki_model_name not in helsinki_model_cache:
        tokenizer = MarianTokenizer.from_pretrained(helsinki_model_name)
        model = MarianMTModel.from_pretrained(helsinki_model_name)
        helsinki_model_cache[helsinki_model_name] = (tokenizer, model)
    return helsinki_model_cache[helsinki_model_name]

#Define function to transcribe audio to text and then translate it into the specified language
def translate(transcribed_text, target_lang="es"):
    try:
        #Define the model and tokenizer
        src_lang = detect(transcribed_text)
        tokenizer, model = get_helsinki_model_and_tokenizer(src_lang, target_lang)
        max_length = tokenizer.model_max_length

        # Split text based on sentence endings to better manage translation segments
        # This is done because in previous iterations of the app, some translations hit 
        # the max number of tokens and the output was truncated. This is part of our 
        # evaluation and optimization process
        sentences = re.split(r'(?<=[.!?]) +', transcribed_text)
        full_translation = ""

        # Process each sentence individually
        for sentence in sentences:
            tokens = tokenizer.encode(sentence, return_tensors="pt", truncation=True, max_length=max_length)
            if tokens.size(1) > max_length:
                continue  # optionally handle long sentences longer than max # tokens for model
            translated_tokens = model.generate(tokens)
            segment_translation = tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
            full_translation += segment_translation + " "

        return full_translation.strip()

    except Exception as e:
        print(f"An error occurred: {e}")
        return "Error in transcription or translation"
        

#Initialize Whisper model at the module level to be used across different calls
transcription_pipeline = None

def initialize_transcription_model():
    global transcription_pipeline
    if transcription_pipeline is None:
        transcription_pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-large")

#Define function to transcribes audio to text using Whisper in the original language it was spoken
def transcribe_audio_original(audio_filepath):
    try:
        if transcription_pipeline is None:
            initialize_transcription_model()
        transcription_result = transcription_pipeline(audio_filepath)
        transcribed_text = transcription_result['text']
        return transcribed_text
    except Exception as e:
        print(f"an error occured: {e}")
        return "Error in transcription"


#Initialize Polly client at module level
polly_client = boto3.client(
    'polly',
    region_name=aws_default_region,
    aws_access_key_id=aws_access_key_id,
    aws_secret_access_key=aws_secret_access_key
)

# Define text-to-speech function using Amazon Polly
def polly_text_to_speech(text, lang_code):
    
    try:
    
        #get the appropriate voice ID from the mapping
        voice_id = voice_map[lang_code]
        
        #request speech synthesis
        response = polly_client.synthesize_speech(
            Engine = 'neural',
            Text=text,
            OutputFormat='mp3',
            VoiceId=voice_id
        )
        
        # Save the audio to a temporary file and return its path
        if "AudioStream" in response:
            with tempfile.NamedTemporaryFile(delete=False, suffix='.mp3') as audio_file:
                audio_file.write(response['AudioStream'].read())
                return audio_file.name
    except boto3.exceptions.Boto3Error as e:
        print(f"Error accessing Polly: {e}")
    return None  # Return None if there was an error



#Define function to submit query to Wikipedia to feed into Gradio app
def submit_question (audio_filepath=None, typed_text=None, target_lang=default_language):
    
    #Determine source of text: audio transctiption or direct text input
    # if audio_filepath and typed_text:
    #     return "Please use only one input method at a time", None
    
    if not audio_filepath and not typed_text:
        return "Please provide input by typing or speaking", None
    
    response_speech = None
    response_text = None
    
    if typed_text:
        #submit through handle_query function
        # query_text = typed_text
        detected_lang_code = detect(typed_text)
        response_text = handle_query(typed_text)
        response_speech = polly_text_to_speech(response_text, detected_lang_code)
        
    
    elif audio_filepath:
        #transcribe audio to text in background
        query_text = transcribe_audio_original(audio_filepath)
        detected_lang_code = detect(query_text)
        response_text = handle_query(query_text)
        response_speech = polly_text_to_speech(response_text, detected_lang_code)
        
    
    if not response_speech:
        response_speech = "No audio available"
    
    
    return response_text, response_speech

#Define function to transcribe audio and provide output in text and speech
def transcribe_and_speech(audio_filepath=None, typed_text=None, target_lang=default_language):
    
    #Determine source of text: audio transctiption or direct text input
    if audio_filepath and typed_text:
        return "Please use only one input method at a time", None
    
    query_text = None
    detected_lang_code = None
    original_speech = None
    
    if typed_text:
        #convert typed text to speech
        query_text = typed_text
        detected_lang_code = detect(query_text)
        original_speech = polly_text_to_speech(query_text, detected_lang_code)
        return None, original_speech
    
    elif audio_filepath:
        #transcribe audio to text
        query_text = transcribe_audio_original(audio_filepath)
        detected_lang_code = detect(query_text)
        original_speech = polly_text_to_speech(query_text, detected_lang_code)
        return query_text, original_speech
    
    if not query_text:
        return "Please provide input by typing or speaking.", None
    
    #Check if the language is specified. Default to English if not.
    target_lang_code = language_map.get(target_lang, "en")
    
    #Map detected language code to language name
    detected_lang = [key for key, value in language_map.items() if value == detected_lang_code][0]
    
    
    return query_text, original_speech


#Define function to translate query into target language in text and audio
def translate_and_speech(response_text=None, target_lang=default_language):
    
        
    #Detect language of input text
    detected_lang_code = detect(response_text)
    detected_lang = [key for key, value in language_map.items() if value == detected_lang_code][0]
    
    #Check if the language is specified. Default to English if not.
    target_lang_code = language_map.get(target_lang, "en")
    
    #Process text: translate 
    #Check if the detected language and target language are the same
    if detected_lang == target_lang:
        translated_response = response_text
    else:
        translated_response = translate(response_text, target_lang_code)
    
    #convert to speech
    translated_speech = polly_text_to_speech(translated_response, target_lang_code)
    
    return  translated_response, translated_speech


# Function to clear out all inputs
def clear_inputs():
    return None, None, None, None, None, None