File size: 7,863 Bytes
0955862 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
from langgraph.graph import StateGraph, END
from typing import TypedDict, Annotated, List
import operator
from langgraph.checkpoint.sqlite import SqliteSaver
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, AIMessage, ChatMessage
from langchain_openai import ChatOpenAI
from pydantic import BaseModel
from tavily import TavilyClient
import os
import sqlite3
class AgentState(TypedDict):
task: str
lnode: str
plan: str
draft: str
feedback: str
content: List[str]
queries: List[str]
revision_number: int
max_revisions: int
count: Annotated[int, operator.add]
class Queries(BaseModel):
queries: List[str]
# Tool functions
def plan_node(model, state: AgentState):
table_output = """
Workout Table Sequence:
- Section 1: Warm-up
- Section 2: Strength Training
- Section 3: Cardio
- Section 4: Cool-down
Workout Table Example:
Workout Table Full Body (3 times per week):
Day 1:
Squat on Chair: 2-3 sets of 8-10 repetitions
Bench Press with Dumbbells: 2-3 sets of 8-10 repetitions
Row on the machine: 2-3 sets of 8-10 repetitions
Glute Machine or Glute Bridge: 2-3 sets of 10-12 repetitions
Development with dumbbells: 2-3 sets of 8-10 repetitions
Dumbbell curls: 2-3 sets of 10-12 repetitions
Day 2:
Leg Extension: 2-3 sets of 10-12 repetitions
Arms pushdowns (supporting on knees): 2-3 sets of 8-10 repetitions
Dumbbell row: 2-3 sets of 8-10 repetitions
Calf Raises (standing): 2-3 sets of 10-12 repetitions
Lateral Raises: 2-3 sets of 10-12 repetitions
Triceps Extension: 2-3 sets of 10-12 repetitions
Day 3:
Cardio (bike or treadmill): 30 minutes at moderate pace
Abdominal: 2-3 sets of 12-15 repetitions
Plank: 2-3 sets, holding for 20-30 seconds
Remember to always warm up before training and stretch after.
"""
PLAN_PROMPT = ("You are an expert gym trainer tasked with writing a high level workout table. "
"Write such an outline for the user provided workout. Give the three main headers of an outline of "
"the workout table along with any relevant notes or instructions for the sections. "
f"Here is the user's workout table: {table_output}")
messages = [
SystemMessage(content=PLAN_PROMPT),
HumanMessage(content=state['task'])
]
response = model.invoke(messages)
return {"plan": response.content,
"lnode": "planner",
"count": 1,
}
def research_plan_node(model, tavily, state: AgentState):
RESEARCH_PLAN_PROMPT = ("You are a researcher charged with providing information that can "
"be used when writing the following workout table. Generate a list of search "
"queries that will gather "
"any relevant information. Only generate 3 queries max.")
queries = model.with_structured_output(Queries).invoke([
SystemMessage(content=RESEARCH_PLAN_PROMPT),
HumanMessage(content=state['task'])
])
content = state['content'] or [] # add to content
for q in queries.queries:
response = tavily.search(query=q, max_results=2)
for r in response['results']:
content.append(r['content'])
return {"content": content,
"queries": queries.queries,
"lnode": "research_plan",
"count": 1,
}
def generation_node(model, state: AgentState):
WRITER_PROMPT = ("You are an gym trainer assistant tasked with writing excellent workout tables. "
"Generate the best workout table possible for the user's request and the initial outline. "
"If the user provides feedback, respond with a revised version of your previous attempts. "
"Utilize all the information below as needed: \n"
"------\n"
"{content}")
content = "\n\n".join(state['content'] or [])
user_message = HumanMessage(
content=f"{state['task']}\n\nHere is my workout table:\n\n{state['plan']}")
messages = [
SystemMessage(
content=WRITER_PROMPT.format(content=content)
),
user_message
]
response = model.invoke(messages)
return {
"draft": response.content,
"revision_number": state.get("revision_number", 1) + 1,
"lnode": "generate",
"count": 1,
}
def reflection_node(model, state: AgentState):
REFLECTION_PROMPT = ("You are an instructor personal grading an workout table submission. "
"Generate feedback and recommendations for the user's submission. "
"Provide detailed recommendations, including requests for objectives, level of intensity, health benefits, health conditions, etc.")
messages = [
SystemMessage(content=REFLECTION_PROMPT),
HumanMessage(content=state['draft'])
]
response = model.invoke(messages)
return {"feedback": response.content,
"lnode": "reflect",
"count": 1,
}
def research_feedback_node(model, tavily, state: AgentState):
RESEARCH_FEEDBACK_PROMPT = ("You are a researcher charged with providing information that can "
"be used when writing the following workout table. Generate a list of search "
"queries that will gather "
"any relevant information. Only generate 3 queries max.")
queries = model.with_structured_output(Queries).invoke([
SystemMessage(content=RESEARCH_FEEDBACK_PROMPT),
HumanMessage(content=state['feedback'])
])
content = state['content'] or [] # add to content
for q in queries.queries:
response = tavily.search(query=q, max_results=2)
for r in response['results']:
content.append(r['content'])
queries = model.with_structured_output(Queries).invoke([
SystemMessage(content=RESEARCH_FEEDBACK_PROMPT),
HumanMessage(content=state['feedback'])
])
content = state['content'] or []
for q in queries.queries:
response = tavily.search(query=q, max_results=2)
for r in response['results']:
content.append(r['content'])
return {"content": content,
"lnode": "research_feedback",
"count": 1,
}
def should_continue(state):
if state["revision_number"] > state["max_revisions"]:
return END
return "reflect"
# Function to create the graph
def create_workout_table_graph():
model = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
tavily = TavilyClient(api_key=os.environ["TAVILY_API_KEY"])
builder = StateGraph(AgentState)
builder.add_node("planner", lambda state: plan_node(model, state))
builder.add_node("research_plan", lambda state: research_plan_node(model, tavily, state))
builder.add_node("generate", lambda state: generation_node(model, state))
builder.add_node("reflect", lambda state: reflection_node(model, state))
builder.add_node("research_feedback", lambda state: research_feedback_node(model, tavily, state))
builder.set_entry_point("planner")
builder.add_conditional_edges(
"generate",
should_continue,
{END: END, "reflect": "reflect"}
)
builder.add_edge("planner", "research_plan")
builder.add_edge("research_plan", "generate")
builder.add_edge("reflect", "research_feedback")
builder.add_edge("research_feedback", "generate")
memory = SqliteSaver(conn=sqlite3.connect(":memory:", check_same_thread=False))
graph = builder.compile(
checkpointer=memory,
interrupt_after=['planner', 'generate', 'reflect', 'research_plan', 'research_feedback']
)
return graph
|