File size: 12,637 Bytes
2c72e40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
import os
import json
import logging
import subprocess
from datetime import datetime
from flask import request, jsonify, render_template
from werkzeug.utils import secure_filename
from threading import Thread
import time

logger = logging.getLogger(__name__)

# Global training status
training_status = {
    'status': 'idle',  # idle, training, completed, failed
    'progress': 0,
    'logs': [],
    'start_time': None,
    'error': None
}

def register_mt564_routes(app):
    """Register MT564 TinyLlama training routes"""
    
    @app.route('/mt564')
    def mt564_interface():
        """MT564 training interface"""
        return render_template('mt564.html')
    
    @app.route('/api/mt564/upload', methods=['POST'])
    def upload_mt564_docs():
        """Upload MT564 documentation files"""
        try:
            if 'files' not in request.files:
                return jsonify({'success': False, 'error': 'No files uploaded'})
            
            files = request.files.getlist('files')
            if not files or all(f.filename == '' for f in files):
                return jsonify({'success': False, 'error': 'No files selected'})
            
            # Ensure upload directory exists
            upload_dir = os.path.join('data', 'uploaded')
            os.makedirs(upload_dir, exist_ok=True)
            
            uploaded_files = []
            for file in files:
                if file and file.filename:
                    filename = secure_filename(file.filename)
                    filepath = os.path.join(upload_dir, filename)
                    file.save(filepath)
                    uploaded_files.append(filepath)
            
            # Process uploaded files to create training data
            processed_data = process_uploaded_files(uploaded_files)
            
            return jsonify({
                'success': True, 
                'files_uploaded': len(uploaded_files),
                'training_examples': len(processed_data)
            })
            
        except Exception as e:
            logger.error(f"Upload error: {e}")
            return jsonify({'success': False, 'error': str(e)})
    
    @app.route('/api/mt564/train', methods=['POST'])
    def start_mt564_training():
        """Start MT564 model training"""
        try:
            config = request.get_json()
            
            if training_status['status'] == 'training':
                return jsonify({'success': False, 'error': 'Training already in progress'})
            
            # Reset training status
            training_status.update({
                'status': 'training',
                'progress': 0,
                'logs': [],
                'start_time': datetime.now(),
                'error': None
            })
            
            # Start training in background thread
            training_thread = Thread(target=run_training, args=(config,))
            training_thread.daemon = True
            training_thread.start()
            
            return jsonify({'success': True, 'message': 'Training started'})
            
        except Exception as e:
            logger.error(f"Training start error: {e}")
            training_status.update({
                'status': 'failed',
                'error': str(e)
            })
            return jsonify({'success': False, 'error': str(e)})
    
    @app.route('/api/mt564/training-status', methods=['GET'])
    def get_training_status():
        """Get current training status"""
        return jsonify(training_status)
    
    @app.route('/api/mt564/query', methods=['POST'])
    def query_mt564_model():
        """Query the trained MT564 model"""
        try:
            data = request.get_json()
            query = data.get('query', '').strip()
            
            if not query:
                return jsonify({'success': False, 'error': 'Empty query'})
            
            # Check if trained model exists
            model_path = 'mt564_tinyllama_model'
            if not os.path.exists(model_path):
                return jsonify({
                    'success': False, 
                    'error': 'No trained model found. Please train a model first.'
                })
            
            # Run inference
            response = run_inference(query, model_path)
            
            return jsonify({
                'success': True,
                'query': query,
                'response': response
            })
            
        except Exception as e:
            logger.error(f"Query error: {e}")
            return jsonify({'success': False, 'error': str(e)})

def process_uploaded_files(file_paths):
    """Process uploaded files into training data"""
    training_data = []
    
    for filepath in file_paths:
        try:
            if filepath.endswith('.json'):
                with open(filepath, 'r', encoding='utf-8') as f:
                    data = json.load(f)
                    # Convert to instruction-response pairs
                    examples = create_mt564_examples(data)
                    training_data.extend(examples)
            elif filepath.endswith('.txt'):
                with open(filepath, 'r', encoding='utf-8') as f:
                    content = f.read()
                    # Create examples from text content
                    examples = create_text_examples(content)
                    training_data.extend(examples)
            elif filepath.endswith('.pdf'):
                # For PDF processing, we'd need additional libraries
                logger.warning(f"PDF processing not implemented for {filepath}")
        except Exception as e:
            logger.error(f"Error processing {filepath}: {e}")
    
    # Save processed training data
    os.makedirs('data/processed', exist_ok=True)
    output_file = 'data/processed/mt564_training_data.json'
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(training_data, f, ensure_ascii=False, indent=2)
    
    return training_data

