Spaces:
Sleeping
Sleeping
File size: 17,799 Bytes
30851ff f5fb70e a318441 30851ff f5fb70e 30851ff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 |
# --- Imports ---
import gradio as gr
from langchain_groq import ChatGroq
from langchain_community.tools.tavily_search import TavilySearchResults # Updated import
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
import os
from dotenv import load_dotenv
from typing import TypedDict, List, Optional, Dict, Sequence
from langgraph.graph import StateGraph, END
# --- Environment Setup ---
load_dotenv()
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
if not TAVILY_API_KEY or not GROQ_API_KEY:
raise ValueError("API keys for Tavily and Groq must be set in .env file")
# --- Constants ---
ORDERED_FIELDS = [
"destination",
"budget",
"activities",
"duration",
"accommodation",
]
QUESTIONS_MAP = {
"destination": "π Where would you like to travel?",
"budget": "π° What is your budget for this trip?",
"activities": "π€Έ What kind of activities do you prefer? (e.g., adventure, relaxation, sightseeing)",
"duration": "β³ How many days do you plan to stay?",
"accommodation": "π¨ Do you prefer hotels, hostels, or Airbnbs?",
}
INITIAL_MESSAGE = "π Welcome! Type `START` to begin planning your travel itinerary."
START_CONFIRMATION = "π Great! Let's plan your trip. " + QUESTIONS_MAP[ORDERED_FIELDS[0]]
RESTART_MESSAGE = "β
Restarted! Type `START` to begin again."
INVALID_START_MESSAGE = "β Please type `START` to begin the travel itinerary process."
ITINERARY_READY_MESSAGE = "β
Your travel itinerary is ready!"
# --- LangChain / LangGraph Components ---
# Initialize models and tools
llm = ChatGroq(api_key=GROQ_API_KEY, model="llama3-8b-8192")
# Note: TavilySearchResults often works better when integrated as a Langchain Tool
# but for direct use like this, the previous langchain_tavily.TavilySearch is also fine.
# Let's use the more standard Tool wrapper approach for better compatibility.
search_tool = TavilySearchResults(max_results=5, tavily_api_key=TAVILY_API_KEY)
# Define the state for our graph
class AgentState(TypedDict):
messages: Sequence[BaseMessage] # Conversation history
user_info: Dict[str, str] # Collected user preferences
missing_fields: List[str] # Fields still needed
search_results: Optional[str] # Results from Tavily search
itinerary: Optional[str] # Final generated itinerary
# --- Node Functions ---
def process_user_input(state: AgentState) -> Dict:
"""Processes the latest user message to update user_info."""
last_message = state['messages'][-1]
if not isinstance(last_message, HumanMessage):
# Should not happen in normal flow, but good safeguard
return {}
user_response = last_message.content.strip()
missing_fields = state.get('missing_fields', [])
user_info = state.get('user_info', {})
if not missing_fields:
# All info gathered, nothing to process here for now
# This node is primarily for capturing answers to questions
return {}
# Assume the user is answering the question related to the *first* missing field
field_to_update = missing_fields[0]
user_info[field_to_update] = user_response
# Remove the field we just got
updated_missing_fields = missing_fields[1:]
print(f"--- Processed Input: Field '{field_to_update}' = '{user_response}' ---")
print(f"--- Remaining Fields: {updated_missing_fields} ---")
return {"user_info": user_info, "missing_fields": updated_missing_fields}
def ask_next_question(state: AgentState) -> Dict:
"""Adds the next question to the messages list."""
missing_fields = state.get('missing_fields', [])
if not missing_fields:
# Should not be called if no fields are missing, but handle defensively
return {"messages": state['messages'] + [AIMessage(content="Something went wrong, no more questions to ask.")]} # type: ignore
next_field = missing_fields[0]
question = QUESTIONS_MAP[next_field]
print(f"--- Asking Question for: {next_field} ---")
# Append the new question
updated_messages = state['messages'] + [AIMessage(content=question)] # type: ignore
return {"messages": updated_messages}
def run_search(state: AgentState) -> Dict:
"""Runs Tavily search based on collected user info, handling descriptive terms."""
user_info = state['user_info']
print("--- Running Search ---")
# Construct a more descriptive query for the search engine
query_parts = [f"Travel itinerary ideas for {user_info.get('destination', 'anywhere')}"]
if user_info.get('duration'):
# Try to clarify duration if it's non-numeric, otherwise use as is
duration_desc = user_info.get('duration')
query_parts.append(f"for about {duration_desc} days.")
if user_info.get('budget'):
# Frame budget as a description
query_parts.append(f"User's budget is described as: '{user_info.get('budget')}'.")
if user_info.get('activities'):
# Frame activities as a description
query_parts.append(f"User is looking for activities like: '{user_info.get('activities')}'.")
if user_info.get('accommodation'):
# Frame accommodation as a description
query_parts.append(f"User's accommodation preference is: '{user_info.get('accommodation')}'.")
search_query = " ".join(query_parts)
print(f"--- Refined Search Query: {search_query} ---")
try:
results = search_tool.invoke(search_query)
# Handle potential result formats (string, list of docs, dict)
if isinstance(results, list):
search_results_str = "\n\n".join([getattr(doc, 'page_content', str(doc)) for doc in results]) # Safer access
elif isinstance(results, dict) and 'answer' in results:
search_results_str = results['answer']
elif isinstance(results, dict) and 'result' in results: # Another common format
search_results_str = results['result']
else:
search_results_str = str(results)
print(f"--- Search Results Found ---")
return {"search_results": search_results_str}
except Exception as e:
print(f"--- Search Failed: {e} ---")
# Provide a more informative error message if possible
error_details = str(e)
return {"search_results": f"Search failed or timed out. Details: {error_details}"}
def generate_itinerary(state: AgentState) -> Dict:
"""Generates the final itinerary using the LLM, interpreting flexible user inputs."""
user_info = state['user_info']
search_results = state.get('search_results', "No search results available.") # Provide default
print("--- Generating Itinerary ---")
# --- Enhanced Prompt ---
itinerary_prompt = f"""
You are an expert travel planner. Create a detailed and engaging travel itinerary based on the following user preferences:
**User Preferences:**
- Destination: {user_info.get('destination', 'Not specified')}
- Duration: {user_info.get('duration', 'Not specified')} days
- Budget Description: '{user_info.get('budget', 'Not specified')}'
- Preferred Activities Description: '{user_info.get('activities', 'Not specified')}'
- Preferred Accommodation Description: '{user_info.get('accommodation', 'Not specified')}'
**Your Task:**
1. **Interpret Preferences:** Carefully interpret the user's descriptions for budget, activities, and accommodation.
* If the budget is descriptive (e.g., 'moderate', 'budget-friendly', 'a bit flexible', 'around $X'), tailor suggestions to match that level for the specific destination. Avoid extreme high-cost or only free options unless explicitly requested. 'Moderate' usually implies a balance of value, comfort, and experiences.
* If activities are described generally (e.g., 'mix of famous and offbeat', 'relaxing', 'cultural immersion', 'adventure'), create an itinerary that reflects this. A 'mix' should include popular landmarks and hidden gems. 'Relaxing' should include downtime.
* Interpret accommodation descriptions (e.g., 'mid-range', 'cheap but clean', 'boutique hotel') based on typical offerings at the destination.
2. **Use Search Results:** Incorporate relevant and specific suggestions from the search results below, but *only* if they align with the interpreted user preferences. Do not blindly copy search results.
3. **Create a Coherent Plan:** Structure the itinerary logically, often day-by-day. Include suggestions for specific activities, potential dining spots (matching budget), and estimated timings where appropriate.
4. **Engaging Tone:** Present the itinerary in an exciting and appealing way.
NOTE :- Do not go over the Budget Description. No matter the case , Always try to stay under the Budget Description
**Supporting Search Results:**
```
{search_results}
```
**Generate the Itinerary:**
"""
# --- End of Enhanced Prompt ---
try:
response = llm.invoke(itinerary_prompt)
# Ensure content extraction handles potential variations
itinerary_content = getattr(response, 'content', str(response))
print("--- Itinerary Generated ---")
final_message = AIMessage(content=f"{ITINERARY_READY_MESSAGE}\n\n{itinerary_content}")
# Ensure messages list exists and append correctly
updated_messages = list(state.get('messages', [])) + [final_message]
return {"itinerary": itinerary_content, "messages": updated_messages}
except Exception as e:
print(f"--- Itinerary Generation Failed: {e} ---")
error_details = str(e)
error_message = AIMessage(content=f"Sorry, I encountered an error while generating the itinerary. Details: {error_details}")
# Ensure messages list exists and append correctly
updated_messages = list(state.get('messages', [])) + [error_message]
return {"itinerary": None, "messages": updated_messages}
def should_ask_question_or_search(state: AgentState) -> str:
"""Determines the next step based on whether all info is collected."""
missing_fields = state.get('missing_fields', [])
if not missing_fields:
print("--- Condition: All info gathered, proceed to search ---")
return "run_search"
else:
print("--- Condition: More info needed, ask next question ---")
return "ask_next_question"
# --- Build the Graph ---
graph_builder = StateGraph(AgentState)
# Define nodes
graph_builder.add_node("process_user_input", process_user_input)
graph_builder.add_node("ask_next_question", ask_next_question)
graph_builder.add_node("run_search", run_search) # Uses the updated function
graph_builder.add_node("generate_itinerary", generate_itinerary)
# Define edges
graph_builder.set_entry_point("process_user_input")
graph_builder.add_conditional_edges(
"process_user_input",
should_ask_question_or_search,
{"ask_next_question": "ask_next_question", "run_search": "run_search"}
)
graph_builder.add_edge("ask_next_question", END)
graph_builder.add_edge("run_search", "generate_itinerary")
graph_builder.add_edge("generate_itinerary", END)
# Compile the graph
travel_agent_app = graph_builder.compile()
# --- Gradio Interface ---
# Function to handle the conversation logic with LangGraph state
def handle_user_message(user_input: str, history: List[List[str | None]], current_state_dict: Optional[dict]) -> tuple:
user_input_cleaned = user_input.strip().lower()
# Initialize state if it doesn't exist (first interaction)
if current_state_dict is None:
current_state_dict = {
"messages": [AIMessage(content=INITIAL_MESSAGE)],
"user_info": {},
"missing_fields": [], # Will be populated if user types START
"search_results": None,
"itinerary": None,
}
# Handle START command
if user_input_cleaned == "start" and not current_state_dict.get("missing_fields"): # Only start if not already started
print("--- Received START ---")
current_state_dict["missing_fields"] = list(ORDERED_FIELDS) # Initialize missing fields
current_state_dict["user_info"] = {} # Reset user info
current_state_dict["messages"] = [AIMessage(content=START_CONFIRMATION)] # type: ignore
# No graph execution needed yet, just update state and return the first question
history.append([None, START_CONFIRMATION]) # Gradio format needs None for user message here
# Handle case where user types something other than START initially
elif not current_state_dict.get("missing_fields") and user_input_cleaned != "start":
print("--- Waiting for START ---")
current_state_dict["messages"] = current_state_dict["messages"] + [ # type: ignore
HumanMessage(content=user_input),
AIMessage(content=INVALID_START_MESSAGE)
]
history.append([user_input, INVALID_START_MESSAGE])
# Handle user responses after START
elif current_state_dict.get("missing_fields") or current_state_dict.get("itinerary"): # Process if questions pending or itinerary just generated
print(f"--- User Input: {user_input} ---")
# Prevent processing if itinerary was just generated and user typed something else
if current_state_dict.get("itinerary") and not current_state_dict.get("missing_fields"):
print("--- Itinerary already generated, waiting for START OVER ---")
# Optionally add a message like "Type START OVER to begin again."
history.append([user_input, "Itinerary generated. Please click 'Start Over' to plan a new trip."])
# Keep state as is, just update history
return history, current_state_dict, ""
# Add user message to state's messages list
current_messages = current_state_dict.get("messages", [])
current_messages.append(HumanMessage(content=user_input))
current_state_dict["messages"] = current_messages
# Invoke the graph
print("--- Invoking Graph ---")
# Ensure state keys match AgentState before invoking
graph_input_state = AgentState(
messages=current_state_dict.get("messages", []),
user_info=current_state_dict.get("user_info", {}),
missing_fields=current_state_dict.get("missing_fields", []),
search_results=current_state_dict.get("search_results"),
itinerary=current_state_dict.get("itinerary"),
)
# Use stream or invoke. Invoke is simpler for this request/response cycle.
final_state = travel_agent_app.invoke(graph_input_state)
print("--- Graph Execution Complete ---")
# Update the state dictionary from the graph's final state
current_state_dict.update(final_state)
# Update Gradio history
# The graph adds the AI response(s) to state['messages']
# Get the *last* AI message added by the graph
ai_response = final_state['messages'][-1].content if final_state['messages'] and isinstance(final_state['messages'][-1], AIMessage) else "Error: No response."
history.append([user_input, ai_response])
# Handle unexpected state (fallback)
else:
print("--- Unexpected State - Resetting ---")
history.append([user_input, "Something went wrong. Please type START to begin."])
current_state_dict = { # Reset state
"messages": [AIMessage(content=INITIAL_MESSAGE)],
"user_info": {},
"missing_fields": [],
"search_results": None,
"itinerary": None,
}
# Return updated history, the persistent state dictionary, and clear the input box
return history, current_state_dict, ""
# Function to reset the state (Start Over button)
def start_over() -> tuple:
print("--- Starting Over ---")
initial_state = {
"messages": [AIMessage(content=RESTART_MESSAGE)],
"user_info": {},
"missing_fields": [],
"search_results": None,
"itinerary": None,
}
# Gradio history format: List of [user_msg, assistant_msg] pairs
initial_history = [[None, RESTART_MESSAGE]]
return initial_history, initial_state, "" # Return history, state, clear input
# --- Gradio UI Definition ---
with gr.Blocks(theme=gr.themes.Soft()) as app:
gr.Markdown("# π AI-Powered Travel Itinerary Generator (LangGraph Version)")
gr.Markdown("Type `START` to begin planning your trip. Answer the questions, and I'll generate a personalized itinerary for you!")
# Store the LangGraph state between interactions
agent_state = gr.State(value=None) # Initialize state as None
# Chat interface
chatbot = gr.Chatbot(
label="Travel Bot",
bubble_full_width=False,
value=[[None, INITIAL_MESSAGE]] # Initial message
)
user_input = gr.Textbox(label="Your Message", placeholder="Type here...", scale=3)
submit_btn = gr.Button("Send", scale=1)
start_over_btn = gr.Button("Start Over", scale=1)
# Button and Textbox actions
submit_btn.click(
fn=handle_user_message,
inputs=[user_input, chatbot, agent_state],
outputs=[chatbot, agent_state, user_input] # Update chatbot, state, and clear input
)
user_input.submit( # Allow Enter key submission
fn=handle_user_message,
inputs=[user_input, chatbot, agent_state],
outputs=[chatbot, agent_state, user_input]
)
start_over_btn.click(
fn=start_over,
inputs=[],
outputs=[chatbot, agent_state, user_input] # Reset chatbot, state, and clear input
)
# --- Run the App ---
if __name__ == "__main__":
app.launch(debug=True)# Debug=True provides more logs |