Spaces:
Runtime error
Runtime error
File size: 6,171 Bytes
0d8581e 6416f7d 0d8581e 6416f7d 0d8581e 6416f7d 0d8581e 6416f7d 0d8581e 6416f7d 0d8581e 6416f7d 0d8581e 6416f7d 0d8581e 6416f7d 0d8581e 6416f7d 0d8581e 6416f7d 0d8581e 6416f7d 0d8581e 6416f7d 0d8581e 6416f7d 0d8581e 6416f7d 0d8581e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from peft import PeftModel
import logging
import os
import gc
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class TextToSQLModel:
"""Text-to-SQL model wrapper for deployment"""
def __init__(self, model_dir="./final-model", base_model="Salesforce/codet5-base"):
self.model_dir = model_dir
self.base_model = base_model
self.max_length = 128
self.model = None
self.tokenizer = None
self._load_model()
def _load_model(self):
"""Load the trained model and tokenizer with optimizations for HF Spaces"""
try:
# Check if model directory exists
if not os.path.exists(self.model_dir):
raise FileNotFoundError(f"Model directory {self.model_dir} not found")
logger.info("Loading tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_dir,
trust_remote_code=True,
use_fast=True
)
logger.info("Loading base model...")
# Use lower precision and CPU if needed for memory optimization
device = "cpu" # Force CPU for HF Spaces stability
torch_dtype = torch.float32 # Use float32 for better compatibility
base_model = AutoModelForSeq2SeqLM.from_pretrained(
self.base_model,
torch_dtype=torch_dtype,
device_map=device,
trust_remote_code=True,
low_cpu_mem_usage=True
)
logger.info("Loading PEFT model...")
self.model = PeftModel.from_pretrained(
base_model,
self.model_dir,
torch_dtype=torch_dtype,
device_map=device
)
# Move to CPU and set to eval mode
self.model = self.model.to(device)
self.model.eval()
# Clear cache to free memory
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
logger.info("Model loaded successfully!")
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
# Clean up on error
self.model = None
self.tokenizer = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
raise
def predict(self, question: str, table_headers: list) -> str:
"""
Generate SQL query for a given question and table headers
Args:
question (str): Natural language question
table_headers (list): List of table column names
Returns:
str: Generated SQL query
"""
try:
if self.model is None or self.tokenizer is None:
raise RuntimeError("Model not properly loaded")
# Format input text
table_headers_str = ", ".join(table_headers)
input_text = f"### Table columns:\n{table_headers_str}\n### Question:\n{question}\n### SQL:"
# Tokenize input
inputs = self.tokenizer(
input_text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=self.max_length
)
# Generate prediction with memory optimization
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_length=self.max_length,
num_beams=1, # Use greedy decoding for speed
do_sample=False,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
# Decode prediction
sql_query = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Clean up
del inputs, outputs
if torch.cuda.is_available():
torch.cuda.empty_cache()
return sql_query.strip()
except Exception as e:
logger.error(f"Error generating SQL: {str(e)}")
raise
def batch_predict(self, queries: list) -> list:
"""
Generate SQL queries for multiple questions
Args:
queries (list): List of dicts with 'question' and 'table_headers' keys
Returns:
list: List of generated SQL queries
"""
results = []
for query in queries:
try:
sql = self.predict(query['question'], query['table_headers'])
results.append({
'question': query['question'],
'table_headers': query['table_headers'],
'sql': sql,
'status': 'success'
})
except Exception as e:
logger.error(f"Error in batch prediction for query '{query['question']}': {str(e)}")
results.append({
'question': query['question'],
'table_headers': query['table_headers'],
'sql': None,
'status': 'error',
'error': str(e)
})
return results
def health_check(self) -> bool:
"""Check if model is loaded and ready"""
return (self.model is not None and
self.tokenizer is not None and
hasattr(self.model, 'generate'))
# Global model instance
_model_instance = None
def get_model():
"""Get or create global model instance"""
global _model_instance
if _model_instance is None:
_model_instance = TextToSQLModel()
return _model_instance |