File size: 5,188 Bytes
e3f06ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import asyncio
import websockets
import logging
import json
from convaimodel_extended import ConvAIModelExtended

logging.basicConfig()

STATE = {"value": 0}

USERS = set()

train_args = {}
model = ConvAIModelExtended("gpt2", "cahya/gpt2-small-indonesian-personachat",
                            args=train_args, use_cuda=False)


def connection_event():
    return json.dumps({"type": "connection", "value": True})


def state_event():
    return json.dumps({"type": "state", **STATE})


def dialog_event(message):
    return json.dumps({"type": "dialog", "message": message})


def personality_event(message):
    return json.dumps({"type": "personality", "message": message})


def persona_list_event(message):
    return json.dumps({"type": "persona_list", "message": message})


def personality_reply_event(message):
    return json.dumps({"type": "personality_reply", "message": message})


def persona_greeting_event(message):
    return json.dumps({"type": "persona_greeting", "message": message})

def talk_event(message):
    return json.dumps({"type": "talk", "message": message})


def users_event():
    return json.dumps({"type": "users", "count": len(USERS)})


async def chatbot(websocket, path):
    dialog_id = 0
    try:
        # Register user
        USERS.add(websocket)
        await websocket.send(connection_event())
        websockets.broadcast(USERS, users_event())
        # Send current state to user
        await websocket.send(state_event())
        # Manage state changes
        async for message in websocket:
            message = message.strip()
            if message == "":
                continue
            try:
                data = json.loads(message)
                if data["action"] == "minus":
                    STATE["value"] -= 1
                    websockets.broadcast(USERS, state_event())
                elif data["action"] == "plus":
                    STATE["value"] += 1
                    websockets.broadcast(USERS, state_event())
                elif data["action"] == "get_users":
                    await websocket.send(users_event())
                elif data["action"] == "dialog":
                    if dialog_id == 0:
                        dialog_id = model.new_dialog()
                        if dialog_id != 0:
                            await websocket.send(dialog_event("New dialog is created"))
                            persona_list = model.get_persona_list(dialog_id)
                            await websocket.send(persona_list_event(persona_list))
                        else:
                            await websocket.send(dialog_event("Dialog is not created"))
                elif data["action"] == "talk":
                    if dialog_id != 0:
                        do_sample = bool(data["do_sample"]) if "do_sample" in data else True
                        min_length = int(data["min_length"]) if "min_length" in data else 1
                        max_length = int(data["max_length"]) if "max_length" in data else 20
                        temperature = float(data["temperature"]) if "temperature" in data else 0.7
                        top_k = int(data["top_k"]) if "top_k" in data else 0
                        top_p = float(data["top_p"]) if "top_p" in data else 0.9
                        reply = model.talk(dialog_id,
                                           persona_id=data["persona_id"],
                                           utterance=data["utterance"],
                                           do_sample=do_sample,
                                           min_length=min_length,
                                           max_length=max_length,
                                           temperature=temperature,
                                           top_k=top_k,
                                           top_p=top_p)
                        await websocket.send(talk_event(reply))
                elif data["action"] == "personality":
                    if dialog_id != 0:
                        model.set_personality(dialog_id, persona_id=data["persona_id"], personality=data["message"])
                        await websocket.send(personality_reply_event("Personality has been updated"))
                elif data["action"] == "persona_chosen":
                    if dialog_id != 0:
                        name = ConvAIModelExtended.get_persona_name(dialog_id, data["persona_id"])
                        greeting = f"Hi, I am {name}. Nice to meet you. Feel free too talk in English or Indonesian."
                        await websocket.send(persona_greeting_event(greeting))
                else:
                    logging.error("unsupported event: %s", data)
            except json.decoder.JSONDecodeError as error:
                print(error)
    finally:
        # Unregister user
        ConvAIModelExtended.delete_dialog(dialog_id)
        USERS.remove(websocket)
        websockets.broadcast(USERS, users_event())


async def main():
    async with websockets.serve(chatbot, "0.0.0.0", 8502):
        print("Websocket is running")
        await asyncio.Future()  # run forever


if __name__ == "__main__":
    asyncio.run(main())