|
|
import os |
|
|
import torch |
|
|
from torch.utils.data import Dataset |
|
|
import json |
|
|
|
|
|
from torch_geometric.data import HeteroData |
|
|
import networkx as nx |
|
|
|
|
|
class PowerFlowDataset(Dataset): |
|
|
def __init__(self, data_root, split_txt, pq_len, pv_len, slack_len, mask_num=0): |
|
|
self.data_root = data_root |
|
|
with open(split_txt, 'r') as f: |
|
|
self.file_list = [json.loads(line) for line in f] |
|
|
self.pq_len = pq_len |
|
|
self.pv_len = pv_len |
|
|
self.slack_len = slack_len |
|
|
self.mask_num = mask_num |
|
|
|
|
|
|
|
|
self.flag_distance_once_calculated = False |
|
|
self.shortest_paths = None |
|
|
self.node_type_to_global_index = None |
|
|
self.max_depth = 16 |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.file_list) |
|
|
|
|
|
def update_max_depth(self): |
|
|
tmp_distance = max(list(self.shortest_paths.values())) |
|
|
if tmp_distance < self.max_depth: |
|
|
self.max_depth = tmp_distance |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
file_dict = self.file_list[idx] |
|
|
data = torch.load(os.path.join(file_dict['file_path'])) |
|
|
pq_num = data['PQ'].x.shape[0] |
|
|
pv_num = data['PV'].x.shape[0] |
|
|
slack_num = data['Slack'].x.shape[0] |
|
|
|
|
|
Vm, Va, P_net, Q_net, Gs, Bs = 0, 1, 2, 3, 4, 5 |
|
|
|
|
|
|
|
|
|
|
|
data['PQ'].y = data['PQ'].x[:,[Vm, Va, P_net, Q_net]].clone().detach() |
|
|
data['PQ'].x[:, Vm] = 1.0 |
|
|
data['PQ'].x[:, Va] = data['Slack'].x[0, Va].item() |
|
|
|
|
|
non_zero_indices = torch.nonzero(data['PQ'].x[:, Q_net]) |
|
|
data['PQ'].q_mask = torch.ones((pq_num,),dtype=torch.bool) |
|
|
if self.mask_num > 0: |
|
|
if file_dict.get('masked_node') is None: |
|
|
mask_indices = non_zero_indices[torch.randperm(non_zero_indices.shape[0])[:self.mask_num]] |
|
|
else: |
|
|
mask_indices = file_dict['masked_node'][:self.mask_num] |
|
|
data['PQ'].q_mask[mask_indices] = False |
|
|
data['PQ'].x[~data['PQ'].q_mask, Q_net] = 0 |
|
|
|
|
|
data['PV'].y = data['PV'].x[:,[Vm, Va, P_net, Q_net]].clone().detach() |
|
|
data['PV'].x[:, Va] = data['Slack'].x[0, Va].item() |
|
|
data['PV'].x[:, Q_net] = 0 |
|
|
|
|
|
data['Slack'].y = data['Slack'].x[:,[Vm, Va, P_net, Q_net]].clone().detach() |
|
|
data['Slack'].x[:, P_net] = 0 |
|
|
data['Slack'].x[:, Q_net] = 0 |
|
|
|
|
|
return data |
|
|
|