Spaces:
Sleeping
Sleeping
ShaswatSingh
commited on
Update wanderlust.py
Browse files- wanderlust.py +391 -134
wanderlust.py
CHANGED
@@ -1,134 +1,391 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
from
|
4 |
-
import
|
5 |
-
from
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
"
|
24 |
-
"
|
25 |
-
"
|
26 |
-
"
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
"
|
34 |
-
"
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
]
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
#
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --- Imports ---
|
2 |
+
import gradio as gr
|
3 |
+
from langchain_groq import ChatGroq
|
4 |
+
from langchain_community.tools.tavily_search import TavilySearchResults # Updated import
|
5 |
+
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
|
6 |
+
import os
|
7 |
+
from dotenv import load_dotenv
|
8 |
+
from typing import TypedDict, List, Optional, Dict, Sequence
|
9 |
+
from langgraph.graph import StateGraph, END
|
10 |
+
|
11 |
+
# --- Environment Setup ---
|
12 |
+
load_dotenv()
|
13 |
+
|
14 |
+
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
|
15 |
+
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
|
16 |
+
|
17 |
+
if not TAVILY_API_KEY or not GROQ_API_KEY:
|
18 |
+
raise ValueError("API keys for Tavily and Groq must be set in .env file")
|
19 |
+
|
20 |
+
# --- Constants ---
|
21 |
+
ORDERED_FIELDS = [
|
22 |
+
"destination",
|
23 |
+
"budget",
|
24 |
+
"activities",
|
25 |
+
"duration",
|
26 |
+
"accommodation",
|
27 |
+
]
|
28 |
+
|
29 |
+
QUESTIONS_MAP = {
|
30 |
+
"destination": "🌍 Where would you like to travel?",
|
31 |
+
"budget": "💰 What is your budget for this trip?",
|
32 |
+
"activities": "🤸 What kind of activities do you prefer? (e.g., adventure, relaxation, sightseeing)",
|
33 |
+
"duration": "⏳ How many days do you plan to stay?",
|
34 |
+
"accommodation": "🏨 Do you prefer hotels, hostels, or Airbnbs?",
|
35 |
+
}
|
36 |
+
|
37 |
+
INITIAL_MESSAGE = "👋 Welcome! Type `START` to begin planning your travel itinerary."
|
38 |
+
START_CONFIRMATION = "🚀 Great! Let's plan your trip. " + QUESTIONS_MAP[ORDERED_FIELDS[0]]
|
39 |
+
RESTART_MESSAGE = "✅ Restarted! Type `START` to begin again."
|
40 |
+
INVALID_START_MESSAGE = "❗ Please type `START` to begin the travel itinerary process."
|
41 |
+
ITINERARY_READY_MESSAGE = "✅ Your travel itinerary is ready!"
|
42 |
+
|
43 |
+
# --- LangChain / LangGraph Components ---
|
44 |
+
|
45 |
+
# Initialize models and tools
|
46 |
+
llm = ChatGroq(api_key=GROQ_API_KEY, model="llama3-8b-8192")
|
47 |
+
# Note: TavilySearchResults often works better when integrated as a Langchain Tool
|
48 |
+
# but for direct use like this, the previous langchain_tavily.TavilySearch is also fine.
|
49 |
+
# Let's use the more standard Tool wrapper approach for better compatibility.
|
50 |
+
search_tool = TavilySearchResults(max_results=5, tavily_api_key=TAVILY_API_KEY)
|
51 |
+
|
52 |
+
|
53 |
+
# Define the state for our graph
|
54 |
+
class AgentState(TypedDict):
|
55 |
+
messages: Sequence[BaseMessage] # Conversation history
|
56 |
+
user_info: Dict[str, str] # Collected user preferences
|
57 |
+
missing_fields: List[str] # Fields still needed
|
58 |
+
search_results: Optional[str] # Results from Tavily search
|
59 |
+
itinerary: Optional[str] # Final generated itinerary
|
60 |
+
|
61 |
+
# --- Node Functions ---
|
62 |
+
|
63 |
+
def process_user_input(state: AgentState) -> Dict:
|
64 |
+
"""Processes the latest user message to update user_info."""
|
65 |
+
last_message = state['messages'][-1]
|
66 |
+
if not isinstance(last_message, HumanMessage):
|
67 |
+
# Should not happen in normal flow, but good safeguard
|
68 |
+
return {}
|
69 |
+
|
70 |
+
user_response = last_message.content.strip()
|
71 |
+
missing_fields = state.get('missing_fields', [])
|
72 |
+
user_info = state.get('user_info', {})
|
73 |
+
|
74 |
+
if not missing_fields:
|
75 |
+
# All info gathered, nothing to process here for now
|
76 |
+
# This node is primarily for capturing answers to questions
|
77 |
+
return {}
|
78 |
+
|
79 |
+
# Assume the user is answering the question related to the *first* missing field
|
80 |
+
field_to_update = missing_fields[0]
|
81 |
+
user_info[field_to_update] = user_response
|
82 |
+
|
83 |
+
# Remove the field we just got
|
84 |
+
updated_missing_fields = missing_fields[1:]
|
85 |
+
|
86 |
+
print(f"--- Processed Input: Field '{field_to_update}' = '{user_response}' ---")
|
87 |
+
print(f"--- Remaining Fields: {updated_missing_fields} ---")
|
88 |
+
|
89 |
+
return {"user_info": user_info, "missing_fields": updated_missing_fields}
|
90 |
+
|
91 |
+
def ask_next_question(state: AgentState) -> Dict:
|
92 |
+
"""Adds the next question to the messages list."""
|
93 |
+
missing_fields = state.get('missing_fields', [])
|
94 |
+
if not missing_fields:
|
95 |
+
# Should not be called if no fields are missing, but handle defensively
|
96 |
+
return {"messages": state['messages'] + [AIMessage(content="Something went wrong, no more questions to ask.")]} # type: ignore
|
97 |
+
|
98 |
+
next_field = missing_fields[0]
|
99 |
+
question = QUESTIONS_MAP[next_field]
|
100 |
+
|
101 |
+
print(f"--- Asking Question for: {next_field} ---")
|
102 |
+
# Append the new question
|
103 |
+
updated_messages = state['messages'] + [AIMessage(content=question)] # type: ignore
|
104 |
+
return {"messages": updated_messages}
|
105 |
+
|
106 |
+
|
107 |
+
def run_search(state: AgentState) -> Dict:
|
108 |
+
"""Runs Tavily search based on collected user info, handling descriptive terms."""
|
109 |
+
user_info = state['user_info']
|
110 |
+
print("--- Running Search ---")
|
111 |
+
|
112 |
+
# Construct a more descriptive query for the search engine
|
113 |
+
query_parts = [f"Travel itinerary ideas for {user_info.get('destination', 'anywhere')}"]
|
114 |
+
if user_info.get('duration'):
|
115 |
+
# Try to clarify duration if it's non-numeric, otherwise use as is
|
116 |
+
duration_desc = user_info.get('duration')
|
117 |
+
query_parts.append(f"for about {duration_desc} days.")
|
118 |
+
|
119 |
+
if user_info.get('budget'):
|
120 |
+
# Frame budget as a description
|
121 |
+
query_parts.append(f"User's budget is described as: '{user_info.get('budget')}'.")
|
122 |
+
|
123 |
+
if user_info.get('activities'):
|
124 |
+
# Frame activities as a description
|
125 |
+
query_parts.append(f"User is looking for activities like: '{user_info.get('activities')}'.")
|
126 |
+
|
127 |
+
if user_info.get('accommodation'):
|
128 |
+
# Frame accommodation as a description
|
129 |
+
query_parts.append(f"User's accommodation preference is: '{user_info.get('accommodation')}'.")
|
130 |
+
|
131 |
+
search_query = " ".join(query_parts)
|
132 |
+
print(f"--- Refined Search Query: {search_query} ---")
|
133 |
+
|
134 |
+
try:
|
135 |
+
results = search_tool.invoke(search_query)
|
136 |
+
# Handle potential result formats (string, list of docs, dict)
|
137 |
+
if isinstance(results, list):
|
138 |
+
search_results_str = "\n\n".join([getattr(doc, 'page_content', str(doc)) for doc in results]) # Safer access
|
139 |
+
elif isinstance(results, dict) and 'answer' in results:
|
140 |
+
search_results_str = results['answer']
|
141 |
+
elif isinstance(results, dict) and 'result' in results: # Another common format
|
142 |
+
search_results_str = results['result']
|
143 |
+
else:
|
144 |
+
search_results_str = str(results)
|
145 |
+
|
146 |
+
print(f"--- Search Results Found ---")
|
147 |
+
return {"search_results": search_results_str}
|
148 |
+
except Exception as e:
|
149 |
+
print(f"--- Search Failed: {e} ---")
|
150 |
+
# Provide a more informative error message if possible
|
151 |
+
error_details = str(e)
|
152 |
+
return {"search_results": f"Search failed or timed out. Details: {error_details}"}
|
153 |
+
|
154 |
+
|
155 |
+
def generate_itinerary(state: AgentState) -> Dict:
|
156 |
+
"""Generates the final itinerary using the LLM, interpreting flexible user inputs."""
|
157 |
+
user_info = state['user_info']
|
158 |
+
search_results = state.get('search_results', "No search results available.") # Provide default
|
159 |
+
print("--- Generating Itinerary ---")
|
160 |
+
|
161 |
+
# --- Enhanced Prompt ---
|
162 |
+
itinerary_prompt = f"""
|
163 |
+
You are an expert travel planner. Create a detailed and engaging travel itinerary based on the following user preferences:
|
164 |
+
|
165 |
+
**User Preferences:**
|
166 |
+
- Destination: {user_info.get('destination', 'Not specified')}
|
167 |
+
- Duration: {user_info.get('duration', 'Not specified')} days
|
168 |
+
- Budget Description: '{user_info.get('budget', 'Not specified')}'
|
169 |
+
- Preferred Activities Description: '{user_info.get('activities', 'Not specified')}'
|
170 |
+
- Preferred Accommodation Description: '{user_info.get('accommodation', 'Not specified')}'
|
171 |
+
|
172 |
+
**Your Task:**
|
173 |
+
1. **Interpret Preferences:** Carefully interpret the user's descriptions for budget, activities, and accommodation.
|
174 |
+
* If the budget is descriptive (e.g., 'moderate', 'budget-friendly', 'a bit flexible', 'around $X'), tailor suggestions to match that level for the specific destination. Avoid extreme high-cost or only free options unless explicitly requested. 'Moderate' usually implies a balance of value, comfort, and experiences.
|
175 |
+
* If activities are described generally (e.g., 'mix of famous and offbeat', 'relaxing', 'cultural immersion', 'adventure'), create an itinerary that reflects this. A 'mix' should include popular landmarks and hidden gems. 'Relaxing' should include downtime.
|
176 |
+
* Interpret accommodation descriptions (e.g., 'mid-range', 'cheap but clean', 'boutique hotel') based on typical offerings at the destination.
|
177 |
+
2. **Use Search Results:** Incorporate relevant and specific suggestions from the search results below, but *only* if they align with the interpreted user preferences. Do not blindly copy search results.
|
178 |
+
3. **Create a Coherent Plan:** Structure the itinerary logically, often day-by-day. Include suggestions for specific activities, potential dining spots (matching budget), and estimated timings where appropriate.
|
179 |
+
4. **Engaging Tone:** Present the itinerary in an exciting and appealing way.
|
180 |
+
|
181 |
+
**Supporting Search Results:**
|
182 |
+
```
|
183 |
+
{search_results}
|
184 |
+
```
|
185 |
+
|
186 |
+
**Generate the Itinerary:**
|
187 |
+
"""
|
188 |
+
# --- End of Enhanced Prompt ---
|
189 |
+
|
190 |
+
try:
|
191 |
+
response = llm.invoke(itinerary_prompt)
|
192 |
+
# Ensure content extraction handles potential variations
|
193 |
+
itinerary_content = getattr(response, 'content', str(response))
|
194 |
+
|
195 |
+
print("--- Itinerary Generated ---")
|
196 |
+
|
197 |
+
final_message = AIMessage(content=f"{ITINERARY_READY_MESSAGE}\n\n{itinerary_content}")
|
198 |
+
# Ensure messages list exists and append correctly
|
199 |
+
updated_messages = list(state.get('messages', [])) + [final_message]
|
200 |
+
|
201 |
+
return {"itinerary": itinerary_content, "messages": updated_messages}
|
202 |
+
except Exception as e:
|
203 |
+
print(f"--- Itinerary Generation Failed: {e} ---")
|
204 |
+
error_details = str(e)
|
205 |
+
error_message = AIMessage(content=f"Sorry, I encountered an error while generating the itinerary. Details: {error_details}")
|
206 |
+
# Ensure messages list exists and append correctly
|
207 |
+
updated_messages = list(state.get('messages', [])) + [error_message]
|
208 |
+
return {"itinerary": None, "messages": updated_messages}
|
209 |
+
|
210 |
+
|
211 |
+
|
212 |
+
def should_ask_question_or_search(state: AgentState) -> str:
|
213 |
+
"""Determines the next step based on whether all info is collected."""
|
214 |
+
missing_fields = state.get('missing_fields', [])
|
215 |
+
if not missing_fields:
|
216 |
+
print("--- Condition: All info gathered, proceed to search ---")
|
217 |
+
return "run_search"
|
218 |
+
else:
|
219 |
+
print("--- Condition: More info needed, ask next question ---")
|
220 |
+
return "ask_next_question"
|
221 |
+
|
222 |
+
# --- Build the Graph ---
|
223 |
+
|
224 |
+
graph_builder = StateGraph(AgentState)
|
225 |
+
|
226 |
+
# Define nodes
|
227 |
+
graph_builder.add_node("process_user_input", process_user_input)
|
228 |
+
graph_builder.add_node("ask_next_question", ask_next_question)
|
229 |
+
graph_builder.add_node("run_search", run_search) # Uses the updated function
|
230 |
+
graph_builder.add_node("generate_itinerary", generate_itinerary)
|
231 |
+
|
232 |
+
# Define edges
|
233 |
+
graph_builder.set_entry_point("process_user_input")
|
234 |
+
graph_builder.add_conditional_edges(
|
235 |
+
"process_user_input",
|
236 |
+
should_ask_question_or_search,
|
237 |
+
{"ask_next_question": "ask_next_question", "run_search": "run_search"}
|
238 |
+
)
|
239 |
+
graph_builder.add_edge("ask_next_question", END)
|
240 |
+
graph_builder.add_edge("run_search", "generate_itinerary")
|
241 |
+
graph_builder.add_edge("generate_itinerary", END)
|
242 |
+
|
243 |
+
# Compile the graph
|
244 |
+
travel_agent_app = graph_builder.compile()
|
245 |
+
|
246 |
+
# --- Gradio Interface ---
|
247 |
+
|
248 |
+
# Function to handle the conversation logic with LangGraph state
|
249 |
+
def handle_user_message(user_input: str, history: List[List[str | None]], current_state_dict: Optional[dict]) -> tuple:
|
250 |
+
user_input_cleaned = user_input.strip().lower()
|
251 |
+
|
252 |
+
# Initialize state if it doesn't exist (first interaction)
|
253 |
+
if current_state_dict is None:
|
254 |
+
current_state_dict = {
|
255 |
+
"messages": [AIMessage(content=INITIAL_MESSAGE)],
|
256 |
+
"user_info": {},
|
257 |
+
"missing_fields": [], # Will be populated if user types START
|
258 |
+
"search_results": None,
|
259 |
+
"itinerary": None,
|
260 |
+
}
|
261 |
+
|
262 |
+
# Handle START command
|
263 |
+
if user_input_cleaned == "start" and not current_state_dict.get("missing_fields"): # Only start if not already started
|
264 |
+
print("--- Received START ---")
|
265 |
+
current_state_dict["missing_fields"] = list(ORDERED_FIELDS) # Initialize missing fields
|
266 |
+
current_state_dict["user_info"] = {} # Reset user info
|
267 |
+
current_state_dict["messages"] = [AIMessage(content=START_CONFIRMATION)] # type: ignore
|
268 |
+
# No graph execution needed yet, just update state and return the first question
|
269 |
+
history.append([None, START_CONFIRMATION]) # Gradio format needs None for user message here
|
270 |
+
|
271 |
+
# Handle case where user types something other than START initially
|
272 |
+
elif not current_state_dict.get("missing_fields") and user_input_cleaned != "start":
|
273 |
+
print("--- Waiting for START ---")
|
274 |
+
current_state_dict["messages"] = current_state_dict["messages"] + [ # type: ignore
|
275 |
+
HumanMessage(content=user_input),
|
276 |
+
AIMessage(content=INVALID_START_MESSAGE)
|
277 |
+
]
|
278 |
+
history.append([user_input, INVALID_START_MESSAGE])
|
279 |
+
|
280 |
+
|
281 |
+
# Handle user responses after START
|
282 |
+
elif current_state_dict.get("missing_fields") or current_state_dict.get("itinerary"): # Process if questions pending or itinerary just generated
|
283 |
+
print(f"--- User Input: {user_input} ---")
|
284 |
+
# Prevent processing if itinerary was just generated and user typed something else
|
285 |
+
if current_state_dict.get("itinerary") and not current_state_dict.get("missing_fields"):
|
286 |
+
print("--- Itinerary already generated, waiting for START OVER ---")
|
287 |
+
# Optionally add a message like "Type START OVER to begin again."
|
288 |
+
history.append([user_input, "Itinerary generated. Please click 'Start Over' to plan a new trip."])
|
289 |
+
# Keep state as is, just update history
|
290 |
+
return history, current_state_dict, ""
|
291 |
+
|
292 |
+
|
293 |
+
# Add user message to state's messages list
|
294 |
+
current_messages = current_state_dict.get("messages", [])
|
295 |
+
current_messages.append(HumanMessage(content=user_input))
|
296 |
+
current_state_dict["messages"] = current_messages
|
297 |
+
|
298 |
+
# Invoke the graph
|
299 |
+
print("--- Invoking Graph ---")
|
300 |
+
# Ensure state keys match AgentState before invoking
|
301 |
+
graph_input_state = AgentState(
|
302 |
+
messages=current_state_dict.get("messages", []),
|
303 |
+
user_info=current_state_dict.get("user_info", {}),
|
304 |
+
missing_fields=current_state_dict.get("missing_fields", []),
|
305 |
+
search_results=current_state_dict.get("search_results"),
|
306 |
+
itinerary=current_state_dict.get("itinerary"),
|
307 |
+
)
|
308 |
+
# Use stream or invoke. Invoke is simpler for this request/response cycle.
|
309 |
+
final_state = travel_agent_app.invoke(graph_input_state)
|
310 |
+
print("--- Graph Execution Complete ---")
|
311 |
+
|
312 |
+
# Update the state dictionary from the graph's final state
|
313 |
+
current_state_dict.update(final_state)
|
314 |
+
|
315 |
+
# Update Gradio history
|
316 |
+
# The graph adds the AI response(s) to state['messages']
|
317 |
+
# Get the *last* AI message added by the graph
|
318 |
+
ai_response = final_state['messages'][-1].content if final_state['messages'] and isinstance(final_state['messages'][-1], AIMessage) else "Error: No response."
|
319 |
+
history.append([user_input, ai_response])
|
320 |
+
|
321 |
+
|
322 |
+
# Handle unexpected state (fallback)
|
323 |
+
else:
|
324 |
+
print("--- Unexpected State - Resetting ---")
|
325 |
+
history.append([user_input, "Something went wrong. Please type START to begin."])
|
326 |
+
current_state_dict = { # Reset state
|
327 |
+
"messages": [AIMessage(content=INITIAL_MESSAGE)],
|
328 |
+
"user_info": {},
|
329 |
+
"missing_fields": [],
|
330 |
+
"search_results": None,
|
331 |
+
"itinerary": None,
|
332 |
+
}
|
333 |
+
|
334 |
+
|
335 |
+
# Return updated history, the persistent state dictionary, and clear the input box
|
336 |
+
return history, current_state_dict, ""
|
337 |
+
|
338 |
+
# Function to reset the state (Start Over button)
|
339 |
+
def start_over() -> tuple:
|
340 |
+
print("--- Starting Over ---")
|
341 |
+
initial_state = {
|
342 |
+
"messages": [AIMessage(content=RESTART_MESSAGE)],
|
343 |
+
"user_info": {},
|
344 |
+
"missing_fields": [],
|
345 |
+
"search_results": None,
|
346 |
+
"itinerary": None,
|
347 |
+
}
|
348 |
+
# Gradio history format: List of [user_msg, assistant_msg] pairs
|
349 |
+
initial_history = [[None, RESTART_MESSAGE]]
|
350 |
+
return initial_history, initial_state, "" # Return history, state, clear input
|
351 |
+
|
352 |
+
|
353 |
+
# --- Gradio UI Definition ---
|
354 |
+
with gr.Blocks(theme=gr.themes.Soft()) as app:
|
355 |
+
gr.Markdown("# 🌍 AI-Powered Travel Itinerary Generator (LangGraph Version)")
|
356 |
+
gr.Markdown("Type `START` to begin planning your trip. Answer the questions, and I'll generate a personalized itinerary for you!")
|
357 |
+
|
358 |
+
# Store the LangGraph state between interactions
|
359 |
+
agent_state = gr.State(value=None) # Initialize state as None
|
360 |
+
|
361 |
+
# Chat interface
|
362 |
+
chatbot = gr.Chatbot(
|
363 |
+
label="Travel Bot",
|
364 |
+
bubble_full_width=False,
|
365 |
+
value=[[None, INITIAL_MESSAGE]] # Initial message
|
366 |
+
)
|
367 |
+
user_input = gr.Textbox(label="Your Message", placeholder="Type here...", scale=3)
|
368 |
+
submit_btn = gr.Button("Send", scale=1)
|
369 |
+
start_over_btn = gr.Button("Start Over", scale=1)
|
370 |
+
|
371 |
+
# Button and Textbox actions
|
372 |
+
submit_btn.click(
|
373 |
+
fn=handle_user_message,
|
374 |
+
inputs=[user_input, chatbot, agent_state],
|
375 |
+
outputs=[chatbot, agent_state, user_input] # Update chatbot, state, and clear input
|
376 |
+
)
|
377 |
+
user_input.submit( # Allow Enter key submission
|
378 |
+
fn=handle_user_message,
|
379 |
+
inputs=[user_input, chatbot, agent_state],
|
380 |
+
outputs=[chatbot, agent_state, user_input]
|
381 |
+
)
|
382 |
+
start_over_btn.click(
|
383 |
+
fn=start_over,
|
384 |
+
inputs=[],
|
385 |
+
outputs=[chatbot, agent_state, user_input] # Reset chatbot, state, and clear input
|
386 |
+
)
|
387 |
+
|
388 |
+
|
389 |
+
# --- Run the App ---
|
390 |
+
if __name__ == "__main__":
|
391 |
+
app.launch(debug=True)# Debug=True provides more logs
|