|
from langgraph.graph import StateGraph, END |
|
from typing import TypedDict, Annotated, List |
|
import operator |
|
from langgraph.checkpoint.sqlite import SqliteSaver |
|
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, AIMessage, ChatMessage |
|
from langchain_openai import ChatOpenAI |
|
from pydantic import BaseModel |
|
from tavily import TavilyClient |
|
import os |
|
import sqlite3 |
|
|
|
class AgentState(TypedDict): |
|
task: str |
|
lnode: str |
|
plan: str |
|
draft: str |
|
feedback: str |
|
content: List[str] |
|
queries: List[str] |
|
revision_number: int |
|
max_revisions: int |
|
count: Annotated[int, operator.add] |
|
|
|
|
|
class Queries(BaseModel): |
|
queries: List[str] |
|
|
|
|
|
def plan_node(model, state: AgentState): |
|
table_output = """ |
|
Workout Table Sequence: |
|
- Section 1: Warm-up |
|
- Section 2: Strength Training |
|
- Section 3: Cardio |
|
- Section 4: Cool-down |
|
|
|
Workout Table Example: |
|
Workout Table Full Body (3 times per week): |
|
Day 1: |
|
Squat on Chair: 2-3 sets of 8-10 repetitions |
|
Bench Press with Dumbbells: 2-3 sets of 8-10 repetitions |
|
Row on the machine: 2-3 sets of 8-10 repetitions |
|
Glute Machine or Glute Bridge: 2-3 sets of 10-12 repetitions |
|
Development with dumbbells: 2-3 sets of 8-10 repetitions |
|
Dumbbell curls: 2-3 sets of 10-12 repetitions |
|
|
|
Day 2: |
|
Leg Extension: 2-3 sets of 10-12 repetitions |
|
Arms pushdowns (supporting on knees): 2-3 sets of 8-10 repetitions |
|
Dumbbell row: 2-3 sets of 8-10 repetitions |
|
Calf Raises (standing): 2-3 sets of 10-12 repetitions |
|
Lateral Raises: 2-3 sets of 10-12 repetitions |
|
Triceps Extension: 2-3 sets of 10-12 repetitions |
|
|
|
Day 3: |
|
Cardio (bike or treadmill): 30 minutes at moderate pace |
|
Abdominal: 2-3 sets of 12-15 repetitions |
|
Plank: 2-3 sets, holding for 20-30 seconds |
|
|
|
Remember to always warm up before training and stretch after. |
|
""" |
|
|
|
PLAN_PROMPT = ("You are an expert gym trainer tasked with writing a high level workout table. " |
|
"Write such an outline for the user provided workout. Give the three main headers of an outline of " |
|
"the workout table along with any relevant notes or instructions for the sections. " |
|
f"Here is the user's workout table: {table_output}") |
|
messages = [ |
|
SystemMessage(content=PLAN_PROMPT), |
|
HumanMessage(content=state['task']) |
|
] |
|
response = model.invoke(messages) |
|
return {"plan": response.content, |
|
"lnode": "planner", |
|
"count": 1, |
|
} |
|
|
|
def research_plan_node(model, tavily, state: AgentState): |
|
RESEARCH_PLAN_PROMPT = ("You are a researcher charged with providing information that can " |
|
"be used when writing the following workout table. Generate a list of search " |
|
"queries that will gather " |
|
"any relevant information. Only generate 3 queries max.") |
|
queries = model.with_structured_output(Queries).invoke([ |
|
SystemMessage(content=RESEARCH_PLAN_PROMPT), |
|
HumanMessage(content=state['task']) |
|
]) |
|
content = state['content'] or [] |
|
for q in queries.queries: |
|
response = tavily.search(query=q, max_results=2) |
|
for r in response['results']: |
|
content.append(r['content']) |
|
return {"content": content, |
|
"queries": queries.queries, |
|
"lnode": "research_plan", |
|
"count": 1, |
|
} |
|
|
|
def generation_node(model, state: AgentState): |
|
WRITER_PROMPT = ("You are an gym trainer assistant tasked with writing excellent workout tables. " |
|
"Generate the best workout table possible for the user's request and the initial outline. " |
|
"If the user provides feedback, respond with a revised version of your previous attempts. " |
|
"Utilize all the information below as needed: \n" |
|
"------\n" |
|
"{content}") |
|
content = "\n\n".join(state['content'] or []) |
|
user_message = HumanMessage( |
|
content=f"{state['task']}\n\nHere is my workout table:\n\n{state['plan']}") |
|
messages = [ |
|
SystemMessage( |
|
content=WRITER_PROMPT.format(content=content) |
|
), |
|
user_message |
|
] |
|
response = model.invoke(messages) |
|
return { |
|
"draft": response.content, |
|
"revision_number": state.get("revision_number", 1) + 1, |
|
"lnode": "generate", |
|
"count": 1, |
|
} |
|
|
|
def reflection_node(model, state: AgentState): |
|
REFLECTION_PROMPT = ("You are an instructor personal grading an workout table submission. " |
|
"Generate feedback and recommendations for the user's submission. " |
|
"Provide detailed recommendations, including requests for objectives, level of intensity, health benefits, health conditions, etc.") |
|
messages = [ |
|
SystemMessage(content=REFLECTION_PROMPT), |
|
HumanMessage(content=state['draft']) |
|
] |
|
response = model.invoke(messages) |
|
return {"feedback": response.content, |
|
"lnode": "reflect", |
|
"count": 1, |
|
} |
|
|
|
def research_feedback_node(model, tavily, state: AgentState): |
|
RESEARCH_FEEDBACK_PROMPT = ("You are a researcher charged with providing information that can " |
|
"be used when writing the following workout table. Generate a list of search " |
|
"queries that will gather " |
|
"any relevant information. Only generate 3 queries max.") |
|
queries = model.with_structured_output(Queries).invoke([ |
|
SystemMessage(content=RESEARCH_FEEDBACK_PROMPT), |
|
HumanMessage(content=state['feedback']) |
|
]) |
|
content = state['content'] or [] |
|
for q in queries.queries: |
|
response = tavily.search(query=q, max_results=2) |
|
for r in response['results']: |
|
content.append(r['content']) |
|
queries = model.with_structured_output(Queries).invoke([ |
|
SystemMessage(content=RESEARCH_FEEDBACK_PROMPT), |
|
HumanMessage(content=state['feedback']) |
|
]) |
|
content = state['content'] or [] |
|
for q in queries.queries: |
|
response = tavily.search(query=q, max_results=2) |
|
for r in response['results']: |
|
content.append(r['content']) |
|
return {"content": content, |
|
"lnode": "research_feedback", |
|
"count": 1, |
|
} |
|
|
|
def should_continue(state): |
|
if state["revision_number"] > state["max_revisions"]: |
|
return END |
|
return "reflect" |
|
|
|
|
|
def create_workout_table_graph(): |
|
model = ChatOpenAI(model="gpt-3.5-turbo", temperature=0) |
|
tavily = TavilyClient(api_key=os.environ["TAVILY_API_KEY"]) |
|
|
|
builder = StateGraph(AgentState) |
|
builder.add_node("planner", lambda state: plan_node(model, state)) |
|
builder.add_node("research_plan", lambda state: research_plan_node(model, tavily, state)) |
|
builder.add_node("generate", lambda state: generation_node(model, state)) |
|
builder.add_node("reflect", lambda state: reflection_node(model, state)) |
|
builder.add_node("research_feedback", lambda state: research_feedback_node(model, tavily, state)) |
|
builder.set_entry_point("planner") |
|
builder.add_conditional_edges( |
|
"generate", |
|
should_continue, |
|
{END: END, "reflect": "reflect"} |
|
) |
|
builder.add_edge("planner", "research_plan") |
|
builder.add_edge("research_plan", "generate") |
|
builder.add_edge("reflect", "research_feedback") |
|
builder.add_edge("research_feedback", "generate") |
|
memory = SqliteSaver(conn=sqlite3.connect(":memory:", check_same_thread=False)) |
|
graph = builder.compile( |
|
checkpointer=memory, |
|
interrupt_after=['planner', 'generate', 'reflect', 'research_plan', 'research_feedback'] |
|
) |
|
|
|
return graph |
|
|
|
|