Spaces:
Sleeping
Sleeping
File size: 9,380 Bytes
d0c79e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 |
import logging
from abc import ABC
from typing import Dict, Any, Union, List, Literal, Optional
from datetime import datetime
import uuid
from aworld.models.model_response import ToolCall
from examples.debate.agent.base import DebateSpeech
from examples.debate.agent.prompts import user_assignment_prompt, user_assignment_system_prompt, affirmative_few_shots, \
negative_few_shots, \
user_debate_prompt
from examples.debate.agent.search.search_engine import SearchEngine
from examples.debate.agent.search.tavily_search_engine import TavilySearchEngine
from examples.debate.agent.stream_output_agent import StreamOutputAgent
from aworld.config import AgentConfig
from aworld.core.common import Observation, ActionModel
from aworld.output import SearchOutput, SearchItem, MessageOutput
from aworld.output.artifact import ArtifactType
def truncate_content(raw_content, char_limit):
if raw_content is None:
raw_content = ''
if len(raw_content) > char_limit:
raw_content = raw_content[:char_limit] + "... [truncated]"
return raw_content
class DebateAgent(StreamOutputAgent, ABC):
stance: Literal["affirmative", "negative"]
def __init__(self, name: str, stance: Literal["affirmative", "negative"], conf: AgentConfig, search_engine: Optional[SearchEngine] = TavilySearchEngine()):
conf.name = name
super().__init__(conf)
self.steps = 0
self.stance = stance
self.search_engine = search_engine
async def speech(self, topic: str, opinion: str,oppose_opinion: str, round: int, speech_history: list[DebateSpeech]) -> DebateSpeech:
observation = Observation(content=self.get_latest_speech(speech_history).content if self.get_latest_speech(speech_history) else "")
info = {
"topic": topic,
"round": round,
"opinion": opinion,
"oppose_opinion": oppose_opinion,
"history": speech_history
}
actions = await self.async_policy(observation, info)
return actions[0].policy_info
async def async_policy(self, observation: Observation, info: Dict[str, Any] = {}, **kwargs) -> Union[
List[ActionModel], None]:
## step 1: params
opponent_claim = observation.content
round = info["round"]
opinion = info["opinion"]
oppose_opinion = info["oppose_opinion"]
topic = info["topic"]
history: list[DebateSpeech] = info["history"]
#Event.emit("xxx")
## step2: gen keywords
keywords = await self.gen_keywords(topic, opinion, oppose_opinion, opponent_claim, history)
logging.info(f"gen keywords = {keywords}")
## step3:search_webpages
search_results = await self.search_webpages(keywords, max_results=5)
for search_result in search_results:
logging.info(f"keyword#{search_result['query']}-> result size is {len(search_result['results'])}")
search_item = {
"query": search_result.get("query", ""),
"results": [SearchItem(title=result["title"],url=result["url"], content=result['content'], raw_content=result['raw_content'], metadata={}) for result in search_result["results"]],
"origin_tool_call": ToolCall.from_dict({
"id": f"call_search",
"type": "function",
"function": {
"name": "search",
"arguments": keywords
}
})
}
search_output = SearchOutput.from_dict(search_item)
await self.workspace.create_artifact(
artifact_type=ArtifactType.WEB_PAGES,
artifact_id=str(uuid.uuid4()),
content=search_output,
metadata={
"query": search_output.query,
"user": self.name(),
"round": info["round"],
"opinion": info["opinion"],
"oppose_opinion": info["oppose_opinion"],
"topic": info["topic"],
"tags": [f"user#{self.name()}",f"Rounds#{info['round']}"]
}
)
## step4 gen result
user_response = await self.gen_statement(topic, opinion, oppose_opinion, opponent_claim, history, search_results)
logging.info(f"user_response is {user_response}")
## step3: gen speech
speech = DebateSpeech.from_dict({
"round": round,
"type": "speech",
"stance": self.stance,
"name": self.name(),
})
async def after_speech_call(message_output_response):
logging.info(f"{self.stance}#{self.name()}: after_speech_call")
speech.metadata = {}
speech.content = message_output_response
speech.finished = True
await speech.convert_to_parts(user_response, after_speech_call)
action = ActionModel(
policy_info=speech
)
return [action]
async def gen_keywords(self, topic, opinion, oppose_opinion, last_oppose_speech_content, history):
current_time = datetime.now().strftime("%Y-%m-%d-%H")
human_prompt = user_assignment_prompt.format(topic=topic,
opinion=opinion,
oppose_opinion=oppose_opinion,
last_oppose_speech_content=last_oppose_speech_content,
current_time = current_time,
limit=2
)
messages = [{'role': 'system', 'content': user_assignment_system_prompt},
{'role': 'user', 'content': human_prompt}]
output = await self.async_call_llm(messages)
response = await output.get_finished_response()
return response.split(",")
async def search_webpages(self, keywords, max_results):
return await self.search_engine.async_batch_search(queries=keywords, max_results=max_results)
async def gen_statement(self, topic, opinion, oppose_opinion, opponent_claim, history, search_results) -> MessageOutput:
search_results_content = ""
for search_result in search_results:
search_results_content += f"SearchQuery: {search_result['query']}"
search_results_content += "\n\n".join([truncate_content(s['content'], 1000) for s in search_result['results']])
unique_history = history
# if len(history) >= 2:
# for i in range(len(history)):
# # Check if the current element is the same as the next one
# if i == len(history) - 1 or history[i] != history[i+1]:
# # Add the current element to the result list
# unique_history.append(history[i])
affirmative_chat_history = ""
negative_chat_history = ""
if len(unique_history) >= 2:
if self.stance == "affirmative":
for speech in unique_history[:-1]:
if speech.stance == "affirmative":
affirmative_chat_history = affirmative_chat_history + "You: " + speech.content + "\n"
elif speech.stance == "negative":
affirmative_chat_history = affirmative_chat_history + "Your Opponent: " + speech.content + "\n"
elif self.stance == "negative":
for speech in unique_history[:-1]:
if speech.stance == "negative":
negative_chat_history = negative_chat_history + "You: " + speech.content + "\n"
elif speech.stance == "affirmative":
negative_chat_history = negative_chat_history + "Your Opponent: " + speech.content + "\n"
few_shots = ""
chat_history = ""
if self.stance == "affirmative":
chat_history = affirmative_chat_history
few_shots = affirmative_few_shots
elif self.stance == "negative":
chat_history = negative_chat_history
few_shots = negative_few_shots
human_prompt = user_debate_prompt.format(topic=topic,
opinion=opinion,
oppose_opinion=oppose_opinion,
last_oppose_speech_content=opponent_claim,
search_results_content=search_results_content,
chat_history = chat_history,
few_shots = few_shots
)
messages = [{'role': 'system', 'content': user_assignment_system_prompt},
{'role': 'user', 'content': human_prompt}]
return await self.async_call_llm(messages)
def get_latest_speech(self, history: list[DebateSpeech]):
"""
get the latest speech from history
"""
if len(history) == 0:
return None
return history[-1]
def set_workspace(self, workspace):
self.workspace = workspace
|