import os os.system("git clone https://github.com/SakanaAI/continuous-thought-machines") os.system("pip install -r ./continuous-thought-machines/requirements.txt") from flask import Flask, render_template, request, send_file, jsonify import torch from torchvision import transforms from PIL import Image import os import numpy as np from model import ContinuousThoughtMachine from utils import make_gif, prepare_data, train import time import threading app = Flask(__name__) # Configuration UPLOAD_FOLDER = 'static/output' MODEL_PATH = 'model.pth' DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') TRAINING_IN_PROGRESS = False TRAINING_STATUS = {"status": "idle", "progress": 0, "message": ""} # Ensure upload folder exists os.makedirs(UPLOAD_FOLDER, exist_ok=True) # Initialize model model = ContinuousThoughtMachine( iterations=30, d_model=128, d_input=32, memory_length=15, heads=1, n_synch_out=16, n_synch_action=16, memory_hidden_dims=8, out_dims=10, ).to(DEVICE) # Load pre-trained weights if available if os.path.exists(MODEL_PATH): model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE)) model.eval() else: print("Model weights not found. Use the training option to create a new model.") # Image preprocessing transform = transforms.Compose([ transforms.Grayscale(num_output_channels=1), transforms.Resize((28, 28)), transforms.ToTensor(), ]) @app.route('/') def index(): return render_template('index.html', training_status=TRAINING_STATUS) @app.route('/predict', methods=['POST']) def predict(): if not os.path.exists(MODEL_PATH): return render_template('index.html', error='No trained model available. Please train the model first.', training_status=TRAINING_STATUS) if 'file' not in request.files: return render_template('index.html', error='No file uploaded', training_status=TRAINING_STATUS) file = request.files['file'] if file.filename == '': return render_template('index.html', error='No file selected', training_status=TRAINING_STATUS) try: # Process the uploaded image image = Image.open(file) image = transform(image).unsqueeze(0).to(DEVICE) # Run inference with torch.inference_mode(): predictions, certainties, (synch_out_tracking, synch_action_tracking), \ pre_activations_tracking, post_activations_tracking, attention = model(image, track=True) # Get prediction at the most certain tick where_most_certain = certainties[0, 1].argmax(-1).item() prediction = predictions[0, :, where_most_certain].argmax().item() certainty = certainties[0, 1, where_most_certain].item() # Generate GIF timestamp = int(time.time()) gif_filename = os.path.join(UPLOAD_FOLDER, f'output_{timestamp}.gif') make_gif( predictions.detach().cpu().numpy(), certainties.detach().cpu().numpy(), np.array([prediction]), # Dummy target for visualization pre_activations_tracking, post_activations_tracking, attention, image.detach().cpu().numpy(), gif_filename ) return render_template('index.html', prediction=prediction, certainty=f"{certainty:.2%}", gif_path=gif_filename, training_status=TRAINING_STATUS) except Exception as e: return render_template('index.html', error=str(e), training_status=TRAINING_STATUS) @app.route('/train', methods=['POST']) def train_model(): global TRAINING_IN_PROGRESS, TRAINING_STATUS if TRAINING_IN_PROGRESS: return jsonify({"error": "Training is already in progress."}) try: iterations = int(request.form.get('iterations', 1000)) lr = float(request.form.get('lr', 0.0001)) if iterations < 100 or iterations > 5000: return jsonify({"error": "Iterations must be between 100 and 5000."}) if lr < 1e-6 or lr > 1e-2: return jsonify({"error": "Learning rate must be between 1e-6 and 1e-2."}) # Start training in a separate thread TRAINING_IN_PROGRESS = True TRAINING_STATUS = {"status": "running", "progress": 0, "message": "Starting training..."} def train_and_save(): global TRAINING_STATUS try: trainloader, testloader = prepare_data() model.train() model = train(model=model, trainloader=trainloader, testloader=testloader, iterations=iterations, device=DEVICE, lr=lr, status=TRAINING_STATUS) torch.save(model.state_dict(), MODEL_PATH) model.eval() TRAINING_STATUS = {"status": "completed", "progress": 100, "message": "Training completed and model saved."} except Exception as e: TRAINING_STATUS = {"status": "error", "progress": 0, "message": f"Training failed: {str(e)}"} finally: global TRAINING_IN_PROGRESS TRAINING_IN_PROGRESS = False threading.Thread(target=train_and_save).start() return jsonify({"message": "Training started."}) except ValueError as e: return jsonify({"error": "Invalid input parameters."}) @app.route('/training_status') def get_training_status(): return jsonify(TRAINING_STATUS) if __name__ == '__main__': app.run(host='0.0.0.0', port=7860)