File size: 3,681 Bytes
4c1a791
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
let ws;

function initializeCharts() {
    const lossData = [{
        name: 'Training Loss',
        x: [],
        y: [],
        type: 'scatter'
    }, {
        name: 'Validation Loss',
        x: [],
        y: [],
        type: 'scatter'
    }];

    const accuracyData = [{
        name: 'Training Accuracy',
        x: [],
        y: [],
        type: 'scatter'
    }, {
        name: 'Validation Accuracy',
        x: [],
        y: [],
        type: 'scatter'
    }];

    Plotly.newPlot('loss-plot', lossData, {
        title: 'Training and Validation Loss',
        xaxis: { title: 'Iterations' },
        yaxis: { title: 'Loss' }
    });

    Plotly.newPlot('accuracy-plot', accuracyData, {
        title: 'Training and Validation Accuracy',
        xaxis: { title: 'Iterations' },
        yaxis: { title: 'Accuracy (%)' }
    });
}

function updateCharts(data) {
    const iteration = data.epoch * data.batch;
    
    Plotly.extendTraces('loss-plot', {
        x: [[iteration], [iteration]],
        y: [[data.train_loss], [data.val_loss]]
    }, [0, 1]);

    Plotly.extendTraces('accuracy-plot', {
        x: [[iteration], [iteration]],
        y: [[data.train_acc], [data.val_acc]]
    }, [0, 1]);

    // Update training logs
    const logsDiv = document.getElementById('training-logs');
    logsDiv.innerHTML = `
        <p>Epoch: ${data.epoch + 1}</p>
        <p>Training Loss: ${data.train_loss.toFixed(4)}</p>
        <p>Training Accuracy: ${data.train_acc.toFixed(2)}%</p>
        <p>Validation Loss: ${data.val_loss.toFixed(4)}</p>
        <p>Validation Accuracy: ${data.val_acc.toFixed(2)}%</p>
    `;
}

async function trainModel() {
    console.log("Training started...");  // Debug log
    const config = {
        kernels: [
            parseInt(document.getElementById('kernel1').value),
            parseInt(document.getElementById('kernel2').value),
            parseInt(document.getElementById('kernel3').value)
        ],
        optimizer: document.getElementById('optimizer').value,
        batch_size: parseInt(document.getElementById('batch_size').value),
        epochs: parseInt(document.getElementById('epochs').value)
    };

    console.log("Config:", config);  // Debug log

    // Show progress section and initialize charts
    document.getElementById('training-progress').classList.remove('hidden');
    initializeCharts();

    try {
        // Connect to WebSocket
        console.log("Connecting to WebSocket...");  // Debug log
        ws = new WebSocket(`ws://${window.location.host}/ws/train`);
        
        ws.onopen = function() {
            console.log("WebSocket connection established");
            // Send configuration once connected
            ws.send(JSON.stringify(config));
            console.log("Config sent to server");  // Debug log
        };

        ws.onmessage = function(event) {
            console.log("Received message:", event.data);  // Debug log
            const data = JSON.parse(event.data);
            if (data.status === "completed") {
                alert('Training completed successfully!');
            } else if (data.status === "error") {
                alert('Error during training: ' + data.message);
            } else {
                updateCharts(data);
            }
        };

        ws.onerror = function(error) {
            console.error('WebSocket error:', error);
            alert('Error connecting to training server');
        };

        ws.onclose = function() {
            console.log('WebSocket connection closed');
        };

    } catch (error) {
        console.error('Error:', error);
        alert('Error during training: ' + error.message);
    }
}