Spaces:
Runtime error
Runtime error
from simpletransformers.conv_ai import ConvAIModel | |
from simpletransformers.conv_ai.conv_ai_utils import get_dataset | |
import torch | |
import random | |
import os | |
import copy | |
import json | |
class ConvAIModelExtended(ConvAIModel): | |
PERSONACHAT_URL = "https://cloud.uncool.ai/index.php/s/Pazx3rifmFpwNNm/download/id_personachat.json" | |
dataset_path = "data/id_personachat.json" | |
persona_list_path = "data/persona_list.json" | |
dialogs = {} | |
dialogs_counter = 0 | |
def __init__(self, model_type, model_name, args=None, use_cuda=True, **kwargs): | |
super(ConvAIModelExtended, self).__init__(model_type, model_name, | |
args, use_cuda, **kwargs) | |
os.makedirs(self.args.cache_dir, exist_ok=True) | |
self.dataset = get_dataset( | |
self.tokenizer, | |
dataset_path=ConvAIModelExtended.dataset_path, | |
dataset_cache=self.args.cache_dir, | |
process_count=self.args.process_count, | |
proxies=self.__dict__.get("proxies", None), | |
interact=False, | |
args=self.args, | |
) | |
self.personalities = [ | |
dialog["personality"] | |
for dataset in self.dataset.values() | |
for dialog in dataset | |
] | |
with open(ConvAIModelExtended.persona_list_path, "r") as f: | |
self.persona_list = json.load(f) | |
def new_dialog(self): | |
tokenizer = self.tokenizer | |
ConvAIModelExtended.dialogs_counter += 1 | |
dialog_id = ConvAIModelExtended.dialogs_counter | |
persona_list = copy.deepcopy(self.persona_list) | |
for persona in persona_list: | |
persona["history"] = [] | |
persona["personality"] = [tokenizer.encode(s.lower()) for s in persona["personality"]] | |
persona_ids = {persona["id"]: persona for persona in persona_list} | |
ConvAIModelExtended.dialogs[dialog_id] = { | |
"persona_list": persona_list, | |
"persona_ids": persona_ids, | |
"args": copy.deepcopy(self.args) # each dialog has its own independent copy of args | |
} | |
return dialog_id | |
def delete_dialog(dialog_id): | |
del ConvAIModelExtended.dialogs[dialog_id] | |
def get_persona_list(self, dialog_id: int): | |
tokenizer = self.tokenizer | |
persona_list = copy.deepcopy(ConvAIModelExtended.dialogs[dialog_id]["persona_list"]) | |
for persona in persona_list: | |
persona["personality"] = [tokenizer.decode(tokens) for tokens in persona["personality"]] | |
return persona_list | |
def set_personality(self, dialog_id: int, persona_id: str, personality: list): | |
tokenizer = self.tokenizer | |
personality = [tokenizer.encode(s.lower()) for s in personality] | |
for i in range(3, len(ConvAIModelExtended.dialogs[dialog_id]["persona_ids"][persona_id]["personality"])): | |
ConvAIModelExtended.dialogs[dialog_id]["persona_ids"][persona_id]["personality"][i] = personality[i-3] | |
def get_persona_name(dialog_id: int, persona_id: int): | |
name = ConvAIModelExtended.dialogs[dialog_id]["persona_ids"][persona_id]["name"] | |
return name | |
def talk(self, dialog_id: int, persona_id:int, utterance: str, | |
do_sample: bool = True, min_length: int = 1, max_length: int = 20, | |
temperature: float = 0.7, top_k: int = 0, top_p: float = 0.9): | |
model = self.model | |
args = ConvAIModelExtended.dialogs[dialog_id]["args"] | |
args.do_sample = do_sample | |
args.min_length = min_length | |
args.max_length = max_length | |
args.temperature = temperature | |
args.top_k = top_k | |
args.top_p = top_p | |
tokenizer = self.tokenizer | |
ConvAIModelExtended.dialogs[dialog_id]["persona_ids"][persona_id]["history"].append( | |
tokenizer.encode(utterance) | |
) | |
with torch.no_grad(): | |
out_ids = self.sample_sequence( | |
ConvAIModelExtended.dialogs[dialog_id]["persona_ids"][persona_id]["personality"], | |
ConvAIModelExtended.dialogs[dialog_id]["persona_ids"][persona_id]["history"], | |
tokenizer, model, args | |
) | |
if len(out_ids) == 0: | |
return "Ma'af, saya tidak mengerti. Coba tanya yang lain" | |
ConvAIModelExtended.dialogs[dialog_id]["persona_ids"][persona_id]["history"].append(out_ids) | |
ConvAIModelExtended.dialogs[dialog_id]["persona_ids"][persona_id]["history"] = \ | |
ConvAIModelExtended.dialogs[dialog_id]["persona_ids"][persona_id]["history"][-(2 * args.max_history + 1):] | |
out_text = tokenizer.decode( | |
out_ids, skip_special_tokens=args.skip_special_tokens | |
) | |
return out_text | |