naman1102 commited on
Commit
5854ce9
·
1 Parent(s): fc853f2
Files changed (3) hide show
  1. app.py +71 -1
  2. code_interpreter.py +135 -0
  3. decision_maker.py +162 -0
app.py CHANGED
@@ -16,6 +16,8 @@ from huggingface_hub import InferenceClient
16
  import io
17
  import mimetypes
18
  import base64
 
 
19
 
20
  # -------------------------
21
  # Environment & constants
@@ -94,6 +96,7 @@ class AgentState(TypedDict):
94
  search_query: Annotated[str, override]
95
  task_id: Annotated[str, override]
96
  logs: Annotated[Dict[str, Any], merge_dicts]
 
97
 
98
  # -------------------------
99
  # BasicAgent implementation
@@ -104,6 +107,8 @@ class BasicAgent:
104
  if not OPENAI_API_KEY:
105
  raise EnvironmentError("OPENAI_API_KEY not set")
106
  self.llm = OpenAI(api_key=OPENAI_API_KEY)
 
 
107
  self.workflow = self._build_workflow()
108
 
109
  # ---- Low‑level LLM call
@@ -121,6 +126,19 @@ class BasicAgent:
121
 
122
  # ---- Workflow nodes
123
  def _analyze_question(self, state: AgentState) -> AgentState:
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  # Check for multimodal content
125
  q = state["question"].lower()
126
  if "video" in q or q.endswith(".mp4"):
@@ -153,6 +171,53 @@ class BasicAgent:
153
  state["history"].append({"step": "analyze", "output": decision})
154
  return state
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  def _image_node(self, state: AgentState) -> AgentState:
157
  """Handle image-based questions."""
158
  try:
@@ -284,6 +349,7 @@ Think step-by-step. Write ANSWER: <answer> on its own line.
284
  sg.add_node("image", self._image_node)
285
  sg.add_node("video", self._video_node)
286
  sg.add_node("sheet", self._sheet_node)
 
287
 
288
  # Add edges
289
  sg.add_edge("analyze", "search")
@@ -292,6 +358,7 @@ Think step-by-step. Write ANSWER: <answer> on its own line.
292
  sg.add_edge("image", "answer")
293
  sg.add_edge("video", "answer")
294
  sg.add_edge("sheet", "answer")
 
295
 
296
  def router(state: AgentState):
297
  return state["current_step"]
@@ -301,7 +368,8 @@ Think step-by-step. Write ANSWER: <answer> on its own line.
301
  "answer": "answer",
302
  "image": "image",
303
  "video": "video",
304
- "sheet": "sheet"
 
305
  })
306
  sg.add_conditional_edges("recheck", router, {
307
  "search": "search",
@@ -323,6 +391,7 @@ Think step-by-step. Write ANSWER: <answer> on its own line.
323
  "search_query": "",
324
  "task_id": task_id,
325
  "logs": {},
 
326
  }
327
  final_state = self.workflow.invoke(state)
328
  return final_state["final_answer"]
@@ -409,6 +478,7 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
409
  "search_query": "",
410
  "task_id": task_id,
411
  "logs": {},
 
412
  }
413
 
414
  # Run the workflow
 
16
  import io
17
  import mimetypes
18
  import base64
19
+ from code_interpreter import CodeInterpreter
20
+ from decision_maker import DecisionMaker, ToolType
21
 
22
  # -------------------------
23
  # Environment & constants
 
96
  search_query: Annotated[str, override]
97
  task_id: Annotated[str, override]
98
  logs: Annotated[Dict[str, Any], merge_dicts]
99
+ code_blocks: Annotated[List[Dict[str, str]], list.__add__]
100
 
101
  # -------------------------
102
  # BasicAgent implementation
 
107
  if not OPENAI_API_KEY:
108
  raise EnvironmentError("OPENAI_API_KEY not set")
109
  self.llm = OpenAI(api_key=OPENAI_API_KEY)
