File size: 5,572 Bytes
14832ec
10e4df1
2938ce7
6982352
0ff91c7
7e2d457
 
 
 
 
 
0ff91c7
7e2d457
0ff91c7
7e2d457
 
 
 
 
 
 
0ff91c7
 
7e2d457
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ff91c7
7e2d457
 
0ff91c7
7e2d457
0ff91c7
7e2d457
 
 
 
 
 
 
 
 
 
0ff91c7
7e2d457
 
 
0ff91c7
 
 
7e2d457
0ff91c7
7e2d457
 
 
0ff91c7
7e2d457
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ff91c7
 
7e2d457
 
0ff91c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6982352
7e2d457
 
6982352
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import os
os.system("git clone https://github.com/SakanaAI/continuous-thought-machines")
os.system("pip install -r ./continuous-thought-machines/requirements.txt")

from flask import Flask, render_template, request, send_file, jsonify
import torch
from torchvision import transforms
from PIL import Image
import os
import numpy as np
from model import ContinuousThoughtMachine
from utils import make_gif, prepare_data, train
import time
import threading

app = Flask(__name__)

# Configuration
UPLOAD_FOLDER = 'static/output'
MODEL_PATH = 'model.pth'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
TRAINING_IN_PROGRESS = False
TRAINING_STATUS = {"status": "idle", "progress": 0, "message": ""}

# Ensure upload folder exists
os.makedirs(UPLOAD_FOLDER, exist_ok=True)

# Initialize model
model = ContinuousThoughtMachine(
    iterations=30,
    d_model=128,
    d_input=32,
    memory_length=15,
    heads=1,
    n_synch_out=16,
    n_synch_action=16,
    memory_hidden_dims=8,
    out_dims=10,
).to(DEVICE)

# Load pre-trained weights if available
if os.path.exists(MODEL_PATH):
    model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
    model.eval()
else:
    print("Model weights not found. Use the training option to create a new model.")

# Image preprocessing
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
])

@app.route('/')
def index():
    return render_template('index.html', training_status=TRAINING_STATUS)

@app.route('/predict', methods=['POST'])
def predict():
    if not os.path.exists(MODEL_PATH):
        return render_template('index.html', error='No trained model available. Please train the model first.', training_status=TRAINING_STATUS)

    if 'file' not in request.files:
        return render_template('index.html', error='No file uploaded', training_status=TRAINING_STATUS)

    file = request.files['file']
    if file.filename == '':
        return render_template('index.html', error='No file selected', training_status=TRAINING_STATUS)

    try:
        # Process the uploaded image
        image = Image.open(file)
        image = transform(image).unsqueeze(0).to(DEVICE)

        # Run inference
        with torch.inference_mode():
            predictions, certainties, (synch_out_tracking, synch_action_tracking), \
            pre_activations_tracking, post_activations_tracking, attention = model(image, track=True)

        # Get prediction at the most certain tick
        where_most_certain = certainties[0, 1].argmax(-1).item()
        prediction = predictions[0, :, where_most_certain].argmax().item()
        certainty = certainties[0, 1, where_most_certain].item()

        # Generate GIF
        timestamp = int(time.time())
        gif_filename = os.path.join(UPLOAD_FOLDER, f'output_{timestamp}.gif')
        make_gif(
            predictions.detach().cpu().numpy(),
            certainties.detach().cpu().numpy(),
            np.array([prediction]),  # Dummy target for visualization
            pre_activations_tracking,
            post_activations_tracking,
            attention,
            image.detach().cpu().numpy(),
            gif_filename
        )

        return render_template('index.html',
                              prediction=prediction,
                              certainty=f"{certainty:.2%}",
                              gif_path=gif_filename,
                              training_status=TRAINING_STATUS)

    except Exception as e:
        return render_template('index.html', error=str(e), training_status=TRAINING_STATUS)

@app.route('/train', methods=['POST'])
def train_model():
    global TRAINING_IN_PROGRESS, TRAINING_STATUS

    if TRAINING_IN_PROGRESS:
        return jsonify({"error": "Training is already in progress."})

    try:
        iterations = int(request.form.get('iterations', 1000))
        lr = float(request.form.get('lr', 0.0001))

        if iterations < 100 or iterations > 5000:
            return jsonify({"error": "Iterations must be between 100 and 5000."})
        if lr < 1e-6 or lr > 1e-2:
            return jsonify({"error": "Learning rate must be between 1e-6 and 1e-2."})

        # Start training in a separate thread
        TRAINING_IN_PROGRESS = True
        TRAINING_STATUS = {"status": "running", "progress": 0, "message": "Starting training..."}

        def train_and_save():
            global TRAINING_STATUS
            try:
                trainloader, testloader = prepare_data()
                model.train()
                model = train(model=model, trainloader=trainloader, testloader=testloader, iterations=iterations, device=DEVICE, lr=lr, status=TRAINING_STATUS)
                torch.save(model.state_dict(), MODEL_PATH)
                model.eval()
                TRAINING_STATUS = {"status": "completed", "progress": 100, "message": "Training completed and model saved."}
            except Exception as e:
                TRAINING_STATUS = {"status": "error", "progress": 0, "message": f"Training failed: {str(e)}"}
            finally:
                global TRAINING_IN_PROGRESS
                TRAINING_IN_PROGRESS = False

        threading.Thread(target=train_and_save).start()
        return jsonify({"message": "Training started."})

    except ValueError as e:
        return jsonify({"error": "Invalid input parameters."})

@app.route('/training_status')
def get_training_status():
    return jsonify(TRAINING_STATUS)

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=7860)