File size: 2,715 Bytes
c58daf7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
import os
import matplotlib.pyplot as plt
import seaborn as sns


def netParams(model):
    '''
    helper function to see total network parameters
    :param model: model
    :return: total network parameters
    '''
    total_paramters = 0
    for parameter in model.parameters():
        i = len(parameter.size())
        p = 1
        for j in range(i):
            p *= parameter.size(j)
        total_paramters += p

    return total_paramters

def save_checkpoint(state, save_path):
    """
    Save model checkpoint.
    :param state: model state
    :param is_best: is this checkpoint the best so far?
    :param save_path: the path for saving
    """
    filename = 'checkpoint.pth.tar'
    torch.save(state, os.path.join(save_path, filename))

def plotSignal(mode, signal, titlename):
    folderpath = "./testFig/"
    #titlename = "./testFig/Channel plot with mode" + str(mode)

    #print("draw:", mode, signal.shape)

    # Selecting one specific channel (e.g., the first channel)
    channel_to_plot = signal[0, :]

    # Create a new figure
    plt.figure(figsize=(10, 6))

    # Plot the selected channel
    plt.plot(channel_to_plot)

    # Add labels and title
    plt.xlabel('Time or Sample Index')
    plt.ylabel('Amplitude')
    plt.title(titlename)

    # Save the figure to a file
    plt.savefig(folderpath + titlename+'.png')

def plotHeatmap(mode, data, titlename):
    folderpath = "./testFig/"
    #titlename = "Channel plot with mode" + str(mode)

    # print("draw:", mode, signal.shape)

    # Selecting one specific channel (e.g., the first channel)
    #data_2d = data[0, :, :]

    # Create a new figure
    plt.figure(figsize=(10, 8))

    # Plot the selected channel
    sns.heatmap(data, cmap="YlGnBu", cbar_kws={"shrink": 0.75})

    # Add labels and title
    plt.xlabel('Time point index')
    plt.ylabel('Time point index')
    plt.title(titlename)

    # Save the figure to a file
    plt.savefig(folderpath + titlename + '.png')

def draw(mode, signal, titlename):
    if mode == 0: #
        #signal = signal[0, :, :]
        signal = signal.cpu().detach().numpy()
        plotSignal(mode, signal, titlename)
    elif mode == 1:
        signal = signal[0, :, :]
        signal = signal.cpu().detach().numpy()
        plotSignal(mode, signal, titlename)
    elif mode == 2:
        #signal = signal[0, :, :]
        signal = torch.transpose(signal, 0, 1)
        signal = signal.cpu().detach().numpy()
        plotSignal(mode, signal, titlename)
    elif mode == 3: # plot headmap
        #signal = signal[0, :, :]
        #signal = torch.transpose(signal, 0, 1)
        signal = signal.cpu().detach().numpy()
        plotHeatmap(mode, signal, titlename)