chat-server / app /convaimodel_extended.py
cahya's picture
first commit
e3f06ac
from simpletransformers.conv_ai import ConvAIModel
from simpletransformers.conv_ai.conv_ai_utils import get_dataset
import torch
import random
import os
import copy
import json
class ConvAIModelExtended(ConvAIModel):
PERSONACHAT_URL = "https://cloud.uncool.ai/index.php/s/Pazx3rifmFpwNNm/download/id_personachat.json"
dataset_path = "data/id_personachat.json"
persona_list_path = "data/persona_list.json"
dialogs = {}
dialogs_counter = 0
def __init__(self, model_type, model_name, args=None, use_cuda=True, **kwargs):
super(ConvAIModelExtended, self).__init__(model_type, model_name,
args, use_cuda, **kwargs)
os.makedirs(self.args.cache_dir, exist_ok=True)
self.dataset = get_dataset(
self.tokenizer,
dataset_path=ConvAIModelExtended.dataset_path,
dataset_cache=self.args.cache_dir,
process_count=self.args.process_count,
proxies=self.__dict__.get("proxies", None),
interact=False,
args=self.args,
)
self.personalities = [
dialog["personality"]
for dataset in self.dataset.values()
for dialog in dataset
]
with open(ConvAIModelExtended.persona_list_path, "r") as f:
self.persona_list = json.load(f)
def new_dialog(self):
tokenizer = self.tokenizer
ConvAIModelExtended.dialogs_counter += 1
dialog_id = ConvAIModelExtended.dialogs_counter
persona_list = copy.deepcopy(self.persona_list)
for persona in persona_list:
persona["history"] = []
persona["personality"] = [tokenizer.encode(s.lower()) for s in persona["personality"]]
persona_ids = {persona["id"]: persona for persona in persona_list}
ConvAIModelExtended.dialogs[dialog_id] = {
"persona_list": persona_list,
"persona_ids": persona_ids,
"args": copy.deepcopy(self.args) # each dialog has its own independent copy of args
}
return dialog_id
@staticmethod
def delete_dialog(dialog_id):
del ConvAIModelExtended.dialogs[dialog_id]
def get_persona_list(self, dialog_id: int):
tokenizer = self.tokenizer
persona_list = copy.deepcopy(ConvAIModelExtended.dialogs[dialog_id]["persona_list"])
for persona in persona_list:
persona["personality"] = [tokenizer.decode(tokens) for tokens in persona["personality"]]
return persona_list
def set_personality(self, dialog_id: int, persona_id: str, personality: list):
tokenizer = self.tokenizer
personality = [tokenizer.encode(s.lower()) for s in personality]
for i in range(3, len(ConvAIModelExtended.dialogs[dialog_id]["persona_ids"][persona_id]["personality"])):
ConvAIModelExtended.dialogs[dialog_id]["persona_ids"][persona_id]["personality"][i] = personality[i-3]
@staticmethod
def get_persona_name(dialog_id: int, persona_id: int):
name = ConvAIModelExtended.dialogs[dialog_id]["persona_ids"][persona_id]["name"]
return name
def talk(self, dialog_id: int, persona_id:int, utterance: str,
do_sample: bool = True, min_length: int = 1, max_length: int = 20,
temperature: float = 0.7, top_k: int = 0, top_p: float = 0.9):
model = self.model
args = ConvAIModelExtended.dialogs[dialog_id]["args"]
args.do_sample = do_sample
args.min_length = min_length
args.max_length = max_length
args.temperature = temperature
args.top_k = top_k
args.top_p = top_p
tokenizer = self.tokenizer
ConvAIModelExtended.dialogs[dialog_id]["persona_ids"][persona_id]["history"].append(
tokenizer.encode(utterance)
)
with torch.no_grad():
out_ids = self.sample_sequence(
ConvAIModelExtended.dialogs[dialog_id]["persona_ids"][persona_id]["personality"],
ConvAIModelExtended.dialogs[dialog_id]["persona_ids"][persona_id]["history"],
tokenizer, model, args
)
if len(out_ids) == 0:
return "Ma'af, saya tidak mengerti. Coba tanya yang lain"
ConvAIModelExtended.dialogs[dialog_id]["persona_ids"][persona_id]["history"].append(out_ids)
ConvAIModelExtended.dialogs[dialog_id]["persona_ids"][persona_id]["history"] = \
ConvAIModelExtended.dialogs[dialog_id]["persona_ids"][persona_id]["history"][-(2 * args.max_history + 1):]
out_text = tokenizer.decode(
out_ids, skip_special_tokens=args.skip_special_tokens
)
return out_text