Spaces:
Sleeping
Sleeping
File size: 3,681 Bytes
4c1a791 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
let ws;
function initializeCharts() {
const lossData = [{
name: 'Training Loss',
x: [],
y: [],
type: 'scatter'
}, {
name: 'Validation Loss',
x: [],
y: [],
type: 'scatter'
}];
const accuracyData = [{
name: 'Training Accuracy',
x: [],
y: [],
type: 'scatter'
}, {
name: 'Validation Accuracy',
x: [],
y: [],
type: 'scatter'
}];
Plotly.newPlot('loss-plot', lossData, {
title: 'Training and Validation Loss',
xaxis: { title: 'Iterations' },
yaxis: { title: 'Loss' }
});
Plotly.newPlot('accuracy-plot', accuracyData, {
title: 'Training and Validation Accuracy',
xaxis: { title: 'Iterations' },
yaxis: { title: 'Accuracy (%)' }
});
}
function updateCharts(data) {
const iteration = data.epoch * data.batch;
Plotly.extendTraces('loss-plot', {
x: [[iteration], [iteration]],
y: [[data.train_loss], [data.val_loss]]
}, [0, 1]);
Plotly.extendTraces('accuracy-plot', {
x: [[iteration], [iteration]],
y: [[data.train_acc], [data.val_acc]]
}, [0, 1]);
// Update training logs
const logsDiv = document.getElementById('training-logs');
logsDiv.innerHTML = `
<p>Epoch: ${data.epoch + 1}</p>
<p>Training Loss: ${data.train_loss.toFixed(4)}</p>
<p>Training Accuracy: ${data.train_acc.toFixed(2)}%</p>
<p>Validation Loss: ${data.val_loss.toFixed(4)}</p>
<p>Validation Accuracy: ${data.val_acc.toFixed(2)}%</p>
`;
}
async function trainModel() {
console.log("Training started..."); // Debug log
const config = {
kernels: [
parseInt(document.getElementById('kernel1').value),
parseInt(document.getElementById('kernel2').value),
parseInt(document.getElementById('kernel3').value)
],
optimizer: document.getElementById('optimizer').value,
batch_size: parseInt(document.getElementById('batch_size').value),
epochs: parseInt(document.getElementById('epochs').value)
};
console.log("Config:", config); // Debug log
// Show progress section and initialize charts
document.getElementById('training-progress').classList.remove('hidden');
initializeCharts();
try {
// Connect to WebSocket
console.log("Connecting to WebSocket..."); // Debug log
ws = new WebSocket(`ws://${window.location.host}/ws/train`);
ws.onopen = function() {
console.log("WebSocket connection established");
// Send configuration once connected
ws.send(JSON.stringify(config));
console.log("Config sent to server"); // Debug log
};
ws.onmessage = function(event) {
console.log("Received message:", event.data); // Debug log
const data = JSON.parse(event.data);
if (data.status === "completed") {
alert('Training completed successfully!');
} else if (data.status === "error") {
alert('Error during training: ' + data.message);
} else {
updateCharts(data);
}
};
ws.onerror = function(error) {
console.error('WebSocket error:', error);
alert('Error connecting to training server');
};
ws.onclose = function() {
console.log('WebSocket connection closed');
};
} catch (error) {
console.error('Error:', error);
alert('Error during training: ' + error.message);
}
} |