MnistStudio / templates /train_compare.html
Shilpaj's picture
Refactor: css file address
244431c
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Compare Models - MNIST</title>
<link rel="stylesheet" href="{{ url_for('static', path='/css/style.css') }}">
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&family=Roboto+Mono&display=swap" rel="stylesheet">
</head>
<body>
<div class="container">
<h1>Compare Models</h1>
<div class="models-grid">
<!-- Model A Configuration -->
<div class="model-config">
<h3>Model A</h3>
<div class="network-config">
<h4>Network Architecture</h4>
<div class="block-config">
<div class="block">
<label for="model1_block1">Block-1:</label>
<select id="model1_block1" name="block1" class="form-select">
<option value="8">8</option>
<option value="16">16</option>
<option value="32" selected>32</option>
<option value="64">64</option>
<option value="128">128</option>
</select>
</div>
<div class="block">
<label for="model1_block2">Block-2:</label>
<select id="model1_block2" name="block2" class="form-select">
<option value="8">8</option>
<option value="16">16</option>
<option value="32">32</option>
<option value="64" selected>64</option>
<option value="128">128</option>
</select>
</div>
<div class="block">
<label for="model1_block3">Block-3:</label>
<select id="model1_block3" name="block3" class="form-select">
<option value="8">8</option>
<option value="16">16</option>
<option value="32">32</option>
<option value="64">64</option>
<option value="128" selected>128</option>
</select>
</div>
</div>
</div>
<div class="training-config">
<div class="config-item">
<label for="model1_optimizer">Optimizer:</label>
<select id="model1_optimizer" name="optimizer">
<option value="SGD" selected>SGD</option>
<option value="Adam">Adam</option>
</select>
</div>
<div class="config-item">
<label for="model1_batch_size">Batch Size:</label>
<select id="model1_batch_size" name="batch_size">
<option value="32">32</option>
<option value="64" selected>64</option>
<option value="128">128</option>
</select>
</div>
<div class="config-item">
<label for="model1_epochs">Epochs:</label>
<select id="model1_epochs" name="epochs">
<option value="1">1</option>
<option value="2">2</option>
<option value="3">3</option>
</select>
</div>
</div>
</div>
<!-- Model B Configuration -->
<div class="model-config">
<h3>Model B</h3>
<div class="network-config">
<h4>Network Architecture</h4>
<div class="block-config">
<div class="block">
<label for="model2_block1">Block-1:</label>
<select id="model2_block1" name="block1" class="form-select">
<option value="8">8</option>
<option value="16">16</option>
<option value="32" selected>32</option>
<option value="64">64</option>
<option value="128">128</option>
</select>
</div>
<div class="block">
<label for="model2_block2">Block-2:</label>
<select id="model2_block2" name="block2" class="form-select">
<option value="8">8</option>
<option value="16">16</option>
<option value="32">32</option>
<option value="64" selected>64</option>
<option value="128">128</option>
</select>
</div>
<div class="block">
<label for="model2_block3">Block-3:</label>
<select id="model2_block3" name="block3" class="form-select">
<option value="8">8</option>
<option value="16">16</option>
<option value="32">32</option>
<option value="64">64</option>
<option value="128" selected>128</option>
</select>
</div>
</div>
</div>
<div class="training-config">
<div class="config-item">
<label for="model2_optimizer">Optimizer:</label>
<select id="model2_optimizer" name="optimizer">
<option value="SGD" selected>SGD</option>
<option value="Adam">Adam</option>
</select>
</div>
<div class="config-item">
<label for="model2_batch_size">Batch Size:</label>
<select id="model2_batch_size" name="batch_size">
<option value="32">32</option>
<option value="64" selected>64</option>
<option value="128">128</option>
</select>
</div>
<div class="config-item">
<label for="model2_epochs">Epochs:</label>
<select id="model2_epochs" name="epochs">
<option value="1">1</option>
<option value="2">2</option>
<option value="3">3</option>
</select>
</div>
</div>
</div>
</div>
<!-- Training Controls -->
<div class="controls">
<button id="startComparison" onclick="startComparison()">Start Comparison</button>
<button id="stopComparison" onclick="stopComparison()" disabled>Stop Comparison</button>
</div>
<!-- Training Progress -->
<div class="charts-container">
<div id="lossChart"></div>
<div id="accuracyChart"></div>
</div>
<!-- Add this after the charts container -->
<div class="training-status">
<p id="training-progress"></p>
</div>
<!-- Add this after the training-status div -->
<div class="inference-controls" style="display: none;">
<button id="goToInference" onclick="window.location.href='/inference'" class="inference-button">
Try Model Inference
</button>
</div>
</div>
<style>
.container {
max-width: 1200px;
margin: 0 auto;
padding: 20px;
}
.models-grid {
display: grid;
grid-template-columns: 1fr 1fr;
gap: 20px;
margin-bottom: 20px;
}
.model-config {
padding: 20px;
border: 1px solid #ddd;
border-radius: 5px;
margin-bottom: 20px;
}
.network-config {
margin-bottom: 20px;
}
.network-config h4 {
margin: 0 0 15px 0;
font-size: 1.1em;
}
.block-config {
display: flex;
justify-content: space-between;
gap: 20px;
}
.block {
flex: 1;
}
.block label {
display: block;
margin-bottom: 5px;
font-weight: bold;
}
.training-config {
display: flex;
gap: 20px;
}
.config-item {
flex: 1;
}
.config-item label {
display: block;
margin-bottom: 5px;
font-weight: bold;
}
select {
width: 100%;
padding: 8px;
border: 1px solid #ddd;
border-radius: 4px;
}
.controls {
margin: 20px 0;
text-align: center;
}
button {
padding: 10px 20px;
margin-right: 10px;
border: none;
border-radius: 4px;
background-color: #007bff;
color: white;
cursor: pointer;
}
button:disabled {
background-color: #ccc;
cursor: not-allowed;
}
.charts-container {
display: flex;
flex-direction: column;
gap: 20px;
margin-top: 20px;
}
#lossChart, #accuracyChart {
height: 400px;
width: 100%;
}
h4 {
margin: 0 0 10px 0;
}
.section-title {
color: white;
font-weight: bold;
margin: 0 0 10px 0;
font-size: 1.1em;
text-transform: uppercase;
}
.network-config .section-title {
margin: 0 0 15px 0;
}
.config-item .section-title {
margin-bottom: 5px;
}
.training-status {
text-align: center;
margin: 20px 0;
font-weight: bold;
}
.inference-controls {
margin: 20px 0;
text-align: center;
}
.inference-button {
background-color: #28a745;
padding: 12px 24px;
font-size: 1.1em;
transition: background-color 0.3s;
}
.inference-button:hover {
background-color: #218838;
}
</style>
<script>
let ws;
let lossChart;
let accuracyChart;
// Initialize charts
document.addEventListener('DOMContentLoaded', function() {
// Loss chart configuration
const lossData = [
{
x: [],
y: [],
name: 'Model A Training Loss',
type: 'scatter'
},
{
x: [],
y: [],
name: 'Model B Training Loss',
type: 'scatter'
}
];
const lossLayout = {
title: 'Training Loss Comparison',
xaxis: {
title: 'Iterations',
rangemode: 'nonnegative'
},
yaxis: {
title: 'Loss',
rangemode: 'nonnegative'
}
};
// Accuracy chart configuration
const accuracyData = [
{
x: [],
y: [],
name: 'Model A Training Accuracy',
type: 'scatter'
},
{
x: [],
y: [],
name: 'Model B Training Accuracy',
type: 'scatter'
}
];
const accuracyLayout = {
title: 'Training Accuracy Comparison',
xaxis: {
title: 'Iterations',
rangemode: 'nonnegative'
},
yaxis: {
title: 'Accuracy (%)',
range: [0, 100]
}
};
// Create charts
Plotly.newPlot('lossChart', lossData, lossLayout);
Plotly.newPlot('accuracyChart', accuracyData, accuracyLayout);
});
function startComparison() {
// Disable start button and enable stop button
document.getElementById('startComparison').disabled = true;
document.getElementById('stopComparison').disabled = false;
// Get configuration for both models
const model1Config = {
block1: parseInt(document.getElementById('model1_block1').value),
block2: parseInt(document.getElementById('model1_block2').value),
block3: parseInt(document.getElementById('model1_block3').value),
optimizer: document.getElementById('model1_optimizer').value,
batch_size: parseInt(document.getElementById('model1_batch_size').value),
epochs: parseInt(document.getElementById('model1_epochs').value)
};
const model2Config = {
block1: parseInt(document.getElementById('model2_block1').value),
block2: parseInt(document.getElementById('model2_block2').value),
block3: parseInt(document.getElementById('model2_block3').value),
optimizer: document.getElementById('model2_optimizer').value,
batch_size: parseInt(document.getElementById('model2_batch_size').value),
epochs: parseInt(document.getElementById('model2_epochs').value)
};
// Setup WebSocket connection
ws = new WebSocket(`ws://${window.location.host}/ws/compare`);
ws.onopen = function() {
console.log('WebSocket connection established');
// Only send the message after connection is established
const message = {
action: 'start_training',
parameters: {
model_params: {
model_a: model1Config,
model_b: model2Config
},
dataset_params: {
batch_size: model1Config.batch_size,
shuffle: true
}
}
};
console.log('Sending message:', message);
ws.send(JSON.stringify(message));
};
ws.onmessage = function(event) {
console.log('Received message:', event.data);
const data = JSON.parse(event.data);
if (data.status === 'training') {
const modelIndex = data.model === 'A' ? 0 : 1;
const iteration = data.metrics.iteration;
console.log(`Updating charts for model ${data.model} at iteration ${iteration}`);
// Update loss chart using iteration number
Plotly.extendTraces('lossChart', {
x: [[iteration]],
y: [[data.metrics.loss]]
}, [modelIndex]);
// Update accuracy chart using iteration number
Plotly.extendTraces('accuracyChart', {
x: [[iteration]],
y: [[data.metrics.accuracy]]
}, [modelIndex]);
// Update progress text with more detailed information
const progressText = document.getElementById('training-progress');
if (progressText) {
const progress = (data.metrics.iteration / data.metrics.total_iterations * 100).toFixed(1);
progressText.textContent =
`Training Model ${data.model} - ` +
`Epoch ${data.epoch + 1} - ` +
`Iteration ${data.metrics.iteration}/${data.metrics.total_iterations} ` +
`(${progress}%) - ` +
`Batch Size: ${data.batch_size}`;
}
}
else if (data.status === 'complete') {
document.getElementById('startComparison').disabled = false;
document.getElementById('stopComparison').disabled = true;
const progressText = document.getElementById('training-progress');
if (progressText) {
progressText.textContent = 'Training Complete!';
}
// Show the inference button
document.querySelector('.inference-controls').style.display = 'block';
}
else if (data.status === 'error') {
console.error('Training error:', data.message);
alert(`Training error: ${data.message}`);
document.getElementById('startComparison').disabled = false;
document.getElementById('stopComparison').disabled = true;
}
};
ws.onerror = function(error) {
console.error('WebSocket error:', error);
document.getElementById('startComparison').disabled = false;
document.getElementById('stopComparison').disabled = true;
};
ws.onclose = function(event) {
console.log('WebSocket connection closed:', event);
document.getElementById('startComparison').disabled = false;
document.getElementById('stopComparison').disabled = true;
};
}
function stopComparison() {
if (ws) {
ws.close();
}
document.getElementById('startComparison').disabled = false;
document.getElementById('stopComparison').disabled = true;
}
</script>
</body>
</html>