├── .gitignore ├── README.md ├── datasets ├── __init__.py ├── bms.py ├── extra_prepocess2.py ├── prepocess2.py ├── pseudo_prepocess2.py ├── tokenizer2.py └── transforms.py ├── environment.yml ├── eval.py ├── inference.py ├── losses ├── __init__.py └── ls_loss.py ├── misc ├── __init__.py ├── metrics.py ├── sample_submission_with_length.csv.gz └── utils.py ├── models ├── __init__.py ├── cait.py ├── fairseq_transformer.py ├── swin.py └── vit.py ├── normalize_inchis.py ├── normalize_inchis.sh ├── optims └── __init__.py ├── r09_create_images_from_allowed_inchi.py ├── swa.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | notebooks 3 | __pycache__ 4 | logdir 5 | *.csv 6 | *.pth 7 | *.pickle 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Part of the 9th place solution for [the Bristol-Myers Squibb – Molecular Translation challenge](https://www.kaggle.com/c/bms-molecular-translation/overview) translating 2 | images containing chemical structures into InChI (International Chemical Identifier) texts. 3 | 4 | This repo is partially based on the following resources: 5 | * [Y.Nakama's](https://www.kaggle.com/yasufuminakama) [tokenization](https://www.kaggle.com/yasufuminakama/inchi-preprocess-2) 6 | * [Heng's](https://www.kaggle.com/hengck23) [transformer decoder](https://www.kaggle.com/c/bms-molecular-translation/discussion/231190) 7 | * [Sam Stainsby's](https://www.kaggle.com/stainsby) [external images creation](https://www.kaggle.com/stainsby/improved-synthetic-data-for-bms-competition-v3) updated by [ZFTurbo](https://www.kaggle.com/zfturbo) 8 | 9 | 10 | ## Requirements 11 | * install and activate [the conda environment](environment.yml) 12 | * download and extract the data into `/data/bms/` 13 | * extract and move [sample_submission_with_length.csv.gz](models/sample_submission_with_length.csv.gz) into `/data/bms/` 14 | * tokenize training inputs: `python datasets/prepocess2.py` 15 | * if you want to use pseudo labeling, execute: `python datasets/pseudo_prepocess2.py your_submission_file.csv` 16 | * if you want to use external images, you can create with the following commands: 17 | ``` 18 | python r09_create_images_from_allowed_inchi.py 19 | python datasets/extra_prepocess2.py 20 | ``` 21 | * and also install [apex](https://github.com/NVIDIA/apex) 22 | 23 | ## Training 24 | This repo supports training any VIT/SWIN/CAIT transformer models from [timm](https://github.com/rwightman/pytorch-image-models/) as encoder together with the fairseq transformer decoder. 25 | 26 | 27 | Here is an example configuration to train a SWIN `swin_base_patch4_window12_384` as encoder and 12 layer 16 head fairseq decoder: 28 | ``` 29 | python -m torch.distributed.launch --nproc_per_node=N train.py --logdir=logdir/ \ 30 | --pipeline --train-batch-size=50 --valid-batch-size=128 --dataload-workers-nums=10 --mixed-precision --amp-level=O2 \ 31 | --aug-rotate90-p=0.5 --aug-crop-p=0.5 --aug-noise-p=0.9 --label-smoothing=0.1 \ 32 | --encoder-lr=1e-3 --decoder-lr=1e-3 --lr-step-ratio=0.3 --lr-policy=step --optim=adam --lr-warmup-steps=1000 --max-epochs=20 --weight-decay=0 --clip-grad-norm=1 \ 33 | --verbose --image-size=384 --model=swin_base_patch4_window12_384 --loss=ce --embed-dim=1024 --num-head=16 --num-layer=12 \ 34 | --fold=0 --train-dataset-size=0 --valid-dataset-size=65536 --valid-dataset-non-sorted 35 | ``` 36 | 37 | For pseudo labeling, use `--pseudo=pseudo.pkl`. If you want subsample the pseudo dataset, use: `--pseudo-dataset-size=448000`. 38 | For using external images, use `--extra` (`--extra-dataset-size=448000`). 39 | 40 | After training, you can also use Stochastic Weight Averaging (SWA) which gives a boost around 0.02: 41 | ``` 42 | python swa.py --image-size=384 --input logdir/epoch-17.pth,logdir/epoch-18.pth,logdir/epoch-19.pth,logdir/epoch-20.pth 43 | ``` 44 | 45 | ## Inference 46 | 47 | Evaluation: 48 | ``` 49 | python -m torch.distributed.launch --nproc_per_node=N eval.py --mixed-precision --batch-size=128 swa_model.pth 50 | ``` 51 | 52 | Inference: 53 | ``` 54 | python -m torch.distributed.launch --nproc_per_node=N inference.py --mixed-precision --batch-size=128 swa_model.pth 55 | ``` 56 | 57 | Normalization with RDKit: 58 | ``` 59 | ./normalize_inchis.sh submission.csv 60 | ``` 61 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tugstugi/pytorch-bms/ddd0979eb01a08df1c7dee3c435b661c19df38fc/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/bms.py: -------------------------------------------------------------------------------- 1 | __author__ = 'Erdene-Ochir Tuguldur' 2 | 3 | import torch 4 | import random 5 | import numpy as np 6 | import pandas as pd 7 | 8 | from pathlib import Path 9 | from PIL import Image 10 | from torch.utils.data import Dataset 11 | from sklearn.model_selection import StratifiedKFold 12 | 13 | from .tokenizer2 import load_tokenizer 14 | 15 | RANDOM_SEED = 1234 16 | 17 | 18 | class BMSTrainDataset(Dataset): 19 | 20 | def __init__(self, fold=0, mode='train', data_root='/data/bms', transform=None, dataset_size=0, sort_valid=True): 21 | self.mode = mode 22 | self.data_root = Path(data_root) 23 | self.transform = transform 24 | self.dataset_size = dataset_size 25 | self.tokenizer = load_tokenizer(data_root) 26 | 27 | data = pd.read_pickle(str(self.data_root / 'train2.pkl')) 28 | 29 | # data = data[:500000] 30 | 31 | def gen_file_path(image_id): 32 | return str(self.data_root / "train/{}/{}/{}/{}.png".format(image_id[0], image_id[1], image_id[2], image_id)) 33 | 34 | data['file_path'] = data['image_id'].apply(gen_file_path) 35 | 36 | pd.set_option('display.max_colwidth', None) 37 | # print(data['file_path'].head()) 38 | # print(data['InChI_length'].max()) 39 | 40 | skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=RANDOM_SEED) 41 | for n, (train_index, val_index) in enumerate(skf.split(data, data['InChI_length'])): 42 | data.loc[val_index, 'fold'] = int(n) 43 | data['fold'] = data['fold'].astype(int) 44 | # print(data.groupby(['fold']).size()) 45 | 46 | if self.mode == 'train': 47 | data_idx = data[data['fold'] != fold].index 48 | else: 49 | data_idx = data[data['fold'] == fold].index 50 | 51 | self.df = data.loc[data_idx].reset_index(drop=True) 52 | if self.mode != 'train' and sort_valid: 53 | # make fast eval by sorting length 54 | self.df = self.df.sort_values(by=['InChI_length']) 55 | # print(self.df.head()) 56 | # print(len(self.df)) 57 | 58 | def __len__(self): 59 | if self.dataset_size > 0: 60 | return self.dataset_size 61 | return len(self.df) 62 | 63 | def __getitem__(self, idx): 64 | if self.dataset_size > 0 and self.mode == 'train': 65 | idx = random.randint(0, len(self.df) - 1) 66 | 67 | row = self.df.iloc[idx] 68 | 69 | image = Image.open(row['file_path']).convert('RGB') 70 | label = self.tokenizer.text_to_sequence(row['InChI_text']) 71 | 72 | data = { 73 | 'input': np.array(image), 74 | 'label': np.array(label), 75 | 'label_length': len(label), 76 | 'inchi': str(row['InChI']) 77 | } 78 | 79 | if self.transform is not None: 80 | data = self.transform(data) 81 | 82 | return data 83 | 84 | 85 | class BMSPseudoDataset(BMSTrainDataset): 86 | 87 | def __init__(self, pseudo_file, data_root='/data/bms', transform=None, dataset_size=0): 88 | self.mode = 'train' 89 | self.data_root = Path(data_root) 90 | self.transform = transform 91 | self.dataset_size = dataset_size 92 | self.tokenizer = load_tokenizer(data_root) 93 | 94 | data = pd.read_pickle(str(self.data_root / pseudo_file)) 95 | 96 | def gen_file_path(image_id): 97 | return str(self.data_root / "test/{}/{}/{}/{}.png".format(image_id[0], image_id[1], image_id[2], image_id)) 98 | 99 | data['file_path'] = data['image_id'].apply(gen_file_path) 100 | 101 | pd.set_option('display.max_colwidth', None) 102 | print(data['file_path'].head()) 103 | print(data['InChI_length'].max()) 104 | 105 | self.df = data 106 | 107 | 108 | class BMSExtraDataset(BMSTrainDataset): 109 | 110 | def __init__(self, data_root='/data/bms', transform=None, dataset_size=0): 111 | self.mode = 'train' 112 | self.data_root = Path(data_root) 113 | self.transform = transform 114 | self.dataset_size = dataset_size 115 | self.tokenizer = load_tokenizer(data_root) 116 | 117 | data = pd.read_pickle(str(self.data_root / 'extra.pkl')) 118 | 119 | def gen_file_path(image_id): 120 | return str(self.data_root / "extra_images/{}/{}/{}/{}.png".format(image_id[0], image_id[1], image_id[2], image_id)) 121 | 122 | data['file_path'] = data['image_id'].apply(gen_file_path) 123 | 124 | self.df = data 125 | self.df = self.df.sort_values(by=['InChI_length'], ascending=False) 126 | # self.df = self.df[:500000] 127 | 128 | pd.set_option('display.max_colwidth', None) 129 | print(self.df['file_path'].head()) 130 | print(self.df['InChI_length'].min(), self.df['InChI_length'].mean(), self.df['InChI_length'].max()) 131 | 132 | 133 | class BMSTestDataset(Dataset): 134 | 135 | def __init__(self, data_root='/data/bms', transform=None): 136 | self.data_root = Path(data_root) 137 | self.transform = transform 138 | self.tokenizer = load_tokenizer(data_root) 139 | 140 | data = pd.read_csv(str(self.data_root / '4174.csv')) 141 | 142 | # data = data.sort_values(by=['InChI_length']) 143 | # data = data.drop(columns=['InChI_length']) 144 | 145 | def gen_file_path(image_id): 146 | return str(self.data_root / "test/{}/{}/{}/{}.png".format(image_id[0], image_id[1], image_id[2], image_id)) 147 | 148 | data['file_path'] = data['image_id'].apply(gen_file_path) 149 | # data = data[:1024] 150 | self.df = data 151 | 152 | pd.set_option('display.max_colwidth', None) 153 | # print(data['file_path'].head()) 154 | # print(self.df.head()) 155 | # print(len(self.df)) 156 | 157 | def __len__(self): 158 | return len(self.df) 159 | 160 | def __getitem__(self, idx): 161 | row = self.df.iloc[idx] 162 | 163 | image = Image.open(row['file_path']).convert('RGB') 164 | 165 | data = { 166 | 'input': np.array(image), 167 | 'image_id': row['image_id'] 168 | } 169 | 170 | if self.transform is not None: 171 | data = self.transform(data) 172 | 173 | return data 174 | -------------------------------------------------------------------------------- /datasets/extra_prepocess2.py: -------------------------------------------------------------------------------- 1 | __author__ = 'Erdene-Ochir Tuguldur' 2 | 3 | import pandas as pd 4 | from tqdm.auto import tqdm 5 | 6 | from tokenizer2 import load_tokenizer 7 | from prepocess2 import split_form, split_form2 8 | 9 | tqdm.pandas() 10 | 11 | if __name__ == '__main__': 12 | # ==================================================== 13 | # Data Loading 14 | # ==================================================== 15 | # cat PSD_43.csv | grep -v "?" | grep -v "p" | grep -v "q" > 43.csv 16 | train = pd.read_csv('/data/bms/extra_approved_InChIs_with_ids.csv') 17 | print(f'train.shape: {train.shape}') 18 | 19 | train['InChI_1'] = train['InChI'].progress_apply(lambda x: x.split('/')[1]) 20 | train['InChI_text'] = train['InChI_1'].progress_apply(split_form) + ' ' + \ 21 | train['InChI'].apply(lambda x: '/'.join(x.split('/')[2:])).progress_apply(split_form2).values 22 | 23 | tokenizer = load_tokenizer() 24 | 25 | # ==================================================== 26 | # preprocess pseudo 27 | # ==================================================== 28 | lengths = [] 29 | tk0 = tqdm(train['InChI_text'].values, total=len(train)) 30 | for text in tk0: 31 | try: 32 | seq = tokenizer.text_to_sequence(text) 33 | length = len(seq) - 2 34 | lengths.append(length) 35 | except KeyError: 36 | lengths.append(-1) 37 | train['InChI_length'] = lengths 38 | train = train[train['InChI_length'] != -1] 39 | print("valid inchis: ", len(train)) 40 | train.to_pickle('/data/bms/extra.pkl') 41 | print('Saved preprocessed pseudo.pkl') 42 | -------------------------------------------------------------------------------- /datasets/prepocess2.py: -------------------------------------------------------------------------------- 1 | # 2 | # Original code: https://www.kaggle.com/yasufuminakama/inchi-preprocess-2 3 | # 4 | 5 | import re 6 | import pandas as pd 7 | from tqdm.auto import tqdm 8 | 9 | tqdm.pandas() 10 | 11 | 12 | # ==================================================== 13 | # Preprocess functions 14 | # ==================================================== 15 | def split_form(form): 16 | string = '' 17 | for i in re.findall(r"[A-Z][^A-Z]*", form): 18 | elem = re.match(r"\D+", i).group() 19 | num = i.replace(elem, "") 20 | if num == "": 21 | string += f"{elem} " 22 | else: 23 | string += f"{elem} {str(num)} " 24 | return string.rstrip(' ') 25 | 26 | 27 | def split_form2(form): 28 | string = '' 29 | for i in re.findall(r"[a-z][^a-z]*", form): 30 | elem = i[0] 31 | num = i.replace(elem, "").replace('/', "") 32 | num_string = '' 33 | for j in re.findall(r"[0-9]+[^0-9]*", num): 34 | num_list = list(re.findall(r'\d+', j)) 35 | assert len(num_list) == 1, f"len(num_list) != 1" 36 | _num = num_list[0] 37 | if j == _num: 38 | num_string += f"{_num} " 39 | else: 40 | extra = j.replace(_num, "") 41 | num_string += f"{_num} {' '.join(list(extra))} " 42 | string += f"/{elem} {num_string}" 43 | return string.rstrip(' ') 44 | 45 | 46 | if __name__ == '__main__': 47 | # ==================================================== 48 | # Data Loading 49 | # ==================================================== 50 | data_root = '/data/bms/' 51 | train = pd.read_csv(data_root + 'train_labels.csv') 52 | print(f'train.shape: {train.shape}') 53 | 54 | train['InChI_1'] = train['InChI'].progress_apply(lambda x: x.split('/')[1]) 55 | train['InChI_text'] = train['InChI_1'].progress_apply(split_form) + ' ' + \ 56 | train['InChI'].apply(lambda x: '/'.join(x.split('/')[2:])).progress_apply(split_form2).values 57 | # ==================================================== 58 | # create tokenizer 59 | # ==================================================== 60 | from tokenizer2 import Tokenizer 61 | import pickle 62 | 63 | tokenizer = Tokenizer() 64 | tokenizer.fit_on_texts(train['InChI_text'].values) 65 | pickle.dump({ 66 | 'stoi': tokenizer.stoi, 67 | 'itos': tokenizer.itos 68 | }, open(data_root + 'tokenizer2.pickle', 'wb')) 69 | print('Saved tokenizer2.pickle') 70 | # ==================================================== 71 | # preprocess train.csv 72 | # ==================================================== 73 | lengths = [] 74 | tk0 = tqdm(train['InChI_text'].values, total=len(train)) 75 | for text in tk0: 76 | seq = tokenizer.text_to_sequence(text) 77 | length = len(seq) - 2 78 | lengths.append(length) 79 | train['InChI_length'] = lengths 80 | train.to_pickle(data_root + 'train2.pkl') 81 | print('Saved preprocessed train2.pkl') 82 | -------------------------------------------------------------------------------- /datasets/pseudo_prepocess2.py: -------------------------------------------------------------------------------- 1 | # 2 | # Original code: https://www.kaggle.com/yasufuminakama/inchi-preprocess-2 3 | # 4 | 5 | import sys 6 | import pandas as pd 7 | from tqdm.auto import tqdm 8 | 9 | from tokenizer2 import load_tokenizer 10 | from prepocess2 import split_form, split_form2 11 | 12 | tqdm.pandas() 13 | 14 | if __name__ == '__main__': 15 | # ==================================================== 16 | # Data Loading 17 | # ==================================================== 18 | # cat PSD_43.csv | grep -v "?" | grep -v "p" | grep -v "q" > 43.csv 19 | train = pd.read_csv(sys.argv[1]) 20 | print(f'train.shape: {train.shape}') 21 | 22 | train['InChI_1'] = train['InChI'].progress_apply(lambda x: x.split('/')[1]) 23 | train['InChI_text'] = train['InChI_1'].progress_apply(split_form) + ' ' + \ 24 | train['InChI'].apply(lambda x: '/'.join(x.split('/')[2:])).progress_apply(split_form2).values 25 | 26 | tokenizer = load_tokenizer() 27 | 28 | # ==================================================== 29 | # preprocess pseudo 30 | # ==================================================== 31 | lengths = [] 32 | tk0 = tqdm(train['InChI_text'].values, total=len(train)) 33 | for text in tk0: 34 | seq = tokenizer.text_to_sequence(text) 35 | length = len(seq) - 2 36 | lengths.append(length) 37 | train['InChI_length'] = lengths 38 | train.to_pickle('/data/bms/pseudo.pkl') 39 | print('Saved preprocessed pseudo.pkl') 40 | -------------------------------------------------------------------------------- /datasets/tokenizer2.py: -------------------------------------------------------------------------------- 1 | # 2 | # Original code: https://www.kaggle.com/yasufuminakama/inchi-resnet-lstm-with-attention-starter 3 | # 4 | 5 | class Tokenizer(object): 6 | 7 | def __init__(self): 8 | self.stoi = {} 9 | self.itos = {} 10 | 11 | def __len__(self): 12 | return len(self.stoi) 13 | 14 | def fit_on_texts(self, texts): 15 | vocab = set() 16 | for text in texts: 17 | vocab.update(text.split(' ')) 18 | vocab = sorted(vocab) 19 | vocab.append('') 20 | vocab.append('') 21 | vocab.append('') 22 | for i, s in enumerate(vocab): 23 | self.stoi[s] = i 24 | self.itos = {item[1]: item[0] for item in self.stoi.items()} 25 | 26 | def text_to_sequence(self, text): 27 | sequence = [] 28 | sequence.append(self.stoi['']) 29 | for s in text.split(' '): 30 | sequence.append(self.stoi[s]) 31 | sequence.append(self.stoi['']) 32 | return sequence 33 | 34 | def texts_to_sequences(self, texts): 35 | sequences = [] 36 | for text in texts: 37 | sequence = self.text_to_sequence(text) 38 | sequences.append(sequence) 39 | return sequences 40 | 41 | def sequence_to_text(self, sequence): 42 | return ''.join(list(map(lambda i: self.itos[i], sequence))) 43 | 44 | def sequences_to_texts(self, sequences): 45 | texts = [] 46 | for sequence in sequences: 47 | text = self.sequence_to_text(sequence) 48 | texts.append(text) 49 | return texts 50 | 51 | def predict_caption(self, sequence): 52 | caption = '' 53 | for i in sequence: 54 | if i == self.stoi[''] or i == self.stoi['']: 55 | break 56 | caption += self.itos[i] 57 | return caption 58 | 59 | def predict_captions(self, sequences): 60 | captions = [] 61 | for sequence in sequences: 62 | caption = self.predict_caption(sequence) 63 | captions.append(caption) 64 | return captions 65 | 66 | 67 | def load_tokenizer(data_root='/data/bms/'): 68 | import pickle 69 | from pathlib import Path 70 | saved_dicts = pickle.load(open(Path(data_root) / 'tokenizer2.pickle', 'rb')) 71 | tokenizer = Tokenizer() 72 | tokenizer.stoi = saved_dicts['stoi'] 73 | tokenizer.itos = saved_dicts['itos'] 74 | return tokenizer 75 | 76 | 77 | if __name__ == '__main__': 78 | t = load_tokenizer('/data/bms/') 79 | print(f"tokenizer.stoi: {t.stoi}") 80 | print(f"tokenizer.stoi: {t.itos}") 81 | -------------------------------------------------------------------------------- /datasets/transforms.py: -------------------------------------------------------------------------------- 1 | __author__ = 'Erdene-Ochir Tuguldur' 2 | 3 | import cv2 4 | import random 5 | import numpy as np 6 | import albumentations as album 7 | from albumentations.pytorch import ToTensorV2 8 | 9 | 10 | def get_test_transform(args): 11 | return Compose([ 12 | # TestFix(), 13 | ApplyAlbumentations( 14 | album.Compose([ 15 | album.Resize(args.image_size, args.image_size), 16 | album.Normalize( 17 | mean=[0.485, 0.456, 0.406], 18 | std=[0.229, 0.224, 0.225], 19 | ), 20 | ToTensorV2() 21 | ]) 22 | ) 23 | ]) 24 | 25 | 26 | class Compose(object): 27 | """Composes several transforms together.""" 28 | 29 | def __init__(self, transforms): 30 | self.transforms = transforms 31 | 32 | def __call__(self, data): 33 | for t in self.transforms: 34 | data = t(data) 35 | return data 36 | 37 | 38 | class ApplyAlbumentations(object): 39 | """Apply transforms from Albumentations.""" 40 | 41 | def __init__(self, a_transform): 42 | self.a_transform = a_transform 43 | 44 | def __call__(self, data): 45 | data['input'] = self.a_transform(image=data['input'])['image'] 46 | return data 47 | 48 | 49 | class TestFix(object): 50 | def __init__(self): 51 | self.a_transform = album.Compose([album.Transpose(p=1), album.VerticalFlip(p=1)]) 52 | 53 | def __call__(self, data): 54 | h, w, _ = data['input'].shape 55 | if h > w: 56 | data['input'] = self.a_transform(image=data['input'])['image'] 57 | return data 58 | 59 | 60 | class CropAugment(object): 61 | """Crop pixels from borders. """ 62 | def __init__(self, probability=0.5, crops=(5, 10, 15, 20)): 63 | self.probability = probability 64 | self.crops = crops 65 | 66 | def __call__(self, data): 67 | if random.random() < self.probability: 68 | img = data['input'] 69 | crop = random.choice(self.crops) 70 | img = img[crop:-crop, crop:-crop, :] 71 | data['input'] = img 72 | return data 73 | 74 | 75 | class RandomNoiseAugment(object): 76 | """Add random noise. """ 77 | def __init__(self, probability=0.5, frac=0.005): 78 | self.probability = probability 79 | self.frac = frac 80 | 81 | def __call__(self, data): 82 | if random.random() < self.probability: 83 | img = data['input'] 84 | max_val = int(round(1 / self.frac)) 85 | indices = np.random.randint(0, max_val, size=img.shape) 86 | img[indices == 0] = 0 87 | data['input'] = img 88 | return data 89 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: bms 2 | channels: 3 | - fastai/label/test 4 | - pytorch 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - python=3.6.9 9 | - pillow-simd 10 | 11 | - cudatoolkit=10.1 12 | - pytorch=1.7.1 13 | - torchvision 14 | 15 | - opencv 16 | - cython 17 | - scipy 18 | - scikit-learn 19 | - tqdm 20 | - pandas 21 | - pyyaml 22 | - scikit-image 23 | - rdkit 24 | #- sqlite 25 | - pip 26 | - pip: 27 | - tensorboardX 28 | - git+https://github.com/NVIDIA/apex.git@3ae89c754d945e407a6674aa2006d5a0e35d540e 29 | - albumentations 30 | - requests 31 | - resnest 32 | - torch_optimizer 33 | - git+https://github.com/rwightman/pytorch-image-models 34 | - efficientnet_pytorch 35 | - python-Levenshtein 36 | - fairseq==0.10.2 37 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | __author__ = 'Erdene-Ochir Tuguldur' 4 | 5 | import os 6 | import argparse 7 | 8 | from torch.nn.utils.rnn import pad_sequence 9 | from tqdm import * 10 | 11 | # project imports 12 | from datasets.bms import BMSTrainDataset, BMSTestDataset 13 | from models import get_model 14 | 15 | from datasets.transforms import * 16 | from misc.metrics import * 17 | 18 | import torch 19 | from torch.utils.data import DataLoader 20 | import torch.distributed as dist 21 | 22 | import warnings 23 | 24 | warnings.filterwarnings("ignore") 25 | 26 | 27 | seed = 1234 28 | np.random.seed(seed) 29 | random.seed(seed) 30 | os.environ['PYTHONHASHSEED'] = str(seed) 31 | torch.manual_seed(seed) 32 | torch.cuda.manual_seed(seed) 33 | torch.cuda.manual_seed_all(seed) 34 | torch.backends.cudnn.deterministic = True 35 | # torch.backends.cudnn.benchmark = False 36 | 37 | 38 | def bms_collate(batch): 39 | inputs, labels, label_lengths, inchis = [], [], [], [] 40 | for b in batch: 41 | inputs.append(b['input']) 42 | labels.append(torch.LongTensor(b['label']).reshape(-1, 1)) 43 | label_lengths.append(torch.LongTensor([b['label_length']])) 44 | inchis.append(b['inchi']) 45 | labels = pad_sequence(labels, batch_first=True, padding_value=192) 46 | return { 47 | 'input': torch.stack(inputs), 48 | 'label': labels.squeeze(dim=-1), 49 | 'label_length': torch.stack(label_lengths).reshape(-1, 1), 50 | 'inchi': inchis 51 | } 52 | 53 | 54 | def load_model_and_data(add_args, validation=True): 55 | checkpoint_file_name = add_args.checkpoint 56 | print(checkpoint_file_name) 57 | checkpoint = torch.load(checkpoint_file_name, map_location='cpu') 58 | args = checkpoint.get('args', None) 59 | print("saved args", args) 60 | if 'metric' in checkpoint: 61 | print("metric %.2f%%" % checkpoint['metric']) 62 | if not hasattr(args, 'max_token'): 63 | args.max_token = 275 64 | if hasattr(add_args, 'test_fold') and add_args.test_fold is not None: 65 | print("\n\n\n\n\n\nWARNING: this is evil don't use it...!!!\n\n\n\n\n\n", args) 66 | args.fold = add_args.test_fold 67 | 68 | model = get_model(args) 69 | state_dict = checkpoint['state_dict'] 70 | remove_module_keys = True 71 | if remove_module_keys: 72 | new_state_dict = {} 73 | for k, v in state_dict.items(): 74 | if k.startswith('module.'): 75 | new_state_dict[k[len('module.'):]] = v 76 | else: 77 | new_state_dict[k] = v 78 | state_dict = new_state_dict 79 | 80 | model.load_state_dict(state_dict) 81 | model.float() 82 | start_epoch = checkpoint.get('epoch', 0) 83 | global_step = checkpoint.get('global_step', 0) 84 | 85 | del checkpoint 86 | print("loaded checkpoint epoch=%d step=%d" % (start_epoch, global_step)) 87 | 88 | args.dataload_workers_nums = add_args.dataload_workers_nums 89 | 90 | test_transform = get_test_transform(args) 91 | 92 | if validation: 93 | test_dataset = BMSTrainDataset(fold=args.fold, mode='valid', transform=test_transform, dataset_size=0) 94 | test_batch_size = add_args.batch_size 95 | collate_fn = bms_collate 96 | else: 97 | test_transform = Compose([ 98 | TestFix(), 99 | test_transform 100 | ]) 101 | test_dataset = BMSTestDataset(transform=test_transform) 102 | test_batch_size = add_args.batch_size 103 | collate_fn = None 104 | 105 | test_data_sampler = None 106 | if add_args.distributed: 107 | test_data_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, shuffle=False) 108 | 109 | test_data_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False, 110 | collate_fn=collate_fn, 111 | num_workers=args.dataload_workers_nums, 112 | sampler=test_data_sampler, pin_memory=True) 113 | 114 | return model, args, test_dataset, test_batch_size, test_data_loader 115 | 116 | 117 | def evaluate(add_args): 118 | model, args, test_dataset, test_batch_size, test_data_loader = load_model_and_data(add_args, validation=True) 119 | 120 | model.eval() 121 | model.cuda() 122 | torch.set_grad_enabled(False) 123 | if add_args.mixed_precision: 124 | model = model.half() 125 | 126 | all_inchis = [] 127 | all_predictions = [] 128 | 129 | print(test_dataset.tokenizer.stoi[""]) 130 | 131 | pbar = tqdm(test_data_loader, unit="images", unit_scale=test_batch_size, mininterval=10) 132 | for batch in pbar: 133 | inputs = batch['input'].cuda() 134 | if add_args.mixed_precision: 135 | inputs = inputs.half() 136 | 137 | predictions = model(inputs, True, None, None, args.max_token, test_dataset.tokenizer) 138 | all_predictions += predictions 139 | all_inchis += batch['inchi'] 140 | 141 | print(all_inchis[0]) 142 | print(all_predictions[0]) 143 | metric = compute_metric(all_inchis, all_predictions) 144 | print(metric) 145 | 146 | if add_args.distributed: 147 | metric_tensor = torch.tensor(metric).cuda() 148 | rt = metric_tensor.clone() 149 | dist.all_reduce(rt, op=dist.reduce_op.SUM) 150 | rt = rt / add_args.world_size 151 | metric = rt.item() 152 | 153 | # wait until everything finished 154 | dist.barrier() 155 | 156 | # import pickle 157 | # pickle.dump({ 158 | # 'inchis': all_inchis, 159 | # 'predictions': all_predictions 160 | # }, open('eval.pickle', 'wb')) 161 | 162 | return metric 163 | 164 | 165 | if __name__ == '__main__': 166 | parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter) 167 | parser.add_argument("--dataload-workers-nums", type=int, default=4, help='number of workers for dataloader') 168 | parser.add_argument("--test-fold", type=int, help="don't use it!!!") 169 | parser.add_argument("--batch-size", type=int, default=64, help='batch size') 170 | parser.add_argument("--local_rank", default=0, type=int) 171 | parser.add_argument('--cudnn-benchmark', action='store_true', help='enable CUDNN benchmark') 172 | parser.add_argument('--mixed-precision', action='store_true', help='mixed precision') 173 | parser.add_argument("checkpoint", type=str, help='a pretrained neural network model') 174 | main_args = parser.parse_args() 175 | 176 | main_args.distributed = False 177 | main_args.world_size = 1 178 | if 'WORLD_SIZE' in os.environ: 179 | main_args.distributed = int(os.environ['WORLD_SIZE']) > 1 180 | main_args.world_size = int(os.environ['WORLD_SIZE']) 181 | if main_args.distributed: 182 | torch.cuda.set_device(main_args.local_rank) 183 | dist.init_process_group(backend='nccl', init_method='env://') 184 | torch.backends.cudnn.benchmark = main_args.cudnn_benchmark 185 | 186 | metric = evaluate(main_args) 187 | if main_args.local_rank == 0: 188 | print("%s: %.2f" % ('metric', metric)) 189 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | __author__ = 'Erdene-Ochir Tuguldur' 4 | 5 | import os 6 | import glob 7 | import numpy as np 8 | import random 9 | import argparse 10 | from pathlib import Path 11 | from tqdm import * 12 | 13 | import torch 14 | import torch.distributed as dist 15 | 16 | # project imports 17 | from eval import load_model_and_data 18 | 19 | import warnings 20 | 21 | warnings.filterwarnings("ignore") 22 | 23 | 24 | seed = 1234 25 | np.random.seed(seed) 26 | random.seed(seed) 27 | os.environ['PYTHONHASHSEED'] = str(seed) 28 | torch.manual_seed(seed) 29 | torch.cuda.manual_seed(seed) 30 | torch.cuda.manual_seed_all(seed) 31 | torch.backends.cudnn.deterministic = True 32 | # torch.backends.cudnn.benchmark = False 33 | 34 | 35 | def inference(add_args): 36 | model, args, test_dataset, test_batch_size, test_data_loader = load_model_and_data(add_args, validation=False) 37 | 38 | model.eval() 39 | model.cuda() 40 | torch.set_grad_enabled(False) 41 | if add_args.mixed_precision: 42 | model = model.half() 43 | 44 | print(test_dataset.tokenizer.stoi[""]) 45 | 46 | with Path("%d.csv" % add_args.local_rank).open('wt') as f: 47 | # f.write('image_id,InChI\n') 48 | pbar = tqdm(test_data_loader, unit="images", unit_scale=test_batch_size, mininterval=30) 49 | for batch in pbar: 50 | inputs, image_ids = batch['input'].cuda(), batch['image_id'] 51 | if add_args.mixed_precision: 52 | inputs = inputs.half() 53 | 54 | predictions = model(inputs, True, None, None, args.max_token, test_dataset.tokenizer) 55 | for image_id, p in zip(image_ids, predictions): 56 | f.write("%s,\"%s\"\n" % (image_id, p)) 57 | f.flush() 58 | 59 | if add_args.distributed: 60 | # wait until everything finished 61 | dist.barrier() 62 | 63 | 64 | if __name__ == '__main__': 65 | parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter) 66 | parser.add_argument("--dataload-workers-nums", type=int, default=4, help='number of workers for dataloader') 67 | parser.add_argument("--batch-size", type=int, default=64, help='batch size') 68 | parser.add_argument("--local_rank", default=0, type=int) 69 | parser.add_argument('--cudnn-benchmark', action='store_true', help='enable CUDNN benchmark') 70 | parser.add_argument("--submission", type=str, default='submission.csv', help="submission output") 71 | parser.add_argument('--mixed-precision', action='store_true', help='mixed precision') 72 | parser.add_argument("checkpoint", type=str, help='a pretrained neural network model') 73 | main_args = parser.parse_args() 74 | 75 | main_args.distributed = False 76 | main_args.world_size = 1 77 | if 'WORLD_SIZE' in os.environ: 78 | main_args.distributed = int(os.environ['WORLD_SIZE']) > 1 79 | main_args.world_size = int(os.environ['WORLD_SIZE']) 80 | if main_args.distributed: 81 | torch.cuda.set_device(main_args.local_rank) 82 | dist.init_process_group(backend='nccl', init_method='env://') 83 | torch.backends.cudnn.benchmark = main_args.cudnn_benchmark 84 | 85 | inference(main_args) 86 | 87 | if main_args.local_rank == 0: 88 | image_ids = set() 89 | submission_file = Path(main_args.submission) 90 | submission_file.parent.mkdir(exist_ok=True) 91 | with submission_file.open('wt') as f: 92 | f.write('image_id,InChI\n') 93 | output_files = glob.glob("[0-9].csv") 94 | for output_file in sorted(output_files): 95 | print("merging %s..." % output_file) 96 | with Path(output_file).open('rt') as of: 97 | for line in of: 98 | image_id = line.split(',')[0] 99 | if image_id not in image_ids: 100 | image_ids.add(image_id) 101 | f.write(line) 102 | print("image ids: ", len(image_ids)) 103 | -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'Erdene-Ochir Tuguldur' 2 | 3 | import torch 4 | 5 | from losses.ls_loss import LabelSmoothingLoss, LabelSmoothingCrossEntropy 6 | 7 | LOSSES = ['ce'] 8 | 9 | 10 | def get_loss(args, tokenizer): 11 | if args.loss == 'ce': 12 | if args.label_smoothing > 0: 13 | # return LabelSmoothingLoss(args.label_smoothing, tgt_vocab_size=len(tokenizer), ignore_index=tokenizer.stoi[""]).cuda() 14 | return LabelSmoothingCrossEntropy(args.label_smoothing, ignore_index=tokenizer.stoi[""]).cuda() 15 | else: 16 | return torch.nn.CrossEntropyLoss(ignore_index=tokenizer.stoi[""]) 17 | else: 18 | raise RuntimeError("Unknown loss!") 19 | -------------------------------------------------------------------------------- /losses/ls_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | # 7 | # https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/utils/loss.py 8 | # 9 | class LabelSmoothingLoss(nn.Module): 10 | """ 11 | With label smoothing, 12 | KL-divergence between q_{smoothed ground truth prob.}(w) 13 | and p_{prob. computed by model}(w) is minimized. 14 | """ 15 | 16 | def __init__(self, label_smoothing, tgt_vocab_size, ignore_index=-100): 17 | assert 0.0 < label_smoothing <= 1.0 18 | self.ignore_index = ignore_index 19 | super(LabelSmoothingLoss, self).__init__() 20 | 21 | smoothing_value = label_smoothing / (tgt_vocab_size - 2) 22 | one_hot = torch.full((tgt_vocab_size,), smoothing_value) 23 | one_hot[self.ignore_index] = 0 24 | self.register_buffer('one_hot', one_hot.unsqueeze(0)) 25 | 26 | self.confidence = 1.0 - label_smoothing 27 | 28 | def forward(self, output, target): 29 | """ 30 | output (FloatTensor): batch_size x n_classes 31 | target (LongTensor): batch_size 32 | """ 33 | model_prob = self.one_hot.repeat(target.size(0), 1) 34 | model_prob.scatter_(1, target.unsqueeze(1), self.confidence) 35 | model_prob.masked_fill_((target == self.ignore_index).unsqueeze(1), 0) 36 | 37 | return F.kl_div(output, model_prob, reduction='sum') 38 | 39 | 40 | # 41 | # https://github.com/fastai/fastai/blob/8013797e05f0ae0d771d60ecf7cf524da591503c/fastai/layers.py#L300 42 | # 43 | class LabelSmoothingCrossEntropy(nn.Module): 44 | def __init__(self, eps: float = 0.1, ignore_index=-100, reduction='mean'): 45 | super().__init__() 46 | self.eps, self.reduction, self.ignore_index = eps, reduction, ignore_index 47 | print("label smoothing: ", self.eps, self.reduction, self.ignore_index) 48 | 49 | def forward(self, output, target): 50 | c = output.size()[-1] 51 | log_preds = F.log_softmax(output, dim=-1) 52 | if self.reduction == 'sum': 53 | loss = -log_preds.sum() 54 | else: 55 | loss = -log_preds.sum(dim=-1) 56 | if self.reduction == 'mean': loss = loss.mean() 57 | return loss * self.eps / c + (1 - self.eps) * F.nll_loss(log_preds, target, reduction=self.reduction, ignore_index=self.ignore_index) 58 | -------------------------------------------------------------------------------- /misc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tugstugi/pytorch-bms/ddd0979eb01a08df1c7dee3c435b661c19df38fc/misc/__init__.py -------------------------------------------------------------------------------- /misc/metrics.py: -------------------------------------------------------------------------------- 1 | __author__ = 'Erdene-Ochir Tuguldur' 2 | 3 | import numpy as np 4 | import Levenshtein 5 | 6 | 7 | def compute_metric(y_true, y_pred): 8 | scores = [] 9 | for true, pred in zip(y_true, y_pred): 10 | score = Levenshtein.distance(true, pred) 11 | scores.append(score) 12 | avg_score = np.mean(scores) 13 | return avg_score 14 | -------------------------------------------------------------------------------- /misc/sample_submission_with_length.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tugstugi/pytorch-bms/ddd0979eb01a08df1c7dee3c435b661c19df38fc/misc/sample_submission_with_length.csv.gz -------------------------------------------------------------------------------- /misc/utils.py: -------------------------------------------------------------------------------- 1 | __author__ = 'Erdene-Ochir Tuguldur' 2 | 3 | import os 4 | import glob 5 | import torch 6 | 7 | 8 | def get_last_checkpoint_file_name(logdir): 9 | """Returns the last checkpoint file name in the given log dir path.""" 10 | # checkpoints = glob.glob(os.path.join(logdir, '*.pth')) 11 | # checkpoints.sort() 12 | # if len(checkpoints) == 0: 13 | # return None 14 | # return checkpoints[-1] 15 | checkpoint = os.path.join(logdir, 'last.pth') 16 | if os.path.exists(checkpoint): 17 | return checkpoint 18 | return None 19 | 20 | 21 | def load_checkpoint(checkpoint_file_name, model, optimizer, use_gpu=False, 22 | remove_module_keys=True, remove_decoder=False, remove_encoder=False, non_strict=False): 23 | """Loads the checkpoint into the given model and optimizer.""" 24 | checkpoint = torch.load(checkpoint_file_name, map_location='cpu' if not use_gpu else None) 25 | state_dict = checkpoint['state_dict'] 26 | if remove_module_keys or remove_decoder: 27 | new_state_dict = {} 28 | for k, v in state_dict.items(): 29 | if k.startswith('module.'): 30 | k = k[len('module.'):] 31 | new_state_dict[k] = v 32 | if remove_encoder and k.startswith('encoder'): 33 | del new_state_dict[k] 34 | if remove_decoder and k.startswith('decoder'): 35 | del new_state_dict[k] 36 | state_dict = new_state_dict 37 | 38 | model.load_state_dict(state_dict, strict=False if (remove_encoder or remove_decoder or non_strict) else True) 39 | model.float() 40 | if optimizer is not None: 41 | optimizer.load_state_dict(checkpoint['optimizer']) 42 | start_epoch = checkpoint.get('epoch', 0) 43 | global_step = checkpoint.get('global_step', 0) 44 | saved_args = checkpoint.get('args', None) 45 | del checkpoint 46 | print("loaded checkpoint epoch=%d step=%d" % (start_epoch, global_step)) 47 | return start_epoch, global_step, saved_args 48 | 49 | 50 | def save_checkpoint(logdir, epoch, global_step, model, optimizer, args, checkpoint_file_name): 51 | """Saves the training state into the given log dir path.""" 52 | # checkpoint_file_name = os.path.join(logdir, 'epoch-%04d.pth' % epoch) 53 | # print("saving the checkpoint file '%s'..." % checkpoint_file_name) 54 | checkpoint = { 55 | 'epoch': epoch + 1, 56 | 'global_step': global_step, 57 | 'state_dict': model.state_dict(), 58 | 'optimizer': optimizer.state_dict(), 59 | 'args': args 60 | } 61 | torch.save(checkpoint, checkpoint_file_name) 62 | del checkpoint 63 | 64 | 65 | def save_latest_checkpoint(logdir, epoch, global_step, model, optimizer, args): 66 | # checkpoint_file_name = os.path.join(logdir, 'last.pth') 67 | checkpoint_file_name = os.path.join(logdir, 'epoch-%d.pth' % epoch) 68 | if os.path.exists(checkpoint_file_name): 69 | os.remove(checkpoint_file_name) 70 | save_checkpoint(logdir, epoch, global_step, model, optimizer, args, checkpoint_file_name) 71 | 72 | 73 | def save_best_checkpoint(logdir, epoch, global_step, model, optimizer, args, metric, name): 74 | """Saves the training state into the given log dir path.""" 75 | checkpoint_file_name = os.path.join(logdir, 'best-%s.pth' % name) 76 | # print("saving the checkpoint file '%s'..." % checkpoint_file_name) 77 | checkpoint = { 78 | 'epoch': epoch + 1, 79 | 'global_step': global_step, 80 | 'state_dict': model.state_dict(), 81 | 'optimizer': optimizer.state_dict(), 82 | 'args': args, 83 | 'metric': metric 84 | } 85 | torch.save(checkpoint, checkpoint_file_name) 86 | del checkpoint 87 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'Erdene-Ochir Tuguldur' 2 | 3 | from .cait import Cait 4 | from .swin import Swin 5 | from .vit import Vit 6 | 7 | 8 | def get_model(args): 9 | name = args.model 10 | 11 | if name.lower().startswith('vit'): 12 | return Vit(name, args) 13 | elif name.lower().startswith('swin'): 14 | return Swin(name, args) 15 | elif name.lower().startswith('cait'): 16 | return Cait(name, args) 17 | else: 18 | raise RuntimeError("Unknown model! %s" % name) 19 | -------------------------------------------------------------------------------- /models/cait.py: -------------------------------------------------------------------------------- 1 | # 2 | # Original code: https://www.kaggle.com/c/bms-molecular-translation/discussion/231190 3 | # Updated for CAIT 4 | # 5 | 6 | import numpy as np 7 | import timm 8 | from torch.nn.utils.rnn import pack_padded_sequence 9 | 10 | from .fairseq_transformer import * 11 | from .fairseq_transformer import PositionEncode1D, TransformerDecode 12 | 13 | 14 | class CaitEncoder(nn.Module): 15 | def __init__(self, backbone, args): 16 | super().__init__() 17 | 18 | self.e = timm.create_model(backbone, pretrained=True) 19 | 20 | def forward(self, x): 21 | B = x.shape[0] 22 | x = self.e.patch_embed(x) 23 | 24 | cls_tokens = self.e.cls_token.expand(B, -1, -1) 25 | 26 | x = x + self.e.pos_embed 27 | x = self.e.pos_drop(x) 28 | 29 | for i, blk in enumerate(self.e.blocks): 30 | x = blk(x) 31 | 32 | for i, blk in enumerate(self.e.blocks_token_only): 33 | cls_tokens = blk(x, cls_tokens) 34 | 35 | x = torch.cat((cls_tokens, x), dim=1) 36 | 37 | x = self.e.norm(x) 38 | # print(x.size()) 39 | # print(x[:, 0].size()) 40 | return x # x[:, 0] 41 | 42 | 43 | class CaitDecoder(nn.Module): 44 | def __init__(self, args): 45 | super().__init__() 46 | 47 | self.vocab_size = args.vocab_size 48 | self.max_length = 300 # args.max_token 49 | self.embed_dim = args.embed_dim 50 | 51 | self.image_encode = nn.Identity() 52 | self.text_pos = PositionEncode1D(self.embed_dim, self.max_length) 53 | self.token_embed = nn.Embedding(self.vocab_size, self.embed_dim) 54 | self.text_decode = TransformerDecode(self.embed_dim, 55 | ff_dim=args.ff_dim, 56 | num_head=args.num_head, 57 | num_layer=args.num_layer) 58 | 59 | # --- 60 | self.logit = nn.Linear(self.embed_dim, self.vocab_size) 61 | 62 | # ---- 63 | # initialization 64 | self.token_embed.weight.data.uniform_(-0.1, 0.1) 65 | self.logit.bias.data.fill_(0) 66 | self.logit.weight.data.uniform_(-0.1, 0.1) 67 | 68 | @torch.jit.unused 69 | def forward(self, image_embed, token, length): 70 | device = image_embed.device 71 | # 16, 577, 768 72 | image_embed = self.image_encode(image_embed).permute(1, 0, 2).contiguous() 73 | # (T,N,E) expected 74 | 75 | text_embed = self.token_embed(token) 76 | text_embed = self.text_pos(text_embed).permute(1, 0, 2).contiguous() 77 | 78 | text_mask_max_length = length.max() # max_length 79 | text_mask = np.triu(np.ones((text_mask_max_length, text_mask_max_length)), k=1).astype(np.uint8) 80 | text_mask = torch.autograd.Variable(torch.from_numpy(text_mask) == 1).to(device) 81 | 82 | # ---- 83 | # mask based on length of token? 84 | # perturb mask as aug 85 | 86 | x = self.text_decode(text_embed, image_embed, text_mask) 87 | x = x.permute(1, 0, 2).contiguous() 88 | 89 | logit = self.logit(x) 90 | return logit 91 | 92 | @torch.jit.export 93 | def predict(self, image): 94 | STOI = { 95 | '': 190, 96 | '': 191, 97 | '': 192, 98 | } 99 | 100 | # --------------------------------- 101 | device = image.device 102 | batch_size = len(image) 103 | 104 | # image_embed = self.cnn(image) 105 | image_embed = self.image_encode(image).permute(1, 0, 2).contiguous() 106 | 107 | token = torch.full((batch_size, self.max_length), STOI[''], dtype=torch.long, device=device) 108 | text_pos = self.text_pos.pos 109 | token[:, 0] = STOI[''] 110 | 111 | # ------------------------------------- 112 | eos = STOI[''] 113 | pad = STOI[''] 114 | 115 | # incremental_state = {} 116 | incremental_state = torch.jit.annotate( 117 | Dict[str, Dict[str, Optional[Tensor]]], 118 | torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {}), 119 | ) 120 | for t in range(self.max_length - 1): 121 | # last_token = token [:,:(t+1)] 122 | # text_embed = self.token_embed(last_token) 123 | # text_embed = self.text_pos(text_embed) #text_embed + text_pos[:,:(t+1)] # 124 | 125 | last_token = token[:, t] 126 | text_embed = self.token_embed(last_token) 127 | text_embed = text_embed + text_pos[:, t] # 128 | text_embed = text_embed.reshape(1, batch_size, self.embed_dim) 129 | 130 | x = self.text_decode.forward_one(text_embed, image_embed, incremental_state) 131 | x = x.reshape(batch_size, self.embed_dim) 132 | # print(incremental_state.keys()) 133 | 134 | l = self.logit(x) 135 | k = torch.argmax(l, -1) # predict max 136 | token[:, t + 1] = k 137 | if ((k == eos) | (k == pad)).all(): break 138 | 139 | predict = token[:, 1:] 140 | return predict 141 | 142 | 143 | class Cait(nn.Module): 144 | def __init__(self, backbone, args): 145 | super().__init__() 146 | self.encoder = CaitEncoder(backbone, args) 147 | self.decoder = CaitDecoder(args) 148 | 149 | def _forward(self, x, encoded_captions, caption_lengths): 150 | encoder_out = self.encoder(x) 151 | 152 | caption_lengths, sort_ind = caption_lengths.squeeze(1).sort(dim=0, descending=True) 153 | encoder_out = encoder_out[sort_ind] 154 | encoded_captions = encoded_captions[sort_ind] 155 | decode_lengths = (caption_lengths - 1).tolist() 156 | 157 | predictions = self.decoder(encoder_out, encoded_captions, caption_lengths) 158 | targets = encoded_captions[:, 1:] 159 | predictions = pack_padded_sequence(predictions, decode_lengths, batch_first=True).data 160 | targets = pack_padded_sequence(targets, decode_lengths, batch_first=True).data 161 | # targets = encoded_captions 162 | return targets, predictions 163 | 164 | def _predict(self, x, max_length, tokenizer): 165 | encoder_out = self.encoder(x) 166 | predictions = self.decoder.predict(encoder_out) # , max_length, tokenizer) 167 | # predicted_sequence = torch.argmax(predictions.detach().cpu(), -1).numpy() 168 | predicted_captions = tokenizer.predict_captions(predictions.detach().cpu().numpy()) 169 | predicted_captions = ['InChI=1S/' + p for p in predicted_captions] 170 | return predicted_captions 171 | 172 | def forward(self, x, predict, encoded_captions, caption_lengths, max_length, tokenizer): 173 | if predict: 174 | return self._predict(x, max_length, tokenizer) 175 | else: 176 | return self._forward(x, encoded_captions, caption_lengths) 177 | -------------------------------------------------------------------------------- /models/fairseq_transformer.py: -------------------------------------------------------------------------------- 1 | # 2 | # Original code: https://www.kaggle.com/c/bms-molecular-translation/discussion/231190 3 | # 4 | 5 | import math 6 | import torch 7 | import torch.nn as nn 8 | 9 | from typing import Tuple, Dict, Optional 10 | 11 | from fairseq import utils 12 | from fairseq.models import * 13 | from fairseq.modules import * 14 | 15 | 16 | # https://stackoverflow.com/questions/4984647/accessing-dict-keys-like-an-attribute 17 | from torch import Tensor 18 | 19 | 20 | class Namespace(object): 21 | def __init__(self, adict): 22 | self.__dict__.update(adict) 23 | 24 | 25 | # ------------------------------------------------------ 26 | # https://kazemnejad.com/blog/transformer_architecture_positional_encoding/ 27 | # https://stackoverflow.com/questions/46452020/sinusoidal-embedding-attention-is-all-you-need 28 | 29 | class PositionEncode1D(nn.Module): 30 | def __init__(self, dim, max_length): 31 | super().__init__() 32 | assert (dim % 2 == 0) 33 | self.max_length = max_length 34 | 35 | d = torch.exp(torch.arange(0., dim, 2) * (-math.log(10000.0) / dim)) 36 | position = torch.arange(0., max_length).unsqueeze(1) 37 | pos = torch.zeros(1, max_length, dim) 38 | pos[0, :, 0::2] = torch.sin(position * d) 39 | pos[0, :, 1::2] = torch.cos(position * d) 40 | self.register_buffer('pos', pos) 41 | 42 | def forward(self, x): 43 | batch_size, T, dim = x.shape 44 | x = x + self.pos[:, :T] 45 | return x 46 | 47 | 48 | # https://gitlab.maastrichtuniversity.nl/dsri-examples/dsri-pytorch-workspace/-/blob/c8a88cdeb8e1a0f3a2ccd3c6119f43743cbb01e9/examples/transformer/fairseq/models/transformer.py 49 | # https://github.com/pytorch/fairseq/issues/568 50 | # fairseq/fairseq/models/fairseq_encoder.py 51 | 52 | # https://github.com/pytorch/fairseq/blob/master/fairseq/modules/transformer_layer.py 53 | class TransformerEncode(FairseqEncoder): 54 | 55 | def __init__(self, dim, ff_dim, num_head, num_layer): 56 | super().__init__({}) 57 | # print('my TransformerEncode()') 58 | 59 | self.layer = nn.ModuleList([ 60 | TransformerEncoderLayer(Namespace({ 61 | 'encoder_embed_dim': dim, 62 | 'encoder_attention_heads': num_head, 63 | 'attention_dropout': 0.1, 64 | 'dropout': 0.1, 65 | 'encoder_normalize_before': True, 66 | 'encoder_ffn_embed_dim': ff_dim, 67 | })) for i in range(num_layer) 68 | ]) 69 | self.layer_norm = nn.LayerNorm(dim) 70 | 71 | def forward(self, x): # T x B x C 72 | # print('my TransformerEncode forward()') 73 | for layer in self.layer: 74 | x = layer(x) 75 | x = self.layer_norm(x) 76 | return x 77 | 78 | 79 | # https://mt.cs.upc.edu/2020/12/21/the-transformer-fairseq-edition/ 80 | # for debug 81 | # class TransformerDecode(FairseqDecoder): 82 | # def __init__(self, dim, ff_dim, num_head, num_layer): 83 | # super().__init__({}) 84 | # print('my TransformerDecode()') 85 | # 86 | # self.layer = nn.ModuleList([ 87 | # TransformerDecoderLayer(Namespace({ 88 | # 'decoder_embed_dim': dim, 89 | # 'decoder_attention_heads': num_head, 90 | # 'attention_dropout': 0.1, 91 | # 'dropout': 0.1, 92 | # 'decoder_normalize_before': True, 93 | # 'decoder_ffn_embed_dim': ff_dim, 94 | # })) for i in range(num_layer) 95 | # ]) 96 | # self.layer_norm = nn.LayerNorm(dim) 97 | # 98 | # 99 | # def forward(self, x, mem, x_mask):# T x B x C 100 | # print('my TransformerDecode forward()') 101 | # for layer in self.layer: 102 | # x = layer(x, mem, self_attn_mask=x_mask)[0] 103 | # x = self.layer_norm(x) 104 | # return x # T x B x C 105 | 106 | # https://fairseq.readthedocs.io/en/latest/tutorial_simple_lstm.html 107 | # see https://gitlab.maastrichtuniversity.nl/dsri-examples/dsri-pytorch-workspace/-/blob/c8a88cdeb8e1a0f3a2ccd3c6119f43743cbb01e9/examples/transformer/fairseq/models/transformer.py 108 | class TransformerDecode(FairseqIncrementalDecoder): 109 | def __init__(self, dim, ff_dim, num_head, num_layer): 110 | super().__init__({}) 111 | # print('my TransformerDecode()') 112 | 113 | self.layer = nn.ModuleList([ 114 | TransformerDecoderLayer(Namespace({ 115 | 'decoder_embed_dim': dim, 116 | 'decoder_attention_heads': num_head, 117 | 'attention_dropout': 0.1, 118 | 'dropout': 0.1, 119 | 'decoder_normalize_before': True, 120 | 'decoder_ffn_embed_dim': ff_dim, 121 | })) for i in range(num_layer) 122 | ]) 123 | self.layer_norm = nn.LayerNorm(dim) 124 | 125 | def forward(self, x, mem, x_mask): 126 | # print('my TransformerDecode forward()') 127 | for layer in self.layer: 128 | x = layer(x, mem, self_attn_mask=x_mask)[0] 129 | x = self.layer_norm(x) 130 | return x # T x B x C 131 | 132 | # def forward_one(self, x, mem, incremental_state): 133 | def forward_one(self, 134 | x: Tensor, 135 | mem: Tensor, 136 | incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] 137 | ) -> Tensor: 138 | x = x[-1:] 139 | for layer in self.layer: 140 | x = layer(x, mem, incremental_state=incremental_state)[0] 141 | x = self.layer_norm(x) 142 | return x 143 | -------------------------------------------------------------------------------- /models/swin.py: -------------------------------------------------------------------------------- 1 | # 2 | # Original code: https://www.kaggle.com/c/bms-molecular-translation/discussion/231190 3 | # Updated for SWIN 4 | # 5 | 6 | import numpy as np 7 | import timm 8 | from torch.nn.utils.rnn import pack_padded_sequence 9 | 10 | from .fairseq_transformer import * 11 | from .fairseq_transformer import PositionEncode1D, TransformerDecode 12 | from .vit import VitDecoder 13 | 14 | 15 | class SwinEncoder(nn.Module): 16 | def __init__(self, backbone, args): 17 | super().__init__() 18 | 19 | self.e = timm.create_model(backbone, pretrained=True) 20 | 21 | def forward(self, x): 22 | x = self.e.patch_embed(x) 23 | if self.e.absolute_pos_embed is not None: 24 | x = x + self.e.absolute_pos_embed 25 | x = self.e.pos_drop(x) 26 | x = self.e.layers(x) 27 | x = self.e.norm(x) 28 | return x 29 | 30 | 31 | class Swin(nn.Module): 32 | def __init__(self, backbone, args): 33 | super().__init__() 34 | self.encoder = SwinEncoder(backbone, args) 35 | self.decoder = VitDecoder(args) 36 | 37 | def _forward(self, x, encoded_captions, caption_lengths): 38 | encoder_out = self.encoder(x) 39 | 40 | caption_lengths, sort_ind = caption_lengths.squeeze(1).sort(dim=0, descending=True) 41 | encoder_out = encoder_out[sort_ind] 42 | encoded_captions = encoded_captions[sort_ind] 43 | decode_lengths = (caption_lengths - 1).tolist() 44 | 45 | predictions = self.decoder(encoder_out, encoded_captions, caption_lengths) 46 | targets = encoded_captions[:, 1:] 47 | predictions = pack_padded_sequence(predictions, decode_lengths, batch_first=True).data 48 | targets = pack_padded_sequence(targets, decode_lengths, batch_first=True).data 49 | # targets = encoded_captions 50 | return targets, predictions 51 | 52 | def _predict(self, x, max_length, tokenizer): 53 | encoder_out = self.encoder(x) 54 | predictions = self.decoder.predict(encoder_out) # , max_length, tokenizer) 55 | # predicted_sequence = torch.argmax(predictions.detach().cpu(), -1).numpy() 56 | predicted_captions = tokenizer.predict_captions(predictions.detach().cpu().numpy()) 57 | predicted_captions = ['InChI=1S/' + p for p in predicted_captions] 58 | return predicted_captions 59 | 60 | def forward(self, x, predict, encoded_captions, caption_lengths, max_length, tokenizer): 61 | if predict: 62 | return self._predict(x, max_length, tokenizer) 63 | else: 64 | return self._forward(x, encoded_captions, caption_lengths) 65 | -------------------------------------------------------------------------------- /models/vit.py: -------------------------------------------------------------------------------- 1 | # 2 | # Original code: https://www.kaggle.com/c/bms-molecular-translation/discussion/231190 3 | # Updated for VIT 4 | # 5 | 6 | import numpy as np 7 | import timm 8 | from torch.nn.utils.rnn import pack_padded_sequence 9 | 10 | from .fairseq_transformer import * 11 | from .fairseq_transformer import PositionEncode1D, TransformerDecode 12 | 13 | 14 | class VitEncoder(nn.Module): 15 | def __init__(self, backbone, args): 16 | super().__init__() 17 | 18 | self.e = timm.create_model(backbone, pretrained=True) 19 | 20 | def forward(self, x): 21 | x = self.e.patch_embed(x) 22 | cls_token = self.e.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks 23 | if self.e.dist_token is None: 24 | x = torch.cat((cls_token, x), dim=1) 25 | else: 26 | x = torch.cat((cls_token, self.e.dist_token.expand(x.shape[0], -1, -1), x), dim=1) 27 | x = self.e.pos_drop(x + self.e.pos_embed) 28 | x = self.e.blocks(x) 29 | x = self.e.norm(x) 30 | # tiny 16, 197, 192 31 | # small 16, 197, 384 32 | # base 16, 197, 768 33 | # base384 16, 577, 768 34 | return x 35 | 36 | 37 | class VitDecoder(nn.Module): 38 | def __init__(self, args): 39 | super().__init__() 40 | 41 | self.vocab_size = args.vocab_size 42 | self.max_length = 300 # args.max_token 43 | self.embed_dim = args.embed_dim 44 | 45 | self.image_encode = nn.Identity() 46 | self.text_pos = PositionEncode1D(self.embed_dim, self.max_length) 47 | self.token_embed = nn.Embedding(self.vocab_size, self.embed_dim) 48 | self.text_decode = TransformerDecode(self.embed_dim, 49 | ff_dim=args.ff_dim, 50 | num_head=args.num_head, 51 | num_layer=args.num_layer) 52 | 53 | # --- 54 | self.logit = nn.Linear(self.embed_dim, self.vocab_size) 55 | 56 | # ---- 57 | # initialization 58 | self.token_embed.weight.data.uniform_(-0.1, 0.1) 59 | self.logit.bias.data.fill_(0) 60 | self.logit.weight.data.uniform_(-0.1, 0.1) 61 | 62 | @torch.jit.unused 63 | def forward(self, image_embed, token, length): 64 | device = image_embed.device 65 | # 16, 577, 768 66 | image_embed = self.image_encode(image_embed).permute(1, 0, 2).contiguous() 67 | # (T,N,E) expected 68 | 69 | text_embed = self.token_embed(token) 70 | text_embed = self.text_pos(text_embed).permute(1, 0, 2).contiguous() 71 | 72 | text_mask_max_length = length.max() # max_length 73 | text_mask = np.triu(np.ones((text_mask_max_length, text_mask_max_length)), k=1).astype(np.uint8) 74 | text_mask = torch.autograd.Variable(torch.from_numpy(text_mask) == 1).to(device) 75 | 76 | # ---- 77 | # mask based on length of token? 78 | # perturb mask as aug 79 | 80 | x = self.text_decode(text_embed, image_embed, text_mask) 81 | x = x.permute(1, 0, 2).contiguous() 82 | 83 | logit = self.logit(x) 84 | return logit 85 | 86 | @torch.jit.export 87 | def predict(self, image): 88 | STOI = { 89 | '': 190, 90 | '': 191, 91 | '': 192, 92 | } 93 | 94 | # --------------------------------- 95 | device = image.device 96 | batch_size = len(image) 97 | 98 | # image_embed = self.cnn(image) 99 | image_embed = self.image_encode(image).permute(1, 0, 2).contiguous() 100 | 101 | token = torch.full((batch_size, self.max_length), STOI[''], dtype=torch.long, device=device) 102 | text_pos = self.text_pos.pos 103 | token[:, 0] = STOI[''] 104 | 105 | # ------------------------------------- 106 | eos = STOI[''] 107 | pad = STOI[''] 108 | 109 | # incremental_state = {} 110 | incremental_state = torch.jit.annotate( 111 | Dict[str, Dict[str, Optional[Tensor]]], 112 | torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {}), 113 | ) 114 | for t in range(self.max_length - 1): 115 | # last_token = token [:,:(t+1)] 116 | # text_embed = self.token_embed(last_token) 117 | # text_embed = self.text_pos(text_embed) #text_embed + text_pos[:,:(t+1)] # 118 | 119 | last_token = token[:, t] 120 | text_embed = self.token_embed(last_token) 121 | text_embed = text_embed + text_pos[:, t] # 122 | text_embed = text_embed.reshape(1, batch_size, self.embed_dim) 123 | 124 | x = self.text_decode.forward_one(text_embed, image_embed, incremental_state) 125 | x = x.reshape(batch_size, self.embed_dim) 126 | # print(incremental_state.keys()) 127 | 128 | l = self.logit(x) 129 | k = torch.argmax(l, -1) # predict max 130 | token[:, t + 1] = k 131 | if ((k == eos) | (k == pad)).all(): break 132 | 133 | predict = token[:, 1:] 134 | return predict 135 | 136 | 137 | class Vit(nn.Module): 138 | def __init__(self, backbone, args): 139 | super().__init__() 140 | self.encoder = VitEncoder(backbone, args) 141 | self.decoder = VitDecoder(args) 142 | 143 | def _forward(self, x, encoded_captions, caption_lengths): 144 | encoder_out = self.encoder(x) 145 | 146 | caption_lengths, sort_ind = caption_lengths.squeeze(1).sort(dim=0, descending=True) 147 | encoder_out = encoder_out[sort_ind] 148 | encoded_captions = encoded_captions[sort_ind] 149 | decode_lengths = (caption_lengths - 1).tolist() 150 | 151 | predictions = self.decoder(encoder_out, encoded_captions, caption_lengths) 152 | targets = encoded_captions[:, 1:] 153 | predictions = pack_padded_sequence(predictions, decode_lengths, batch_first=True).data 154 | targets = pack_padded_sequence(targets, decode_lengths, batch_first=True).data 155 | # targets = encoded_captions 156 | return targets, predictions 157 | 158 | def _predict(self, x, max_length, tokenizer): 159 | encoder_out = self.encoder(x) 160 | predictions = self.decoder.predict(encoder_out) # , max_length, tokenizer) 161 | # predicted_sequence = torch.argmax(predictions.detach().cpu(), -1).numpy() 162 | predicted_captions = tokenizer.predict_captions(predictions.detach().cpu().numpy()) 163 | predicted_captions = ['InChI=1S/' + p for p in predicted_captions] 164 | return predicted_captions 165 | 166 | def forward(self, x, predict, encoded_captions, caption_lengths, max_length, tokenizer): 167 | if predict: 168 | return self._predict(x, max_length, tokenizer) 169 | else: 170 | return self._forward(x, encoded_captions, caption_lengths) 171 | -------------------------------------------------------------------------------- /normalize_inchis.py: -------------------------------------------------------------------------------- 1 | # 2 | # Original code: https://www.kaggle.com/nofreewill/normalize-your-predictions 3 | # 4 | 5 | from tqdm import tqdm 6 | from rdkit import Chem 7 | from rdkit import RDLogger 8 | 9 | RDLogger.DisableLog('rdApp.*') 10 | from pathlib import Path 11 | 12 | 13 | def normalize_inchi(inchi): 14 | try: 15 | mol = Chem.MolFromInchi(inchi) 16 | return inchi if (mol is None) else Chem.MolToInchi(mol) 17 | except: 18 | return inchi 19 | 20 | 21 | import sys 22 | 23 | # Segfault in rdkit taken care of, run it with: 24 | # while [ 1 ]; do python normalize_inchis.py && break; done 25 | if __name__ == '__main__': 26 | # Input & Output 27 | print(sys.argv) 28 | orig_path = Path(sys.argv[1]) 29 | norm_path = orig_path.with_name(orig_path.stem + '_norm.csv') 30 | 31 | # Do the job 32 | N = norm_path.read_text().count('\n') if norm_path.exists() else 0 33 | print(N, 'number of predictions already normalized') 34 | 35 | r = open(str(orig_path), 'r') 36 | w = open(str(norm_path), 'a', buffering=1) 37 | 38 | for _ in range(N): 39 | r.readline() 40 | line = r.readline() # this line is the header or is where it segfaulted last time 41 | w.write(line) 42 | 43 | pbar = tqdm() 44 | while True: 45 | line = r.readline() 46 | if not line: 47 | break # done 48 | image_id = line.split(',')[0] 49 | inchi = ','.join(line[:-1].split(',')[1:]).replace('"', '') 50 | inchi_norm = normalize_inchi(inchi) 51 | w.write(f'{image_id},"{inchi_norm}"\n') 52 | pbar.update(1) 53 | 54 | r.close() 55 | w.close() 56 | -------------------------------------------------------------------------------- /normalize_inchis.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | while [ 1 ]; do python normalize_inchis.py $1 && break; done 4 | -------------------------------------------------------------------------------- /optims/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'Erdene-Ochir Tuguldur' 2 | 3 | import torch 4 | from torch.optim.lr_scheduler import StepLR, MultiStepLR, CosineAnnealingLR, ReduceLROnPlateau 5 | from torch_optimizer import RAdam, Lookahead 6 | 7 | 8 | def get_optimizer(args, model): 9 | if args.optim == 'sgd': 10 | optimizer = torch.optim.SGD([{'params': filter(lambda p: p.requires_grad, model.encoder.parameters()), 'lr': args.encoder_lr}, 11 | {'params': filter(lambda p: p.requires_grad, model.decoder.parameters()), 'lr': args.decoder_lr}], 12 | lr=args.decoder_lr, momentum=0.9, weight_decay=args.weight_decay) 13 | elif args.optim == 'adamw': 14 | optimizer = torch.optim.AdamW([{'params': filter(lambda p: p.requires_grad, model.encoder.parameters()), 'lr': args.encoder_lr}, 15 | {'params': filter(lambda p: p.requires_grad, model.decoder.parameters()), 'lr': args.decoder_lr}], 16 | lr=args.decoder_lr, weight_decay=args.weight_decay) 17 | elif args.optim == 'adam': 18 | # optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.weight_decay) 19 | optimizer = torch.optim.Adam([{'params': filter(lambda p: p.requires_grad, model.encoder.parameters()), 'lr': args.encoder_lr}, 20 | {'params': filter(lambda p: p.requires_grad, model.decoder.parameters()), 'lr': args.decoder_lr}], 21 | lr=args.decoder_lr, weight_decay=args.weight_decay, 22 | # betas=(args.beta1, args.beta2) 23 | ) 24 | elif args.optim == 'lookahead_radam': 25 | optimizer = Lookahead(RAdam([{'params': filter(lambda p: p.requires_grad, model.encoder.parameters()), 'lr': args.encoder_lr}, 26 | {'params': filter(lambda p: p.requires_grad, model.decoder.parameters()), 'lr': args.decoder_lr}], 27 | lr=args.decoder_lr, weight_decay=args.weight_decay), 28 | alpha=0.5, k=5) 29 | else: 30 | raise RuntimeError("Unknown optimizer!") 31 | return optimizer 32 | 33 | 34 | def get_lr_scheduler(args, optimizer): 35 | if args.lr_policy == 'step': 36 | step_size = round(args.max_epochs * args.lr_step_ratio[0]) 37 | print("step size", step_size) 38 | scheduler = StepLR(optimizer, step_size=step_size, gamma=args.lr_gamma) 39 | elif args.lr_policy == 'mstep': 40 | milestones = [round(args.max_epochs * step_ratio) for step_ratio in args.lr_step_ratio] 41 | print("milestones", milestones) 42 | scheduler = MultiStepLR(optimizer, milestones, gamma=args.lr_gamma) 43 | elif args.lr_policy == 'cosine': 44 | t_max = round(args.max_epochs * args.lr_step_ratio[0]) # / 2) 45 | scheduler = CosineAnnealingLR(optimizer, T_max=t_max, eta_min=args.min_lr) 46 | elif args.lr_policy == 'plateau': 47 | scheduler = ReduceLROnPlateau(optimizer, args.lr_plateau_mode, args.lr_gamma, args.lr_plateau_patience, threshold=0.001) 48 | else: 49 | scheduler = None 50 | return scheduler 51 | -------------------------------------------------------------------------------- /r09_create_images_from_allowed_inchi.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | __author__ = 'ZFTurbo: https://kaggle.com/zfturbo' 3 | 4 | import os 5 | import cv2 6 | import numpy as np 7 | import pandas as pd 8 | from tqdm import tqdm 9 | 10 | from rdkit import Chem 11 | from rdkit.Chem import Draw 12 | import albumentations as A 13 | 14 | 15 | def sp_noise(image): 16 | # https://gist.github.com/lucaswiman/1e877a164a69f78694f845eab45c381a 17 | output = image.copy() 18 | if len(image.shape) == 2: 19 | black = 0 20 | white = 255 21 | else: 22 | colorspace = image.shape[2] 23 | if colorspace == 3: # RGB 24 | black = np.array([0, 0, 0], dtype='uint8') 25 | white = np.array([255, 255, 255], dtype='uint8') 26 | else: # RGBA 27 | black = np.array([0, 0, 0, 255], dtype='uint8') 28 | white = np.array([255, 255, 255, 255], dtype='uint8') 29 | probs = np.random.random(image.shape[:2]) 30 | image[probs < .0002] = black 31 | image[probs > .9] = white 32 | return image 33 | 34 | 35 | def noisy_inchi(inchi, inchi_path, add_noise=True, crop_and_pad=True): 36 | mol = Chem.MolFromInchi(inchi) 37 | d = Draw.rdMolDraw2D.MolDraw2DCairo(640, 640) 38 | # https://www.kaggle.com/stainsby/improved-synthetic-data-for-bms-competition-v3 39 | d.drawOptions().useBWAtomPalette() 40 | d.drawOptions().bondLineWidth = 1 41 | d.drawOptions().additionalAtomLabelPadding = np.random.uniform(0, .2) 42 | d.DrawMolecule(mol) 43 | d.FinishDrawing() 44 | d.WriteDrawingText(inchi_path) 45 | if crop_and_pad: 46 | img = cv2.imread(inchi_path, cv2.IMREAD_GRAYSCALE) 47 | crop_rows = img[~np.all(img == 255, axis=1), :] 48 | img = crop_rows[:, ~np.all(crop_rows == 255, axis=0)] 49 | img = cv2.copyMakeBorder(img, 30, 30, 30, 30, cv2.BORDER_CONSTANT, value=255) 50 | # img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) 51 | else: 52 | img = cv2.imread(inchi_path, cv2.IMREAD_GRAYSCALE) 53 | if add_noise: 54 | img = sp_noise(img) 55 | cv2.imwrite(inchi_path, img) 56 | return img 57 | 58 | 59 | def create_dataset(input_csv, out_dir): 60 | fix_transform = A.Compose([A.Transpose(p=1), A.VerticalFlip(p=1)]) 61 | 62 | s = pd.read_csv(input_csv) 63 | print(s) 64 | print(s['image_id'].values) 65 | 66 | # Create folders 67 | if not os.path.isdir(out_dir): 68 | os.mkdir(out_dir) 69 | unique_letters = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'] 70 | for u1 in unique_letters: 71 | if not os.path.isdir(out_dir + u1 + '/'): 72 | os.mkdir(out_dir + u1 + '/') 73 | for u2 in unique_letters: 74 | if not os.path.isdir(out_dir + u1 + '/' + u2 + '/'): 75 | os.mkdir(out_dir + u1 + '/' + u2 + '/') 76 | for u3 in unique_letters: 77 | if not os.path.isdir(out_dir + u1 + '/' + u2 + '/' + u3 + '/'): 78 | os.mkdir(out_dir + u1 + '/' + u2 + '/' + u3 + '/') 79 | 80 | image_ids = s['image_id'].values 81 | inchis = s['InChI'].values 82 | pbar = tqdm(range(len(image_ids))) 83 | full_size = 0 84 | for i in pbar: 85 | image_id = image_ids[i] 86 | inchi = inchis[i] 87 | out_file = out_dir + '{}/{}/{}/{}.png'.format(image_id[0], image_id[1], image_id[2], image_id) 88 | if not os.path.isfile(out_file): 89 | img2 = noisy_inchi(inchi, out_file) 90 | full_size += os.path.getsize(out_file) 91 | # print(img2.shape, out_file) 92 | # show_image(img2) 93 | pbar.set_postfix({'shape': img2.shape, 'size': full_size / (1024 * 1024)}) 94 | if 0: 95 | h, w = img2.shape 96 | if h > w: 97 | img2 = fix_transform(image=img2)['image'] 98 | cv2.imwrite(out_file, img2) 99 | 100 | 101 | def add_image_ids(s, out_csv): 102 | index = s.index.values 103 | image_ids = [] 104 | for ind in index: 105 | m = ind + 0x1000000000000 106 | val = hex(m).split('x')[-1][::-1] 107 | print(ind, val) 108 | image_ids.append(val) 109 | s['image_id'] = image_ids 110 | s.to_csv(out_csv, index=False) 111 | 112 | 113 | if __name__ == '__main__': 114 | input_csv = '/data/bms/extra_approved_InChIs.csv' 115 | input_csv_fixed = '/data/bms/extra_approved_InChIs_with_ids.csv' 116 | s = pd.read_csv(input_csv) 117 | 118 | 119 | def compute_length(col): 120 | def _compute_length(row): 121 | return len(row[col]) 122 | 123 | return _compute_length 124 | 125 | 126 | s['length'] = s.apply(compute_length('InChI'), axis=1) 127 | s = s.sort_values(by=['length'], ascending=False) 128 | 129 | # only longest 1m images 130 | s = s[:1000000] 131 | 132 | if not os.path.isfile(input_csv_fixed): 133 | add_image_ids(s, input_csv_fixed) 134 | output_dir = '/data/bms/extra_images/' 135 | create_dataset(input_csv_fixed, output_dir) 136 | # test_random_molecule_image(n=20, graphics=True) 137 | -------------------------------------------------------------------------------- /swa.py: -------------------------------------------------------------------------------- 1 | """ 2 | Stochastic Weight Averaging (SWA) 3 | 4 | Averaging Weights Leads to Wider Optima and Better Generalization 5 | 6 | https://github.com/timgaripov/swa 7 | """ 8 | from pathlib import Path 9 | import warnings 10 | import argparse 11 | 12 | import torch 13 | from tqdm import tqdm 14 | 15 | from datasets.bms import BMSTrainDataset 16 | from datasets.transforms import get_test_transform, Compose 17 | from eval import bms_collate 18 | 19 | 20 | def moving_average(net1, net2, alpha=1.): 21 | for param1, param2 in zip(net1.parameters(), net2.parameters()): 22 | param1.data *= (1.0 - alpha) 23 | param1.data += param2.data * alpha 24 | 25 | 26 | def _check_bn(module, flag): 27 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): 28 | flag[0] = True 29 | 30 | 31 | def check_bn(model): 32 | flag = [False] 33 | model.apply(lambda module: _check_bn(module, flag)) 34 | return flag[0] 35 | 36 | 37 | def reset_bn(module): 38 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): 39 | module.running_mean = torch.zeros_like(module.running_mean) 40 | module.running_var = torch.ones_like(module.running_var) 41 | 42 | 43 | def _get_momenta(module, momenta): 44 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): 45 | momenta[module] = module.momentum 46 | 47 | 48 | def _set_momenta(module, momenta): 49 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): 50 | module.momentum = momenta[module] 51 | 52 | 53 | def bn_update(loader, model, device): 54 | """ 55 | BatchNorm buffers update (if any). 56 | Performs 1 epochs to estimate buffers average using train dataset. 57 | :param dataset: train dataset for buffers average estimation. 58 | :param model: model being update 59 | :param jobs: jobs for dataloader 60 | :return: None 61 | """ 62 | if not check_bn(model): 63 | print('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>') 64 | print('no bn in model?!') 65 | print('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>!') 66 | # return model 67 | 68 | model.train() 69 | momenta = {} 70 | model.apply(reset_bn) 71 | model.apply(lambda module: _get_momenta(module, momenta)) 72 | n = 0 73 | 74 | model = model.to(device) 75 | pbar = tqdm(loader, unit="samples", unit_scale=loader.batch_size) 76 | for sample in pbar: 77 | inputs, targets, target_lengths = sample['input'].to(device), sample['label'].to(device), sample['label_length'].to(device) 78 | 79 | inputs = inputs.to(device) 80 | b = inputs.size(0) 81 | 82 | momentum = b / (n + b) 83 | for module in momenta.keys(): 84 | module.momentum = momentum 85 | 86 | # model(inputs) 87 | # TODO: 88 | model(inputs, False, targets, target_lengths, 275, test_dataset.tokenizer) 89 | n += b 90 | 91 | model.apply(lambda module: _set_momenta(module, momenta)) 92 | return model 93 | 94 | 95 | if __name__ == '__main__': 96 | import argparse 97 | from pathlib import Path 98 | # project imports 99 | from models import get_model 100 | from misc.utils import load_checkpoint 101 | 102 | import torch 103 | from torch.utils.data import DataLoader 104 | 105 | parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter) 106 | parser.add_argument("--input", type=str, help='input directory') 107 | parser.add_argument("--output", type=str, default='swa_model.pth', help='output model file') 108 | 109 | parser.add_argument("--dataload-workers-nums", type=int, default=4, help='number of workers for dataloader') 110 | 111 | parser.add_argument("--image-size", default=224, type=int, help="image size") 112 | parser.add_argument("--batch-size", type=int, default=64, help='valid batch size') 113 | 114 | parser.add_argument('--bn-update', action='store_true', help='update batch norm') 115 | args = parser.parse_args() 116 | 117 | test_transform = get_test_transform(args) 118 | test_dataset = BMSTrainDataset(fold=0, mode='train', transform=test_transform, dataset_size=0) 119 | test_data_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, 120 | collate_fn=bms_collate, 121 | num_workers=args.dataload_workers_nums, 122 | sampler=None, pin_memory=True, drop_last=True) 123 | 124 | if ',' in args.input: 125 | files = [] 126 | for f in args.input.split(','): 127 | assert f.endswith('.pth') 128 | assert Path(f).exists() 129 | files.append(f) 130 | else: 131 | directory = Path(args.input) 132 | files = [f for f in directory.iterdir() if f.suffix == ".pth"] 133 | assert (len(files) > 1) 134 | 135 | model_args = torch.load(files[0], map_location='cpu').get('args', None) 136 | 137 | 138 | def load_model(f): 139 | model = get_model(model_args) 140 | load_checkpoint(f, model, optimizer=None, use_gpu=True, remove_module_keys=True) 141 | model.float() 142 | return model.cuda() 143 | 144 | 145 | def save_model(model, f): 146 | torch.save({ 147 | 'epoch': -1, 148 | 'global_step': -1, 149 | 'state_dict': model.state_dict(), 150 | 'optimizer': {}, 151 | 'args': model_args 152 | }, f) 153 | 154 | 155 | net = load_model(files[0]) 156 | for i, f in enumerate(files[1:]): 157 | net2 = load_model(f) 158 | moving_average(net, net2, 1. / (i + 2)) 159 | 160 | if args.bn_update: 161 | with torch.no_grad(): 162 | net = bn_update(test_data_loader, net, torch.device('cuda')) 163 | 164 | save_model(net, args.output) 165 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Train for the Kaggle Bristol-Myers Squibb – Molecular Translation challenge: https://www.kaggle.com/c/bms-molecular-translation""" 3 | __author__ = 'Erdene-Ochir Tuguldur' 4 | 5 | import cv2 6 | import os 7 | import json 8 | import time 9 | import torch 10 | import numpy as np 11 | import random 12 | import argparse 13 | 14 | import warnings 15 | 16 | from torch.nn.utils.rnn import pad_sequence 17 | import torch.distributed as dist 18 | 19 | from datasets.bms import BMSTrainDataset, BMSPseudoDataset, BMSExtraDataset 20 | from misc.metrics import compute_metric 21 | 22 | warnings.filterwarnings("ignore") 23 | 24 | from tqdm import * 25 | 26 | # seed 27 | seed = 1234 28 | np.random.seed(seed) 29 | random.seed(seed) 30 | os.environ['PYTHONHASHSEED'] = str(seed) 31 | torch.manual_seed(seed) 32 | torch.cuda.manual_seed(seed) 33 | torch.cuda.manual_seed_all(seed) 34 | torch.backends.cudnn.deterministic = True 35 | # torch.backends.cudnn.benchmark = False 36 | 37 | 38 | import apex 39 | from apex.parallel import DistributedDataParallel 40 | from apex import amp 41 | import albumentations as album 42 | 43 | from torch.utils.data import DataLoader, Subset, ConcatDataset 44 | 45 | from tensorboardX import SummaryWriter 46 | 47 | # project imports 48 | from datasets import * 49 | from losses import * 50 | from models import get_model 51 | from optims import get_optimizer, get_lr_scheduler 52 | from misc.utils import save_best_checkpoint, save_latest_checkpoint 53 | 54 | from datasets.transforms import Compose, ApplyAlbumentations, get_test_transform, CropAugment, RandomNoiseAugment 55 | from misc.utils import load_checkpoint 56 | 57 | parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter) 58 | parser.add_argument("--comment", type=str, default='', help='comment in tensorboard title') 59 | parser.add_argument("--logdir", type=str, default=None, help='log dir for tensorboard logs and checkpoints') 60 | parser.add_argument("--train-batch-size", type=int, default=64, help='train batch size') 61 | parser.add_argument("--valid-batch-size", type=int, default=64, help='train batch size') 62 | parser.add_argument("--dataload-workers-nums", type=int, default=4, help='number of workers for dataloader') 63 | parser.add_argument("--weight-decay", type=float, default=0, help='weight decay') 64 | parser.add_argument("--optim", default='adam', help='choices of optimization algorithms') 65 | parser.add_argument("--clip-grad-norm", type=float, default=0, help='clip gradient norm value') 66 | parser.add_argument("--encoder-lr", type=float, default=1e-3, help='encoder learning rate for optimization') 67 | parser.add_argument("--decoder-lr", type=float, default=1e-3, help='decoder learning rate for optimization') 68 | parser.add_argument("--min-lr", type=float, default=1e-12, help='minimal learning rate for optimization') 69 | parser.add_argument("--warm-start", type=str, help='warm start from a checkpoint') 70 | parser.add_argument("--lr-warmup-steps", type=int, default=1000, help='learning rate warmup steps') 71 | parser.add_argument("--lr-gamma", type=float, default=0.1, help='learning rate gamma for step scheduler') 72 | parser.add_argument("--lr-step-ratio", type=lambda s: [float(item) for item in s.split(',')], default=[0.4], 73 | help='learning rate step ratio for step scheduler') 74 | parser.add_argument("--lr-policy", choices=['cosine', 'mcosine', 'step', 'mstep', 'none', 'plateau'], default='none', 75 | help='learning rate scheduling policy') 76 | parser.add_argument("--lr-plateau_mode", choices=['min', 'max'], default='min', help='reduce on plateau mode') 77 | parser.add_argument("--lr-plateau_patience", type=int, default=10, help='reduce on plateau patience') 78 | parser.add_argument('--mixed-precision', action='store_true', help='enable mixed precision training') 79 | parser.add_argument('--amp-level', type=str, default='O2', help='amp level') 80 | parser.add_argument('--sync-bn', action='store_true', help='enable apex sync batch norm.') 81 | parser.add_argument('--cudnn-benchmark', action='store_true', help='enable CUDNN benchmark') 82 | parser.add_argument("--max-epochs", default=20, type=int, help="train epochs") 83 | parser.add_argument("--local_rank", default=0, type=int) 84 | 85 | parser.add_argument("--encoder-freeze", default=0, type=int, help="freeze first n encoder layers of encoder") 86 | parser.add_argument("--decoder-freeze", default=0, type=int, help="freeze first n encoder layers of decoder") 87 | parser.add_argument('--remove-decoder', action='store_true', help='remove decoder weights for warm start') 88 | parser.add_argument('--remove-encoder', action='store_true', help='remove encoder weights for warm start') 89 | parser.add_argument("--non-strict", action='store_true', help="non strict loading") 90 | 91 | parser.add_argument('--debug', action='store_true', help='visual debug') 92 | parser.add_argument('--cache', action='store_true', help='cache training data') 93 | parser.add_argument('--verbose', action='store_true', help='cache training data') 94 | parser.add_argument('--pipeline', action='store_true', help='make progress bar less verbose') 95 | 96 | parser.add_argument("--valid-epochs", default=1, type=int, help='validation at every valid-epochs') 97 | parser.add_argument('--reset-state', action='store_true', help='reset global steps and epochs to 0') 98 | 99 | parser.add_argument("--model", default='swin_base_patch4_window12_384', help='choices of neural network') 100 | parser.add_argument("--loss", default='ce', choices=LOSSES, help='choices of loss') 101 | parser.add_argument("--fold", default=0, type=int, help="data fold") 102 | parser.add_argument("--image-size", default=224, type=int, help="image size") 103 | parser.add_argument("--train-dataset-size", default=0, type=int, help="subsample train set") 104 | parser.add_argument("--valid-dataset-size", default=0, type=int, help="subsample validation set") 105 | parser.add_argument("--max-token", default=275, type=int, help="max token") 106 | parser.add_argument('--valid-dataset-non-sorted', action='store_true', help='reset global steps and epochs to 0') 107 | 108 | parser.add_argument("--pseudo", default=None, type=str, help="pseudo dataset") 109 | parser.add_argument("--pseudo-dataset-size", default=0, type=int, help="subsample pseudo dataset") 110 | 111 | parser.add_argument("--extra", action='store_true', help="use extra images dataset") 112 | parser.add_argument("--extra-dataset-size", default=0, type=int, help="subsample extra images dataset") 113 | 114 | parser.add_argument("--embed-dim", default=384, type=int, help="embedding dim") 115 | parser.add_argument("--vocab-size", default=193, type=int, help="vocab size") 116 | 117 | parser.add_argument("--label-smoothing", default=0.0, type=float, help="label smoothing alpha") 118 | 119 | parser.add_argument("--aug-rotate90-p", default=0.5, type=float, help="rotate probability") 120 | parser.add_argument("--aug-crop-p", default=0.0, type=float, help="border crop probability") 121 | parser.add_argument("--aug-noise-p", default=0.0, type=float, help="noise probability") 122 | 123 | parser.add_argument("--num-head", default=8, type=int, help="decoder num head") 124 | parser.add_argument("--num-layer", default=3, type=int, help="decoder num layer") 125 | parser.add_argument("--ff-dim", default=1024, type=int, help="decoder ff dim") 126 | 127 | args = parser.parse_args() 128 | 129 | args.distributed = False 130 | args.world_size = 1 131 | if 'WORLD_SIZE' in os.environ: 132 | args.distributed = int(os.environ['WORLD_SIZE']) > 1 133 | args.world_size = int(os.environ['WORLD_SIZE']) 134 | if args.distributed: 135 | torch.cuda.set_device(args.local_rank) 136 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 137 | torch.backends.cudnn.benchmark = args.cudnn_benchmark 138 | 139 | train_transform = Compose([ 140 | CropAugment(probability=args.aug_crop_p, crops=list(range(5, 20))), 141 | ApplyAlbumentations(album.Compose([ 142 | # album.Resize(args.image_size, args.image_size), 143 | # album.HorizontalFlip(p=0.5), 144 | # album.VerticalFlip(p=0.5), 145 | album.RandomRotate90(p=args.aug_rotate90_p), 146 | 147 | # album.RandomScale(scale_limit=0.1, p=1), 148 | # album.Rotate(limit=15, border_mode=cv2.BORDER_CONSTANT, p=1), 149 | # album.PadIfNeeded(args.image_size, args.image_size, border_mode=cv2.BORDER_CONSTANT, p=1), 150 | # album.RandomCrop(args.image_size, args.image_size, p=1), 151 | # album.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.1, rotate_limit=10, border_mode=cv2.BORDER_CONSTANT, p=0.7), 152 | 153 | # album.CoarseDropout(p=0.3), 154 | # album.GaussNoise(p=0.3), 155 | # album.IAASharpen(p=0.1) 156 | ])), 157 | RandomNoiseAugment(probability=args.aug_noise_p, frac=0.001), 158 | get_test_transform(args) 159 | ]) 160 | valid_transform = get_test_transform(args) 161 | 162 | train_dataset = BMSTrainDataset(fold=args.fold, mode='train', transform=train_transform, dataset_size=args.train_dataset_size) 163 | valid_dataset = BMSTrainDataset(fold=args.fold, mode='valid', transform=valid_transform, dataset_size=args.valid_dataset_size, 164 | sort_valid=not args.valid_dataset_non_sorted) 165 | tokenizer = train_dataset.tokenizer 166 | 167 | pseudo_and_extra_datasets = [] 168 | if args.pseudo: 169 | pseudo_dataset = BMSPseudoDataset(pseudo_file=args.pseudo, transform=train_transform, dataset_size=args.pseudo_dataset_size) 170 | pseudo_and_extra_datasets.append(pseudo_dataset) 171 | 172 | if args.extra: 173 | extra_dataset = BMSExtraDataset(transform=train_transform, dataset_size=args.extra_dataset_size) 174 | pseudo_and_extra_datasets.append(extra_dataset) 175 | 176 | if len(pseudo_and_extra_datasets) > 0: 177 | train_dataset = ConcatDataset([train_dataset] + pseudo_and_extra_datasets) 178 | 179 | 180 | def bms_collate(batch): 181 | inputs, labels, label_lengths, inchis = [], [], [], [] 182 | for b in batch: 183 | inputs.append(b['input']) 184 | labels.append(torch.LongTensor(b['label']).reshape(-1, 1)) 185 | label_lengths.append(torch.LongTensor([b['label_length']])) 186 | inchis.append(b['inchi']) 187 | labels = pad_sequence(labels, batch_first=True, padding_value=tokenizer.stoi[""]) 188 | return { 189 | 'input': torch.stack(inputs), 190 | 'label': labels.squeeze(dim=-1), 191 | 'label_length': torch.stack(label_lengths).reshape(-1, 1), 192 | 'inchi': inchis 193 | } 194 | 195 | 196 | train_data_sampler, valid_data_sampler = None, None 197 | if args.distributed: 198 | train_data_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 199 | valid_data_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset, shuffle=False) 200 | train_data_loader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=(train_data_sampler is None), 201 | collate_fn=bms_collate, 202 | num_workers=args.dataload_workers_nums, 203 | sampler=train_data_sampler, pin_memory=True, drop_last=True) 204 | valid_data_loader = DataLoader(valid_dataset, batch_size=args.valid_batch_size, shuffle=False, 205 | collate_fn=bms_collate, 206 | num_workers=args.dataload_workers_nums, 207 | sampler=valid_data_sampler, pin_memory=True) 208 | 209 | model = get_model(args) 210 | if args.verbose and args.local_rank == 0: 211 | print(model) 212 | 213 | if args.warm_start: 214 | load_checkpoint(args.warm_start, model, optimizer=None, use_gpu=False, 215 | remove_encoder=args.remove_encoder, remove_decoder=args.remove_decoder, non_strict=args.non_strict) 216 | 217 | if args.sync_bn: 218 | model = apex.parallel.convert_syncbn_model(model) 219 | model = model.cuda() 220 | 221 | criterion = get_loss(args, tokenizer) 222 | 223 | if args.encoder_freeze != 0: 224 | idx = 0 225 | for idx, parameter in enumerate(model.encoder.parameters()): # enumerate(model.encoder[:args.encoder_freeze].parameters()): 226 | parameter.requires_grad = False 227 | if args.local_rank == 0: 228 | print("encoder frozen!") 229 | # print("freezing %i n layers of total %i encoder layers" % (idx + 1, len(model.encoder))) 230 | if args.decoder_freeze != 0: 231 | idx = 0 232 | for idx, parameter in enumerate(model.decoder.parameters()): # enumerate(model.decoder[:args.decoder_freeze].parameters()): 233 | parameter.requires_grad = False 234 | if args.local_rank == 0: 235 | print("decoder frozen!") 236 | 237 | optimizer = get_optimizer(args, model) 238 | 239 | total_steps = int(len(train_dataset) * args.max_epochs / (args.world_size * args.train_batch_size)) 240 | if args.local_rank == 0: 241 | print("total steps:", total_steps, " epoch steps:", int(total_steps / args.max_epochs)) 242 | 243 | lr_scheduler = get_lr_scheduler(args, optimizer) 244 | 245 | if args.mixed_precision: 246 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.amp_level) # , min_loss_scale=1) 247 | if args.distributed: 248 | model = DistributedDataParallel(model) 249 | 250 | start_timestamp = int(time.time() * 1000) 251 | start_epoch = 1 252 | global_step = 0 253 | best_metric = 999 254 | best_loss = 1e9 255 | 256 | if args.logdir is None: 257 | logname = "%s_%s_wd%.0e" % (args.model, args.optim, args.weight_decay) 258 | if args.comment: 259 | logname = "%s_%s" % (logname, args.comment.replace(' ', '_')) 260 | logdir = os.path.join('logdir', logname) 261 | else: 262 | logdir = args.logdir 263 | 264 | writer = SummaryWriter(log_dir=logdir) 265 | if args.local_rank == 0: 266 | print(vars(args)) 267 | writer.add_text("hparams", json.dumps(vars(args), indent=4)) 268 | 269 | 270 | def get_lr(): 271 | return optimizer.param_groups[0]['lr'], optimizer.param_groups[1]['lr'] 272 | 273 | 274 | def lr_warmup(step): 275 | if step < args.lr_warmup_steps: 276 | assert len(optimizer.param_groups) == 2 277 | 278 | encoder_lr = args.encoder_lr * (step + 1) / (args.lr_warmup_steps + 1) 279 | optimizer.param_groups[0]['lr'] = encoder_lr 280 | decoder_lr = args.decoder_lr * (step + 1) / (args.lr_warmup_steps + 1) 281 | optimizer.param_groups[1]['lr'] = decoder_lr 282 | elif step == args.lr_warmup_steps: 283 | optimizer.param_groups[0]['lr'] = args.encoder_lr 284 | optimizer.param_groups[1]['lr'] = args.decoder_lr 285 | 286 | 287 | def train(epoch, phase='train'): 288 | global global_step, best_metric 289 | 290 | lr_warmup(global_step) 291 | if args.local_rank == 0: 292 | lrs = get_lr() 293 | print("epoch %3d with encoder_lr=%.02e decoder_lr=%.02e" % (epoch, lrs[0], lrs[1])) 294 | 295 | if args.distributed: 296 | train_data_sampler.set_epoch(epoch) 297 | 298 | model.train() if phase == 'train' else model.eval() 299 | torch.set_grad_enabled(True) if phase == 'train' else torch.set_grad_enabled(False) 300 | data_loader = train_data_loader if phase == 'train' else valid_data_loader 301 | 302 | it = 0 303 | running_loss = 0.0 304 | all_inchis = [] 305 | all_predictions = [] 306 | 307 | pbar = None 308 | if args.local_rank == 0: 309 | batch_size = args.train_batch_size if phase == 'train' else args.valid_batch_size 310 | pbar = tqdm(data_loader, unit="images", unit_scale=batch_size, 311 | disable=False, mininterval=30.0 if args.pipeline else 0.1) 312 | 313 | for batch in data_loader if pbar is None else pbar: 314 | inputs, targets, target_lengths = batch['input'].cuda(), batch['label'].cuda(), batch['label_length'].cuda() 315 | 316 | if args.debug: 317 | print(batch['inchi']) 318 | print(inputs.size(), targets.size(), target_lengths.size()) 319 | img = inputs[0, 0, :, :].detach().cpu().numpy() 320 | import matplotlib.pyplot as plt 321 | plt.imshow(img) 322 | plt.show() 323 | 324 | _criterion = criterion 325 | 326 | loss = None 327 | if phase == 'train': 328 | targets, outputs = model(inputs, False, targets, target_lengths, args.max_token, tokenizer) 329 | loss = _criterion(outputs, targets).mean() 330 | else: 331 | predictions = model(inputs, True, targets, target_lengths, args.max_token, tokenizer) 332 | all_predictions += predictions 333 | all_inchis += batch['inchi'] 334 | 335 | if phase == 'train': 336 | lr_warmup(global_step) 337 | optimizer.zero_grad() 338 | 339 | if args.mixed_precision: 340 | with amp.scale_loss(loss, optimizer) as scaled_loss: 341 | scaled_loss.backward() 342 | else: 343 | loss.backward() 344 | if args.clip_grad_norm > 0: 345 | # clip gradient 346 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.clip_grad_norm) 347 | 348 | optimizer.step() 349 | 350 | # global step size is increased only in the train phase 351 | global_step += 1 352 | it += 1 353 | 354 | if loss is not None: 355 | loss = loss.item() 356 | running_loss += loss 357 | 358 | if args.local_rank == 0: 359 | if global_step % 5 == 1: 360 | if phase == 'train': 361 | writer.add_scalar('%s/loss' % phase, loss, global_step) 362 | lrs = get_lr() 363 | writer.add_scalar('%s/encoder_lr' % phase, lrs[0], global_step) 364 | writer.add_scalar('%s/decoder_lr' % phase, lrs[1], global_step) 365 | 366 | # update the progress bar 367 | pbar.set_postfix({ 368 | 'loss': "%.05f" % (running_loss / it) 369 | }) 370 | 371 | if not args.pipeline: 372 | # update the progress bar 373 | pbar.set_postfix({ 374 | 'loss': "%.05f" % (running_loss / it) 375 | }) 376 | 377 | epoch_loss = running_loss / it 378 | metric = 999 379 | if phase == 'valid': 380 | metric = compute_metric(all_inchis, all_predictions) 381 | 382 | if args.distributed: 383 | metric_tensor = torch.tensor(metric).cuda() 384 | rt = metric_tensor.clone() 385 | dist.all_reduce(rt, op=dist.reduce_op.SUM) 386 | rt = rt / args.world_size 387 | metric = rt.item() 388 | 389 | if args.local_rank == 0: 390 | writer.add_scalar('%s/epoch_loss' % phase, epoch_loss, epoch) 391 | 392 | if phase == 'valid': 393 | writer.add_scalar('%s/metric' % phase, metric, epoch) 394 | print(all_inchis[0]) 395 | print(all_predictions[0]) 396 | print("Metric: %f" % metric) 397 | 398 | writer.add_text('%s/prediction' % phase, 399 | 'truth: %s\npredicted: %s' % (all_inchis[0], all_predictions[0]), 400 | global_step if phase == 'train' else global_step + it) 401 | 402 | save_latest_checkpoint(logdir, epoch, global_step, model, optimizer, args) 403 | if metric <= best_metric: 404 | best_metric = metric 405 | print("\nBest metric: %.2f" % best_metric) 406 | save_best_checkpoint(logdir, epoch, global_step, model, optimizer, args, metric, 'metric') 407 | 408 | writer.flush() 409 | 410 | return epoch_loss 411 | 412 | 413 | since = time.time() 414 | epoch = start_epoch 415 | valid_epoch_loss = 1e6 416 | while True: 417 | train_epoch_loss = train(epoch, phase='train') 418 | 419 | if args.local_rank == 0: 420 | time_elapsed = time.time() - since 421 | time_str = 'total time elapsed: {:.0f}h {:.0f}m {:.0f}s '.format(time_elapsed // 3600, 422 | time_elapsed % 3600 // 60, 423 | time_elapsed % 60) 424 | print("train epoch loss %f, step=%d, %s" % (train_epoch_loss, global_step, time_str)) 425 | if epoch == args.max_epochs or epoch % args.valid_epochs == 0: 426 | valid_epoch_loss = train(epoch, phase='valid') 427 | if args.local_rank == 0: 428 | # print("valid epoch loss %f\n############\n" % valid_epoch_loss) 429 | print() 430 | if lr_scheduler is not None: 431 | if args.lr_policy == 'plateau': 432 | print("plateau") 433 | lr_scheduler.step(valid_epoch_loss) 434 | else: 435 | lr_scheduler.step() 436 | 437 | epoch += 1 438 | 439 | if epoch > args.max_epochs: 440 | break 441 | 442 | writer.close() 443 | --------------------------------------------------------------------------------