File size: 6,802 Bytes
d519be4
 
 
cb3dcae
 
 
 
 
 
 
 
d519be4
 
 
 
 
cb3dcae
a1a68e8
cb3dcae
d519be4
a1a68e8
d519be4
 
 
 
cb3dcae
bff69bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb3dcae
 
 
 
 
 
d519be4
cb3dcae
 
 
 
 
 
 
 
d519be4
cb3dcae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84e8a66
98b4357
cb3dcae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from dotenv import load_dotenv


import gradio as gr
from gradio import ChatMessage

import json
from openai import OpenAI
from datetime import datetime
import os
import re
import logging

logging.basicConfig(level=logging.INFO, format='[%(asctime)s][%(levelname)s] - %(message)s')
# logging.getLogger().setLevel(logging.INFO)


load_dotenv(".env", override=True) 
HF_TOKEN = os.environ.get("HF_TOKEN")  
BASE_URL = os.environ.get("BASE_URL")
EMBEDDINGS = os.environ.get("EMBEDDINGS_MODEL")



from tools import tools, oitools

SYSTEM_PROMPT_TEMPLATE = """Today’s date is **{date}**.

You are an AI assistant. Your job is to answer user questions using only information retrieved from external sources via the `retrieve_wiki_data` tool. Follow these rules:

### Tool Use Guidelines:

- **query**: When using `retrieve_wiki_data`, you may rephrase the user's query to improve clarity or specificity. However, **do not remove or change essential names or terms**.

- **missing_info**: If the information needed is **not already present** in the conversation or past tool responses, you **must call** `retrieve_wiki_data`.

- **redundant_search**: Do **not** use `retrieve_wiki_data` if the answer has already been retrieved. Avoid repeating searches unnecessarily.

- **wikipedia_entities**: If the user asks about a **person, place, topic, or concept likely to exist on Wikipedia**, and it hasn’t been discussed yet, you **must** use `retrieve_wiki_data` to find the information.

- **external_info_only**: You are not allowed to use your internal memory or built-in knowledge. Only respond based on the content retrieved using `retrieve_wiki_data`.

- **no_info_found**: If the tool returns no relevant content, clearly inform the user that you couldn’t find the answer.
"""


client = OpenAI(
    base_url=f"{BASE_URL}/v1",  
    api_key=HF_TOKEN
)
logging.info(f"Client initialized: {client}")

def today_date():
    return datetime.today().strftime('%A, %B %d, %Y, %I:%M %p')


def clean_json_string(json_str):
    return re.sub(r'[ ,}\s]+$', '', json_str) + '}'


def completion(history, model, system_prompt: str, tools=None):
    messages = [{"role": "system", "content": system_prompt.format(date=today_date())}]
    for msg in history:
        if isinstance(msg, dict):  
            msg = ChatMessage(**msg)
        if msg.role == "assistant" and hasattr(msg, "metadata") and msg.metadata:  
            tools_calls = json.loads(msg.metadata.get("title", "[]")) 
            # for tool_calls in tools_calls:
            #     tool_calls["function"]["arguments"] = json.loads(tool_calls["function"]["arguments"])
            messages.append({"role": "assistant", "tool_calls": tools_calls, "content": ""})
            messages.append({"role": "tool", "content": msg.content})
        else:
            messages.append({"role": msg.role, "content": msg.content})
    
    request_params = {
        "model": model,
        "messages": messages,
        "stream": True,
        "max_tokens": 1000,
        "temperature": 0.2,
        #"frequency_penalty": 1,
        "extra_body": {"repetition_penalty": 1.2},
    }
    if tools:
        request_params.update({"tool_choice": "auto", "tools": tools})
    
    return client.chat.completions.create(**request_params)  

def llm_in_loop(history, system_prompt, recursive):  
    try:   
        models = client.models.list()
        model = models.data[0].id if models.data else "gpt-3.5-turbo"  
    except Exception as err:
        gr.Warning("The model is initializing. Please wait; this may take 5 to 10 minutes ⏳.", duration=20)
        raise err
    
    arguments = ""
    name = ""
    chat_completion = completion(history=history, tools=oitools, model=model, system_prompt=system_prompt)  
    appended = False
    # if chat_completion.choices and chat_completion.choices[0].message.tool_calls:
    #     call = chat_completion.choices[0].message.tool_calls[0]
    #     if hasattr(call.function, "name") and call.function.name:
    #         name = call.function.name
    #     if hasattr(call.function, "arguments") and call.function.arguments:
    #         arguments += call.function.arguments
    # elif chat_completion.choices[0].message.content:
    #     if not appended:
    #         history.append(ChatMessage(role="assistant", content=""))
    #         appended = True
    #     history[-1].content += chat_completion.choices[0].message.content
    #     yield history[recursive:]
    for chunk in chat_completion:
        if chunk.choices and chunk.choices[0].delta.tool_calls:
            call = chunk.choices[0].delta.tool_calls[0]
            if hasattr(call.function, "name") and call.function.name:
                name = call.function.name
            if hasattr(call.function, "arguments") and call.function.arguments:
                arguments += call.function.arguments
        elif chunk.choices[0].delta.content:
            if not appended:
                history.append(ChatMessage(role="assistant", content=""))
                appended = True
            history[-1].content += chunk.choices[0].delta.content
            yield history[recursive:]
    
    arguments = clean_json_string(arguments) if arguments else "{}"
    print(name, arguments)
    arguments = json.loads(arguments)
    print(name, arguments)
    print("====================")
    if appended:
        recursive -= 1
    if name:
        result = f"💥 Error using tool {name}, tool doesn't exist" if name not in tools else str(tools[name].invoke(input=arguments))
        result = json.dumps({name: result}, ensure_ascii=False)
        # msg = ChatMessage(
        #             role="assistant",
        #             content="",
        #             metadata= {"title": f"🛠️ Using tool '{name}', arguments: {json.dumps(json_arguments, ensure_ascii=False)}"},
        #             options=[{"label":"tool_calls", "value": json.dumps([{"id": "call_FthC9qRpsL5kBpwwyw6c7j4k","function": {"arguments": arguments,"name": name},"type": "function"}])}]
        #         )
        history.append(ChatMessage(role="assistant", content=result, metadata={"title": json.dumps([{"id": "call_id", "function": {"arguments": json.dumps(arguments, ensure_ascii=False), "name": name}, "type": "function"}], ensure_ascii=False)}))
        yield history[recursive:]
        yield from llm_in_loop(history, system_prompt, recursive - 1)

def respond(message, history, additional_inputs):  
    history.append(ChatMessage(role="user", content=message))
    yield from llm_in_loop(history, additional_inputs, -1)

if __name__ == "__main__":
    system_prompt = gr.Textbox(label="System prompt", value=SYSTEM_PROMPT_TEMPLATE, lines=3)  
    demo = gr.ChatInterface(respond, type="messages", additional_inputs=[system_prompt])
    demo.launch()