khalednabawi11 commited on
Commit
cccd587
·
verified ·
1 Parent(s): 2505e32

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -89
app.py CHANGED
@@ -67,122 +67,170 @@
67
  # demo.launch()
68
 
69
 
70
- import gradio as gr
71
- from langdetect import detect
72
- from transformers import pipeline
73
- from qdrant_client import QdrantClient
74
- from qdrant_client.models import VectorParams, Distance
75
- from langchain.llms import HuggingFacePipeline
76
- from langchain.chains import RetrievalQA
77
- from langchain.vectorstores import Qdrant
78
- from transformers import GenerationConfig, AutoTokenizer, AutoModelForCausalLM
79
- from langchain.embeddings import HuggingFaceEmbeddings
80
- import os
81
 
82
- QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
83
- QDRANT_URL = os.getenv("QDRANT_URL")
84
 
85
 
 
 
86
 
 
 
 
87
 
88
- # Define model path
89
- model_name = "FreedomIntelligence/Apollo-7B"
90
 
91
- # Load model directly
92
- tokenizer = AutoTokenizer.from_pretrained(model_name)
93
- model = AutoModelForCausalLM.from_pretrained(model_name)
 
94
 
95
- # Enable padding token if missing
96
- tokenizer.pad_token = tokenizer.eos_token
 
 
 
97
 
98
- # Set up Qdrant vector store
99
- qdrant_client = QdrantClient(url=QDRANT_URL, api_key = QDRANT_API_KEY)
100
- vector_size = 768
101
- embedding = HuggingFaceEmbeddings(model_name="Omartificial-Intelligence-Space/GATE-AraBert-v1")
 
 
 
 
 
102
 
103
- qdrant_vectorstore = Qdrant(
104
- client=qdrant_client,
105
- collection_name="arabic_rag_collection",
106
- embeddings=embedding
107
- )
 
 
108
 
109
- # Generation config
110
- generation_config = GenerationConfig(
111
- max_new_tokens=150,
112
- temperature=0.2,
113
- top_k=20,
114
- do_sample=True,
115
- top_p=0.7,
116
- repetition_penalty=1.3,
117
- )
118
 
119
- # Set up HuggingFace Pipeline
120
- llm_pipeline = pipeline(
121
- model=model,
122
- tokenizer=tokenizer,
123
- task="text-generation",
124
- generation_config=generation_config,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  )
126
 
127
- llm = HuggingFacePipeline(pipeline=llm_pipeline)
128
 
129
- # Set up QA Chain
130
- qa_chain = RetrievalQA.from_chain_type(
131
- llm=llm,
132
- retriever=qdrant_vectorstore.as_retriever(search_kwargs={"k": 3}),
133
- chain_type="stuff"
134
  )
135
 
136
- # Generate prompt based on language
137
- def generate_prompt(question):
138
- lang = detect(question)
139
  if lang == "ar":
140
- return f"""أجب على السؤال الطبي التالي بلغة عربية فصحى، بإجابة دقيقة ومفصلة. إذا لم تجد معلومات كافية في السياق، استخدم معرفتك الطبية السابقة.
141
  وتأكد من ان:
142
  - عدم تكرار أي نقطة أو عبارة أو كلمة
143
  - وضوح وسلاسة كل نقطة
144
- - تجنب الحشو والعبارات الزائدة-
145
 
146
- السؤال: {question}
147
- الإجابة:
148
- """
149
  else:
150
- return f"""Answer the following medical question in clear English with a detailed, non-redundant response. Do not repeat ideas, phrases, or restate the question in the answer. If the context lacks relevant information, rely on your prior medical knowledge. If the answer involves multiple points, list them in concise and distinct bullet points:
151
- Question: {question}
152
  Answer:"""
153
 
154
- # Define Gradio interface function
155
- def medical_chatbot(question):
156
- formatted_question = generate_prompt(question)
157
- answer = qa_chain.run(formatted_question)
 
158
  return answer
159
 
