Spaces:
Runtime error
Runtime error
import torch | |
import os | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
def netParams(model): | |
''' | |
helper function to see total network parameters | |
:param model: model | |
:return: total network parameters | |
''' | |
total_paramters = 0 | |
for parameter in model.parameters(): | |
i = len(parameter.size()) | |
p = 1 | |
for j in range(i): | |
p *= parameter.size(j) | |
total_paramters += p | |
return total_paramters | |
def save_checkpoint(state, save_path): | |
""" | |
Save model checkpoint. | |
:param state: model state | |
:param is_best: is this checkpoint the best so far? | |
:param save_path: the path for saving | |
""" | |
filename = 'checkpoint.pth.tar' | |
torch.save(state, os.path.join(save_path, filename)) | |
def plotSignal(mode, signal, titlename): | |
folderpath = "./testFig/" | |
#titlename = "./testFig/Channel plot with mode" + str(mode) | |
#print("draw:", mode, signal.shape) | |
# Selecting one specific channel (e.g., the first channel) | |
channel_to_plot = signal[0, :] | |
# Create a new figure | |
plt.figure(figsize=(10, 6)) | |
# Plot the selected channel | |
plt.plot(channel_to_plot) | |
# Add labels and title | |
plt.xlabel('Time or Sample Index') | |
plt.ylabel('Amplitude') | |
plt.title(titlename) | |
# Save the figure to a file | |
plt.savefig(folderpath + titlename+'.png') | |
def plotHeatmap(mode, data, titlename): | |
folderpath = "./testFig/" | |
#titlename = "Channel plot with mode" + str(mode) | |
# print("draw:", mode, signal.shape) | |
# Selecting one specific channel (e.g., the first channel) | |
#data_2d = data[0, :, :] | |
# Create a new figure | |
plt.figure(figsize=(10, 8)) | |
# Plot the selected channel | |
sns.heatmap(data, cmap="YlGnBu", cbar_kws={"shrink": 0.75}) | |
# Add labels and title | |
plt.xlabel('Time point index') | |
plt.ylabel('Time point index') | |
plt.title(titlename) | |
# Save the figure to a file | |
plt.savefig(folderpath + titlename + '.png') | |
def draw(mode, signal, titlename): | |
if mode == 0: # | |
#signal = signal[0, :, :] | |
signal = signal.cpu().detach().numpy() | |
plotSignal(mode, signal, titlename) | |
elif mode == 1: | |
signal = signal[0, :, :] | |
signal = signal.cpu().detach().numpy() | |
plotSignal(mode, signal, titlename) | |
elif mode == 2: | |
#signal = signal[0, :, :] | |
signal = torch.transpose(signal, 0, 1) | |
signal = signal.cpu().detach().numpy() | |
plotSignal(mode, signal, titlename) | |
elif mode == 3: # plot headmap | |
#signal = signal[0, :, :] | |
#signal = torch.transpose(signal, 0, 1) | |
signal = signal.cpu().detach().numpy() | |
plotHeatmap(mode, signal, titlename) | |