Spaces:
Running
Running
''' | |
Chat State and Logging | |
''' | |
import json | |
import os | |
from typing import Any, Literal, Optional | |
from conversation import Conversation | |
import datetime | |
import uuid | |
LOG_DIR = os.getenv("LOGDIR", "./logs") | |
''' | |
The default output dir of log files | |
''' | |
class ModelChatState: | |
''' | |
The state of a chat with a model. | |
''' | |
is_vision: bool | |
''' | |
Whether the model is vision based. | |
''' | |
conv: Conversation | |
''' | |
The conversation | |
''' | |
conv_id: str | |
''' | |
Unique identifier for the model conversation. | |
Unique per chat per model. | |
''' | |
chat_session_id: str | |
''' | |
Unique identifier for the chat session. | |
Unique per chat. The two battle models share the same chat session id. | |
''' | |
skip_next: bool | |
''' | |
Flag to indicate skipping the next operation. | |
''' | |
model_name: str | |
''' | |
Name of the model being used. | |
''' | |
oai_thread_id: Optional[str] | |
''' | |
Identifier for the OpenAI thread. | |
''' | |
has_csam_image: bool | |
''' | |
Indicates if a CSAM image has been uploaded. | |
''' | |
regen_support: bool | |
''' | |
Indicates if regeneration is supported for the model. | |
''' | |
chat_start_time: datetime.datetime | |
''' | |
Chat start time. | |
''' | |
chat_mode: Literal['battle_anony', 'battle_named', 'direct'] | |
''' | |
Chat mode. | |
''' | |
curr_response_type: Literal['chat_multi', 'chat_single', 'regenerate_multi', 'regenerate_single'] | None | |
''' | |
Current response type. Used for logging. | |
''' | |
def create_chat_session_id() -> str: | |
''' | |
Create a new chat session id. | |
''' | |
return uuid.uuid4().hex | |
def create_battle_chat_states( | |
model_name_1: str, model_name_2: str, | |
chat_mode: Literal['battle_anony', 'battle_named'], | |
is_vision: bool, | |
) -> tuple['ModelChatState', 'ModelChatState']: | |
''' | |
Create two chat states for a battle. | |
''' | |
chat_session_id = ModelChatState.create_chat_session_id() | |
return ( | |
ModelChatState(model_name_1, chat_mode, | |
is_vision=is_vision, | |
chat_session_id=chat_session_id), | |
ModelChatState(model_name_2, chat_mode, | |
is_vision=is_vision, | |
chat_session_id=chat_session_id), | |
) | |
def __init__(self, | |
model_name: str, | |
chat_mode: Literal['battle_anony', 'battle_named', 'direct'], | |
is_vision: bool, | |
chat_session_id: str | None = None, | |
): | |
from fastchat.model.model_adapter import get_conversation_template | |
self.conv = get_conversation_template(model_name) | |
self.conv_id = uuid.uuid4().hex | |
# if no chat session id is provided, use the conversation id | |
self.chat_session_id = chat_session_id if chat_session_id else self.conv_id | |
self.chat_start_time = datetime.datetime.now() | |
self.chat_mode = chat_mode | |
self.skip_next = False | |
self.model_name = model_name | |
self.oai_thread_id = None | |
self.is_vision = is_vision | |
# NOTE(chris): This could be sort of a hack since it assumes the user only uploads one image. If they can upload multiple, we should store a list of image hashes. | |
self.has_csam_image = False | |
self.regen_support = True | |
if "browsing" in model_name: | |
self.regen_support = False | |
self.init_system_prompt(self.conv, is_vision) | |
def init_system_prompt(self, conv, is_vision): | |
system_prompt = conv.get_system_message(is_vision) | |
if len(system_prompt) == 0: | |
return | |
current_date = datetime.datetime.now().strftime("%Y-%m-%d") | |
system_prompt = system_prompt.replace("{{currentDateTime}}", current_date) | |
current_date_v2 = datetime.datetime.now().strftime("%d %b %Y") | |
system_prompt = system_prompt.replace("{{currentDateTimev2}}", current_date_v2) | |
current_date_v3 = datetime.datetime.now().strftime("%B %Y") | |
system_prompt = system_prompt.replace("{{currentDateTimev3}}", current_date_v3) | |
conv.set_system_message(system_prompt) | |
def set_response_type( | |
self, | |
response_type: Literal['chat_multi', 'chat_single', 'regenerate_multi', 'regenerate_single'] | |
): | |
''' | |
Set the response type for the chat state. | |
''' | |
self.curr_response_type = response_type | |
def to_gradio_chatbot(self): | |
''' | |
Convert to a Gradio chatbot. | |
''' | |
return self.conv.to_gradio_chatbot() | |
def get_conv_log_filepath(self, path_prefix: str): | |
''' | |
Get the filepath for the conversation log. | |
Expected directory structure: | |
softwarearenlog/ | |
βββ YEAR_MONTH_DAY/ | |
βββ conv_logs/ | |
βββ sandbox_logs/ | |
''' | |
date_str = self.chat_start_time.strftime('%Y_%m_%d') | |
filepath = os.path.join( | |
path_prefix, | |
date_str, | |
'conv_logs', | |
self.chat_mode, | |
f"conv-log-{self.chat_session_id}.json" | |
) | |
return filepath | |
def to_dict(self): | |
base = self.conv.to_dict() | |
base.update( | |
{ | |
"chat_session_id": self.chat_session_id, | |
"conv_id": self.conv_id, | |
"chat_mode": self.chat_mode, | |
"chat_start_time": self.chat_start_time, | |
"model_name": self.model_name, | |
} | |
) | |
if self.is_vision: | |
base.update({"has_csam_image": self.has_csam_image}) | |
return base | |
def generate_vote_record( | |
self, | |
vote_type: str, | |
ip: str | |
) -> dict[str, Any]: | |
''' | |
Generate a vote record for telemertry. | |
''' | |
data = { | |
"tstamp": round(datetime.datetime.now().timestamp(), 4), | |
"type": vote_type, | |
"model": self.model_name, | |
"state": self.to_dict(), | |
"ip": ip, | |
} | |
return data | |
def generate_response_record( | |
self, | |
gen_params: dict[str, Any], | |
start_ts: float, | |
end_ts: float, | |
ip: str | |
) -> dict[str, Any]: | |
''' | |
Generate a vote record for telemertry. | |
''' | |
data = { | |
"tstamp": round(datetime.datetime.now().timestamp(), 4), | |
"type": self.curr_response_type, | |
"model": self.model_name, | |
"start_ts": round(start_ts, 4), | |
"end_ts": round(end_ts, 4), | |
"gen_params": gen_params, | |
"state": self.to_dict(), | |
"ip": ip, | |
} | |
return data | |
def save_log_to_local( | |
log_data: dict[str, Any], | |
log_path: str, | |
write_mode: Literal['overwrite', 'append'] = 'append' | |
): | |
''' | |
Save the log locally. | |
''' | |
log_json = json.dumps(log_data, default=str) | |
os.makedirs(os.path.dirname(log_path), exist_ok=True) | |
with open(log_path, "w" if write_mode == 'overwrite' else 'a') as fout: | |
fout.write(log_json + "\n") | |