File size: 3,947 Bytes
c52a50c
 
78768a5
 
c52a50c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78768a5
 
c52a50c
 
e595409
c52a50c
 
 
e595409
 
 
c52a50c
 
 
4ebb27c
 
e595409
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c52a50c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import re

import torch

import gradio as gr

from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer


def load_model_tokenizer():
    model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-3B-Instruct", max_length=2560)
    model = PeftModel.from_pretrained(model, "DeathReaper0965/Qwen2.5-3B-Inst-SQL-Reasoning-GRPO", is_trainable=False)

    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B-Instruct", max_length = 2560)

    return model, tokenizer


model, tokenizer = load_model_tokenizer()


def create_prompt(schemas, question):
    prompt = [
      {
        'role': 'system',
        'content': """\
You are an expert SQL Query Writer. 
Given relevant Schemas and the Question, you first understand the problem entirely and then reason about the best possible approach to come up with an answer.
Once, you are confident in your reasoning, you will then start generating the SQL Query as the answer that accurately solves the given question leveraging some or all schemas.

Remember that you should place all your reasoning between <reason> and </reason> tags.
Also, you should provide your solution between <answer> and </answer> tags.

An example generation is as follows:
<reason>
This is a sample reasoning that solves the question based on the schema.
</reason>
<answer>
SELECT
    COLUMN
FROM TABLE_NAME
WHERE
    CONDITION
</answer>"""
      },
      {
        'role': 'user',
        'content': f"""\
SCHEMAS:
---------------

{schemas}

---------------

QUESTION: "{question}"\
"""
      }
    ]

    return prompt


def extract_answer(gen_output):
    answer_start_token = "<answer>"
    answer_end_token = "</answer>"
    answer_match_format = re.compile(rf"{answer_start_token}(.+?){answer_end_token}", flags = re.MULTILINE | re.DOTALL | re.IGNORECASE)

    answer_match = answer_match_format.search(gen_output)

    final_answer = None

    if answer_match is not None:
        final_answer = answer_match.group(1)

    return final_answer


def response(user_schemas, user_question):
    user_prompt = create_prompt(user_schemas, user_question)

    inputs = tokenizer.apply_chat_template(user_prompt,
                                           tokenize=True,
                                           add_generation_prompt=True,
                                           return_dict=True,
                                           return_tensors="pt")
    
    with torch.inference_mode():
        outputs = model.generate(**inputs, max_new_tokens=1024)
    
    outputs = tokenizer.batch_decode(outputs)
    output = outputs[0].split("<|im_start|>assistant")[-1].strip()

    final_answer = extract_answer(output)

    final_output = output + "\n\n" + "="*20 + "\n\nFinal Answer: \n" + final_answer

    return final_output


desc="""
**NOTE: This HF Space is running on Free Version so the generation process will be very slow.**<br>

Please use the "Table Schemas" field to provide the required schemas to to generate the SQL Query for - separated by new lines.<br>
**Example:**
```python
CREATE TABLE demographic (
    subject_id text, 
    admission_type text, 
    hadm_id text)

CREATE TABLE diagnoses (
    subject_id text,
    hadm_id text)
```

Finally, use the "Question" field to provide the relevant question to be answered based on the provided schemas.<br>
**Example:** How many patients whose admission type is emergency.
"""

demo = gr.Interface(
    fn=response,
    inputs=[gr.Textbox(label="Table Schemas", 
                       placeholder="Expected to have CREATE TABLE statements with datatypes separated by new lines"), 
            gr.Textbox(label="Question",
                       placeholder="Eg. How many patients whose admission type is emergency")
            ],
    outputs=gr.Textbox(label="Generated SQL Query with reasoning"),
    title="SQL Query Generator trained with GRPO to elicit reasoning",
    description=desc
)

demo.launch()