File size: 6,171 Bytes
0d8581e
 
 
 
6416f7d
 
0d8581e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6416f7d
0d8581e
6416f7d
 
 
 
0d8581e
6416f7d
 
 
 
 
0d8581e
 
6416f7d
 
 
 
 
 
 
 
 
 
 
0d8581e
 
6416f7d
 
 
 
 
 
 
 
 
0d8581e
 
6416f7d
 
 
 
 
0d8581e
 
 
 
6416f7d
 
 
 
 
 
0d8581e
 
 
 
 
 
 
 
 
 
 
 
 
 
6416f7d
 
 
0d8581e
 
 
 
 
 
 
 
 
 
 
 
 
6416f7d
0d8581e
6416f7d
 
 
 
 
 
 
 
0d8581e
 
 
 
6416f7d
 
 
 
 
 
0d8581e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6416f7d
0d8581e
 
 
 
 
 
 
 
 
 
 
 
6416f7d
 
 
0d8581e
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from peft import PeftModel
import logging
import os
import gc

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class TextToSQLModel:
    """Text-to-SQL model wrapper for deployment"""
    
    def __init__(self, model_dir="./final-model", base_model="Salesforce/codet5-base"):
        self.model_dir = model_dir
        self.base_model = base_model
        self.max_length = 128
        self.model = None
        self.tokenizer = None
        self._load_model()
    
    def _load_model(self):
        """Load the trained model and tokenizer with optimizations for HF Spaces"""
        try:
            # Check if model directory exists
            if not os.path.exists(self.model_dir):
                raise FileNotFoundError(f"Model directory {self.model_dir} not found")
            
            logger.info("Loading tokenizer...")
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.model_dir,
                trust_remote_code=True,
                use_fast=True
            )
            
            logger.info("Loading base model...")
            # Use lower precision and CPU if needed for memory optimization
            device = "cpu"  # Force CPU for HF Spaces stability
            torch_dtype = torch.float32  # Use float32 for better compatibility
            
            base_model = AutoModelForSeq2SeqLM.from_pretrained(
                self.base_model,
                torch_dtype=torch_dtype,
                device_map=device,
                trust_remote_code=True,
                low_cpu_mem_usage=True
            )
            
            logger.info("Loading PEFT model...")
            self.model = PeftModel.from_pretrained(
                base_model, 
                self.model_dir,
                torch_dtype=torch_dtype,
                device_map=device
            )
            
            # Move to CPU and set to eval mode
            self.model = self.model.to(device)
            self.model.eval()
            
            # Clear cache to free memory
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            gc.collect()
            
            logger.info("Model loaded successfully!")
            
        except Exception as e:
            logger.error(f"Error loading model: {str(e)}")
            # Clean up on error
            self.model = None
            self.tokenizer = None
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            gc.collect()
            raise
    
    def predict(self, question: str, table_headers: list) -> str:
        """
        Generate SQL query for a given question and table headers
        
        Args:
            question (str): Natural language question
            table_headers (list): List of table column names
            
        Returns:
            str: Generated SQL query
        """
        try:
            if self.model is None or self.tokenizer is None:
                raise RuntimeError("Model not properly loaded")
            
            # Format input text
            table_headers_str = ", ".join(table_headers)
            input_text = f"### Table columns:\n{table_headers_str}\n### Question:\n{question}\n### SQL:"
            
            # Tokenize input
            inputs = self.tokenizer(
                input_text, 
                return_tensors="pt", 
                padding=True, 
                truncation=True, 
                max_length=self.max_length
            )
            
            # Generate prediction with memory optimization
            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs, 
                    max_length=self.max_length,
                    num_beams=1,  # Use greedy decoding for speed
                    do_sample=False,
                    pad_token_id=self.tokenizer.pad_token_id,
                    eos_token_id=self.tokenizer.eos_token_id
                )
            
            # Decode prediction
            sql_query = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            
            # Clean up
            del inputs, outputs
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            
            return sql_query.strip()
            
        except Exception as e:
            logger.error(f"Error generating SQL: {str(e)}")
            raise
    
    def batch_predict(self, queries: list) -> list:
        """
        Generate SQL queries for multiple questions
        
        Args:
            queries (list): List of dicts with 'question' and 'table_headers' keys
            
        Returns:
            list: List of generated SQL queries
        """
        results = []
        for query in queries:
            try:
                sql = self.predict(query['question'], query['table_headers'])
                results.append({
                    'question': query['question'],
                    'table_headers': query['table_headers'],
                    'sql': sql,
                    'status': 'success'
                })
            except Exception as e:
                logger.error(f"Error in batch prediction for query '{query['question']}': {str(e)}")
                results.append({
                    'question': query['question'],
                    'table_headers': query['table_headers'],
                    'sql': None,
                    'status': 'error',
                    'error': str(e)
                })
        
        return results
    
    def health_check(self) -> bool:
        """Check if model is loaded and ready"""
        return (self.model is not None and 
                self.tokenizer is not None and 
                hasattr(self.model, 'generate'))

# Global model instance
_model_instance = None

def get_model():
    """Get or create global model instance"""
    global _model_instance
    if _model_instance is None:
        _model_instance = TextToSQLModel()
    return _model_instance