Spaces:
Sleeping
Sleeping
<html lang="en"> | |
<head> | |
<meta charset="UTF-8"> | |
<meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
<title>Compare Models - MNIST</title> | |
<link rel="stylesheet" href="{{ url_for('static', path='/css/style.css') }}"> | |
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&family=Roboto+Mono&display=swap" rel="stylesheet"> | |
</head> | |
<body> | |
<div class="container"> | |
<h1>Compare Models</h1> | |
<div class="models-grid"> | |
<!-- Model A Configuration --> | |
<div class="model-config"> | |
<h3>Model A</h3> | |
<div class="network-config"> | |
<h4>Network Architecture</h4> | |
<div class="block-config"> | |
<div class="block"> | |
<label for="model1_block1">Block-1:</label> | |
<select id="model1_block1" name="block1" class="form-select"> | |
<option value="8">8</option> | |
<option value="16">16</option> | |
<option value="32" selected>32</option> | |
<option value="64">64</option> | |
<option value="128">128</option> | |
</select> | |
</div> | |
<div class="block"> | |
<label for="model1_block2">Block-2:</label> | |
<select id="model1_block2" name="block2" class="form-select"> | |
<option value="8">8</option> | |
<option value="16">16</option> | |
<option value="32">32</option> | |
<option value="64" selected>64</option> | |
<option value="128">128</option> | |
</select> | |
</div> | |
<div class="block"> | |
<label for="model1_block3">Block-3:</label> | |
<select id="model1_block3" name="block3" class="form-select"> | |
<option value="8">8</option> | |
<option value="16">16</option> | |
<option value="32">32</option> | |
<option value="64">64</option> | |
<option value="128" selected>128</option> | |
</select> | |
</div> | |
</div> | |
</div> | |
<div class="training-config"> | |
<div class="config-item"> | |
<label for="model1_optimizer">Optimizer:</label> | |
<select id="model1_optimizer" name="optimizer"> | |
<option value="SGD" selected>SGD</option> | |
<option value="Adam">Adam</option> | |
</select> | |
</div> | |
<div class="config-item"> | |
<label for="model1_batch_size">Batch Size:</label> | |
<select id="model1_batch_size" name="batch_size"> | |
<option value="32">32</option> | |
<option value="64" selected>64</option> | |
<option value="128">128</option> | |
</select> | |
</div> | |
<div class="config-item"> | |
<label for="model1_epochs">Epochs:</label> | |
<select id="model1_epochs" name="epochs"> | |
<option value="1">1</option> | |
<option value="2">2</option> | |
<option value="3">3</option> | |
</select> | |
</div> | |
</div> | |
</div> | |
<!-- Model B Configuration --> | |
<div class="model-config"> | |
<h3>Model B</h3> | |
<div class="network-config"> | |
<h4>Network Architecture</h4> | |
<div class="block-config"> | |
<div class="block"> | |
<label for="model2_block1">Block-1:</label> | |
<select id="model2_block1" name="block1" class="form-select"> | |
<option value="8">8</option> | |
<option value="16">16</option> | |
<option value="32" selected>32</option> | |
<option value="64">64</option> | |
<option value="128">128</option> | |
</select> | |
</div> | |
<div class="block"> | |
<label for="model2_block2">Block-2:</label> | |
<select id="model2_block2" name="block2" class="form-select"> | |
<option value="8">8</option> | |
<option value="16">16</option> | |
<option value="32">32</option> | |
<option value="64" selected>64</option> | |
<option value="128">128</option> | |
</select> | |
</div> | |
<div class="block"> | |
<label for="model2_block3">Block-3:</label> | |
<select id="model2_block3" name="block3" class="form-select"> | |
<option value="8">8</option> | |
<option value="16">16</option> | |
<option value="32">32</option> | |
<option value="64">64</option> | |
<option value="128" selected>128</option> | |
</select> | |
</div> | |
</div> | |
</div> | |
<div class="training-config"> | |
<div class="config-item"> | |
<label for="model2_optimizer">Optimizer:</label> | |
<select id="model2_optimizer" name="optimizer"> | |
<option value="SGD" selected>SGD</option> | |
<option value="Adam">Adam</option> | |
</select> | |
</div> | |
<div class="config-item"> | |
<label for="model2_batch_size">Batch Size:</label> | |
<select id="model2_batch_size" name="batch_size"> | |
<option value="32">32</option> | |
<option value="64" selected>64</option> | |
<option value="128">128</option> | |
</select> | |
</div> | |
<div class="config-item"> | |
<label for="model2_epochs">Epochs:</label> | |
<select id="model2_epochs" name="epochs"> | |
<option value="1">1</option> | |
<option value="2">2</option> | |
<option value="3">3</option> | |
</select> | |
</div> | |
</div> | |
</div> | |
</div> | |
<!-- Training Controls --> | |
<div class="controls"> | |
<button id="startComparison" onclick="startComparison()">Start Comparison</button> | |
<button id="stopComparison" onclick="stopComparison()" disabled>Stop Comparison</button> | |
</div> | |
<!-- Training Progress --> | |
<div class="charts-container"> | |
<div id="lossChart"></div> | |
<div id="accuracyChart"></div> | |
</div> | |
<!-- Add this after the charts container --> | |
<div class="training-status"> | |
<p id="training-progress"></p> | |
</div> | |
<!-- Add this after the training-status div --> | |
<div class="inference-controls" style="display: none;"> | |
<button id="goToInference" onclick="window.location.href='/inference'" class="inference-button"> | |
Try Model Inference | |
</button> | |
</div> | |
</div> | |
<style> | |
.container { | |
max-width: 1200px; | |
margin: 0 auto; | |
padding: 20px; | |
} | |
.models-grid { | |
display: grid; | |
grid-template-columns: 1fr 1fr; | |
gap: 20px; | |
margin-bottom: 20px; | |
} | |
.model-config { | |
padding: 20px; | |
border: 1px solid #ddd; | |
border-radius: 5px; | |
margin-bottom: 20px; | |
} | |
.network-config { | |
margin-bottom: 20px; | |
} | |
.network-config h4 { | |
margin: 0 0 15px 0; | |
font-size: 1.1em; | |
} | |
.block-config { | |
display: flex; | |
justify-content: space-between; | |
gap: 20px; | |
} | |
.block { | |
flex: 1; | |
} | |
.block label { | |
display: block; | |
margin-bottom: 5px; | |
font-weight: bold; | |
} | |
.training-config { | |
display: flex; | |
gap: 20px; | |
} | |
.config-item { | |
flex: 1; | |
} | |
.config-item label { | |
display: block; | |
margin-bottom: 5px; | |
font-weight: bold; | |
} | |
select { | |
width: 100%; | |
padding: 8px; | |
border: 1px solid #ddd; | |
border-radius: 4px; | |
} | |
.controls { | |
margin: 20px 0; | |
text-align: center; | |
} | |
button { | |
padding: 10px 20px; | |
margin-right: 10px; | |
border: none; | |
border-radius: 4px; | |
background-color: #007bff; | |
color: white; | |
cursor: pointer; | |
} | |
button:disabled { | |
background-color: #ccc; | |
cursor: not-allowed; | |
} | |
.charts-container { | |
display: flex; | |
flex-direction: column; | |
gap: 20px; | |
margin-top: 20px; | |
} | |
#lossChart, #accuracyChart { | |
height: 400px; | |
width: 100%; | |
} | |
h4 { | |
margin: 0 0 10px 0; | |
} | |
.section-title { | |
color: white; | |
font-weight: bold; | |
margin: 0 0 10px 0; | |
font-size: 1.1em; | |
text-transform: uppercase; | |
} | |
.network-config .section-title { | |
margin: 0 0 15px 0; | |
} | |
.config-item .section-title { | |
margin-bottom: 5px; | |
} | |
.training-status { | |
text-align: center; | |
margin: 20px 0; | |
font-weight: bold; | |
} | |
.inference-controls { | |
margin: 20px 0; | |
text-align: center; | |
} | |
.inference-button { | |
background-color: #28a745; | |
padding: 12px 24px; | |
font-size: 1.1em; | |
transition: background-color 0.3s; | |
} | |
.inference-button:hover { | |
background-color: #218838; | |
} | |
</style> | |
<script> | |
let ws; | |
let lossChart; | |
let accuracyChart; | |
// Initialize charts | |
document.addEventListener('DOMContentLoaded', function() { | |
// Loss chart configuration | |
const lossData = [ | |
{ | |
x: [], | |
y: [], | |
name: 'Model A Training Loss', | |
type: 'scatter' | |
}, | |
{ | |
x: [], | |
y: [], | |
name: 'Model B Training Loss', | |
type: 'scatter' | |
} | |
]; | |
const lossLayout = { | |
title: 'Training Loss Comparison', | |
xaxis: { | |
title: 'Iterations', | |
rangemode: 'nonnegative' | |
}, | |
yaxis: { | |
title: 'Loss', | |
rangemode: 'nonnegative' | |
} | |
}; | |
// Accuracy chart configuration | |
const accuracyData = [ | |
{ | |
x: [], | |
y: [], | |
name: 'Model A Training Accuracy', | |
type: 'scatter' | |
}, | |
{ | |
x: [], | |
y: [], | |
name: 'Model B Training Accuracy', | |
type: 'scatter' | |
} | |
]; | |
const accuracyLayout = { | |
title: 'Training Accuracy Comparison', | |
xaxis: { | |
title: 'Iterations', | |
rangemode: 'nonnegative' | |
}, | |
yaxis: { | |
title: 'Accuracy (%)', | |
range: [0, 100] | |
} | |
}; | |
// Create charts | |
Plotly.newPlot('lossChart', lossData, lossLayout); | |
Plotly.newPlot('accuracyChart', accuracyData, accuracyLayout); | |
}); | |
function startComparison() { | |
// Disable start button and enable stop button | |
document.getElementById('startComparison').disabled = true; | |
document.getElementById('stopComparison').disabled = false; | |
// Get configuration for both models | |
const model1Config = { | |
block1: parseInt(document.getElementById('model1_block1').value), | |
block2: parseInt(document.getElementById('model1_block2').value), | |
block3: parseInt(document.getElementById('model1_block3').value), | |
optimizer: document.getElementById('model1_optimizer').value, | |
batch_size: parseInt(document.getElementById('model1_batch_size').value), | |
epochs: parseInt(document.getElementById('model1_epochs').value) | |
}; | |
const model2Config = { | |
block1: parseInt(document.getElementById('model2_block1').value), | |
block2: parseInt(document.getElementById('model2_block2').value), | |
block3: parseInt(document.getElementById('model2_block3').value), | |
optimizer: document.getElementById('model2_optimizer').value, | |
batch_size: parseInt(document.getElementById('model2_batch_size').value), | |
epochs: parseInt(document.getElementById('model2_epochs').value) | |
}; | |
// Setup WebSocket connection | |
ws = new WebSocket(`ws://${window.location.host}/ws/compare`); | |
ws.onopen = function() { | |
console.log('WebSocket connection established'); | |
// Only send the message after connection is established | |
const message = { | |
action: 'start_training', | |
parameters: { | |
model_params: { | |
model_a: model1Config, | |
model_b: model2Config | |
}, | |
dataset_params: { | |
batch_size: model1Config.batch_size, | |
shuffle: true | |
} | |
} | |
}; | |
console.log('Sending message:', message); | |
ws.send(JSON.stringify(message)); | |
}; | |
ws.onmessage = function(event) { | |
console.log('Received message:', event.data); | |
const data = JSON.parse(event.data); | |
if (data.status === 'training') { | |
const modelIndex = data.model === 'A' ? 0 : 1; | |
const iteration = data.metrics.iteration; | |
console.log(`Updating charts for model ${data.model} at iteration ${iteration}`); | |
// Update loss chart using iteration number | |
Plotly.extendTraces('lossChart', { | |
x: [[iteration]], | |
y: [[data.metrics.loss]] | |
}, [modelIndex]); | |
// Update accuracy chart using iteration number | |
Plotly.extendTraces('accuracyChart', { | |
x: [[iteration]], | |
y: [[data.metrics.accuracy]] | |
}, [modelIndex]); | |
// Update progress text with more detailed information | |
const progressText = document.getElementById('training-progress'); | |
if (progressText) { | |
const progress = (data.metrics.iteration / data.metrics.total_iterations * 100).toFixed(1); | |
progressText.textContent = | |
`Training Model ${data.model} - ` + | |
`Epoch ${data.epoch + 1} - ` + | |
`Iteration ${data.metrics.iteration}/${data.metrics.total_iterations} ` + | |
`(${progress}%) - ` + | |
`Batch Size: ${data.batch_size}`; | |
} | |
} | |
else if (data.status === 'complete') { | |
document.getElementById('startComparison').disabled = false; | |
document.getElementById('stopComparison').disabled = true; | |
const progressText = document.getElementById('training-progress'); | |
if (progressText) { | |
progressText.textContent = 'Training Complete!'; | |
} | |
// Show the inference button | |
document.querySelector('.inference-controls').style.display = 'block'; | |
} | |
else if (data.status === 'error') { | |
console.error('Training error:', data.message); | |
alert(`Training error: ${data.message}`); | |
document.getElementById('startComparison').disabled = false; | |
document.getElementById('stopComparison').disabled = true; | |
} | |
}; | |
ws.onerror = function(error) { | |
console.error('WebSocket error:', error); | |
document.getElementById('startComparison').disabled = false; | |
document.getElementById('stopComparison').disabled = true; | |
}; | |
ws.onclose = function(event) { | |
console.log('WebSocket connection closed:', event); | |
document.getElementById('startComparison').disabled = false; | |
document.getElementById('stopComparison').disabled = true; | |
}; | |
} | |
function stopComparison() { | |
if (ws) { | |
ws.close(); | |
} | |
document.getElementById('startComparison').disabled = false; | |
document.getElementById('stopComparison').disabled = true; | |
} | |
</script> | |
</body> | |
</html> |