Spaces:
Sleeping
Sleeping
""" | |
Core data processing and analysis logic for the PharmaCircle AI Data Analyst. | |
This module orchestrates the main analysis workflow: | |
1. Takes a user's natural language query. | |
2. Uses the LLM to generate a structured analysis plan. | |
3. Executes parallel queries against Solr for quantitative and qualitative data. | |
4. Generates a data visualization using the LLM. | |
5. Synthesizes the findings into a comprehensive, user-facing report. | |
""" | |
import json | |
import re | |
import datetime | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import os | |
import concurrent.futures | |
import copy | |
import google.generativeai as genai | |
import urllib | |
import pysolr | |
import config # Import the config module to access remote host details | |
import tiktoken | |
from llm_prompts import ( | |
get_analysis_plan_prompt, | |
get_synthesis_report_prompt, | |
get_visualization_code_prompt | |
) | |
from extract_results import get_search_list_params | |
def parse_suggestions_from_report(report_text): | |
"""Extracts numbered suggestions from the report's markdown text.""" | |
suggestions_match = re.search(r"### (?:Deeper Dive: Suggested Follow-up Analyses|Suggestions for Further Exploration)\s*\n(.*?)$", report_text, re.DOTALL | re.IGNORECASE) | |
if not suggestions_match: return [] | |
suggestions_text = suggestions_match.group(1) | |
suggestions = re.findall(r"^\s*\d+\.\s*(.*)", suggestions_text, re.MULTILINE) | |
return [s.strip() for s in suggestions] | |
def llm_generate_analysis_plan_with_history(llm_model, natural_language_query, chat_history): | |
""" | |
Generates a complete analysis plan from a user query, considering chat history | |
and dynamic field suggestions from an external API. | |
""" | |
search_fields, search_name, field_mappings = [], "", {} | |
intent = None | |
try: | |
intent, search_fields, search_name, field_mappings = get_search_list_params(natural_language_query) | |
print(f"API returned intent: '{intent}', core: '{search_name}' with {len(search_fields)} fields and {len(field_mappings)} mappings.") | |
if intent != 'search_list': | |
print(f"API returned intent '{intent}' which is not 'search_list'. Aborting analysis.") | |
return None, None, None, intent, None, None, None | |
except Exception as e: | |
print(f"Warning: Could not retrieve dynamic search fields. Proceeding without them. Error: {e}") | |
return None, [], None, 'api_error', None, None, None | |
core_name = search_name if search_name else 'news' | |
mapped_search_fields = [] | |
if search_fields and field_mappings: | |
for field in search_fields: | |
original_name = field.get('field_name') | |
mapped_field = field.copy() | |
if original_name in field_mappings: | |
mapped_field['field_name'] = field_mappings[original_name] | |
print(f"Mapped field '{original_name}' to '{mapped_field['field_name']}'") | |
mapped_search_fields.append(mapped_field) | |
else: | |
mapped_search_fields = search_fields | |
prompt = get_analysis_plan_prompt(natural_language_query, chat_history, mapped_search_fields, core_name) | |
try: | |
response = llm_model.generate_content(prompt) | |
encoding = tiktoken.encoding_for_model("gpt-4") | |
input_token_count = len(encoding.encode(prompt)) | |
output_token_count = len(encoding.encode(response.text)) | |
total_token_count = (input_token_count if input_token_count is not None else 0) + (output_token_count if output_token_count is not None else 0) | |
cleaned_text = re.sub(r'```json\s*|\s*```', '', response.text, flags=re.MULTILINE | re.DOTALL).strip() | |
plan = json.loads(cleaned_text) | |
return plan, mapped_search_fields, core_name, intent, input_token_count, output_token_count, total_token_count | |
except json.JSONDecodeError as e: | |
raw_response_text = response.text if 'response' in locals() else 'N/A' | |
print(f"Error decoding JSON from LLM response: {e}\nRaw Response:\n{raw_response_text}") | |
return None, mapped_search_fields, core_name, intent, None, None, None | |
except Exception as e: | |
raw_response_text = response.text if 'response' in locals() else 'N/A' | |
print(f"Error in llm_generate_analysis_plan_with_history: {e}\nRaw Response:\n{raw_response_text}") | |
return None, mapped_search_fields, core_name, intent, None, None, None | |
def execute_quantitative_query(solr_client, plan): | |
"""Executes the facet query to get aggregate data.""" | |
if not plan or 'quantitative_request' not in plan or 'json.facet' not in plan.get('quantitative_request', {}): | |
print("Skipping quantitative query due to incomplete plan.") | |
return None, None | |
try: | |
params = { | |
"q": plan.get('query_filter', '*_*'), | |
"rows": 0, | |
"json.facet": json.dumps(plan['quantitative_request']['json.facet']) | |
} | |
base_url = f"{solr_client.url}/select" | |
query_string = urllib.parse.urlencode(params) | |
full_url = f"{base_url}?{query_string}" | |
# Create the public-facing URL for display | |
public_url = full_url.replace(f'http://127.0.0.1:{config.LOCAL_BIND_PORT}', f'http://{config.REMOTE_SOLR_HOST}:{config.REMOTE_SOLR_PORT}') | |
print(f"[DEBUG] Solr QUANTITATIVE query URL (PUBLIC): {public_url}") | |
results = solr_client.search(**params) | |
return results.raw_response.get("facets", {}), public_url | |
except pysolr.SolrError as e: | |
print(f"Solr Error in quantitative query on core {solr_client.url}: {e}") | |
return None, None | |
except Exception as e: | |
print(f"Unexpected error in quantitative query: {e}") | |
return None, None | |
def execute_qualitative_query(solr_client, plan): | |
"""Executes the grouping query to get the best example docs.""" | |
if not plan or 'qualitative_request' not in plan: | |
print("Skipping qualitative query due to incomplete plan.") | |
return None, None | |
try: | |
qual_request = copy.deepcopy(plan['qualitative_request']) | |
params = { | |
"q": plan.get('query_filter', '*_*'), | |
"rows": 5, | |
"fl": "*,score", | |
**qual_request | |
} | |
base_url = f"{solr_client.url}/select" | |
query_string = urllib.parse.urlencode(params) | |
full_url = f"{base_url}?{query_string}" | |
# Create the public-facing URL for display | |
public_url = full_url.replace(f'http://127.0.0.1:{config.LOCAL_BIND_PORT}', f'http://{config.REMOTE_SOLR_HOST}:{config.REMOTE_SOLR_PORT}') | |
print(f"[DEBUG] Solr QUALITATIVE query URL (PUBLIC): {public_url}") | |
results = solr_client.search(**params) | |
return results.grouped, public_url | |
except pysolr.SolrError as e: | |
print(f"Solr Error in qualitative query on core {solr_client.url}: {e}") | |
return None, None | |
except Exception as e: | |
print(f"Unexpected error in qualitative query: {e}") | |
return None, None | |
def llm_synthesize_enriched_report_stream(llm_model, query, quantitative_data, qualitative_data, plan): | |
""" | |
Generates an enriched report by synthesizing quantitative aggregates | |
and qualitative examples, and streams the result. | |
""" | |
prompt = get_synthesis_report_prompt(query, quantitative_data, qualitative_data, plan) | |
try: | |
response_stream = llm_model.generate_content(prompt, stream=True) | |
response_text = "" | |
for chunk in response_stream: | |
yield {"text": chunk.text, "tokens": None} | |
response_text += chunk.text | |
encoding = tiktoken.encoding_for_model("gpt-4") | |
input_token_count = len(encoding.encode(prompt)) | |
output_token_count = len(encoding.encode(response_text)) | |
total_token_count = (input_token_count if input_token_count is not None else 0) + (output_token_count if output_token_count is not None else 0) | |
tokens = { | |
"input": input_token_count, | |
"output": output_token_count, | |
"total": total_token_count, | |
} | |
yield {"text": None, "tokens": tokens} | |
except Exception as e: | |
print(f"Error in llm_synthesize_enriched_report_stream: {e}") | |
yield {"text": "Sorry, an error occurred while generating the report. Please check the logs for details.", "tokens": None} | |
def llm_generate_visualization_code(llm_model, query_context, facet_data): | |
"""Generates Python code for visualization based on query and data.""" | |
prompt = get_visualization_code_prompt(query_context, facet_data) | |
try: | |
generation_config = genai.types.GenerationConfig(temperature=0) | |
response = llm_model.generate_content(prompt, generation_config=generation_config) | |
encoding = tiktoken.encoding_for_model("gpt-4") | |
input_token_count = len(encoding.encode(prompt)) | |
output_token_count = len(encoding.encode(response.text)) | |
total_token_count = (input_token_count if input_token_count is not None else 0) + (output_token_count if output_token_count is not None else 0) | |
code = re.sub(r'^```python\s*|```$', '', response.text, flags=re.MULTILINE) | |
return code, input_token_count, output_token_count, total_token_count | |
except Exception as e: | |
raw_response_text = response.text if 'response' in locals() else 'N/A' | |
print(f"Error in llm_generate_visualization_code: {e}\nRaw response: {raw_response_text}") | |
return | |
def execute_viz_code_and_get_path(viz_code, facet_data): | |
"""Executes visualization code and returns the path to the saved plot image.""" | |
if not viz_code: | |
return None | |
# --- SECURITY WARNING --- | |
# The following code executes code generated by an LLM. This is a security | |
# risk and should be handled with extreme care in a production environment. | |
# Ideally, this code should be run in a sandboxed environment. | |
print("\n--- WARNING: Executing LLM-generated code. ---") | |
try: | |
if not os.path.exists('/tmp/plots'): | |
os.makedirs('/tmp/plots') | |
plot_path = f"/tmp/plots/plot_{datetime.datetime.now().timestamp()}.png" | |
# Create a restricted global environment for execution | |
exec_globals = {'facet_data': facet_data, 'plt': plt, 'sns': sns, 'pd': pd} | |
exec(viz_code, exec_globals) | |
fig = exec_globals.get('fig') | |
if fig: | |
fig.savefig(plot_path, bbox_inches='tight') | |
plt.close(fig) | |
print("--- LLM-generated code executed successfully. ---") | |
return plot_path | |
else: | |
print("--- LLM-generated code did not produce a 'fig' object. ---") | |
return None | |
except Exception as e: | |
print(f"\n--- ERROR executing visualization code: ---") | |
print(f"Error: {e}") | |
print(f"--- Code---\n{viz_code}") | |
print("-----------------------------------------") | |
return None |