160
- # Set up Gradio interface
161
- iface = gr.Interface(
162
- fn=medical_chatbot,
163
- inputs=gr.Textbox(label="Ask a Medical Question", placeholder="Type your question here..."),
164
- outputs=gr.Textbox(label="Answer", interactive=False),
165
- title="Medical Chatbot",
166
- description="Ask medical questions and get detailed answers in Arabic or English.",
 
 
 
 
167
  theme="compact"
168
  )
169
 
170
- # demo = gr.ChatInterface(
171
- # respond,
172
- # additional_inputs=[
173
- # gr.Textbox(value="You are a Medical Chatbot.", label="System message"),
174
- # gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
175
- # gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
176
- # gr.Slider(
177
- # minimum=0.1,
178
- # maximum=1.0,
179
- # value=0.95,
180
- # step=0.05,
181
- # label="Top-p (nucleus sampling)",
182
- # ),
183
- # ],
184
- # )
185
-
186
- # Launch Gradio interface
187
  if __name__ == "__main__":
188
- iface.launch()
 
67
  # demo.launch()
68
 
69
 
70
+ # import gradio as gr
71
+ # from langdetect import detect
72
+ # from transformers import pipeline
73
+ # from qdrant_client import QdrantClient
74
+ # from qdrant_client.models import VectorParams, Distance
75
+ # from langchain.llms import HuggingFacePipeline
76
+ # from langchain.chains import RetrievalQA
77
+ # from langchain.vectorstores import Qdrant
78
+ # from transformers import GenerationConfig, AutoTokenizer, AutoModelForCausalLM
79
+ # from langchain.embeddings import HuggingFaceEmbeddings
80
+ # import os
81
 
82
+ # QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
83
+ # QDRANT_URL = os.getenv("QDRANT_URL")
84
 
85
 
86
+ # # Define model path
87
+ # model_name = "FreedomIntelligence/Apollo-7B"
88
 
89
+ # # Load model directly
90
+ # tokenizer = AutoTokenizer.from_pretrained(model_name)
91
+ # model = AutoModelForCausalLM.from_pretrained(model_name)
92
 
93
+ # # Enable padding token if missing
94
+ # tokenizer.pad_token = tokenizer.eos_token
95
 
96
+ # # Set up Qdrant vector store
97
+ # qdrant_client = QdrantClient(url=QDRANT_URL, api_key = QDRANT_API_KEY)
98
+ # vector_size = 768
99
+ # embedding = HuggingFaceEmbeddings(model_name="Omartificial-Intelligence-Space/GATE-AraBert-v1")
100
 
101
+ # qdrant_vectorstore = Qdrant(
102
+ # client=qdrant_client,
103
+ # collection_name="arabic_rag_collection",
104
+ # embeddings=embedding
105
+ # )
106
 
107
+ # # Generation config
108
+ # generation_config = GenerationConfig(
109
+ # max_new_tokens=150,
110
+ # temperature=0.2,
111
+ # top_k=20,
112
+ # do_sample=True,
113
+ # top_p=0.7,
114
+ # repetition_penalty=1.3,
115
+ # )
116
 
117
+ # # Set up HuggingFace Pipeline
118
+ # llm_pipeline = pipeline(
119
+ # model=model,
120
+ # tokenizer=tokenizer,
121
+ # task="text-generation",
122
+ # generation_config=generation_config,
123
+ # )
124
 
125
+ # llm = HuggingFacePipeline(pipeline=llm_pipeline)
 
 
 
 
 
 
 
 
126
 
