|
import gradio as gr |
|
import os |
|
import time |
|
from tools import create_workout_table_graph |
|
|
|
class WorkoutTableGUI: |
|
def __init__(self, graph=None, share=False): |
|
self.graph = graph if graph else create_workout_table_graph() |
|
self.share = share |
|
self.partial_message = "" |
|
self.response = {} |
|
self.max_iterations = 10 |
|
self.iterations = [] |
|
self.threads = [] |
|
self.thread_id = -1 |
|
self.thread = {"configurable": {"thread_id": str(self.thread_id)}} |
|
self.demo = self.create_interface() |
|
|
|
def run_agent(self, start, topic, stop_after): |
|
if start: |
|
self.iterations.append(0) |
|
config = {'task': topic, "max_revisions": 2, "revision_number": 0, |
|
'lnode': "", 'planner': "no plan", 'draft': "no draft", 'feedback': "no feedback", |
|
'content': ["no content",], 'queries': "no queries", 'count': 0} |
|
self.thread_id += 1 |
|
self.threads.append(self.thread_id) |
|
else: |
|
config = None |
|
self.thread = {"configurable": {"thread_id": str(self.thread_id)}} |
|
while self.iterations[self.thread_id] < self.max_iterations: |
|
self.response = self.graph.invoke(config, self.thread) |
|
self.iterations[self.thread_id] += 1 |
|
self.partial_message += str(self.response) |
|
self.partial_message += f"\n------------------\n\n" |
|
lnode, nnode, _, rev, acount = self.get_disp_state() |
|
yield self.partial_message, lnode, nnode, self.thread_id, rev, acount |
|
config = None |
|
if not nnode: |
|
return |
|
if lnode in stop_after: |
|
return |
|
return |
|
|
|
def get_disp_state(self): |
|
current_state = self.graph.get_state(self.thread) |
|
lnode = current_state.values["lnode"] |
|
acount = current_state.values["count"] |
|
rev = current_state.values["revision_number"] |
|
nnode = current_state.next |
|
return lnode, nnode, self.thread_id, rev, acount |
|
|
|
def get_state(self, key): |
|
current_values = self.graph.get_state(self.thread) |
|
if key in current_values.values: |
|
lnode, nnode, self.thread_id, rev, astep = self.get_disp_state() |
|
new_label = f"last_node: {lnode}, thread_id: {self.thread_id}, rev: {rev}, step: {astep}" |
|
return gr.update(label=new_label, value=current_values.values[key]) |
|
else: |
|
return "" |
|
|
|
def get_content(self): |
|
current_values = self.graph.get_state(self.thread) |
|
if "content" in current_values.values: |
|
content = current_values.values["content"] |
|
lnode, nnode, thread_id, rev, astep = self.get_disp_state() |
|
new_label = f"last_node: {lnode}, thread_id: {self.thread_id}, rev: {rev}, step: {astep}" |
|
return gr.update(label=new_label, value="\n\n".join(item for item in content) + "\n\n") |
|
else: |
|
return "" |
|
|
|
def update_hist_pd(self): |
|
hist = [] |
|
for state in self.graph.get_state_history(self.thread): |
|
if state.metadata['step'] < 1: |
|
continue |
|
|
|
thread_ts = state.config['configurable'].get('thread_ts', 'N/A') |
|
tid = state.config['configurable']['thread_id'] |
|
count = state.values['count'] |
|
lnode = state.values['lnode'] |
|
rev = state.values['revision_number'] |
|
nnode = state.next |
|
st = f"{tid}:{count}:{lnode}:{nnode}:{rev}:{thread_ts}" |
|
hist.append(st) |
|
return gr.Dropdown(label="update_state from: thread:count:last_node:next_node:rev:thread_ts", |
|
choices=hist, value=hist[0] if hist else None, interactive=True) |
|
|
|
def find_config(self, thread_ts): |
|
for state in self.graph.get_state_history(self.thread): |
|
config = state.config |
|
|
|
if 'thread_ts' not in config['configurable'] or config['configurable']['thread_ts'] != thread_ts: |
|
continue |
|
return config |
|
return None |
|
|
|
def copy_state(self, hist_str): |
|
thread_ts = hist_str.split(":")[-1] |
|
config = self.find_config(thread_ts) |
|
if config is None: |
|
return None, None, None, None, None |
|
state = self.graph.get_state(config) |
|
self.graph.update_state(self.thread, state.values, as_node=state.values['lnode']) |
|
new_state = self.graph.get_state(self.thread) |
|
new_thread_ts = new_state.config['configurable'].get('thread_ts', 'N/A') |
|
tid = new_state.config['configurable']['thread_id'] |
|
count = new_state.values['count'] |
|
lnode = new_state.values['lnode'] |
|
rev = new_state.values['revision_number'] |
|
nnode = new_state.next |
|
return lnode, nnode, new_thread_ts, rev, count |
|
|
|
def update_thread_pd(self): |
|
return gr.Dropdown(label="choose thread", choices=self.threads, value=self.thread_id, interactive=True) |
|
|
|
def switch_thread(self, new_thread_id): |
|
self.thread = {"configurable": {"thread_id": str(new_thread_id)}} |
|
self.thread_id = new_thread_id |
|
return |
|
|
|
def modify_state(self, key, asnode, new_state): |
|
current_values = self.graph.get_state(self.thread) |
|
current_values.values[key] = new_state |
|
self.graph.update_state(self.thread, current_values.values, as_node=asnode) |
|
return |
|
|
|
def create_interface(self): |
|
with gr.Blocks(theme=gr.themes.Default(spacing_size='sm', text_size="sm")) as demo: |
|
|
|
def updt_disp(): |
|
current_state = self.graph.get_state(self.thread) |
|
hist = [] |
|
for state in self.graph.get_state_history(self.thread): |
|
if state.metadata['step'] < 1: |
|
continue |
|
|
|
s_thread_ts = state.config['configurable'].get('thread_ts', 'N/A') |
|
s_tid = state.config['configurable']['thread_id'] |
|
s_count = state.values['count'] |
|
s_lnode = state.values['lnode'] |
|
s_rev = state.values['revision_number'] |
|
s_nnode = state.next |
|
st = f"{s_tid}:{s_count}:{s_lnode}:{s_nnode}:{s_rev}:{s_thread_ts}" |
|
hist.append(st) |
|
if not current_state.metadata: |
|
return {} |
|
else: |
|
return { |
|
topic_bx: current_state.values["task"], |
|
lnode_bx: current_state.values["lnode"], |
|
count_bx: current_state.values["count"], |
|
revision_bx: current_state.values["revision_number"], |
|
nnode_bx: current_state.next, |
|
threadid_bx: self.thread_id, |
|
thread_pd: gr.Dropdown(label="choose thread", choices=self.threads, value=self.thread_id, interactive=True), |
|
step_pd: gr.Dropdown(label="update_state from: thread:count:last_node:next_node:rev:thread_ts", |
|
choices=hist, value=hist[0] if hist else None, interactive=True), |
|
} |
|
|
|
def get_snapshots(): |
|
new_label = f"thread_id: {self.thread_id}, Summary of snapshots" |
|
sstate = "" |
|
for state in self.graph.get_state_history(self.thread): |
|
for key in ['plan', 'draft', 'feedback']: |
|
if key in state.values: |
|
state.values[key] = state.values[key][:80] + "..." |
|
if 'content' in state.values: |
|
for i in range(len(state.values['content'])): |
|
state.values['content'][i] = state.values['content'][i][:20] + '...' |
|
if 'writes' in state.metadata: |
|
state.metadata['writes'] = "not shown" |
|
sstate += str(state) + "\n\n" |
|
return gr.update(label=new_label, value=sstate) |
|
|
|
def vary_btn(stat): |
|
return gr.update(variant=stat) |
|
|
|
with gr.Tab("Agent"): |
|
with gr.Row(): |
|
topic_bx = gr.Textbox(label="Workout Table", value="Workout Table for a 30 year old male who wants to gain muscle mass and strength") |
|
gen_btn = gr.Button("Generate Workout Table", scale=0, min_width=80, variant='primary') |
|
cont_btn = gr.Button("Continue Workout Table", scale=0, min_width=80) |
|
with gr.Row(): |
|
lnode_bx = gr.Textbox(label="last node", min_width=100) |
|
nnode_bx = gr.Textbox(label="next node", min_width=100) |
|
threadid_bx = gr.Textbox(label="Thread", scale=0, min_width=80) |
|
revision_bx = gr.Textbox(label="Draft Rev", scale=0, min_width=80) |
|
count_bx = gr.Textbox(label="count", scale=0, min_width=80) |
|
with gr.Accordion("Manage Agent", open=False): |
|
checks = list(self.graph.nodes.keys()) |
|
checks.remove('__start__') |
|
stop_after = gr.CheckboxGroup(checks, label="Interrupt After State", value=checks, scale=0, min_width=400) |
|
with gr.Row(): |
|
thread_pd = gr.Dropdown(choices=self.threads, interactive=True, label="select thread", min_width=120, scale=0) |
|
step_pd = gr.Dropdown(choices=['N/A'], interactive=True, label="select step", min_width=160, scale=1) |
|
live = gr.Textbox(label="Live Agent Output", lines=5, max_lines=5) |
|
|
|
|
|
sdisps = [topic_bx, lnode_bx, nnode_bx, threadid_bx, revision_bx, count_bx, step_pd, thread_pd] |
|
thread_pd.input(self.switch_thread, [thread_pd], None).then( |
|
fn=updt_disp, inputs=None, outputs=sdisps) |
|
step_pd.input(self.copy_state, [step_pd], None).then( |
|
fn=updt_disp, inputs=None, outputs=sdisps) |
|
gen_btn.click(vary_btn, gr.Number("secondary", visible=False), gen_btn).then( |
|
fn=self.run_agent, inputs=[gr.Number(True, visible=False), topic_bx, stop_after], outputs=[live], show_progress=True).then( |
|
fn=updt_disp, inputs=None, outputs=sdisps).then( |
|
vary_btn, gr.Number("primary", visible=False), gen_btn).then( |
|
vary_btn, gr.Number("primary", visible=False), cont_btn) |
|
cont_btn.click(vary_btn, gr.Number("secondary", visible=False), cont_btn).then( |
|
fn=self.run_agent, inputs=[gr.Number(False, visible=False), topic_bx, stop_after], |
|
outputs=[live]).then( |
|
fn=updt_disp, inputs=None, outputs=sdisps).then( |
|
vary_btn, gr.Number("primary", visible=False), cont_btn) |
|
|
|
with gr.Tab("Workout Table"): |
|
with gr.Row(): |
|
refresh_btn = gr.Button("Refresh") |
|
modify_btn = gr.Button("Modify") |
|
plan = gr.Textbox(label="Plan", lines=10, interactive=True) |
|
refresh_btn.click(fn=self.get_state, inputs=gr.Number("plan", visible=False), outputs=plan) |
|
modify_btn.click(fn=self.modify_state, inputs=[gr.Number("plan", visible=False), |
|
gr.Number("planner", visible=False), plan], outputs=None).then( |
|
fn=updt_disp, inputs=None, outputs=sdisps) |
|
with gr.Tab("Research Content"): |
|
refresh_btn = gr.Button("Refresh") |
|
content_bx = gr.Textbox(label="content", lines=10) |
|
refresh_btn.click(fn=self.get_content, inputs=None, outputs=content_bx) |
|
with gr.Tab("Draft"): |
|
with gr.Row(): |
|
refresh_btn = gr.Button("Refresh") |
|
modify_btn = gr.Button("Modify") |
|
draft_bx = gr.Textbox(label="draft", lines=10, interactive=True) |
|
refresh_btn.click(fn=self.get_state, inputs=gr.Number("draft", visible=False), outputs=draft_bx) |
|
modify_btn.click(fn=self.modify_state, inputs=[gr.Number("draft", visible=False), |
|
gr.Number("generate", visible=False), draft_bx], outputs=None).then( |
|
fn=updt_disp, inputs=None, outputs=sdisps) |
|
with gr.Tab("Feedback"): |
|
with gr.Row(): |
|
refresh_btn = gr.Button("Refresh") |
|
modify_btn = gr.Button("Modify") |
|
feedback_bx = gr.Textbox(label="Feedback", lines=10, interactive=True) |
|
refresh_btn.click(fn=self.get_state, inputs=gr.Number("feedback", visible=False), outputs=feedback_bx) |
|
modify_btn.click(fn=self.modify_state, inputs=[gr.Number("feedback", visible=False), |
|
gr.Number("reflect", visible=False), |
|
feedback_bx], outputs=None).then( |
|
fn=updt_disp, inputs=None, outputs=sdisps) |
|
with gr.Tab("StateSnapShots"): |
|
with gr.Row(): |
|
refresh_btn = gr.Button("Refresh") |
|
snapshots = gr.Textbox(label="State Snapshots Summaries") |
|
refresh_btn.click(fn=get_snapshots, inputs=None, outputs=snapshots) |
|
return demo |
|
|
|
def launch(self, share=None): |
|
if port := os.getenv("PORT1"): |
|
self.demo.launch(share=True, server_port=int(port), server_name="0.0.0.0") |
|
else: |
|
self.demo.launch(share=self.share) |
|
|
|
|
|
if __name__ == "__main__": |
|
workout_table = WorkoutTableGUI() |
|
workout_table.launch() |
|
|