|
import os |
|
import pickle |
|
import streamlit as st |
|
import json |
|
from pathlib import Path |
|
from typing import Annotated, List, TypedDict, Dict, Any, Literal, Optional, NotRequired |
|
import operator |
|
import numpy as np |
|
from scipy.spatial.distance import cosine |
|
from dotenv import load_dotenv |
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage |
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder, PromptTemplate |
|
from langchain_core.tools import tool |
|
from langchain_openai import ChatOpenAI |
|
from langchain_community.tools.arxiv.tool import ArxivQueryRun |
|
from langchain.schema.output_parser import StrOutputParser |
|
from langchain_core.documents import Document |
|
from langchain_core.vectorstores import VectorStore |
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
from langchain_community.retrievers import BM25Retriever |
|
from langchain.retrievers import EnsembleRetriever |
|
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever |
|
from langchain_cohere import CohereRerank |
|
from langgraph.graph import StateGraph, START, END |
|
from langgraph.prebuilt import ToolNode |
|
|
|
from pydantic import BaseModel, Field |
|
import asyncio |
|
import requests |
|
from tavily import TavilyClient, AsyncTavilyClient |
|
from langchain_community.retrievers import ArxivRetriever |
|
from enum import Enum |
|
from dataclasses import dataclass, fields |
|
from langchain_core.runnables import RunnableConfig |
|
from langchain.chat_models import init_chat_model |
|
from langgraph.constants import Send |
|
from langgraph.types import interrupt, Command |
|
from IPython.display import Markdown, display |
|
import uuid |
|
|
|
|
|
|
|
|
|
def debug_startup_info(): |
|
"""Print debug information at startup to help identify file locations""" |
|
print("=" * 50) |
|
print("DEBUG STARTUP INFO") |
|
print("=" * 50) |
|
|
|
print(f"Current working directory: {os.getcwd()}") |
|
|
|
print("\nChecking for data directory:") |
|
if os.path.exists("data"): |
|
print("Found 'data' directory in current directory") |
|
print(f"Contents: {os.listdir('data')}") |
|
if os.path.exists("data/processed_data"): |
|
print(f"Contents of data/processed_data: {os.listdir('data/processed_data')}") |
|
|
|
common_paths = [ |
|
"/data", |
|
"/repository", |
|
"/app", |
|
"/app/data", |
|
"/repository/data", |
|
"/app/repository", |
|
"AB_AI_RAG_Agent/data" |
|
] |
|
print("\nChecking common paths:") |
|
for path in common_paths: |
|
if os.path.exists(path): |
|
print(f"Found path: {path}") |
|
print(f"Contents: {os.listdir(path)}") |
|
|
|
processed_path = os.path.join(path, "processed_data") |
|
if os.path.exists(processed_path): |
|
print(f"Found processed_data at: {processed_path}") |
|
print(f"Contents: {os.listdir(processed_path)}") |
|
print("=" * 50) |
|
|
|
|
|
|
|
debug_startup_info() |
|
|
|
|
|
import os |
|
DEBUG_FILE_PATHS = True |
|
|
|
def debug_paths(): |
|
if DEBUG_FILE_PATHS: |
|
print("Current working directory:", os.getcwd()) |
|
print("Files in /data:", os.listdir("/data") if os.path.exists("/data") else "Not found") |
|
print("Files in /data/processed_data:", os.listdir("/data/processed_data") if os.path.exists("/data/processed_data") else "Not found") |
|
for path in ["/repository", "/app", "/app/data"]: |
|
if os.path.exists(path): |
|
print(f"Files in {path}:", os.listdir(path)) |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
required_keys = ["COHERE_API_KEY", "ANTHROPIC_API_KEY", "TAVILY_API_KEY"] |
|
missing_keys = [key for key in required_keys if not os.environ.get(key)] |
|
if missing_keys: |
|
st.error(f"Missing required API keys: {', '.join(missing_keys)}. Please set them as environment variables.") |
|
st.stop() |
|
|
|
|
|
|
|
class CustomVectorStore(VectorStore): |
|
def __init__(self, embedded_docs, embedding_model): |
|
self.embedded_docs = embedded_docs |
|
self.embedding_model = embedding_model |
|
|
|
def similarity_search_with_score(self, query, k=5): |
|
|
|
query_embedding = self.embedding_model.embed_query(query) |
|
|
|
results = [] |
|
for doc in self.embedded_docs: |
|
|
|
similarity = 1 - cosine(query_embedding, doc["embedding"]) |
|
results.append((doc, similarity)) |
|
|
|
results.sort(key=lambda x: x[1], reverse=True) |
|
|
|
documents_with_scores = [] |
|
for doc, score in results[:k]: |
|
document = Document( |
|
page_content=doc["text"], |
|
metadata=doc["metadata"] |
|
) |
|
documents_with_scores.append((document, score)) |
|
return documents_with_scores |
|
|
|
def similarity_search(self, query, k=5): |
|
docs_with_scores = self.similarity_search_with_score(query, k) |
|
return [doc for doc, _ in docs_with_scores] |
|
|
|
|
|
@classmethod |
|
def from_texts(cls, texts, embedding, metadatas=None, **kwargs): |
|
"""Implement required abstract method from VectorStore base class.""" |
|
|
|
embeddings = embedding.embed_documents(texts) |
|
|
|
embedded_docs = [] |
|
for i, (text, embedding_vector) in enumerate(zip(texts, embeddings)): |
|
metadata = metadatas[i] if metadatas else {} |
|
embedded_docs.append({ |
|
"text": text, |
|
"embedding": embedding_vector, |
|
"metadata": metadata |
|
}) |
|
|
|
return cls(embedded_docs, embedding) |
|
|
|
|
|
def find_processed_data(): |
|
"""Find the processed_data directory path""" |
|
possible_paths = [ |
|
"data/processed_data", |
|
"../data/processed_data", |
|
"/data/processed_data", |
|
"/app/data/processed_data", |
|
"./data/processed_data", |
|
"/repository/data/processed_data", |
|
"AB_AI_RAG_Agent/data/processed_data" |
|
] |
|
for path in possible_paths: |
|
if os.path.exists(path): |
|
required_files = ["chunks.pkl", "bm25_retriever.pkl", "embedding_info.json", "embedded_docs.pkl"] |
|
if all(os.path.exists(os.path.join(path, f)) for f in required_files): |
|
print(f"Found processed_data at: {path}") |
|
return path |
|
raise FileNotFoundError("Could not find processed_data directory with required files") |
|
|
|
|
|
|
|
@st.cache_resource |
|
def initialize_vectorstore(): |
|
"""Initialize the hybrid retriever system with Cohere reranking""" |
|
try: |
|
|
|
processed_data_path = find_processed_data() |
|
|
|
|
|
with open(os.path.join(processed_data_path, "chunks.pkl"), "rb") as f: |
|
documents = pickle.load(f) |
|
|
|
|
|
with open(os.path.join(processed_data_path, "bm25_retriever.pkl"), "rb") as f: |
|
bm25_retriever = pickle.load(f) |
|
bm25_retriever.k = 5 |
|
|
|
|
|
with open(os.path.join(processed_data_path, "embedding_info.json"), "r") as f: |
|
embedding_info = json.load(f) |
|
|
|
|
|
with open(os.path.join(processed_data_path, "embedded_docs.pkl"), "rb") as f: |
|
embedded_docs = pickle.load(f) |
|
|
|
|
|
embedding_model = HuggingFaceEmbeddings( |
|
model_name=embedding_info["model_name"] |
|
) |
|
|
|
|
|
vectorstore = CustomVectorStore(embedded_docs, embedding_model) |
|
qdrant_retriever = vectorstore.as_retriever(search_kwargs={"k": 5}) |
|
|
|
|
|
hybrid_retriever = EnsembleRetriever( |
|
retrievers=[qdrant_retriever, bm25_retriever], |
|
weights=[0.5, 0.5], |
|
) |
|
|
|
|
|
cohere_rerank = CohereRerank( |
|
model="rerank-english-v3.0", |
|
top_n=5, |
|
) |
|
|
|
reranker = ContextualCompressionRetriever( |
|
base_compressor=cohere_rerank, |
|
base_retriever=hybrid_retriever |
|
) |
|
|
|
print("Successfully initialized retriever system") |
|
return reranker, documents |
|
except Exception as e: |
|
st.error(f"Error initializing retrievers: {str(e)}") |
|
st.stop() |
|
|
|
|
|
|
|
|
|
|
|
report_planner_query_writer_instructions="""You are performing research for a report. |
|
|
|
<Report topic> |
|
{topic} |
|
</Report topic> |
|
|
|
<Report organization> |
|
{report_organization} |
|
</Report organization> |
|
|
|
<Task> |
|
Your goal is to generate {number_of_queries} web search queries that will help gather information for planning the report sections. |
|
|
|
The queries should: |
|
|
|
1. Be related to the Report topic |
|
2. Help satisfy the requirements specified in the report organization |
|
|
|
Make the queries specific enough to find high-quality, relevant sources while covering the breadth needed for the report structure. |
|
</Task> |
|
""" |
|
|
|
|
|
report_planner_instructions="""I want a plan for a report that is concise and focused. |
|
|
|
<Report topic> |
|
The topic of the report is: |
|
{topic} |
|
</Report topic> |
|
|
|
<Report organization> |
|
The report should follow this organization: |
|
{report_organization} |
|
</Report organization> |
|
|
|
<Context> |
|
Here is context to use to plan the sections of the report: |
|
{context} |
|
</Context> |
|
|
|
<Task> |
|
Generate a list of sections for the report. Your plan should be tight and focused with NO overlapping sections or unnecessary filler. |
|
|
|
For example, a good report structure might look like: |
|
1/ intro |
|
2/ overview of topic A |
|
3/ overview of topic B |
|
4/ comparison between A and B |
|
5/ conclusion |
|
|
|
Each section should have the fields: |
|
|
|
- Name - Name for this section of the report. |
|
- Description - Brief overview of the main topics covered in this section. |
|
- Research - Whether to perform web research for this section of the report. |
|
- Content - The content of the section, which you will leave blank for now. |
|
|
|
Integration guidelines: |
|
- Include examples and implementation details within main topic sections, not as separate sections |
|
- Ensure each section has a distinct purpose with no content overlap |
|
- Combine related concepts rather than separating them |
|
|
|
Before submitting, review your structure to ensure it has no redundant sections and follows a logical flow. |
|
</Task> |
|
|
|
""" |
|
|
|
|
|
query_writer_instructions="""You are an expert technical writer crafting targeted web search queries that will gather comprehensive information for writing a technical report section. |
|
|
|
<Report topic> |
|
{topic} |
|
</Report topic> |
|
|
|
<Section topic> |
|
{section_topic} |
|
</Section topic> |
|
|
|
<Task> |
|
Your goal is to generate {number_of_queries} search queries that will help gather comprehensive information above the section topic. |
|
|
|
The queries should: |
|
|
|
1. Be related to the topic |
|
2. Examine different aspects of the topic |
|
|
|
Make the queries specific enough to find high-quality, relevant sources. |
|
</Task> |
|
""" |
|
|
|
|
|
section_writer_instructions = """You are an expert technical writer crafting one section of a technical report. |
|
|
|
<Report topic> |
|
{topic} |
|
</Report topic> |
|
|
|
<Section name> |
|
{section_name} |
|
</Section name> |
|
|
|
<Section topic> |
|
{section_topic} |
|
</Section topic> |
|
|
|
<Existing section content (if populated)> |
|
{section_content} |
|
</Existing section content> |
|
|
|
<Source material> |
|
{context} |
|
</Source material> |
|
|
|
|
|
<Guidelines for writing> |
|
1. If the existing section content is not populated, write a new section from scratch. |
|
2. If the existing section content is populated, write a new section that synthesizes the existing section content with the Source material. If there is a discrepancy between the existing section content and the Source material, use the existing section content as the primary source. The purpose of the Source material is to provide additional information and context to help fill the gaps in the existing section content. |
|
</Guidelines for writing> |
|
|
|
<Length and style> |
|
- Strict 150-200 word limit |
|
- No marketing language |
|
- Technical focus |
|
- Write in simple, clear language |
|
- Start with your most important insight in **bold** |
|
- Use short paragraphs (2-3 sentences max) |
|
- Use ## for section title (Markdown format) |
|
- Only use ONE structural element IF it helps clarify your point: |
|
* Either a focused table comparing 2-3 key items (using Markdown table syntax) |
|
* Or a short list (3-5 items) using proper Markdown list syntax: |
|
- Use `*` or `-` for unordered lists |
|
- Use `1.` for ordered lists |
|
- Ensure proper indentation and spacing |
|
</Length and style> |
|
|
|
<Quality checks> |
|
- Exactly 150-200 words (excluding title and sources) |
|
- Careful use of only ONE structural element (table or list) and only if it helps clarify your point |
|
- One specific example / case study |
|
- Starts with bold insight |
|
- No preamble prior to creating the section content |
|
- If there is a discrepancy between the existing section content and the Source material, use the existing section content as the primary source. The purpose of the Source material is to provide additional information and context to help fill the gaps in the existing section content. |
|
</Quality checks> |
|
""" |
|
|
|
|
|
section_grader_instructions = """Review a report section relative to the specified topic: |
|
|
|
<Report topic> |
|
{topic} |
|
</Report topic> |
|
|
|
<section topic> |
|
{section_topic} |
|
</section topic> |
|
|
|
<section content> |
|
{section} |
|
</section content> |
|
|
|
<search type> |
|
{current_iteration} |
|
</search type> |
|
|
|
<task> |
|
Evaluate whether the section content adequately addresses the section topic. |
|
|
|
If the section content does not adequately address the section topic, generate {number_of_follow_up_queries} follow-up search queries to gather missing information. Note that if search type is 1, your follow-up search queries will be used to search Arxiv for academic papers. If search type is 2 or more, your follow-up search queries will be used to search Tavily for general web search. |
|
</task> |
|
|
|
<format> |
|
grade: Literal["pass","fail"] = Field( |
|
description="Evaluation result indicating whether the response meets requirements ('pass') or needs revision ('fail')." |
|
) |
|
follow_up_queries: List[SearchQuery] = Field( |
|
description="List of follow-up search queries.", |
|
) |
|
</format> |
|
""" |
|
|
|
final_section_writer_instructions="""You are an expert technical writer crafting a section that synthesizes information from the rest of the report. |
|
|
|
<Report topic> |
|
{topic} |
|
</Report topic> |
|
|
|
<Section name> |
|
{section_name} |
|
</Section name> |
|
|
|
<Section topic> |
|
{section_topic} |
|
</Section topic> |
|
|
|
<Available report content> |
|
{context} |
|
</Available report content> |
|
|
|
<Task> |
|
1. Section-Specific Approach: |
|
|
|
For Introduction: |
|
- Use # for report title (Markdown format) |
|
- 50-100 word limit |
|
- Write in simple and clear language |
|
- Focus on the core motivation for the report in 1-2 paragraphs |
|
- Use a clear narrative arc to introduce the report |
|
- Include NO structural elements (no lists or tables) |
|
- No sources section needed |
|
|
|
For Conclusion/Summary: |
|
- Use ## for section title (Markdown format) |
|
- 100-150 word limit |
|
- For comparative reports: |
|
* Must include a focused comparison table using Markdown table syntax |
|
* Table should distill insights from the report |
|
* Keep table entries clear and concise |
|
- For non-comparative reports: |
|
* Only use ONE structural element IF it helps distill the points made in the report: |
|
* Either a focused table comparing items present in the report (using Markdown table syntax) |
|
* Or a short list using proper Markdown list syntax: |
|
- Use `*` or `-` for unordered lists |
|
- Use `1.` for ordered lists |
|
- Ensure proper indentation and spacing |
|
- End with specific next steps or implications |
|
- No sources section needed |
|
|
|
3. Writing Approach: |
|
- Use concrete details over general statements |
|
- Make every word count |
|
- Focus on your single most important point |
|
</Task> |
|
|
|
<Quality Checks> |
|
- For introduction: 50-100 word limit, # for report title, no structural elements, no sources section |
|
- For conclusion: 100-150 word limit, ## for section title, only ONE structural element at most, no sources section |
|
- Markdown format |
|
- Do not include word count or any preamble in your response |
|
</Quality Checks>""" |
|
|
|
|
|
initial_AB_topic_check_instructions="""You are checking if a given topic is related to A/B testing (even vaguely e.g. statistics, A/B testing, experimentation, etc.). |
|
|
|
<Topic> |
|
{topic} |
|
</Topic> |
|
|
|
<Task> |
|
Check if the topic is related to A/B testing (even vaguely, e.g. statistics, A/B testing, experimentation, etc.). |
|
|
|
If the topic is related to A/B testing (even vaguely), return 'true'. |
|
If the topic is not related to A/B testing, return 'false'. |
|
</Task> |
|
""" |
|
|
|
class Section(BaseModel): |
|
name: str = Field( |
|
description="Name for this section of the report.", |
|
) |
|
description: str = Field( |
|
description="Brief overview of the main topics and concepts to be covered in this section.", |
|
) |
|
research: bool = Field( |
|
description="Whether to perform web research for this section of the report." |
|
) |
|
content: str = Field( |
|
description="The content of the section." |
|
) |
|
sources: str = Field( |
|
default="", |
|
description="All sources used for this section" |
|
) |
|
|
|
class Sections(BaseModel): |
|
sections: List[Section] = Field( |
|
description="Sections of the report.", |
|
) |
|
|
|
class SearchQuery(BaseModel): |
|
search_query: str = Field(None, description="Query for web search.") |
|
|
|
class Queries(BaseModel): |
|
queries: List[SearchQuery] = Field( |
|
description="List of search queries.", |
|
) |
|
|
|
class Feedback(BaseModel): |
|
grade: Literal["pass","fail"] = Field( |
|
description="Evaluation result indicating whether the response meets requirements ('pass') or needs revision ('fail')." |
|
) |
|
follow_up_queries: List[SearchQuery] = Field( |
|
description="List of follow-up search queries.", |
|
) |
|
|
|
class ReportStateInput(TypedDict): |
|
topic: str |
|
|
|
class ReportStateOutput(TypedDict): |
|
final_report: str |
|
|
|
class ReportState(TypedDict): |
|
topic: str |
|
sections: list[Section] |
|
completed_sections: Annotated[list, operator.add] |
|
report_sections_from_research: str |
|
final_report: str |
|
ab_testing_check: NotRequired[bool] |
|
|
|
class SectionState(TypedDict): |
|
topic: str |
|
section: Section |
|
search_iterations: int |
|
search_queries: list[SearchQuery] |
|
source_str: str |
|
source_str_all: str |
|
report_sections_from_research: str |
|
completed_sections: list[Section] |
|
|
|
class SectionOutputState(TypedDict): |
|
completed_sections: list[Section] |
|
|
|
|
|
|
|
@st.cache_resource |
|
def initialize_report_system(_reranker): |
|
"""Initialize the AB Testing report system""" |
|
|
|
reranker = _reranker |
|
|
|
|
|
|
|
tavily_client = TavilyClient() |
|
tavily_async_client = AsyncTavilyClient() |
|
|
|
def get_config_value(value): |
|
""" |
|
Helper function to handle both string and enum cases of configuration values |
|
""" |
|
return value if isinstance(value, str) else value.value |
|
|
|
|
|
def get_search_params(search_api: str, search_api_config: Optional[Dict[str, Any]]) -> Dict[str, Any]: |
|
""" |
|
Filters the search_api_config dictionary to include only parameters accepted by the specified search API. |
|
|
|
Args: |
|
search_api (str): The search API identifier (e.g., "tavily"). |
|
search_api_config (Optional[Dict[str, Any]]): The configuration dictionary for the search API. |
|
|
|
Returns: |
|
Dict[str, Any]: A dictionary of parameters to pass to the search function. |
|
""" |
|
|
|
SEARCH_API_PARAMS = { |
|
"rag": [], |
|
"arxiv": ["load_max_docs", "get_full_documents", "load_all_available_meta"], |
|
"tavily": [] |
|
|
|
} |
|
|
|
|
|
accepted_params = SEARCH_API_PARAMS.get(search_api, []) |
|
|
|
|
|
if not search_api_config: |
|
return {} |
|
|
|
|
|
return {k: v for k, v in search_api_config.items() if k in accepted_params} |
|
|
|
def get_next_search_type(search_iterations): |
|
if search_iterations == 0: |
|
return "RAG search (internal A/B testing knowledge base)" |
|
elif search_iterations == 1: |
|
return "ArXiv web search (search academic papers on arXiv)" |
|
else: |
|
return "tavily web search (general web sources)" |
|
|
|
def deduplicate_and_format_sources(search_response, max_tokens_per_source, include_raw_content=True, search_iterations=None, return_has_sources=False): |
|
""" |
|
Takes a list of search responses and formats them into a readable string. |
|
Limits the raw_content to approximately max_tokens_per_source. |
|
|
|
Args: |
|
search_responses: List of search response dicts, each containing: |
|
- query: str |
|
- results: List of dicts with fields: |
|
- title: str |
|
- url: str |
|
- content: str |
|
- raw_content: str|None |
|
- score: float |
|
max_tokens_per_source: int |
|
include_raw_content: bool |
|
search_iterations: int, optional |
|
If 0, deduplicate by title (for RAG results) and show only title |
|
Otherwise, deduplicate by URL (for web/arxiv results) and show title + URL |
|
return_has_sources: bool, optional |
|
If True, returns (formatted_string, has_sources_bool) |
|
If False, returns just formatted_string |
|
|
|
Returns: |
|
str OR tuple: |
|
- If return_has_sources=False: formatted string |
|
- If return_has_sources=True: (formatted_string, has_sources_bool) |
|
""" |
|
|
|
sources_list = [] |
|
for response in search_response: |
|
sources_list.extend(response['results']) |
|
|
|
if not sources_list: |
|
empty_result = "" |
|
return (empty_result, False) if return_has_sources else empty_result |
|
|
|
|
|
if search_iterations == 0: |
|
unique_sources = {source['title']: source for source in sources_list} |
|
else: |
|
unique_sources = {source['url']: source for source in sources_list} |
|
|
|
|
|
has_unique_sources = bool(unique_sources) |
|
|
|
if not unique_sources: |
|
empty_result = "" |
|
return (empty_result, False) if return_has_sources else empty_result |
|
|
|
|
|
formatted_text = "" |
|
for i, source in enumerate(unique_sources.values(), 1): |
|
formatted_text += f"#### {source['title']}\n\n" |
|
|
|
|
|
if search_iterations != 0: |
|
formatted_text += f"#### URL: {source['url']}\n\n" |
|
|
|
if include_raw_content: |
|
|
|
char_limit = max_tokens_per_source * 4 |
|
|
|
raw_content = source.get('raw_content', '') |
|
if raw_content is None: |
|
raw_content = '' |
|
print(f"Warning: No raw_content found for source {source['url']}") |
|
if len(raw_content) > char_limit: |
|
raw_content = raw_content[:char_limit] + "... [truncated]" |
|
formatted_text += f"#### Full source content limited to {max_tokens_per_source} tokens \n\n" |
|
|
|
final_result = formatted_text.strip() |
|
return (final_result, has_unique_sources) if return_has_sources else final_result |
|
|
|
|
|
def format_sections(sections: list[Section]) -> str: |
|
""" Format a list of sections into a string """ |
|
formatted_str = "" |
|
for idx, section in enumerate(sections, 1): |
|
formatted_str += f""" |
|
{'='*60} # divider line of 60 equal signs |
|
Section {idx}: {section.name} |
|
{'='*60} # divider line of 60 equal signs |
|
Description: |
|
{section.description} |
|
Requires Research: |
|
{section.research} |
|
|
|
Content: |
|
{section.content if section.content else '[Not yet written]'} |
|
|
|
""" |
|
return formatted_str |
|
|
|
async def tavily_search_async(search_queries): |
|
""" |
|
Performs concurrent web searches using the Tavily API. |
|
|
|
Args: |
|
search_queries (List[SearchQuery]): List of search queries to process |
|
|
|
Returns: |
|
List[dict]: List of search responses from Tavily API, one per query. Each response has format: |
|
{ |
|
'query': str, # The original search query |
|
'follow_up_questions': None, |
|
'answer': None, |
|
'images': list, |
|
'results': [ # List of search results |
|
{ |
|
'title': str, # Title of the webpage |
|
'url': str, # URL of the result |
|
'content': str, # Summary/snippet of content |
|
'score': float, # Relevance score |
|
'raw_content': str|None # Full page content if available |
|
}, |
|
... |
|
] |
|
} |
|
""" |
|
|
|
search_tasks = [] |
|
for query in search_queries: |
|
search_tasks.append( |
|
tavily_async_client.search( |
|
query, |
|
max_results=5, |
|
include_raw_content=True, |
|
topic="general" |
|
) |
|
) |
|
|
|
|
|
search_docs = await asyncio.gather(*search_tasks) |
|
|
|
return search_docs |
|
|
|
async def arxiv_search_async(search_queries, load_max_docs=5, get_full_documents=False, load_all_available_meta=True): |
|
""" |
|
Performs concurrent searches on arXiv using the ArxivRetriever. |
|
|
|
Args: |
|
search_queries (List[str]): List of search queries or article IDs |
|
load_max_docs (int, optional): Maximum number of documents to return per query. Default is 5. |
|
get_full_documents (bool, optional): Whether to fetch full text of documents. Default is True. |
|
load_all_available_meta (bool, optional): Whether to load all available metadata. Default is True. |
|
|
|
Returns: |
|
List[dict]: List of search responses from arXiv, one per query. Each response has format: |
|
{ |
|
'query': str, # The original search query |
|
'follow_up_questions': None, |
|
'answer': None, |
|
'images': [], |
|
'results': [ # List of search results |
|
{ |
|
'title': str, # Title of the paper |
|
'url': str, # URL (Entry ID) of the paper |
|
'content': str, # Formatted summary with metadata |
|
'score': float, # Relevance score (approximated) |
|
'raw_content': str|None # Full paper content if available |
|
}, |
|
... |
|
] |
|
} |
|
""" |
|
|
|
|
|
print(f"[DEBUG] Starting ArXiv search with {len(search_queries)} queries: {[str(q) for q in search_queries]}") |
|
|
|
async def process_single_query(query): |
|
print(f"[DEBUG] Processing ArXiv query: {query}") |
|
try: |
|
|
|
print(f"[DEBUG] Creating ArxivRetriever with params: load_max_docs={load_max_docs}, get_full_documents={get_full_documents}, load_all_available_meta={load_all_available_meta}") |
|
|
|
|
|
retriever = ArxivRetriever( |
|
load_max_docs=load_max_docs, |
|
get_full_documents=get_full_documents, |
|
load_all_available_meta=load_all_available_meta |
|
) |
|
|
|
print(f"[DEBUG] ArxivRetriever created successfully") |
|
|
|
|
|
loop = asyncio.get_event_loop() |
|
print(f"[DEBUG] About to invoke retriever for query: {query}") |
|
docs = await loop.run_in_executor(None, lambda: retriever.invoke(query)) |
|
|
|
print(f"[DEBUG] ArXiv query '{query}' returned {len(docs)} documents") |
|
|
|
|
|
if docs: |
|
print(f"[DEBUG] First document metadata keys: {list(docs[0].metadata.keys())}") |
|
print(f"[DEBUG] First document has page_content: {bool(docs[0].page_content)}") |
|
else: |
|
print(f"[DEBUG] no documents returned for query: {query}") |
|
|
|
results = [] |
|
|
|
base_score = 1.0 |
|
score_decrement = 1.0 / (len(docs) + 1) if docs else 0 |
|
|
|
for i, doc in enumerate(docs): |
|
|
|
normalized_metadata = {k.lower().replace(' ', '_'): v for k, v in doc.metadata.items()} |
|
|
|
print(f"[DEBUG] Processing doc {i+1}: {normalized_metadata.get('title', 'No title')}") |
|
|
|
|
|
url = normalized_metadata.get('entry_id', '') |
|
title = normalized_metadata.get('title', '') |
|
authors = normalized_metadata.get('authors', '') |
|
published = normalized_metadata.get('published') |
|
|
|
|
|
summary = normalized_metadata.get('summary', '') |
|
if not summary and doc.page_content: |
|
summary = doc.page_content.strip() |
|
|
|
|
|
content_parts = [] |
|
if summary: |
|
content_parts.append(f"Summary: {summary}") |
|
if authors: |
|
content_parts.append(f"Authors: {authors}") |
|
|
|
|
|
|
|
if published: |
|
published_str = published.isoformat() if hasattr(published, 'isoformat') else str(published) |
|
content_parts.append(f"Published: {published_str}") |
|
|
|
|
|
primary_category = normalized_metadata.get('primary_category', '') |
|
if primary_category: |
|
content_parts.append(f"Primary Category: {primary_category}") |
|
|
|
categories = normalized_metadata.get('categories', []) |
|
if categories: |
|
if isinstance(categories, list): |
|
content_parts.append(f"Categories: {', '.join(categories)}") |
|
else: |
|
content_parts.append(f"Categories: {categories}") |
|
|
|
comment = normalized_metadata.get('comment', '') |
|
if comment: |
|
content_parts.append(f"Comment: {comment}") |
|
|
|
journal_ref = normalized_metadata.get('journal_ref', '') |
|
if journal_ref: |
|
content_parts.append(f"Journal Reference: {journal_ref}") |
|
|
|
doi = normalized_metadata.get('doi', '') |
|
if doi: |
|
content_parts.append(f"DOI: {doi}") |
|
|
|
|
|
links = normalized_metadata.get('links', []) |
|
if links: |
|
for link in links: |
|
if 'pdf' in str(link).lower(): |
|
content_parts.append(f"PDF: {link}") |
|
break |
|
|
|
|
|
content = "\n".join(content_parts) |
|
|
|
result = { |
|
'title': title, |
|
'url': url, |
|
'content': content, |
|
'score': base_score - (i * score_decrement), |
|
'raw_content': doc.page_content if get_full_documents else None |
|
} |
|
results.append(result) |
|
|
|
print(f"[DEBUG] Query '{query}' processed successfully, returning {len(results)} results") |
|
|
|
return { |
|
'query': query, |
|
'follow_up_questions': None, |
|
'answer': None, |
|
'images': [], |
|
'results': results |
|
} |
|
except Exception as e: |
|
|
|
print(f"[DEBUG ERROR] Error processing arXiv query '{query}': {str(e)}") |
|
print(f"[DEBUG ERROR] Exception type: {type(e).__name__}") |
|
import traceback |
|
print(f"[DEBUG ERROR] Full traceback: {traceback.format_exc()}") |
|
return { |
|
'query': query, |
|
'follow_up_questions': None, |
|
'answer': None, |
|
'images': [], |
|
'results': [], |
|
'error': str(e) |
|
} |
|
|
|
|
|
search_docs = [] |
|
for i, query in enumerate(search_queries): |
|
try: |
|
|
|
if i > 0: |
|
print(f"[DEBUG] Adding 4-second delay before processing query {i+1}") |
|
await asyncio.sleep(4.0) |
|
|
|
result = await process_single_query(query) |
|
search_docs.append(result) |
|
print(f"[DEBUG] Completed processing query {i+1}/{len(search_queries)}") |
|
except Exception as e: |
|
|
|
print(f"[DEBUG ERROR] Error processing arXiv query '{query}': {str(e)}") |
|
search_docs.append({ |
|
'query': query, |
|
'follow_up_questions': None, |
|
'answer': None, |
|
'images': [], |
|
'results': [], |
|
'error': str(e) |
|
}) |
|
|
|
|
|
if "429" in str(e) or "Too Many Requests" in str(e): |
|
print("[DEBUG] ArXiv rate limit exceeded. Adding additional delay...") |
|
await asyncio.sleep(7.0) |
|
|
|
print(f"[DEBUG] ArXiv search completed. Total results across all queries: {sum(len(doc.get('results', [])) for doc in search_docs)}") |
|
return search_docs |
|
|
|
async def rag_search_async(search_queries): |
|
""" |
|
Performs concurrent RAG searches of our thorough A/B testing collection using the reranker. |
|
|
|
Args: |
|
search_queries (List[SearchQuery]): List of search queries to process |
|
|
|
Returns: |
|
List[dict]: List of search responses from RAG, one per query. Each response has format: |
|
{ |
|
'query': str, # The original search query |
|
'follow_up_questions': None, |
|
'answer': None, |
|
'images': list, |
|
'results': [ # List of search results |
|
{ |
|
'title': str, # Title in format "Kohavi: {title}, Section: {section}" |
|
'url': str, # None for RAG results |
|
'content': str, # None for RAG results |
|
'score': float, # None for RAG results |
|
'raw_content': str|None # Chunk's page_content |
|
}, |
|
... |
|
] |
|
} |
|
""" |
|
|
|
async def single_rag_search(query): |
|
|
|
docs_descending = reranker.get_relevant_documents(query) |
|
docs = docs_descending[::-1] |
|
|
|
|
|
results = [] |
|
for doc in docs: |
|
source_path = doc.metadata.get("source", "") |
|
filename = source_path.split("/")[-1] if "/" in source_path else source_path |
|
|
|
|
|
if filename.endswith('.pdf'): |
|
filename = filename[:-4] |
|
|
|
section = doc.metadata.get("section_title", "unknown") |
|
|
|
title = f"Kohavi: {filename}, Section: {section}" |
|
|
|
results.append({ |
|
'title': title, |
|
'url': None, |
|
'content': None, |
|
'score': None, |
|
'raw_content': doc.page_content |
|
}) |
|
|
|
return { |
|
'query': query, |
|
'follow_up_questions': None, |
|
'answer': None, |
|
'images': [], |
|
'results': results |
|
} |
|
|
|
|
|
search_tasks = [single_rag_search(query) for query in search_queries] |
|
|
|
|
|
search_responses = await asyncio.gather(*search_tasks) |
|
|
|
return search_responses |
|
|
|
|
|
DEFAULT_REPORT_STRUCTURE = """Use this structure to create a report on the user-provided topic: |
|
|
|
1. Introduction (no research needed - REQUIRED) |
|
- Brief overview of the topic area |
|
- Set research=false for this section |
|
|
|
2. Main Body Sections: |
|
- Each section should focus on a sub-topic of the user-provided topic |
|
- These sections require research |
|
|
|
3. Conclusion (no research needed - REQUIRED) |
|
- Aim for 1 structural element (either a list of table) that distills the main body sections |
|
- Provide a concise summary of the report |
|
- Set research=false for this section |
|
|
|
IMPORTANT: Always include at least one Introduction section and one Conclusion section with research=false.""" |
|
|
|
|
|
class SearchAPI(Enum): |
|
TAVILY = "tavily" |
|
ARXIV = "arxiv" |
|
RAG = "rag" |
|
|
|
class PlannerProvider(Enum): |
|
ANTHROPIC = "anthropic" |
|
OPENAI = "openai" |
|
|
|
class WriterProvider(Enum): |
|
ANTHROPIC = "anthropic" |
|
OPENAI = "openai" |
|
|
|
|
|
|
|
@dataclass(kw_only=True) |
|
class Configuration: |
|
"""The configurable fields for the chatbot.""" |
|
report_structure: str = DEFAULT_REPORT_STRUCTURE |
|
|
|
|
|
number_of_queries: int = 1 |
|
max_search_depth: int = 3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
planner_provider: PlannerProvider = PlannerProvider.ANTHROPIC |
|
planner_model: str = "claude-opus-4-20250514" |
|
writer_provider: WriterProvider = WriterProvider.ANTHROPIC |
|
writer_model: str = "claude-sonnet-4-20250514" |
|
|
|
|
|
search_api: SearchAPI = SearchAPI.TAVILY |
|
search_api_config: Optional[Dict[str, Any]] = None |
|
|
|
@classmethod |
|
def from_runnable_config( |
|
cls, config: Optional[RunnableConfig] = None |
|
) -> "Configuration": |
|
"""Create a Configuration instance from a RunnableConfig.""" |
|
configurable = ( |
|
config["configurable"] if config and "configurable" in config else {} |
|
) |
|
values: dict[str, Any] = { |
|
f.name: os.environ.get(f.name.upper(), configurable.get(f.name)) |
|
for f in fields(cls) |
|
if f.init |
|
} |
|
return cls(**{k: v for k, v in values.items() if v}) |
|
|
|
|
|
async def generate_report_plan(state: ReportState, config: RunnableConfig): |
|
""" Generate the report plan """ |
|
|
|
|
|
topic = state["topic"] |
|
|
|
|
|
configurable = Configuration.from_runnable_config(config) |
|
report_structure = configurable.report_structure |
|
number_of_queries = configurable.number_of_queries |
|
|
|
search_api = "tavily" |
|
|
|
|
|
|
|
if isinstance(report_structure, dict): |
|
report_structure = str(report_structure) |
|
|
|
|
|
writer_provider = get_config_value(configurable.writer_provider) |
|
writer_model_name = get_config_value(configurable.writer_model) |
|
writer_model = init_chat_model(model=writer_model_name, model_provider=writer_provider) |
|
|
|
|
|
|
|
structured_llm = writer_model.with_structured_output(Queries) |
|
|
|
|
|
system_instructions_query = report_planner_query_writer_instructions.format(topic=topic, report_organization=report_structure, number_of_queries=number_of_queries) |
|
|
|
|
|
results = structured_llm.invoke([SystemMessage(content=system_instructions_query), |
|
HumanMessage(content="Generate search queries that will help with planning the sections of the report.")]) |
|
|
|
|
|
query_list = [query.search_query for query in results.queries] |
|
|
|
search_api_config = configurable.search_api_config or {} |
|
params_to_pass = get_search_params(search_api, search_api_config) |
|
|
|
|
|
if search_api == "tavily": |
|
search_results = await tavily_search_async(query_list, **params_to_pass) |
|
source_str = deduplicate_and_format_sources(search_results, max_tokens_per_source=1500, include_raw_content=False) |
|
elif search_api == "arxiv": |
|
search_results = await arxiv_search_async(query_list, **params_to_pass) |
|
source_str = deduplicate_and_format_sources(search_results, max_tokens_per_source=1500, include_raw_content=False) |
|
else: |
|
raise ValueError(f"Unsupported search API: {search_api}") |
|
|
|
|
|
system_instructions_sections = report_planner_instructions.format(topic=topic, report_organization=report_structure, context=source_str) |
|
|
|
|
|
planner_provider = get_config_value(configurable.planner_provider) |
|
planner_model = get_config_value(configurable.planner_model) |
|
|
|
|
|
planner_message = """Generate the sections of the report. Your response must include a 'sections' field containing a list of sections. |
|
Each section must have: name, description, plan, research, and content fields.""" |
|
|
|
|
|
|
|
planner_llm = init_chat_model( |
|
model=planner_model, |
|
model_provider=planner_provider, |
|
max_tokens=32_000, |
|
thinking={"type": "enabled", "budget_tokens": 24_000} |
|
) |
|
|
|
|
|
|
|
structured_llm = planner_llm.with_structured_output(Sections) |
|
report_sections = structured_llm.invoke([SystemMessage(content=system_instructions_sections), |
|
HumanMessage(content=planner_message)]) |
|
|
|
|
|
sections = report_sections.sections |
|
|
|
return Command(goto=[Send("build_section_with_web_research", {"topic": topic, "section": s, "search_iterations": 0}) for s in sections if s.research], update={"sections": sections}) |
|
|
|
def generate_queries(state: SectionState, config: RunnableConfig): |
|
""" Generate search queries for a report section to query our A/B testing RAG collection """ |
|
|
|
|
|
topic = state["topic"] |
|
section = state["section"] |
|
|
|
|
|
configurable = Configuration.from_runnable_config(config) |
|
number_of_queries = configurable.number_of_queries |
|
|
|
|
|
writer_provider = get_config_value(configurable.writer_provider) |
|
writer_model_name = get_config_value(configurable.writer_model) |
|
writer_model = init_chat_model(model=writer_model_name, model_provider=writer_provider) |
|
structured_llm = writer_model.with_structured_output(Queries) |
|
|
|
|
|
system_instructions = query_writer_instructions.format(topic=topic, |
|
section_topic=section.description, |
|
number_of_queries=number_of_queries) |
|
|
|
|
|
queries = structured_llm.invoke([SystemMessage(content=system_instructions), |
|
HumanMessage(content="Generate search queries on the provided topic.")]) |
|
|
|
return {"search_queries": queries.queries} |
|
|
|
async def search_rag_and_web(state: SectionState, config: RunnableConfig): |
|
""" Search A/B testing RAG collection and web with dual source tracking """ |
|
|
|
|
|
search_queries = state["search_queries"] |
|
search_iterations = state["search_iterations"] |
|
existing_source_str_all = state.get("source_str_all", "") |
|
|
|
|
|
configurable = Configuration.from_runnable_config(config) |
|
|
|
if search_iterations == 0: |
|
search_api = "rag" |
|
elif search_iterations == 1: |
|
search_api = "arxiv" |
|
else: |
|
search_api = "tavily" |
|
|
|
|
|
query_list = [query.search_query for query in search_queries] |
|
search_api_config = configurable.search_api_config or {} |
|
params_to_pass = get_search_params(search_api, search_api_config) |
|
|
|
if search_api == "rag": |
|
search_results = await rag_search_async(query_list) |
|
elif search_api == "arxiv": |
|
search_results = await arxiv_search_async(query_list, **params_to_pass) |
|
elif search_api == "tavily": |
|
search_results = await tavily_search_async(query_list) |
|
else: |
|
raise ValueError(f"Unsupported search API: {search_api}") |
|
|
|
|
|
|
|
current_source_str, has_sources = deduplicate_and_format_sources( |
|
search_results, |
|
max_tokens_per_source=1500, |
|
include_raw_content=True, |
|
search_iterations=search_iterations, |
|
return_has_sources=True |
|
) |
|
|
|
|
|
if has_sources: |
|
iteration_header = f"{'='*80}\n### SEARCH ITERATION {search_iterations + 1} - {search_api.upper()} RESULTS\n{'='*80}\n\n" |
|
|
|
|
|
if existing_source_str_all: |
|
accumulated_source_str = existing_source_str_all + "\n\n" + iteration_header + current_source_str |
|
else: |
|
accumulated_source_str = iteration_header + current_source_str |
|
else: |
|
|
|
accumulated_source_str = existing_source_str_all |
|
current_source_str = "" |
|
|
|
return { |
|
"source_str": current_source_str, |
|
"source_str_all": accumulated_source_str, |
|
"search_iterations": search_iterations + 1 |
|
} |
|
|
|
def write_section(state: SectionState, config: RunnableConfig) -> Command[Literal[END, "search_rag_and_web"]]: |
|
""" Write a section of the report """ |
|
|
|
|
|
topic = state["topic"] |
|
section = state["section"] |
|
source_str = state["source_str"] |
|
search_iterations = state["search_iterations"] |
|
|
|
|
|
configurable = Configuration.from_runnable_config(config) |
|
|
|
|
|
configurable = Configuration.from_runnable_config(config) |
|
|
|
|
|
system_instructions = section_writer_instructions.format(topic=topic, |
|
section_name=section.name, |
|
section_topic=section.description, |
|
context=source_str, |
|
section_content=section.content) |
|
|
|
|
|
writer_provider = get_config_value(configurable.writer_provider) |
|
writer_model_name = get_config_value(configurable.writer_model) |
|
writer_model = init_chat_model(model=writer_model_name, model_provider=writer_provider) |
|
section_content = writer_model.invoke([SystemMessage(content=system_instructions), |
|
HumanMessage(content="Generate a report section based on the existing section content (if any) and the provided sources.")]) |
|
|
|
|
|
section.content = section_content.content |
|
|
|
|
|
section_grader_message = """Grade the report and consider follow-up questions for missing information. |
|
If the grade is 'pass', return empty strings for all follow-up queries. |
|
If the grade is 'fail', provide specific search queries to gather missing information.""" |
|
|
|
section_grader_instructions_formatted = section_grader_instructions.format(topic=topic, |
|
section_topic=section.description, |
|
section=section.content, |
|
number_of_follow_up_queries=configurable.number_of_queries, |
|
current_iteration=search_iterations) |
|
|
|
|
|
planner_provider = get_config_value(configurable.planner_provider) |
|
planner_model = get_config_value(configurable.planner_model) |
|
|
|
reflection_llm = init_chat_model( |
|
model=planner_model, |
|
model_provider=planner_provider, |
|
max_tokens=32_000, |
|
thinking={"type": "enabled", "budget_tokens": 24_000} |
|
) |
|
|
|
reflection_model = reflection_llm.with_structured_output(Feedback) |
|
feedback = reflection_model.invoke([SystemMessage(content=section_grader_instructions_formatted), |
|
HumanMessage(content=section_grader_message)]) |
|
|
|
|
|
if feedback.grade == "pass" or state["search_iterations"] >= configurable.max_search_depth: |
|
|
|
section.sources = state.get("source_str_all", "") |
|
|
|
return Command( |
|
update={ |
|
"completed_sections": [section] |
|
}, |
|
goto=END |
|
) |
|
else: |
|
return Command( |
|
update={"search_queries": feedback.follow_up_queries, "section": section}, |
|
goto="search_rag_and_web" |
|
) |
|
|
|
def write_final_sections(state: SectionState, config: RunnableConfig): |
|
""" Write final sections of the report, which do not require RAG or web search and use the completed sections as context """ |
|
|
|
|
|
configurable = Configuration.from_runnable_config(config) |
|
|
|
|
|
topic = state["topic"] |
|
section = state["section"] |
|
completed_report_sections = state["report_sections_from_research"] |
|
|
|
|
|
system_instructions = final_section_writer_instructions.format(topic=topic, section_name=section.name, section_topic=section.description, context=completed_report_sections) |
|
|
|
|
|
writer_provider = get_config_value(configurable.writer_provider) |
|
writer_model_name = get_config_value(configurable.writer_model) |
|
writer_model = init_chat_model(model=writer_model_name, model_provider=writer_provider) |
|
section_content = writer_model.invoke([SystemMessage(content=system_instructions), |
|
HumanMessage(content="Generate a report section based on the provided sources.")]) |
|
|
|
|
|
section.content = section_content.content |
|
|
|
|
|
return {"completed_sections": [section]} |
|
|
|
def gather_completed_sections(state: ReportState): |
|
""" Gather completed sections from research and format them as context for writing the final sections """ |
|
|
|
|
|
original_sections = state["sections"] |
|
completed_sections = state["completed_sections"] |
|
|
|
|
|
completed_by_name = {s.name: s for s in completed_sections} |
|
|
|
|
|
ordered_completed_sections = [] |
|
for original_section in original_sections: |
|
if original_section.name in completed_by_name: |
|
ordered_completed_sections.append(completed_by_name[original_section.name]) |
|
|
|
|
|
sections_without_sources = [] |
|
for section in ordered_completed_sections: |
|
temp_section = Section( |
|
name=section.name, |
|
description=section.description, |
|
research=section.research, |
|
content=section.content, |
|
sources="" |
|
) |
|
sections_without_sources.append(temp_section) |
|
|
|
|
|
completed_report_sections = format_sections(sections_without_sources) |
|
|
|
return {"report_sections_from_research": completed_report_sections} |
|
|
|
def initiate_final_section_writing(state: ReportState): |
|
""" Write any final sections using the Send API to parallelize the process """ |
|
|
|
|
|
return Command(goto=[Send("write_final_sections", {"topic": state["topic"], "section": s, "report_sections_from_research": state["report_sections_from_research"]}) for s in state["sections"] if not s.research ]) |
|
|
|
|
|
def compile_final_report(state: ReportState): |
|
""" Compile the final report with section-grouped sources only for research sections """ |
|
|
|
|
|
sections = state["sections"] |
|
completed_sections = {s.name: s.content for s in state["completed_sections"]} |
|
|
|
|
|
for section in sections: |
|
section.content = completed_sections[section.name] |
|
|
|
|
|
main_report = "\n\n".join([s.content for s in sections]) |
|
|
|
|
|
research_sections_with_sources = [s for s in state["completed_sections"] if s.research and s.sources] |
|
|
|
if research_sections_with_sources: |
|
sources_section = "\n\n## Sources Used\n\n" |
|
|
|
|
|
for section in sections: |
|
if section.research: |
|
|
|
completed_section = next((s for s in state["completed_sections"] if s.name == section.name), None) |
|
if completed_section and completed_section.sources: |
|
sources_section += f"### Sources for Section: {section.name}\n\n" |
|
sources_section += completed_section.sources + "\n\n" |
|
|
|
final_report_with_sources = main_report + sources_section |
|
else: |
|
final_report_with_sources = main_report |
|
|
|
return {"final_report": final_report_with_sources} |
|
|
|
def initial_AB_topic_check(state: ReportState, config): |
|
""" Checks if the topic is related to A/B testing """ |
|
|
|
|
|
topic = state["topic"] |
|
|
|
|
|
configurable = Configuration.from_runnable_config(config) |
|
|
|
|
|
system_instructions = initial_AB_topic_check_instructions.format(topic=topic) |
|
|
|
|
|
initial_AB_topic_check_message = """Check if the topic is related to A/B testing (even vaguely e.g. statistics, A/B testing, experimentation, etc.). If the topic is related to A/B testing (even vaguely), return 'true'. If the topic is not related to A/B testing, return 'false'. """ |
|
|
|
|
|
planner_provider = get_config_value(configurable.planner_provider) |
|
planner_model = get_config_value(configurable.planner_model) |
|
|
|
reflection_model = init_chat_model( |
|
model=planner_model, |
|
model_provider=planner_provider, |
|
max_tokens=32_000, |
|
thinking={"type": "enabled", "budget_tokens": 24_000} |
|
) |
|
|
|
feedback = reflection_model.invoke([SystemMessage(content=system_instructions), |
|
HumanMessage(content=initial_AB_topic_check_message)]) |
|
|
|
|
|
response_content = str(feedback.content).lower().strip() |
|
is_explicitly_not_ab_testing = "false" in response_content |
|
|
|
|
|
updated_state = state.copy() |
|
updated_state["ab_testing_check"] = not is_explicitly_not_ab_testing |
|
|
|
|
|
|
|
if is_explicitly_not_ab_testing: |
|
return { |
|
"ab_testing_check": False, |
|
"final_report": "I'm trained to only generate reports related to A/B testing. Thus, unfortunately, I can't make this report." |
|
} |
|
else: |
|
return { |
|
"ab_testing_check": True |
|
} |
|
|
|
def route_after_ab_check(state: ReportState): |
|
"""Route to either generate_report_plan or end based on A/B testing check""" |
|
|
|
if state.get("ab_testing_check", True): |
|
return "generate_report_plan" |
|
else: |
|
return END |
|
|
|
section_builder = StateGraph(SectionState, output=SectionOutputState) |
|
section_builder.add_node("generate_queries", generate_queries) |
|
section_builder.add_node("search_rag_and_web", search_rag_and_web) |
|
section_builder.add_node("write_section", write_section) |
|
|
|
|
|
section_builder.add_edge(START, "generate_queries") |
|
section_builder.add_edge("generate_queries", "search_rag_and_web") |
|
section_builder.add_edge("search_rag_and_web", "write_section") |
|
|
|
|
|
|
|
|
|
builder = StateGraph(ReportState, input=ReportStateInput, output=ReportStateOutput, config_schema=Configuration) |
|
builder.add_node("initial_AB_topic_check", initial_AB_topic_check) |
|
builder.add_node("generate_report_plan", generate_report_plan) |
|
builder.add_node("build_section_with_web_research", section_builder.compile()) |
|
builder.add_node("gather_completed_sections", gather_completed_sections) |
|
builder.add_node("write_final_sections", write_final_sections) |
|
builder.add_node("compile_final_report", compile_final_report) |
|
builder.add_node("initiate_final_section_writing", initiate_final_section_writing) |
|
|
|
|
|
|
|
builder.add_edge(START, "initial_AB_topic_check") |
|
builder.add_conditional_edges("initial_AB_topic_check", route_after_ab_check, ["generate_report_plan", END]) |
|
builder.add_edge("build_section_with_web_research", "gather_completed_sections") |
|
builder.add_edge("gather_completed_sections", "initiate_final_section_writing") |
|
builder.add_edge("write_final_sections", "compile_final_report") |
|
builder.add_edge("compile_final_report", END) |
|
|
|
return builder.compile() |
|
|
|
def start_new_report(topic, report_placeholder): |
|
"""Start a new report generation process""" |
|
with st.spinner("Generating comprehensive report...This may take about 3-7 minutes."): |
|
|
|
|
|
input_state = {"topic": topic} |
|
|
|
|
|
try: |
|
config = {} |
|
|
|
|
|
result = asyncio.run(run_graph_to_completion(input_state, config)) |
|
|
|
if result.get("ab_testing_check") == False: |
|
|
|
response = result.get("final_report", "This topic is not related to A/B testing.") |
|
report_placeholder.markdown(response) |
|
return response |
|
else: |
|
|
|
final_report = result.get("final_report", "") |
|
if final_report: |
|
final_content = f"## π Final Report\n\n{final_report}" |
|
report_placeholder.markdown(final_content) |
|
return final_content |
|
else: |
|
error_msg = "No report was generated." |
|
report_placeholder.error(error_msg) |
|
return None |
|
|
|
except Exception as e: |
|
error_msg = f"Error generating report: {str(e)}" |
|
report_placeholder.error(error_msg) |
|
return None |
|
|
|
async def run_graph_to_completion(input_state, config): |
|
"""Run the graph to completion""" |
|
result = await report_system.ainvoke(input_state, config) |
|
return result |
|
|
|
|
|
st.markdown( |
|
"<h1>π A/B<sub><span style='color:green;'>AI</span></sub></h1>", |
|
unsafe_allow_html=True |
|
) |
|
st.markdown(""" |
|
A/B<sub><span style='color:green;'>AI</span></sub> is a specialized agent that generates comprehensive reports on your provided A/B testing topics using a thorough collection of Ron Kohavi's work, including his book, papers, and LinkedIn posts. For each section of the report, if A/B<sub><span style='color:green;'>AI</span></sub> can't answer your questions using this collection, it will then search Arxiv. If that's not enough, it will finally search the web. It provides ALL sources, section by section. It has been trained to only write based on the sources it retrieves. Let's begin! |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
try: |
|
|
|
loading_placeholder = st.empty() |
|
with loading_placeholder.container(): |
|
import time |
|
for dots in [".", "..", "..."]: |
|
loading_placeholder.text(f"Loading{dots}") |
|
time.sleep(0.2) |
|
|
|
|
|
vectorstore, chunks = initialize_vectorstore() |
|
report_system = initialize_report_system(vectorstore) |
|
|
|
|
|
loading_placeholder.empty() |
|
except Exception as e: |
|
st.error(f"Error initializing the system: {str(e)}") |
|
st.stop() |
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
|
|
|
|
|
|
for i, message in enumerate(st.session_state.messages): |
|
if message["role"] == "user": |
|
st.chat_message("user").write(message["content"]) |
|
else: |
|
with st.chat_message("assistant"): |
|
st.write(message["content"]) |
|
|
|
|
|
|
|
query = st.chat_input("Please give me a topic on anything regarding A/B Testing...") |
|
|
|
if query: |
|
|
|
st.chat_message("user").write(query) |
|
st.session_state.messages.append({"role": "user", "content": query}) |
|
|
|
|
|
with st.chat_message("assistant"): |
|
report_placeholder = st.empty() |
|
|
|
|
|
final_content = start_new_report(query, report_placeholder) |
|
|
|
|
|
if final_content: |
|
st.session_state.messages.append({ |
|
"role": "assistant", |
|
"content": final_content |
|
}) |
|
|
|
|
|
|
|
|
|
|
|
|