Spaces:
Runtime error
Runtime error
import asyncio | |
import websockets | |
import logging | |
import json | |
from convaimodel_extended import ConvAIModelExtended | |
logging.basicConfig() | |
STATE = {"value": 0} | |
USERS = set() | |
train_args = {} | |
model = ConvAIModelExtended("gpt2", "cahya/gpt2-small-indonesian-personachat", | |
args=train_args, use_cuda=False) | |
def connection_event(): | |
return json.dumps({"type": "connection", "value": True}) | |
def state_event(): | |
return json.dumps({"type": "state", **STATE}) | |
def dialog_event(message): | |
return json.dumps({"type": "dialog", "message": message}) | |
def personality_event(message): | |
return json.dumps({"type": "personality", "message": message}) | |
def persona_list_event(message): | |
return json.dumps({"type": "persona_list", "message": message}) | |
def personality_reply_event(message): | |
return json.dumps({"type": "personality_reply", "message": message}) | |
def persona_greeting_event(message): | |
return json.dumps({"type": "persona_greeting", "message": message}) | |
def talk_event(message): | |
return json.dumps({"type": "talk", "message": message}) | |
def users_event(): | |
return json.dumps({"type": "users", "count": len(USERS)}) | |
async def chatbot(websocket, path): | |
dialog_id = 0 | |
try: | |
# Register user | |
USERS.add(websocket) | |
await websocket.send(connection_event()) | |
websockets.broadcast(USERS, users_event()) | |
# Send current state to user | |
await websocket.send(state_event()) | |
# Manage state changes | |
async for message in websocket: | |
message = message.strip() | |
if message == "": | |
continue | |
try: | |
data = json.loads(message) | |
if data["action"] == "minus": | |
STATE["value"] -= 1 | |
websockets.broadcast(USERS, state_event()) | |
elif data["action"] == "plus": | |
STATE["value"] += 1 | |
websockets.broadcast(USERS, state_event()) | |
elif data["action"] == "get_users": | |
await websocket.send(users_event()) | |
elif data["action"] == "dialog": | |
if dialog_id == 0: | |
dialog_id = model.new_dialog() | |
if dialog_id != 0: | |
await websocket.send(dialog_event("New dialog is created")) | |
persona_list = model.get_persona_list(dialog_id) | |
await websocket.send(persona_list_event(persona_list)) | |
else: | |
await websocket.send(dialog_event("Dialog is not created")) | |
elif data["action"] == "talk": | |
if dialog_id != 0: | |
do_sample = bool(data["do_sample"]) if "do_sample" in data else True | |
min_length = int(data["min_length"]) if "min_length" in data else 1 | |
max_length = int(data["max_length"]) if "max_length" in data else 20 | |
temperature = float(data["temperature"]) if "temperature" in data else 0.7 | |
top_k = int(data["top_k"]) if "top_k" in data else 0 | |
top_p = float(data["top_p"]) if "top_p" in data else 0.9 | |
reply = model.talk(dialog_id, | |
persona_id=data["persona_id"], | |
utterance=data["utterance"], | |
do_sample=do_sample, | |
min_length=min_length, | |
max_length=max_length, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p) | |
await websocket.send(talk_event(reply)) | |
elif data["action"] == "personality": | |
if dialog_id != 0: | |
model.set_personality(dialog_id, persona_id=data["persona_id"], personality=data["message"]) | |
await websocket.send(personality_reply_event("Personality has been updated")) | |
elif data["action"] == "persona_chosen": | |
if dialog_id != 0: | |
name = ConvAIModelExtended.get_persona_name(dialog_id, data["persona_id"]) | |
greeting = f"Hi, I am {name}. Nice to meet you. Feel free too talk in English or Indonesian." | |
await websocket.send(persona_greeting_event(greeting)) | |
else: | |
logging.error("unsupported event: %s", data) | |
except json.decoder.JSONDecodeError as error: | |
print(error) | |
finally: | |
# Unregister user | |
ConvAIModelExtended.delete_dialog(dialog_id) | |
USERS.remove(websocket) | |
websockets.broadcast(USERS, users_event()) | |
async def main(): | |
async with websockets.serve(chatbot, "0.0.0.0", 8502): | |
print("Websocket is running") | |
await asyncio.Future() # run forever | |
if __name__ == "__main__": | |
asyncio.run(main()) | |