File size: 3,947 Bytes
c52a50c 78768a5 c52a50c 78768a5 c52a50c e595409 c52a50c e595409 c52a50c 4ebb27c e595409 c52a50c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import re
import torch
import gradio as gr
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
def load_model_tokenizer():
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-3B-Instruct", max_length=2560)
model = PeftModel.from_pretrained(model, "DeathReaper0965/Qwen2.5-3B-Inst-SQL-Reasoning-GRPO", is_trainable=False)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B-Instruct", max_length = 2560)
return model, tokenizer
model, tokenizer = load_model_tokenizer()
def create_prompt(schemas, question):
prompt = [
{
'role': 'system',
'content': """\
You are an expert SQL Query Writer.
Given relevant Schemas and the Question, you first understand the problem entirely and then reason about the best possible approach to come up with an answer.
Once, you are confident in your reasoning, you will then start generating the SQL Query as the answer that accurately solves the given question leveraging some or all schemas.
Remember that you should place all your reasoning between <reason> and </reason> tags.
Also, you should provide your solution between <answer> and </answer> tags.
An example generation is as follows:
<reason>
This is a sample reasoning that solves the question based on the schema.
</reason>
<answer>
SELECT
COLUMN
FROM TABLE_NAME
WHERE
CONDITION
</answer>"""
},
{
'role': 'user',
'content': f"""\
SCHEMAS:
---------------
{schemas}
---------------
QUESTION: "{question}"\
"""
}
]
return prompt
def extract_answer(gen_output):
answer_start_token = "<answer>"
answer_end_token = "</answer>"
answer_match_format = re.compile(rf"{answer_start_token}(.+?){answer_end_token}", flags = re.MULTILINE | re.DOTALL | re.IGNORECASE)
answer_match = answer_match_format.search(gen_output)
final_answer = None
if answer_match is not None:
final_answer = answer_match.group(1)
return final_answer
def response(user_schemas, user_question):
user_prompt = create_prompt(user_schemas, user_question)
inputs = tokenizer.apply_chat_template(user_prompt,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt")
with torch.inference_mode():
outputs = model.generate(**inputs, max_new_tokens=1024)
outputs = tokenizer.batch_decode(outputs)
output = outputs[0].split("<|im_start|>assistant")[-1].strip()
final_answer = extract_answer(output)
final_output = output + "\n\n" + "="*20 + "\n\nFinal Answer: \n" + final_answer
return final_output
desc="""
**NOTE: This HF Space is running on Free Version so the generation process will be very slow.**<br>
Please use the "Table Schemas" field to provide the required schemas to to generate the SQL Query for - separated by new lines.<br>
**Example:**
```python
CREATE TABLE demographic (
subject_id text,
admission_type text,
hadm_id text)
CREATE TABLE diagnoses (
subject_id text,
hadm_id text)
```
Finally, use the "Question" field to provide the relevant question to be answered based on the provided schemas.<br>
**Example:** How many patients whose admission type is emergency.
"""
demo = gr.Interface(
fn=response,
inputs=[gr.Textbox(label="Table Schemas",
placeholder="Expected to have CREATE TABLE statements with datatypes separated by new lines"),
gr.Textbox(label="Question",
placeholder="Eg. How many patients whose admission type is emergency")
],
outputs=gr.Textbox(label="Generated SQL Query with reasoning"),
title="SQL Query Generator trained with GRPO to elicit reasoning",
description=desc
)
demo.launch() |