├── .gitignore ├── API ├── __init__.py ├── aug_dataset.py ├── dataloader.py ├── dataloader_gtrans.py ├── featurizer.py ├── recorder.py ├── rpuzzles_dataset.py ├── struct2seq_dataset.py └── utils.py ├── assets ├── dataset_stat.png └── visualization_example.png ├── checkpoints ├── checkpoint.pth ├── log.log └── model_param.json ├── environment.yml ├── example.pdb ├── main.py ├── methods ├── __init__.py ├── rdesign.py └── utils.py ├── model ├── __init__.py ├── feature.py ├── module.py └── rdesign_model.py ├── parser.py ├── readme.md └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .vscode 3 | .ipynb_checkpoints 4 | data/ 5 | results/ 6 | tmp/ 7 | wandb/ 8 | */test.ipynb 9 | fasta_pred/ -------------------------------------------------------------------------------- /API/__init__.py: -------------------------------------------------------------------------------- 1 | from .recorder import Recorder 2 | from .dataloader import load_data -------------------------------------------------------------------------------- /API/aug_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from tqdm import tqdm 4 | import _pickle as cPickle 5 | 6 | import torch.utils.data as data 7 | from .utils import cached_property 8 | 9 | 10 | class AugDataset(data.Dataset): 11 | def __init__(self, path='./', mode='train', load_full_data= True): 12 | self.path = path 13 | self.mode = mode 14 | self.load_full_data = load_full_data 15 | self.data = self.cache_data[mode] 16 | 17 | @cached_property 18 | def cache_data(self): 19 | alphabet_set = set(['A', 'U', 'C', 'G']) 20 | if os.path.exists(self.path): 21 | data_dict = {'train': [], 'val': [], 'test': []} 22 | # val and test data 23 | if self.load_full_data: 24 | data_list = ['train', 'val', 'test'] 25 | else: 26 | data_list = ['test'] 27 | for split in data_list: 28 | data = cPickle.load(open(os.path.join(self.path, split + '_data.pt'), 'rb')) 29 | for entry in tqdm(data): 30 | for key, val in entry['coords'].items(): 31 | entry['coords'][key] = np.asarray(val) 32 | bad_chars = set([s for s in entry['seq']]).difference(alphabet_set) 33 | if len(bad_chars) == 0: 34 | data_dict[split].append(entry) 35 | return data_dict 36 | else: 37 | raise "no such file:{} !!!".format(self.path) 38 | 39 | def change_mode(self, mode): 40 | self.data = self.cache_data[mode] 41 | 42 | def __len__(self): 43 | return len(self.data) 44 | 45 | def __getitem__(self, index): 46 | return self.data[index] 47 | -------------------------------------------------------------------------------- /API/dataloader.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from .aug_dataset import AugDataset 3 | from .dataloader_gtrans import DataLoader_GTrans 4 | from .featurizer import featurize_HC, featurize_HC_Aug 5 | 6 | 7 | def load_data(batch_size, data_root, num_workers=8, load_full_data = True, **kwargs): 8 | if load_full_data: 9 | dataset = AugDataset(data_root, mode='train') 10 | train_set, valid_set, test_set = map(lambda x: copy.deepcopy(x), [dataset] * 3) 11 | valid_set.change_mode('val') 12 | test_set.change_mode('test') 13 | 14 | train_loader = DataLoader_GTrans(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=featurize_HC_Aug) 15 | valid_loader = DataLoader_GTrans(valid_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=featurize_HC) 16 | test_loader = DataLoader_GTrans(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=featurize_HC) 17 | return train_loader, valid_loader, test_loader 18 | else: 19 | dataset = AugDataset(data_root, mode='train', load_full_data = False) 20 | dataset.change_mode('test') 21 | test_set = dataset 22 | 23 | test_loader = DataLoader_GTrans(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=featurize_HC) 24 | return test_loader 25 | -------------------------------------------------------------------------------- /API/dataloader_gtrans.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class DataLoader_GTrans(torch.utils.data.DataLoader): 5 | def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, **kwargs): 6 | super(DataLoader_GTrans, self).__init__(dataset, batch_size, shuffle, sampler, batch_sampler, num_workers, collate_fn,**kwargs) 7 | self.featurizer = collate_fn -------------------------------------------------------------------------------- /API/featurizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def find_bracket_pairs(ss, seq): 6 | pairs = [] 7 | stack = [] 8 | for i, c in enumerate(ss): 9 | if c == '(': 10 | stack.append(i) 11 | elif c == ')': 12 | if stack: 13 | pairs.append((stack.pop(), i)) 14 | else: 15 | pairs.append((None, i)) 16 | if stack: 17 | pairs.extend(zip(stack[::-1], range(i, i - len(stack), -1))) 18 | 19 | npairs = [] 20 | for pair in pairs: 21 | if None in pair: 22 | continue 23 | p_a, p_b = pair 24 | if (seq[p_a], seq[p_b]) in (('A', 'U'), ('U', 'A'), ('C', 'G'), ('G', 'C')): 25 | npairs.append(pair) 26 | return npairs 27 | 28 | def shuffle_subset(n, p): 29 | n_shuffle = np.random.binomial(n, p) 30 | ix = np.arange(n) 31 | ix_subset = np.random.choice(ix, size=n_shuffle, replace=False) 32 | ix_subset_shuffled = np.copy(ix_subset) 33 | np.random.shuffle(ix_subset_shuffled) 34 | ix[ix_subset] = ix_subset_shuffled 35 | return ix 36 | 37 | def featurize_HC(batch): 38 | alphabet = 'AUCG' 39 | B = len(batch) 40 | lengths = np.array([len(b['seq']) for b in batch], dtype=np.int32) 41 | L_max = max([len(b['seq']) for b in batch]) 42 | X = np.zeros([B, L_max, 6, 3]) 43 | S = np.zeros([B, L_max], dtype=np.int32) 44 | clus = np.zeros([B], dtype=np.int32) 45 | ss_pos = np.zeros([B, L_max], dtype=np.int32) 46 | 47 | ss_pair = [] 48 | names = [] 49 | 50 | # Build the batch 51 | for i, b in enumerate(batch): 52 | x = np.stack([b['coords'][c] for c in ['P', 'O5\'', 'C5\'', 'C4\'', 'C3\'', 'O3\'']], 1) 53 | 54 | l = len(b['seq']) 55 | x_pad = np.pad(x, [[0, L_max-l], [0,0], [0,0]], 'constant', constant_values=(np.nan, )) 56 | X[i,:,:,:] = x_pad 57 | 58 | indices = np.asarray([alphabet.index(a) for a in b['seq']], dtype=np.int32) 59 | S[i, :l] = indices 60 | ss_pos[i, :l] = np.asarray([1 if ss_val!='.' else 0 for ss_val in b['ss']], dtype=np.int32) 61 | ss_pair.append(find_bracket_pairs(b['ss'], b['seq'])) 62 | names.append(b['name']) 63 | 64 | clus[i] = b['cluster'] 65 | 66 | mask = np.isfinite(np.sum(X,(2,3))).astype(np.float32) # atom mask 67 | numbers = np.sum(mask, axis=1).astype(np.int) 68 | S_new = np.zeros_like(S) 69 | X_new = np.zeros_like(X)+np.nan 70 | for i, n in enumerate(numbers): 71 | X_new[i,:n,::] = X[i][mask[i]==1] 72 | S_new[i,:n] = S[i][mask[i]==1] 73 | 74 | X = X_new 75 | S = S_new 76 | isnan = np.isnan(X) 77 | mask = np.isfinite(np.sum(X,(2,3))).astype(np.float32) 78 | X[isnan] = 0. 79 | # Conversion 80 | S = torch.from_numpy(S).to(dtype=torch.long) 81 | X = torch.from_numpy(X).to(dtype=torch.float32) 82 | mask = torch.from_numpy(mask).to(dtype=torch.float32) 83 | clus = torch.from_numpy(clus).to(dtype=torch.long) 84 | return X, S, mask, lengths, clus, ss_pos, ss_pair, names 85 | 86 | def featurize_HC_Aug(batch): 87 | alphabet = 'AUCG' 88 | B = len(batch) 89 | lengths = np.array([len(b['seq']) for b in batch], dtype=np.int32) 90 | L_max = max([len(b['seq']) for b in batch]) 91 | X = np.zeros([B, L_max, 6, 3]) 92 | S = np.zeros([B, L_max], dtype=np.int32) 93 | clus = np.zeros([B], dtype=np.int32) 94 | ss_pos = np.zeros([B, L_max], dtype=np.int32) 95 | 96 | 97 | aug_idxs = [] 98 | aug_Xs = [] 99 | aug_tms = [] 100 | aug_rms = [] 101 | ss_pair = [] 102 | names = [] 103 | 104 | # Build the batch 105 | for i, b in enumerate(batch): 106 | x = np.stack([b['coords'][c] for c in ['P', 'O5\'', 'C5\'', 'C4\'', 'C3\'', 'O3\'']], 1) 107 | 108 | l = len(b['seq']) 109 | x_pad = np.pad(x, [[0, L_max-l], [0,0], [0,0]], 'constant', constant_values=(np.nan, )) 110 | X[i,:,:,:] = x_pad 111 | 112 | indices = np.asarray([alphabet.index(a) for a in b['seq']], dtype=np.int32) 113 | S[i, :l] = indices 114 | ss_pos[i, :l] = np.asarray([1 if ss_val!='.' else 0 for ss_val in b['ss']], dtype=np.int32) 115 | ss_pair.append(find_bracket_pairs(b['ss'], b['seq'])) 116 | names.append(b['name']) 117 | 118 | clus[i] = b['cluster'] 119 | 120 | aug_Xs.append([]) 121 | aug_tms.append([]) 122 | aug_rms.append([]) 123 | if len(batch[i]['augs']) > 0: 124 | aug_idxs.append(i) 125 | for aug_item in batch[i]['augs']: 126 | aug_x = np.stack([aug_item['coords'][c] for c in ['P', 'O5\'', 'C5\'', 'C4\'', 'C3\'', 'O3\'']], 1) 127 | aug_x_pad = np.pad(aug_x, [[0, L_max-l], [0,0], [0,0]], 'constant', constant_values=(np.nan, )) 128 | aug_x_pad[np.isnan(aug_x_pad)] = 0. 129 | aug_Xs[i].append(torch.from_numpy(aug_x_pad).to(dtype=torch.float32)) 130 | aug_tms[i].append(aug_item['tm-score']) 131 | aug_rms[i].append(aug_item['rmsd']) 132 | 133 | 134 | mask = np.isfinite(np.sum(X,(2,3))).astype(np.float32) # atom mask 135 | numbers = np.sum(mask, axis=1).astype(np.int) 136 | S_new = np.zeros_like(S) 137 | X_new = np.zeros_like(X)+np.nan 138 | for i, n in enumerate(numbers): 139 | X_new[i,:n,::] = X[i][mask[i]==1] 140 | S_new[i,:n] = S[i][mask[i]==1] 141 | 142 | X = X_new 143 | S = S_new 144 | isnan = np.isnan(X) 145 | mask = np.isfinite(np.sum(X,(2,3))).astype(np.float32) 146 | X[isnan] = 0. 147 | # Conversion 148 | S = torch.from_numpy(S).to(dtype=torch.long) 149 | X = torch.from_numpy(X).to(dtype=torch.float32) 150 | mask = torch.from_numpy(mask).to(dtype=torch.float32) 151 | clus = torch.from_numpy(clus).to(dtype=torch.long) 152 | return X, aug_Xs, aug_idxs, aug_tms, aug_rms, S, mask, lengths, clus, ss_pos, names -------------------------------------------------------------------------------- /API/recorder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class Recorder: 6 | def __init__(self, verbose=False, delta=0): 7 | self.verbose = verbose 8 | self.best_score = None 9 | self.val_loss_min = np.Inf 10 | self.delta = delta 11 | 12 | def __call__(self, val_loss, model, path): 13 | score = -val_loss 14 | if self.best_score is None: 15 | self.best_score = score 16 | self.save_checkpoint(val_loss, model, path) 17 | elif score >= self.best_score + self.delta: 18 | self.best_score = score 19 | self.save_checkpoint(val_loss, model, path) 20 | 21 | def save_checkpoint(self, val_loss, model, path): 22 | if self.verbose: 23 | print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') 24 | torch.save(model.state_dict(), path+'/'+'checkpoint.pth') 25 | self.val_loss_min = val_loss -------------------------------------------------------------------------------- /API/rpuzzles_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from tqdm import tqdm 4 | import _pickle as cPickle 5 | 6 | import torch.utils.data as data 7 | from .utils import cached_property 8 | 9 | 10 | class RPuzzlesDataset(data.Dataset): 11 | def __init__(self, path='./'): 12 | self.path = path 13 | self.data = self.cache_data 14 | 15 | @cached_property 16 | def cache_data(self): 17 | alphabet_set = set(['A', 'U', 'C', 'G']) 18 | rna_puzzles_data = [] 19 | if os.path.exists(self.path): 20 | data = cPickle.load(open(os.path.join(self.path), 'rb')) 21 | for entry in tqdm(data): 22 | for key, val in entry['coords'].items(): 23 | entry['coords'][key] = np.asarray(val) 24 | bad_chars = set([s for s in entry['seq']]).difference(alphabet_set) 25 | if len(bad_chars) == 0: 26 | rna_puzzles_data.append(entry) 27 | return rna_puzzles_data 28 | else: 29 | raise "no such file:{} !!!".format(self.path) 30 | 31 | def __len__(self): 32 | return len(self.data) 33 | 34 | def __getitem__(self, index): 35 | return self.data[index] -------------------------------------------------------------------------------- /API/struct2seq_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from tqdm import tqdm 4 | import _pickle as cPickle 5 | 6 | import torch.utils.data as data 7 | from .utils import cached_property 8 | 9 | 10 | class Struct2SeqDataset(data.Dataset): 11 | def __init__(self, path='./', mode='train'): 12 | self.path = path 13 | self.mode = mode 14 | self.data = self.cache_data[mode] 15 | 16 | @cached_property 17 | def cache_data(self): 18 | alphabet_set = set(['A', 'U', 'C', 'G']) 19 | if os.path.exists(self.path): 20 | data_dict = {'train': [], 'val': [], 'test': []} 21 | for split in ['train', 'val', 'test']: 22 | data = cPickle.load(open(os.path.join(self.path, split + '_data.pt'), 'rb')) 23 | for entry in tqdm(data): 24 | for key, val in entry['coords'].items(): 25 | entry['coords'][key] = np.asarray(val) 26 | bad_chars = set([s for s in entry['seq']]).difference(alphabet_set) 27 | if len(bad_chars) == 0: 28 | data_dict[split].append(entry) 29 | return data_dict 30 | else: 31 | raise "no such file:{} !!!".format(self.path) 32 | 33 | def change_mode(self, mode): 34 | self.data = self.cache_data[mode] 35 | 36 | def __len__(self): 37 | return len(self.data) 38 | 39 | def __getitem__(self, index): 40 | return self.data[index] -------------------------------------------------------------------------------- /API/utils.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class cached_property(object): 4 | """ 5 | Descriptor (non-data) for building an attribute on-demand on first use. 6 | """ 7 | def __init__(self, factory): 8 | """ 9 | is called such: factory(instance) to build the attribute. 10 | """ 11 | self._attr_name = factory.__name__ 12 | self._factory = factory 13 | 14 | def __get__(self, instance, owner): 15 | # Build the attribute. 16 | attr = self._factory(instance) 17 | 18 | # Cache the value; hide ourselves. 19 | setattr(instance, self._attr_name, attr) 20 | return attr -------------------------------------------------------------------------------- /assets/dataset_stat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/A4Bio/RDesign/12716159363cbcfc582573032c5ec2015d14d92c/assets/dataset_stat.png -------------------------------------------------------------------------------- /assets/visualization_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/A4Bio/RDesign/12716159363cbcfc582573032c5ec2015d14d92c/assets/visualization_example.png -------------------------------------------------------------------------------- /checkpoints/checkpoint.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/A4Bio/RDesign/12716159363cbcfc582573032c5ec2015d14d92c/checkpoints/checkpoint.pth -------------------------------------------------------------------------------- /checkpoints/log.log: -------------------------------------------------------------------------------- 1 | 2023-06-20 15:29:49,427 - 2 | device: cuda 3 | display_step: 10 4 | res_dir: ./results 5 | ex_name: ex_s222 6 | use_gpu: True 7 | gpu: 0 8 | seed: 222 9 | data_root: ./data/RNAsolo/ 10 | batch_size: 64 11 | num_workers: 0 12 | epoch: 200 13 | log_step: 1 14 | lr: 0.001 15 | node_feat_types: ['angle', 'distance', 'direction'] 16 | edge_feat_types: ['orientation', 'distance', 'direction'] 17 | num_encoder_layers: 3 18 | num_decoder_layers: 3 19 | hidden: 128 20 | k_neighbors: 30 21 | vocab_size: 4 22 | shuffle: 0.0 23 | dropout: 0.1 24 | smoothing: 0.1 25 | weigth_clu_con: 0.5 26 | weigth_sam_con: 0.5 27 | ss_temp: 0.5 28 | 2023-06-20 15:29:57,636 - Valid Perp: 4.0113 29 | 2023-06-20 15:29:57,637 - Epoch: 1, Steps: 28 | Train Loss: 1.4974 Train Perp: 4.4698 Valid Loss: 1.3891 Valid Perp: 4.0113 30 | 31 | 2023-06-20 15:30:04,766 - Valid Perp: 3.9455 32 | 2023-06-20 15:30:04,767 - Epoch: 2, Steps: 28 | Train Loss: 1.4208 Train Perp: 4.1403 Valid Loss: 1.3726 Valid Perp: 3.9455 33 | 34 | 2023-06-20 15:30:11,776 - Valid Perp: 3.9391 35 | 2023-06-20 15:30:11,777 - Epoch: 3, Steps: 28 | Train Loss: 1.4045 Train Perp: 4.0737 Valid Loss: 1.3709 Valid Perp: 3.9391 36 | 37 | 2023-06-20 15:30:18,706 - Valid Perp: 3.9543 38 | 2023-06-20 15:30:18,707 - Epoch: 4, Steps: 28 | Train Loss: 1.3948 Train Perp: 4.0342 Valid Loss: 1.3748 Valid Perp: 3.9543 39 | 40 | 2023-06-20 15:30:25,780 - Valid Perp: 3.9102 41 | 2023-06-20 15:30:25,780 - Epoch: 5, Steps: 28 | Train Loss: 1.3874 Train Perp: 4.0043 Valid Loss: 1.3636 Valid Perp: 3.9102 42 | 43 | 2023-06-20 15:30:33,073 - Valid Perp: 3.9071 44 | 2023-06-20 15:30:33,073 - Epoch: 6, Steps: 28 | Train Loss: 1.3831 Train Perp: 3.9871 Valid Loss: 1.3628 Valid Perp: 3.9071 45 | 46 | 2023-06-20 15:30:40,128 - Valid Perp: 3.8964 47 | 2023-06-20 15:30:40,128 - Epoch: 7, Steps: 28 | Train Loss: 1.3782 Train Perp: 3.9677 Valid Loss: 1.3601 Valid Perp: 3.8964 48 | 49 | 2023-06-20 15:30:47,058 - Valid Perp: 3.8791 50 | 2023-06-20 15:30:47,059 - Epoch: 8, Steps: 28 | Train Loss: 1.3734 Train Perp: 3.9488 Valid Loss: 1.3556 Valid Perp: 3.8791 51 | 52 | 2023-06-20 15:30:54,146 - Valid Perp: 3.9207 53 | 2023-06-20 15:30:54,146 - Epoch: 9, Steps: 28 | Train Loss: 1.3713 Train Perp: 3.9406 Valid Loss: 1.3663 Valid Perp: 3.9207 54 | 55 | 2023-06-20 15:31:01,522 - Valid Perp: 3.8854 56 | 2023-06-20 15:31:01,522 - Epoch: 10, Steps: 28 | Train Loss: 1.3665 Train Perp: 3.9216 Valid Loss: 1.3572 Valid Perp: 3.8854 57 | 58 | 2023-06-20 15:31:08,773 - Valid Perp: 3.8997 59 | 2023-06-20 15:31:08,773 - Epoch: 11, Steps: 28 | Train Loss: 1.3627 Train Perp: 3.9067 Valid Loss: 1.3609 Valid Perp: 3.8997 60 | 61 | 2023-06-20 15:31:15,739 - Valid Perp: 3.8739 62 | 2023-06-20 15:31:15,740 - Epoch: 12, Steps: 28 | Train Loss: 1.3597 Train Perp: 3.8952 Valid Loss: 1.3543 Valid Perp: 3.8739 63 | 64 | 2023-06-20 15:31:23,119 - Valid Perp: 3.8685 65 | 2023-06-20 15:31:23,120 - Epoch: 13, Steps: 28 | Train Loss: 1.3574 Train Perp: 3.8860 Valid Loss: 1.3529 Valid Perp: 3.8685 66 | 67 | 2023-06-20 15:31:30,168 - Valid Perp: 3.9257 68 | 2023-06-20 15:31:30,169 - Epoch: 14, Steps: 28 | Train Loss: 1.3575 Train Perp: 3.8864 Valid Loss: 1.3675 Valid Perp: 3.9257 69 | 70 | 2023-06-20 15:31:37,169 - Valid Perp: 3.8501 71 | 2023-06-20 15:31:37,170 - Epoch: 15, Steps: 28 | Train Loss: 1.3525 Train Perp: 3.8669 Valid Loss: 1.3481 Valid Perp: 3.8501 72 | 73 | 2023-06-20 15:31:44,152 - Valid Perp: 3.8653 74 | 2023-06-20 15:31:44,153 - Epoch: 16, Steps: 28 | Train Loss: 1.3493 Train Perp: 3.8545 Valid Loss: 1.3520 Valid Perp: 3.8653 75 | 76 | 2023-06-20 15:31:51,238 - Valid Perp: 3.8830 77 | 2023-06-20 15:31:51,238 - Epoch: 17, Steps: 28 | Train Loss: 1.3530 Train Perp: 3.8688 Valid Loss: 1.3566 Valid Perp: 3.8830 78 | 79 | 2023-06-20 15:31:58,402 - Valid Perp: 3.8601 80 | 2023-06-20 15:31:58,402 - Epoch: 18, Steps: 28 | Train Loss: 1.3427 Train Perp: 3.8292 Valid Loss: 1.3507 Valid Perp: 3.8601 81 | 82 | 2023-06-20 15:32:05,614 - Valid Perp: 3.8824 83 | 2023-06-20 15:32:05,615 - Epoch: 19, Steps: 28 | Train Loss: 1.3415 Train Perp: 3.8248 Valid Loss: 1.3565 Valid Perp: 3.8824 84 | 85 | 2023-06-20 15:32:12,764 - Valid Perp: 3.8298 86 | 2023-06-20 15:32:12,765 - Epoch: 20, Steps: 28 | Train Loss: 1.3354 Train Perp: 3.8016 Valid Loss: 1.3428 Valid Perp: 3.8298 87 | 88 | 2023-06-20 15:32:19,808 - Valid Perp: 3.8412 89 | 2023-06-20 15:32:19,808 - Epoch: 21, Steps: 28 | Train Loss: 1.3299 Train Perp: 3.7806 Valid Loss: 1.3458 Valid Perp: 3.8412 90 | 91 | 2023-06-20 15:32:26,933 - Valid Perp: 3.8222 92 | 2023-06-20 15:32:26,935 - Epoch: 22, Steps: 28 | Train Loss: 1.3282 Train Perp: 3.7741 Valid Loss: 1.3408 Valid Perp: 3.8222 93 | 94 | 2023-06-20 15:32:34,110 - Valid Perp: 3.8284 95 | 2023-06-20 15:32:34,111 - Epoch: 23, Steps: 28 | Train Loss: 1.3172 Train Perp: 3.7331 Valid Loss: 1.3424 Valid Perp: 3.8284 96 | 97 | 2023-06-20 15:32:41,229 - Valid Perp: 3.7813 98 | 2023-06-20 15:32:41,229 - Epoch: 24, Steps: 28 | Train Loss: 1.3167 Train Perp: 3.7310 Valid Loss: 1.3301 Valid Perp: 3.7813 99 | 100 | 2023-06-20 15:32:48,673 - Valid Perp: 3.8567 101 | 2023-06-20 15:32:48,673 - Epoch: 25, Steps: 28 | Train Loss: 1.3071 Train Perp: 3.6955 Valid Loss: 1.3498 Valid Perp: 3.8567 102 | 103 | 2023-06-20 15:32:55,755 - Valid Perp: 3.8285 104 | 2023-06-20 15:32:55,755 - Epoch: 26, Steps: 28 | Train Loss: 1.3056 Train Perp: 3.6900 Valid Loss: 1.3425 Valid Perp: 3.8285 105 | 106 | 2023-06-20 15:33:02,858 - Valid Perp: 3.7847 107 | 2023-06-20 15:33:02,861 - Epoch: 27, Steps: 28 | Train Loss: 1.2987 Train Perp: 3.6646 Valid Loss: 1.3310 Valid Perp: 3.7847 108 | 109 | 2023-06-20 15:33:10,345 - Valid Perp: 3.7935 110 | 2023-06-20 15:33:10,346 - Epoch: 28, Steps: 28 | Train Loss: 1.2908 Train Perp: 3.6358 Valid Loss: 1.3333 Valid Perp: 3.7935 111 | 112 | 2023-06-20 15:33:17,530 - Valid Perp: 3.8046 113 | 2023-06-20 15:33:17,530 - Epoch: 29, Steps: 28 | Train Loss: 1.2934 Train Perp: 3.6450 Valid Loss: 1.3362 Valid Perp: 3.8046 114 | 115 | 2023-06-20 15:33:24,518 - Valid Perp: 3.7515 116 | 2023-06-20 15:33:24,518 - Epoch: 30, Steps: 28 | Train Loss: 1.2806 Train Perp: 3.5989 Valid Loss: 1.3222 Valid Perp: 3.7515 117 | 118 | 2023-06-20 15:33:31,621 - Valid Perp: 3.8508 119 | 2023-06-20 15:33:31,622 - Epoch: 31, Steps: 28 | Train Loss: 1.2747 Train Perp: 3.5775 Valid Loss: 1.3483 Valid Perp: 3.8508 120 | 121 | 2023-06-20 15:33:38,945 - Valid Perp: 3.7500 122 | 2023-06-20 15:33:38,945 - Epoch: 32, Steps: 28 | Train Loss: 1.2747 Train Perp: 3.5777 Valid Loss: 1.3217 Valid Perp: 3.7500 123 | 124 | 2023-06-20 15:33:45,736 - Valid Perp: 3.7523 125 | 2023-06-20 15:33:45,739 - Epoch: 33, Steps: 28 | Train Loss: 1.2811 Train Perp: 3.6007 Valid Loss: 1.3224 Valid Perp: 3.7523 126 | 127 | 2023-06-20 15:33:52,817 - Valid Perp: 3.8763 128 | 2023-06-20 15:33:52,818 - Epoch: 34, Steps: 28 | Train Loss: 1.2600 Train Perp: 3.5256 Valid Loss: 1.3549 Valid Perp: 3.8763 129 | 130 | 2023-06-20 15:33:59,976 - Valid Perp: 3.8470 131 | 2023-06-20 15:33:59,978 - Epoch: 35, Steps: 28 | Train Loss: 1.2564 Train Perp: 3.5128 Valid Loss: 1.3473 Valid Perp: 3.8470 132 | 133 | 2023-06-20 15:34:07,293 - Valid Perp: 3.7222 134 | 2023-06-20 15:34:07,296 - Epoch: 36, Steps: 28 | Train Loss: 1.2521 Train Perp: 3.4976 Valid Loss: 1.3143 Valid Perp: 3.7222 135 | 136 | 2023-06-20 15:34:14,577 - Valid Perp: 3.7631 137 | 2023-06-20 15:34:14,578 - Epoch: 37, Steps: 28 | Train Loss: 1.2449 Train Perp: 3.4724 Valid Loss: 1.3252 Valid Perp: 3.7631 138 | 139 | 2023-06-20 15:34:21,478 - Valid Perp: 3.7425 140 | 2023-06-20 15:34:21,479 - Epoch: 38, Steps: 28 | Train Loss: 1.2273 Train Perp: 3.4121 Valid Loss: 1.3197 Valid Perp: 3.7425 141 | 142 | 2023-06-20 15:34:28,433 - Valid Perp: 3.7482 143 | 2023-06-20 15:34:28,433 - Epoch: 39, Steps: 28 | Train Loss: 1.2227 Train Perp: 3.3965 Valid Loss: 1.3213 Valid Perp: 3.7482 144 | 145 | 2023-06-20 15:34:35,560 - Valid Perp: 3.7572 146 | 2023-06-20 15:34:35,560 - Epoch: 40, Steps: 28 | Train Loss: 1.2219 Train Perp: 3.3937 Valid Loss: 1.3237 Valid Perp: 3.7572 147 | 148 | 2023-06-20 15:34:42,911 - Valid Perp: 3.7585 149 | 2023-06-20 15:34:42,911 - Epoch: 41, Steps: 28 | Train Loss: 1.2129 Train Perp: 3.3632 Valid Loss: 1.3240 Valid Perp: 3.7585 150 | 151 | 2023-06-20 15:34:49,896 - Valid Perp: 3.6944 152 | 2023-06-20 15:34:49,896 - Epoch: 42, Steps: 28 | Train Loss: 1.1992 Train Perp: 3.3176 Valid Loss: 1.3068 Valid Perp: 3.6944 153 | 154 | 2023-06-20 15:34:57,023 - Valid Perp: 3.6606 155 | 2023-06-20 15:34:57,023 - Epoch: 43, Steps: 28 | Train Loss: 1.2014 Train Perp: 3.3246 Valid Loss: 1.2976 Valid Perp: 3.6606 156 | 157 | 2023-06-20 15:35:04,129 - Valid Perp: 3.6818 158 | 2023-06-20 15:35:04,130 - Epoch: 44, Steps: 28 | Train Loss: 1.1831 Train Perp: 3.2646 Valid Loss: 1.3034 Valid Perp: 3.6818 159 | 160 | 2023-06-20 15:35:11,247 - Valid Perp: 3.6219 161 | 2023-06-20 15:35:11,248 - Epoch: 45, Steps: 28 | Train Loss: 1.1846 Train Perp: 3.2695 Valid Loss: 1.2870 Valid Perp: 3.6219 162 | 163 | 2023-06-20 15:35:18,702 - Valid Perp: 3.6168 164 | 2023-06-20 15:35:18,702 - Epoch: 46, Steps: 28 | Train Loss: 1.1805 Train Perp: 3.2559 Valid Loss: 1.2856 Valid Perp: 3.6168 165 | 166 | 2023-06-20 15:35:26,375 - Valid Perp: 3.6597 167 | 2023-06-20 15:35:26,376 - Epoch: 47, Steps: 28 | Train Loss: 1.1685 Train Perp: 3.2173 Valid Loss: 1.2974 Valid Perp: 3.6597 168 | 169 | 2023-06-20 15:35:33,761 - Valid Perp: 3.6778 170 | 2023-06-20 15:35:33,762 - Epoch: 48, Steps: 28 | Train Loss: 1.1651 Train Perp: 3.2064 Valid Loss: 1.3023 Valid Perp: 3.6778 171 | 172 | 2023-06-20 15:35:40,858 - Valid Perp: 3.6490 173 | 2023-06-20 15:35:40,859 - Epoch: 49, Steps: 28 | Train Loss: 1.1465 Train Perp: 3.1471 Valid Loss: 1.2944 Valid Perp: 3.6490 174 | 175 | 2023-06-20 15:35:48,121 - Valid Perp: 3.5758 176 | 2023-06-20 15:35:48,121 - Epoch: 50, Steps: 28 | Train Loss: 1.1444 Train Perp: 3.1406 Valid Loss: 1.2742 Valid Perp: 3.5758 177 | 178 | 2023-06-20 15:35:55,150 - Valid Perp: 3.6223 179 | 2023-06-20 15:35:55,151 - Epoch: 51, Steps: 28 | Train Loss: 1.1263 Train Perp: 3.0842 Valid Loss: 1.2871 Valid Perp: 3.6223 180 | 181 | 2023-06-20 15:36:02,124 - Valid Perp: 3.5903 182 | 2023-06-20 15:36:02,125 - Epoch: 52, Steps: 28 | Train Loss: 1.1265 Train Perp: 3.0847 Valid Loss: 1.2782 Valid Perp: 3.5903 183 | 184 | 2023-06-20 15:36:09,138 - Valid Perp: 3.6537 185 | 2023-06-20 15:36:09,138 - Epoch: 53, Steps: 28 | Train Loss: 1.1274 Train Perp: 3.0877 Valid Loss: 1.2957 Valid Perp: 3.6537 186 | 187 | 2023-06-20 15:36:16,099 - Valid Perp: 3.5999 188 | 2023-06-20 15:36:16,100 - Epoch: 54, Steps: 28 | Train Loss: 1.1027 Train Perp: 3.0124 Valid Loss: 1.2809 Valid Perp: 3.5999 189 | 190 | 2023-06-20 15:36:23,082 - Valid Perp: 3.6385 191 | 2023-06-20 15:36:23,082 - Epoch: 55, Steps: 28 | Train Loss: 1.0971 Train Perp: 2.9955 Valid Loss: 1.2916 Valid Perp: 3.6385 192 | 193 | 2023-06-20 15:36:30,082 - Valid Perp: 3.6014 194 | 2023-06-20 15:36:30,082 - Epoch: 56, Steps: 28 | Train Loss: 1.0858 Train Perp: 2.9619 Valid Loss: 1.2813 Valid Perp: 3.6014 195 | 196 | 2023-06-20 15:36:36,928 - Valid Perp: 3.6744 197 | 2023-06-20 15:36:36,931 - Epoch: 57, Steps: 28 | Train Loss: 1.0695 Train Perp: 2.9139 Valid Loss: 1.3014 Valid Perp: 3.6744 198 | 199 | 2023-06-20 15:36:43,965 - Valid Perp: 3.6513 200 | 2023-06-20 15:36:43,965 - Epoch: 58, Steps: 28 | Train Loss: 1.0569 Train Perp: 2.8775 Valid Loss: 1.2951 Valid Perp: 3.6513 201 | 202 | 2023-06-20 15:36:51,334 - Valid Perp: 3.6804 203 | 2023-06-20 15:36:51,335 - Epoch: 59, Steps: 28 | Train Loss: 1.0435 Train Perp: 2.8392 Valid Loss: 1.3030 Valid Perp: 3.6804 204 | 205 | 2023-06-20 15:36:58,474 - Valid Perp: 3.6399 206 | 2023-06-20 15:36:58,475 - Epoch: 60, Steps: 28 | Train Loss: 1.0363 Train Perp: 2.8186 Valid Loss: 1.2920 Valid Perp: 3.6399 207 | 208 | 2023-06-20 15:37:05,383 - Valid Perp: 3.6939 209 | 2023-06-20 15:37:05,383 - Epoch: 61, Steps: 28 | Train Loss: 1.0332 Train Perp: 2.8099 Valid Loss: 1.3067 Valid Perp: 3.6939 210 | 211 | 2023-06-20 15:37:12,638 - Valid Perp: 3.6957 212 | 2023-06-20 15:37:12,641 - Epoch: 62, Steps: 28 | Train Loss: 1.0134 Train Perp: 2.7550 Valid Loss: 1.3072 Valid Perp: 3.6957 213 | 214 | 2023-06-20 15:37:19,773 - Valid Perp: 3.7866 215 | 2023-06-20 15:37:19,773 - Epoch: 63, Steps: 28 | Train Loss: 0.9997 Train Perp: 2.7174 Valid Loss: 1.3315 Valid Perp: 3.7866 216 | 217 | 2023-06-20 15:37:26,782 - Valid Perp: 3.6891 218 | 2023-06-20 15:37:26,782 - Epoch: 64, Steps: 28 | Train Loss: 0.9950 Train Perp: 2.7048 Valid Loss: 1.3054 Valid Perp: 3.6891 219 | 220 | 2023-06-20 15:37:33,756 - Valid Perp: 3.6879 221 | 2023-06-20 15:37:33,757 - Epoch: 65, Steps: 28 | Train Loss: 0.9658 Train Perp: 2.6269 Valid Loss: 1.3051 Valid Perp: 3.6879 222 | 223 | 2023-06-20 15:37:40,720 - Valid Perp: 3.8919 224 | 2023-06-20 15:37:40,720 - Epoch: 66, Steps: 28 | Train Loss: 0.9605 Train Perp: 2.6130 Valid Loss: 1.3589 Valid Perp: 3.8919 225 | 226 | 2023-06-20 15:37:47,697 - Valid Perp: 3.8482 227 | 2023-06-20 15:37:47,698 - Epoch: 67, Steps: 28 | Train Loss: 0.9282 Train Perp: 2.5299 Valid Loss: 1.3476 Valid Perp: 3.8482 228 | 229 | 2023-06-20 15:37:54,788 - Valid Perp: 3.8788 230 | 2023-06-20 15:37:54,789 - Epoch: 68, Steps: 28 | Train Loss: 0.9195 Train Perp: 2.5081 Valid Loss: 1.3555 Valid Perp: 3.8788 231 | 232 | 2023-06-20 15:38:01,812 - Valid Perp: 3.8352 233 | 2023-06-20 15:38:01,813 - Epoch: 69, Steps: 28 | Train Loss: 0.9152 Train Perp: 2.4973 Valid Loss: 1.3442 Valid Perp: 3.8352 234 | 235 | 2023-06-20 15:38:09,008 - Valid Perp: 3.8573 236 | 2023-06-20 15:38:09,008 - Epoch: 70, Steps: 28 | Train Loss: 0.8844 Train Perp: 2.4215 Valid Loss: 1.3500 Valid Perp: 3.8573 237 | 238 | 2023-06-20 15:38:16,200 - Valid Perp: 3.8358 239 | 2023-06-20 15:38:16,200 - Epoch: 71, Steps: 28 | Train Loss: 0.8649 Train Perp: 2.3748 Valid Loss: 1.3444 Valid Perp: 3.8358 240 | 241 | 2023-06-20 15:38:23,614 - Valid Perp: 3.9688 242 | 2023-06-20 15:38:23,615 - Epoch: 72, Steps: 28 | Train Loss: 0.8429 Train Perp: 2.3230 Valid Loss: 1.3785 Valid Perp: 3.9688 243 | 244 | 2023-06-20 15:38:30,863 - Valid Perp: 3.9993 245 | 2023-06-20 15:38:30,865 - Epoch: 73, Steps: 28 | Train Loss: 0.8225 Train Perp: 2.2763 Valid Loss: 1.3861 Valid Perp: 3.9993 246 | 247 | 2023-06-20 15:38:38,109 - Valid Perp: 4.1163 248 | 2023-06-20 15:38:38,110 - Epoch: 74, Steps: 28 | Train Loss: 0.8114 Train Perp: 2.2511 Valid Loss: 1.4149 Valid Perp: 4.1163 249 | 250 | 2023-06-20 15:38:44,987 - Valid Perp: 4.0428 251 | 2023-06-20 15:38:44,988 - Epoch: 75, Steps: 28 | Train Loss: 0.7918 Train Perp: 2.2075 Valid Loss: 1.3969 Valid Perp: 4.0428 252 | 253 | 2023-06-20 15:38:52,030 - Valid Perp: 4.1999 254 | 2023-06-20 15:38:52,033 - Epoch: 76, Steps: 28 | Train Loss: 0.7829 Train Perp: 2.1878 Valid Loss: 1.4351 Valid Perp: 4.1999 255 | 256 | 2023-06-20 15:38:59,064 - Valid Perp: 4.0748 257 | 2023-06-20 15:38:59,065 - Epoch: 77, Steps: 28 | Train Loss: 0.7645 Train Perp: 2.1478 Valid Loss: 1.4048 Valid Perp: 4.0748 258 | 259 | 2023-06-20 15:39:06,050 - Valid Perp: 4.2477 260 | 2023-06-20 15:39:06,050 - Epoch: 78, Steps: 28 | Train Loss: 0.7475 Train Perp: 2.1117 Valid Loss: 1.4464 Valid Perp: 4.2477 261 | 262 | 2023-06-20 15:39:13,237 - Valid Perp: 4.2654 263 | 2023-06-20 15:39:13,238 - Epoch: 79, Steps: 28 | Train Loss: 0.7315 Train Perp: 2.0782 Valid Loss: 1.4505 Valid Perp: 4.2654 264 | 265 | 2023-06-20 15:39:20,257 - Valid Perp: 4.4625 266 | 2023-06-20 15:39:20,257 - Epoch: 80, Steps: 28 | Train Loss: 0.7080 Train Perp: 2.0299 Valid Loss: 1.4957 Valid Perp: 4.4625 267 | 268 | 2023-06-20 15:39:27,197 - Valid Perp: 4.3577 269 | 2023-06-20 15:39:27,197 - Epoch: 81, Steps: 28 | Train Loss: 0.7009 Train Perp: 2.0155 Valid Loss: 1.4719 Valid Perp: 4.3577 270 | 271 | 2023-06-20 15:39:34,432 - Valid Perp: 4.3383 272 | 2023-06-20 15:39:34,432 - Epoch: 82, Steps: 28 | Train Loss: 0.6853 Train Perp: 1.9845 Valid Loss: 1.4675 Valid Perp: 4.3383 273 | 274 | 2023-06-20 15:39:41,716 - Valid Perp: 4.6709 275 | 2023-06-20 15:39:41,716 - Epoch: 83, Steps: 28 | Train Loss: 0.6593 Train Perp: 1.9334 Valid Loss: 1.5413 Valid Perp: 4.6709 276 | 277 | 2023-06-20 15:39:48,744 - Valid Perp: 4.7931 278 | 2023-06-20 15:39:48,747 - Epoch: 84, Steps: 28 | Train Loss: 0.6481 Train Perp: 1.9119 Valid Loss: 1.5672 Valid Perp: 4.7931 279 | 280 | 2023-06-20 15:39:55,787 - Valid Perp: 4.7067 281 | 2023-06-20 15:39:55,787 - Epoch: 85, Steps: 28 | Train Loss: 0.6306 Train Perp: 1.8787 Valid Loss: 1.5490 Valid Perp: 4.7067 282 | 283 | 2023-06-20 15:40:02,776 - Valid Perp: 4.7553 284 | 2023-06-20 15:40:02,776 - Epoch: 86, Steps: 28 | Train Loss: 0.6132 Train Perp: 1.8464 Valid Loss: 1.5593 Valid Perp: 4.7553 285 | 286 | 2023-06-20 15:40:09,774 - Valid Perp: 4.4594 287 | 2023-06-20 15:40:09,776 - Epoch: 87, Steps: 28 | Train Loss: 0.6030 Train Perp: 1.8276 Valid Loss: 1.4950 Valid Perp: 4.4594 288 | 289 | 2023-06-20 15:40:16,875 - Valid Perp: 4.7817 290 | 2023-06-20 15:40:16,876 - Epoch: 88, Steps: 28 | Train Loss: 0.5928 Train Perp: 1.8091 Valid Loss: 1.5648 Valid Perp: 4.7817 291 | 292 | 2023-06-20 15:40:23,885 - Valid Perp: 5.0242 293 | 2023-06-20 15:40:23,886 - Epoch: 89, Steps: 28 | Train Loss: 0.5765 Train Perp: 1.7797 Valid Loss: 1.6143 Valid Perp: 5.0242 294 | 295 | 2023-06-20 15:40:31,086 - Valid Perp: 4.7411 296 | 2023-06-20 15:40:31,087 - Epoch: 90, Steps: 28 | Train Loss: 0.5648 Train Perp: 1.7591 Valid Loss: 1.5563 Valid Perp: 4.7411 297 | 298 | 2023-06-20 15:40:37,972 - Valid Perp: 5.4589 299 | 2023-06-20 15:40:37,972 - Epoch: 91, Steps: 28 | Train Loss: 0.5430 Train Perp: 1.7211 Valid Loss: 1.6972 Valid Perp: 5.4589 300 | 301 | 2023-06-20 15:40:45,181 - Valid Perp: 5.3676 302 | 2023-06-20 15:40:45,181 - Epoch: 92, Steps: 28 | Train Loss: 0.5305 Train Perp: 1.6999 Valid Loss: 1.6804 Valid Perp: 5.3676 303 | 304 | 2023-06-20 15:40:52,086 - Valid Perp: 5.4124 305 | 2023-06-20 15:40:52,086 - Epoch: 93, Steps: 28 | Train Loss: 0.5113 Train Perp: 1.6674 Valid Loss: 1.6887 Valid Perp: 5.4124 306 | 307 | 2023-06-20 15:40:59,098 - Valid Perp: 5.2817 308 | 2023-06-20 15:40:59,098 - Epoch: 94, Steps: 28 | Train Loss: 0.5003 Train Perp: 1.6492 Valid Loss: 1.6643 Valid Perp: 5.2817 309 | 310 | 2023-06-20 15:41:06,173 - Valid Perp: 5.2428 311 | 2023-06-20 15:41:06,173 - Epoch: 95, Steps: 28 | Train Loss: 0.5009 Train Perp: 1.6502 Valid Loss: 1.6569 Valid Perp: 5.2428 312 | 313 | 2023-06-20 15:41:13,364 - Valid Perp: 5.3867 314 | 2023-06-20 15:41:13,365 - Epoch: 96, Steps: 28 | Train Loss: 0.4704 Train Perp: 1.6006 Valid Loss: 1.6839 Valid Perp: 5.3867 315 | 316 | 2023-06-20 15:41:20,485 - Valid Perp: 5.2434 317 | 2023-06-20 15:41:20,486 - Epoch: 97, Steps: 28 | Train Loss: 0.4613 Train Perp: 1.5861 Valid Loss: 1.6570 Valid Perp: 5.2434 318 | 319 | 2023-06-20 15:41:27,413 - Valid Perp: 6.2728 320 | 2023-06-20 15:41:27,413 - Epoch: 98, Steps: 28 | Train Loss: 0.4517 Train Perp: 1.5709 Valid Loss: 1.8362 Valid Perp: 6.2728 321 | 322 | 2023-06-20 15:41:34,438 - Valid Perp: 6.0197 323 | 2023-06-20 15:41:34,438 - Epoch: 99, Steps: 28 | Train Loss: 0.4382 Train Perp: 1.5500 Valid Loss: 1.7950 Valid Perp: 6.0197 324 | 325 | 2023-06-20 15:41:41,550 - Valid Perp: 5.4782 326 | 2023-06-20 15:41:41,551 - Epoch: 100, Steps: 28 | Train Loss: 0.4258 Train Perp: 1.5309 Valid Loss: 1.7008 Valid Perp: 5.4782 327 | 328 | 2023-06-20 15:41:48,692 - Valid Perp: 5.9376 329 | 2023-06-20 15:41:48,748 - Epoch: 101, Steps: 28 | Train Loss: 0.4161 Train Perp: 1.5161 Valid Loss: 1.7813 Valid Perp: 5.9376 330 | 331 | 2023-06-20 15:41:55,896 - Valid Perp: 5.8198 332 | 2023-06-20 15:41:55,896 - Epoch: 102, Steps: 28 | Train Loss: 0.4141 Train Perp: 1.5130 Valid Loss: 1.7613 Valid Perp: 5.8198 333 | 334 | 2023-06-20 15:42:02,769 - Valid Perp: 6.2861 335 | 2023-06-20 15:42:02,770 - Epoch: 103, Steps: 28 | Train Loss: 0.3991 Train Perp: 1.4905 Valid Loss: 1.8383 Valid Perp: 6.2861 336 | 337 | 2023-06-20 15:42:09,769 - Valid Perp: 5.7898 338 | 2023-06-20 15:42:09,770 - Epoch: 104, Steps: 28 | Train Loss: 0.3843 Train Perp: 1.4686 Valid Loss: 1.7561 Valid Perp: 5.7898 339 | 340 | 2023-06-20 15:42:16,695 - Valid Perp: 5.8168 341 | 2023-06-20 15:42:16,695 - Epoch: 105, Steps: 28 | Train Loss: 0.3816 Train Perp: 1.4646 Valid Loss: 1.7608 Valid Perp: 5.8168 342 | 343 | 2023-06-20 15:42:23,870 - Valid Perp: 5.9182 344 | 2023-06-20 15:42:23,870 - Epoch: 106, Steps: 28 | Train Loss: 0.3614 Train Perp: 1.4353 Valid Loss: 1.7780 Valid Perp: 5.9182 345 | 346 | 2023-06-20 15:42:31,027 - Valid Perp: 7.7883 347 | 2023-06-20 15:42:31,027 - Epoch: 107, Steps: 28 | Train Loss: 0.3630 Train Perp: 1.4376 Valid Loss: 2.0526 Valid Perp: 7.7883 348 | 349 | 2023-06-20 15:42:38,165 - Valid Perp: 6.4514 350 | 2023-06-20 15:42:38,166 - Epoch: 108, Steps: 28 | Train Loss: 0.3520 Train Perp: 1.4219 Valid Loss: 1.8643 Valid Perp: 6.4514 351 | 352 | 2023-06-20 15:42:45,418 - Valid Perp: 6.6854 353 | 2023-06-20 15:42:45,419 - Epoch: 109, Steps: 28 | Train Loss: 0.3383 Train Perp: 1.4025 Valid Loss: 1.8999 Valid Perp: 6.6854 354 | 355 | 2023-06-20 15:42:52,971 - Valid Perp: 8.2627 356 | 2023-06-20 15:42:52,971 - Epoch: 110, Steps: 28 | Train Loss: 0.3289 Train Perp: 1.3895 Valid Loss: 2.1117 Valid Perp: 8.2627 357 | 358 | 2023-06-20 15:42:59,902 - Valid Perp: 7.0740 359 | 2023-06-20 15:42:59,902 - Epoch: 111, Steps: 28 | Train Loss: 0.3275 Train Perp: 1.3876 Valid Loss: 1.9564 Valid Perp: 7.0740 360 | 361 | 2023-06-20 15:43:06,875 - Valid Perp: 7.1661 362 | 2023-06-20 15:43:06,877 - Epoch: 112, Steps: 28 | Train Loss: 0.3048 Train Perp: 1.3563 Valid Loss: 1.9694 Valid Perp: 7.1661 363 | 364 | 2023-06-20 15:43:13,851 - Valid Perp: 7.2806 365 | 2023-06-20 15:43:13,852 - Epoch: 113, Steps: 28 | Train Loss: 0.3099 Train Perp: 1.3633 Valid Loss: 1.9852 Valid Perp: 7.2806 366 | 367 | 2023-06-20 15:43:20,877 - Valid Perp: 7.6165 368 | 2023-06-20 15:43:20,877 - Epoch: 114, Steps: 28 | Train Loss: 0.3000 Train Perp: 1.3498 Valid Loss: 2.0303 Valid Perp: 7.6165 369 | 370 | 2023-06-20 15:43:28,245 - Valid Perp: 7.2235 371 | 2023-06-20 15:43:28,245 - Epoch: 115, Steps: 28 | Train Loss: 0.2952 Train Perp: 1.3434 Valid Loss: 1.9773 Valid Perp: 7.2235 372 | 373 | 2023-06-20 15:43:35,430 - Valid Perp: 7.4002 374 | 2023-06-20 15:43:35,430 - Epoch: 116, Steps: 28 | Train Loss: 0.2968 Train Perp: 1.3455 Valid Loss: 2.0015 Valid Perp: 7.4002 375 | 376 | 2023-06-20 15:43:42,502 - Valid Perp: 8.0922 377 | 2023-06-20 15:43:42,502 - Epoch: 117, Steps: 28 | Train Loss: 0.2821 Train Perp: 1.3260 Valid Loss: 2.0909 Valid Perp: 8.0922 378 | 379 | 2023-06-20 15:43:49,444 - Valid Perp: 6.2877 380 | 2023-06-20 15:43:49,445 - Epoch: 118, Steps: 28 | Train Loss: 0.2637 Train Perp: 1.3018 Valid Loss: 1.8386 Valid Perp: 6.2877 381 | 382 | 2023-06-20 15:43:56,562 - Valid Perp: 7.9754 383 | 2023-06-20 15:43:56,562 - Epoch: 119, Steps: 28 | Train Loss: 0.2585 Train Perp: 1.2950 Valid Loss: 2.0764 Valid Perp: 7.9754 384 | 385 | 2023-06-20 15:44:03,557 - Valid Perp: 7.9007 386 | 2023-06-20 15:44:03,558 - Epoch: 120, Steps: 28 | Train Loss: 0.2580 Train Perp: 1.2943 Valid Loss: 2.0669 Valid Perp: 7.9007 387 | 388 | 2023-06-20 15:44:10,645 - Valid Perp: 10.4021 389 | 2023-06-20 15:44:10,647 - Epoch: 121, Steps: 28 | Train Loss: 0.2464 Train Perp: 1.2794 Valid Loss: 2.3420 Valid Perp: 10.4021 390 | 391 | 2023-06-20 15:44:17,777 - Valid Perp: 9.2737 392 | 2023-06-20 15:44:17,778 - Epoch: 122, Steps: 28 | Train Loss: 0.2454 Train Perp: 1.2781 Valid Loss: 2.2272 Valid Perp: 9.2737 393 | 394 | 2023-06-20 15:44:24,932 - Valid Perp: 7.7123 395 | 2023-06-20 15:44:24,935 - Epoch: 123, Steps: 28 | Train Loss: 0.2330 Train Perp: 1.2624 Valid Loss: 2.0428 Valid Perp: 7.7123 396 | 397 | 2023-06-20 15:44:31,902 - Valid Perp: 10.4214 398 | 2023-06-20 15:44:31,905 - Epoch: 124, Steps: 28 | Train Loss: 0.2322 Train Perp: 1.2613 Valid Loss: 2.3439 Valid Perp: 10.4214 399 | 400 | 2023-06-20 15:44:39,033 - Valid Perp: 9.4715 401 | 2023-06-20 15:44:39,034 - Epoch: 125, Steps: 28 | Train Loss: 0.2211 Train Perp: 1.2474 Valid Loss: 2.2483 Valid Perp: 9.4715 402 | 403 | 2023-06-20 15:44:45,985 - Valid Perp: 9.4634 404 | 2023-06-20 15:44:45,987 - Epoch: 126, Steps: 28 | Train Loss: 0.2183 Train Perp: 1.2439 Valid Loss: 2.2474 Valid Perp: 9.4634 405 | 406 | 2023-06-20 15:44:53,204 - Valid Perp: 9.6315 407 | 2023-06-20 15:44:53,207 - Epoch: 127, Steps: 28 | Train Loss: 0.2089 Train Perp: 1.2323 Valid Loss: 2.2650 Valid Perp: 9.6315 408 | 409 | 2023-06-20 15:45:00,324 - Valid Perp: 11.2635 410 | 2023-06-20 15:45:00,324 - Epoch: 128, Steps: 28 | Train Loss: 0.2009 Train Perp: 1.2225 Valid Loss: 2.4216 Valid Perp: 11.2635 411 | 412 | 2023-06-20 15:45:07,455 - Valid Perp: 11.4176 413 | 2023-06-20 15:45:07,458 - Epoch: 129, Steps: 28 | Train Loss: 0.2023 Train Perp: 1.2242 Valid Loss: 2.4352 Valid Perp: 11.4176 414 | 415 | 2023-06-20 15:45:14,509 - Valid Perp: 9.6132 416 | 2023-06-20 15:45:14,510 - Epoch: 130, Steps: 28 | Train Loss: 0.1969 Train Perp: 1.2176 Valid Loss: 2.2631 Valid Perp: 9.6132 417 | 418 | 2023-06-20 15:45:21,530 - Valid Perp: 8.3788 419 | 2023-06-20 15:45:21,531 - Epoch: 131, Steps: 28 | Train Loss: 0.1857 Train Perp: 1.2041 Valid Loss: 2.1257 Valid Perp: 8.3788 420 | 421 | 2023-06-20 15:45:28,651 - Valid Perp: 8.0204 422 | 2023-06-20 15:45:28,654 - Epoch: 132, Steps: 28 | Train Loss: 0.1867 Train Perp: 1.2053 Valid Loss: 2.0820 Valid Perp: 8.0204 423 | 424 | 2023-06-20 15:45:35,595 - Valid Perp: 10.9108 425 | 2023-06-20 15:45:35,596 - Epoch: 133, Steps: 28 | Train Loss: 0.1796 Train Perp: 1.1968 Valid Loss: 2.3898 Valid Perp: 10.9108 426 | 427 | 2023-06-20 15:45:42,661 - Valid Perp: 9.8892 428 | 2023-06-20 15:45:42,661 - Epoch: 134, Steps: 28 | Train Loss: 0.1698 Train Perp: 1.1851 Valid Loss: 2.2914 Valid Perp: 9.8892 429 | 430 | 2023-06-20 15:45:49,837 - Valid Perp: 11.9400 431 | 2023-06-20 15:45:49,839 - Epoch: 135, Steps: 28 | Train Loss: 0.1650 Train Perp: 1.1794 Valid Loss: 2.4799 Valid Perp: 11.9400 432 | 433 | 2023-06-20 15:45:56,844 - Valid Perp: 10.6555 434 | 2023-06-20 15:45:56,845 - Epoch: 136, Steps: 28 | Train Loss: 0.1644 Train Perp: 1.1787 Valid Loss: 2.3661 Valid Perp: 10.6555 435 | 436 | 2023-06-20 15:46:03,811 - Valid Perp: 9.8457 437 | 2023-06-20 15:46:03,814 - Epoch: 137, Steps: 28 | Train Loss: 0.1631 Train Perp: 1.1771 Valid Loss: 2.2870 Valid Perp: 9.8457 438 | 439 | 2023-06-20 15:46:10,698 - Valid Perp: 14.2152 440 | 2023-06-20 15:46:10,698 - Epoch: 138, Steps: 28 | Train Loss: 0.1623 Train Perp: 1.1762 Valid Loss: 2.6543 Valid Perp: 14.2152 441 | 442 | 2023-06-20 15:46:17,745 - Valid Perp: 13.4372 443 | 2023-06-20 15:46:17,746 - Epoch: 139, Steps: 28 | Train Loss: 0.1552 Train Perp: 1.1679 Valid Loss: 2.5980 Valid Perp: 13.4372 444 | 445 | 2023-06-20 15:46:24,750 - Valid Perp: 13.6954 446 | 2023-06-20 15:46:24,751 - Epoch: 140, Steps: 28 | Train Loss: 0.1468 Train Perp: 1.1582 Valid Loss: 2.6171 Valid Perp: 13.6954 447 | 448 | 2023-06-20 15:46:31,871 - Valid Perp: 15.0808 449 | 2023-06-20 15:46:31,871 - Epoch: 141, Steps: 28 | Train Loss: 0.1435 Train Perp: 1.1543 Valid Loss: 2.7134 Valid Perp: 15.0808 450 | 451 | 2023-06-20 15:46:39,180 - Valid Perp: 14.9466 452 | 2023-06-20 15:46:39,180 - Epoch: 142, Steps: 28 | Train Loss: 0.1446 Train Perp: 1.1555 Valid Loss: 2.7045 Valid Perp: 14.9466 453 | 454 | 2023-06-20 15:46:46,185 - Valid Perp: 14.9847 455 | 2023-06-20 15:46:46,186 - Epoch: 143, Steps: 28 | Train Loss: 0.1400 Train Perp: 1.1503 Valid Loss: 2.7070 Valid Perp: 14.9847 456 | 457 | 2023-06-20 15:46:53,344 - Valid Perp: 15.4675 458 | 2023-06-20 15:46:53,345 - Epoch: 144, Steps: 28 | Train Loss: 0.1378 Train Perp: 1.1477 Valid Loss: 2.7387 Valid Perp: 15.4675 459 | 460 | 2023-06-20 15:47:00,475 - Valid Perp: 13.0683 461 | 2023-06-20 15:47:00,475 - Epoch: 145, Steps: 28 | Train Loss: 0.1313 Train Perp: 1.1403 Valid Loss: 2.5702 Valid Perp: 13.0683 462 | 463 | 2023-06-20 15:47:07,345 - Valid Perp: 14.2797 464 | 2023-06-20 15:47:07,347 - Epoch: 146, Steps: 28 | Train Loss: 0.1295 Train Perp: 1.1383 Valid Loss: 2.6588 Valid Perp: 14.2797 465 | 466 | 2023-06-20 15:47:14,308 - Valid Perp: 14.2807 467 | 2023-06-20 15:47:14,311 - Epoch: 147, Steps: 28 | Train Loss: 0.1254 Train Perp: 1.1336 Valid Loss: 2.6589 Valid Perp: 14.2807 468 | 469 | 2023-06-20 15:47:21,305 - Valid Perp: 11.2100 470 | 2023-06-20 15:47:21,306 - Epoch: 148, Steps: 28 | Train Loss: 0.1216 Train Perp: 1.1293 Valid Loss: 2.4168 Valid Perp: 11.2100 471 | 472 | 2023-06-20 15:47:28,459 - Valid Perp: 13.3194 473 | 2023-06-20 15:47:28,461 - Epoch: 149, Steps: 28 | Train Loss: 0.1247 Train Perp: 1.1329 Valid Loss: 2.5892 Valid Perp: 13.3194 474 | 475 | 2023-06-20 15:47:35,477 - Valid Perp: 11.9576 476 | 2023-06-20 15:47:35,479 - Epoch: 150, Steps: 28 | Train Loss: 0.1186 Train Perp: 1.1260 Valid Loss: 2.4814 Valid Perp: 11.9576 477 | 478 | 2023-06-20 15:47:42,733 - Valid Perp: 14.3778 479 | 2023-06-20 15:47:42,734 - Epoch: 151, Steps: 28 | Train Loss: 0.1173 Train Perp: 1.1245 Valid Loss: 2.6657 Valid Perp: 14.3778 480 | 481 | 2023-06-20 15:47:49,956 - Valid Perp: 14.1759 482 | 2023-06-20 15:47:49,958 - Epoch: 152, Steps: 28 | Train Loss: 0.1118 Train Perp: 1.1183 Valid Loss: 2.6515 Valid Perp: 14.1759 483 | 484 | 2023-06-20 15:47:56,923 - Valid Perp: 12.7249 485 | 2023-06-20 15:47:56,923 - Epoch: 153, Steps: 28 | Train Loss: 0.1106 Train Perp: 1.1169 Valid Loss: 2.5436 Valid Perp: 12.7249 486 | 487 | 2023-06-20 15:48:04,143 - Valid Perp: 14.8805 488 | 2023-06-20 15:48:04,145 - Epoch: 154, Steps: 28 | Train Loss: 0.1057 Train Perp: 1.1115 Valid Loss: 2.7001 Valid Perp: 14.8805 489 | 490 | 2023-06-20 15:48:11,311 - Valid Perp: 14.9733 491 | 2023-06-20 15:48:11,312 - Epoch: 155, Steps: 28 | Train Loss: 0.1058 Train Perp: 1.1116 Valid Loss: 2.7063 Valid Perp: 14.9733 492 | 493 | 2023-06-20 15:48:18,653 - Valid Perp: 22.2206 494 | 2023-06-20 15:48:18,653 - Epoch: 156, Steps: 28 | Train Loss: 0.1024 Train Perp: 1.1078 Valid Loss: 3.1010 Valid Perp: 22.2206 495 | 496 | 2023-06-20 15:48:25,690 - Valid Perp: 17.0409 497 | 2023-06-20 15:48:25,691 - Epoch: 157, Steps: 28 | Train Loss: 0.1034 Train Perp: 1.1089 Valid Loss: 2.8356 Valid Perp: 17.0409 498 | 499 | 2023-06-20 15:48:32,707 - Valid Perp: 15.8458 500 | 2023-06-20 15:48:32,707 - Epoch: 158, Steps: 28 | Train Loss: 0.0959 Train Perp: 1.1007 Valid Loss: 2.7629 Valid Perp: 15.8458 501 | 502 | 2023-06-20 15:48:39,693 - Valid Perp: 16.7121 503 | 2023-06-20 15:48:39,693 - Epoch: 159, Steps: 28 | Train Loss: 0.0940 Train Perp: 1.0986 Valid Loss: 2.8161 Valid Perp: 16.7121 504 | 505 | 2023-06-20 15:48:46,623 - Valid Perp: 16.2141 506 | 2023-06-20 15:48:46,624 - Epoch: 160, Steps: 28 | Train Loss: 0.0929 Train Perp: 1.0973 Valid Loss: 2.7859 Valid Perp: 16.2141 507 | 508 | 2023-06-20 15:48:53,489 - Valid Perp: 20.1110 509 | 2023-06-20 15:48:53,491 - Epoch: 161, Steps: 28 | Train Loss: 0.0932 Train Perp: 1.0976 Valid Loss: 3.0013 Valid Perp: 20.1110 510 | 511 | 2023-06-20 15:49:00,619 - Valid Perp: 16.0534 512 | 2023-06-20 15:49:00,620 - Epoch: 162, Steps: 28 | Train Loss: 0.0878 Train Perp: 1.0918 Valid Loss: 2.7759 Valid Perp: 16.0534 513 | 514 | 2023-06-20 15:49:07,789 - Valid Perp: 20.2584 515 | 2023-06-20 15:49:07,791 - Epoch: 163, Steps: 28 | Train Loss: 0.0858 Train Perp: 1.0895 Valid Loss: 3.0086 Valid Perp: 20.2584 516 | 517 | 2023-06-20 15:49:14,789 - Valid Perp: 16.3666 518 | 2023-06-20 15:49:14,792 - Epoch: 164, Steps: 28 | Train Loss: 0.0891 Train Perp: 1.0931 Valid Loss: 2.7952 Valid Perp: 16.3666 519 | 520 | 2023-06-20 15:49:21,840 - Valid Perp: 21.6643 521 | 2023-06-20 15:49:21,840 - Epoch: 165, Steps: 28 | Train Loss: 0.0873 Train Perp: 1.0913 Valid Loss: 3.0757 Valid Perp: 21.6643 522 | 523 | 2023-06-20 15:49:29,039 - Valid Perp: 18.8727 524 | 2023-06-20 15:49:29,041 - Epoch: 166, Steps: 28 | Train Loss: 0.0835 Train Perp: 1.0871 Valid Loss: 2.9377 Valid Perp: 18.8727 525 | 526 | 2023-06-20 15:49:36,203 - Valid Perp: 15.0239 527 | 2023-06-20 15:49:36,206 - Epoch: 167, Steps: 28 | Train Loss: 0.0813 Train Perp: 1.0847 Valid Loss: 2.7096 Valid Perp: 15.0239 528 | 529 | 2023-06-20 15:49:43,256 - Valid Perp: 16.7479 530 | 2023-06-20 15:49:43,259 - Epoch: 168, Steps: 28 | Train Loss: 0.0856 Train Perp: 1.0894 Valid Loss: 2.8183 Valid Perp: 16.7479 531 | 532 | 2023-06-20 15:49:50,383 - Valid Perp: 20.1794 533 | 2023-06-20 15:49:50,386 - Epoch: 169, Steps: 28 | Train Loss: 0.0817 Train Perp: 1.0851 Valid Loss: 3.0047 Valid Perp: 20.1794 534 | 535 | 2023-06-20 15:49:57,524 - Valid Perp: 23.0520 536 | 2023-06-20 15:49:57,525 - Epoch: 170, Steps: 28 | Train Loss: 0.0749 Train Perp: 1.0778 Valid Loss: 3.1378 Valid Perp: 23.0520 537 | 538 | 2023-06-20 15:50:04,826 - Valid Perp: 19.3180 539 | 2023-06-20 15:50:04,827 - Epoch: 171, Steps: 28 | Train Loss: 0.0755 Train Perp: 1.0785 Valid Loss: 2.9610 Valid Perp: 19.3180 540 | 541 | 2023-06-20 15:50:12,113 - Valid Perp: 19.5884 542 | 2023-06-20 15:50:12,114 - Epoch: 172, Steps: 28 | Train Loss: 0.0742 Train Perp: 1.0770 Valid Loss: 2.9749 Valid Perp: 19.5884 543 | 544 | 2023-06-20 15:50:19,473 - Valid Perp: 15.8083 545 | 2023-06-20 15:50:19,477 - Epoch: 173, Steps: 28 | Train Loss: 0.0762 Train Perp: 1.0792 Valid Loss: 2.7605 Valid Perp: 15.8083 546 | 547 | 2023-06-20 15:50:26,809 - Valid Perp: 19.1729 548 | 2023-06-20 15:50:26,809 - Epoch: 174, Steps: 28 | Train Loss: 0.0756 Train Perp: 1.0785 Valid Loss: 2.9535 Valid Perp: 19.1729 549 | 550 | 2023-06-20 15:50:34,188 - Valid Perp: 21.9685 551 | 2023-06-20 15:50:34,188 - Epoch: 175, Steps: 28 | Train Loss: 0.0706 Train Perp: 1.0731 Valid Loss: 3.0896 Valid Perp: 21.9685 552 | 553 | 2023-06-20 15:50:41,419 - Valid Perp: 19.8605 554 | 2023-06-20 15:50:41,420 - Epoch: 176, Steps: 28 | Train Loss: 0.0730 Train Perp: 1.0758 Valid Loss: 2.9887 Valid Perp: 19.8605 555 | 556 | 2023-06-20 15:50:48,587 - Valid Perp: 19.9314 557 | 2023-06-20 15:50:48,587 - Epoch: 177, Steps: 28 | Train Loss: 0.0683 Train Perp: 1.0706 Valid Loss: 2.9923 Valid Perp: 19.9314 558 | 559 | 2023-06-20 15:50:55,804 - Valid Perp: 18.4600 560 | 2023-06-20 15:50:55,805 - Epoch: 178, Steps: 28 | Train Loss: 0.0691 Train Perp: 1.0716 Valid Loss: 2.9156 Valid Perp: 18.4600 561 | 562 | 2023-06-20 15:51:03,205 - Valid Perp: 18.8280 563 | 2023-06-20 15:51:03,205 - Epoch: 179, Steps: 28 | Train Loss: 0.0695 Train Perp: 1.0720 Valid Loss: 2.9353 Valid Perp: 18.8280 564 | 565 | 2023-06-20 15:51:10,335 - Valid Perp: 18.7145 566 | 2023-06-20 15:51:10,335 - Epoch: 180, Steps: 28 | Train Loss: 0.0678 Train Perp: 1.0701 Valid Loss: 2.9293 Valid Perp: 18.7145 567 | 568 | 2023-06-20 15:51:17,272 - Valid Perp: 20.9156 569 | 2023-06-20 15:51:17,273 - Epoch: 181, Steps: 28 | Train Loss: 0.0674 Train Perp: 1.0697 Valid Loss: 3.0405 Valid Perp: 20.9156 570 | 571 | 2023-06-20 15:51:24,658 - Valid Perp: 22.2330 572 | 2023-06-20 15:51:24,658 - Epoch: 182, Steps: 28 | Train Loss: 0.0659 Train Perp: 1.0682 Valid Loss: 3.1016 Valid Perp: 22.2330 573 | 574 | 2023-06-20 15:51:31,735 - Valid Perp: 20.0339 575 | 2023-06-20 15:51:31,735 - Epoch: 183, Steps: 28 | Train Loss: 0.0709 Train Perp: 1.0734 Valid Loss: 2.9974 Valid Perp: 20.0339 576 | 577 | 2023-06-20 15:51:38,705 - Valid Perp: 21.7607 578 | 2023-06-20 15:51:38,706 - Epoch: 184, Steps: 28 | Train Loss: 0.0632 Train Perp: 1.0653 Valid Loss: 3.0801 Valid Perp: 21.7607 579 | 580 | 2023-06-20 15:51:45,774 - Valid Perp: 20.1676 581 | 2023-06-20 15:51:45,774 - Epoch: 185, Steps: 28 | Train Loss: 0.0674 Train Perp: 1.0697 Valid Loss: 3.0041 Valid Perp: 20.1676 582 | 583 | 2023-06-20 15:51:52,894 - Valid Perp: 19.4263 584 | 2023-06-20 15:51:52,894 - Epoch: 186, Steps: 28 | Train Loss: 0.0701 Train Perp: 1.0726 Valid Loss: 2.9666 Valid Perp: 19.4263 585 | 586 | 2023-06-20 15:51:59,937 - Valid Perp: 21.9782 587 | 2023-06-20 15:51:59,938 - Epoch: 187, Steps: 28 | Train Loss: 0.0657 Train Perp: 1.0679 Valid Loss: 3.0900 Valid Perp: 21.9782 588 | 589 | 2023-06-20 15:52:07,224 - Valid Perp: 20.4693 590 | 2023-06-20 15:52:07,224 - Epoch: 188, Steps: 28 | Train Loss: 0.0683 Train Perp: 1.0707 Valid Loss: 3.0189 Valid Perp: 20.4693 591 | 592 | 2023-06-20 15:52:14,373 - Valid Perp: 22.1787 593 | 2023-06-20 15:52:14,375 - Epoch: 189, Steps: 28 | Train Loss: 0.0640 Train Perp: 1.0661 Valid Loss: 3.0991 Valid Perp: 22.1787 594 | 595 | 2023-06-20 15:52:21,446 - Valid Perp: 21.8519 596 | 2023-06-20 15:52:21,446 - Epoch: 190, Steps: 28 | Train Loss: 0.0627 Train Perp: 1.0647 Valid Loss: 3.0843 Valid Perp: 21.8519 597 | 598 | 2023-06-20 15:52:28,709 - Valid Perp: 21.1443 599 | 2023-06-20 15:52:28,710 - Epoch: 191, Steps: 28 | Train Loss: 0.0628 Train Perp: 1.0648 Valid Loss: 3.0514 Valid Perp: 21.1443 600 | 601 | 2023-06-20 15:52:35,725 - Valid Perp: 20.3039 602 | 2023-06-20 15:52:35,728 - Epoch: 192, Steps: 28 | Train Loss: 0.0654 Train Perp: 1.0676 Valid Loss: 3.0108 Valid Perp: 20.3039 603 | 604 | 2023-06-20 15:52:42,881 - Valid Perp: 20.4082 605 | 2023-06-20 15:52:42,882 - Epoch: 193, Steps: 28 | Train Loss: 0.0655 Train Perp: 1.0677 Valid Loss: 3.0159 Valid Perp: 20.4082 606 | 607 | 2023-06-20 15:52:49,853 - Valid Perp: 21.4385 608 | 2023-06-20 15:52:49,853 - Epoch: 194, Steps: 28 | Train Loss: 0.0634 Train Perp: 1.0655 Valid Loss: 3.0652 Valid Perp: 21.4385 609 | 610 | 2023-06-20 15:52:56,942 - Valid Perp: 21.8357 611 | 2023-06-20 15:52:56,943 - Epoch: 195, Steps: 28 | Train Loss: 0.0627 Train Perp: 1.0647 Valid Loss: 3.0835 Valid Perp: 21.8357 612 | 613 | 2023-06-20 15:53:04,019 - Valid Perp: 21.9943 614 | 2023-06-20 15:53:04,020 - Epoch: 196, Steps: 28 | Train Loss: 0.0624 Train Perp: 1.0644 Valid Loss: 3.0908 Valid Perp: 21.9943 615 | 616 | 2023-06-20 15:53:11,377 - Valid Perp: 22.1544 617 | 2023-06-20 15:53:11,380 - Epoch: 197, Steps: 28 | Train Loss: 0.0625 Train Perp: 1.0645 Valid Loss: 3.0980 Valid Perp: 22.1544 618 | 619 | 2023-06-20 15:53:18,569 - Valid Perp: 22.2445 620 | 2023-06-20 15:53:18,569 - Epoch: 198, Steps: 28 | Train Loss: 0.0619 Train Perp: 1.0639 Valid Loss: 3.1021 Valid Perp: 22.2445 621 | 622 | 2023-06-20 15:53:25,553 - Valid Perp: 22.1944 623 | 2023-06-20 15:53:25,554 - Epoch: 199, Steps: 28 | Train Loss: 0.0623 Train Perp: 1.0643 Valid Loss: 3.0998 Valid Perp: 22.1944 624 | 625 | 2023-06-20 15:53:32,862 - Valid Perp: 22.1759 626 | 2023-06-20 15:53:32,862 - Epoch: 200, Steps: 28 | Train Loss: 0.0608 Train Perp: 1.0627 Valid Loss: 3.0990 Valid Perp: 22.1759 627 | 628 | 2023-06-20 15:53:35,033 - Test Perp: 3.6372, Test Rec: 0.4000 629 | 630 | -------------------------------------------------------------------------------- /checkpoints/model_param.json: -------------------------------------------------------------------------------- 1 | {"device": "cuda", "display_step": 10, "res_dir": "./results", "ex_name": "camera_ready", "use_gpu": true, "gpu": 0, "seed": 222, "data_name": "", "data_root": "/content/RDesign/data/", "batch_size": 64, "num_workers": 0, "method": "RDesign", "config_file": "default.py", "epoch": 200, "log_step": 1, "lr": 0.001, "node_feat_types": ["angle", "distance", "direction"], "edge_feat_types": ["orientation", "distance", "direction"], "original": 0, "nat": 0, "num_encoder_layers": 3, "num_decoder_layers": 3, "hidden": 128, "k_neighbors": 30, "vocab_size": 4, "shuffle": 0.0, "dropout": 0.2, "smoothing": 0.1, "weigth_clu_con": 0.5, "weigth_sam_con": 0.5, "ss_temp": 0.5, "conf_case": 1, "aug_log_case": 1, "ss_case": 0, "wandb": 0} -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: RDesign 2 | channels: 3 | - pyg 4 | - bioconda 5 | - pytorch 6 | - http://mirrors.ustc.edu.cn/anaconda/pkgs/free 7 | - http://mirrors.ustc.edu.cn/anaconda/cloud/conda-forge 8 | - http://mirrors.ustc.edu.cn/anaconda/pkgs/main 9 | - conda-forge 10 | - http://mirrors.ustc.edu.cn/anaconda/cloud/menpo/ 11 | - http://mirrors.ustc.edu.cn/anaconda/cloud/bioconda/ 12 | - http://mirrors.ustc.edu.cn/anaconda/cloud/msys2/ 13 | - http://mirrors.ustc.edu.cn/anaconda/cloud/conda-forge/ 14 | - http://mirrors.ustc.edu.cn/anaconda/pkgs/free/ 15 | - http://mirrors.ustc.edu.cn/anaconda/pkgs/main/ 16 | - defaults 17 | dependencies: 18 | - _libgcc_mutex=0.1=conda_forge 19 | - _openmp_mutex=4.5=1_gnu 20 | - _pytorch_select=0.1=cpu_0 21 | - alsa-lib=1.2.3.2=h166bdaf_0 22 | - argon2-cffi=21.3.0=pyhd8ed1ab_0 23 | - argon2-cffi-bindings=21.2.0=py39hb9d737c_2 24 | - asttokens=2.0.5=pyhd8ed1ab_0 25 | - attrs=22.1.0=pyh71513ae_1 26 | - backcall=0.2.0=pyh9f0ad1d_0 27 | - backports=1.0=py_2 28 | - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0 29 | - beautifulsoup4=4.11.1=pyha770c72_0 30 | - blas=1.0=mkl 31 | - bleach=5.0.1=pyhd8ed1ab_0 32 | - boost=1.74.0=py39h5472131_4 33 | - boost-cpp=1.74.0=h359cf19_5 34 | - brotli=1.0.9=h7f98852_6 35 | - brotli-bin=1.0.9=h7f98852_6 36 | - brotlipy=0.7.0=py39hb9d737c_1004 37 | - bzip2=1.0.8=h7f98852_4 38 | - c-ares=1.18.1=h7f98852_0 39 | - ca-certificates=2022.6.15=ha878542_0 40 | - cached-property=1.5.2=hd8ed1ab_1 41 | - cached_property=1.5.2=pyha770c72_1 42 | - cairo=1.16.0=ha00ac49_1009 43 | - certifi=2022.6.15=py39hf3d152e_0 44 | - cffi=1.15.0=py39h4bc2ebd_0 45 | - cryptography=37.0.1=py39h9ce1e76_0 46 | - cudatoolkit=11.3.1=h9edb442_10 47 | - cycler=0.11.0=pyhd8ed1ab_0 48 | - dbus=1.13.6=h5008d03_3 49 | - debugpy=1.6.0=py39h5a03fae_0 50 | - decorator=5.1.1=pyhd8ed1ab_0 51 | - defusedxml=0.7.1=pyhd8ed1ab_0 52 | - entrypoints=0.4=pyhd8ed1ab_0 53 | - executing=0.8.3=pyhd8ed1ab_0 54 | - expat=2.4.8=h27087fc_0 55 | - ffmpeg=4.3=hf484d3e_0 56 | - flit-core=3.7.1=pyhd8ed1ab_0 57 | - font-ttf-dejavu-sans-mono=2.37=hab24e00_0 58 | - font-ttf-inconsolata=3.000=h77eed37_0 59 | - font-ttf-source-code-pro=2.038=h77eed37_0 60 | - font-ttf-ubuntu=0.83=hab24e00_0 61 | - fontconfig=2.13.1=hba837de_1005 62 | - fonts-conda-ecosystem=1=0 63 | - fonts-conda-forge=1=0 64 | - fonttools=4.28.5=py39h3811e60_0 65 | - freetype=2.10.4=h0708190_1 66 | - gawk=5.1.0=h7f98852_0 67 | - gettext=0.19.8.1=h73d1719_1008 68 | - gmp=6.2.1=h58526e2_0 69 | - gnutls=3.6.13=h85f3911_1 70 | - greenlet=1.1.2=py39he80948d_1 71 | - gst-plugins-base=1.18.5=hf529b03_3 72 | - gstreamer=1.18.5=h9f60fe5_3 73 | - h5py=3.6.0=nompi_py39h7e08c79_100 74 | - hdf5=1.12.1=nompi_h2386368_104 75 | - icu=69.1=h9c3ff4c_0 76 | - idna=3.3=pyhd8ed1ab_0 77 | - importlib-metadata=4.11.4=py39hf3d152e_0 78 | - importlib_resources=5.9.0=pyhd8ed1ab_0 79 | - intel-openmp=2019.4=243 80 | - ipykernel=6.15.1=pyh210e3f2_0 81 | - ipython=8.4.0=py39hf3d152e_0 82 | - ipython_genutils=0.2.0=py_1 83 | - ipywidgets=7.7.1=pyhd8ed1ab_0 84 | - jbig=2.1=h7f98852_2003 85 | - jedi=0.18.1=py39hf3d152e_1 86 | - jinja2=3.1.2=pyhd8ed1ab_1 87 | - joblib=1.1.0=pyhd8ed1ab_0 88 | - jpeg=9d=h36c2ea0_0 89 | - jsonschema=4.9.1=pyhd8ed1ab_0 90 | - jupyter=1.0.0=py39hf3d152e_7 91 | - jupyter_client=7.2.2=pyhd8ed1ab_1 92 | - jupyter_console=6.4.4=pyhd8ed1ab_0 93 | - jupyter_core=4.11.0=py39hf3d152e_0 94 | - jupyterlab_pygments=0.2.2=pyhd8ed1ab_0 95 | - jupyterlab_widgets=1.1.1=pyhd8ed1ab_0 96 | - keyutils=1.6.1=h166bdaf_0 97 | - kiwisolver=1.3.2=py39h1a9c180_1 98 | - krb5=1.19.3=h3790be6_0 99 | - lame=3.100=h7f98852_1001 100 | - lcms2=2.12=hddcbb42_0 101 | - ld_impl_linux-64=2.36.1=hea4e1c9_2 102 | - lerc=3.0=h9c3ff4c_0 103 | - libblas=3.9.0=1_h86c2bf4_netlib 104 | - libbrotlicommon=1.0.9=h7f98852_6 105 | - libbrotlidec=1.0.9=h7f98852_6 106 | - libbrotlienc=1.0.9=h7f98852_6 107 | - libcblas=3.9.0=5_h92ddd45_netlib 108 | - libclang=13.0.1=default_hc23dcda_0 109 | - libcurl=7.83.1=h7bff187_0 110 | - libdeflate=1.8=h7f98852_0 111 | - libedit=3.1.20191231=he28a2e2_2 112 | - libev=4.33=h516909a_1 113 | - libevent=2.1.10=h9b69904_4 114 | - libffi=3.4.2=h7f98852_5 115 | - libgcc-ng=11.2.0=h1d223b6_11 116 | - libgfortran-ng=11.2.0=h69a702a_11 117 | - libgfortran5=11.2.0=h5c6108e_11 118 | - libglib=2.70.2=h174f98d_1 119 | - libgomp=11.2.0=h1d223b6_11 120 | - libiconv=1.16=h516909a_0 121 | - libidn2=2.3.2=h7f8727e_0 122 | - liblapack=3.9.0=5_h92ddd45_netlib 123 | - libllvm13=13.0.1=hf817b99_2 124 | - libmklml=2019.0.5=h06a4308_0 125 | - libnghttp2=1.47.0=h727a467_0 126 | - libnsl=2.0.0=h7f98852_0 127 | - libogg=1.3.4=h7f98852_1 128 | - libopenblas=0.3.18=pthreads_h8fe5266_0 129 | - libopus=1.3.1=h7f98852_1 130 | - libpng=1.6.37=h21135ba_2 131 | - libpq=14.3=hd77ab85_0 132 | - libsodium=1.0.18=h36c2ea0_1 133 | - libssh2=1.10.0=ha56f1ee_2 134 | - libstdcxx-ng=11.2.0=he4da1e4_11 135 | - libtiff=4.3.0=h6f004c6_2 136 | - libunistring=0.9.10=h7f98852_0 137 | - libuuid=2.32.1=h7f98852_1000 138 | - libuv=1.43.0=h7f98852_0 139 | - libvorbis=1.3.7=h9c3ff4c_0 140 | - libwebp-base=1.2.1=h7f98852_0 141 | - libxcb=1.13=h7f98852_1004 142 | - libxkbcommon=1.0.3=he3ba5ed_0 143 | - libxml2=2.9.12=h885dcf4_1 144 | - libzlib=1.2.11=h36c2ea0_1013 145 | - lz4-c=1.9.3=h9c3ff4c_1 146 | - markupsafe=2.1.1=py39hb9d737c_1 147 | - matplotlib-base=3.5.1=py39h2fa2bec_0 148 | - matplotlib-inline=0.1.3=pyhd8ed1ab_0 149 | - mistune=0.8.4=py39h3811e60_1005 150 | - mkl=2020.2=256 151 | - mmseqs2=13.45111=pl5321hf1761c0_2 152 | - munkres=1.1.4=pyh9f0ad1d_0 153 | - mysql-common=8.0.28=ha770c72_0 154 | - mysql-libs=8.0.28=hfa10184_0 155 | - nbclient=0.6.6=pyhd8ed1ab_0 156 | - nbconvert=6.5.0=pyhd8ed1ab_0 157 | - nbconvert-core=6.5.0=pyhd8ed1ab_0 158 | - nbconvert-pandoc=6.5.0=pyhd8ed1ab_0 159 | - nbformat=5.4.0=pyhd8ed1ab_0 160 | - ncurses=6.2=h58526e2_4 161 | - nest-asyncio=1.5.5=pyhd8ed1ab_0 162 | - nettle=3.6=he412f7d_0 163 | - ninja=1.10.2=h4bd325d_1 164 | - notebook=6.4.12=pyha770c72_0 165 | - nspr=4.32=h9c3ff4c_1 166 | - nss=3.74=hb5efdd6_0 167 | - numpy=1.22.0=py39h91f2184_0 168 | - olefile=0.46=pyh9f0ad1d_1 169 | - openh264=2.1.1=h780b84a_0 170 | - openjpeg=2.4.0=hb52868f_1 171 | - openssl=1.1.1o=h166bdaf_0 172 | - packaging=21.3=pyhd8ed1ab_0 173 | - pandas=1.3.5=py39hde0f152_0 174 | - pandoc=2.19=ha770c72_0 175 | - pandocfilters=1.5.0=pyhd8ed1ab_0 176 | - parso=0.8.3=pyhd8ed1ab_0 177 | - pcre=8.45=h9c3ff4c_0 178 | - perl=5.32.1=2_h7f98852_perl5 179 | - pexpect=4.8.0=pyh9f0ad1d_2 180 | - pickleshare=0.7.5=py_1003 181 | - pixman=0.40.0=h36c2ea0_0 182 | - pkgutil-resolve-name=1.3.10=pyhd8ed1ab_0 183 | - prometheus_client=0.14.1=pyhd8ed1ab_0 184 | - prompt-toolkit=3.0.30=pyha770c72_0 185 | - prompt_toolkit=3.0.30=hd8ed1ab_0 186 | - psutil=5.9.1=py39hb9d737c_0 187 | - pthread-stubs=0.4=h36c2ea0_1001 188 | - ptyprocess=0.7.0=pyhd3deb0d_0 189 | - pure_eval=0.2.2=pyhd8ed1ab_0 190 | - pycairo=1.20.1=py39hedcb9fc_1 191 | - pycparser=2.21=pyhd8ed1ab_0 192 | - pyg=2.0.4=py39_torch_1.11.0_cu113 193 | - pygments=2.12.0=pyhd8ed1ab_0 194 | - pyopenssl=22.0.0=pyhd8ed1ab_0 195 | - pyparsing=3.0.6=pyhd8ed1ab_0 196 | - pyqt=5.12.3=py39h03dd644_4 197 | - pyrsistent=0.18.1=py39hb9d737c_1 198 | - pysocks=1.7.1=py39hf3d152e_5 199 | - python=3.9.9=h62f1059_0_cpython 200 | - python-dateutil=2.8.2=pyhd8ed1ab_0 201 | - python-fastjsonschema=2.16.1=pyhd8ed1ab_0 202 | - python-louvain=0.15=pyhd8ed1ab_1 203 | - python_abi=3.9=2_cp39 204 | - pytorch=1.11.0=py3.9_cuda11.3_cudnn8.2.0_0 205 | - pytorch-cluster=1.6.0=py39_torch_1.11.0_cu113 206 | - pytorch-mutex=1.0=cuda 207 | - pytorch-scatter=2.0.9=py39_torch_1.11.0_cu113 208 | - pytorch-sparse=0.6.14=py39_torch_1.11.0_cu113 209 | - pytorch-spline-conv=1.2.1=py39_torch_1.11.0_cu113 210 | - pytz=2021.3=pyhd8ed1ab_0 211 | - pyyaml=6.0=py39hb9d737c_4 212 | - pyzmq=22.3.0=py39headdf64_2 213 | - qt=5.12.9=ha98a1a1_5 214 | - qtconsole=5.3.1=pyhd8ed1ab_0 215 | - qtconsole-base=5.3.1=pyha770c72_0 216 | - qtpy=2.1.0=pyhd8ed1ab_0 217 | - rdkit=2021.09.4=py39hccf6a74_0 218 | - readline=8.1=h46c0cb4_0 219 | - reportlab=3.5.68=py39he59360d_1 220 | - send2trash=1.8.0=pyhd8ed1ab_0 221 | - setuptools=59.8.0=py39hf3d152e_0 222 | - six=1.16.0=pyh6c4a22f_0 223 | - soupsieve=2.3.2.post1=pyhd8ed1ab_0 224 | - sqlalchemy=1.4.29=py39h3811e60_0 225 | - sqlite=3.37.0=h9cd32fc_0 226 | - stack_data=0.3.0=pyhd8ed1ab_0 227 | - terminado=0.15.0=py39hf3d152e_0 228 | - tinycss2=1.1.1=pyhd8ed1ab_0 229 | - tk=8.6.11=h27826a3_1 230 | - torchaudio=0.11.0=py39_cu113 231 | - torchvision=0.12.0=py39_cu113 232 | - tornado=6.1=py39hb9d737c_3 233 | - traitlets=5.3.0=pyhd8ed1ab_0 234 | - typing-extensions=4.0.1=hd8ed1ab_0 235 | - typing_extensions=4.0.1=pyha770c72_0 236 | - tzdata=2021e=he74cb21_0 237 | - wcwidth=0.2.5=pyh9f0ad1d_2 238 | - webencodings=0.5.1=py_1 239 | - wget=1.20.3=ha56f1ee_1 240 | - wheel=0.37.1=pyhd8ed1ab_0 241 | - widgetsnbextension=3.6.1=pyha770c72_0 242 | - xorg-kbproto=1.0.7=h7f98852_1002 243 | - xorg-libice=1.0.10=h7f98852_0 244 | - xorg-libsm=1.2.3=hd9c2040_1000 245 | - xorg-libx11=1.7.2=h7f98852_0 246 | - xorg-libxau=1.0.9=h7f98852_0 247 | - xorg-libxdmcp=1.1.3=h7f98852_0 248 | - xorg-libxext=1.3.4=h7f98852_1 249 | - xorg-libxrender=0.9.10=h7f98852_1003 250 | - xorg-renderproto=0.11.1=h7f98852_1002 251 | - xorg-xextproto=7.3.0=h7f98852_1002 252 | - xorg-xproto=7.0.31=h7f98852_1007 253 | - xz=5.2.5=h516909a_1 254 | - yacs=0.1.8=pyhd8ed1ab_0 255 | - yaml=0.2.5=h7f98852_2 256 | - zeromq=4.3.4=h9c3ff4c_1 257 | - zipp=3.8.1=pyhd8ed1ab_0 258 | - zlib=1.2.11=h36c2ea0_1013 259 | - zstd=1.5.1=ha95c52a_0 260 | - pip: 261 | - addict==2.4.0 262 | - antlr4-python3-runtime==4.9.3 263 | - appdirs==1.4.4 264 | - astor==0.8.1 265 | - autopep8==1.6.0 266 | - biopython==1.80 267 | - captum==0.2.0 268 | - cftime==1.6.1 269 | - charset-normalizer==2.0.10 270 | - cilog==1.2.3 271 | - click==8.1.3 272 | - cloudpickle==2.0.0 273 | - colorama==0.4.4 274 | - contextlib2==21.6.0 275 | - dask==2022.6.1 276 | - dill==0.3.4 277 | - dive-into-graphs==1.0.0 278 | - docker-pycreds==0.4.0 279 | - et-xmlfile==1.1.0 280 | - filelock==3.4.2 281 | - focal-frequency-loss==0.3.0 282 | - fsspec==2022.5.0 283 | - future==0.18.2 284 | - fvcore==0.1.5.post20220506 285 | - gitdb==4.0.10 286 | - gitpython==3.1.30 287 | - h5netcdf==1.0.1 288 | - hickle==5.0.0.dev0 289 | - hydra-core==1.2.0 290 | - hyperopt==0.1.2 291 | - imageio==2.16.1 292 | - iopath==0.1.9 293 | - json-tricks==3.15.5 294 | - llvmlite==0.39.0 295 | - locket==1.0.0 296 | - lpips==0.1.4 297 | - mmcv-full==1.4.8 298 | - mock==4.0.3 299 | - mpmath==1.2.1 300 | - mypy-extensions==0.4.3 301 | - netcdf4==1.6.0 302 | - networkx==2.6.3 303 | - nni==2.6.1 304 | - numba==0.56.0 305 | - omegaconf==2.2.2 306 | - opencv-python==4.5.5.64 307 | - opencv-python-headless==4.5.5.64 308 | - openpyxl==3.0.10 309 | - partd==1.2.0 310 | - pathtools==0.1.2 311 | - pillow==9.0.1 312 | - pip==22.0.4 313 | - portalocker==2.4.0 314 | - prettytable==3.0.0 315 | - protobuf==4.21.12 316 | - ptflops==0.6.8 317 | - pycodestyle==2.8.0 318 | - pymongo==4.0.1 319 | - pyqt5-sip==4.19.18 320 | - pyqtchart==5.12 321 | - pyqtwebengine==5.12.1 322 | - pyscf==1.7.6.post1 323 | - pythonwebhdfs==0.2.3 324 | - pywavelets==1.3.0 325 | - rdkit-pypi==2022.3.4 326 | - requests==2.27.1 327 | - responses==0.17.0 328 | - schema==0.7.5 329 | - scikit-image==0.16.2 330 | - scikit-learn==1.0.2 331 | - scipy==1.8.0 332 | - sentry-sdk==1.15.0 333 | - setproctitle==1.3.2 334 | - shap==0.41.0 335 | - simplejson==3.17.6 336 | - slicer==0.0.7 337 | - smmap==5.0.0 338 | - subword-nmt==0.3.8 339 | - sympy==1.10.1 340 | - tabulate==0.8.9 341 | - termcolor==1.1.0 342 | - threadpoolctl==3.0.0 343 | - timm==0.5.4 344 | - toml==0.10.2 345 | - toolz==0.11.2 346 | - tqdm==4.62.3 347 | - typed-argument-parser==1.7.2 348 | - typeguard==2.13.3 349 | - typing-inspect==0.7.1 350 | - urllib3==1.26.8 351 | - wandb==0.13.10 352 | - websockets==10.1 353 | - xarray==2022.3.0 354 | - yapf==0.32.0 355 | prefix: /root/anaconda3/envs/RDesign 356 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pickle 3 | import json 4 | import torch 5 | import os.path as osp 6 | from parser import create_parser 7 | 8 | import warnings 9 | warnings.filterwarnings('ignore') 10 | 11 | from methods import RDesign 12 | from API import Recorder 13 | from utils import * 14 | 15 | 16 | class Exp: 17 | def __init__(self, args): 18 | self.args = args 19 | self.config = args.__dict__ 20 | self.device = self._acquire_device() 21 | self.total_step = 0 22 | self._preparation() 23 | print_log(output_namespace(self.args)) 24 | 25 | def _acquire_device(self): 26 | if self.args.use_gpu: device = torch.device('cuda:0') 27 | else: device = torch.device('cpu') 28 | return device 29 | 30 | def _preparation(self): 31 | set_seed(self.args.seed) 32 | # log and checkpoint 33 | self.path = osp.join(self.args.res_dir, self.args.ex_name) 34 | check_dir(self.path) 35 | 36 | self.checkpoints_path = osp.join(self.path, 'checkpoints') 37 | check_dir(self.checkpoints_path) 38 | 39 | sv_param = osp.join(self.path, 'model_param.json') 40 | with open(sv_param, 'w') as file_obj: 41 | json.dump(self.args.__dict__, file_obj) 42 | 43 | for handler in logging.root.handlers[:]: 44 | logging.root.removeHandler(handler) 45 | logging.basicConfig(level=logging.INFO, filename=osp.join(self.path, 'log.log'), 46 | filemode='a', format='%(asctime)s - %(message)s') 47 | # prepare data, only using self._get_data() for training, otherwise just comment it 48 | self._get_data() 49 | # build the method 50 | self._build_method() 51 | 52 | # def _build_method(self): 53 | # steps_per_epoch = 1 54 | # # If training, uncomment next line 55 | # # steps_per_epoch = len(self.train_loader) 56 | # self.method = RDesign(self.args, self.device, steps_per_epoch) 57 | 58 | # def _get_data(self): 59 | # self.train_loader, self.valid_loader, self.test_loader = get_dataset(self.config) 60 | def _build_method(self): 61 | if self.args.load_full_data: 62 | steps_per_epoch = len(self.train_loader) 63 | else: 64 | steps_per_epoch = 1 65 | self.method = RDesign(self.args, self.device, steps_per_epoch) 66 | 67 | def _get_data(self): 68 | if self.args.load_full_data: 69 | self.train_loader, self.valid_loader, self.test_loader = get_dataset(self.config) 70 | else: 71 | self.test_loader = get_dataset(self.config) 72 | 73 | def _save(self, name=''): 74 | torch.save(self.method.model.state_dict(), osp.join(self.checkpoints_path, name + '.pth')) 75 | fw = open(osp.join(self.checkpoints_path, name + '.pkl'), 'wb') 76 | state = self.method.scheduler.state_dict() 77 | pickle.dump(state, fw) 78 | 79 | def _load(self, epoch): 80 | self.method.model.load_state_dict(torch.load(osp.join(self.checkpoints_path, str(epoch) + '.pth'))) 81 | fw = open(osp.join(self.checkpoints_path, str(epoch) + '.pkl'), 'rb') 82 | state = pickle.load(fw) 83 | self.method.scheduler.load_state_dict(state) 84 | 85 | def test(self): 86 | test_perplexity, test_recovery = self.method.test_one_epoch(self.test_loader) 87 | print_log('Test Perp: {0:.4f}, Test Rec: {1:.4f}\n'.format(test_perplexity, test_recovery)) 88 | return test_perplexity, test_recovery 89 | 90 | 91 | if __name__ == '__main__': 92 | args = create_parser() 93 | config = args.__dict__ 94 | 95 | exp = Exp(args) 96 | exp.method.model.load_state_dict(torch.load('checkpoints/checkpoint.pth')) 97 | print('>>>>>>>>>>>>>>>>>>>>>>>>>> testing <<<<<<<<<<<<<<<<<<<<<<<<<<') 98 | test_perp, test_rec = exp.test() 99 | -------------------------------------------------------------------------------- /methods/__init__.py: -------------------------------------------------------------------------------- 1 | from .rdesign import RDesign -------------------------------------------------------------------------------- /methods/rdesign.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from tqdm import tqdm 4 | from .utils import cuda, loss_nll_flatten 5 | from model import RDesign_Model 6 | import torch.nn.functional as F 7 | from sklearn.metrics import precision_recall_fscore_support 8 | 9 | 10 | alphabet = 'AUCG' 11 | pre_base_pairs = {0: 1, 1: 0, 2: 3, 3: 2} 12 | pre_great_pairs = ((0, 1), (1, 0), (2, 3), (3, 2)) 13 | 14 | class RDesign: 15 | def __init__(self, args, device, steps_per_epoch): 16 | self.args = args 17 | self.device = device 18 | self.config = args.__dict__ 19 | 20 | self.model = self._build_model() 21 | 22 | def _build_model(self): 23 | return RDesign_Model(self.args).to(self.device) 24 | 25 | def _cal_recovery(self, dataset, featurizer): 26 | recovery = [] 27 | S_preds, S_trues = [], [] 28 | for sample in tqdm(dataset): 29 | sample = featurizer([sample]) 30 | X, S, mask, lengths, clus, ss_pos, ss_pair, names = sample 31 | X, S, mask, ss_pos = cuda((X, S, mask, ss_pos), device=self.device) 32 | logits, gt_S = self.model.sample(X=X, S=S, mask=mask) 33 | log_probs = F.log_softmax(logits, dim=-1) 34 | 35 | # secondary sharpen 36 | ss_pos = ss_pos[mask == 1].long() 37 | log_probs = log_probs.clone() 38 | log_probs[ss_pos] = log_probs[ss_pos] / self.args.ss_temp 39 | S_pred = torch.argmax(log_probs, dim=1) 40 | 41 | pos_log_probs = log_probs.softmax(-1) 42 | for pair in ss_pair[0]: 43 | s_pos_a, s_pos_b = pair 44 | if s_pos_a == None or s_pos_b == None or s_pos_b >= S_pred.shape[0]: 45 | continue 46 | if (S_pred[s_pos_a].item(), S_pred[s_pos_b].item()) in pre_great_pairs: 47 | continue 48 | 49 | if pos_log_probs[s_pos_a][S_pred[s_pos_a]] > pos_log_probs[s_pos_b][S_pred[s_pos_b]]: 50 | S_pred[s_pos_b] = pre_base_pairs[S_pred[s_pos_a].item()] 51 | elif pos_log_probs[s_pos_a][S_pred[s_pos_a]] < pos_log_probs[s_pos_b][S_pred[s_pos_b]]: 52 | S_pred[s_pos_a] = pre_base_pairs[S_pred[s_pos_b].item()] 53 | 54 | cmp = S_pred.eq(gt_S) 55 | recovery_ = cmp.float().mean().cpu().numpy() 56 | S_preds += S_pred.cpu().numpy().tolist() 57 | S_trues += gt_S.cpu().numpy().tolist() 58 | if np.isnan(recovery_): recovery_ = 0.0 59 | recovery.append(recovery_) 60 | recovery = np.median(recovery) 61 | precision, recall, f1, _ = precision_recall_fscore_support(S_trues, S_preds, average=None) 62 | macro_f1 = f1.mean() 63 | print('macro f1', macro_f1) 64 | return recovery 65 | 66 | def valid_one_epoch(self, valid_loader): 67 | self.model.eval() 68 | with torch.no_grad(): 69 | valid_sum, valid_weights = 0., 0. 70 | valid_pbar = tqdm(valid_loader) 71 | for batch in valid_pbar: 72 | X, S, mask, lengths, clus, ss_pos, ss_pair, names = batch 73 | X, S, mask, lengths, clus, ss_pos = cuda((X, S, mask, lengths, clus, ss_pos), device=self.device) 74 | logits, S, _ = self.model(X, S, mask) 75 | 76 | log_probs = F.log_softmax(logits, dim=-1) 77 | loss, _ = loss_nll_flatten(S, log_probs) 78 | 79 | valid_sum += torch.sum(loss).cpu().data.numpy() 80 | valid_weights += len(loss) 81 | valid_pbar.set_description('valid loss: {:.4f}'.format(loss.mean().item())) 82 | 83 | valid_loss = valid_sum / valid_weights 84 | valid_perplexity = np.exp(valid_loss) 85 | return valid_loss, valid_perplexity 86 | 87 | def test_one_epoch(self, test_loader): 88 | self.model.eval() 89 | with torch.no_grad(): 90 | test_sum, test_weights = 0., 0. 91 | test_pbar = tqdm(test_loader) 92 | for batch in test_pbar: 93 | X, S, mask, lengths, clus, ss_pos, ss_pair, names = batch 94 | X, S, mask, lengths, clus, ss_pos = cuda((X, S, mask, lengths, clus, ss_pos), device=self.device) 95 | logits, S, _ = self.model(X, S, mask) 96 | 97 | log_probs = F.log_softmax(logits, dim=-1) 98 | loss, _ = loss_nll_flatten(S, log_probs) 99 | 100 | test_sum += torch.sum(loss).cpu().data.numpy() 101 | test_weights += len(loss) 102 | test_pbar.set_description('test loss: {:.4f}'.format(loss.mean().item())) 103 | 104 | test_recovery = self._cal_recovery(test_loader.dataset, test_loader.featurizer) 105 | 106 | test_loss = test_sum / test_weights 107 | test_perplexity = np.exp(test_loss) 108 | return test_perplexity, test_recovery -------------------------------------------------------------------------------- /methods/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from collections.abc import Mapping, Sequence 4 | 5 | 6 | def cuda(obj, *args, **kwargs): 7 | """ 8 | Transfer any nested conatiner of tensors to CUDA. 9 | """ 10 | if hasattr(obj, "cuda"): 11 | return obj.cuda(*args, **kwargs) 12 | elif isinstance(obj, Mapping): 13 | return type(obj)({k: cuda(v, *args, **kwargs) for k, v in obj.items()}) 14 | elif isinstance(obj, Sequence): 15 | return type(obj)(cuda(x, *args, **kwargs) for x in obj) 16 | elif isinstance(obj, np.ndarray): 17 | return torch.tensor(obj, *args, **kwargs) 18 | raise TypeError("Can't transfer object type `%s`" % type(obj)) 19 | 20 | def loss_nll_flatten(S, log_probs): 21 | """ Negative log probabilities """ 22 | criterion = torch.nn.NLLLoss(reduction='none') 23 | loss = criterion(log_probs, S) 24 | loss_av = loss.mean() 25 | return loss, loss_av -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .rdesign_model import RDesign_Model -------------------------------------------------------------------------------- /model/feature.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .module import gather_nodes, gather_edges, Normalize 5 | 6 | 7 | feat_dims = { 8 | 'node': { 9 | 'angle': 12, 10 | 'distance': 80, 11 | 'direction': 9, 12 | }, 13 | 'edge': { 14 | 'orientation': 4, 15 | 'distance': 96, 16 | 'direction': 15, 17 | } 18 | } 19 | 20 | 21 | def nan_to_num(tensor, nan=0.0): 22 | idx = torch.isnan(tensor) 23 | tensor[idx] = nan 24 | return tensor 25 | 26 | def _normalize(tensor, dim=-1): 27 | return nan_to_num( 28 | torch.div(tensor, torch.norm(tensor, dim=dim, keepdim=True))) 29 | 30 | 31 | class RNAFeatures(nn.Module): 32 | def __init__(self, edge_features, node_features, node_feat_types=[], edge_feat_types=[], num_rbf=16, top_k=30, augment_eps=0., dropout=0.1, args=None): 33 | super(RNAFeatures, self).__init__() 34 | """Extract RNA Features""" 35 | self.edge_features = edge_features 36 | self.node_features = node_features 37 | self.top_k = top_k 38 | self.augment_eps = augment_eps 39 | self.num_rbf = num_rbf 40 | self.dropout = nn.Dropout(dropout) 41 | self.node_feat_types = node_feat_types 42 | self.edge_feat_types = edge_feat_types 43 | 44 | node_in = sum([feat_dims['node'][feat] for feat in node_feat_types]) 45 | edge_in = sum([feat_dims['edge'][feat] for feat in edge_feat_types]) 46 | self.node_embedding = nn.Linear(node_in, node_features, bias=True) 47 | self.edge_embedding = nn.Linear(edge_in, edge_features, bias=True) 48 | self.norm_nodes = Normalize(node_features) 49 | self.norm_edges = Normalize(edge_features) 50 | 51 | def _dist(self, X, mask, eps=1E-6): 52 | mask_2D = torch.unsqueeze(mask,1) * torch.unsqueeze(mask,2) 53 | dX = torch.unsqueeze(X,1) - torch.unsqueeze(X,2) 54 | D = (1. - mask_2D)*10000 + mask_2D* torch.sqrt(torch.sum(dX**2, 3) + eps) 55 | 56 | D_max, _ = torch.max(D, -1, keepdim=True) 57 | D_adjust = D + (1. - mask_2D) * (D_max+1) 58 | D_neighbors, E_idx = torch.topk(D_adjust, min(self.top_k, D_adjust.shape[-1]), dim=-1, largest=False) 59 | return D_neighbors, E_idx 60 | 61 | def _rbf(self, D): 62 | D_min, D_max, D_count = 0., 20., self.num_rbf 63 | D_mu = torch.linspace(D_min, D_max, D_count, device=D.device) 64 | D_mu = D_mu.view([1,1,1,-1]) 65 | D_sigma = (D_max - D_min) / D_count 66 | D_expand = torch.unsqueeze(D, -1) 67 | return torch.exp(-((D_expand - D_mu) / D_sigma)**2) 68 | 69 | def _get_rbf(self, A, B, E_idx=None, num_rbf=16): 70 | if E_idx is not None: 71 | D_A_B = torch.sqrt(torch.sum((A[:,:,None,:] - B[:,None,:,:])**2,-1) + 1e-6) 72 | D_A_B_neighbors = gather_edges(D_A_B[:,:,:,None], E_idx)[:,:,:,0] 73 | RBF_A_B = self._rbf(D_A_B_neighbors) 74 | else: 75 | D_A_B = torch.sqrt(torch.sum((A[:,:,None,:] - B[:,:,None,:])**2,-1) + 1e-6) 76 | RBF_A_B = self._rbf(D_A_B) 77 | return RBF_A_B 78 | 79 | def _quaternions(self, R): 80 | diag = torch.diagonal(R, dim1=-2, dim2=-1) 81 | Rxx, Ryy, Rzz = diag.unbind(-1) 82 | magnitudes = 0.5 * torch.sqrt(torch.abs(1 + torch.stack([ 83 | Rxx - Ryy - Rzz, 84 | - Rxx + Ryy - Rzz, 85 | - Rxx - Ryy + Rzz 86 | ], -1))) 87 | _R = lambda i,j: R[:,:,:,i,j] 88 | signs = torch.sign(torch.stack([ 89 | _R(2,1) - _R(1,2), 90 | _R(0,2) - _R(2,0), 91 | _R(1,0) - _R(0,1) 92 | ], -1)) 93 | xyz = signs * magnitudes 94 | w = torch.sqrt(F.relu(1 + diag.sum(-1, keepdim=True))) / 2. 95 | Q = torch.cat((xyz, w), -1) 96 | Q = F.normalize(Q, dim=-1) 97 | return Q 98 | 99 | def _orientations_coarse(self, X, E_idx, eps=1e-6): 100 | V = X.clone() 101 | X = X[:,:,:6,:].reshape(X.shape[0], 6*X.shape[1], 3) 102 | dX = X[:,1:,:] - X[:,:-1,:] 103 | U = _normalize(dX, dim=-1) 104 | u_0, u_1 = U[:,:-2,:], U[:,1:-1,:] 105 | n_0 = _normalize(torch.cross(u_0, u_1), dim=-1) 106 | b_1 = _normalize(u_0 - u_1, dim=-1) 107 | 108 | # select C3' 109 | n_0 = n_0[:,4::6,:] 110 | b_1 = b_1[:,4::6,:] 111 | X = X[:,4::6,:] 112 | 113 | Q = torch.stack((b_1, n_0, torch.cross(b_1, n_0)), 2) 114 | Q = Q.view(list(Q.shape[:2]) + [9]) 115 | Q = F.pad(Q, (0,0,0,1), 'constant', 0) # [16, 464, 9] 116 | 117 | Q_neighbors = gather_nodes(Q, E_idx) # [16, 464, 30, 9] 118 | P_neighbors = gather_nodes(V[:,:,0,:], E_idx) # [16, 464, 30, 3] 119 | O5_neighbors = gather_nodes(V[:,:,1,:], E_idx) 120 | C5_neighbors = gather_nodes(V[:,:,2,:], E_idx) 121 | C4_neighbors = gather_nodes(V[:,:,3,:], E_idx) 122 | O3_neighbors = gather_nodes(V[:,:,5,:], E_idx) 123 | 124 | Q = Q.view(list(Q.shape[:2]) + [3,3]).unsqueeze(2) # [16, 464, 1, 3, 3] 125 | Q_neighbors = Q_neighbors.view(list(Q_neighbors.shape[:3]) + [3,3]) # [16, 464, 30, 3, 3] 126 | 127 | dX = torch.stack([P_neighbors,O5_neighbors,C5_neighbors,C4_neighbors,O3_neighbors], dim=3) - X[:,:,None,None,:] # [16, 464, 30, 3] 128 | dU = torch.matmul(Q[:,:,:,None,:,:], dX[...,None]).squeeze(-1) # [16, 464, 30, 3] 邻居的相对坐标 129 | B, N, K = dU.shape[:3] 130 | E_direct = _normalize(dU, dim=-1) 131 | E_direct = E_direct.reshape(B, N, K,-1) 132 | R = torch.matmul(Q.transpose(-1,-2), Q_neighbors) 133 | E_orient = self._quaternions(R) 134 | 135 | dX_inner = V[:,:,[0,2,3],:] - X.unsqueeze(-2) 136 | dU_inner = torch.matmul(Q, dX_inner.unsqueeze(-1)).squeeze(-1) 137 | dU_inner = _normalize(dU_inner, dim=-1) 138 | V_direct = dU_inner.reshape(B,N,-1) 139 | return V_direct, E_direct, E_orient 140 | 141 | def _dihedrals(self, X, eps=1e-7): 142 | # P, O5', C5', C4', C3', O3' 143 | X = X[:,:,:6,:].reshape(X.shape[0], 6*X.shape[1], 3) 144 | 145 | # Shifted slices of unit vectors 146 | # https://iupac.qmul.ac.uk/misc/pnuc2.html#220 147 | # https://x3dna.org/highlights/torsion-angles-of-nucleic-acid-structures 148 | # alpha: O3'_{i-1} P_i O5'_i C5'_i 149 | # beta: P_i O5'_i C5'_i C4'_i 150 | # gamma: O5'_i C5'_i C4'_i C3'_i 151 | # delta: C5'_i C4'_i C3'_i O3'_i 152 | # epsilon: C4'_i C3'_i O3'_i P_{i+1} 153 | # zeta: C3'_i O3'_i P_{i+1} O5'_{i+1} 154 | # What's more: 155 | # chi: C1' - N9 156 | # chi is different for (C, T, U) and (A, G) https://x3dna.org/highlights/the-chi-x-torsion-angle-characterizes-base-sugar-relative-orientation 157 | 158 | dX = X[:, 5:, :] - X[:, :-5, :] # O3'-P, P-O5', O5'-C5', C5'-C4', ... 159 | U = F.normalize(dX, dim=-1) 160 | u_2 = U[:,:-2,:] # O3'-P, P-O5', ... 161 | u_1 = U[:,1:-1,:] # P-O5', O5'-C5', ... 162 | u_0 = U[:,2:,:] # O5'-C5', C5'-C4', ... 163 | # Backbone normals 164 | n_2 = F.normalize(torch.cross(u_2, u_1), dim=-1) 165 | n_1 = F.normalize(torch.cross(u_1, u_0), dim=-1) 166 | 167 | # Angle between normals 168 | cosD = (n_2 * n_1).sum(-1) 169 | cosD = torch.clamp(cosD, -1+eps, 1-eps) 170 | D = torch.sign((u_2 * n_1).sum(-1)) * torch.acos(cosD) 171 | 172 | D = F.pad(D, (3,4), 'constant', 0) 173 | D = D.view((D.size(0), D.size(1) //6, 6)) 174 | return torch.cat((torch.cos(D), torch.sin(D)), 2) # return D_features 175 | 176 | def forward(self, X, S, mask): 177 | if self.training and self.augment_eps > 0: 178 | X = X + self.augment_eps * torch.randn_like(X) 179 | 180 | # Build k-Nearest Neighbors graph 181 | B, N, _,_ = X.shape 182 | # P, O5', C5', C4', C3', O3' 183 | atom_P = X[:, :, 0, :] 184 | atom_O5_ = X[:, :, 1, :] 185 | atom_C5_ = X[:, :, 2, :] 186 | atom_C4_ = X[:, :, 3, :] 187 | atom_C3_ = X[:, :, 4, :] 188 | atom_O3_ = X[:, :, 5, :] 189 | 190 | X_backbone = atom_P 191 | D_neighbors, E_idx = self._dist(X_backbone, mask) 192 | 193 | mask_bool = (mask==1) 194 | mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1) 195 | mask_attend = (mask.unsqueeze(-1) * mask_attend) == 1 196 | edge_mask_select = lambda x: torch.masked_select(x, mask_attend.unsqueeze(-1)).reshape(-1,x.shape[-1]) 197 | node_mask_select = lambda x: torch.masked_select(x, mask_bool.unsqueeze(-1)).reshape(-1, x.shape[-1]) 198 | 199 | # node features 200 | h_V = [] 201 | # angle 202 | V_angle = node_mask_select(self._dihedrals(X)) 203 | # distance 204 | node_list = ['O5_-P', 'C5_-P', 'C4_-P', 'C3_-P', 'O3_-P'] 205 | V_dist = [] 206 | for pair in node_list: 207 | atom1, atom2 = pair.split('-') 208 | V_dist.append(node_mask_select(self._get_rbf(vars()['atom_' + atom1], vars()['atom_' + atom2], None, self.num_rbf).squeeze())) 209 | V_dist = torch.cat(tuple(V_dist), dim=-1).squeeze() 210 | # direction 211 | V_direct, E_direct, E_orient = self._orientations_coarse(X, E_idx) 212 | V_direct = node_mask_select(V_direct) 213 | E_direct, E_orient = list(map(lambda x: edge_mask_select(x), [E_direct, E_orient])) 214 | 215 | # edge features 216 | h_E = [] 217 | # dist 218 | edge_list = ['P-P', 'O5_-P', 'C5_-P', 'C4_-P', 'C3_-P', 'O3_-P'] 219 | E_dist = [] 220 | for pair in edge_list: 221 | atom1, atom2 = pair.split('-') 222 | E_dist.append(edge_mask_select(self._get_rbf(vars()['atom_' + atom1], vars()['atom_' + atom2], E_idx, self.num_rbf))) 223 | E_dist = torch.cat(tuple(E_dist), dim=-1) 224 | 225 | if 'angle' in self.node_feat_types: 226 | h_V.append(V_angle) 227 | if 'distance' in self.node_feat_types: 228 | h_V.append(V_dist) 229 | if 'direction' in self.node_feat_types: 230 | h_V.append(V_direct) 231 | 232 | if 'orientation' in self.edge_feat_types: 233 | h_E.append(E_orient) 234 | if 'distance' in self.edge_feat_types: 235 | h_E.append(E_dist) 236 | if 'direction' in self.edge_feat_types: 237 | h_E.append(E_direct) 238 | 239 | # Embed the nodes 240 | h_V = self.norm_nodes(self.node_embedding(torch.cat(h_V, dim=-1))) 241 | h_E = self.norm_edges(self.edge_embedding(torch.cat(h_E, dim=-1))) 242 | 243 | # prepare the variables to return 244 | S = torch.masked_select(S, mask_bool) 245 | shift = mask.sum(dim=1).cumsum(dim=0) - mask.sum(dim=1) 246 | src = shift.view(B,1,1) + E_idx 247 | src = torch.masked_select(src, mask_attend).view(1,-1) 248 | dst = shift.view(B,1,1) + torch.arange(0, N, device=src.device).view(1,-1,1).expand_as(mask_attend) 249 | dst = torch.masked_select(dst, mask_attend).view(1,-1) 250 | E_idx = torch.cat((dst, src), dim=0).long() 251 | 252 | sparse_idx = mask.nonzero() 253 | X = X[sparse_idx[:,0], sparse_idx[:,1], :, :] 254 | batch_id = sparse_idx[:,0] 255 | return X, S, h_V, h_E, E_idx, batch_id -------------------------------------------------------------------------------- /model/module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from torch_scatter import scatter_sum, scatter_softmax 5 | 6 | 7 | def gather_edges(edges, neighbor_idx): 8 | neighbors = neighbor_idx.unsqueeze(-1).expand(-1, -1, -1, edges.size(-1)) 9 | return torch.gather(edges, 2, neighbors) 10 | 11 | def gather_nodes(nodes, neighbor_idx): 12 | neighbors_flat = neighbor_idx.view((neighbor_idx.shape[0], -1)) 13 | neighbors_flat = neighbors_flat.unsqueeze(-1).expand(-1, -1, nodes.size(2)) 14 | neighbor_features = torch.gather(nodes, 1, neighbors_flat) 15 | neighbor_features = neighbor_features.view(list(neighbor_idx.shape)[:3] + [-1]) 16 | return neighbor_features 17 | 18 | def gather_nodes_t(nodes, neighbor_idx): 19 | idx_flat = neighbor_idx.unsqueeze(-1).expand(-1, -1, nodes.size(2)) 20 | return torch.gather(nodes, 1, idx_flat) 21 | 22 | def cat_neighbors_nodes(h_nodes, h_neighbors, E_idx): 23 | h_nodes = gather_nodes(h_nodes, E_idx) 24 | return torch.cat([h_neighbors, h_nodes], -1) 25 | 26 | 27 | class MPNNLayer(nn.Module): 28 | def __init__(self, num_hidden, num_in, dropout=0.1, num_heads=None, scale=30): 29 | super(MPNNLayer, self).__init__() 30 | self.num_hidden = num_hidden 31 | self.num_in = num_in 32 | self.scale = scale 33 | self.dropout = nn.Dropout(dropout) 34 | self.norm1 = nn.LayerNorm(num_hidden) 35 | self.norm2 = nn.LayerNorm(num_hidden) 36 | 37 | self.W1 = nn.Linear(num_hidden + num_in, num_hidden, bias=True) 38 | self.W2 = nn.Linear(num_hidden, num_hidden, bias=True) 39 | self.W3 = nn.Linear(num_hidden, num_hidden, bias=True) 40 | self.act = nn.ReLU() 41 | 42 | self.dense = nn.Sequential( 43 | nn.Linear(num_hidden, num_hidden*4), 44 | nn.ReLU(), 45 | nn.Linear(num_hidden*4, num_hidden) 46 | ) 47 | 48 | def forward(self, h_V, h_E, edge_idx, batch_id=None): 49 | src_idx, dst_idx = edge_idx[0], edge_idx[1] 50 | h_message = self.W3(self.act(self.W2(self.act(self.W1(h_E))))) 51 | dh = scatter_sum(h_message, src_idx, dim=0) / self.scale 52 | h_V = self.norm1(h_V + self.dropout(dh)) 53 | dh = self.dense(h_V) 54 | h_V = self.norm2(h_V + self.dropout(dh)) 55 | return h_V 56 | 57 | 58 | class TransformerLayer(nn.Module): 59 | def __init__(self, num_hidden, num_in, num_heads=4, dropout=0.0): 60 | super(TransformerLayer, self).__init__() 61 | self.num_heads = num_heads 62 | self.num_hidden = num_hidden 63 | self.num_in = num_in 64 | self.dropout = nn.Dropout(dropout) 65 | self.norm = nn.ModuleList([nn.BatchNorm1d(num_hidden) for _ in range(2)]) 66 | self.attention = NeighborAttention(num_hidden, num_hidden + num_in, num_heads) 67 | self.dense = nn.Sequential( 68 | nn.Linear(num_hidden, num_hidden*4), 69 | nn.ReLU(), 70 | nn.Linear(num_hidden*4, num_hidden) 71 | ) 72 | 73 | def forward(self, h_V, h_E, edge_idx, batch_id=None): 74 | center_id = edge_idx[0] 75 | dh = self.attention(h_V, h_E, center_id, batch_id) 76 | h_V = self.norm[0](h_V + self.dropout(dh)) 77 | dh = self.dense(h_V) 78 | h_V = self.norm[1](h_V + self.dropout(dh)) 79 | return h_V 80 | 81 | 82 | class Normalize(nn.Module): 83 | def __init__(self, features, epsilon=1e-6): 84 | super(Normalize, self).__init__() 85 | self.gain = nn.Parameter(torch.ones(features)) 86 | self.bias = nn.Parameter(torch.zeros(features)) 87 | self.epsilon = epsilon 88 | 89 | def forward(self, x, dim=-1): 90 | mu = x.mean(dim, keepdim=True) 91 | sigma = torch.sqrt(x.var(dim, keepdim=True) + self.epsilon) 92 | gain = self.gain 93 | bias = self.bias 94 | if dim != -1: 95 | shape = [1] * len(mu.size()) 96 | shape[dim] = self.gain.size()[0] 97 | gain = gain.view(shape) 98 | bias = bias.view(shape) 99 | return gain * (x - mu) / (sigma + self.epsilon) + bias 100 | 101 | 102 | class NeighborAttention(nn.Module): 103 | def __init__(self, num_hidden, num_in, num_heads=4): 104 | super(NeighborAttention, self).__init__() 105 | self.num_heads = num_heads 106 | self.num_hidden = num_hidden 107 | 108 | self.W_Q = nn.Linear(num_hidden, num_hidden, bias=False) 109 | self.W_K = nn.Linear(num_in, num_hidden, bias=False) 110 | self.W_V = nn.Linear(num_in, num_hidden, bias=False) 111 | self.Bias = nn.Sequential( 112 | nn.Linear(num_hidden*3, num_hidden), 113 | nn.ReLU(), 114 | nn.Linear(num_hidden,num_hidden), 115 | nn.ReLU(), 116 | nn.Linear(num_hidden,num_heads) 117 | ) 118 | self.W_O = nn.Linear(num_hidden, num_hidden, bias=False) 119 | 120 | def forward(self, h_V, h_E, center_id, batch_id): 121 | N = h_V.shape[0] 122 | E = h_E.shape[0] 123 | n_heads = self.num_heads 124 | d = int(self.num_hidden / n_heads) 125 | 126 | Q = self.W_Q(h_V).view(N, n_heads, 1, d)[center_id] 127 | K = self.W_K(h_E).view(E, n_heads, d, 1) 128 | attend_logits = torch.matmul(Q, K).view(E, n_heads, 1) 129 | attend_logits = attend_logits / np.sqrt(d) 130 | 131 | V = self.W_V(h_E).view(-1, n_heads, d) 132 | attend = scatter_softmax(attend_logits, index=center_id, dim=0) 133 | h_V = scatter_sum(attend*V, center_id, dim=0).view([N, self.num_hidden]) 134 | h_V_update = self.W_O(h_V) 135 | return h_V_update -------------------------------------------------------------------------------- /model/rdesign_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .module import MPNNLayer 4 | from .feature import RNAFeatures 5 | 6 | 7 | def gather_nodes(nodes, neighbor_idx): 8 | neighbors_flat = neighbor_idx.view((neighbor_idx.shape[0], -1)) 9 | neighbors_flat = neighbors_flat.unsqueeze(-1).expand(-1, -1, nodes.size(2)) 10 | neighbor_features = torch.gather(nodes, 1, neighbors_flat) 11 | neighbor_features = neighbor_features.view(list(neighbor_idx.shape)[:3] + [-1]) 12 | return neighbor_features 13 | 14 | 15 | class RDesign_Model(nn.Module): 16 | def __init__(self, args): 17 | super(RDesign_Model, self).__init__() 18 | 19 | self.device = 'cuda:0' 20 | self.smoothing = args.smoothing 21 | self.node_features = self.edge_features = args.hidden 22 | self.hidden_dim = args.hidden 23 | self.vocab = args.vocab_size 24 | 25 | self.features = RNAFeatures( 26 | args.hidden, args.hidden, 27 | top_k=args.k_neighbors, 28 | dropout=args.dropout, 29 | node_feat_types=args.node_feat_types, 30 | edge_feat_types=args.edge_feat_types, 31 | args=args 32 | ) 33 | 34 | layer = MPNNLayer 35 | self.W_s = nn.Embedding(args.vocab_size, self.hidden_dim) 36 | self.encoder_layers = nn.ModuleList([ 37 | layer(self.hidden_dim, self.hidden_dim*2, dropout=args.dropout) 38 | for _ in range(args.num_encoder_layers)]) 39 | self.decoder_layers = nn.ModuleList([ 40 | layer(self.hidden_dim, self.hidden_dim*2, dropout=args.dropout) 41 | for _ in range(args.num_decoder_layers)]) 42 | 43 | self.projection_head = nn.Sequential( 44 | nn.Linear(self.hidden_dim, self.hidden_dim, bias=False), 45 | nn.ReLU(inplace=True), 46 | nn.Linear(self.hidden_dim, self.hidden_dim, bias=True) 47 | ) 48 | 49 | self.readout = nn.Linear(self.hidden_dim, args.vocab_size, bias=True) 50 | 51 | for p in self.parameters(): 52 | if p.dim() > 1: 53 | nn.init.xavier_uniform_(p) 54 | 55 | def forward(self, X, S, mask): 56 | X, S, h_V, h_E, E_idx, batch_id = self.features(X, S, mask) 57 | 58 | for enc_layer in self.encoder_layers: 59 | h_EV = torch.cat([h_E, h_V[E_idx[0]], h_V[E_idx[1]]], dim=-1) 60 | h_V = enc_layer(h_V, h_EV, E_idx, batch_id) 61 | 62 | for dec_layer in self.decoder_layers: 63 | h_EV = torch.cat([h_E, h_V[E_idx[0]], h_V[E_idx[1]]], dim=-1) 64 | h_V = dec_layer(h_V, h_EV, E_idx, batch_id) 65 | 66 | graph_embs = [] 67 | for b_id in range(batch_id[-1].item()+1): 68 | b_data = h_V[batch_id == b_id].mean(0) 69 | graph_embs.append(b_data) 70 | graph_embs = torch.stack(graph_embs, dim=0) 71 | graph_prjs = self.projection_head(graph_embs) 72 | 73 | logits = self.readout(h_V) 74 | return logits, S, graph_prjs 75 | 76 | def sample(self, X, S, mask=None): 77 | X, gt_S, h_V, h_E, E_idx, batch_id = self.features(X, S, mask) 78 | 79 | for enc_layer in self.encoder_layers: 80 | h_EV = torch.cat([h_E, h_V[E_idx[0]], h_V[E_idx[1]]], dim=-1) 81 | h_V = enc_layer(h_V, h_EV, E_idx, batch_id) 82 | 83 | for dec_layer in self.decoder_layers: 84 | h_EV = torch.cat([h_E, h_V[E_idx[0]], h_V[E_idx[1]]], dim=-1) 85 | h_V = dec_layer(h_V, h_EV, E_idx, batch_id) 86 | 87 | logits = self.readout(h_V) 88 | return logits, gt_S -------------------------------------------------------------------------------- /parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def create_parser(): 5 | parser = argparse.ArgumentParser() 6 | # Set-up parameters 7 | parser.add_argument('--device', default='cuda', type=str, help='Name of device to use for tensor computations (cuda/cpu)') 8 | parser.add_argument('--display_step', default=10, type=int, help='Interval in batches between display of training metrics') 9 | parser.add_argument('--res_dir', default='./results', type=str) 10 | parser.add_argument('--ex_name', default='debug', type=str) 11 | parser.add_argument('--use_gpu', default=True, type=bool) 12 | parser.add_argument('--gpu', default=0, type=int) 13 | parser.add_argument('--seed', default=111, type=int) 14 | 15 | # dataset parameters 16 | parser.add_argument('--data_root', default='./data/RNAsolo/') 17 | parser.add_argument('--batch_size', default=64, type=int) 18 | parser.add_argument('--num_workers', default=0, type=int) 19 | 20 | # training parameters 21 | parser.add_argument('--epoch', default=2, type=int, help='end epoch') 22 | parser.add_argument('--log_step', default=1, type=int) 23 | parser.add_argument('--lr', default=0.001, type=float, help='Learning rate') 24 | 25 | # feature parameters 26 | parser.add_argument('--node_feat_types', default=['angle', 'distance', 'direction'], type=list) 27 | parser.add_argument('--edge_feat_types', default=['orientation', 'distance', 'direction'], type=list) 28 | 29 | # model parameters 30 | parser.add_argument('--num_encoder_layers', default=3, type=int) 31 | parser.add_argument('--num_decoder_layers', default=3, type=int) 32 | 33 | parser.add_argument('--hidden', default=128, type=int) 34 | parser.add_argument('--k_neighbors', default=30, type=int) 35 | parser.add_argument('--vocab_size', default=4, type=int) 36 | parser.add_argument('--shuffle', default=0., type=float) 37 | parser.add_argument('--dropout', default=0.1, type=float) 38 | parser.add_argument('--smoothing', default=0.1, type=float) 39 | 40 | parser.add_argument('--weigth_clu_con', default=0.5, type=float) 41 | parser.add_argument('--weigth_sam_con', default=0.5, type=float) 42 | parser.add_argument('--ss_temp', default=0.5, type=float) 43 | 44 | return parser.parse_args() -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # RDesign: Hierarchical Data-efficient Representation Learning for Tertiary Structure-based RNA Design 2 | 3 | ![GitHub stars](https://img.shields.io/github/stars/A4Bio/RDesign) ![GitHub forks](https://img.shields.io/github/forks/A4Bio/RDesign?color=green) 4 | Open In Colab 5 | 6 | **[2024-08-21] News:** We provided a comprehensive evaluation system for RNA sequence design and prediction named **R3Design**. APIs and Colab demos are also provided. Feel free to check out our new [repo](https://github.com/A4Bio/R3Design)! 7 | 8 | **[2024-08-15] Update:** Thank you all for the interests and inquries about our paper, we are sorry that we haven't provided detailed documentation and demo of the paper for such a long time. Now, it has been solved. Feel free to check out our updated documentation and colab! :) 9 | ## Introduction 10 | 11 | While artificial intelligence has made remarkable strides in revealing the relationship between biological macromolecules' primary sequence and tertiary structure, designing RNA sequences based on specified tertiary structures remains challenging. Though existing approaches in protein design have thoroughly explored structure-to-sequence dependencies in proteins, RNA design still confronts difficulties due to structural complexity and data scarcity. 12 | 13 | In this study, we aim to systematically construct a data-driven RNA design pipeline. We crafted a large, well-curated benchmark dataset and designed a comprehensive structural modeling approach to represent the complex RNA tertiary structure. More importantly, we proposed a hierarchical data-efficient representation learning framework that learns structural representations through contrastive learning at both cluster-level and sample-level to fully leverage the limited data. Extensive experiments demonstrate the effectiveness of our proposed method, providing a reliable baseline for future RNA design tasks. 14 | 15 |

