ernani
removing warnings and dotenv
3da7218
from langgraph.graph import StateGraph, END
from typing import TypedDict, Annotated, List
import operator
from langgraph.checkpoint.sqlite import SqliteSaver
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, AIMessage, ChatMessage
from langchain_openai import ChatOpenAI
from pydantic import BaseModel
from tavily import TavilyClient
import os
import sqlite3
class AgentState(TypedDict):
task: str
lnode: str
plan: str
draft: str
feedback: str
content: List[str]
queries: List[str]
revision_number: int
max_revisions: int
count: Annotated[int, operator.add]
class Queries(BaseModel):
queries: List[str]
# Tool functions
def plan_node(model, state: AgentState):
table_output = """
Workout Table Sequence:
- Section 1: Warm-up
- Section 2: Strength Training
- Section 3: Cardio
- Section 4: Cool-down
Workout Table Example:
Workout Table Full Body (3 times per week):
Day 1:
Squat on Chair: 2-3 sets of 8-10 repetitions
Bench Press with Dumbbells: 2-3 sets of 8-10 repetitions
Row on the machine: 2-3 sets of 8-10 repetitions
Glute Machine or Glute Bridge: 2-3 sets of 10-12 repetitions
Development with dumbbells: 2-3 sets of 8-10 repetitions
Dumbbell curls: 2-3 sets of 10-12 repetitions
Day 2:
Leg Extension: 2-3 sets of 10-12 repetitions
Arms pushdowns (supporting on knees): 2-3 sets of 8-10 repetitions
Dumbbell row: 2-3 sets of 8-10 repetitions
Calf Raises (standing): 2-3 sets of 10-12 repetitions
Lateral Raises: 2-3 sets of 10-12 repetitions
Triceps Extension: 2-3 sets of 10-12 repetitions
Day 3:
Cardio (bike or treadmill): 30 minutes at moderate pace
Abdominal: 2-3 sets of 12-15 repetitions
Plank: 2-3 sets, holding for 20-30 seconds
Remember to always warm up before training and stretch after.
"""
PLAN_PROMPT = ("You are an expert gym trainer tasked with writing a high level workout table. "
"Write such an outline for the user provided workout. Give the three main headers of an outline of "
"the workout table along with any relevant notes or instructions for the sections. "
f"Here is the user's workout table: {table_output}")
messages = [
SystemMessage(content=PLAN_PROMPT),
HumanMessage(content=state['task'])
]
response = model.invoke(messages)
return {"plan": response.content,
"lnode": "planner",
"count": 1,
}
def research_plan_node(model, tavily, state: AgentState):
RESEARCH_PLAN_PROMPT = ("You are a researcher charged with providing information that can "
"be used when writing the following workout table. Generate a list of search "
"queries that will gather "
"any relevant information. Only generate 3 queries max.")
queries = model.with_structured_output(Queries).invoke([
SystemMessage(content=RESEARCH_PLAN_PROMPT),
HumanMessage(content=state['task'])
])
content = state['content'] or [] # add to content
for q in queries.queries:
response = tavily.search(query=q, max_results=2)
for r in response['results']:
content.append(r['content'])
return {"content": content,
"queries": queries.queries,
"lnode": "research_plan",
"count": 1,
}
def generation_node(model, state: AgentState):
WRITER_PROMPT = ("You are an gym trainer assistant tasked with writing excellent workout tables. "
"Generate the best workout table possible for the user's request and the initial outline. "
"If the user provides feedback, respond with a revised version of your previous attempts. "
"Utilize all the information below as needed: \n"
"------\n"
"{content}")
content = "\n\n".join(state['content'] or [])
user_message = HumanMessage(
content=f"{state['task']}\n\nHere is my workout table:\n\n{state['plan']}")
messages = [
SystemMessage(
content=WRITER_PROMPT.format(content=content)
),
user_message
]
response = model.invoke(messages)
return {
"draft": response.content,
"revision_number": state.get("revision_number", 1) + 1,
"lnode": "generate",
"count": 1,
}
def reflection_node(model, state: AgentState):
REFLECTION_PROMPT = ("You are an instructor personal grading an workout table submission. "
"Generate feedback and recommendations for the user's submission. "
"Provide detailed recommendations, including requests for objectives, level of intensity, health benefits, health conditions, etc.")
messages = [
SystemMessage(content=REFLECTION_PROMPT),
HumanMessage(content=state['draft'])
]
response = model.invoke(messages)
return {"feedback": response.content,
"lnode": "reflect",
"count": 1,
}
def research_feedback_node(model, tavily, state: AgentState):
RESEARCH_FEEDBACK_PROMPT = ("You are a researcher charged with providing information that can "
"be used when writing the following workout table. Generate a list of search "
"queries that will gather "
"any relevant information. Only generate 3 queries max.")
queries = model.with_structured_output(Queries).invoke([
SystemMessage(content=RESEARCH_FEEDBACK_PROMPT),
HumanMessage(content=state['feedback'])
])
content = state['content'] or [] # add to content
for q in queries.queries:
response = tavily.search(query=q, max_results=2)
for r in response['results']:
content.append(r['content'])
queries = model.with_structured_output(Queries).invoke([
SystemMessage(content=RESEARCH_FEEDBACK_PROMPT),
HumanMessage(content=state['feedback'])
])
content = state['content'] or []
for q in queries.queries:
response = tavily.search(query=q, max_results=2)
for r in response['results']:
content.append(r['content'])
return {"content": content,
"lnode": "research_feedback",
"count": 1,
}
def should_continue(state):
if state["revision_number"] > state["max_revisions"]:
return END
return "reflect"
# Function to create the graph
def create_workout_table_graph():
model = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
tavily = TavilyClient(api_key=os.environ["TAVILY_API_KEY"])
builder = StateGraph(AgentState)
builder.add_node("planner", lambda state: plan_node(model, state))
builder.add_node("research_plan", lambda state: research_plan_node(model, tavily, state))
builder.add_node("generate", lambda state: generation_node(model, state))
builder.add_node("reflect", lambda state: reflection_node(model, state))
builder.add_node("research_feedback", lambda state: research_feedback_node(model, tavily, state))
builder.set_entry_point("planner")
builder.add_conditional_edges(
"generate",
should_continue,
{END: END, "reflect": "reflect"}
)
builder.add_edge("planner", "research_plan")
builder.add_edge("research_plan", "generate")
builder.add_edge("reflect", "research_feedback")
builder.add_edge("research_feedback", "generate")
memory = SqliteSaver(conn=sqlite3.connect(":memory:", check_same_thread=False))
graph = builder.compile(
checkpointer=memory,
interrupt_after=['planner', 'generate', 'reflect', 'research_plan', 'research_feedback']
)
return graph