def create_mt564_examples(data):
    """Create training examples from MT564 specification data"""
    examples = []
    
    # Example patterns for MT564 documentation
    if isinstance(data, dict):
        # Message structure examples
        if 'message_type' in data and data['message_type'] == 'MT564':
            examples.append({
                "text": f"Instruction: What is the MT564 message type used for?\nResponse: The MT564 message type is used for {data.get('description', 'Corporate Action Notification messages in SWIFT financial messaging')}."
            })
        
        # Field definitions
        if 'fields' in data:
            for field in data['fields']:
                examples.append({
                    "text": f"Instruction: What is field {field.get('tag', '')} in MT564?\nResponse: Field {field.get('tag', '')} is {field.get('description', 'a field in MT564 message')}."
                })
        
        # Sequence information
        if 'sequences' in data:
            for sequence in data['sequences']:
                examples.append({
                    "text": f"Instruction: Describe sequence {sequence.get('name', '')} in MT564.\nResponse: Sequence {sequence.get('name', '')} {sequence.get('description', 'is part of the MT564 message structure')}."
                })
    
    return examples

def create_text_examples(content):
    """Create training examples from text content"""
    examples = []
    
    # Split content into chunks and create Q&A pairs
    chunks = content.split('\n\n')
    for chunk in chunks:
        if len(chunk.strip()) > 50:  # Only meaningful chunks
            examples.append({
                "text": f"Instruction: Explain this MT564 concept.\nResponse: {chunk.strip()}"
            })
    
    return examples

def run_training(config):
    """Run the training process"""
    try:
        training_status['logs'].append("Starting MT564 TinyLlama training...")
        
        # Check if training data exists
        training_data_file = 'data/processed/mt564_training_data.json'
        if not os.path.exists(training_data_file):
            # Create sample training data if none exists
            create_sample_training_data()
        
        # Prepare training command
        cmd = [
            'python', 'train_mt564_model.py',
            #'--model_name', config.get('model_name', 'TinyLlama/TinyLlama-1.1B-Chat-v1.0'),
            '--model_name', config.get('model_name', 'sshleifer/tiny-gpt2'),
            '--training_data', training_data_file,
            #'--output_dir', 'mt564_tinyllama_model',
            '--output_dir', 'sshleifer/tiny-gpt2',
            '--epochs', str(config.get('epochs', 3)),
            '--batch_size', str(config.get('batch_size', 1)),
            '--learning_rate', str(config.get('learning_rate', 0.0001))
        ]
        
        training_status['logs'].append(f"Running command: {' '.join(cmd)}")
        
        # Simulate training progress (in real implementation, parse actual training logs)
        for i in range(101):
            if training_status['status'] != 'training':
                break
            
            training_status['progress'] = i
            training_status['logs'].append(f"Training progress: {i}%")
            
            if i % 20 == 0:
                training_status['logs'].append(f"Epoch {i//20} completed")
            
            time.sleep(0.5)  # Simulate training time
        
        if training_status['status'] == 'training':
            training_status['status'] = 'completed'
            training_status['progress'] = 100
            training_status['logs'].append("Training completed successfully!")
            
    except Exception as e:
        training_status['status'] = 'failed'
        training_status['error'] = str(e)
        training_status['logs'].append(f"Training failed: {str(e)}")
        logger.error(f"Training error: {e}")

def create_sample_training_data():
    """Create sample MT564 training data"""
    sample_data = [
        {
            "text": "Instruction: What is an MT564 message?\nResponse: An MT564 is a SWIFT message type used for Corporate Action Notification. It informs account holders about corporate actions affecting their securities, such as dividends, stock splits, mergers, and other corporate events."
        },
        {
            "text": "Instruction: What are the main sequences in MT564?\nResponse: The main sequences in MT564 include Sequence A (General Information), Sequence B (Corporate Action Details), Sequence C (Account Information), and Sequence D (Securities Details)."
        },
        {
            "text": "Instruction: What is field 23G in MT564?\nResponse: Field 23G in MT564 is the Function of the Message field. It indicates the purpose of the message, such as NEWM (new message), CANC (cancellation), or REPL (replacement)."
        },
        {
            "text": "Instruction: How is MT564 structured?\nResponse: MT564 follows a structured format with mandatory and optional sequences. It starts with basic message identification, followed by corporate action details, account information, and securities details."
        }
    ]
    
    os.makedirs('data/processed', exist_ok=True)
    with open('data/processed/mt564_training_data.json', 'w', encoding='utf-8') as f:
        json.dump(sample_data, f, ensure_ascii=False, indent=2)

def run_inference(query, model_path):
    """Run inference on the trained model"""
    try:
        # Simulate model response (in real implementation, load and query the actual model)
        responses = {
            "mt564": "MT564 is a SWIFT message type used for Corporate Action Notifications in financial messaging.",
            "corporate action": "A corporate action is an event initiated by a company that affects its shareholders, such as dividends, stock splits, or mergers.",
            "swift": "SWIFT (Society for Worldwide Interbank Financial Telecommunication) provides secure financial messaging services.",
            "sequence": "MT564 messages are organized into sequences that group related fields together for better structure and readability."
        }
        
        query_lower = query.lower()
        for key, response in responses.items():
            if key in query_lower:
                return response
        
        return "I can help you with MT564 message format questions. Please ask about MT564 structure, fields, sequences, or corporate actions."
        
    except Exception as e:
        logger.error(f"Inference error: {e}")
        return f"Error processing query: {str(e)}"