16 | 17 |

18 | 19 | 20 | ## Dataset 21 | 22 | We carefully collected representative RNA tertiary structure data from two sources, RNAsolo and the Protein Data Bank (PDB). The refined data has been released [here](https://github.com/A4Bio/RDesign/releases/tag/data). Please download the datasets and organize them as follows. 23 | 24 | ``` 25 | RDesign 26 | ├── API 27 | ├── assets 28 | ├── checkpoints 29 | ├── methods 30 | ├── model 31 | └── data 32 | ├── RNAsolo 33 | │ ├── train_data.pt 34 | │ ├── val_data.pt 35 | │ ├── test_data.pt 36 | ``` 37 | 38 | ### Main Environment 39 | 40 | ```shell 41 | cd RDesign 42 | conda env create -f environment.yml 43 | conda activate RDesign 44 | ``` 45 | 46 | ### Load Data 47 | 48 | ```shell 49 | # If you want to see the details inside our dataset, you could use Pickle package from Python 50 | import _pickle as cPickle 51 | train_data = cPickle.load(open('data/train_data.pt', 'rb')) 52 | print(train_data[0].keys()) 53 | 54 | #For external datasets, loading data could be in this way: 55 | from API.rpuzzles_dataset import RPuzzlesDataset 56 | rfam_dataset = RPuzzlesDataset('./data/rfam_data.pt') 57 | rpuz_dataset = RPuzzlesDataset('./data/rpuz_data.pt') 58 | ``` 59 | 60 | ### Test the model 61 | 62 | ```shell 63 | # For more details, please refer to the colab 64 | # We provided detailed functions and pipeline to show how our model operates 65 | ``` 66 | Colab Link: 67 | 68 | Open In Colab 69 | 70 | 71 | ## Citation 72 | 73 | If you are interested in our repository and our paper, please cite the following paper: 74 | 75 | ``` 76 | @inproceedings{tan2024rdesign, 77 | title={RDesign: Hierarchical Data-efficient Representation Learning for Tertiary Structure-based RNA Design}, 78 | author={Tan, Cheng and Zhang, Yijie and Gao, Zhangyang and Hu, Bozhen and Li, Siyuan and Liu, Zicheng and Li, Stan Z}, 79 | booktitle={The Twelfth International Conference on Learning Representations}, 80 | year={2024} 81 | } 82 | ``` 83 | 84 | ## Feedback 85 | If you have any issue about this work, please feel free to contact me by email: 86 | * Cheng Tan: tancheng@westlake.edu.cn 87 | * Yijie Zhang: yj.zhang@mail.mcgill.ca 88 | 89 | ## License 90 | 91 | This project is released under the [Apache 2.0 license](LICENSE). See `LICENSE` for more information. 92 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import random 4 | import torch 5 | import numpy as np 6 | import torch.backends.cudnn as cudnn 7 | 8 | 9 | def set_seed(seed): 10 | random.seed(seed) 11 | np.random.seed(seed) 12 | torch.manual_seed(seed) 13 | cudnn.deterministic = True 14 | 15 | def print_log(message): 16 | print(message) 17 | logging.info(message) 18 | 19 | def output_namespace(namespace): 20 | configs = namespace.__dict__ 21 | message = '' 22 | for k, v in configs.items(): 23 | message += '\n' + k + ': \t' + str(v) + '\t' 24 | return message 25 | 26 | def check_dir(path): 27 | if not os.path.exists(path): 28 | os.makedirs(path) 29 | 30 | def get_dataset(config): 31 | from API import load_data 32 | return load_data(**config) --------------------------------------------------------------------------------