110
+ self.code_interpreter = CodeInterpreter()
111
+ self.decision_maker = DecisionMaker()
112
  self.workflow = self._build_workflow()
113
 
114
  # ---- Low‑level LLM call
 
126
 
127
  # ---- Workflow nodes
128
  def _analyze_question(self, state: AgentState) -> AgentState:
129
+ # First, analyze the request using the decision maker
130
+ request_analysis = self.decision_maker.analyze_request(state["question"])
131
+ state["logs"]["request_analysis"] = request_analysis
132
+
133
+ # Check for code-related content
134
+ if "code" in request_analysis["intent"]:
135
+ # Extract code blocks from the question
136
+ code_blocks = self._extract_code_blocks(state["question"])
137
+ if code_blocks:
138
+ state["current_step"] = "code_analysis"
139
+ state["code_blocks"] = code_blocks
140
+ return state
141
+
142
  # Check for multimodal content
143
  q = state["question"].lower()
144
  if "video" in q or q.endswith(".mp4"):
 
171
  state["history"].append({"step": "analyze", "output": decision})
172
  return state
173
 
174
+ def _extract_code_blocks(self, text: str) -> List[Dict[str, str]]:
175
+ """Extract code blocks from text using markdown-style code blocks."""
176
+ code_blocks = []
177
+ pattern = r"```(\w+)?\n(.*?)```"
178
+ matches = re.finditer(pattern, text, re.DOTALL)
179
+
180
+ for match in matches:
181
+ language = match.group(1) or "python"
182
+ code = match.group(2).strip()
183
+ code_blocks.append({
184
+ "language": language,
185
+ "code": code
186
+ })
187
+
188
+ return code_blocks
189
+
190
+ def _code_analysis_node(self, state: AgentState) -> AgentState:
191
+ """Handle code analysis requests."""
192
+ try:
193
+ results = []
194
+ for block in state["code_blocks"]:
195
+ # Analyze code using the code interpreter
196
+ analysis = self.code_interpreter.analyze_code(
197
+ block["code"],
198
+ language=block["language"]
199
+ )
200
+
201
+ # Get improvement suggestions
202
+ suggestions = self.code_interpreter.suggest_improvements(analysis)
203
+
204
+ # Format the results
205
+ result = {
206
+ "language": block["language"],
207
+ "analysis": analysis,
208
+ "suggestions": suggestions
209
+ }
210
+ results.append(result)
211
+
212
+ state["history"].append({"step": "code_analysis", "output": results})
213
+ state["current_step"] = "answer"
214
+
215
+ except Exception as e:
216
+ state["logs"]["code_analysis_error"] = str(e)
217
+ state["current_step"] = "answer"
218
+
219
+ return state
220
+
221
  def _image_node(self, state: AgentState) -> AgentState:
222
  """Handle image-based questions."""
223
  try:
 
349
  sg.add_node("image", self._image_node)
350
  sg.add_node("video", self._video_node)
351
  sg.add_node("sheet", self._sheet_node)
352
+ sg.add_node("code_analysis", self._code_analysis_node)
353
 
354
  # Add edges
355
  sg.add_edge("analyze", "search")
 
358
  sg.add_edge("image", "answer")
359
  sg.add_edge("video", "answer")
360
  sg.add_edge("sheet", "answer")
361
+ sg.add_edge("code_analysis", "answer")
362
 
363
  def router(state: AgentState):
364
  return state["current_step"]
 
368
  "answer": "answer",
369
  "image": "image",
370
  "video": "video",
371
+ "sheet": "sheet",
372
+ "code_analysis": "code_analysis"
373
  })
374
  sg.add_conditional_edges("recheck", router, {
375
  "search": "search",
 
391
  "search_query": "",
392
  "task_id": task_id,
393
  "logs": {},
394
+ "code_blocks": [],
395
  }
396
  final_state = self.workflow.invoke(state)
397
  return final_state["final_answer"]
 
478
  "search_query": "",
479
  "task_id": task_id,
480
  "logs": {},
481
+ "code_blocks": [],
482
  }
