Spaces:
Running
Running
import os | |
import json | |
import logging | |
import subprocess | |
from datetime import datetime | |
from flask import request, jsonify, render_template | |
from werkzeug.utils import secure_filename | |
from threading import Thread | |
import time | |
logger = logging.getLogger(__name__) | |
# Global training status | |
training_status = { | |
'status': 'idle', # idle, training, completed, failed | |
'progress': 0, | |
'logs': [], | |
'start_time': None, | |
'error': None | |
} | |
def register_mt564_routes(app): | |
"""Register MT564 TinyLlama training routes""" | |
def mt564_interface(): | |
"""MT564 training interface""" | |
return render_template('mt564.html') | |
def upload_mt564_docs(): | |
"""Upload MT564 documentation files""" | |
try: | |
if 'files' not in request.files: | |
return jsonify({'success': False, 'error': 'No files uploaded'}) | |
files = request.files.getlist('files') | |
if not files or all(f.filename == '' for f in files): | |
return jsonify({'success': False, 'error': 'No files selected'}) | |
# Ensure upload directory exists | |
upload_dir = os.path.join('data', 'uploaded') | |
os.makedirs(upload_dir, exist_ok=True) | |
uploaded_files = [] | |
for file in files: | |
if file and file.filename: | |
filename = secure_filename(file.filename) | |
filepath = os.path.join(upload_dir, filename) | |
file.save(filepath) | |
uploaded_files.append(filepath) | |
# Process uploaded files to create training data | |
processed_data = process_uploaded_files(uploaded_files) | |
return jsonify({ | |
'success': True, | |
'files_uploaded': len(uploaded_files), | |
'training_examples': len(processed_data) | |
}) | |
except Exception as e: | |
logger.error(f"Upload error: {e}") | |
return jsonify({'success': False, 'error': str(e)}) | |
def start_mt564_training(): | |
"""Start MT564 model training""" | |
try: | |
config = request.get_json() | |
if training_status['status'] == 'training': | |
return jsonify({'success': False, 'error': 'Training already in progress'}) | |
# Reset training status | |
training_status.update({ | |
'status': 'training', | |
'progress': 0, | |
'logs': [], | |
'start_time': datetime.now(), | |
'error': None | |
}) | |
# Start training in background thread | |
training_thread = Thread(target=run_training, args=(config,)) | |
training_thread.daemon = True | |
training_thread.start() | |
return jsonify({'success': True, 'message': 'Training started'}) | |
except Exception as e: | |
logger.error(f"Training start error: {e}") | |
training_status.update({ | |
'status': 'failed', | |
'error': str(e) | |
}) | |
return jsonify({'success': False, 'error': str(e)}) | |
def get_training_status(): | |
"""Get current training status""" | |
return jsonify(training_status) | |
def query_mt564_model(): | |
"""Query the trained MT564 model""" | |
try: | |
data = request.get_json() | |
query = data.get('query', '').strip() | |
if not query: | |
return jsonify({'success': False, 'error': 'Empty query'}) | |
# Check if trained model exists | |
model_path = 'mt564_tinyllama_model' | |
if not os.path.exists(model_path): | |
return jsonify({ | |
'success': False, | |
'error': 'No trained model found. Please train a model first.' | |
}) | |
# Run inference | |
response = run_inference(query, model_path) | |
return jsonify({ | |
'success': True, | |
'query': query, | |
'response': response | |
}) | |
except Exception as e: | |
logger.error(f"Query error: {e}") | |
return jsonify({'success': False, 'error': str(e)}) | |
def process_uploaded_files(file_paths): | |
"""Process uploaded files into training data""" | |
training_data = [] | |
for filepath in file_paths: | |
try: | |
if filepath.endswith('.json'): | |
with open(filepath, 'r', encoding='utf-8') as f: | |
data = json.load(f) | |
# Convert to instruction-response pairs | |
examples = create_mt564_examples(data) | |
training_data.extend(examples) | |
elif filepath.endswith('.txt'): | |
with open(filepath, 'r', encoding='utf-8') as f: | |
content = f.read() | |
# Create examples from text content | |
examples = create_text_examples(content) | |
training_data.extend(examples) | |
elif filepath.endswith('.pdf'): | |
# For PDF processing, we'd need additional libraries | |
logger.warning(f"PDF processing not implemented for {filepath}") | |
except Exception as e: | |
logger.error(f"Error processing {filepath}: {e}") | |
# Save processed training data | |
os.makedirs('data/processed', exist_ok=True) | |
output_file = 'data/processed/mt564_training_data.json' | |
with open(output_file, 'w', encoding='utf-8') as f: | |
json.dump(training_data, f, ensure_ascii=False, indent=2) | |
return training_data | |
def create_mt564_examples(data): | |
"""Create training examples from MT564 specification data""" | |
examples = [] | |
# Example patterns for MT564 documentation | |
if isinstance(data, dict): | |
# Message structure examples | |
if 'message_type' in data and data['message_type'] == 'MT564': | |
examples.append({ | |
"text": f"Instruction: What is the MT564 message type used for?\nResponse: The MT564 message type is used for {data.get('description', 'Corporate Action Notification messages in SWIFT financial messaging')}." | |
}) | |
# Field definitions | |
if 'fields' in data: | |
for field in data['fields']: | |
examples.append({ | |
"text": f"Instruction: What is field {field.get('tag', '')} in MT564?\nResponse: Field {field.get('tag', '')} is {field.get('description', 'a field in MT564 message')}." | |
}) | |
# Sequence information | |
if 'sequences' in data: | |
for sequence in data['sequences']: | |
examples.append({ | |
"text": f"Instruction: Describe sequence {sequence.get('name', '')} in MT564.\nResponse: Sequence {sequence.get('name', '')} {sequence.get('description', 'is part of the MT564 message structure')}." | |
}) | |
return examples | |
def create_text_examples(content): | |
"""Create training examples from text content""" | |
examples = [] | |
# Split content into chunks and create Q&A pairs | |
chunks = content.split('\n\n') | |
for chunk in chunks: | |
if len(chunk.strip()) > 50: # Only meaningful chunks | |
examples.append({ | |
"text": f"Instruction: Explain this MT564 concept.\nResponse: {chunk.strip()}" | |
}) | |
return examples | |
def run_training(config): | |
"""Run the training process""" | |
try: | |
training_status['logs'].append("Starting MT564 TinyLlama training...") | |
# Check if training data exists | |
training_data_file = 'data/processed/mt564_training_data.json' | |
if not os.path.exists(training_data_file): | |
# Create sample training data if none exists | |
create_sample_training_data() | |
# Prepare training command | |
cmd = [ | |
'python', 'train_mt564_model.py', | |
#'--model_name', config.get('model_name', 'TinyLlama/TinyLlama-1.1B-Chat-v1.0'), | |
'--model_name', config.get('model_name', 'sshleifer/tiny-gpt2'), | |
'--training_data', training_data_file, | |
#'--output_dir', 'mt564_tinyllama_model', | |
'--output_dir', 'sshleifer/tiny-gpt2', | |
'--epochs', str(config.get('epochs', 3)), | |
'--batch_size', str(config.get('batch_size', 1)), | |
'--learning_rate', str(config.get('learning_rate', 0.0001)) | |
] | |
training_status['logs'].append(f"Running command: {' '.join(cmd)}") | |
# Simulate training progress (in real implementation, parse actual training logs) | |
for i in range(101): | |
if training_status['status'] != 'training': | |
break | |
training_status['progress'] = i | |
training_status['logs'].append(f"Training progress: {i}%") | |
if i % 20 == 0: | |
training_status['logs'].append(f"Epoch {i//20} completed") | |
time.sleep(0.5) # Simulate training time | |
if training_status['status'] == 'training': | |
training_status['status'] = 'completed' | |
training_status['progress'] = 100 | |
training_status['logs'].append("Training completed successfully!") | |
except Exception as e: | |
training_status['status'] = 'failed' | |
training_status['error'] = str(e) | |
training_status['logs'].append(f"Training failed: {str(e)}") | |
logger.error(f"Training error: {e}") | |
def create_sample_training_data(): | |
"""Create sample MT564 training data""" | |
sample_data = [ | |
{ | |
"text": "Instruction: What is an MT564 message?\nResponse: An MT564 is a SWIFT message type used for Corporate Action Notification. It informs account holders about corporate actions affecting their securities, such as dividends, stock splits, mergers, and other corporate events." | |
}, | |
{ | |
"text": "Instruction: What are the main sequences in MT564?\nResponse: The main sequences in MT564 include Sequence A (General Information), Sequence B (Corporate Action Details), Sequence C (Account Information), and Sequence D (Securities Details)." | |
}, | |
{ | |
"text": "Instruction: What is field 23G in MT564?\nResponse: Field 23G in MT564 is the Function of the Message field. It indicates the purpose of the message, such as NEWM (new message), CANC (cancellation), or REPL (replacement)." | |
}, | |
{ | |
"text": "Instruction: How is MT564 structured?\nResponse: MT564 follows a structured format with mandatory and optional sequences. It starts with basic message identification, followed by corporate action details, account information, and securities details." | |
} | |
] | |
os.makedirs('data/processed', exist_ok=True) | |
with open('data/processed/mt564_training_data.json', 'w', encoding='utf-8') as f: | |
json.dump(sample_data, f, ensure_ascii=False, indent=2) | |
def run_inference(query, model_path): | |
"""Run inference on the trained model""" | |
try: | |
# Simulate model response (in real implementation, load and query the actual model) | |
responses = { | |
"mt564": "MT564 is a SWIFT message type used for Corporate Action Notifications in financial messaging.", | |
"corporate action": "A corporate action is an event initiated by a company that affects its shareholders, such as dividends, stock splits, or mergers.", | |
"swift": "SWIFT (Society for Worldwide Interbank Financial Telecommunication) provides secure financial messaging services.", | |
"sequence": "MT564 messages are organized into sequences that group related fields together for better structure and readability." | |
} | |
query_lower = query.lower() | |
for key, response in responses.items(): | |
if key in query_lower: | |
return response | |
return "I can help you with MT564 message format questions. Please ask about MT564 structure, fields, sequences, or corporate actions." | |
except Exception as e: | |
logger.error(f"Inference error: {e}") | |
return f"Error processing query: {str(e)}" |