from simpletransformers.conv_ai import ConvAIModel from simpletransformers.conv_ai.conv_ai_utils import get_dataset import torch import random import os import copy import json class ConvAIModelExtended(ConvAIModel): PERSONACHAT_URL = "https://cloud.uncool.ai/index.php/s/Pazx3rifmFpwNNm/download/id_personachat.json" dataset_path = "data/id_personachat.json" persona_list_path = "data/persona_list.json" dialogs = {} dialogs_counter = 0 def __init__(self, model_type, model_name, args=None, use_cuda=True, **kwargs): super(ConvAIModelExtended, self).__init__(model_type, model_name, args, use_cuda, **kwargs) os.makedirs(self.args.cache_dir, exist_ok=True) self.dataset = get_dataset( self.tokenizer, dataset_path=ConvAIModelExtended.dataset_path, dataset_cache=self.args.cache_dir, process_count=self.args.process_count, proxies=self.__dict__.get("proxies", None), interact=False, args=self.args, ) self.personalities = [ dialog["personality"] for dataset in self.dataset.values() for dialog in dataset ] with open(ConvAIModelExtended.persona_list_path, "r") as f: self.persona_list = json.load(f) def new_dialog(self): tokenizer = self.tokenizer ConvAIModelExtended.dialogs_counter += 1 dialog_id = ConvAIModelExtended.dialogs_counter persona_list = copy.deepcopy(self.persona_list) for persona in persona_list: persona["history"] = [] persona["personality"] = [tokenizer.encode(s.lower()) for s in persona["personality"]] persona_ids = {persona["id"]: persona for persona in persona_list} ConvAIModelExtended.dialogs[dialog_id] = { "persona_list": persona_list, "persona_ids": persona_ids, "args": copy.deepcopy(self.args) # each dialog has its own independent copy of args } return dialog_id @staticmethod def delete_dialog(dialog_id): del ConvAIModelExtended.dialogs[dialog_id] def get_persona_list(self, dialog_id: int): tokenizer = self.tokenizer persona_list = copy.deepcopy(ConvAIModelExtended.dialogs[dialog_id]["persona_list"]) for persona in persona_list: persona["personality"] = [tokenizer.decode(tokens) for tokens in persona["personality"]] return persona_list def set_personality(self, dialog_id: int, persona_id: str, personality: list): tokenizer = self.tokenizer personality = [tokenizer.encode(s.lower()) for s in personality] for i in range(3, len(ConvAIModelExtended.dialogs[dialog_id]["persona_ids"][persona_id]["personality"])): ConvAIModelExtended.dialogs[dialog_id]["persona_ids"][persona_id]["personality"][i] = personality[i-3] @staticmethod def get_persona_name(dialog_id: int, persona_id: int): name = ConvAIModelExtended.dialogs[dialog_id]["persona_ids"][persona_id]["name"] return name def talk(self, dialog_id: int, persona_id:int, utterance: str, do_sample: bool = True, min_length: int = 1, max_length: int = 20, temperature: float = 0.7, top_k: int = 0, top_p: float = 0.9): model = self.model args = ConvAIModelExtended.dialogs[dialog_id]["args"] args.do_sample = do_sample args.min_length = min_length args.max_length = max_length args.temperature = temperature args.top_k = top_k args.top_p = top_p tokenizer = self.tokenizer ConvAIModelExtended.dialogs[dialog_id]["persona_ids"][persona_id]["history"].append( tokenizer.encode(utterance) ) with torch.no_grad(): out_ids = self.sample_sequence( ConvAIModelExtended.dialogs[dialog_id]["persona_ids"][persona_id]["personality"], ConvAIModelExtended.dialogs[dialog_id]["persona_ids"][persona_id]["history"], tokenizer, model, args ) if len(out_ids) == 0: return "Ma'af, saya tidak mengerti. Coba tanya yang lain" ConvAIModelExtended.dialogs[dialog_id]["persona_ids"][persona_id]["history"].append(out_ids) ConvAIModelExtended.dialogs[dialog_id]["persona_ids"][persona_id]["history"] = \ ConvAIModelExtended.dialogs[dialog_id]["persona_ids"][persona_id]["history"][-(2 * args.max_history + 1):] out_text = tokenizer.decode( out_ids, skip_special_tokens=args.skip_special_tokens ) return out_text