PD03's picture
Update agent/rica_agent.py
83658db verified
"""
RICA Agent optimized for Hugging Face Spaces
Fixed for latest smolagents API
"""
import os
from smolagents import CodeAgent
from smolagents.models import OpenAIServerModel
from agent_tools.ml_tools import predict_customer_churn_hf, get_model_status
def create_rica_agent_hf():
"""Create RICA agent using correct smolagents API"""
# Check API key availability
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError("OpenAI API key not configured")
# Initialize OpenAI model using smolagents model class
model = OpenAIServerModel(
model_id="gpt-3.5-turbo",
api_key=api_key
)
# HF Spaces optimized tools
hf_tools = [
predict_customer_churn_hf,
get_model_status
]
try:
# Create agent with correct parameters (no max_iterations)
agent = CodeAgent(
tools=hf_tools,
model=model,
add_base_tools=False
)
return agent
except Exception as e:
raise Exception(f"Agent creation failed: {str(e)}")
def execute_rica_analysis_hf(analysis_type: str, parameters: dict = None):
"""Execute RICA analysis optimized for HF Spaces"""
# Check API key first
if not os.getenv("OPENAI_API_KEY"):
return "Error: OpenAI API key not configured. Please set your API key in the sidebar."
try:
agent = create_rica_agent_hf()
# Simplified goals for HF Spaces
hf_goals = {
"comprehensive": f"""
Execute business intelligence analysis:
1) Check model status with get_model_status()
2) Predict customer churn with predict_customer_churn_hf()
3) Provide executive summary with key insights and recommendations
Focus on actionable insights for business decision-making.
Parameters: {parameters}
""",
"churn_focus": f"""
Focus on customer churn analysis:
1) Predict customer churn with predict_customer_churn_hf(risk_threshold={parameters.get('risk_threshold', 0.6) if parameters else 0.6})
2) Identify high-risk customers requiring immediate attention
3) Provide specific intervention strategies
Parameters: {parameters}
""",
"quick_insights": f"""
Provide quick business insights:
1) Get model status with get_model_status()
2) Run limited churn analysis with predict_customer_churn_hf()
3) Summarize top 3 business priorities
Parameters: {parameters}
"""
}
goal = hf_goals.get(analysis_type, hf_goals["comprehensive"])
result = agent.run(goal)
return result
except Exception as e:
return f"Analysis failed: {str(e)}. Please check your API key and model status."