Spaces:
Running
Running
| import os | |
| import torch | |
| import random | |
| import selfies as sf | |
| from transformers import AutoTokenizer | |
| ################################ | |
| def getrandomnumber(numbers, k, weights=None): | |
| if k == 1: | |
| return random.choices(numbers, weights=weights, k=k)[0] | |
| else: | |
| return random.choices(numbers, weights=weights, k=k) | |
| # simple smiles tokenizer | |
| # treat every charater as token | |
| def build_simple_smiles_vocab(dir): | |
| assert dir is not None, "dir and smiles_vocab can not be None at the same time." | |
| if not os.path.exists(os.path.join(dir, "simple_smiles_tokenizer_vocab.txt")): | |
| # print('Generating Vocabulary for {} ...'.format(dir)) | |
| dirs = list( | |
| os.path.join(dir, i) for i in ["train.txt", "validation.txt", "test.txt"] | |
| ) | |
| smiles = [] | |
| for idir in dirs: | |
| with open(idir, "r") as f: | |
| for i, line in enumerate(f): | |
| if i == 0: | |
| continue | |
| line = line.split("\t") | |
| assert len(line) == 3, "Dataset format error." | |
| if line[1] != "*": | |
| smiles.append(line[1].strip()) | |
| char_set = set() | |
| for smi in smiles: | |
| for c in smi: | |
| char_set.add(c) | |
| vocabstring = "".join(char_set) | |
| with open(os.path.join(dir, "simple_smiles_tokenizer_vocab.txt"), "w") as f: | |
| f.write(os.path.join(vocabstring)) | |
| return vocabstring | |
| else: | |
| print("Reading in Vocabulary...") | |
| with open(os.path.join(dir, "simple_smiles_tokenizer_vocab.txt"), "r") as f: | |
| vocabstring = f.readline().strip() | |
| return vocabstring | |
| class Tokenizer: | |
| def __init__( | |
| self, | |
| pretrained_name="QizhiPei/biot5-base-text2mol", | |
| selfies_dict_path=os.path.join("dataset", "selfies_dict.txt"), | |
| ): | |
| self.tokenizer = self.get_tokenizer(pretrained_name, selfies_dict_path) | |
| def get_tokenizer(self, pretrained_name, selfies_dict_path): | |
| tokenizer = AutoTokenizer.from_pretrained(pretrained_name, use_fast=True) | |
| tokenizer.model_max_length = int(1e9) | |
| amino_acids = [ | |
| "A", | |
| "C", | |
| "D", | |
| "E", | |
| "F", | |
| "G", | |
| "H", | |
| "I", | |
| "K", | |
| "L", | |
| "M", | |
| "N", | |
| "P", | |
| "Q", | |
| "R", | |
| "S", | |
| "T", | |
| "V", | |
| "W", | |
| "Y", | |
| ] | |
| prefixed_amino_acids = [f"<p>{aa}" for aa in amino_acids] | |
| tokenizer.add_tokens(prefixed_amino_acids) | |
| selfies_dict_list = [line.strip() for line in open(selfies_dict_path)] | |
| tokenizer.add_tokens(selfies_dict_list) | |
| special_tokens_dict = { | |
| "additional_special_tokens": [ | |
| "<bom>", | |
| "<eom>", | |
| "<bop>", | |
| "<eop>", | |
| "MOLECULE NAME", | |
| "DESCRIPTION", | |
| "PROTEIN NAME", | |
| "FUNCTION", | |
| "SUBCELLULAR LOCATION", | |
| "PROTEIN FAMILIES", | |
| ] | |
| } | |
| tokenizer.add_special_tokens(special_tokens_dict) | |
| return tokenizer | |
| def __call__(self, *args, **kwds): | |
| return self.tokenizer(*args, **kwds) | |
| def __len__(self): | |
| return len(self.tokenizer) | |
| def corrupt(self, selfies_list: list): | |
| tensors = [] | |
| if type(selfies_list) is str: | |
| selfies_list = [selfies_list] | |
| for selfies in selfies_list: | |
| tensors.append(self.corrupt_one(selfies)) | |
| return torch.concat(tensors, dim=0) | |
| # TODO: rewrite this for selfies | |
| def corrupt_one(self, selfies): | |
| smi = sf.decoder(selfies) | |
| # res = [self.toktoid[i] for i in self.rg.findall(smi)] | |
| res = [i for i in self.rg.findall(smi)] | |
| total_length = len(res) + 2 | |
| if total_length > self.max_len: | |
| return self.encode_one(smi) | |
| ######################## start corruption ########################### | |
| r = random.random() | |
| if r < 0.3: | |
| pa, ring = True, True | |
| elif r < 0.65: | |
| pa, ring = True, False | |
| else: | |
| pa, ring = False, True | |
| ######################### | |
| max_ring_num = 1 | |
| ringpos = [] | |
| papos = [] | |
| for pos, at in enumerate(res): | |
| if at == "(" or at == ")": | |
| papos.append(pos) | |
| elif at.isnumeric(): | |
| max_ring_num = max(max_ring_num, int(at)) | |
| ringpos.append(pos) | |
| # ( & ) remove | |
| r = random.random() | |
| if r < 0.3: | |
| remove, padd = True, True | |
| elif r < 0.65: | |
| remove, padd = True, False | |
| else: | |
| remove, padd = False, True | |
| if pa and len(papos) > 0: | |
| if remove: | |
| # remove pa | |
| n_remove = getrandomnumber( | |
| [1, 2, 3, 4], 1, weights=[0.6, 0.2, 0.1, 0.1] | |
| ) | |
| p_remove = set(random.choices(papos, weights=None, k=n_remove)) | |
| total_length -= len(p_remove) | |
| for p in p_remove: | |
| res[p] = None | |
| # print('debug pa delete {}'.format(p)) | |
| # Ring remove | |
| r = random.random() | |
| if r < 0.3: | |
| remove, radd = True, True | |
| elif r < 0.65: | |
| remove, radd = True, False | |
| else: | |
| remove, radd = False, True | |
| if ring and len(ringpos) > 0: | |
| if remove: | |
| # remove ring | |
| n_remove = getrandomnumber( | |
| [1, 2, 3, 4], 1, weights=[0.7, 0.2, 0.05, 0.05] | |
| ) | |
| p_remove = set(random.choices(ringpos, weights=None, k=n_remove)) | |
| total_length -= len(p_remove) | |
| for p in p_remove: | |
| res[p] = None | |
| # print('debug ring delete {}'.format(p)) | |
| # ring add & ( ) add | |
| if pa: | |
| if padd: | |
| n_add = getrandomnumber([1, 2, 3], 1, weights=[0.8, 0.2, 0.1]) | |
| n_add = min(self.max_len - total_length, n_add) | |
| for _ in range(n_add): | |
| sele = random.randrange(len(res) + 1) | |
| res.insert(sele, "(" if random.random() < 0.5 else ")") | |
| # print('debug pa add {}'.format(sele)) | |
| total_length += 1 | |
| if ring: | |
| if radd: | |
| n_add = getrandomnumber([1, 2, 3], 1, weights=[0.8, 0.2, 0.1]) | |
| n_add = min(self.max_len - total_length, n_add) | |
| for _ in range(n_add): | |
| sele = random.randrange(len(res) + 1) | |
| res.insert(sele, str(random.randrange(1, max_ring_num + 1))) | |
| # print('debug ring add {}'.format(sele)) | |
| total_length += 1 | |
| ########################## end corruption ############################### | |
| # print('test:',res) | |
| # print('test:',''.join([i for i in res if i is not None])) | |
| res = [self.toktoid[i] for i in res if i is not None] | |
| res = [1] + res + [2] | |
| if len(res) < self.max_len: | |
| res += [0] * (self.max_len - len(res)) | |
| else: | |
| res = res[: self.max_len] | |
| res[-1] = 2 | |
| return torch.LongTensor([res]) | |
| def decode_one(self, sample): | |
| return self.tokenizer.decode(sample) | |
| def decode(self, sample_list): | |
| if len(sample_list.shape)==1: | |
| return [self.decode_one(sample_list)] | |
| return [self.decode_one(sample) for sample in sample_list] | |
| if __name__ == "__main__": | |
| import selfies as sf | |
| tokenizer = Tokenizer( | |
| selfies_dict_path=r"D:\molecule\mol-lang-bridge\dataset\selfies_dict.txt" | |
| ) | |
| smiles = [ | |
| "[210Po]", | |
| "C[C@H]1C(=O)[C@H]([C@H]([C@H](O1)OP(=O)(O)OP(=O)(O)OC[C@@H]2[C@H](C[C@@H](O2)N3C=C(C(=O)NC3=O)C)O)O)O", | |
| "C(O)P(=O)(O)[O-]", | |
| "CCCCCCCCCCCC(=O)OC(=O)CCCCCCCCCCC", | |
| "C[C@]12CC[C@H](C[C@H]1CC[C@@H]3[C@@H]2CC[C@]4([C@H]3CCC4=O)C)O[C@H]5[C@@H]([C@H]([C@@H]([C@H](O5)C(=O)O)O)O)O", | |
| ] | |
| selfies = [sf.encoder(smiles_ele) for smiles_ele in smiles] | |
| output = tokenizer( | |
| selfies, | |
| max_length=512, | |
| truncation=True, | |
| padding="max_length", | |
| add_special_tokens=True, | |
| return_tensors="pt", | |
| return_attention_mask=True, | |
| ) | |
| print(output["input_ids"]) | |