Spaces:
Sleeping
Sleeping
File size: 5,572 Bytes
14832ec 10e4df1 2938ce7 6982352 0ff91c7 7e2d457 0ff91c7 7e2d457 0ff91c7 7e2d457 0ff91c7 7e2d457 0ff91c7 7e2d457 0ff91c7 7e2d457 0ff91c7 7e2d457 0ff91c7 7e2d457 0ff91c7 7e2d457 0ff91c7 7e2d457 0ff91c7 7e2d457 0ff91c7 7e2d457 0ff91c7 6982352 7e2d457 6982352 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import os
os.system("git clone https://github.com/SakanaAI/continuous-thought-machines")
os.system("pip install -r ./continuous-thought-machines/requirements.txt")
from flask import Flask, render_template, request, send_file, jsonify
import torch
from torchvision import transforms
from PIL import Image
import os
import numpy as np
from model import ContinuousThoughtMachine
from utils import make_gif, prepare_data, train
import time
import threading
app = Flask(__name__)
# Configuration
UPLOAD_FOLDER = 'static/output'
MODEL_PATH = 'model.pth'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
TRAINING_IN_PROGRESS = False
TRAINING_STATUS = {"status": "idle", "progress": 0, "message": ""}
# Ensure upload folder exists
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
# Initialize model
model = ContinuousThoughtMachine(
iterations=30,
d_model=128,
d_input=32,
memory_length=15,
heads=1,
n_synch_out=16,
n_synch_action=16,
memory_hidden_dims=8,
out_dims=10,
).to(DEVICE)
# Load pre-trained weights if available
if os.path.exists(MODEL_PATH):
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()
else:
print("Model weights not found. Use the training option to create a new model.")
# Image preprocessing
transform = transforms.Compose([
transforms.Grayscale(num_output_channels=1),
transforms.Resize((28, 28)),
transforms.ToTensor(),
])
@app.route('/')
def index():
return render_template('index.html', training_status=TRAINING_STATUS)
@app.route('/predict', methods=['POST'])
def predict():
if not os.path.exists(MODEL_PATH):
return render_template('index.html', error='No trained model available. Please train the model first.', training_status=TRAINING_STATUS)
if 'file' not in request.files:
return render_template('index.html', error='No file uploaded', training_status=TRAINING_STATUS)
file = request.files['file']
if file.filename == '':
return render_template('index.html', error='No file selected', training_status=TRAINING_STATUS)
try:
# Process the uploaded image
image = Image.open(file)
image = transform(image).unsqueeze(0).to(DEVICE)
# Run inference
with torch.inference_mode():
predictions, certainties, (synch_out_tracking, synch_action_tracking), \
pre_activations_tracking, post_activations_tracking, attention = model(image, track=True)
# Get prediction at the most certain tick
where_most_certain = certainties[0, 1].argmax(-1).item()
prediction = predictions[0, :, where_most_certain].argmax().item()
certainty = certainties[0, 1, where_most_certain].item()
# Generate GIF
timestamp = int(time.time())
gif_filename = os.path.join(UPLOAD_FOLDER, f'output_{timestamp}.gif')
make_gif(
predictions.detach().cpu().numpy(),
certainties.detach().cpu().numpy(),
np.array([prediction]), # Dummy target for visualization
pre_activations_tracking,
post_activations_tracking,
attention,
image.detach().cpu().numpy(),
gif_filename
)
return render_template('index.html',
prediction=prediction,
certainty=f"{certainty:.2%}",
gif_path=gif_filename,
training_status=TRAINING_STATUS)
except Exception as e:
return render_template('index.html', error=str(e), training_status=TRAINING_STATUS)
@app.route('/train', methods=['POST'])
def train_model():
global TRAINING_IN_PROGRESS, TRAINING_STATUS
if TRAINING_IN_PROGRESS:
return jsonify({"error": "Training is already in progress."})
try:
iterations = int(request.form.get('iterations', 1000))
lr = float(request.form.get('lr', 0.0001))
if iterations < 100 or iterations > 5000:
return jsonify({"error": "Iterations must be between 100 and 5000."})
if lr < 1e-6 or lr > 1e-2:
return jsonify({"error": "Learning rate must be between 1e-6 and 1e-2."})
# Start training in a separate thread
TRAINING_IN_PROGRESS = True
TRAINING_STATUS = {"status": "running", "progress": 0, "message": "Starting training..."}
def train_and_save():
global TRAINING_STATUS
try:
trainloader, testloader = prepare_data()
model.train()
model = train(model=model, trainloader=trainloader, testloader=testloader, iterations=iterations, device=DEVICE, lr=lr, status=TRAINING_STATUS)
torch.save(model.state_dict(), MODEL_PATH)
model.eval()
TRAINING_STATUS = {"status": "completed", "progress": 100, "message": "Training completed and model saved."}
except Exception as e:
TRAINING_STATUS = {"status": "error", "progress": 0, "message": f"Training failed: {str(e)}"}
finally:
global TRAINING_IN_PROGRESS
TRAINING_IN_PROGRESS = False
threading.Thread(target=train_and_save).start()
return jsonify({"message": "Training started."})
except ValueError as e:
return jsonify({"error": "Invalid input parameters."})
@app.route('/training_status')
def get_training_status():
return jsonify(TRAINING_STATUS)
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)
|