|
import re |
|
|
|
import torch |
|
|
|
import gradio as gr |
|
|
|
from peft import PeftModel |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
def load_model_tokenizer(): |
|
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-3B-Instruct", max_length=2560) |
|
model = PeftModel.from_pretrained(model, "DeathReaper0965/Qwen2.5-3B-Inst-SQL-Reasoning-GRPO", is_trainable=False) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B-Instruct", max_length = 2560) |
|
|
|
return model, tokenizer |
|
|
|
|
|
model, tokenizer = load_model_tokenizer() |
|
|
|
|
|
def create_prompt(schemas, question): |
|
prompt = [ |
|
{ |
|
'role': 'system', |
|
'content': """\ |
|
You are an expert SQL Query Writer. |
|
Given relevant Schemas and the Question, you first understand the problem entirely and then reason about the best possible approach to come up with an answer. |
|
Once, you are confident in your reasoning, you will then start generating the SQL Query as the answer that accurately solves the given question leveraging some or all schemas. |
|
|
|
Remember that you should place all your reasoning between <reason> and </reason> tags. |
|
Also, you should provide your solution between <answer> and </answer> tags. |
|
|
|
An example generation is as follows: |
|
<reason> |
|
This is a sample reasoning that solves the question based on the schema. |
|
</reason> |
|
<answer> |
|
SELECT |
|
COLUMN |
|
FROM TABLE_NAME |
|
WHERE |
|
CONDITION |
|
</answer>""" |
|
}, |
|
{ |
|
'role': 'user', |
|
'content': f"""\ |
|
SCHEMAS: |
|
--------------- |
|
|
|
{schemas} |
|
|
|
--------------- |
|
|
|
QUESTION: "{question}"\ |
|
""" |
|
} |
|
] |
|
|
|
return prompt |
|
|
|
|
|
def extract_answer(gen_output): |
|
answer_start_token = "<answer>" |
|
answer_end_token = "</answer>" |
|
answer_match_format = re.compile(rf"{answer_start_token}(.+?){answer_end_token}", flags = re.MULTILINE | re.DOTALL | re.IGNORECASE) |
|
|
|
answer_match = answer_match_format.search(gen_output) |
|
|
|
final_answer = None |
|
|
|
if answer_match is not None: |
|
final_answer = answer_match.group(1) |
|
|
|
return final_answer |
|
|
|
|
|
def response(user_schemas, user_question): |
|
user_prompt = create_prompt(user_schemas, user_question) |
|
|
|
inputs = tokenizer.apply_chat_template(user_prompt, |
|
tokenize=True, |
|
add_generation_prompt=True, |
|
return_dict=True, |
|
return_tensors="pt") |
|
|
|
with torch.inference_mode(): |
|
outputs = model.generate(**inputs, max_new_tokens=1024) |
|
|
|
outputs = tokenizer.batch_decode(outputs) |
|
output = outputs[0].split("<|im_start|>assistant")[-1].strip() |
|
|
|
final_answer = extract_answer(output) |
|
|
|
final_output = output + "\n\n" + "="*20 + "\n\nFinal Answer: \n" + final_answer |
|
|
|
return final_output |
|
|
|
|
|
desc=""" |
|
**NOTE: This HF Space is running on Free Version so the generation process will be very slow.**<br> |
|
|
|
Please use the "Table Schemas" field to provide the required schemas to to generate the SQL Query for - separated by new lines.<br> |
|
**Example:** |
|
```python |
|
CREATE TABLE demographic ( |
|
subject_id text, |
|
admission_type text, |
|
hadm_id text) |
|
|
|
CREATE TABLE diagnoses ( |
|
subject_id text, |
|
hadm_id text) |
|
``` |
|
|
|
Finally, use the "Question" field to provide the relevant question to be answered based on the provided schemas.<br> |
|
**Example:** How many patients whose admission type is emergency. |
|
""" |
|
|
|
demo = gr.Interface( |
|
fn=response, |
|
inputs=[gr.Textbox(label="Table Schemas", |
|
placeholder="Expected to have CREATE TABLE statements with datatypes separated by new lines"), |
|
gr.Textbox(label="Question", |
|
placeholder="Eg. How many patients whose admission type is emergency") |
|
], |
|
outputs=gr.Textbox(label="Generated SQL Query with reasoning"), |
|
title="SQL Query Generator trained with GRPO to elicit reasoning", |
|
description=desc |
|
) |
|
|
|
demo.launch() |