# app.py import streamlit as st from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import sqlparse # Set page config st.set_page_config( page_title="AI SQL Query Generator", page_icon="🤖", layout="centered" ) # Load model and tokenizer @st.cache_resource def load_model(): model_name = "tscholak/cxmefzzi" # Pre-trained text-to-SQL model tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSeq2SeqLM.from_pretrained(model_name) return tokenizer, model # Format SQL output def format_sql(sql): return sqlparse.format(sql, reindent=True, keyword_case='upper') # Generate SQL from natural language def generate_sql(input_text, tokenizer, model): prefix = "Translate English to SQL: " inputs = tokenizer(prefix + input_text, return_tensors="pt", max_length=512, truncation=True) outputs = model.generate(**inputs, max_length=256) return tokenizer.decode(outputs[0], skip_special_tokens=True) # Streamlit UI def main(): st.title("🤖 AI-Powered SQL Query Generator") st.markdown("Convert natural language questions to SQL queries") # Load model tokenizer, model = load_model() # User input user_input = st.text_area( "Enter your question in natural language:", placeholder="e.g., Show all customers from California who made purchases after January 2023", height=150 ) # Generate button if st.button("Generate SQL"): if user_input.strip() == "": st.warning("Please enter a question") else: with st.spinner("Generating SQL query..."): try: # Generate and format SQL raw_sql = generate_sql(user_input, tokenizer, model) formatted_sql = format_sql(raw_sql) # Display results st.subheader("Generated SQL Query:") st.code(formatted_sql, language="sql") st.success("Query generated successfully!") # Show raw output for debugging with st.expander("Debug Info"): st.write(f"Model: tscholak/cxmefzzi") st.write(f"Raw Output: `{raw_sql}`") except Exception as e: st.error(f"Error generating SQL: {str(e)}") # Footer st.markdown("---") st.markdown("### How to use:") st.markdown("1. Enter a question about data you want to query") st.markdown("2. Click 'Generate SQL'") st.markdown("3. Copy the generated SQL and use it in your database") st.markdown("### Example queries:") st.code("Show the total sales per product category in 2022", language="text") st.code("List employees hired before 2020 with salary above $50,000", language="text") st.code("Count orders by customer country and sort descending", language="text") if __name__ == "__main__": main()