from langgraph.graph import StateGraph, END from typing import TypedDict, Annotated, List import operator from langgraph.checkpoint.sqlite import SqliteSaver from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, AIMessage, ChatMessage from langchain_openai import ChatOpenAI from pydantic import BaseModel from tavily import TavilyClient import os import sqlite3 class AgentState(TypedDict): task: str lnode: str plan: str draft: str feedback: str content: List[str] queries: List[str] revision_number: int max_revisions: int count: Annotated[int, operator.add] class Queries(BaseModel): queries: List[str] # Tool functions def plan_node(model, state: AgentState): table_output = """ Workout Table Sequence: - Section 1: Warm-up - Section 2: Strength Training - Section 3: Cardio - Section 4: Cool-down Workout Table Example: Workout Table Full Body (3 times per week): Day 1: Squat on Chair: 2-3 sets of 8-10 repetitions Bench Press with Dumbbells: 2-3 sets of 8-10 repetitions Row on the machine: 2-3 sets of 8-10 repetitions Glute Machine or Glute Bridge: 2-3 sets of 10-12 repetitions Development with dumbbells: 2-3 sets of 8-10 repetitions Dumbbell curls: 2-3 sets of 10-12 repetitions Day 2: Leg Extension: 2-3 sets of 10-12 repetitions Arms pushdowns (supporting on knees): 2-3 sets of 8-10 repetitions Dumbbell row: 2-3 sets of 8-10 repetitions Calf Raises (standing): 2-3 sets of 10-12 repetitions Lateral Raises: 2-3 sets of 10-12 repetitions Triceps Extension: 2-3 sets of 10-12 repetitions Day 3: Cardio (bike or treadmill): 30 minutes at moderate pace Abdominal: 2-3 sets of 12-15 repetitions Plank: 2-3 sets, holding for 20-30 seconds Remember to always warm up before training and stretch after. """ PLAN_PROMPT = ("You are an expert gym trainer tasked with writing a high level workout table. " "Write such an outline for the user provided workout. Give the three main headers of an outline of " "the workout table along with any relevant notes or instructions for the sections. " f"Here is the user's workout table: {table_output}") messages = [ SystemMessage(content=PLAN_PROMPT), HumanMessage(content=state['task']) ] response = model.invoke(messages) return {"plan": response.content, "lnode": "planner", "count": 1, } def research_plan_node(model, tavily, state: AgentState): RESEARCH_PLAN_PROMPT = ("You are a researcher charged with providing information that can " "be used when writing the following workout table. Generate a list of search " "queries that will gather " "any relevant information. Only generate 3 queries max.") queries = model.with_structured_output(Queries).invoke([ SystemMessage(content=RESEARCH_PLAN_PROMPT), HumanMessage(content=state['task']) ]) content = state['content'] or [] # add to content for q in queries.queries: response = tavily.search(query=q, max_results=2) for r in response['results']: content.append(r['content']) return {"content": content, "queries": queries.queries, "lnode": "research_plan", "count": 1, } def generation_node(model, state: AgentState): WRITER_PROMPT = ("You are an gym trainer assistant tasked with writing excellent workout tables. " "Generate the best workout table possible for the user's request and the initial outline. " "If the user provides feedback, respond with a revised version of your previous attempts. " "Utilize all the information below as needed: \n" "------\n" "{content}") content = "\n\n".join(state['content'] or []) user_message = HumanMessage( content=f"{state['task']}\n\nHere is my workout table:\n\n{state['plan']}") messages = [ SystemMessage( content=WRITER_PROMPT.format(content=content) ), user_message ] response = model.invoke(messages) return { "draft": response.content, "revision_number": state.get("revision_number", 1) + 1, "lnode": "generate", "count": 1, } def reflection_node(model, state: AgentState): REFLECTION_PROMPT = ("You are an instructor personal grading an workout table submission. " "Generate feedback and recommendations for the user's submission. " "Provide detailed recommendations, including requests for objectives, level of intensity, health benefits, health conditions, etc.") messages = [ SystemMessage(content=REFLECTION_PROMPT), HumanMessage(content=state['draft']) ] response = model.invoke(messages) return {"feedback": response.content, "lnode": "reflect", "count": 1, } def research_feedback_node(model, tavily, state: AgentState): RESEARCH_FEEDBACK_PROMPT = ("You are a researcher charged with providing information that can " "be used when writing the following workout table. Generate a list of search " "queries that will gather " "any relevant information. Only generate 3 queries max.") queries = model.with_structured_output(Queries).invoke([ SystemMessage(content=RESEARCH_FEEDBACK_PROMPT), HumanMessage(content=state['feedback']) ]) content = state['content'] or [] # add to content for q in queries.queries: response = tavily.search(query=q, max_results=2) for r in response['results']: content.append(r['content']) queries = model.with_structured_output(Queries).invoke([ SystemMessage(content=RESEARCH_FEEDBACK_PROMPT), HumanMessage(content=state['feedback']) ]) content = state['content'] or [] for q in queries.queries: response = tavily.search(query=q, max_results=2) for r in response['results']: content.append(r['content']) return {"content": content, "lnode": "research_feedback", "count": 1, } def should_continue(state): if state["revision_number"] > state["max_revisions"]: return END return "reflect" # Function to create the graph def create_workout_table_graph(): model = ChatOpenAI(model="gpt-3.5-turbo", temperature=0) tavily = TavilyClient(api_key=os.environ["TAVILY_API_KEY"]) builder = StateGraph(AgentState) builder.add_node("planner", lambda state: plan_node(model, state)) builder.add_node("research_plan", lambda state: research_plan_node(model, tavily, state)) builder.add_node("generate", lambda state: generation_node(model, state)) builder.add_node("reflect", lambda state: reflection_node(model, state)) builder.add_node("research_feedback", lambda state: research_feedback_node(model, tavily, state)) builder.set_entry_point("planner") builder.add_conditional_edges( "generate", should_continue, {END: END, "reflect": "reflect"} ) builder.add_edge("planner", "research_plan") builder.add_edge("research_plan", "generate") builder.add_edge("reflect", "research_feedback") builder.add_edge("research_feedback", "generate") memory = SqliteSaver(conn=sqlite3.connect(":memory:", check_same_thread=False)) graph = builder.compile( checkpointer=memory, interrupt_after=['planner', 'generate', 'reflect', 'research_plan', 'research_feedback'] ) return graph