| | import os, json
|
| | import torch
|
| | import utils
|
| |
|
| | def calc_feats(smi, ms, nls, cfg):
|
| | item = {}
|
| | item['ms_bins'] = utils.ms_binner(ms, nls,
|
| | min_mz=cfg.min_mz,
|
| | max_mz=cfg.max_mz,
|
| | bin_size=cfg.bin_size,
|
| | add_nl=cfg.add_nl,
|
| | binary_intn=cfg.binary_intn)
|
| |
|
| | fmcalced = False
|
| | if 'fp' in cfg.mol_encoder:
|
| | if not 'fm' in cfg.mol_encoder:
|
| | item['mol_fps'] = utils.mol_fp_encoder(smi,
|
| | tp=cfg.fptype,
|
| | nbits=cfg.mol_embedding_dim)
|
| | else:
|
| | item['mol_fps'], item['mol_fmvec'] = utils.mol_fp_fm_encoder(smi,
|
| | tp=cfg.fptype,
|
| | nbits=cfg.mol_embedding_dim)
|
| | fmcalced = True
|
| | if 'gnn' in cfg.mol_encoder:
|
| | f = utils.mol_graph_featurizer(smi)
|
| | if not f:
|
| | return None
|
| | item.update(f)
|
| | if 'fm' in cfg.mol_encoder and not fmcalced:
|
| | item['mol_fmvec'] = utils.smi2fmvec(smi)
|
| |
|
| | return item
|
| |
|
| | class Dataset(torch.utils.data.Dataset):
|
| | def __init__(self, inp, cfg):
|
| | if type(inp) is str:
|
| | self.data = json.load(open(inp))
|
| | else:
|
| | self.data = inp
|
| |
|
| | self.cfg = cfg
|
| |
|
| | def __getitem__(self, idx):
|
| | item = {}
|
| | try:
|
| | if 'ms_bins' in self.data[idx]:
|
| | return self.data[idx]
|
| |
|
| | if 'nls' in self.data[idx]:
|
| | nls = self.data[idx]['nls']
|
| | else:
|
| | nls = []
|
| |
|
| | ms = self.data[idx]['ms']
|
| | smi = self.data[idx]['smiles']
|
| |
|
| | item = calc_feats(smi, ms, nls, self.cfg)
|
| |
|
| | except Exception as e:
|
| | print('='*50, idx, str(e))
|
| | return None
|
| |
|
| | return item
|
| |
|
| | def __len__(self):
|
| | return len(self.data)
|
| |
|
| | class DatasetGNNFP(torch.utils.data.Dataset):
|
| | def __init__(self, inp, cfg):
|
| | if type(inp) is str:
|
| | self.data = json.load(open(inp))
|
| | else:
|
| | self.data = inp
|
| |
|
| | self.cfg = cfg
|
| |
|
| | def __getitem__(self, idx):
|
| | try:
|
| | smi = self.data[idx]['smiles']
|
| | item = {}
|
| | item['mol_fps'] = utils.mol_fp_encoder(smi,
|
| | tp=self.cfg.fptype,
|
| | nbits=self.cfg.mol_embedding_dim)
|
| | item.update(utils.mol_graph_featurizer(smi))
|
| | except Exception as e:
|
| | print('='*50, idx, str(e))
|
| | return None
|
| |
|
| | return item
|
| |
|
| | def __len__(self):
|
| | return len(self.data)
|
| |
|
| | class PathDataset(torch.utils.data.Dataset):
|
| | def __init__(self, pathlist, cfg):
|
| | self.fns = pathlist
|
| | self.cfg = cfg
|
| | self.data = {}
|
| |
|
| | def __getitem__(self, idx):
|
| | try:
|
| | item = {}
|
| | nls = []
|
| | if not idx in self.data:
|
| | out = self.proc_data(self.fns[idx], self.cfg.energy)
|
| | if out is None:
|
| | return None
|
| | self.data[idx] = out
|
| |
|
| | ms = self.data[idx]['ms']
|
| | smi = self.data[idx]['smiles']
|
| |
|
| | item = calc_feats(smi, ms, nls, self.cfg)
|
| |
|
| | except Exception as e:
|
| |
|
| | return None
|
| |
|
| | return item
|
| |
|
| | def proc_data(self, fn, energy='Energy1'):
|
| | tl = open(fn).readlines()
|
| | l = []
|
| | try:
|
| | flag = False
|
| | for i in tl:
|
| | if energy in i:
|
| | smi = i.split(';')[-2]
|
| | flag = True
|
| | continue
|
| | if 'END IONS' in i:
|
| | if flag:
|
| | break
|
| | if flag:
|
| | mz, intn = i.split(' ')
|
| | l.append((float(mz), float(intn)))
|
| | except:
|
| | return None
|
| |
|
| | out = {'ms': l, 'smiles': smi}
|
| | return out
|
| |
|
| | def __len__(self):
|
| | return len(self.fns) |