483
 
484
  # Run the workflow
code_interpreter.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ from typing import Dict, List, Any, Optional
3
+ import re
4
+
5
+ class CodeInterpreter:
6
+ def __init__(self):
7
+ self.language = 'python' # Only support Python
8
+
9
+ def analyze_code(self, code: str) -> Dict[str, Any]:
10
+ """
11
+ Analyze Python code and extract key information about its structure and functionality.
12
+ """
13
+ try:
14
+ return self._analyze_python_code(code)
15
+ except Exception as e:
16
+ return {"error": f"Python code analysis failed: {str(e)}"}
17
+
18
+ def _analyze_python_code(self, code: str) -> Dict[str, Any]:
19
+ """
20
+ Analyze Python code using AST.
21
+ """
22
+ try:
23
+ tree = ast.parse(code)
24
+ analysis = {
25
+ "imports": [],
26
+ "functions": [],
27
+ "classes": [],
28
+ "variables": [],
29
+ "complexity": 0,
30
+ "docstrings": [],
31
+ "decorators": []
32
+ }
33
+
34
+ for node in ast.walk(tree):
35
+ if isinstance(node, ast.Import):
36
+ for name in node.names:
37
+ analysis["imports"].append(name.name)
38
+ elif isinstance(node, ast.ImportFrom):
39
+ analysis["imports"].append(f"{node.module}.{node.names[0].name}")
40
+ elif isinstance(node, ast.FunctionDef):
41
+ func_info = {
42
+ "name": node.name,
43
+ "args": [arg.arg for arg in node.args.args],
44
+ "returns": self._get_return_type(node),
45
+ "complexity": self._calculate_complexity(node),
46
+ "docstring": ast.get_docstring(node),
47
+ "decorators": [d.id for d in node.decorator_list if isinstance(d, ast.Name)]
48
+ }
49
+ analysis["functions"].append(func_info)
50
+ elif isinstance(node, ast.ClassDef):
51
+ class_info = {
52
+ "name": node.name,
53
+ "methods": [],
54
+ "bases": [base.id for base in node.bases if isinstance(base, ast.Name)],
55
+ "docstring": ast.get_docstring(node),
56
+ "decorators": [d.id for d in node.decorator_list if isinstance(d, ast.Name)]
57
+ }
58
+ for item in node.body:
59
+ if isinstance(item, ast.FunctionDef):
60
+ class_info["methods"].append(item.name)
61
+ analysis["classes"].append(class_info)
62
+ elif isinstance(node, ast.Assign):
63
+ for target in node.targets:
64
+ if isinstance(target, ast.Name):
65
+ analysis["variables"].append(target.id)
66
+ elif isinstance(node, ast.Expr) and isinstance(node.value, ast.Str):
67
+ analysis["docstrings"].append(node.value.s)
68
+
69
+ analysis["complexity"] = sum(func["complexity"] for func in analysis["functions"])
70
+ return analysis
71
+
72
+ except Exception as e:
73
+ return {"error": f"Python code analysis failed: {str(e)}"}
74
+
75
+ def _get_return_type(self, node: ast.FunctionDef) -> Optional[str]:
76
+ """Extract return type annotation if present."""
77
+ if node.returns:
78
+ if isinstance(node.returns, ast.Name):
79
+ return node.returns.id
80
+ elif isinstance(node.returns, ast.Subscript):
81
+ return f"{node.returns.value.id}[{node.returns.slice.value.id}]"
82
+ return None
83
+
84
+ def _calculate_complexity(self, node: ast.AST) -> int:
85
+ """Calculate cyclomatic complexity of a function."""
86
+ complexity = 1
87
+ for child in ast.walk(node):
88
+ if isinstance(child, (ast.If, ast.While, ast.For, ast.Try, ast.ExceptHandler)):
89
+ complexity += 1
90
+ return complexity
91
+
92
+ def suggest_improvements(self, analysis: Dict[str, Any]) -> List[str]:
93
+ """
94
+ Suggest code improvements based on analysis.
95
+ """
96
+ suggestions = []
97
+
98
+ # Check function complexity
99
+ for func in analysis.get("functions", []):
100
+ if func["complexity"] > 10:
101
+ suggestions.append(f"Function '{func['name']}' is too complex (complexity: {func['complexity']}). Consider breaking it down into smaller functions.")
102
+
103
+ # Check for missing type hints
104
+ for func in analysis.get("functions", []):
105
+ if not func["returns"]:
106
+ suggestions.append(f"Function '{func['name']}' is missing return type annotation.")
107
+
108
+ # Check for missing docstrings
109
+ for func in analysis.get("functions", []):
110
+ if not func["docstring"]:
111
+ suggestions.append(f"Function '{func['name']}' is missing a docstring.")
112
+
113
+ # Check for unused imports
114
+ if len(analysis.get("imports", [])) > 10:
115
+ suggestions.append("Consider removing unused imports to improve code clarity.")
116
+
117
+ # Check for long functions
118
+ for func in analysis.get("functions", []):
119
+ if len(func["args"]) > 5:
120
+ suggestions.append(f"Function '{func['name']}' has too many parameters ({len(func['args'])}). Consider using a data class or dictionary.")
121
+
122
+ return suggestions
123
+
124
+ def extract_code_context(self, code: str, line_number: int) -> Dict[str, Any]:
125
+ """
126
+ Extract context around a specific line of code.
127
+ """
128
+ lines = code.split('\n')
129
+ context = {
130
+ "line": lines[line_number - 1] if 0 <= line_number - 1 < len(lines) else "",
131
+ "before": lines[max(0, line_number - 3):line_number - 1],
132
+ "after": lines[line_number:min(len(lines), line_number + 3)],
133
+ "indentation": len(re.match(r'^\s*', lines[line_number - 1]).group()) if 0 <= line_number - 1 < len(lines) else 0
134
+ }
135
+ return context
decision_maker.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any, Optional
2
+ from dataclasses import dataclass
3
+ from enum import Enum
4
+ import json
5
+ import time
6
+
7
+ class ToolType(Enum):
8
+ SEARCH = "search"
9
+ CODE_ANALYSIS = "code_analysis"
10
+ FILE_OPERATION = "file_operation"
11
+ UNKNOWN = "unknown"
12
+
13
+ @dataclass
14
+ class Tool:
15
+ name: str
16
+ type: ToolType
17
+ description: str
18
+ required_params: List[str]
19
+ optional_params: List[str]
20
+ confidence_threshold: float = 0.7
21
+
22
+ class DecisionMaker:
23
+ def __init__(self):
24
+ self.tools = self._initialize_tools()
25
+ self.decision_history = []
26
+
27
+ def _initialize_tools(self) -> Dict[str, Tool]:
28
+ """Initialize available tools with their metadata."""
29
+ return {
30
+ "simple_search": Tool(
31
+ name="simple_search",
32
+ type=ToolType.SEARCH,
33
+ description="Perform web search using DuckDuckGo",
34
+ required_params=["query"],
35
+ optional_params=["max_results"],
36
+ confidence_threshold=0.6
37
+ ),
38
+ "code_analysis": Tool(
39
+ name="code_analysis",
40
+ type=ToolType.CODE_ANALYSIS,
41
+ description="Analyze Python code structure and provide insights",
42
+ required_params=["code"],
43
+ optional_params=[],
44
+ confidence_threshold=0.8
45
+ ),
46
+ "file_operation": Tool(
47
+ name="file_operation",
48
+ type=ToolType.FILE_OPERATION,
49
+ description="Perform file operations like read/write",
50
+ required_params=["path"],
51
+ optional_params=["mode"],
52
+ confidence_threshold=0.9
53
+ )
54
+ }
55
+
56
+ def analyze_request(self, request: str) -> Dict[str, Any]:
57
+ """
58
+ Analyze the user request to determine the best course of action.
59
+ """
60
+ analysis = {
61
+ "intent": self._detect_intent(request),
62
+ "required_tools": [],
63
+ "confidence": 0.0,
64
+ "suggested_actions": []
65
+ }
66
+
67
+ # Determine required tools based on intent
68
+ if "search" in analysis["intent"]:
69
+ analysis["required_tools"].append(self.tools["simple_search"])
70
+ if "code" in analysis["intent"]:
71
+ analysis["required_tools"].append(self.tools["code_analysis"])
72
+ if "file" in analysis["intent"]:
73
+ analysis["required_tools"].append(self.tools["file_operation"])
74
+
75
+ # Calculate confidence based on tool requirements
76
+ if analysis["required_tools"]:
77
+ analysis["confidence"] = min(tool.confidence_threshold for tool in analysis["required_tools"])
78
+
79
+ # Generate suggested actions
80
+ analysis["suggested_actions"] = self._generate_actions(analysis)
81
+
82
+ return analysis
83
+
84
+ def _detect_intent(self, request: str) -> List[str]:
85
+ """Detect the intent(s) from the user request."""
86
+ intents = []
87
+
88
+ # Python-specific keyword-based intent detection
89
+ keywords = {
90
+ "search": ["search", "find", "look up", "query"],
91
+ "code": ["python", "code", "function", "class", "analyze", "def", "import", "from"],
92
+ "file": ["file", "read", "write", "save", "load", ".py"]
93
+ }
94
+
95
+ request_lower = request.lower()
96
+ for intent, words in keywords.items():
97
+ if any(word in request_lower for word in words):
98
+ intents.append(intent)
99
+
100
+ return intents if intents else ["unknown"]
101
+
102
+ def _generate_actions(self, analysis: Dict[str, Any]) -> List[Dict[str, Any]]:
103
+ """Generate suggested actions based on the analysis."""
104
+ actions = []
105
+
106
+ for tool in analysis["required_tools"]:
107
+ action = {
108
+ "tool": tool.name,
109
+ "type": tool.type.value,
110
+ "confidence": tool.confidence_threshold,
111
+ "required_params": tool.required_params,
112
+ "optional_params": tool.optional_params
113
+ }
114
+ actions.append(action)
115
+
116
+ return actions
117
+
118
+ def validate_tool_usage(self, tool_name: str, params: Dict[str, Any]) -> Dict[str, Any]:
119
+ """
120
+ Validate if a tool can be used with the given parameters.
121
+ """
122
+ if tool_name not in self.tools:
123
+ return {
124
+ "valid": False,
125
+ "error": f"Unknown tool: {tool_name}"
126
+ }
127
+
128
+ tool = self.tools[tool_name]
129
+ validation = {
130
+ "valid": True,
131
+ "missing_params": [],
132
+ "extra_params": []
133
+ }
134
+
135
+ # Check required parameters
136
+ for param in tool.required_params:
137
+ if param not in params:
138
+ validation["valid"] = False
139
+ validation["missing_params"].append(param)
140
+
141
+ # Check for extra parameters
142
+ for param in params:
143
+ if param not in tool.required_params and param not in tool.optional_params:
144
+ validation["extra_params"].append(param)
145
+
146
+ return validation
147
+
148
+ def log_decision(self, request: str, analysis: Dict[str, Any], outcome: Dict[str, Any]):
149
+ """
150
+ Log a decision made by the system for future reference.
151
+ """
152
+ decision = {
153
+ "timestamp": time.time(),
154
+ "request": request,
155
+ "analysis": analysis,
156
+ "outcome": outcome
157
+ }
158
+ self.decision_history.append(decision)
159
+
160
+ def get_decision_history(self) -> List[Dict[str, Any]]:
161
+ """Get the history of decisions made."""
162
+ return self.decision_history