import re
import torch
import gradio as gr
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
def load_model_tokenizer():
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-3B-Instruct", max_length=2560)
model = PeftModel.from_pretrained(model, "DeathReaper0965/Qwen2.5-3B-Inst-SQL-Reasoning-GRPO", is_trainable=False)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B-Instruct", max_length = 2560)
return model, tokenizer
model, tokenizer = load_model_tokenizer()
def create_prompt(schemas, question):
prompt = [
{
'role': 'system',
'content': """\
You are an expert SQL Query Writer.
Given relevant Schemas and the Question, you first understand the problem entirely and then reason about the best possible approach to come up with an answer.
Once, you are confident in your reasoning, you will then start generating the SQL Query as the answer that accurately solves the given question leveraging some or all schemas.
Remember that you should place all your reasoning between and tags.
Also, you should provide your solution between and tags.
An example generation is as follows:
This is a sample reasoning that solves the question based on the schema.
SELECT
COLUMN
FROM TABLE_NAME
WHERE
CONDITION
"""
},
{
'role': 'user',
'content': f"""\
SCHEMAS:
---------------
{schemas}
---------------
QUESTION: "{question}"\
"""
}
]
return prompt
def extract_answer(gen_output):
answer_start_token = ""
answer_end_token = ""
answer_match_format = re.compile(rf"{answer_start_token}(.+?){answer_end_token}", flags = re.MULTILINE | re.DOTALL | re.IGNORECASE)
answer_match = answer_match_format.search(gen_output)
final_answer = None
if answer_match is not None:
final_answer = answer_match.group(1)
return final_answer
def response(user_schemas, user_question):
user_prompt = create_prompt(user_schemas, user_question)
inputs = tokenizer.apply_chat_template(user_prompt,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt")
with torch.inference_mode():
outputs = model.generate(**inputs, max_new_tokens=1024)
outputs = tokenizer.batch_decode(outputs)
output = outputs[0].split("<|im_start|>assistant")[-1].strip()
final_answer = extract_answer(output)
final_output = output + "\n\n" + "="*20 + "\n\nFinal Answer: \n" + final_answer
return final_output
desc="""
**NOTE: This HF Space is running on Free Version so the generation process will be very slow.**
Please use the "Table Schemas" field to provide the required schemas to to generate the SQL Query for - separated by new lines.
**Example:**
```python
CREATE TABLE demographic (
subject_id text,
admission_type text,
hadm_id text)
CREATE TABLE diagnoses (
subject_id text,
hadm_id text)
```
Finally, use the "Question" field to provide the relevant question to be answered based on the provided schemas.
**Example:** How many patients whose admission type is emergency.
"""
demo = gr.Interface(
fn=response,
inputs=[gr.Textbox(label="Table Schemas",
placeholder="Expected to have CREATE TABLE statements with datatypes separated by new lines"),
gr.Textbox(label="Question",
placeholder="Eg. How many patients whose admission type is emergency")
],
outputs=gr.Textbox(label="Generated SQL Query with reasoning"),
title="SQL Query Generator trained with GRPO to elicit reasoning",
description=desc
)
demo.launch()