127
+ # # Set up QA Chain
128
+ # qa_chain = RetrievalQA.from_chain_type(
129
+ # llm=llm,
130
+ # retriever=qdrant_vectorstore.as_retriever(search_kwargs={"k": 3}),
131
+ # chain_type="stuff"
132
+ # )
133
+
134
+ # # Generate prompt based on language
135
+ # def generate_prompt(question):
136
+ # lang = detect(question)
137
+ # if lang == "ar":
138
+ # return f"""أجب على السؤال الطبي التالي بلغة عربية فصحى، بإجابة دقيقة ومفصلة. إذا لم تجد معلومات كافية في السياق، استخدم معرفتك الطبية السابقة.
139
+ # وتأكد من ان:
140
+ # - عدم تكرار أي نقطة أو عبارة أو كلمة
141
+ # - وضوح وسلاسة كل نقطة
142
+ # - تجنب الحشو والعبارات الزائدة-
143
+
144
+ # السؤال: {question}
145
+ # الإجابة:
146
+ # """
147
+ # else:
148
+ # return f"""Answer the following medical question in clear English with a detailed, non-redundant response. Do not repeat ideas, phrases, or restate the question in the answer. If the context lacks relevant information, rely on your prior medical knowledge. If the answer involves multiple points, list them in concise and distinct bullet points:
149
+ # Question: {question}
150
+ # Answer:"""
151
+
152
+ # # Define Gradio interface function
153
+ # def medical_chatbot(question):
154
+ # formatted_question = generate_prompt(question)
155
+ # answer = qa_chain.run(formatted_question)
156
+ # return answer
157
+
158
+ # # Set up Gradio interface
159
+ # iface = gr.Interface(
160
+ # fn=medical_chatbot,
161
+ # inputs=gr.Textbox(label="Ask a Medical Question", placeholder="Type your question here..."),
162
+ # outputs=gr.Textbox(label="Answer", interactive=False),
163
+ # title="Medical Chatbot",
164
+ # description="Ask medical questions and get detailed answers in Arabic or English.",
165
+ # theme="compact"
166
+ # )
167
+
168
+ # # Launch Gradio interface
169
+ # if __name__ == "__main__":
170
+ # iface.launch()
171
+
172
+
173
+ import gradio as gr
174
+ from langdetect import detect
175
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextGenerationPipeline
176
+ import torch
177
+
178
+ # Load model and tokenizer
179
+ model_name = "FreedomIntelligence/Apollo-7B"
180
+
181
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
182
+ model = AutoModelForCausalLM.from_pretrained(
183
+ model_name,
184
+ torch_dtype=torch.float16,
185
+ device_map="auto"
186
  )
187
 
188
+ tokenizer.pad_token = tokenizer.eos_token
189
 
190
+ # Create generation pipeline
191
+ pipe = TextGenerationPipeline(
192
+ model=model,
193
+ tokenizer=tokenizer,
194
+ device=model.device.index if torch.cuda.is_available() else -1
195
  )
196
 
197
+ # Prompt formatter based on language
198
+ def generate_prompt(message, history):
199
+ lang = detect(message)
200
  if lang == "ar":
201
+ return f"""أجب على السؤال الطبي التالي بلغة عربية فصحى، بإجابة دقيقة ومفصلة. إذا لم تجد معلومات كافية في السياق، استخدم معرفتك الطبية السابقة.
202
  وتأكد من ان:
203
  - عدم تكرار أي نقطة أو عبارة أو كلمة
204
  - وضوح وسلاسة كل نقطة
205
+ - تجنب الحشو والعبارات الزائدة
206
 
207
+ السؤال: {message}
208
+ الإجابة:"""
 
209
  else:
210
+ return f"""Answer the following medical question in clear English with a detailed, non-redundant response. Do not repeat ideas or restate the question. If information is missing, rely on your prior medical knowledge:
211
+ Question: {message}
212
  Answer:"""
213
 
214
+ # Chat function
215
+ def chat_fn(message, history):
216
+ prompt = generate_prompt(message, history)
217
+ response = pipe(prompt, max_new_tokens=512, temperature=0.7, top_p=0.9)[0]['generated_text']
218
+ answer = response.split("Answer:")[-1].strip() if "Answer:" in response else response.split("الإجابة:")[-1].strip()
219
  return answer
220
 
221
+ # Gradio ChatInterface
222
+ demo = gr.ChatInterface(
223
+ fn=chat_fn,
224
+ title="🩺 Apollo-7B Medical Chatbot",
225
+ description="Multilingual (Arabic & English) medical Q&A chatbot powered by Apollo-7B. No RAG, just fast model inference.",
226
+ examples=[
227
+ "ما هي أعراض ضغط الدم المرتفع؟",
228
+ "What are the side effects of paracetamol?",
229
+ "هل يمكن علاج مرض السكري؟",
230
+ "How does COVID-19 affect the lungs?"
231
+ ],
232
  theme="compact"
233
  )
234
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  if __name__ == "__main__":
236
+ demo.launch()