├── README.md ├── help.py ├── models ├── models.py └── r2gen.py ├── modules ├── __init__.py ├── att_model.py ├── base_cmn.py ├── base_cmn.py.bak ├── caption_model.py ├── dataloaders.py ├── datasets.py ├── encoder_decoder.py ├── loss.py ├── metrics.py ├── optimizers.py ├── rewards.py ├── tokenizers.py ├── trainer.py ├── trainer_base.py ├── trainer_rl.py ├── utils.py └── visual_extractor.py ├── plot.sh ├── pycocoevalcap ├── README.md ├── __init__.py ├── bleu │ ├── LICENSE │ ├── __init__.py │ ├── bleu.py │ └── bleu_scorer.py ├── cider │ ├── __init__.py │ ├── cider.py │ └── cider_scorer.py ├── eval.py ├── license.txt ├── meteor │ ├── __init__.py │ ├── meteor-1.5.jar │ └── meteor.py ├── rouge │ ├── __init__.py │ └── rouge.py └── tokenizer │ ├── __init__.py │ ├── ptbtokenizer.py │ └── stanford-corenlp-3.4.1.jar ├── run.slurm ├── run_base.sh ├── scripts ├── iu_xray │ └── run_rl.sh └── mimic_cxr │ └── run_rl.sh ├── train.py ├── train_base.py ├── train_rl.py └── train_rl_base.py /README.md: -------------------------------------------------------------------------------- 1 | # R2GenRL 2 | The implementation for our ACL-2022 paper titled [Reinforced Cross-modal Alignment for Radiology Report Generation](https://aclanthology.org/2022.findings-acl.38/) 3 | 4 | ## Citation 5 | 6 | ``` 7 | @inproceedings{qin-song-2022-reinforced, 8 | title = "Reinforced Cross-modal Alignment for Radiology Report Generation", 9 | author = "Qin, Han and Song, Yan", 10 | booktitle = "Findings of the Association for Computational Linguistics: ACL 2022", 11 | month = may, 12 | year = "2022", 13 | address = "Dublin, Ireland", 14 | pages = "448--458", 15 | } 16 | ``` 17 | 18 | ## Requirements 19 | Our code works with the following environment. 20 | - `torch==1.5.1` 21 | - `torchvision==0.6.1` 22 | - `opencv-python==4.4.0.42` 23 | 24 | Clone the evaluation tools from the [website](https://github.com/salaniz/pycocoevalcap). 25 | 26 | ## Datasets 27 | We use two datasets (`IU X-Ray` and `MIMIC-CXR`) in our paper. 28 | 29 | For `IU X-Ray`, you can download the dataset from [here](https://openi.nlm.nih.gov/) and then put the files in `data/iu_xray`. 30 | 31 | For `MIMIC-CXR`, you can download the dataset from [here](https://physionet.org/content/mimic-cxr/2.0.0/) and then put the files in `data/mimic_cxr`. 32 | 33 | 34 | ## Running 35 | For `IU X-Ray`, 36 | * `bash scripts/iu_xray/run.sh` to train the `Base+cmn` model on `IU X-Ray`. 37 | * `bash scripts/iu_xray/run_rl.sh` to train the `Base+cmn+rl` model on `IU X-Ray`. 38 | 39 | For `MIMIC-CXR`, 40 | * `bash scripts/mimic_cxr/run.sh` to train the `Base+cmn` model on `MIMIC-CXR`. 41 | * `bash scripts/mimic_cxr/run_rl.sh` to train the `Base+cmn+rl` model on `MIMIC-CXR`. 42 | 43 | ## Attention Plots 44 | 45 | Change the ```path``` (line:183) variable in ```help.py``` to the image that you wish to plot and then run the script ```plot.sh```. 46 | 47 | -------------------------------------------------------------------------------- /help.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import torch 5 | import matplotlib as mpl 6 | from matplotlib import cm 7 | import matplotlib.pyplot as plt 8 | from PIL import Image 9 | from torchvision import transforms 10 | from models.r2gen import R2GenModel 11 | from modules.tokenizers import Tokenizer 12 | 13 | parula_data = [[0.2081, 0.1663, 0.5292], 14 | [0.2116238095, 0.1897809524, 0.5776761905], 15 | [0.212252381, 0.2137714286, 0.6269714286], 16 | [0.2081, 0.2386, 0.6770857143], 17 | [0.1959047619, 0.2644571429, 0.7279], 18 | [0.1707285714, 0.2919380952, 0.779247619], 19 | [0.1252714286, 0.3242428571, 0.8302714286], 20 | [0.0591333333, 0.3598333333, 0.8683333333], 21 | [0.0116952381, 0.3875095238, 0.8819571429], 22 | [0.0059571429, 0.4086142857, 0.8828428571], 23 | [0.0165142857, 0.4266, 0.8786333333], 24 | [0.032852381, 0.4430428571, 0.8719571429], 25 | [0.0498142857, 0.4585714286, 0.8640571429], 26 | [0.0629333333, 0.4736904762, 0.8554380952], 27 | [0.0722666667, 0.4886666667, 0.8467], 28 | [0.0779428571, 0.5039857143, 0.8383714286], 29 | [0.079347619, 0.5200238095, 0.8311809524], 30 | [0.0749428571, 0.5375428571, 0.8262714286], 31 | [0.0640571429, 0.5569857143, 0.8239571429], 32 | [0.0487714286, 0.5772238095, 0.8228285714], 33 | [0.0343428571, 0.5965809524, 0.819852381], 34 | [0.0265, 0.6137, 0.8135], 35 | [0.0238904762, 0.6286619048, 0.8037619048], 36 | [0.0230904762, 0.6417857143, 0.7912666667], 37 | [0.0227714286, 0.6534857143, 0.7767571429], 38 | [0.0266619048, 0.6641952381, 0.7607190476], 39 | [0.0383714286, 0.6742714286, 0.743552381], 40 | [0.0589714286, 0.6837571429, 0.7253857143], 41 | [0.0843, 0.6928333333, 0.7061666667], 42 | [0.1132952381, 0.7015, 0.6858571429], 43 | [0.1452714286, 0.7097571429, 0.6646285714], 44 | [0.1801333333, 0.7176571429, 0.6424333333], 45 | [0.2178285714, 0.7250428571, 0.6192619048], 46 | [0.2586428571, 0.7317142857, 0.5954285714], 47 | [0.3021714286, 0.7376047619, 0.5711857143], 48 | [0.3481666667, 0.7424333333, 0.5472666667], 49 | [0.3952571429, 0.7459, 0.5244428571], 50 | [0.4420095238, 0.7480809524, 0.5033142857], 51 | [0.4871238095, 0.7490619048, 0.4839761905], 52 | [0.5300285714, 0.7491142857, 0.4661142857], 53 | [0.5708571429, 0.7485190476, 0.4493904762], 54 | [0.609852381, 0.7473142857, 0.4336857143], 55 | [0.6473, 0.7456, 0.4188], 56 | [0.6834190476, 0.7434761905, 0.4044333333], 57 | [0.7184095238, 0.7411333333, 0.3904761905], 58 | [0.7524857143, 0.7384, 0.3768142857], 59 | [0.7858428571, 0.7355666667, 0.3632714286], 60 | [0.8185047619, 0.7327333333, 0.3497904762], 61 | [0.8506571429, 0.7299, 0.3360285714], 62 | [0.8824333333, 0.7274333333, 0.3217], 63 | [0.9139333333, 0.7257857143, 0.3062761905], 64 | [0.9449571429, 0.7261142857, 0.2886428571], 65 | [0.9738952381, 0.7313952381, 0.266647619], 66 | [0.9937714286, 0.7454571429, 0.240347619], 67 | [0.9990428571, 0.7653142857, 0.2164142857], 68 | [0.9955333333, 0.7860571429, 0.196652381], 69 | [0.988, 0.8066, 0.1793666667], 70 | [0.9788571429, 0.8271428571, 0.1633142857], 71 | [0.9697, 0.8481380952, 0.147452381], 72 | [0.9625857143, 0.8705142857, 0.1309], 73 | [0.9588714286, 0.8949, 0.1132428571], 74 | [0.9598238095, 0.9218333333, 0.0948380952], 75 | [0.9661, 0.9514428571, 0.0755333333], 76 | [0.9763, 0.9831, 0.0538]] 77 | 78 | def parse_agrs(): 79 | parser = argparse.ArgumentParser() 80 | 81 | # Data input settings 82 | parser.add_argument('--image_dir', type=str, default='data/iu_xray/images/', help='the path to the directory containing the data.') 83 | parser.add_argument('--ann_path', type=str, default='data/iu_xray/annotation.json', help='the path to the directory containing the data.') 84 | 85 | # Data loader settings 86 | parser.add_argument('--dataset_name', type=str, default='iu_xray', choices=['iu_xray', 'mimic_cxr'], help='the dataset to be used.') 87 | parser.add_argument('--max_seq_length', type=int, default=60, help='the maximum sequence length of the reports.') 88 | parser.add_argument('--threshold', type=int, default=3, help='the cut off frequency for the words.') 89 | parser.add_argument('--num_workers', type=int, default=2, help='the number of workers for dataloader.') 90 | parser.add_argument('--batch_size', type=int, default=16, help='the number of samples for a batch') 91 | 92 | # Model settings (for visual extractor) 93 | parser.add_argument('--visual_extractor', type=str, default='resnet101', help='the visual extractor to be used.') 94 | parser.add_argument('--visual_extractor_pretrained', type=bool, default=True, help='whether to load the pretrained visual extractor') 95 | parser.add_argument('--num_labels', type=int, default=14, help='the size of the label set') 96 | 97 | # Model settings (for Transformer) 98 | parser.add_argument('--d_model', type=int, default=512, help='the dimension of Transformer.') 99 | parser.add_argument('--d_ff', type=int, default=512, help='the dimension of FFN.') 100 | parser.add_argument('--d_vf', type=int, default=2048, help='the dimension of the patch features.') 101 | parser.add_argument('--num_heads', type=int, default=8, help='the number of heads in Transformer.') 102 | parser.add_argument('--num_layers', type=int, default=3, help='the number of layers of Transformer.') 103 | parser.add_argument('--dropout', type=float, default=0.1, help='the dropout rate of Transformer.') 104 | parser.add_argument('--logit_layers', type=int, default=1, help='the number of the logit layer.') 105 | parser.add_argument('--bos_idx', type=int, default=0, help='the index of .') 106 | parser.add_argument('--eos_idx', type=int, default=0, help='the index of .') 107 | parser.add_argument('--pad_idx', type=int, default=0, help='the index of .') 108 | parser.add_argument('--use_bn', type=int, default=0, help='whether to use batch normalization.') 109 | parser.add_argument('--drop_prob_lm', type=float, default=0.5, help='the dropout rate of the output layer.') 110 | # for Relational Memory 111 | parser.add_argument('--rm_num_slots', type=int, default=3, help='the number of memory slots.') 112 | parser.add_argument('--rm_num_heads', type=int, default=8, help='the numebr of heads in rm.') 113 | parser.add_argument('--rm_d_model', type=int, default=512, help='the dimension of rm.') 114 | 115 | # Sample related 116 | parser.add_argument('--sample_method', type=str, default='beam_search', help='the sample methods to sample a report.') 117 | parser.add_argument('--beam_size', type=int, default=3, help='the beam size when beam searching.') 118 | parser.add_argument('--temperature', type=float, default=1.0, help='the temperature when sampling.') 119 | parser.add_argument('--sample_n', type=int, default=1, help='the sample number per image.') 120 | parser.add_argument('--group_size', type=int, default=1, help='the group size.') 121 | parser.add_argument('--output_logsoftmax', type=int, default=1, help='whether to output the probabilities.') 122 | parser.add_argument('--decoding_constraint', type=int, default=0, help='whether decoding constraint.') 123 | parser.add_argument('--block_trigrams', type=int, default=1, help='whether to use block trigrams.') 124 | 125 | # Trainer settings 126 | parser.add_argument('--n_gpu', type=int, default=1, help='the number of gpus to be used.') 127 | parser.add_argument('--epochs', type=int, default=100, help='the number of training epochs.') 128 | parser.add_argument('--save_dir', type=str, default='results/iu_xray', help='the patch to save the models.') 129 | parser.add_argument('--record_dir', type=str, default='records/', help='the patch to save the results of experiments.') 130 | parser.add_argument('--log_period', type=int, default=1000, help='the logging interval (in batches).') 131 | parser.add_argument('--save_period', type=int, default=1, help='the saving period (in epochs).') 132 | parser.add_argument('--monitor_mode', type=str, default='max', choices=['min', 'max'], help='whether to max or min the metric.') 133 | parser.add_argument('--monitor_metric', type=str, default='BLEU_4', help='the metric to be monitored.') 134 | parser.add_argument('--early_stop', type=int, default=50, help='the patience of training.') 135 | 136 | # Optimization 137 | parser.add_argument('--optim', type=str, default='Adam', help='the type of the optimizer.') 138 | parser.add_argument('--lr_ve', type=float, default=5e-5, help='the learning rate for the visual extractor.') 139 | parser.add_argument('--lr_ed', type=float, default=1e-4, help='the learning rate for the remaining parameters.') 140 | parser.add_argument('--weight_decay', type=float, default=5e-5, help='the weight decay.') 141 | parser.add_argument('--adam_betas', type=tuple, default=(0.9, 0.98), help='the weight decay.') 142 | parser.add_argument('--adam_eps', type=float, default=1e-9, help='the weight decay.') 143 | parser.add_argument('--amsgrad', type=bool, default=True, help='.') 144 | 145 | # Learning Rate Scheduler 146 | parser.add_argument('--lr_scheduler', type=str, default='StepLR', help='the type of the learning rate scheduler.') 147 | parser.add_argument('--step_size', type=int, default=50, help='the step size of the learning rate scheduler.') 148 | parser.add_argument('--gamma', type=float, default=0.1, help='the gamma of the learning rate scheduler.') 149 | 150 | # Others 151 | parser.add_argument('--seed', type=int, default=9233, help='.') 152 | parser.add_argument('--resume', type=str, help='whether to resume the training from existing checkpoints.') 153 | 154 | args = parser.parse_args() 155 | return args 156 | 157 | def subsequent_mask(size): 158 | "Mask out subsequent positions." 159 | attn_shape = (1, size, size) 160 | subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8') 161 | return torch.from_numpy(subsequent_mask) == 0 162 | 163 | def normalize(tensor): 164 | lwr = tensor.min() 165 | upr = tensor.max() 166 | diff = upr - lwr 167 | return (tensor - lwr) / diff 168 | 169 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 170 | 171 | args = parse_agrs() 172 | tokenizer = Tokenizer(args) 173 | model = R2GenModel(args, tokenizer) 174 | 175 | ckpt = torch.load('results/mimic_cxr/rl_base_seed_9458/model_best.pth') 176 | print(ckpt['epoch']) 177 | model.load_state_dict(ckpt['state_dict']) 178 | for p in model.parameters(): 179 | p.requires_grad = False 180 | model.to(device) 181 | model.eval() 182 | 183 | path = 'data/mimic_cxr/images/p12/p12991634/s50848641/d26b0ded-85fec1e6-be2f1ead-87e3adcb-80a1b20e.jpg' 184 | img = Image.open(path).convert('RGB') 185 | arr = np.array(img) / 255 186 | transform = transforms.Compose([ 187 | transforms.Resize((224, 224)), 188 | transforms.ToTensor(), 189 | transforms.Normalize((0.485, 0.456, 0.406), 190 | (0.229, 0.224, 0.225))]) 191 | img = transform(img).to(device) 192 | 193 | output, _ = model(img.unsqueeze(0), mode='sample') 194 | report = model.tokenizer.decode_batch(output.cpu().numpy())[0] 195 | print(report) 196 | 197 | report_id = tokenizer(report) 198 | print(report_id) 199 | 200 | interpolate = transforms.Resize(arr.shape[0]) 201 | 202 | viridis = cm.get_cmap('jet', 1000) 203 | cmap = viridis 204 | cmap_list = [cmap(i) for i in range(cmap.N)] 205 | cmap_list = cmap_list[int(0.2*cmap.N):] 206 | cmap = mpl.colors.LinearSegmentedColormap.from_list('mcm', parula_data, cmap.N) 207 | norm = mpl.colors.Normalize(vmin=0, vmax=1) 208 | plt.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap)) 209 | plt.savefig('colorbar.png') 210 | plt.close() 211 | output_dir = 'test' 212 | if not os.path.exists(output_dir): 213 | os.mkdir(output_dir) 214 | 215 | result_dict = {} 216 | targets = torch.LongTensor(report_id[:-1]).unsqueeze(0).to(device) 217 | model(img.unsqueeze(0), targets, mode='train') 218 | attn = model.encoder_decoder.model.decoder.layers[0].src_attn.attn 219 | attn_score = attn.squeeze() 220 | diff = torch.max(attn_score, dim=-1)[0] - torch.min(attn_score, dim=-1)[0] 221 | index = torch.argmax(diff, dim=0) 222 | attn_score = attn_score.view((model.encoder_decoder.model.decoder.layers[0].src_attn.h, -1, 7, 7)) 223 | for i, word_id in enumerate(report_id[1:]): 224 | if word_id == 0: break 225 | word = tokenizer.idx2token[word_id] 226 | score = attn_score[index[i], i].unsqueeze(0) 227 | heatmap_score = interpolate(score).squeeze(0).cpu().numpy() 228 | result_dict[word] = heatmap_score 229 | heatmap_score = normalize(heatmap_score) 230 | colormap = cmap(heatmap_score)[:, :, :3] 231 | # import pdb; pdb.set_trace() 232 | plt.imsave('test.png', colormap) 233 | colormap = 0.5*colormap + 0.5*arr 234 | path = os.path.join(output_dir, word+'.png') 235 | plt.imsave(path, colormap) 236 | -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from modules.visual_extractor import VisualExtractor 6 | from modules.base_cmn import BaseCMN 7 | 8 | 9 | class BaseCMNModel(nn.Module): 10 | def __init__(self, args, tokenizer): 11 | super(BaseCMNModel, self).__init__() 12 | self.args = args 13 | self.tokenizer = tokenizer 14 | self.visual_extractor = VisualExtractor(args) 15 | self.encoder_decoder = BaseCMN(args, tokenizer) 16 | if args.dataset_name == 'iu_xray': 17 | self.forward = self.forward_iu_xray 18 | else: 19 | self.forward = self.forward_mimic_cxr 20 | 21 | def __str__(self): 22 | model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 23 | params = sum([np.prod(p.size()) for p in model_parameters]) 24 | return super().__str__() + '\nTrainable parameters: {}'.format(params) 25 | 26 | def forward_iu_xray(self, images, targets=None, mode='train', update_opts={}): 27 | att_feats_0, fc_feats_0 = self.visual_extractor(images[:, 0]) 28 | att_feats_1, fc_feats_1 = self.visual_extractor(images[:, 1]) 29 | fc_feats = torch.cat((fc_feats_0, fc_feats_1), dim=1) 30 | att_feats = torch.cat((att_feats_0, att_feats_1), dim=1) 31 | if mode == 'train': 32 | output = self.encoder_decoder(fc_feats, att_feats, targets, mode='forward') 33 | return output 34 | elif mode == 'sample': 35 | output, output_probs = self.encoder_decoder(fc_feats, att_feats, mode='sample', update_opts=update_opts) 36 | return output, output_probs 37 | else: 38 | raise ValueError 39 | # return output 40 | 41 | def forward_mimic_cxr(self, images, targets=None, mode='train', update_opts={}): 42 | att_feats, fc_feats = self.visual_extractor(images) 43 | if mode == 'train': 44 | output = self.encoder_decoder(fc_feats, att_feats, targets, mode='forward') 45 | return output 46 | elif mode == 'sample': 47 | output, output_probs = self.encoder_decoder(fc_feats, att_feats, mode='sample', update_opts=update_opts) 48 | return output, output_probs 49 | else: 50 | raise ValueError 51 | # return output 52 | -------------------------------------------------------------------------------- /models/r2gen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | from modules.visual_extractor import VisualExtractor 6 | from modules.encoder_decoder import EncoderDecoder 7 | 8 | 9 | class R2GenModel(nn.Module): 10 | def __init__(self, args, tokenizer): 11 | super(R2GenModel, self).__init__() 12 | self.args = args 13 | self.tokenizer = tokenizer 14 | self.visual_extractor = VisualExtractor(args) 15 | self.encoder_decoder = EncoderDecoder(args, tokenizer) 16 | if args.dataset_name == 'iu_xray': 17 | self.forward = self.forward_iu_xray 18 | else: 19 | self.forward = self.forward_mimic_cxr 20 | 21 | def __str__(self): 22 | model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 23 | params = sum([np.prod(p.size()) for p in model_parameters]) 24 | return super().__str__() + '\nTrainable parameters: {}'.format(params) 25 | 26 | def forward_iu_xray(self, images, targets=None, mode='train', update_opts={}): 27 | att_feats_0, fc_feats_0 = self.visual_extractor(images[:, 0]) 28 | att_feats_1, fc_feats_1 = self.visual_extractor(images[:, 1]) 29 | fc_feats = torch.cat((fc_feats_0, fc_feats_1), dim=1) 30 | att_feats = torch.cat((att_feats_0, att_feats_1), dim=1) 31 | if mode == 'train': 32 | output = self.encoder_decoder(fc_feats, att_feats, targets, mode='forward') 33 | return output 34 | elif mode == 'sample': 35 | output, output_probs = self.encoder_decoder(fc_feats, att_feats, mode='sample', update_opts=update_opts) 36 | return output, output_probs 37 | else: 38 | raise ValueError 39 | 40 | def forward_mimic_cxr(self, images, targets=None, mode='train', update_opts={}): 41 | att_feats, fc_feats = self.visual_extractor(images) 42 | if mode == 'train': 43 | output = self.encoder_decoder(fc_feats, att_feats, targets, mode='forward') 44 | return output 45 | elif mode == 'sample': 46 | output, output_probs = self.encoder_decoder(fc_feats, att_feats, mode='sample', update_opts=update_opts) 47 | return output, output_probs 48 | else: 49 | raise ValueError 50 | -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2GenRL/214b4dcfdde5752d2c49e774780528519ca740fd/modules/__init__.py -------------------------------------------------------------------------------- /modules/att_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence 9 | 10 | import modules.utils as utils 11 | from modules.caption_model import CaptionModel 12 | 13 | 14 | def sort_pack_padded_sequence(input, lengths): 15 | sorted_lengths, indices = torch.sort(lengths, descending=True) 16 | tmp = pack_padded_sequence(input[indices], sorted_lengths.cpu(), batch_first=True) 17 | inv_ix = indices.clone() 18 | inv_ix[indices] = torch.arange(0, len(indices)).type_as(inv_ix) 19 | return tmp, inv_ix 20 | 21 | 22 | def pad_unsort_packed_sequence(input, inv_ix): 23 | tmp, _ = pad_packed_sequence(input, batch_first=True) 24 | tmp = tmp[inv_ix] 25 | return tmp 26 | 27 | 28 | def pack_wrapper(module, att_feats, att_masks): 29 | if att_masks is not None: 30 | packed, inv_ix = sort_pack_padded_sequence(att_feats, att_masks.data.long().sum(1)) 31 | return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix) 32 | else: 33 | return module(att_feats) 34 | 35 | 36 | class AttModel(CaptionModel): 37 | def __init__(self, args, tokenizer): 38 | super(AttModel, self).__init__() 39 | self.args = args 40 | self.tokenizer = tokenizer 41 | self.vocab_size = len(tokenizer.idx2token) 42 | self.input_encoding_size = args.d_model 43 | self.rnn_size = args.d_ff 44 | self.num_layers = args.num_layers 45 | self.drop_prob_lm = args.drop_prob_lm 46 | self.max_seq_length = args.max_seq_length 47 | self.att_feat_size = args.d_vf 48 | self.att_hid_size = args.d_model 49 | 50 | self.bos_idx = args.bos_idx 51 | self.eos_idx = args.eos_idx 52 | self.pad_idx = args.pad_idx 53 | 54 | self.use_bn = args.use_bn 55 | 56 | self.embed = lambda x: x 57 | self.fc_embed = lambda x: x 58 | self.att_embed = nn.Sequential(*( 59 | ((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ()) + 60 | (nn.Linear(self.att_feat_size, self.input_encoding_size), 61 | nn.ReLU(), 62 | nn.Dropout(self.drop_prob_lm)) + 63 | ((nn.BatchNorm1d(self.input_encoding_size),) if self.use_bn == 2 else ()))) 64 | 65 | def clip_att(self, att_feats, att_masks): 66 | # Clip the length of att_masks and att_feats to the maximum length 67 | if att_masks is not None: 68 | max_len = att_masks.data.long().sum(1).max() 69 | att_feats = att_feats[:, :max_len].contiguous() 70 | att_masks = att_masks[:, :max_len].contiguous() 71 | return att_feats, att_masks 72 | 73 | def _prepare_feature(self, fc_feats, att_feats, att_masks): 74 | att_feats, att_masks = self.clip_att(att_feats, att_masks) 75 | 76 | # embed fc and att feats 77 | fc_feats = self.fc_embed(fc_feats) 78 | att_feats = pack_wrapper(self.att_embed, att_feats, att_masks) 79 | 80 | # Project the attention feats first to reduce memory and computation comsumptions. 81 | p_att_feats = self.ctx2att(att_feats) 82 | 83 | return fc_feats, att_feats, p_att_feats, att_masks 84 | 85 | def get_logprobs_state(self, it, fc_feats, att_feats, p_att_feats, att_masks, state, output_logsoftmax=1): 86 | # 'it' contains a word index 87 | xt = self.embed(it) 88 | 89 | output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state, att_masks) 90 | if output_logsoftmax: 91 | logprobs = F.log_softmax(self.logit(output), dim=1) 92 | else: 93 | logprobs = self.logit(output) 94 | 95 | return logprobs, state 96 | 97 | def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): 98 | beam_size = opt.get('beam_size', 10) 99 | group_size = opt.get('group_size', 1) 100 | sample_n = opt.get('sample_n', 10) 101 | # when sample_n == beam_size then each beam is a sample. 102 | assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search' 103 | batch_size = fc_feats.size(0) 104 | 105 | p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks) 106 | 107 | assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' 108 | seq = fc_feats.new_full((batch_size * sample_n, self.max_seq_length), self.pad_idx, dtype=torch.long) 109 | seqLogprobs = fc_feats.new_zeros(batch_size * sample_n, self.max_seq_length, self.vocab_size + 1) 110 | # lets process every image independently for now, for simplicity 111 | 112 | self.done_beams = [[] for _ in range(batch_size)] 113 | 114 | state = self.init_hidden(batch_size) 115 | 116 | # first step, feed bos 117 | it = fc_feats.new_full([batch_size], self.bos_idx, dtype=torch.long) 118 | logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state) 119 | 120 | p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors(beam_size, 121 | [p_fc_feats, p_att_feats, 122 | pp_att_feats, p_att_masks] 123 | ) 124 | self.done_beams = self.beam_search(state, logprobs, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, opt=opt) 125 | for k in range(batch_size): 126 | if sample_n == beam_size: 127 | for _n in range(sample_n): 128 | seq_len = self.done_beams[k][_n]['seq'].shape[0] 129 | seq[k * sample_n + _n, :seq_len] = self.done_beams[k][_n]['seq'] 130 | seqLogprobs[k * sample_n + _n, :seq_len] = self.done_beams[k][_n]['logps'] 131 | else: 132 | seq_len = self.done_beams[k][0]['seq'].shape[0] 133 | seq[k, :seq_len] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score 134 | seqLogprobs[k, :seq_len] = self.done_beams[k][0]['logps'] 135 | # return the samples and their log likelihoods 136 | return seq, seqLogprobs 137 | 138 | def _sample(self, fc_feats, att_feats, att_masks=None, update_opts={}): 139 | opt = self.args.__dict__ 140 | opt.update(**update_opts) 141 | 142 | sample_method = opt.get('sample_method', 'greedy') 143 | beam_size = opt.get('beam_size', 1) 144 | temperature = opt.get('temperature', 1.0) 145 | sample_n = int(opt.get('sample_n', 1)) 146 | group_size = opt.get('group_size', 1) 147 | output_logsoftmax = opt.get('output_logsoftmax', 1) 148 | decoding_constraint = opt.get('decoding_constraint', 0) 149 | block_trigrams = opt.get('block_trigrams', 0) 150 | if beam_size > 1 and sample_method in ['greedy', 'beam_search']: 151 | return self._sample_beam(fc_feats, att_feats, att_masks, opt) 152 | if group_size > 1: 153 | return self._diverse_sample(fc_feats, att_feats, att_masks, opt) 154 | 155 | batch_size = fc_feats.size(0) 156 | state = self.init_hidden(batch_size * sample_n) 157 | 158 | p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks) 159 | 160 | if sample_n > 1: 161 | p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors(sample_n, 162 | [p_fc_feats, p_att_feats, 163 | pp_att_feats, p_att_masks] 164 | ) 165 | 166 | trigrams = [] # will be a list of batch_size dictionaries 167 | 168 | seq = fc_feats.new_full((batch_size * sample_n, self.max_seq_length), self.pad_idx, dtype=torch.long) 169 | seqLogprobs = fc_feats.new_zeros(batch_size * sample_n, self.max_seq_length, self.vocab_size + 1) 170 | for t in range(self.max_seq_length + 1): 171 | if t == 0: # input 172 | it = fc_feats.new_full([batch_size * sample_n], self.bos_idx, dtype=torch.long) 173 | 174 | logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state, 175 | output_logsoftmax=output_logsoftmax) 176 | 177 | if decoding_constraint and t > 0: 178 | tmp = logprobs.new_zeros(logprobs.size()) 179 | tmp.scatter_(1, seq[:, t - 1].data.unsqueeze(1), float('-inf')) 180 | logprobs = logprobs + tmp 181 | 182 | # Mess with trigrams 183 | # Copy from https://github.com/lukemelas/image-paragraph-captioning 184 | if block_trigrams and t >= 3: 185 | # Store trigram generated at last step 186 | prev_two_batch = seq[:, t - 3:t - 1] 187 | for i in range(batch_size): # = seq.size(0) 188 | prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item()) 189 | current = seq[i][t - 1] 190 | if t == 3: # initialize 191 | trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int} 192 | elif t > 3: 193 | if prev_two in trigrams[i]: # add to list 194 | trigrams[i][prev_two].append(current) 195 | else: # create list 196 | trigrams[i][prev_two] = [current] 197 | # Block used trigrams at next step 198 | prev_two_batch = seq[:, t - 2:t] 199 | mask = torch.zeros(logprobs.size(), requires_grad=False).cuda() # batch_size x vocab_size 200 | for i in range(batch_size): 201 | prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item()) 202 | if prev_two in trigrams[i]: 203 | for j in trigrams[i][prev_two]: 204 | mask[i, j] += 1 205 | # Apply mask to log probs 206 | # logprobs = logprobs - (mask * 1e9) 207 | alpha = 2.0 # = 4 208 | logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best) 209 | 210 | # sample the next word 211 | if t == self.max_seq_length: # skip if we achieve maximum length 212 | break 213 | it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, temperature) 214 | 215 | # stop when all finished 216 | if t == 0: 217 | unfinished = it != self.eos_idx 218 | else: 219 | it[~unfinished] = self.pad_idx # This allows eos_idx not being overwritten to 0 220 | logprobs = logprobs * unfinished.unsqueeze(1).float() 221 | unfinished = unfinished * (it != self.eos_idx) 222 | seq[:, t] = it 223 | seqLogprobs[:, t] = logprobs 224 | # quit loop if all sequences have finished 225 | if unfinished.sum() == 0: 226 | break 227 | 228 | return seq, seqLogprobs 229 | 230 | def _diverse_sample(self, fc_feats, att_feats, att_masks=None, opt={}): 231 | 232 | sample_method = opt.get('sample_method', 'greedy') 233 | beam_size = opt.get('beam_size', 1) 234 | temperature = opt.get('temperature', 1.0) 235 | group_size = opt.get('group_size', 1) 236 | diversity_lambda = opt.get('diversity_lambda', 0.5) 237 | decoding_constraint = opt.get('decoding_constraint', 0) 238 | block_trigrams = opt.get('block_trigrams', 0) 239 | 240 | batch_size = fc_feats.size(0) 241 | state = self.init_hidden(batch_size) 242 | 243 | p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks) 244 | 245 | trigrams_table = [[] for _ in range(group_size)] # will be a list of batch_size dictionaries 246 | 247 | seq_table = [fc_feats.new_full((batch_size, self.max_seq_length), self.pad_idx, dtype=torch.long) for _ in 248 | range(group_size)] 249 | seqLogprobs_table = [fc_feats.new_zeros(batch_size, self.max_seq_length) for _ in range(group_size)] 250 | state_table = [self.init_hidden(batch_size) for _ in range(group_size)] 251 | 252 | for tt in range(self.max_seq_length + group_size): 253 | for divm in range(group_size): 254 | t = tt - divm 255 | seq = seq_table[divm] 256 | seqLogprobs = seqLogprobs_table[divm] 257 | trigrams = trigrams_table[divm] 258 | if t >= 0 and t <= self.max_seq_length - 1: 259 | if t == 0: # input 260 | it = fc_feats.new_full([batch_size], self.bos_idx, dtype=torch.long) 261 | else: 262 | it = seq[:, t - 1] # changed 263 | 264 | logprobs, state_table[divm] = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, 265 | p_att_masks, state_table[divm]) # changed 266 | logprobs = F.log_softmax(logprobs / temperature, dim=-1) 267 | 268 | # Add diversity 269 | if divm > 0: 270 | unaug_logprobs = logprobs.clone() 271 | for prev_choice in range(divm): 272 | prev_decisions = seq_table[prev_choice][:, t] 273 | logprobs[:, prev_decisions] = logprobs[:, prev_decisions] - diversity_lambda 274 | 275 | if decoding_constraint and t > 0: 276 | tmp = logprobs.new_zeros(logprobs.size()) 277 | tmp.scatter_(1, seq[:, t - 1].data.unsqueeze(1), float('-inf')) 278 | logprobs = logprobs + tmp 279 | 280 | # Mess with trigrams 281 | if block_trigrams and t >= 3: 282 | # Store trigram generated at last step 283 | prev_two_batch = seq[:, t - 3:t - 1] 284 | for i in range(batch_size): # = seq.size(0) 285 | prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item()) 286 | current = seq[i][t - 1] 287 | if t == 3: # initialize 288 | trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int} 289 | elif t > 3: 290 | if prev_two in trigrams[i]: # add to list 291 | trigrams[i][prev_two].append(current) 292 | else: # create list 293 | trigrams[i][prev_two] = [current] 294 | # Block used trigrams at next step 295 | prev_two_batch = seq[:, t - 2:t] 296 | mask = torch.zeros(logprobs.size(), requires_grad=False).cuda() # batch_size x vocab_size 297 | for i in range(batch_size): 298 | prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item()) 299 | if prev_two in trigrams[i]: 300 | for j in trigrams[i][prev_two]: 301 | mask[i, j] += 1 302 | # Apply mask to log probs 303 | # logprobs = logprobs - (mask * 1e9) 304 | alpha = 2.0 # = 4 305 | logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best) 306 | 307 | it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, 1) 308 | 309 | # stop when all finished 310 | if t == 0: 311 | unfinished = it != self.eos_idx 312 | else: 313 | unfinished = seq[:, t - 1] != self.pad_idx & seq[:, t - 1] != self.eos_idx 314 | it[~unfinished] = self.pad_idx 315 | unfinished = unfinished & (it != self.eos_idx) # changed 316 | seq[:, t] = it 317 | seqLogprobs[:, t] = sampleLogprobs.view(-1) 318 | 319 | return torch.stack(seq_table, 1).reshape(batch_size * group_size, -1), torch.stack(seqLogprobs_table, 320 | 1).reshape( 321 | batch_size * group_size, -1) 322 | -------------------------------------------------------------------------------- /modules/base_cmn.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import copy 6 | import math 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from .att_model import pack_wrapper, AttModel 14 | 15 | 16 | def clones(module, N): 17 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) 18 | 19 | 20 | def subsequent_mask(size): 21 | attn_shape = (1, size, size) 22 | subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8') 23 | return torch.from_numpy(subsequent_mask) == 0 24 | 25 | 26 | def attention(query, key, value, mask=None, dropout=None): 27 | d_k = query.size(-1) 28 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) 29 | if mask is not None: 30 | scores = scores.masked_fill(mask == 0, float('-inf')) 31 | p_attn = F.softmax(scores, dim=-1) 32 | if dropout is not None: 33 | p_attn = dropout(p_attn) 34 | return torch.matmul(p_attn, value), p_attn 35 | 36 | 37 | def memory_querying_responding(query, key, value, mask=None, dropout=None, topk=32): 38 | d_k = query.size(-1) 39 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) 40 | if mask is not None: 41 | scores = scores.masked_fill(mask == 0, float('-inf')) 42 | selected_scores, idx = scores.topk(topk) 43 | dummy_value = value.unsqueeze(2).expand(idx.size(0), idx.size(1), idx.size(2), value.size(-2), value.size(-1)) 44 | dummy_idx = idx.unsqueeze(-1).expand(idx.size(0), idx.size(1), idx.size(2), idx.size(3), value.size(-1)) 45 | selected_value = torch.gather(dummy_value, 3, dummy_idx) 46 | p_attn = F.softmax(selected_scores, dim=-1) 47 | if dropout is not None: 48 | p_attn = dropout(p_attn) 49 | return torch.matmul(p_attn.unsqueeze(3), selected_value).squeeze(3), p_attn 50 | 51 | 52 | class Transformer(nn.Module): 53 | def __init__(self, encoder, decoder, src_embed, tgt_embed, cmn, norm): 54 | super(Transformer, self).__init__() 55 | self.encoder = encoder 56 | self.decoder = decoder 57 | self.src_embed = src_embed 58 | self.tgt_embed = tgt_embed 59 | self.cmn = cmn 60 | self.norm = norm 61 | 62 | def forward(self, src, tgt, src_mask, tgt_mask, memory_matrix): 63 | return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask, memory_matrix=memory_matrix) 64 | 65 | def encode(self, src, src_mask): 66 | return self.encoder(self.src_embed(src), src_mask) 67 | 68 | def decode(self, memory, src_mask, tgt, tgt_mask, past=None, memory_matrix=None): 69 | embeddings = self.tgt_embed(tgt) 70 | mask = (tgt.data > 0) 71 | mask[:, 0] += True 72 | embeddings = pack_wrapper(self.norm, embeddings, mask) 73 | 74 | # Memory querying and responding for textual features 75 | dummy_memory_matrix = memory_matrix.unsqueeze(0).expand(embeddings.size(0), memory_matrix.size(0), 76 | memory_matrix.size(1)) 77 | responses = self.cmn(embeddings, dummy_memory_matrix, dummy_memory_matrix) 78 | embeddings = embeddings + responses 79 | # Memory querying and responding for textual features 80 | 81 | return self.decoder(embeddings, memory, src_mask, tgt_mask, past=past) 82 | 83 | 84 | class Encoder(nn.Module): 85 | def __init__(self, layer, N): 86 | super(Encoder, self).__init__() 87 | self.layers = clones(layer, N) 88 | self.norm = LayerNorm(layer.size) 89 | 90 | def forward(self, x, mask): 91 | for layer in self.layers: 92 | x = layer(x, mask) 93 | return self.norm(x) 94 | 95 | 96 | class LayerNorm(nn.Module): 97 | def __init__(self, features, eps=1e-6): 98 | super(LayerNorm, self).__init__() 99 | self.a_2 = nn.Parameter(torch.ones(features)) 100 | self.b_2 = nn.Parameter(torch.zeros(features)) 101 | self.eps = eps 102 | 103 | def forward(self, x): 104 | mean = x.mean(-1, keepdim=True) 105 | std = x.std(-1, keepdim=True) 106 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 107 | 108 | 109 | class SublayerConnection(nn.Module): 110 | def __init__(self, size, dropout): 111 | super(SublayerConnection, self).__init__() 112 | self.norm = LayerNorm(size) 113 | self.dropout = nn.Dropout(dropout) 114 | 115 | def forward(self, x, sublayer): 116 | _x = sublayer(self.norm(x)) 117 | if type(_x) is tuple: 118 | return x + self.dropout(_x[0]), _x[1] 119 | return x + self.dropout(_x) 120 | 121 | 122 | class EncoderLayer(nn.Module): 123 | def __init__(self, size, self_attn, feed_forward, dropout): 124 | super(EncoderLayer, self).__init__() 125 | self.self_attn = self_attn 126 | self.feed_forward = feed_forward 127 | self.sublayer = clones(SublayerConnection(size, dropout), 2) 128 | self.size = size 129 | 130 | def forward(self, x, mask): 131 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) 132 | return self.sublayer[1](x, self.feed_forward) 133 | 134 | 135 | class Decoder(nn.Module): 136 | def __init__(self, layer, N): 137 | super(Decoder, self).__init__() 138 | self.layers = clones(layer, N) 139 | self.norm = LayerNorm(layer.size) 140 | 141 | def forward(self, x, memory, src_mask, tgt_mask, past=None): 142 | if past is not None: 143 | present = [[], []] 144 | x = x[:, -1:] 145 | tgt_mask = tgt_mask[:, -1:] if tgt_mask is not None else None 146 | past = list(zip(past[0].split(2, dim=0), past[1].split(2, dim=0))) 147 | else: 148 | past = [None] * len(self.layers) 149 | for i, (layer, layer_past) in enumerate(zip(self.layers, past)): 150 | x = layer(x, memory, src_mask, tgt_mask, 151 | layer_past) 152 | if layer_past is not None: 153 | present[0].append(x[1][0]) 154 | present[1].append(x[1][1]) 155 | x = x[0] 156 | if past[0] is None: 157 | return self.norm(x) 158 | else: 159 | return self.norm(x), [torch.cat(present[0], 0), torch.cat(present[1], 0)] 160 | 161 | 162 | class DecoderLayer(nn.Module): 163 | def __init__(self, size, self_attn, src_attn, feed_forward, dropout): 164 | super(DecoderLayer, self).__init__() 165 | self.size = size 166 | self.self_attn = self_attn 167 | self.src_attn = src_attn 168 | self.feed_forward = feed_forward 169 | self.sublayer = clones(SublayerConnection(size, dropout), 3) 170 | 171 | def forward(self, x, memory, src_mask, tgt_mask, layer_past=None): 172 | m = memory 173 | if layer_past is None: 174 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask)) 175 | x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask)) 176 | return self.sublayer[2](x, self.feed_forward) 177 | else: 178 | present = [None, None] 179 | x, present[0] = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask, layer_past[0])) 180 | x, present[1] = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask, layer_past[1])) 181 | return self.sublayer[2](x, self.feed_forward), present 182 | 183 | 184 | class MultiThreadMemory(nn.Module): 185 | def __init__(self, h, d_model, dropout=0.1, topk=32): 186 | super(MultiThreadMemory, self).__init__() 187 | assert d_model % h == 0 188 | self.d_k = d_model // h 189 | self.h = h 190 | self.linears = clones(nn.Linear(d_model, d_model), 4) 191 | self.attn = None 192 | self.dropout = nn.Dropout(p=dropout) 193 | self.topk = topk 194 | 195 | def forward(self, query, key, value, mask=None, layer_past=None): 196 | if mask is not None: 197 | mask = mask.unsqueeze(1) 198 | nbatches = query.size(0) 199 | 200 | if layer_past is not None and layer_past.shape[2] == key.shape[1] > 1: 201 | query = self.linears[0](query) 202 | key, value = layer_past[0], layer_past[1] 203 | present = torch.stack([key, value]) 204 | else: 205 | query, key, value = \ 206 | [l(x) for l, x in zip(self.linears, (query, key, value))] 207 | if layer_past is not None and not (layer_past.shape[2] == key.shape[1] > 1): 208 | past_key, past_value = layer_past[0], layer_past[1] 209 | key = torch.cat((past_key, key), dim=1) 210 | value = torch.cat((past_value, value), dim=1) 211 | present = torch.stack([key, value]) 212 | 213 | query, key, value = \ 214 | [x.view(nbatches, -1, self.h, self.d_k).transpose(1, 2) 215 | for x in [query, key, value]] 216 | 217 | x, self.attn = memory_querying_responding(query, key, value, mask=mask, dropout=self.dropout, topk=self.topk) 218 | 219 | x = x.transpose(1, 2).contiguous() \ 220 | .view(nbatches, -1, self.h * self.d_k) 221 | if layer_past is not None: 222 | return self.linears[-1](x), present 223 | else: 224 | return self.linears[-1](x) 225 | 226 | 227 | class MultiHeadedAttention(nn.Module): 228 | def __init__(self, h, d_model, dropout=0.1): 229 | super(MultiHeadedAttention, self).__init__() 230 | assert d_model % h == 0 231 | self.d_k = d_model // h 232 | self.h = h 233 | self.linears = clones(nn.Linear(d_model, d_model), 4) 234 | self.attn = None 235 | self.dropout = nn.Dropout(p=dropout) 236 | 237 | def forward(self, query, key, value, mask=None, layer_past=None): 238 | if mask is not None: 239 | mask = mask.unsqueeze(1) 240 | nbatches = query.size(0) 241 | if layer_past is not None and layer_past.shape[2] == key.shape[1] > 1: 242 | query = self.linears[0](query) 243 | key, value = layer_past[0], layer_past[1] 244 | present = torch.stack([key, value]) 245 | else: 246 | query, key, value = \ 247 | [l(x) for l, x in zip(self.linears, (query, key, value))] 248 | 249 | if layer_past is not None and not (layer_past.shape[2] == key.shape[1] > 1): 250 | past_key, past_value = layer_past[0], layer_past[1] 251 | key = torch.cat((past_key, key), dim=1) 252 | value = torch.cat((past_value, value), dim=1) 253 | present = torch.stack([key, value]) 254 | 255 | query, key, value = \ 256 | [x.view(nbatches, -1, self.h, self.d_k).transpose(1, 2) 257 | for x in [query, key, value]] 258 | 259 | x, self.attn = attention(query, key, value, mask=mask, 260 | dropout=self.dropout) 261 | x = x.transpose(1, 2).contiguous() \ 262 | .view(nbatches, -1, self.h * self.d_k) 263 | if layer_past is not None: 264 | return self.linears[-1](x), present 265 | else: 266 | return self.linears[-1](x) 267 | 268 | 269 | class PositionwiseFeedForward(nn.Module): 270 | def __init__(self, d_model, d_ff, dropout=0.1): 271 | super(PositionwiseFeedForward, self).__init__() 272 | self.w_1 = nn.Linear(d_model, d_ff) 273 | self.w_2 = nn.Linear(d_ff, d_model) 274 | self.dropout = nn.Dropout(dropout) 275 | 276 | def forward(self, x): 277 | return self.w_2(self.dropout(F.relu(self.w_1(x)))) 278 | 279 | 280 | class Embeddings(nn.Module): 281 | def __init__(self, d_model, vocab): 282 | super(Embeddings, self).__init__() 283 | self.lut = nn.Embedding(vocab, d_model) 284 | self.d_model = d_model 285 | 286 | def forward(self, x): 287 | return self.lut(x) * math.sqrt(self.d_model) 288 | 289 | 290 | class PositionalEncoding(nn.Module): 291 | def __init__(self, d_model, dropout, max_len=5000): 292 | super(PositionalEncoding, self).__init__() 293 | self.dropout = nn.Dropout(p=dropout) 294 | 295 | pe = torch.zeros(max_len, d_model) 296 | position = torch.arange(0, max_len).unsqueeze(1).float() 297 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * 298 | -(math.log(10000.0) / d_model)) 299 | pe[:, 0::2] = torch.sin(position * div_term) 300 | pe[:, 1::2] = torch.cos(position * div_term) 301 | pe = pe.unsqueeze(0) 302 | self.register_buffer('pe', pe) 303 | 304 | def forward(self, x): 305 | x = x + self.pe[:, :x.size(1)] 306 | return self.dropout(x) 307 | 308 | 309 | class BaseCMN(AttModel): 310 | 311 | def make_model(self, tgt_vocab, cmn, norm_txt): 312 | c = copy.deepcopy 313 | attn = MultiHeadedAttention(self.num_heads, self.d_model) 314 | ff = PositionwiseFeedForward(self.d_model, self.d_ff, self.dropout) 315 | position = PositionalEncoding(self.d_model, self.dropout) 316 | model = Transformer( 317 | Encoder(EncoderLayer(self.d_model, c(attn), c(ff), self.dropout), self.num_layers), 318 | Decoder(DecoderLayer(self.d_model, c(attn), c(attn), c(ff), self.dropout), self.num_layers), 319 | nn.Sequential(c(position)), 320 | nn.Sequential(Embeddings(self.d_model, tgt_vocab), c(position)), cmn, norm_txt) 321 | for p in model.parameters(): 322 | if p.dim() > 1: 323 | nn.init.xavier_uniform_(p) 324 | return model 325 | 326 | def __init__(self, args, tokenizer): 327 | super(BaseCMN, self).__init__(args, tokenizer) 328 | self.args = args 329 | self.num_layers = args.num_layers 330 | self.d_model = args.d_model 331 | self.d_ff = args.d_ff 332 | self.num_heads = args.num_heads 333 | self.dropout = args.dropout 334 | self.topk = args.topk 335 | 336 | tgt_vocab = self.vocab_size + 1 337 | 338 | self.cmn = MultiThreadMemory(args.num_heads, args.d_model, topk=args.topk) 339 | self.norm_vis = nn.BatchNorm1d(args.d_model) 340 | self.norm_txt = nn.BatchNorm1d(args.d_model) 341 | 342 | self.model = self.make_model(tgt_vocab, self.cmn, self.norm_txt) 343 | self.logit = nn.Linear(args.d_model, tgt_vocab) 344 | 345 | self.memory_matrix = nn.Parameter(torch.randn((args.cmm_size, args.cmm_dim)), requires_grad=True) 346 | 347 | def init_hidden(self, bsz): 348 | return [] 349 | 350 | def _prepare_feature(self, fc_feats, att_feats, att_masks): 351 | att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks) 352 | memory = self.model.encode(att_feats, att_masks) 353 | 354 | return fc_feats[..., :1], att_feats[..., :1], memory, att_masks 355 | 356 | def _prepare_feature_forward(self, att_feats, att_masks=None, seq=None): 357 | att_feats, att_masks = self.clip_att(att_feats, att_masks) 358 | att_feats = pack_wrapper(self.att_embed, att_feats, att_masks) 359 | 360 | if att_masks is None: 361 | att_masks = att_feats.new_ones(att_feats.shape[:2], dtype=torch.long) 362 | 363 | # Memory querying and responding for visual features 364 | dummy_memory_matrix = self.memory_matrix.unsqueeze(0).expand(att_feats.size(0), self.memory_matrix.size(0), 365 | self.memory_matrix.size(1)) 366 | att_feats = pack_wrapper(self.norm_vis, att_feats, att_masks) 367 | responses = self.cmn(att_feats, dummy_memory_matrix, dummy_memory_matrix) 368 | att_feats = att_feats + responses 369 | # Memory querying and responding for visual features 370 | 371 | att_masks = att_masks.unsqueeze(-2) 372 | if seq is not None: 373 | seq = seq[:, :-1] 374 | seq_mask = (seq.data > 0) 375 | seq_mask[:, 0] += True 376 | 377 | seq_mask = seq_mask.unsqueeze(-2) 378 | seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask) 379 | else: 380 | seq_mask = None 381 | 382 | return att_feats, seq, att_masks, seq_mask 383 | 384 | def _forward(self, fc_feats, att_feats, seq, att_masks=None): 385 | att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq) 386 | out = self.model(att_feats, seq, att_masks, seq_mask, memory_matrix=self.memory_matrix) 387 | outputs = F.log_softmax(self.logit(out), dim=-1) 388 | 389 | return outputs 390 | 391 | def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask): 392 | if len(state) == 0: 393 | ys = it.unsqueeze(1) 394 | past = [fc_feats_ph.new_zeros(self.num_layers * 2, fc_feats_ph.shape[0], 0, self.d_model), 395 | fc_feats_ph.new_zeros(self.num_layers * 2, fc_feats_ph.shape[0], 0, self.d_model)] 396 | else: 397 | ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1) 398 | past = state[1:] 399 | out, past = self.model.decode(memory, mask, ys, subsequent_mask(ys.size(1)).to(memory.device), past=past, 400 | memory_matrix=self.memory_matrix) 401 | return out[:, -1], [ys.unsqueeze(0)] + past 402 | -------------------------------------------------------------------------------- /modules/base_cmn.py.bak: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import copy 6 | import math 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from .att_model import pack_wrapper, AttModel 14 | 15 | 16 | def clones(module, N): 17 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) 18 | 19 | 20 | def subsequent_mask(size): 21 | attn_shape = (1, size, size) 22 | subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8') 23 | return torch.from_numpy(subsequent_mask) == 0 24 | 25 | 26 | def attention(query, key, value, mask=None, dropout=None): 27 | d_k = query.size(-1) 28 | scores = torch.matmul(query, key.transpose(-2, -1)) \ 29 | / math.sqrt(d_k) 30 | if mask is not None: 31 | scores = scores.masked_fill(mask == 0, float('-inf')) 32 | p_attn = F.softmax(scores, dim=-1) 33 | if dropout is not None: 34 | p_attn = dropout(p_attn) 35 | return torch.matmul(p_attn, value), p_attn 36 | 37 | 38 | def memory_querying_responding(query, key, value, mask=None, dropout=None, topk=32): 39 | d_k = query.size(-1) 40 | 41 | scores = torch.matmul(query, key.transpose(-2, -1)) \ 42 | / math.sqrt(d_k) 43 | if mask is not None: 44 | scores = scores.masked_fill(mask == 0, float('-inf')) 45 | selected_scores, idx = scores.topk(topk) 46 | dummy_value = value.unsqueeze(2).expand(idx.size(0), idx.size(1), idx.size(2), value.size(-2), value.size(-1)) 47 | dummy_idx = idx.unsqueeze(-1).expand(idx.size(0), idx.size(1), idx.size(2), idx.size(3), value.size(-1)) 48 | selected_value = torch.gather(dummy_value, 3, dummy_idx) 49 | p_attn = F.softmax(selected_scores, dim=-1) 50 | if dropout is not None: 51 | p_attn = dropout(p_attn) 52 | return torch.matmul(p_attn.unsqueeze(3), selected_value).squeeze(3), p_attn 53 | 54 | 55 | class MultiThreadMemory(nn.Module): 56 | def __init__(self, h, d_model, dropout=0.1, topk=32): 57 | super(MultiThreadMemory, self).__init__() 58 | assert d_model % h == 0 59 | # We assume d_v always equals d_k 60 | self.d_k = d_model // h 61 | self.h = h 62 | self.linears = clones(nn.Linear(d_model, d_model), 4) 63 | self.attn = None 64 | self.dropout = nn.Dropout(p=dropout) 65 | self.topk = topk 66 | 67 | def forward(self, query, key, value, mask=None, layer_past=None): 68 | if mask is not None: 69 | # Same mask applied to all h heads. 70 | mask = mask.unsqueeze(1) 71 | nbatches = query.size(0) 72 | 73 | # The past works differently here. For self attn, the query and key be updated incrementailly 74 | # For src_attn the past is fixed. 75 | 76 | # For src_attn, when the layer past is ready 77 | if layer_past is not None and layer_past.shape[2] == key.shape[ 78 | 1] > 1: # suppose memory size always greater than 1 79 | query = self.linears[0](query) 80 | key, value = layer_past[0], layer_past[1] 81 | present = torch.stack([key, value]) 82 | else: 83 | # 1) Do all the linear projections in batch from d_model => h x d_k 84 | query, key, value = \ 85 | [l(x) for l, x in zip(self.linears, (query, key, value))] 86 | 87 | # self attn + past OR the first time step of src attn 88 | if layer_past is not None and not (layer_past.shape[2] == key.shape[1] > 1): 89 | past_key, past_value = layer_past[0], layer_past[1] 90 | key = torch.cat((past_key, key), dim=1) 91 | value = torch.cat((past_value, value), dim=1) 92 | present = torch.stack([key, value]) 93 | 94 | query, key, value = \ 95 | [x.view(nbatches, -1, self.h, self.d_k).transpose(1, 2) 96 | for x in [query, key, value]] 97 | 98 | # 2) Apply attention on all the projected vectors in batch. 99 | x, self.attn = memory_querying_responding(query, key, value, mask=mask, dropout=self.dropout, topk=self.topk) 100 | 101 | # 3) "Concat" using a view and apply a final linear. 102 | x = x.transpose(1, 2).contiguous() \ 103 | .view(nbatches, -1, self.h * self.d_k) 104 | if layer_past is not None: 105 | return self.linears[-1](x), present 106 | else: 107 | return self.linears[-1](x) 108 | 109 | 110 | class Transformer(nn.Module): 111 | def __init__(self, encoder, decoder, src_embed, tgt_embed, cmn): 112 | super(Transformer, self).__init__() 113 | self.encoder = encoder 114 | self.decoder = decoder 115 | self.src_embed = src_embed 116 | self.tgt_embed = tgt_embed 117 | self.cmn = cmn 118 | 119 | def forward(self, src, tgt, src_mask, tgt_mask, memory_matrix): 120 | return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask, memory_matrix=memory_matrix) 121 | 122 | def encode(self, src, src_mask): 123 | return self.encoder(self.src_embed(src), src_mask) 124 | 125 | def decode(self, memory, src_mask, tgt, tgt_mask, past=None, memory_matrix=None): 126 | embeddings = self.tgt_embed(tgt) 127 | seq_mask = (tgt.data > 0) 128 | seq_mask[:, 0] += True 129 | 130 | dummy_memory_matrix = memory_matrix.unsqueeze(0).expand(embeddings.size(0), memory_matrix.size(0), 131 | memory_matrix.size(1)) 132 | responses = self.cmn(embeddings, dummy_memory_matrix, dummy_memory_matrix) 133 | 134 | embeddings = embeddings + responses 135 | return self.decoder(embeddings, memory, src_mask, tgt_mask, past=past) 136 | 137 | 138 | class Encoder(nn.Module): 139 | def __init__(self, layer, N): 140 | super(Encoder, self).__init__() 141 | self.layers = clones(layer, N) 142 | self.norm = LayerNorm(layer.size) 143 | 144 | def forward(self, x, mask): 145 | for layer in self.layers: 146 | x = layer(x, mask) 147 | return self.norm(x) 148 | 149 | 150 | class LayerNorm(nn.Module): 151 | def __init__(self, features, eps=1e-6): 152 | super(LayerNorm, self).__init__() 153 | self.a_2 = nn.Parameter(torch.ones(features)) 154 | self.b_2 = nn.Parameter(torch.zeros(features)) 155 | self.eps = eps 156 | 157 | def forward(self, x): 158 | mean = x.mean(-1, keepdim=True) 159 | std = x.std(-1, keepdim=True) 160 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 161 | 162 | 163 | class SublayerConnection(nn.Module): 164 | def __init__(self, size, dropout): 165 | super(SublayerConnection, self).__init__() 166 | self.norm = LayerNorm(size) 167 | self.dropout = nn.Dropout(dropout) 168 | 169 | def forward(self, x, sublayer): 170 | _x = sublayer(self.norm(x)) 171 | if type(_x) is tuple: # for multi-head attention that returns past 172 | return x + self.dropout(_x[0]), _x[1] 173 | return x + self.dropout(_x) 174 | 175 | 176 | class EncoderLayer(nn.Module): 177 | def __init__(self, size, self_attn, feed_forward, dropout): 178 | super(EncoderLayer, self).__init__() 179 | self.self_attn = self_attn 180 | self.feed_forward = feed_forward 181 | self.sublayer = clones(SublayerConnection(size, dropout), 2) 182 | self.size = size 183 | 184 | def forward(self, x, mask): 185 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) 186 | return self.sublayer[1](x, self.feed_forward) 187 | 188 | 189 | class Decoder(nn.Module): 190 | def __init__(self, layer, N): 191 | super(Decoder, self).__init__() 192 | self.layers = clones(layer, N) 193 | self.norm = LayerNorm(layer.size) 194 | 195 | def forward(self, x, memory, src_mask, tgt_mask, past=None): 196 | if past is not None: 197 | present = [[], []] 198 | x = x[:, -1:] 199 | tgt_mask = tgt_mask[:, -1:] if tgt_mask is not None else None 200 | past = list(zip(past[0].split(2, dim=0), past[1].split(2, dim=0))) 201 | else: 202 | past = [None] * len(self.layers) 203 | for i, (layer, layer_past) in enumerate(zip(self.layers, past)): 204 | x = layer(x, memory, src_mask, tgt_mask, 205 | layer_past) 206 | if layer_past is not None: 207 | present[0].append(x[1][0]) 208 | present[1].append(x[1][1]) 209 | x = x[0] 210 | if past[0] is None: 211 | return self.norm(x) 212 | else: 213 | return self.norm(x), [torch.cat(present[0], 0), torch.cat(present[1], 0)] 214 | 215 | 216 | class DecoderLayer(nn.Module): 217 | def __init__(self, size, self_attn, src_attn, feed_forward, dropout): 218 | super(DecoderLayer, self).__init__() 219 | self.size = size 220 | self.self_attn = self_attn 221 | self.src_attn = src_attn 222 | self.feed_forward = feed_forward 223 | self.sublayer = clones(SublayerConnection(size, dropout), 3) 224 | 225 | def forward(self, x, memory, src_mask, tgt_mask, layer_past=None): 226 | m = memory 227 | if layer_past is None: 228 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask)) 229 | x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask)) 230 | return self.sublayer[2](x, self.feed_forward) 231 | else: 232 | present = [None, None] 233 | x, present[0] = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask, layer_past[0])) 234 | x, present[1] = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask, layer_past[1])) 235 | return self.sublayer[2](x, self.feed_forward), present 236 | 237 | 238 | class MultiHeadedAttention(nn.Module): 239 | def __init__(self, h, d_model, dropout=0.1): 240 | super(MultiHeadedAttention, self).__init__() 241 | assert d_model % h == 0 242 | # We assume d_v always equals d_k 243 | self.d_k = d_model // h 244 | self.h = h 245 | self.linears = clones(nn.Linear(d_model, d_model), 4) 246 | self.attn = None 247 | self.dropout = nn.Dropout(p=dropout) 248 | 249 | def forward(self, query, key, value, mask=None, layer_past=None): 250 | if mask is not None: 251 | # Same mask applied to all h heads. 252 | mask = mask.unsqueeze(1) 253 | nbatches = query.size(0) 254 | 255 | # The past works differently here. For self attn, the query and key be updated incrementailly 256 | # For src_attn the past is fixed. 257 | 258 | # For src_attn, when the layer past is ready 259 | if layer_past is not None and layer_past.shape[2] == key.shape[ 260 | 1] > 1: # suppose memory size always greater than 1 261 | query = self.linears[0](query) 262 | key, value = layer_past[0], layer_past[1] 263 | present = torch.stack([key, value]) 264 | else: 265 | # 1) Do all the linear projections in batch from d_model => h x d_k 266 | query, key, value = \ 267 | [l(x) for l, x in zip(self.linears, (query, key, value))] 268 | 269 | # self attn + past OR the first time step of src attn 270 | if layer_past is not None and not (layer_past.shape[2] == key.shape[1] > 1): 271 | past_key, past_value = layer_past[0], layer_past[1] 272 | key = torch.cat((past_key, key), dim=1) 273 | value = torch.cat((past_value, value), dim=1) 274 | present = torch.stack([key, value]) 275 | 276 | query, key, value = \ 277 | [x.view(nbatches, -1, self.h, self.d_k).transpose(1, 2) 278 | for x in [query, key, value]] 279 | 280 | # 2) Apply attention on all the projected vectors in batch. 281 | x, self.attn = attention(query, key, value, mask=mask, 282 | dropout=self.dropout) 283 | # 3) "Concat" using a view and apply a final linear. 284 | x = x.transpose(1, 2).contiguous() \ 285 | .view(nbatches, -1, self.h * self.d_k) 286 | if layer_past is not None: 287 | return self.linears[-1](x), present 288 | else: 289 | return self.linears[-1](x) 290 | 291 | 292 | class PositionwiseFeedForward(nn.Module): 293 | def __init__(self, d_model, d_ff, dropout=0.1): 294 | super(PositionwiseFeedForward, self).__init__() 295 | self.w_1 = nn.Linear(d_model, d_ff) 296 | self.w_2 = nn.Linear(d_ff, d_model) 297 | self.dropout = nn.Dropout(dropout) 298 | 299 | def forward(self, x): 300 | return self.w_2(self.dropout(F.relu(self.w_1(x)))) 301 | 302 | 303 | class Embeddings(nn.Module): 304 | def __init__(self, d_model, vocab): 305 | super(Embeddings, self).__init__() 306 | self.lut = nn.Embedding(vocab, d_model) 307 | self.d_model = d_model 308 | 309 | def forward(self, x): 310 | return self.lut(x) * math.sqrt(self.d_model) 311 | 312 | 313 | class PositionalEncoding(nn.Module): 314 | def __init__(self, d_model, dropout, max_len=5000): 315 | super(PositionalEncoding, self).__init__() 316 | self.dropout = nn.Dropout(p=dropout) 317 | 318 | # Compute the positional encodings once in log space. 319 | pe = torch.zeros(max_len, d_model) 320 | position = torch.arange(0, max_len).unsqueeze(1).float() 321 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * 322 | -(math.log(10000.0) / d_model)) 323 | pe[:, 0::2] = torch.sin(position * div_term) 324 | pe[:, 1::2] = torch.cos(position * div_term) 325 | pe = pe.unsqueeze(0) 326 | self.register_buffer('pe', pe) 327 | 328 | def forward(self, x): 329 | x = x + self.pe[:, :x.size(1)] 330 | return self.dropout(x) 331 | 332 | 333 | class BaseCMN(AttModel): 334 | 335 | def make_model(self, tgt_vocab, cmn): 336 | c = copy.deepcopy 337 | attn = MultiHeadedAttention(self.num_heads, self.d_model) 338 | ff = PositionwiseFeedForward(self.d_model, self.d_ff, self.dropout) 339 | position = PositionalEncoding(self.d_model, self.dropout) 340 | model = Transformer( 341 | Encoder(EncoderLayer(self.d_model, c(attn), c(ff), self.dropout), self.num_layers), 342 | Decoder(DecoderLayer(self.d_model, c(attn), c(attn), c(ff), self.dropout), self.num_layers), 343 | nn.Sequential(c(position)), 344 | nn.Sequential(Embeddings(self.d_model, tgt_vocab), c(position)), cmn) 345 | for p in model.parameters(): 346 | if p.dim() > 1: 347 | nn.init.xavier_uniform_(p) 348 | return model 349 | 350 | def __init__(self, args, tokenizer): 351 | super(BaseCMN, self).__init__(args, tokenizer) 352 | self.args = args 353 | self.num_layers = args.num_layers 354 | self.d_model = args.d_model 355 | self.d_ff = args.d_ff 356 | self.num_heads = args.num_heads 357 | self.dropout = args.dropout 358 | self.topk = args.topk 359 | 360 | tgt_vocab = self.vocab_size + 1 361 | 362 | self.cmn = MultiThreadMemory(args.num_heads, args.d_model, topk=args.topk) 363 | self.model = self.make_model(tgt_vocab, self.cmn) 364 | self.logit = nn.Linear(args.d_model, tgt_vocab) 365 | 366 | self.memory_matrix = nn.Parameter(torch.randn((args.cmm_size, args.cmm_dim)), requires_grad=True) 367 | 368 | def init_hidden(self, bsz): 369 | return [] 370 | 371 | def _prepare_feature(self, fc_feats, att_feats, att_masks): 372 | att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks) 373 | memory = self.model.encode(att_feats, att_masks) 374 | 375 | return fc_feats[..., :1], att_feats[..., :1], memory, att_masks 376 | 377 | def _prepare_feature_forward(self, att_feats, att_masks=None, seq=None): 378 | att_feats, att_masks = self.clip_att(att_feats, att_masks) 379 | att_feats = pack_wrapper(self.att_embed, att_feats, att_masks) # (bs, 49, 512) 380 | 381 | if att_masks is None: 382 | att_masks = att_feats.new_ones(att_feats.shape[:2], dtype=torch.long) 383 | 384 | dummy_memory_matrix = self.memory_matrix.unsqueeze(0).expand(att_feats.size(0), self.memory_matrix.size(0), 385 | self.memory_matrix.size(1)) 386 | responses = self.cmn(att_feats, dummy_memory_matrix, dummy_memory_matrix) + att_feats 387 | att_masks = att_masks.unsqueeze(-2) 388 | 389 | if seq is not None: 390 | seq = seq[:, :-1] 391 | seq_mask = (seq.data > 0) 392 | seq_mask[:, 0] += True 393 | 394 | seq_mask = seq_mask.unsqueeze(-2) 395 | seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask) 396 | else: 397 | seq_mask = None 398 | att_feats = att_feats + responses 399 | 400 | return att_feats, seq, att_masks, seq_mask 401 | 402 | def _forward(self, fc_feats, att_feats, seq, att_masks=None): 403 | att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq) 404 | out = self.model(att_feats, seq, att_masks, seq_mask, memory_matrix=self.memory_matrix) 405 | outputs = F.log_softmax(self.logit(out), dim=-1) 406 | 407 | return outputs 408 | 409 | def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask): 410 | if len(state) == 0: 411 | ys = it.unsqueeze(1) 412 | past = [fc_feats_ph.new_zeros(self.num_layers * 2, fc_feats_ph.shape[0], 0, self.d_model), 413 | fc_feats_ph.new_zeros(self.num_layers * 2, fc_feats_ph.shape[0], 0, self.d_model)] 414 | else: 415 | ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1) 416 | past = state[1:] 417 | out, past = self.model.decode(memory, mask, ys, subsequent_mask(ys.size(1)).to(memory.device), past=past, 418 | memory_matrix=self.memory_matrix) 419 | return out[:, -1], [ys.unsqueeze(0)] + past 420 | -------------------------------------------------------------------------------- /modules/dataloaders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torchvision import transforms 4 | from torch.utils.data import DataLoader 5 | from .datasets import IuxrayMultiImageDataset, MimiccxrSingleImageDataset 6 | 7 | 8 | class R2DataLoader(DataLoader): 9 | def __init__(self, args, tokenizer, split, shuffle): 10 | self.args = args 11 | self.dataset_name = args.dataset_name 12 | self.batch_size = args.batch_size 13 | self.shuffle = shuffle 14 | self.num_workers = args.num_workers 15 | self.tokenizer = tokenizer 16 | self.split = split 17 | 18 | if split == 'train': 19 | self.transform = transforms.Compose([ 20 | transforms.Resize(256), 21 | transforms.RandomCrop(224), 22 | transforms.RandomHorizontalFlip(), 23 | transforms.ToTensor(), 24 | transforms.Normalize((0.485, 0.456, 0.406), 25 | (0.229, 0.224, 0.225))]) 26 | else: 27 | self.transform = transforms.Compose([ 28 | transforms.Resize((224, 224)), 29 | transforms.ToTensor(), 30 | transforms.Normalize((0.485, 0.456, 0.406), 31 | (0.229, 0.224, 0.225))]) 32 | 33 | if self.dataset_name == 'iu_xray': 34 | self.dataset = IuxrayMultiImageDataset(self.args, self.tokenizer, self.split, transform=self.transform) 35 | else: 36 | self.dataset = MimiccxrSingleImageDataset(self.args, self.tokenizer, self.split, transform=self.transform) 37 | 38 | self.init_kwargs = { 39 | 'dataset': self.dataset, 40 | 'batch_size': self.batch_size, 41 | 'shuffle': self.shuffle, 42 | 'collate_fn': self.collate_fn, 43 | 'num_workers': self.num_workers 44 | } 45 | super().__init__(**self.init_kwargs) 46 | 47 | @staticmethod 48 | def collate_fn(data): 49 | image_id_batch, image_batch, report_ids_batch, report_masks_batch, seq_lengths_batch = zip(*data) 50 | image_batch = torch.stack(image_batch, 0) 51 | max_seq_length = max(seq_lengths_batch) 52 | 53 | target_batch = np.zeros((len(report_ids_batch), max_seq_length), dtype=int) 54 | target_masks_batch = np.zeros((len(report_ids_batch), max_seq_length), dtype=int) 55 | 56 | for i, report_ids in enumerate(report_ids_batch): 57 | target_batch[i, :len(report_ids)] = report_ids 58 | 59 | for i, report_masks in enumerate(report_masks_batch): 60 | target_masks_batch[i, :len(report_masks)] = report_masks 61 | 62 | return image_id_batch, image_batch, torch.LongTensor(target_batch), torch.FloatTensor(target_masks_batch) 63 | 64 | -------------------------------------------------------------------------------- /modules/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class BaseDataset(Dataset): 9 | def __init__(self, args, tokenizer, split, transform=None): 10 | self.image_dir = args.image_dir 11 | self.ann_path = args.ann_path 12 | self.max_seq_length = args.max_seq_length 13 | self.split = split 14 | self.tokenizer = tokenizer 15 | self.transform = transform 16 | self.ann = json.loads(open(self.ann_path, 'r').read()) 17 | # self.label_set = self.ann['label_set'] 18 | self.examples = self.ann[self.split] 19 | for i in range(len(self.examples)): 20 | self.examples[i]['ids'] = tokenizer(self.examples[i]['report'])[:self.max_seq_length] 21 | self.examples[i]['mask'] = [1] * len(self.examples[i]['ids']) 22 | # self.examples[i]['label'] = [0 if i == -1 else i for i in self.examples[i]['label']] 23 | 24 | def __len__(self): 25 | return len(self.examples) 26 | 27 | 28 | class IuxrayMultiImageDataset(BaseDataset): 29 | def __getitem__(self, idx): 30 | example = self.examples[idx] 31 | image_id = example['id'] 32 | image_path = example['image_path'] 33 | image_1 = Image.open(os.path.join(self.image_dir, image_path[0])).convert('RGB') 34 | image_2 = Image.open(os.path.join(self.image_dir, image_path[1])).convert('RGB') 35 | if self.transform is not None: 36 | image_1 = self.transform(image_1) 37 | image_2 = self.transform(image_2) 38 | image = torch.stack((image_1, image_2), 0) 39 | report_ids = example['ids'] 40 | report_masks = example['mask'] 41 | # report_label = example['label'] 42 | seq_length = len(report_ids) 43 | sample = (image_id, image, report_ids, report_masks, seq_length) 44 | return sample 45 | 46 | 47 | class MimiccxrSingleImageDataset(BaseDataset): 48 | def __getitem__(self, idx): 49 | example = self.examples[idx] 50 | image_id = example['id'] 51 | image_path = example['image_path'] 52 | image = Image.open(os.path.join(self.image_dir, image_path[0])).convert('RGB') 53 | image_id = os.path.join(self.image_dir, image_path[0]) 54 | if self.transform is not None: 55 | image = self.transform(image) 56 | report_ids = example['ids'] 57 | report_masks = example['mask'] 58 | # report_label = example['label'] 59 | seq_length = len(report_ids) 60 | sample = (image_id, image, report_ids, report_masks, seq_length) 61 | return sample 62 | -------------------------------------------------------------------------------- /modules/encoder_decoder.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import copy 6 | import math 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from .att_model import pack_wrapper, AttModel 14 | 15 | 16 | def clones(module, N): 17 | "Produce N identical layers." 18 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) 19 | 20 | 21 | def subsequent_mask(size): 22 | "Mask out subsequent positions." 23 | attn_shape = (1, size, size) 24 | subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8') 25 | return torch.from_numpy(subsequent_mask) == 0 26 | 27 | 28 | def attention(query, key, value, mask=None, dropout=None): 29 | "Compute 'Scaled Dot Product Attention'" 30 | d_k = query.size(-1) 31 | scores = torch.matmul(query, key.transpose(-2, -1)) \ 32 | / math.sqrt(d_k) 33 | if mask is not None: 34 | scores = scores.masked_fill(mask == 0, float('-inf')) 35 | p_attn = F.softmax(scores, dim=-1) 36 | if dropout is not None: 37 | p_attn = dropout(p_attn) 38 | return torch.matmul(p_attn, value), p_attn 39 | 40 | 41 | class Transformer(nn.Module): 42 | """ 43 | A standard Encoder-Decoder architecture. Base for this and many 44 | other models. 45 | """ 46 | 47 | def __init__(self, encoder, decoder, src_embed, tgt_embed): 48 | super(Transformer, self).__init__() 49 | self.encoder = encoder 50 | self.decoder = decoder 51 | self.src_embed = src_embed 52 | self.tgt_embed = tgt_embed 53 | 54 | def forward(self, src, tgt, src_mask, tgt_mask): 55 | "Take in and process masked src and target sequences." 56 | return self.decode(self.encode(src, src_mask), src_mask, 57 | tgt, tgt_mask) 58 | 59 | def encode(self, src, src_mask): 60 | return self.encoder(self.src_embed(src), src_mask) 61 | 62 | def decode(self, memory, src_mask, tgt, tgt_mask, past=None): 63 | return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask, past=past) 64 | 65 | 66 | class Encoder(nn.Module): 67 | "Core encoder is a stack of N layers" 68 | 69 | def __init__(self, layer, N): 70 | super(Encoder, self).__init__() 71 | self.layers = clones(layer, N) 72 | self.norm = LayerNorm(layer.size) 73 | 74 | def forward(self, x, mask): 75 | "Pass the input (and mask) through each layer in turn." 76 | for layer in self.layers: 77 | x = layer(x, mask) 78 | return self.norm(x) 79 | 80 | 81 | class LayerNorm(nn.Module): 82 | "Construct a layernorm module (See citation for details)." 83 | 84 | def __init__(self, features, eps=1e-6): 85 | super(LayerNorm, self).__init__() 86 | self.a_2 = nn.Parameter(torch.ones(features)) 87 | self.b_2 = nn.Parameter(torch.zeros(features)) 88 | self.eps = eps 89 | 90 | def forward(self, x): 91 | mean = x.mean(-1, keepdim=True) 92 | std = x.std(-1, keepdim=True) 93 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 94 | 95 | 96 | class SublayerConnection(nn.Module): 97 | """ 98 | A residual connection followed by a layer norm. 99 | Note for code simplicity the norm is first as opposed to last. 100 | """ 101 | 102 | def __init__(self, size, dropout): 103 | super(SublayerConnection, self).__init__() 104 | self.norm = LayerNorm(size) 105 | self.dropout = nn.Dropout(dropout) 106 | 107 | def forward(self, x, sublayer): 108 | "Apply residual connection to any sublayer with the same size." 109 | _x = sublayer(self.norm(x)) 110 | if type(_x) is tuple: # for multi-head attention that returns past 111 | return x + self.dropout(_x[0]), _x[1] 112 | return x + self.dropout(_x) 113 | 114 | 115 | class EncoderLayer(nn.Module): 116 | "Encoder is made up of self-attn and feed forward (defined below)" 117 | 118 | def __init__(self, size, self_attn, feed_forward, dropout): 119 | super(EncoderLayer, self).__init__() 120 | self.self_attn = self_attn 121 | self.feed_forward = feed_forward 122 | self.sublayer = clones(SublayerConnection(size, dropout), 2) 123 | self.size = size 124 | 125 | def forward(self, x, mask): 126 | "Follow Figure 1 (left) for connections." 127 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) 128 | return self.sublayer[1](x, self.feed_forward) 129 | 130 | 131 | class Decoder(nn.Module): 132 | "Generic N layer decoder with masking." 133 | 134 | def __init__(self, layer, N): 135 | super(Decoder, self).__init__() 136 | self.layers = clones(layer, N) 137 | self.norm = LayerNorm(layer.size) 138 | 139 | def forward(self, x, memory, src_mask, tgt_mask, past=None): 140 | if past is not None: 141 | present = [[], []] 142 | x = x[:, -1:] 143 | tgt_mask = tgt_mask[:, -1:] if tgt_mask is not None else None 144 | past = list(zip(past[0].split(2, dim=0), past[1].split(2, dim=0))) 145 | else: 146 | past = [None] * len(self.layers) 147 | for i, (layer, layer_past) in enumerate(zip(self.layers, past)): 148 | x = layer(x, memory, src_mask, tgt_mask, 149 | layer_past) 150 | if layer_past is not None: 151 | present[0].append(x[1][0]) 152 | present[1].append(x[1][1]) 153 | x = x[0] 154 | if past[0] is None: 155 | return self.norm(x) 156 | else: 157 | return self.norm(x), [torch.cat(present[0], 0), torch.cat(present[1], 0)] 158 | 159 | 160 | class DecoderLayer(nn.Module): 161 | "Decoder is made of self-attn, src-attn, and feed forward (defined below)" 162 | 163 | def __init__(self, size, self_attn, src_attn, feed_forward, dropout): 164 | super(DecoderLayer, self).__init__() 165 | self.size = size 166 | self.self_attn = self_attn 167 | self.src_attn = src_attn 168 | self.feed_forward = feed_forward 169 | self.sublayer = clones(SublayerConnection(size, dropout), 3) 170 | 171 | def forward(self, x, memory, src_mask, tgt_mask, layer_past=None): 172 | "Follow Figure 1 (right) for connections." 173 | m = memory 174 | if layer_past is None: 175 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask)) 176 | x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask)) 177 | return self.sublayer[2](x, self.feed_forward) 178 | else: 179 | present = [None, None] 180 | x, present[0] = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask, layer_past[0])) 181 | x, present[1] = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask, layer_past[1])) 182 | return self.sublayer[2](x, self.feed_forward), present 183 | 184 | 185 | class MultiHeadedAttention(nn.Module): 186 | def __init__(self, h, d_model, dropout=0.1): 187 | "Take in model size and number of heads." 188 | super(MultiHeadedAttention, self).__init__() 189 | assert d_model % h == 0 190 | # We assume d_v always equals d_k 191 | self.d_k = d_model // h 192 | self.h = h 193 | self.linears = clones(nn.Linear(d_model, d_model), 4) 194 | self.attn = None 195 | self.dropout = nn.Dropout(p=dropout) 196 | 197 | def forward(self, query, key, value, mask=None, layer_past=None): 198 | "Implements Figure 2" 199 | if mask is not None: 200 | # Same mask applied to all h heads. 201 | mask = mask.unsqueeze(1) 202 | nbatches = query.size(0) 203 | 204 | # The past works differently here. For self attn, the query and key be updated incrementailly 205 | # For src_attn the past is fixed. 206 | 207 | # For src_attn, when the layer past is ready 208 | if layer_past is not None and layer_past.shape[2] == key.shape[ 209 | 1] > 1: # suppose memory size always greater than 1 210 | query = self.linears[0](query) 211 | key, value = layer_past[0], layer_past[1] 212 | present = torch.stack([key, value]) 213 | else: 214 | # 1) Do all the linear projections in batch from d_model => h x d_k 215 | query, key, value = \ 216 | [l(x) for l, x in zip(self.linears, (query, key, value))] 217 | 218 | # self attn + past OR the first time step of src attn 219 | if layer_past is not None and not (layer_past.shape[2] == key.shape[1] > 1): 220 | past_key, past_value = layer_past[0], layer_past[1] 221 | key = torch.cat((past_key, key), dim=1) 222 | value = torch.cat((past_value, value), dim=1) 223 | present = torch.stack([key, value]) 224 | 225 | query, key, value = \ 226 | [x.view(nbatches, -1, self.h, self.d_k).transpose(1, 2) 227 | for x in [query, key, value]] 228 | 229 | # 2) Apply attention on all the projected vectors in batch. 230 | x, self.attn = attention(query, key, value, mask=mask, 231 | dropout=self.dropout) 232 | 233 | # 3) "Concat" using a view and apply a final linear. 234 | x = x.transpose(1, 2).contiguous() \ 235 | .view(nbatches, -1, self.h * self.d_k) 236 | if layer_past is not None: 237 | return self.linears[-1](x), present 238 | else: 239 | return self.linears[-1](x) 240 | 241 | 242 | class PositionwiseFeedForward(nn.Module): 243 | "Implements FFN equation." 244 | 245 | def __init__(self, d_model, d_ff, dropout=0.1): 246 | super(PositionwiseFeedForward, self).__init__() 247 | self.w_1 = nn.Linear(d_model, d_ff) 248 | self.w_2 = nn.Linear(d_ff, d_model) 249 | self.dropout = nn.Dropout(dropout) 250 | 251 | def forward(self, x): 252 | return self.w_2(self.dropout(F.relu(self.w_1(x)))) 253 | 254 | 255 | class Embeddings(nn.Module): 256 | def __init__(self, d_model, vocab): 257 | super(Embeddings, self).__init__() 258 | self.lut = nn.Embedding(vocab, d_model) 259 | self.d_model = d_model 260 | 261 | def forward(self, x): 262 | return self.lut(x) * math.sqrt(self.d_model) 263 | 264 | 265 | class PositionalEncoding(nn.Module): 266 | "Implement the PE function." 267 | 268 | def __init__(self, d_model, dropout, max_len=5000): 269 | super(PositionalEncoding, self).__init__() 270 | self.dropout = nn.Dropout(p=dropout) 271 | 272 | # Compute the positional encodings once in log space. 273 | pe = torch.zeros(max_len, d_model) 274 | position = torch.arange(0, max_len).unsqueeze(1).float() 275 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * 276 | -(math.log(10000.0) / d_model)) 277 | pe[:, 0::2] = torch.sin(position * div_term) 278 | pe[:, 1::2] = torch.cos(position * div_term) 279 | pe = pe.unsqueeze(0) 280 | self.register_buffer('pe', pe) 281 | 282 | def forward(self, x): 283 | x = x + self.pe[:, :x.size(1)] 284 | return self.dropout(x) 285 | 286 | 287 | class EncoderDecoder(AttModel): 288 | 289 | def make_model(self, tgt_vocab): 290 | c = copy.deepcopy 291 | attn = MultiHeadedAttention(self.num_heads, self.d_model) 292 | ff = PositionwiseFeedForward(self.d_model, self.d_ff, self.dropout) 293 | position = PositionalEncoding(self.d_model, self.dropout) 294 | model = Transformer( 295 | Encoder(EncoderLayer(self.d_model, c(attn), c(ff), self.dropout), self.num_layers), 296 | Decoder( 297 | DecoderLayer(self.d_model, c(attn), c(attn), c(ff), self.dropout), 298 | self.num_layers), 299 | lambda x: x, 300 | nn.Sequential(Embeddings(self.d_model, tgt_vocab), c(position))) 301 | for p in model.parameters(): 302 | if p.dim() > 1: 303 | nn.init.xavier_uniform_(p) 304 | return model 305 | 306 | def __init__(self, args, tokenizer): 307 | super(EncoderDecoder, self).__init__(args, tokenizer) 308 | self.args = args 309 | self.num_layers = args.num_layers 310 | self.d_model = args.d_model 311 | self.d_ff = args.d_ff 312 | self.num_heads = args.num_heads 313 | self.dropout = args.dropout 314 | 315 | tgt_vocab = self.vocab_size + 1 316 | 317 | self.model = self.make_model(tgt_vocab) 318 | self.logit = nn.Linear(args.d_model, tgt_vocab) 319 | 320 | def init_hidden(self, bsz): 321 | return [] 322 | 323 | def _prepare_feature(self, fc_feats, att_feats, att_masks): 324 | 325 | att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks) 326 | memory = self.model.encode(att_feats, att_masks) 327 | 328 | return fc_feats[..., :1], att_feats[..., :1], memory, att_masks 329 | 330 | def _prepare_feature_forward(self, att_feats, att_masks=None, seq=None): 331 | att_feats, att_masks = self.clip_att(att_feats, att_masks) 332 | att_feats = pack_wrapper(self.att_embed, att_feats, att_masks) 333 | 334 | if att_masks is None: 335 | att_masks = att_feats.new_ones(att_feats.shape[:2], dtype=torch.long) 336 | att_masks = att_masks.unsqueeze(-2) 337 | 338 | if seq is not None: 339 | # crop the last one 340 | seq = seq[:, :-1] 341 | seq_mask = (seq.data > 0) 342 | seq_mask[:, 0] += True 343 | 344 | seq_mask = seq_mask.unsqueeze(-2) 345 | seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask) 346 | else: 347 | seq_mask = None 348 | 349 | return att_feats, seq, att_masks, seq_mask 350 | 351 | def _forward(self, fc_feats, att_feats, seq, att_masks=None): 352 | 353 | att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq) 354 | out = self.model(att_feats, seq, att_masks, seq_mask) 355 | outputs = F.log_softmax(self.logit(out), dim=-1) 356 | 357 | return outputs 358 | 359 | def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask): 360 | """ 361 | state is the precomputed key/value. N_dec x seq_len x d_model 362 | Note: due to the layer norm, it's not equivalant to stateless, 363 | but it seems behaving similar 364 | """ 365 | # state is tokens + past 366 | if len(state) == 0: 367 | ys = it.unsqueeze(1) 368 | # basically empty state, just to let it know to return past 369 | # The second dim has to be batch_size, for beam search purpose 370 | past = [fc_feats_ph.new_zeros(self.num_layers * 2, fc_feats_ph.shape[0], 0, self.d_model), # self 371 | fc_feats_ph.new_zeros(self.num_layers * 2, fc_feats_ph.shape[0], 0, self.d_model)] # src 372 | # 2 for self attn, 2 for src attn 373 | else: 374 | ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1) 375 | past = state[1:] 376 | out, past = self.model.decode(memory, mask, ys, subsequent_mask(ys.size(1)).to(memory.device), past=past) 377 | return out[:, -1], [ys.unsqueeze(0)] + past 378 | -------------------------------------------------------------------------------- /modules/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class LanguageModelCriterion(nn.Module): 6 | def __init__(self): 7 | super(LanguageModelCriterion, self).__init__() 8 | 9 | def forward(self, input, target, mask): 10 | # truncate to the same size 11 | target = target[:, :input.size(1)] 12 | mask = mask[:, :input.size(1)] 13 | output = -input.gather(2, target.long().unsqueeze(2)).squeeze(2) * mask 14 | output = torch.sum(output) / torch.sum(mask) 15 | return output 16 | 17 | 18 | class LossWrapper(nn.Module): 19 | def __init__(self): 20 | super(LossWrapper, self).__init__() 21 | self.criterion = LanguageModelCriterion() 22 | self.criterion_mlc = nn.BCELoss() 23 | 24 | def forward(self, output, output_mlc, reports_ids, reports_masks, label): 25 | loss = self.criterion(output, reports_ids[:, 1:], reports_masks[:, 1:]).mean() 26 | loss_mlc = self.criterion_mlc(output_mlc, label) 27 | return loss + loss_mlc 28 | 29 | 30 | def compute_loss(output, reports_ids, reports_masks): 31 | criterion = LanguageModelCriterion() 32 | loss = criterion(output, reports_ids[:, 1:], reports_masks[:, 1:]).mean() 33 | return loss 34 | 35 | 36 | class RewardCriterion(nn.Module): 37 | def __init__(self): 38 | super(RewardCriterion, self).__init__() 39 | 40 | def forward(self, input, seq, reward): 41 | input = input.gather(2, seq.unsqueeze(2)).squeeze(2) 42 | 43 | input = input.reshape(-1) 44 | reward = reward.reshape(-1) 45 | mask = (seq > 0).to(input) 46 | mask = torch.cat([mask.new(mask.size(0), 1).fill_(1), mask[:, :-1]], 1).reshape(-1) 47 | output = - input * reward * mask 48 | output = torch.sum(output) / torch.sum(mask) 49 | 50 | return output 51 | -------------------------------------------------------------------------------- /modules/metrics.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import roc_auc_score, f1_score, recall_score, precision_score 2 | from pycocoevalcap.bleu.bleu import Bleu 3 | from pycocoevalcap.meteor import Meteor 4 | from pycocoevalcap.rouge import Rouge 5 | 6 | 7 | def compute_scores(gts, res): 8 | """ 9 | Performs the MS COCO evaluation using the Python 3 implementation (https://github.com/salaniz/pycocoevalcap) 10 | 11 | :param gts: Dictionary with the image ids and their gold captions, 12 | :param res: Dictionary with the image ids ant their generated captions 13 | :print: Evaluation score (the mean of the scores of all the instances) for each measure 14 | """ 15 | 16 | # Set up scorers 17 | scorers = [ 18 | (Bleu(4), ["BLEU_1", "BLEU_2", "BLEU_3", "BLEU_4"]), 19 | (Meteor(), "METEOR"), 20 | (Rouge(), "ROUGE_L") 21 | ] 22 | eval_res = {} 23 | # Compute score for each metric 24 | for scorer, method in scorers: 25 | try: 26 | score, scores = scorer.compute_score(gts, res, verbose=0) 27 | except TypeError: 28 | score, scores = scorer.compute_score(gts, res) 29 | if type(method) == list: 30 | for sc, m in zip(score, method): 31 | eval_res[m] = sc 32 | else: 33 | eval_res[method] = score 34 | return eval_res 35 | 36 | 37 | def compute_mlc(gt, pred, label_set): 38 | res_mlc = {} 39 | avg_aucroc = 0 40 | for i, label in enumerate(label_set): 41 | res_mlc['AUCROC_' + label] = roc_auc_score(gt[:, i], pred[:, i]) 42 | avg_aucroc += res_mlc['AUCROC_' + label] 43 | res_mlc['AVG_AUCROC'] = avg_aucroc / len(label_set) 44 | 45 | res_mlc['F1_MACRO'] = f1_score(gt, pred, average="macro") 46 | res_mlc['F1_MICRO'] = f1_score(gt, pred, average="micro") 47 | res_mlc['RECALL_MACRO'] = recall_score(gt, pred, average="macro") 48 | res_mlc['RECALL_MICRO'] = recall_score(gt, pred, average="micro") 49 | res_mlc['PRECISION_MACRO'] = precision_score(gt, pred, average="macro") 50 | res_mlc['PRECISION_MICRO'] = precision_score(gt, pred, average="micro") 51 | 52 | return res_mlc 53 | 54 | 55 | class MetricWrapper(object): 56 | def __init__(self, label_set): 57 | self.label_set = label_set 58 | 59 | def __call__(self, gts, res, gts_mlc, res_mlc): 60 | eval_res = compute_scores(gts, res) 61 | eval_res_mlc = compute_mlc(gts_mlc, res_mlc, self.label_set) 62 | 63 | eval_res.update(**eval_res_mlc) 64 | return eval_res 65 | -------------------------------------------------------------------------------- /modules/optimizers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import optim 3 | 4 | 5 | def build_optimizer(args, model): 6 | ve_params = list(map(id, model.visual_extractor.parameters())) 7 | ed_params = filter(lambda x: id(x) not in ve_params, model.parameters()) 8 | optimizer = getattr(torch.optim, args.optim)( 9 | [{'params': model.visual_extractor.parameters(), 'lr': args.lr_ve}, 10 | {'params': ed_params, 'lr': args.lr_ed}], 11 | betas=args.adam_betas, 12 | eps=args.adam_eps, 13 | weight_decay=args.weight_decay, 14 | amsgrad=args.amsgrad 15 | ) 16 | return optimizer 17 | 18 | 19 | def build_lr_scheduler(args, optimizer): 20 | lr_scheduler = getattr(torch.optim.lr_scheduler, args.lr_scheduler)(optimizer, args.step_size, args.gamma) 21 | return lr_scheduler 22 | 23 | 24 | def set_lr(optimizer, lr): 25 | for group in optimizer.param_groups: 26 | group['lr'] = lr 27 | 28 | 29 | def get_lr(optimizer): 30 | for group in optimizer.param_groups: 31 | return group['lr'] 32 | 33 | 34 | class NoamOpt(object): 35 | "Optim wrapper that implements rate." 36 | 37 | def __init__(self, model_size, factor, warmup, optimizer): 38 | self.optimizer = optimizer 39 | self._step = 0 40 | self.warmup = warmup 41 | self.factor = factor 42 | self.model_size = model_size 43 | self._rate = 0 44 | 45 | def step(self): 46 | "Update parameters and rate" 47 | self._step += 1 48 | rate = self.rate() 49 | for p in self.optimizer.param_groups: 50 | p['lr'] = rate 51 | self._rate = rate 52 | self.optimizer.step() 53 | 54 | def rate(self, step=None): 55 | "Implement `lrate` above" 56 | if step is None: 57 | step = self._step 58 | return self.factor * \ 59 | (self.model_size ** (-0.5) * 60 | min(step ** (-0.5), step * self.warmup ** (-1.5))) 61 | 62 | def __getattr__(self, name): 63 | return getattr(self.optimizer, name) 64 | 65 | def state_dict(self): 66 | state_dict = self.optimizer.state_dict() 67 | state_dict['_step'] = self._step 68 | return state_dict 69 | 70 | def load_state_dict(self, state_dict): 71 | if '_step' in state_dict: 72 | self._step = state_dict['_step'] 73 | del state_dict['_step'] 74 | self.optimizer.load_state_dict(state_dict) 75 | 76 | 77 | def get_std_opt(model, optim_func='adam', factor=1, warmup=2000): 78 | optim_func = dict(Adam=torch.optim.Adam, 79 | AdamW=torch.optim.AdamW)[optim_func] 80 | return NoamOpt(model.d_model, factor, warmup, 81 | optim_func(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) 82 | 83 | 84 | def build_noamopt_optimizer(args, model): 85 | ve_optimizer = getattr(torch.optim, args.optim)( 86 | model.visual_extractor.parameters(), 87 | lr=0, 88 | betas=args.adam_betas, 89 | eps=args.adam_eps, 90 | weight_decay=args.weight_decay, 91 | amsgrad=args.amsgrad 92 | ) 93 | ed_optimizer = get_std_opt(model.encoder_decoder, optim_func=args.optim, factor=args.noamopt_factor, warmup=args.noamopt_warmup) 94 | return ve_optimizer, ed_optimizer 95 | 96 | 97 | class ReduceLROnPlateau(object): 98 | "Optim wrapper that implements rate." 99 | 100 | def __init__(self, optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, 101 | threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08): 102 | self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode=mode, factor=factor, 103 | patience=patience, verbose=verbose, threshold=threshold, 104 | threshold_mode=threshold_mode, cooldown=cooldown, 105 | min_lr=min_lr, eps=eps) 106 | self.optimizer = optimizer 107 | self.current_lr = get_lr(optimizer) 108 | 109 | def step(self): 110 | "Update parameters and rate" 111 | self.optimizer.step() 112 | 113 | def scheduler_step(self, val): 114 | self.scheduler.step(val) 115 | self.current_lr = get_lr(self.optimizer) 116 | 117 | def state_dict(self): 118 | return {'current_lr': self.current_lr, 119 | 'scheduler_state_dict': self.scheduler.state_dict(), 120 | 'optimizer_state_dict': self.optimizer.state_dict()} 121 | 122 | def load_state_dict(self, state_dict): 123 | if 'current_lr' not in state_dict: 124 | # it's normal optimizer 125 | self.optimizer.load_state_dict(state_dict) 126 | set_lr(self.optimizer, self.current_lr) # use the lr fromt the option 127 | else: 128 | # it's a schduler 129 | self.current_lr = state_dict['current_lr'] 130 | self.scheduler.load_state_dict(state_dict['scheduler_state_dict']) 131 | self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) 132 | # current_lr is actually useless in this case 133 | 134 | def rate(self, step=None): 135 | "Implement `lrate` above" 136 | if step is None: 137 | step = self._step 138 | return self.factor * \ 139 | (self.model_size ** (-0.5) * 140 | min(step ** (-0.5), step * self.warmup ** (-1.5))) 141 | 142 | def __getattr__(self, name): 143 | return getattr(self.optimizer, name) 144 | 145 | 146 | def build_plateau_optimizer(args, model): 147 | ve_optimizer = getattr(torch.optim, args.optim)( 148 | model.visual_extractor.parameters(), 149 | lr=args.lr_ve, 150 | betas=args.adam_betas, 151 | eps=args.adam_eps, 152 | weight_decay=args.weight_decay, 153 | amsgrad=args.amsgrad 154 | ) 155 | ve_optimizer = ReduceLROnPlateau(ve_optimizer, 156 | factor=args.reduce_on_plateau_factor, 157 | patience=args.reduce_on_plateau_patience) 158 | ed_optimizer = getattr(torch.optim, args.optim)( 159 | model.encoder_decoder.parameters(), 160 | lr=args.lr_ed, 161 | betas=args.adam_betas, 162 | eps=args.adam_eps, 163 | weight_decay=args.weight_decay, 164 | amsgrad=args.amsgrad 165 | ) 166 | ed_optimizer = ReduceLROnPlateau(ed_optimizer, 167 | factor=args.reduce_on_plateau_factor, 168 | patience=args.reduce_on_plateau_patience) 169 | 170 | return ve_optimizer, ed_optimizer -------------------------------------------------------------------------------- /modules/rewards.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from collections import OrderedDict 6 | 7 | import numpy as np 8 | 9 | import logging 10 | 11 | from pycocoevalcap.bleu.bleu import Bleu 12 | 13 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 14 | datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) 15 | logger = logging.getLogger(__name__) 16 | 17 | Bleu_scorer = None 18 | 19 | 20 | def init_scorer(): 21 | global Bleu_scorer 22 | Bleu_scorer = Bleu_scorer or Bleu(4) 23 | 24 | 25 | def array_to_str(arr): 26 | out = '' 27 | for i in range(len(arr)): 28 | out += str(arr[i]) + ' ' 29 | if arr[i] == 0: 30 | break 31 | return out.strip() 32 | 33 | 34 | def get_self_critical_reward(greedy_res, data_gts, gen_result): 35 | batch_size = len(data_gts) 36 | gen_result_size = gen_result.shape[0] 37 | seq_per_img = gen_result_size // len(data_gts) # gen_result_size = batch_size * seq_per_img 38 | assert greedy_res.shape[0] == batch_size 39 | 40 | res = OrderedDict() 41 | gen_result = gen_result.data.cpu().numpy() 42 | greedy_res = greedy_res.data.cpu().numpy() 43 | for i in range(gen_result_size): 44 | res[i] = [array_to_str(gen_result[i])] 45 | for i in range(batch_size): 46 | res[gen_result_size + i] = [array_to_str(greedy_res[i])] 47 | 48 | gts = OrderedDict() 49 | data_gts = data_gts.cpu().numpy() 50 | for i in range(len(data_gts)): 51 | gts[i] = [array_to_str(data_gts[i])] 52 | res_ = [{'image_id': i, 'caption': res[i]} for i in range(len(res))] 53 | res__ = {i: res[i] for i in range(len(res_))} 54 | gts_ = {i: gts[i // seq_per_img] for i in range(gen_result_size)} 55 | gts_.update({i + gen_result_size: gts[i] for i in range(batch_size)}) 56 | _, bleu_scores = Bleu_scorer.compute_score(gts_, res__, verbose = 0) 57 | bleu_scores = np.array(bleu_scores[3]) 58 | # logger.info('Bleu scores: {:.4f}.'.format(_[3])) 59 | scores = bleu_scores 60 | 61 | scores = scores[:gen_result_size].reshape(batch_size, seq_per_img) - scores[-batch_size:][:, np.newaxis] 62 | scores = scores.reshape(gen_result_size) 63 | 64 | rewards = np.repeat(scores[:, np.newaxis], gen_result.shape[1], 1) 65 | 66 | return rewards 67 | -------------------------------------------------------------------------------- /modules/tokenizers.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | import os 4 | from collections import Counter 5 | 6 | 7 | class Tokenizer(object): 8 | def __init__(self, args): 9 | self.ann_path = args.ann_path 10 | self.threshold = args.threshold 11 | self.dataset_name = args.dataset_name 12 | if self.dataset_name == 'iu_xray': 13 | self.clean_report = self.clean_report_iu_xray 14 | else: 15 | self.clean_report = self.clean_report_mimic_cxr 16 | self.ann = json.loads(open(self.ann_path, 'r').read()) 17 | self.token2idx, self.idx2token = self.create_vocabulary() 18 | self.save_to_file(self.token2idx, self.idx2token) 19 | 20 | def save_to_file(self, token2idx, idx2token): 21 | with open(os.path.join('data', self.dataset_name, 'token2idx.json'), 'w') as f: 22 | json_str = json.dumps(token2idx) 23 | f.write(json_str) 24 | 25 | with open(os.path.join('data', self.dataset_name, 'idx2token.json'), 'w') as f: 26 | json_str = json.dumps(idx2token) 27 | f.write(json_str) 28 | 29 | def create_vocabulary(self): 30 | total_tokens = [] 31 | 32 | for example in self.ann['train']: 33 | tokens = self.clean_report(example['report']).split() 34 | for token in tokens: 35 | total_tokens.append(token) 36 | 37 | counter = Counter(total_tokens) 38 | vocab = [k for k, v in counter.items() if v >= self.threshold] + [''] 39 | vocab.sort() 40 | token2idx, idx2token = {}, {} 41 | for idx, token in enumerate(vocab): 42 | token2idx[token] = idx + 1 43 | idx2token[idx + 1] = token 44 | return token2idx, idx2token 45 | 46 | def clean_report_iu_xray(self, report): 47 | report_cleaner = lambda t: t.replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '') \ 48 | .replace('. 2. ', '. ').replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ') \ 49 | .replace(' 2. ', '. ').replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \ 50 | .strip().lower().split('. ') 51 | sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', ''). 52 | replace('\\', '').replace("'", '').strip().lower()) 53 | tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []] 54 | report = ' . '.join(tokens) + ' .' 55 | return report 56 | 57 | def clean_report_mimic_cxr(self, report): 58 | report_cleaner = lambda t: t.replace('\n', ' ').replace('__', '_').replace('__', '_').replace('__', '_') \ 59 | .replace('__', '_').replace('__', '_').replace('__', '_').replace('__', '_').replace(' ', ' ') \ 60 | .replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ') \ 61 | .replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.') \ 62 | .replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '').replace('. 2. ', '. ') \ 63 | .replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ').replace(' 2. ', '. ') \ 64 | .replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \ 65 | .strip().lower().split('. ') 66 | sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', '') 67 | .replace('\\', '').replace("'", '').strip().lower()) 68 | tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []] 69 | report = ' . '.join(tokens) + ' .' 70 | return report 71 | 72 | def get_token_by_id(self, id): 73 | return self.idx2token[id] 74 | 75 | def get_id_by_token(self, token): 76 | if token not in self.token2idx: 77 | return self.token2idx[''] 78 | return self.token2idx[token] 79 | 80 | def get_vocab_size(self): 81 | return len(self.token2idx) 82 | 83 | def __call__(self, report): 84 | tokens = self.clean_report(report).split() 85 | ids = [] 86 | for token in tokens: 87 | ids.append(self.get_id_by_token(token)) 88 | ids = [0] + ids + [0] 89 | return ids 90 | 91 | def decode(self, ids): 92 | txt = '' 93 | for i, idx in enumerate(ids): 94 | if idx > 0: 95 | if i >= 1: 96 | txt += ' ' 97 | txt += self.idx2token[idx] 98 | else: 99 | break 100 | return txt 101 | 102 | def decode_batch(self, ids_batch): 103 | out = [] 104 | for ids in ids_batch: 105 | out.append(self.decode(ids)) 106 | return out 107 | -------------------------------------------------------------------------------- /modules/trainer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import os 4 | import time 5 | from abc import abstractmethod 6 | 7 | import pandas as pd 8 | import torch 9 | from numpy import inf 10 | 11 | from modules.optimizers import set_lr, get_lr 12 | 13 | 14 | class BaseTrainer(object): 15 | def __init__(self, model, criterion, metric_ftns, ve_optimizer, ed_optimizer, args): 16 | self.args = args 17 | 18 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 19 | datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) 20 | self.logger = logging.getLogger(__name__) 21 | 22 | # setup GPU device if available, move model into configured device 23 | self.device, device_ids = self._prepare_device(args.n_gpu) 24 | self.model = model.to(self.device) 25 | if len(device_ids) > 1: 26 | self.model = torch.nn.DataParallel(model, device_ids=device_ids) 27 | 28 | self.criterion = criterion 29 | self.metric_ftns = metric_ftns 30 | self.ve_optimizer = ve_optimizer 31 | self.ed_optimizer = ed_optimizer 32 | 33 | self.epochs = self.args.epochs 34 | self.save_period = self.args.save_period 35 | 36 | self.mnt_mode = args.monitor_mode 37 | self.mnt_metric = 'val_' + args.monitor_metric 38 | self.mnt_metric_test = 'test_' + args.monitor_metric 39 | assert self.mnt_mode in ['min', 'max'] 40 | 41 | self.mnt_best = inf if self.mnt_mode == 'min' else -inf 42 | self.early_stop = getattr(self.args, 'early_stop', inf) 43 | 44 | self.start_epoch = 1 45 | self.checkpoint_dir = args.save_dir + '_seed_' + str(args.seed) 46 | 47 | self.best_recorder = {'val': {self.mnt_metric: self.mnt_best}, 48 | 'test': {self.mnt_metric_test: self.mnt_best}} 49 | 50 | if not os.path.exists(self.checkpoint_dir): 51 | os.makedirs(self.checkpoint_dir) 52 | 53 | if args.resume is not None: 54 | self._resume_checkpoint(args.resume) 55 | 56 | @abstractmethod 57 | def _train_epoch(self, epoch): 58 | raise NotImplementedError 59 | 60 | def train(self): 61 | not_improved_count = 0 62 | for epoch in range(self.start_epoch, self.epochs + 1): 63 | result = self._train_epoch(epoch) 64 | 65 | # save logged informations into log dict 66 | log = {'epoch': epoch} 67 | log.update(result) 68 | self._record_best(log) 69 | self._print_to_file(log) 70 | 71 | # print logged informations to the screen 72 | for key, value in log.items(): 73 | self.logger.info('\t{:15s}: {}'.format(str(key), value)) 74 | 75 | # evaluate model performance according to configured metric, save best checkpoint as model_best 76 | best = False 77 | if self.mnt_mode != 'off': 78 | try: 79 | # check whether model performance improved or not, according to specified metric(mnt_metric) 80 | improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \ 81 | (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best) 82 | except KeyError: 83 | self.logger.warning( 84 | "Warning: Metric '{}' is not found. " "Model performance monitoring is disabled.".format( 85 | self.mnt_metric)) 86 | self.mnt_mode = 'off' 87 | improved = False 88 | 89 | if improved: 90 | self.mnt_best = log[self.mnt_metric] 91 | not_improved_count = 0 92 | best = True 93 | else: 94 | not_improved_count += 1 95 | 96 | if not_improved_count > self.early_stop: 97 | self.logger.info("Validation performance didn\'t improve for {} epochs. " "Training stops.".format( 98 | self.early_stop)) 99 | break 100 | 101 | if epoch % self.save_period == 0: 102 | self._save_checkpoint(epoch, save_best=best) 103 | 104 | def _record_best(self, log): 105 | improved_val = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.best_recorder['val'][ 106 | self.mnt_metric]) or \ 107 | (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.best_recorder['val'][self.mnt_metric]) 108 | if improved_val: 109 | self.best_recorder['val'].update(log) 110 | 111 | improved_test = (self.mnt_mode == 'min' and log[self.mnt_metric_test] <= self.best_recorder['test'][ 112 | self.mnt_metric_test]) or \ 113 | (self.mnt_mode == 'max' and log[self.mnt_metric_test] >= self.best_recorder['test'][ 114 | self.mnt_metric_test]) 115 | if improved_test: 116 | self.best_recorder['test'].update(log) 117 | 118 | def _print_to_file(self, log): 119 | crt_time = time.asctime(time.localtime(time.time())) 120 | log['time'] = crt_time 121 | log['seed'] = self.args.seed 122 | log['best_model_from'] = 'train' 123 | 124 | if not os.path.exists(self.args.record_dir): 125 | os.makedirs(self.args.record_dir) 126 | record_path = os.path.join(self.args.record_dir, self.args.dataset_name + '_seed_' + str(self.args.seed) + '.csv') 127 | if not os.path.exists(record_path): 128 | record_table = pd.DataFrame() 129 | else: 130 | record_table = pd.read_csv(record_path) 131 | tmp_log = copy.deepcopy(log) 132 | tmp_log.update(**self.args.__dict__) 133 | record_table = record_table.append(tmp_log, ignore_index=True) 134 | record_table.to_csv(record_path, index=False) 135 | 136 | def _print_best(self): 137 | self.logger.info('Best results (w.r.t {}) in validation set:'.format(self.args.monitor_metric)) 138 | for key, value in self.best_recorder['val'].items(): 139 | self.logger.info('\t{:15s}: {}'.format(str(key), value)) 140 | 141 | self.logger.info('Best results (w.r.t {}) in test set:'.format(self.args.monitor_metric)) 142 | for key, value in self.best_recorder['test'].items(): 143 | self.logger.info('\t{:15s}: {}'.format(str(key), value)) 144 | 145 | def _get_learning_rate(self): 146 | lrs = list() 147 | lrs.append(self.ve_optimizer.state_dict()['param_groups'][0]['lr']) 148 | lrs.append(self.ed_optimizer.state_dict()['param_groups'][0]['lr']) 149 | 150 | return {'lr_visual_extractor': lrs[0], 'lr_encoder_decoder': lrs[1]} 151 | 152 | def _prepare_device(self, n_gpu_use): 153 | n_gpu = torch.cuda.device_count() 154 | if n_gpu_use > 0 and n_gpu == 0: 155 | self.logger.warning( 156 | "Warning: There\'s no GPU available on this machine," "training will be performed on CPU.") 157 | n_gpu_use = 0 158 | if n_gpu_use > n_gpu: 159 | self.logger.warning( 160 | "Warning: The number of GPU\'s configured to use is {}, but only {} are available " "on this machine.".format( 161 | n_gpu_use, n_gpu)) 162 | n_gpu_use = n_gpu 163 | device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu') 164 | list_ids = list(range(n_gpu_use)) 165 | return device, list_ids 166 | 167 | def _save_checkpoint(self, epoch, save_best=False): 168 | state = { 169 | 'epoch': epoch, 170 | 'state_dict': self.model.state_dict(), 171 | 've_optimizer': self.ve_optimizer.state_dict(), 172 | 'ed_optimizer': self.ed_optimizer.state_dict(), 173 | 'monitor_best': self.mnt_best 174 | } 175 | filename = os.path.join(self.checkpoint_dir, 'current_checkpoint.pth') 176 | torch.save(state, filename) 177 | self.logger.info("Saving checkpoint: {} ...".format(filename)) 178 | if save_best: 179 | best_path = os.path.join(self.checkpoint_dir, 'model_best.pth') 180 | torch.save(state, best_path) 181 | self.logger.info("Saving current best: model_best.pth ...") 182 | 183 | def _resume_checkpoint(self, resume_path): 184 | resume_path = str(resume_path) 185 | self.logger.info("Loading checkpoint: {} ...".format(resume_path)) 186 | checkpoint = torch.load(resume_path) 187 | self.start_epoch = checkpoint['epoch'] + 1 188 | self.mnt_best = checkpoint['monitor_best'] 189 | self.model.load_state_dict(checkpoint['state_dict']) 190 | self.ve_optimizer.load_state_dict(checkpoint['ve_optimizer']) 191 | self.ed_optimizer.load_state_dict(checkpoint['ed_optimizer']) 192 | 193 | self.logger.info("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch)) 194 | 195 | 196 | class Trainer(BaseTrainer): 197 | def __init__(self, model, criterion, metric_ftns, ve_optimizer, ed_optimizer, args, train_dataloader, 198 | val_dataloader, test_dataloader): 199 | super(Trainer, self).__init__(model, criterion, metric_ftns, ve_optimizer, ed_optimizer, args) 200 | # self.lr_scheduler = lr_scheduler 201 | self.train_dataloader = train_dataloader 202 | self.val_dataloader = val_dataloader 203 | self.test_dataloader = test_dataloader 204 | 205 | def _set_lr_ve(self, iteration): 206 | # if iteration < self.args.noamopt_warmup: 207 | # current_lr = self.args.lr_ve * (iteration + 1) / self.args.noamopt_warmup 208 | # set_lr(self.ve_optimizer, current_lr) 209 | current_lr_ed = get_lr(self.ed_optimizer) 210 | current_lr_ve = current_lr_ed * 0.1 211 | set_lr(self.ve_optimizer, current_lr_ve) 212 | 213 | def _set_lr_ed(self, iteration): 214 | if iteration < self.args.noamopt_warmup: 215 | current_lr = self.args.lr_ed * (iteration + 1) / self.args.noamopt_warmup 216 | set_lr(self.ed_optimizer, current_lr) 217 | 218 | def _train_epoch(self, epoch): 219 | 220 | self.logger.info('[{}/{}] Start to train in the training set.'.format(epoch, self.epochs)) 221 | train_loss = 0 222 | self.model.train() 223 | for batch_idx, (images_id, images, reports_ids, reports_masks) in enumerate(self.train_dataloader): 224 | 225 | iteration = batch_idx + (epoch - 1) * len(self.train_dataloader) 226 | # self._set_lr_ed(iteration) 227 | self._set_lr_ve(iteration) 228 | 229 | images, reports_ids, reports_masks = images.to(self.device), reports_ids.to(self.device), \ 230 | reports_masks.to(self.device) 231 | output = self.model(images, reports_ids, mode='train') 232 | loss = self.criterion(output, reports_ids, reports_masks) 233 | train_loss += loss.item() 234 | self.ve_optimizer.zero_grad() 235 | self.ed_optimizer.zero_grad() 236 | loss.backward() 237 | self.ve_optimizer.step() 238 | self.ed_optimizer.step() 239 | if batch_idx % self.args.log_period == 0: 240 | lrs = self._get_learning_rate() 241 | self.logger.info('[{}/{}] Step: {}/{}, Training Loss: {:.5f}, LR (ve): {:.5f}, LR (ed): {:5f}.' 242 | .format(epoch, self.epochs, batch_idx, len(self.train_dataloader), 243 | train_loss / (batch_idx + 1), lrs['lr_visual_extractor'], 244 | lrs['lr_encoder_decoder'])) 245 | 246 | log = {'train_loss': train_loss / len(self.train_dataloader)} 247 | 248 | self.logger.info('[{}/{}] Start to evaluate in the validation set.'.format(epoch, self.epochs)) 249 | self.model.eval() 250 | with torch.no_grad(): 251 | val_loss = 0 252 | val_gts, val_res = [], [] 253 | for batch_idx, (images_id, images, reports_ids, reports_masks) in enumerate(self.val_dataloader): 254 | images, reports_ids, reports_masks = images.to(self.device), reports_ids.to( 255 | self.device), reports_masks.to(self.device) 256 | 257 | # ****** Compute Loss ****** 258 | images, reports_ids, reports_masks = images.to(self.device), reports_ids.to(self.device), \ 259 | reports_masks.to(self.device) 260 | output = self.model(images, reports_ids, mode='train') 261 | loss = self.criterion(output, reports_ids, reports_masks) 262 | val_loss += loss.item() 263 | # ****** Compute Loss ****** 264 | 265 | output, _ = self.model(images, mode='sample') 266 | reports = self.model.tokenizer.decode_batch(output.cpu().numpy()) 267 | ground_truths = self.model.tokenizer.decode_batch(reports_ids[:, 1:].cpu().numpy()) 268 | val_res.extend(reports) 269 | val_gts.extend(ground_truths) 270 | 271 | for id, re, gt in zip(images_id, reports, ground_truths): 272 | print(id) 273 | print('[Generated]: {}'.format(re)) 274 | print('[Ground Truth]: {}'.format(gt)) 275 | 276 | val_met = self.metric_ftns({i: [gt] for i, gt in enumerate(val_gts)}, 277 | {i: [re] for i, re in enumerate(val_res)}) 278 | log.update(**{'val_' + k: v for k, v in val_met.items()}) 279 | log.update(**{'val_loss': val_loss / len(self.val_dataloader)}) 280 | 281 | self.logger.info('[{}/{}] Start to evaluate in the test set.'.format(epoch, self.epochs)) 282 | self.model.eval() 283 | with torch.no_grad(): 284 | test_gts, test_res = [], [] 285 | for batch_idx, (images_id, images, reports_ids, reports_masks) in enumerate(self.test_dataloader): 286 | images, reports_ids, reports_masks = images.to(self.device), reports_ids.to( 287 | self.device), reports_masks.to(self.device) 288 | output, _ = self.model(images, mode='sample') 289 | reports = self.model.tokenizer.decode_batch(output.cpu().numpy()) 290 | ground_truths = self.model.tokenizer.decode_batch(reports_ids[:, 1:].cpu().numpy()) 291 | test_res.extend(reports) 292 | test_gts.extend(ground_truths) 293 | 294 | for id, re, gt in zip(images_id, reports, ground_truths): 295 | print(id) 296 | print('[Generated]: {}'.format(re)) 297 | print('[Ground Truth]: {}'.format(gt)) 298 | 299 | test_met = self.metric_ftns({i: [gt] for i, gt in enumerate(test_gts)}, 300 | {i: [re] for i, re in enumerate(test_res)}) 301 | log.update(**{'test_' + k: v for k, v in test_met.items()}) 302 | 303 | log.update(**self._get_learning_rate()) 304 | # self.lr_scheduler.step() 305 | 306 | return log 307 | -------------------------------------------------------------------------------- /modules/trainer_base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from abc import abstractmethod 4 | 5 | import time 6 | import torch 7 | import pandas as pd 8 | import numpy as np 9 | from numpy import inf 10 | 11 | 12 | class BaseTrainer(object): 13 | def __init__(self, model, criterion, metric_ftns, optimizer, args): 14 | self.args = args 15 | 16 | # setup GPU device if available, move model into configured device 17 | self.device, device_ids = self._prepare_device(args.n_gpu) 18 | self.model = model.to(self.device) 19 | if len(device_ids) > 1: 20 | self.model = torch.nn.DataParallel(model, device_ids=device_ids) 21 | 22 | self.criterion = criterion 23 | self.metric_ftns = metric_ftns 24 | self.optimizer = optimizer 25 | 26 | self.epochs = self.args.epochs 27 | self.save_period = self.args.save_period 28 | 29 | self.mnt_mode = args.monitor_mode 30 | self.mnt_metric = 'val_' + args.monitor_metric 31 | self.mnt_metric_test = 'test_' + args.monitor_metric 32 | assert self.mnt_mode in ['min', 'max'] 33 | 34 | self.mnt_best = inf if self.mnt_mode == 'min' else -inf 35 | self.early_stop = getattr(self.args, 'early_stop', inf) 36 | 37 | self.start_epoch = 1 38 | self.checkpoint_dir = args.save_dir 39 | 40 | if not os.path.exists(self.checkpoint_dir): 41 | os.makedirs(self.checkpoint_dir) 42 | 43 | if args.resume is not None: 44 | self._resume_checkpoint(args.resume) 45 | 46 | self.best_recorder = {'val': {self.mnt_metric: self.mnt_best}, 47 | 'test': {self.mnt_metric_test: self.mnt_best}} 48 | 49 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 50 | datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) 51 | self.logger = logging.getLogger(__name__) 52 | 53 | @abstractmethod 54 | def _train_epoch(self, epoch): 55 | raise NotImplementedError 56 | 57 | def train(self): 58 | not_improved_count = 0 59 | for epoch in range(self.start_epoch, self.epochs + 1): 60 | result = self._train_epoch(epoch) 61 | 62 | # save logged informations into log dict 63 | log = {'epoch': epoch} 64 | log.update(result) 65 | self._record_best(log) 66 | self._print_to_file(log) 67 | 68 | # print logged informations to the screen 69 | for key, value in log.items(): 70 | self.logger.info('\t{:15s}: {}'.format(str(key), value)) 71 | 72 | # evaluate model performance according to configured metric, save best checkpoint as model_best 73 | best = False 74 | if self.mnt_mode != 'off': 75 | try: 76 | # check whether model performance improved or not, according to specified metric(mnt_metric) 77 | improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \ 78 | (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best) 79 | except KeyError: 80 | self.logger.warning("Warning: Metric '{}' is not found. " "Model performance monitoring is disabled.".format( 81 | self.mnt_metric)) 82 | self.mnt_mode = 'off' 83 | improved = False 84 | 85 | if improved: 86 | self.mnt_best = log[self.mnt_metric] 87 | not_improved_count = 0 88 | best = True 89 | else: 90 | not_improved_count += 1 91 | 92 | if not_improved_count > self.early_stop: 93 | self.logger.info("Validation performance didn\'t improve for {} epochs. " "Training stops.".format( 94 | self.early_stop)) 95 | break 96 | 97 | if epoch % self.save_period == 0: 98 | self._save_checkpoint(epoch, save_best=best) 99 | self._print_best() 100 | self._print_best_to_file() 101 | 102 | def _print_to_file(self, log): 103 | crt_time = time.asctime(time.localtime(time.time())) 104 | log['time'] = crt_time 105 | log['seed'] = self.args.seed 106 | log['best_model_from'] = 'train' 107 | 108 | if not os.path.exists(self.args.record_dir): 109 | os.makedirs(self.args.record_dir) 110 | record_path = os.path.join(self.args.record_dir, self.args.dataset_name+'.csv') 111 | if not os.path.exists(record_path): 112 | record_table = pd.DataFrame() 113 | else: 114 | record_table = pd.read_csv(record_path) 115 | record_table = record_table.append(log, ignore_index=True) 116 | record_table.to_csv(record_path, index=False) 117 | 118 | def _print_best_to_file(self): 119 | crt_time = time.asctime(time.localtime(time.time())) 120 | self.best_recorder['val']['time'] = crt_time 121 | self.best_recorder['test']['time'] = crt_time 122 | self.best_recorder['val']['seed'] = self.args.seed 123 | self.best_recorder['test']['seed'] = self.args.seed 124 | self.best_recorder['val']['best_model_from'] = 'val' 125 | self.best_recorder['test']['best_model_from'] = 'test' 126 | 127 | if not os.path.exists(self.args.record_dir): 128 | os.makedirs(self.args.record_dir) 129 | record_path = os.path.join(self.args.record_dir, self.args.dataset_name+'.csv') 130 | if not os.path.exists(record_path): 131 | record_table = pd.DataFrame() 132 | else: 133 | record_table = pd.read_csv(record_path) 134 | record_table = record_table.append(self.best_recorder['val'], ignore_index=True) 135 | record_table = record_table.append(self.best_recorder['test'], ignore_index=True) 136 | record_table.to_csv(record_path, index=False) 137 | 138 | def _prepare_device(self, n_gpu_use): 139 | n_gpu = torch.cuda.device_count() 140 | if n_gpu_use > 0 and n_gpu == 0: 141 | self.logger.warning("Warning: There\'s no GPU available on this machine," "training will be performed on CPU.") 142 | n_gpu_use = 0 143 | if n_gpu_use > n_gpu: 144 | self.logger.warning( 145 | "Warning: The number of GPU\'s configured to use is {}, but only {} are available " "on this machine.".format( 146 | n_gpu_use, n_gpu)) 147 | n_gpu_use = n_gpu 148 | device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu') 149 | list_ids = list(range(n_gpu_use)) 150 | return device, list_ids 151 | 152 | def _save_checkpoint(self, epoch, save_best=False): 153 | state = { 154 | 'epoch': epoch, 155 | 'state_dict': self.model.state_dict(), 156 | 'optimizer': self.optimizer.state_dict(), 157 | 'monitor_best': self.mnt_best 158 | } 159 | filename = os.path.join(self.checkpoint_dir, 'current_checkpoint.pth') 160 | torch.save(state, filename) 161 | self.logger.info("Saving checkpoint: {} ...".format(filename)) 162 | if save_best: 163 | best_path = os.path.join(self.checkpoint_dir, 'model_best.pth') 164 | torch.save(state, best_path) 165 | self.logger.info("Saving current best: model_best.pth ...") 166 | 167 | def _resume_checkpoint(self, resume_path): 168 | resume_path = str(resume_path) 169 | self.logger.info("Loading checkpoint: {} ...".format(resume_path)) 170 | checkpoint = torch.load(resume_path) 171 | self.start_epoch = checkpoint['epoch'] + 1 172 | self.mnt_best = checkpoint['monitor_best'] 173 | self.model.load_state_dict(checkpoint['state_dict']) 174 | self.optimizer.load_state_dict(checkpoint['optimizer']) 175 | 176 | self.logger.info("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch)) 177 | 178 | def _record_best(self, log): 179 | improved_val = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.best_recorder['val'][ 180 | self.mnt_metric]) or \ 181 | (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.best_recorder['val'][self.mnt_metric]) 182 | if improved_val: 183 | self.best_recorder['val'].update(log) 184 | 185 | improved_test = (self.mnt_mode == 'min' and log[self.mnt_metric_test] <= self.best_recorder['test'][ 186 | self.mnt_metric_test]) or \ 187 | (self.mnt_mode == 'max' and log[self.mnt_metric_test] >= self.best_recorder['test'][ 188 | self.mnt_metric_test]) 189 | if improved_test: 190 | self.best_recorder['test'].update(log) 191 | 192 | def _print_best(self): 193 | self.logger.info('Best results (w.r.t {}) in validation set:'.format(self.args.monitor_metric)) 194 | for key, value in self.best_recorder['val'].items(): 195 | self.logger.info('\t{:15s}: {}'.format(str(key), value)) 196 | 197 | self.logger.info('Best results (w.r.t {}) in test set:'.format(self.args.monitor_metric)) 198 | for key, value in self.best_recorder['test'].items(): 199 | self.logger.info('\t{:15s}: {}'.format(str(key), value)) 200 | 201 | 202 | class Trainer(BaseTrainer): 203 | def __init__(self, model, criterion, metric_ftns, optimizer, args, lr_scheduler, train_dataloader, val_dataloader, 204 | test_dataloader): 205 | super(Trainer, self).__init__(model, criterion, metric_ftns, optimizer, args) 206 | self.lr_scheduler = lr_scheduler 207 | self.train_dataloader = train_dataloader 208 | self.val_dataloader = val_dataloader 209 | self.test_dataloader = test_dataloader 210 | 211 | def _train_epoch(self, epoch): 212 | 213 | self.logger.info('[{}/{}] Start to train in the training set.'.format(epoch, self.epochs)) 214 | train_loss = 0 215 | self.model.train() 216 | for batch_idx, (images_id, images, reports_ids, reports_masks) in enumerate(self.train_dataloader): 217 | images, reports_ids, reports_masks = images.to(self.device), reports_ids.to(self.device),\ 218 | reports_masks.to(self.device) 219 | output = self.model(images, reports_ids, mode='train') 220 | loss = self.criterion(output, reports_ids, reports_masks) 221 | train_loss += loss.item() 222 | self.optimizer.zero_grad() 223 | loss.backward() 224 | torch.nn.utils.clip_grad_value_(self.model.parameters(), 0.1) 225 | self.optimizer.step() 226 | if batch_idx % self.args.log_period == 0: 227 | self.logger.info('[{}/{}] Step: {}/{}, Training Loss: {:.4f}.'.format(epoch, self.epochs, batch_idx, len(self.train_dataloader), train_loss / (batch_idx+1))) 228 | log = {'train_loss': train_loss / len(self.train_dataloader)} 229 | 230 | self.logger.info('[{}/{}] Start to evaluate in the validation set.'.format(epoch, self.epochs)) 231 | self.model.eval() 232 | with torch.no_grad(): 233 | val_gts, val_res = [], [] 234 | for batch_idx, (images_id, images, reports_ids, reports_masks) in enumerate(self.val_dataloader): 235 | images, reports_ids, reports_masks = images.to(self.device), reports_ids.to( 236 | self.device), reports_masks.to(self.device) 237 | output, _ = self.model(images, mode='sample') 238 | reports = self.model.tokenizer.decode_batch(output.cpu().numpy()) 239 | ground_truths = self.model.tokenizer.decode_batch(reports_ids[:, 1:].cpu().numpy()) 240 | # import pdb; pdb.set_trace() 241 | val_res.extend(reports) 242 | val_gts.extend(ground_truths) 243 | val_met = self.metric_ftns({i: [gt] for i, gt in enumerate(val_gts)}, 244 | {i: [re] for i, re in enumerate(val_res)}) 245 | log.update(**{'val_' + k: v for k, v in val_met.items()}) 246 | 247 | self.logger.info('[{}/{}] Start to evaluate in the test set.'.format(epoch, self.epochs)) 248 | self.model.eval() 249 | with torch.no_grad(): 250 | test_gts, test_res = [], [] 251 | for batch_idx, (images_id, images, reports_ids, reports_masks) in enumerate(self.test_dataloader): 252 | images, reports_ids, reports_masks = images.to(self.device), reports_ids.to( 253 | self.device), reports_masks.to(self.device) 254 | output, _ = self.model(images, mode='sample') 255 | reports = self.model.tokenizer.decode_batch(output.cpu().numpy()) 256 | ground_truths = self.model.tokenizer.decode_batch(reports_ids[:, 1:].cpu().numpy()) 257 | test_res.extend(reports) 258 | test_gts.extend(ground_truths) 259 | 260 | test_met = self.metric_ftns({i: [gt] for i, gt in enumerate(test_gts)}, 261 | {i: [re] for i, re in enumerate(test_res)}) 262 | log.update(**{'test_' + k: v for k, v in test_met.items()}) 263 | 264 | self.lr_scheduler.step() 265 | 266 | return log 267 | -------------------------------------------------------------------------------- /modules/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def penalty_builder(penalty_config): 5 | if penalty_config == '': 6 | return lambda x, y: y 7 | pen_type, alpha = penalty_config.split('_') 8 | alpha = float(alpha) 9 | if pen_type == 'wu': 10 | return lambda x, y: length_wu(x, y, alpha) 11 | if pen_type == 'avg': 12 | return lambda x, y: length_average(x, y, alpha) 13 | 14 | 15 | def length_wu(length, logprobs, alpha=0.): 16 | """ 17 | NMT length re-ranking score from 18 | "Google's Neural Machine Translation System" :cite:`wu2016google`. 19 | """ 20 | 21 | modifier = (((5 + length) ** alpha) / 22 | ((5 + 1) ** alpha)) 23 | return logprobs / modifier 24 | 25 | 26 | def length_average(length, logprobs, alpha=0.): 27 | """ 28 | Returns the average probability of tokens in a sequence. 29 | """ 30 | return logprobs / length 31 | 32 | 33 | def split_tensors(n, x): 34 | if torch.is_tensor(x): 35 | assert x.shape[0] % n == 0 36 | x = x.reshape(x.shape[0] // n, n, *x.shape[1:]).unbind(1) 37 | elif type(x) is list or type(x) is tuple: 38 | x = [split_tensors(n, _) for _ in x] 39 | elif x is None: 40 | x = [None] * n 41 | return x 42 | 43 | 44 | def repeat_tensors(n, x): 45 | """ 46 | For a tensor of size Bx..., we repeat it n times, and make it Bnx... 47 | For collections, do nested repeat 48 | """ 49 | if torch.is_tensor(x): 50 | x = x.unsqueeze(1) # Bx1x... 51 | x = x.expand(-1, n, *([-1] * len(x.shape[2:]))) # Bxnx... 52 | x = x.reshape(x.shape[0] * n, *x.shape[2:]) # Bnx... 53 | elif type(x) is list or type(x) is tuple: 54 | x = [repeat_tensors(n, _) for _ in x] 55 | return x 56 | -------------------------------------------------------------------------------- /modules/visual_extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | 5 | 6 | class VisualExtractor(nn.Module): 7 | def __init__(self, args): 8 | super(VisualExtractor, self).__init__() 9 | self.visual_extractor = args.visual_extractor 10 | self.pretrained = args.visual_extractor_pretrained 11 | model = getattr(models, self.visual_extractor)(pretrained=self.pretrained) 12 | modules = list(model.children())[:-2] 13 | self.model = nn.Sequential(*modules) 14 | self.avg_fnt = torch.nn.AvgPool2d(kernel_size=7, stride=1, padding=0) 15 | 16 | def forward(self, images): 17 | patch_feats = self.model(images) 18 | avg_feats = self.avg_fnt(patch_feats).squeeze().reshape(-1, patch_feats.size(1)) 19 | batch_size, feat_size, _, _ = patch_feats.shape 20 | patch_feats = patch_feats.reshape(batch_size, feat_size, -1).permute(0, 2, 1) 21 | return patch_feats, avg_feats 22 | -------------------------------------------------------------------------------- /plot.sh: -------------------------------------------------------------------------------- 1 | python help.py \ 2 | --image_dir data/mimic_cxr/images/ \ 3 | --ann_path data/mimic_cxr/annotation.json \ 4 | --dataset_name mimic_cxr \ 5 | --max_seq_length 100 \ 6 | --threshold 10 \ 7 | --batch_size 16 \ 8 | --epochs 30 \ 9 | --step_size 1 \ 10 | --gamma 0.8 -------------------------------------------------------------------------------- /pycocoevalcap/README.md: -------------------------------------------------------------------------------- 1 | Microsoft COCO Caption Evaluation Tools
2 | --- 3 | 4 | Modified the code to work with Python 3.
5 | 6 | ### Requirements 7 | * Python 3.x 8 | * Java 1.8 9 | * pycocotools 10 | 11 | --- 12 | 13 | ### Tested on 14 | * Windows 10, Python 3.5. 15 | 16 | --- 17 | ### To fix Windows JVM memory error:
18 | Add the following in System Variables
19 |     Variable name : _JAVA_OPTIONS
20 |     Variable value : -Xmx1024M
21 | 22 | --- 23 | Original code : https://github.com/tylin/coco-caption
24 | -------------------------------------------------------------------------------- /pycocoevalcap/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' -------------------------------------------------------------------------------- /pycocoevalcap/bleu/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015 Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /pycocoevalcap/bleu/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' -------------------------------------------------------------------------------- /pycocoevalcap/bleu/bleu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : bleu.py 4 | # 5 | # Description : Wrapper for BLEU scorer. 6 | # 7 | # Creation Date : 06-01-2015 8 | # Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | 11 | # Last modified : Wed 22 May 2019 08:10:00 PM EDT 12 | # By Sabarish Sivanath 13 | # To support Python 3 14 | 15 | from .bleu_scorer import BleuScorer 16 | 17 | 18 | class Bleu: 19 | def __init__(self, n=4): 20 | # default compute Blue score up to 4 21 | self._n = n 22 | self._hypo_for_image = {} 23 | self.ref_for_image = {} 24 | 25 | def compute_score(self, gts, res, score_option = 'closest', verbose = 1): 26 | ''' 27 | Inputs: 28 | gts - ground truths 29 | res - predictions 30 | score_option - {shortest, closest, average} 31 | verbose - 1 or 0 32 | Outputs: 33 | Blue scores 34 | ''' 35 | assert(gts.keys() == res.keys()) 36 | imgIds = gts.keys() 37 | 38 | bleu_scorer = BleuScorer(n=self._n) 39 | for id in imgIds: 40 | hypo = res[id] 41 | ref = gts[id] 42 | 43 | # Sanity check. 44 | assert(type(hypo) is list) 45 | assert(len(hypo) == 1) 46 | assert(type(ref) is list) 47 | #assert(len(ref) >= 1) 48 | 49 | bleu_scorer += (hypo[0], ref) 50 | 51 | score, scores = bleu_scorer.compute_score(option = score_option, verbose =verbose) 52 | 53 | # return (bleu, bleu_info) 54 | return score, scores 55 | 56 | def method(self): 57 | return "Bleu" 58 | -------------------------------------------------------------------------------- /pycocoevalcap/bleu/bleu_scorer.py: -------------------------------------------------------------------------------- 1 | # bleu_scorer.py 2 | # David Chiang 3 | 4 | # Copyright (c) 2004-2006 University of Maryland. All rights 5 | # reserved. Do not redistribute without permission from the 6 | # author. Not for commercial use. 7 | 8 | # Modified by: 9 | # Hao Fang 10 | # Tsung-Yi Lin 11 | 12 | # Last modified : Wed 22 May 2019 08:10:00 PM EDT 13 | # By Sabarish Sivanath 14 | # To support Python 3 15 | 16 | '''Provides: 17 | cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test(). 18 | cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked(). 19 | ''' 20 | 21 | import copy 22 | import sys, math, re 23 | from collections import defaultdict 24 | 25 | def precook(s, n=4, out=False): 26 | """Takes a string as input and returns an object that can be given to 27 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 28 | can take string arguments as well.""" 29 | words = s.split() 30 | counts = defaultdict(int) 31 | for k in range(1,n+1): 32 | for i in range(len(words)-k+1): 33 | ngram = tuple(words[i:i+k]) 34 | counts[ngram] += 1 35 | return (len(words), counts) 36 | 37 | def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average" 38 | '''Takes a list of reference sentences for a single segment 39 | and returns an object that encapsulates everything that BLEU 40 | needs to know about them.''' 41 | 42 | reflen = [] 43 | maxcounts = {} 44 | for ref in refs: 45 | rl, counts = precook(ref, n) 46 | reflen.append(rl) 47 | for (ngram,count) in counts.items(): 48 | maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 49 | 50 | # Calculate effective reference sentence length. 51 | if eff == "shortest": 52 | reflen = min(reflen) 53 | elif eff == "average": 54 | reflen = float(sum(reflen))/len(reflen) 55 | 56 | ## lhuang: N.B.: leave reflen computaiton to the very end!! 57 | 58 | ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design) 59 | 60 | return (reflen, maxcounts) 61 | 62 | def cook_test(test, refs , eff=None, n=4): 63 | '''Takes a test sentence and returns an object that 64 | encapsulates everything that BLEU needs to know about it.''' 65 | 66 | reflen = refs[0] 67 | refmaxcounts = refs[1] 68 | 69 | testlen, counts = precook(test, n, True) 70 | 71 | result = {} 72 | 73 | # Calculate effective reference sentence length. 74 | 75 | if eff == "closest": 76 | result["reflen"] = min((abs(l-testlen), l) for l in reflen)[1] 77 | else: ## i.e., "average" or "shortest" or None 78 | result["reflen"] = reflen 79 | 80 | result["testlen"] = testlen 81 | 82 | result["guess"] = [max(0,testlen-k+1) for k in range(1,n+1)] 83 | 84 | result['correct'] = [0]*n 85 | for (ngram, count) in counts.items(): 86 | result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count) 87 | 88 | return result 89 | 90 | class BleuScorer(object): 91 | """Bleu scorer. 92 | """ 93 | 94 | __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen" 95 | # special_reflen is used in oracle (proportional effective ref len for a node). 96 | 97 | def copy(self): 98 | ''' copy the refs.''' 99 | new = BleuScorer(n=self.n) 100 | new.ctest = copy.copy(self.ctest) 101 | new.crefs = copy.copy(self.crefs) 102 | new._score = None 103 | return new 104 | 105 | def __init__(self, test=None, refs=None, n=4, special_reflen=None): 106 | ''' singular instance ''' 107 | 108 | self.n = n 109 | self.crefs = [] 110 | self.ctest = [] 111 | self.cook_append(test, refs) 112 | self.special_reflen = special_reflen 113 | 114 | def cook_append(self, test, refs): 115 | '''called by constructor and __iadd__ to avoid creating new instances.''' 116 | 117 | if refs is not None: 118 | self.crefs.append(cook_refs(refs)) 119 | if test is not None: 120 | cooked_test = cook_test(test, self.crefs[-1]) 121 | self.ctest.append(cooked_test) ## N.B.: -1 122 | else: 123 | self.ctest.append(None) # lens of crefs and ctest have to match 124 | 125 | self._score = None ## need to recompute 126 | 127 | def ratio(self, option=None): 128 | self.compute_score(option=option) 129 | return self._ratio 130 | 131 | def score_ratio(self, option=None): 132 | '''return (bleu, len_ratio) pair''' 133 | return (self.fscore(option=option), self.ratio(option=option)) 134 | 135 | def score_ratio_str(self, option=None): 136 | return "%.4f (%.2f)" % self.score_ratio(option) 137 | 138 | def reflen(self, option=None): 139 | self.compute_score(option=option) 140 | return self._reflen 141 | 142 | def testlen(self, option=None): 143 | self.compute_score(option=option) 144 | return self._testlen 145 | 146 | def retest(self, new_test): 147 | if type(new_test) is str: 148 | new_test = [new_test] 149 | assert len(new_test) == len(self.crefs), new_test 150 | self.ctest = [] 151 | for t, rs in zip(new_test, self.crefs): 152 | self.ctest.append(cook_test(t, rs)) 153 | self._score = None 154 | 155 | return self 156 | 157 | def rescore(self, new_test): 158 | ''' replace test(s) with new test(s), and returns the new score.''' 159 | 160 | return self.retest(new_test).compute_score() 161 | 162 | def size(self): 163 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) 164 | return len(self.crefs) 165 | 166 | def __iadd__(self, other): 167 | '''add an instance (e.g., from another sentence).''' 168 | 169 | if type(other) is tuple: 170 | ## avoid creating new BleuScorer instances 171 | self.cook_append(other[0], other[1]) 172 | else: 173 | assert self.compatible(other), "incompatible BLEUs." 174 | self.ctest.extend(other.ctest) 175 | self.crefs.extend(other.crefs) 176 | self._score = None ## need to recompute 177 | 178 | return self 179 | 180 | def compatible(self, other): 181 | return isinstance(other, BleuScorer) and self.n == other.n 182 | 183 | def single_reflen(self, option="average"): 184 | return self._single_reflen(self.crefs[0][0], option) 185 | 186 | def _single_reflen(self, reflens, option=None, testlen=None): 187 | 188 | if option == "shortest": 189 | reflen = min(reflens) 190 | elif option == "average": 191 | reflen = float(sum(reflens))/len(reflens) 192 | elif option == "closest": 193 | reflen = min((abs(l-testlen), l) for l in reflens)[1] 194 | else: 195 | assert False, "unsupported reflen option %s" % option 196 | 197 | return reflen 198 | 199 | def recompute_score(self, option=None, verbose=0): 200 | self._score = None 201 | return self.compute_score(option, verbose) 202 | 203 | def compute_score(self, option=None, verbose=0): 204 | n = self.n 205 | small = 1e-9 206 | tiny = 1e-15 ## so that if guess is 0 still return 0 207 | bleu_list = [[] for _ in range(n)] 208 | 209 | if self._score is not None: 210 | return self._score 211 | 212 | if option is None: 213 | option = "average" if len(self.crefs) == 1 else "closest" 214 | 215 | self._testlen = 0 216 | self._reflen = 0 217 | totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n} 218 | 219 | # for each sentence 220 | for comps in self.ctest: 221 | testlen = comps['testlen'] 222 | self._testlen += testlen 223 | 224 | if self.special_reflen is None: ## need computation 225 | reflen = self._single_reflen(comps['reflen'], option, testlen) 226 | else: 227 | reflen = self.special_reflen 228 | 229 | self._reflen += reflen 230 | 231 | for key in ['guess','correct']: 232 | for k in range(n): 233 | totalcomps[key][k] += comps[key][k] 234 | 235 | # append per image bleu score 236 | bleu = 1. 237 | for k in range(n): 238 | bleu *= (float(comps['correct'][k]) + tiny) \ 239 | /(float(comps['guess'][k]) + small) 240 | bleu_list[k].append(bleu ** (1./(k+1))) 241 | ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division 242 | if ratio < 1: 243 | for k in range(n): 244 | bleu_list[k][-1] *= math.exp(1 - 1/ratio) 245 | 246 | if verbose > 1: 247 | print(comps, reflen) 248 | 249 | totalcomps['reflen'] = self._reflen 250 | totalcomps['testlen'] = self._testlen 251 | 252 | bleus = [] 253 | bleu = 1. 254 | for k in range(n): 255 | bleu *= float(totalcomps['correct'][k] + tiny) \ 256 | / (totalcomps['guess'][k] + small) 257 | bleus.append(bleu ** (1./(k+1))) 258 | ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division 259 | if ratio < 1: 260 | for k in range(n): 261 | bleus[k] *= math.exp(1 - 1/ratio) 262 | 263 | if verbose > 0: 264 | print(totalcomps) 265 | print("ratio:", ratio) 266 | 267 | self._score = bleus 268 | return self._score, bleu_list 269 | -------------------------------------------------------------------------------- /pycocoevalcap/cider/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /pycocoevalcap/cider/cider.py: -------------------------------------------------------------------------------- 1 | # Filename: cider.py 2 | # 3 | # Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric 4 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) 5 | # 6 | # Creation Date: Sun Feb 8 14:16:54 2015 7 | # 8 | # Authors: Ramakrishna Vedantam and Tsung-Yi Lin 9 | 10 | 11 | from .cider_scorer import CiderScorer 12 | import pdb 13 | 14 | class Cider: 15 | """ 16 | Main Class to compute the CIDEr metric 17 | 18 | """ 19 | def __init__(self, test=None, refs=None, n=4, sigma=6.0): 20 | # set cider to sum over 1 to 4-grams 21 | self._n = n 22 | # set the standard deviation parameter for gaussian penalty 23 | self._sigma = sigma 24 | 25 | def compute_score(self, gts, res): 26 | """ 27 | Main function to compute CIDEr score 28 | :param hypo_for_image (dict) : dictionary with key and value 29 | ref_for_image (dict) : dictionary with key and value 30 | :return: cider (float) : computed CIDEr score for the corpus 31 | """ 32 | 33 | assert(gts.keys() == res.keys()) 34 | imgIds = gts.keys() 35 | 36 | cider_scorer = CiderScorer(n=self._n, sigma=self._sigma) 37 | 38 | for id in imgIds: 39 | hypo = res[id] 40 | ref = gts[id] 41 | 42 | # Sanity check. 43 | assert(type(hypo) is list) 44 | assert(len(hypo) == 1) 45 | assert(type(ref) is list) 46 | assert(len(ref) > 0) 47 | 48 | cider_scorer += (hypo[0], ref) 49 | 50 | (score, scores) = cider_scorer.compute_score() 51 | 52 | return score, scores 53 | 54 | def method(self): 55 | return "CIDEr" -------------------------------------------------------------------------------- /pycocoevalcap/cider/cider_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Tsung-Yi Lin 3 | # Ramakrishna Vedantam 4 | 5 | 6 | # Last modified : Wed 22 May 2019 08:10:00 PM EDT 7 | # By Sabarish Sivanath 8 | # To support Python 3 9 | 10 | import copy 11 | from collections import defaultdict 12 | import numpy as np 13 | import pdb 14 | import math 15 | 16 | def precook(s, n=4, out=False): 17 | """ 18 | Takes a string as input and returns an object that can be given to 19 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 20 | can take string arguments as well. 21 | :param s: string : sentence to be converted into ngrams 22 | :param n: int : number of ngrams for which representation is calculated 23 | :return: term frequency vector for occuring ngrams 24 | """ 25 | words = s.split() 26 | counts = defaultdict(int) 27 | for k in range(1,n+1): 28 | for i in range(len(words)-k+1): 29 | ngram = tuple(words[i:i+k]) 30 | counts[ngram] += 1 31 | return counts 32 | 33 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average" 34 | '''Takes a list of reference sentences for a single segment 35 | and returns an object that encapsulates everything that BLEU 36 | needs to know about them. 37 | :param refs: list of string : reference sentences for some image 38 | :param n: int : number of ngrams for which (ngram) representation is calculated 39 | :return: result (list of dict) 40 | ''' 41 | return [precook(ref, n) for ref in refs] 42 | 43 | def cook_test(test, n=4): 44 | '''Takes a test sentence and returns an object that 45 | encapsulates everything that BLEU needs to know about it. 46 | :param test: list of string : hypothesis sentence for some image 47 | :param n: int : number of ngrams for which (ngram) representation is calculated 48 | :return: result (dict) 49 | ''' 50 | return precook(test, n, True) 51 | 52 | class CiderScorer(object): 53 | """CIDEr scorer. 54 | """ 55 | 56 | def copy(self): 57 | ''' copy the refs.''' 58 | new = CiderScorer(n=self.n) 59 | new.ctest = copy.copy(self.ctest) 60 | new.crefs = copy.copy(self.crefs) 61 | return new 62 | 63 | def __init__(self, test=None, refs=None, n=4, sigma=6.0): 64 | ''' singular instance ''' 65 | self.n = n 66 | self.sigma = sigma 67 | self.crefs = [] 68 | self.ctest = [] 69 | self.document_frequency = defaultdict(float) 70 | self.cook_append(test, refs) 71 | self.ref_len = None 72 | 73 | def cook_append(self, test, refs): 74 | '''called by constructor and __iadd__ to avoid creating new instances.''' 75 | 76 | if refs is not None: 77 | self.crefs.append(cook_refs(refs)) 78 | if test is not None: 79 | self.ctest.append(cook_test(test)) ## N.B.: -1 80 | else: 81 | self.ctest.append(None) # lens of crefs and ctest have to match 82 | 83 | def size(self): 84 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) 85 | return len(self.crefs) 86 | 87 | def __iadd__(self, other): 88 | '''add an instance (e.g., from another sentence).''' 89 | 90 | if type(other) is tuple: 91 | ## avoid creating new CiderScorer instances 92 | self.cook_append(other[0], other[1]) 93 | else: 94 | self.ctest.extend(other.ctest) 95 | self.crefs.extend(other.crefs) 96 | 97 | return self 98 | def compute_doc_freq(self): 99 | ''' 100 | Compute term frequency for reference data. 101 | This will be used to compute idf (inverse document frequency later) 102 | The term frequency is stored in the object 103 | :return: None 104 | ''' 105 | for refs in self.crefs: 106 | # refs, k ref captions of one image 107 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]): 108 | self.document_frequency[ngram] += 1 109 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 110 | 111 | def compute_cider(self): 112 | def counts2vec(cnts): 113 | """ 114 | Function maps counts of ngram to vector of tfidf weights. 115 | The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights. 116 | The n-th entry of array denotes length of n-grams. 117 | :param cnts: 118 | :return: vec (array of dict), norm (array of float), length (int) 119 | """ 120 | vec = [defaultdict(float) for _ in range(self.n)] 121 | length = 0 122 | norm = [0.0 for _ in range(self.n)] 123 | for (ngram,term_freq) in cnts.items(): 124 | # give word count 1 if it doesn't appear in reference corpus 125 | df = np.log(max(1.0, self.document_frequency[ngram])) 126 | # ngram index 127 | n = len(ngram)-1 128 | # tf (term_freq) * idf (precomputed idf) for n-grams 129 | vec[n][ngram] = float(term_freq)*(self.ref_len - df) 130 | # compute norm for the vector. the norm will be used for computing similarity 131 | norm[n] += pow(vec[n][ngram], 2) 132 | 133 | if n == 1: 134 | length += term_freq 135 | norm = [np.sqrt(n) for n in norm] 136 | return vec, norm, length 137 | 138 | def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref): 139 | ''' 140 | Compute the cosine similarity of two vectors. 141 | :param vec_hyp: array of dictionary for vector corresponding to hypothesis 142 | :param vec_ref: array of dictionary for vector corresponding to reference 143 | :param norm_hyp: array of float for vector corresponding to hypothesis 144 | :param norm_ref: array of float for vector corresponding to reference 145 | :param length_hyp: int containing length of hypothesis 146 | :param length_ref: int containing length of reference 147 | :return: array of score for each n-grams cosine similarity 148 | ''' 149 | delta = float(length_hyp - length_ref) 150 | # measure consine similarity 151 | val = np.array([0.0 for _ in range(self.n)]) 152 | for n in range(self.n): 153 | # ngram 154 | for (ngram,count) in vec_hyp[n].items(): 155 | # vrama91 : added clipping 156 | val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram] 157 | 158 | if (norm_hyp[n] != 0) and (norm_ref[n] != 0): 159 | val[n] /= (norm_hyp[n]*norm_ref[n]) 160 | 161 | assert(not math.isnan(val[n])) 162 | # vrama91: added a length based gaussian penalty 163 | val[n] *= np.e**(-(delta**2)/(2*self.sigma**2)) 164 | return val 165 | 166 | # compute log reference length 167 | self.ref_len = np.log(float(len(self.crefs))) 168 | 169 | scores = [] 170 | for test, refs in zip(self.ctest, self.crefs): 171 | # compute vector for test captions 172 | vec, norm, length = counts2vec(test) 173 | # compute vector for ref captions 174 | score = np.array([0.0 for _ in range(self.n)]) 175 | for ref in refs: 176 | vec_ref, norm_ref, length_ref = counts2vec(ref) 177 | score += sim(vec, vec_ref, norm, norm_ref, length, length_ref) 178 | # change by vrama91 - mean of ngram scores, instead of sum 179 | score_avg = np.mean(score) 180 | # divide by number of references 181 | score_avg /= len(refs) 182 | # multiply score by 10 183 | score_avg *= 10.0 184 | # append score of an image to the score list 185 | scores.append(score_avg) 186 | return scores 187 | 188 | def compute_score(self, option=None, verbose=0): 189 | # compute idf 190 | self.compute_doc_freq() 191 | # assert to check document frequency 192 | assert(len(self.ctest) >= max(self.document_frequency.values())) 193 | # compute cider score 194 | score = self.compute_cider() 195 | # debug 196 | # print score 197 | return np.mean(np.array(score)), np.array(score) -------------------------------------------------------------------------------- /pycocoevalcap/eval.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | from .tokenizer.ptbtokenizer import PTBTokenizer 3 | from .bleu.bleu import Bleu 4 | from .meteor.meteor import Meteor 5 | from .rouge.rouge import Rouge 6 | from .cider.cider import Cider 7 | 8 | class COCOEvalCap: 9 | def __init__(self, coco, cocoRes): 10 | self.evalImgs = [] 11 | self.eval = {} 12 | self.imgToEval = {} 13 | self.coco = coco 14 | self.cocoRes = cocoRes 15 | self.params = {'image_id': cocoRes.getImgIds()} 16 | 17 | def evaluate(self): 18 | imgIds = self.params['image_id'] 19 | # imgIds = self.coco.getImgIds() 20 | gts = {} 21 | res = {} 22 | for imgId in imgIds: 23 | gts[imgId] = self.coco.imgToAnns[imgId] 24 | res[imgId] = self.cocoRes.imgToAnns[imgId] 25 | 26 | # ================================================= 27 | # Set up scorers 28 | # ================================================= 29 | print('tokenization...') 30 | tokenizer = PTBTokenizer() 31 | gts = tokenizer.tokenize(gts) 32 | res = tokenizer.tokenize(res) 33 | 34 | # ================================================= 35 | # Set up scorers 36 | # ================================================= 37 | print('setting up scorers...') 38 | scorers = [ 39 | (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]), 40 | (Meteor(),"METEOR"), 41 | (Rouge(), "ROUGE_L"), 42 | (Cider(), "CIDEr") 43 | ] 44 | 45 | # ================================================= 46 | # Compute scores 47 | # ================================================= 48 | eval = {} 49 | for scorer, method in scorers: 50 | print('computing %s score...'%(scorer.method())) 51 | score, scores = scorer.compute_score(gts, res) 52 | if type(method) == list: 53 | for sc, scs, m in zip(score, scores, method): 54 | self.setEval(sc, m) 55 | self.setImgToEvalImgs(scs, imgIds, m) 56 | print("%s: %0.3f"%(m, sc)) 57 | else: 58 | self.setEval(score, method) 59 | self.setImgToEvalImgs(scores, imgIds, method) 60 | print("%s: %0.3f"%(method, score)) 61 | self.setEvalImgs() 62 | 63 | def setEval(self, score, method): 64 | self.eval[method] = score 65 | 66 | def setImgToEvalImgs(self, scores, imgIds, method): 67 | for imgId, score in zip(imgIds, scores): 68 | if not imgId in self.imgToEval: 69 | self.imgToEval[imgId] = {} 70 | self.imgToEval[imgId]["image_id"] = imgId 71 | self.imgToEval[imgId][method] = score 72 | 73 | def setEvalImgs(self): 74 | self.evalImgs = [eval for imgId, eval in self.imgToEval.items()] 75 | -------------------------------------------------------------------------------- /pycocoevalcap/license.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015, Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 2. Redistributions in binary form must reproduce the above copyright notice, 10 | this list of conditions and the following disclaimer in the documentation 11 | and/or other materials provided with the distribution. 12 | 13 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 14 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 15 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 16 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 17 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 18 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 19 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 20 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 21 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 22 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 23 | 24 | The views and conclusions contained in the software and documentation are those 25 | of the authors and should not be interpreted as representing official policies, 26 | either expressed or implied, of the FreeBSD Project. -------------------------------------------------------------------------------- /pycocoevalcap/meteor/__init__.py: -------------------------------------------------------------------------------- 1 | from .meteor import * -------------------------------------------------------------------------------- /pycocoevalcap/meteor/meteor-1.5.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2GenRL/214b4dcfdde5752d2c49e774780528519ca740fd/pycocoevalcap/meteor/meteor-1.5.jar -------------------------------------------------------------------------------- /pycocoevalcap/meteor/meteor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Python wrapper for METEOR implementation, by Xinlei Chen 4 | # Acknowledge Michael Denkowski for the generous discussion and help 5 | 6 | # Last modified : Wed 22 May 2019 08:10:00 PM EDT 7 | # By Sabarish Sivanath 8 | # To support Python 3 9 | 10 | import os 11 | import sys 12 | import subprocess 13 | import threading 14 | 15 | # Assumes meteor-1.5.jar is in the same directory as meteor.py. Change as needed. 16 | METEOR_JAR = 'meteor-1.5.jar' 17 | # print METEOR_JAR 18 | 19 | class Meteor: 20 | 21 | def __init__(self): 22 | self.meteor_cmd = ['java', '-jar', '-Xmx2G', METEOR_JAR, \ 23 | '-', '-', '-stdio', '-l', 'en', '-norm'] 24 | self.meteor_p = subprocess.Popen(self.meteor_cmd, \ 25 | cwd=os.path.dirname(os.path.abspath(__file__)), \ 26 | stdin=subprocess.PIPE, \ 27 | stdout=subprocess.PIPE, \ 28 | stderr=subprocess.PIPE, 29 | universal_newlines = True, 30 | bufsize = 1) 31 | # Used to guarantee thread safety 32 | self.lock = threading.Lock() 33 | 34 | def compute_score(self, gts, res): 35 | assert(gts.keys() == res.keys()) 36 | imgIds = gts.keys() 37 | scores = [] 38 | 39 | eval_line = 'EVAL' 40 | self.lock.acquire() 41 | for i in imgIds: 42 | assert(len(res[i]) == 1) 43 | stat = self._stat(res[i][0], gts[i]) 44 | eval_line += ' ||| {}'.format(stat) 45 | 46 | self.meteor_p.stdin.write('{}\n'.format(eval_line)) 47 | for i in range(0,len(imgIds)): 48 | scores.append(float(self.meteor_p.stdout.readline().strip())) 49 | score = float(self.meteor_p.stdout.readline().strip()) 50 | self.lock.release() 51 | 52 | return score, scores 53 | 54 | def method(self): 55 | return "METEOR" 56 | 57 | def _stat(self, hypothesis_str, reference_list): 58 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words 59 | hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ') 60 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) 61 | self.meteor_p.stdin.write('{}\n'.format(score_line)) 62 | return self.meteor_p.stdout.readline().strip() 63 | 64 | def _score(self, hypothesis_str, reference_list): 65 | self.lock.acquire() 66 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words 67 | hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ') 68 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) 69 | self.meteor_p.stdin.write('{}\n'.format(score_line)) 70 | stats = self.meteor_p.stdout.readline().strip() 71 | eval_line = 'EVAL ||| {}'.format(stats) 72 | # EVAL ||| stats 73 | self.meteor_p.stdin.write('{}\n'.format(eval_line)) 74 | score = float(self.meteor_p.stdout.readline().strip()) 75 | # bug fix: there are two values returned by the jar file, one average, and one all, so do it twice 76 | # thanks for Andrej for pointing this out 77 | score = float(self.meteor_p.stdout.readline().strip()) 78 | self.lock.release() 79 | return score 80 | 81 | def __del__(self): 82 | self.lock.acquire() 83 | self.meteor_p.stdin.close() 84 | self.meteor_p.kill() 85 | self.meteor_p.wait() 86 | self.lock.release() 87 | -------------------------------------------------------------------------------- /pycocoevalcap/rouge/__init__.py: -------------------------------------------------------------------------------- 1 | from .rouge import * -------------------------------------------------------------------------------- /pycocoevalcap/rouge/rouge.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : rouge.py 4 | # 5 | # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004) 6 | # 7 | # Creation Date : 2015-01-07 06:03 8 | # Author : Ramakrishna Vedantam 9 | 10 | import numpy as np 11 | import pdb 12 | 13 | def my_lcs(string, sub): 14 | """ 15 | Calculates longest common subsequence for a pair of tokenized strings 16 | :param string : list of str : tokens from a string split using whitespace 17 | :param sub : list of str : shorter string, also split using whitespace 18 | :returns: length (list of int): length of the longest common subsequence between the two strings 19 | 20 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS 21 | """ 22 | if(len(string)< len(sub)): 23 | sub, string = string, sub 24 | 25 | lengths = [[0 for i in range(0,len(sub)+1)] for j in range(0,len(string)+1)] 26 | 27 | for j in range(1,len(sub)+1): 28 | for i in range(1,len(string)+1): 29 | if(string[i-1] == sub[j-1]): 30 | lengths[i][j] = lengths[i-1][j-1] + 1 31 | else: 32 | lengths[i][j] = max(lengths[i-1][j] , lengths[i][j-1]) 33 | 34 | return lengths[len(string)][len(sub)] 35 | 36 | class Rouge(): 37 | ''' 38 | Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set 39 | 40 | ''' 41 | def __init__(self): 42 | # vrama91: updated the value below based on discussion with Hovey 43 | self.beta = 1.2 44 | 45 | def calc_score(self, candidate, refs): 46 | """ 47 | Compute ROUGE-L score given one candidate and references for an image 48 | :param candidate: str : candidate sentence to be evaluated 49 | :param refs: list of str : COCO reference sentences for the particular image to be evaluated 50 | :returns score: int (ROUGE-L score for the candidate evaluated against references) 51 | """ 52 | assert(len(candidate)==1) 53 | assert(len(refs)>0) 54 | prec = [] 55 | rec = [] 56 | 57 | # split into tokens 58 | token_c = candidate[0].split(" ") 59 | 60 | for reference in refs: 61 | # split into tokens 62 | token_r = reference.split(" ") 63 | # compute the longest common subsequence 64 | lcs = my_lcs(token_r, token_c) 65 | prec.append(lcs/float(len(token_c))) 66 | rec.append(lcs/float(len(token_r))) 67 | 68 | prec_max = max(prec) 69 | rec_max = max(rec) 70 | 71 | if(prec_max!=0 and rec_max !=0): 72 | score = ((1 + self.beta**2)*prec_max*rec_max)/float(rec_max + self.beta**2*prec_max) 73 | else: 74 | score = 0.0 75 | return score 76 | 77 | def compute_score(self, gts, res): 78 | """ 79 | Computes Rouge-L score given a set of reference and candidate sentences for the dataset 80 | Invoked by evaluate_captions.py 81 | :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values 82 | :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values 83 | :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images) 84 | """ 85 | assert(gts.keys() == res.keys()) 86 | imgIds = gts.keys() 87 | 88 | score = [] 89 | for id in imgIds: 90 | hypo = res[id] 91 | ref = gts[id] 92 | 93 | score.append(self.calc_score(hypo, ref)) 94 | 95 | # Sanity check. 96 | assert(type(hypo) is list) 97 | assert(len(hypo) == 1) 98 | assert(type(ref) is list) 99 | assert(len(ref) > 0) 100 | 101 | average_score = np.mean(np.array(score)) 102 | return average_score, np.array(score) 103 | 104 | def method(self): 105 | return "Rouge" 106 | -------------------------------------------------------------------------------- /pycocoevalcap/tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'hfang' 2 | -------------------------------------------------------------------------------- /pycocoevalcap/tokenizer/ptbtokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : ptbtokenizer.py 4 | # 5 | # Description : Do the PTB Tokenization and remove punctuations. 6 | # 7 | # Creation Date : 29-12-2014 8 | # Last Modified : Thu Mar 19 09:53:35 2015 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | 11 | import os 12 | import sys 13 | import subprocess 14 | import tempfile 15 | import itertools 16 | 17 | 18 | # Last modified : Wed 22 May 2019 08:10:00 PM EDT 19 | # By Sabarish Sivanath 20 | # To support Python 3 21 | 22 | # path to the stanford corenlp jar 23 | STANFORD_CORENLP_3_4_1_JAR = 'stanford-corenlp-3.4.1.jar' 24 | 25 | # punctuations to be removed from the sentences 26 | PUNCTUATIONS = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \ 27 | ".", "?", "!", ",", ":", "-", "--", "...", ";"] 28 | 29 | class PTBTokenizer: 30 | """Python wrapper of Stanford PTBTokenizer""" 31 | 32 | def tokenize(self, captions_for_image): 33 | cmd = ['java', '-cp', STANFORD_CORENLP_3_4_1_JAR, \ 34 | 'edu.stanford.nlp.process.PTBTokenizer', \ 35 | '-preserveLines', '-lowerCase'] 36 | 37 | # ====================================================== 38 | # prepare data for PTB Tokenizer 39 | # ====================================================== 40 | final_tokenized_captions_for_image = {} 41 | image_id = [k for k, v in captions_for_image.items() for _ in range(len(v))] 42 | sentences = '\n'.join([c['caption'].replace('\n', ' ') for k, v in captions_for_image.items() for c in v]) 43 | 44 | # ====================================================== 45 | # save sentences to temporary file 46 | # ====================================================== 47 | path_to_jar_dirname=os.path.dirname(os.path.abspath(__file__)) 48 | tmp_file = tempfile.NamedTemporaryFile(delete=False, dir=path_to_jar_dirname) 49 | tmp_file.write(sentences.encode('utf-8')) 50 | tmp_file.close() 51 | 52 | # ====================================================== 53 | # tokenize sentence 54 | # ====================================================== 55 | cmd.append(os.path.basename(tmp_file.name)) 56 | p_tokenizer = subprocess.Popen(cmd, 57 | cwd=path_to_jar_dirname, 58 | stdout=subprocess.PIPE, 59 | universal_newlines = True, 60 | bufsize = 1) 61 | token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0] 62 | lines = token_lines.split('\n') 63 | # remove temp file 64 | os.remove(tmp_file.name) 65 | 66 | # ====================================================== 67 | # create dictionary for tokenized captions 68 | # ====================================================== 69 | for k, line in zip(image_id, lines): 70 | if not k in final_tokenized_captions_for_image: 71 | final_tokenized_captions_for_image[k] = [] 72 | tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') \ 73 | if w not in PUNCTUATIONS]) 74 | final_tokenized_captions_for_image[k].append(tokenized_caption) 75 | 76 | return final_tokenized_captions_for_image 77 | -------------------------------------------------------------------------------- /pycocoevalcap/tokenizer/stanford-corenlp-3.4.1.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synlp/R2GenRL/214b4dcfdde5752d2c49e774780528519ca740fd/pycocoevalcap/tokenizer/stanford-corenlp-3.4.1.jar -------------------------------------------------------------------------------- /run.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -J cmm+rl 3 | #SBATCH -p p-V100 4 | #SBATCH -N 1 5 | #SBATCH --ntasks-per-node=2 6 | #SBATCH --gres=gpu:1 7 | 8 | source /cm/shared/apps/anaconda3/etc/profile.d/conda.sh 9 | 10 | conda activate a100 11 | 12 | seed=32103 13 | # seed=${RANDOM} 14 | noamopt_warmup=1000 15 | 16 | # python train.py \ 17 | # --image_dir data/mimic_cxr/images/ \ 18 | # --ann_path data/mimic_cxr/annotation.json \ 19 | # --dataset_name mimic_cxr \ 20 | # --max_seq_length 100 \ 21 | # --threshold 10 \ 22 | # --batch_size 16 \ 23 | # --epochs 30 \ 24 | # --save_dir results/mimic_cxr/base_cmn_${seed} \ 25 | # --step_size 1 \ 26 | # --gamma 0.8 \ 27 | # --seed ${seed} \ 28 | # --topk 32 \ 29 | # --noamopt_warmup ${noamopt_warmup} 30 | 31 | RESUME=results/mimic_cxr/base_cmn_seed_${seed} 32 | 33 | seed=${RANDOM} 34 | noamopt_warmup=1000 35 | save_dir=results/mimic_cxr/rl_cmn_seed_${seed} 36 | 37 | python train_rl.py \ 38 | --image_dir data/mimic_cxr/images/ \ 39 | --ann_path data/mimic_cxr/annotation.json \ 40 | --dataset_name mimic_cxr \ 41 | --max_seq_length 100 \ 42 | --threshold 10 \ 43 | --batch_size 6 \ 44 | --epochs 50 \ 45 | --save_dir ${save_dir} \ 46 | --step_size 1 \ 47 | --gamma 0.8 \ 48 | --seed ${seed} \ 49 | --topk 32 \ 50 | --sc_eval_period 3000 \ 51 | --resume ${RESUME}/model_best.pth -------------------------------------------------------------------------------- /run_base.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -J base+rl 3 | #SBATCH -p p-V100 4 | #SBATCH -N 1 5 | #SBATCH --ntasks-per-node=2 6 | #SBATCH --gres=gpu:1 7 | 8 | source /cm/shared/apps/anaconda3/etc/profile.d/conda.sh 9 | 10 | conda activate a100 11 | 12 | seed=23838 13 | # seed=${RANDOM} 14 | noamopt_warmup=1000 15 | 16 | # python train_base.py \ 17 | # --image_dir data/mimic_cxr/images/ \ 18 | # --ann_path data/mimic_cxr/annotation.json \ 19 | # --dataset_name mimic_cxr \ 20 | # --max_seq_length 100 \ 21 | # --threshold 10 \ 22 | # --batch_size 16 \ 23 | # --epochs 30 \ 24 | # --save_dir results/mimic_cxr/base_seed_${seed} \ 25 | # --step_size 1 \ 26 | # --gamma 0.8 \ 27 | # --seed ${seed} 28 | 29 | RESUME=results/mimic_cxr/base_seed_${seed} 30 | 31 | seed=${RANDOM} 32 | noamopt_warmup=1000 33 | save_dir=results/mimic_cxr/rl_base_seed_${seed} 34 | echo "seed ${seed}" 35 | 36 | python train_rl_base.py \ 37 | --image_dir data/mimic_cxr/images/ \ 38 | --ann_path data/mimic_cxr/annotation.json \ 39 | --dataset_name mimic_cxr \ 40 | --max_seq_length 100 \ 41 | --threshold 10 \ 42 | --batch_size 6 \ 43 | --epochs 50 \ 44 | --save_dir ${save_dir} \ 45 | --step_size 1 \ 46 | --gamma 0.8 \ 47 | --seed ${seed} \ 48 | --topk 32 \ 49 | --sc_eval_period 3000 \ 50 | --resume ${RESUME}/current_checkpoint.pth 51 | -------------------------------------------------------------------------------- /scripts/iu_xray/run_rl.sh: -------------------------------------------------------------------------------- 1 | seed=${RANDOM} 2 | noamopt_warmup=1000 3 | 4 | RESUME=${1} 5 | 6 | python train_rl.py \ 7 | --image_dir data/iu_xray/images/ \ 8 | --ann_path data/iu_xray/annotation.json \ 9 | --dataset_name iu_xray \ 10 | --max_seq_length 60 \ 11 | --threshold 3 \ 12 | --batch_size 10 \ 13 | --epochs 200 \ 14 | --save_dir ${RESUME} \ 15 | --step_size 1 \ 16 | --gamma 0.8 \ 17 | --seed ${seed} \ 18 | --topk 32 \ 19 | --beam_size 3 \ 20 | --log_period 100 \ 21 | --resume ${RESUME}/model_best.pth 22 | -------------------------------------------------------------------------------- /scripts/mimic_cxr/run_rl.sh: -------------------------------------------------------------------------------- 1 | seed=${RANDOM} 2 | 3 | mkdir -p results/mimic_cxr/base_cmn_rl/ 4 | mkdir -p records/mimic_cxr/base_cmn_rl/ 5 | 6 | python train_rl.py \ 7 | --image_dir data/mimic_cxr/images/ \ 8 | --ann_path data/mimic_cxr/annotation.json \ 9 | --dataset_name mimic_cxr \ 10 | --max_seq_length 100 \ 11 | --threshold 10 \ 12 | --batch_size 6 \ 13 | --epochs 50 \ 14 | --save_dir results/mimic_cxr/base_cmn_rl/ \ 15 | --record_dir records/mimic_cxr/base_cmn_rl/ \ 16 | --step_size 1 \ 17 | --gamma 0.8 \ 18 | --seed ${seed} \ 19 | --topk 32 \ 20 | --sc_eval_period 3000 21 | 22 | # python train_rl.py \ 23 | # --image_dir data/mimic_cxr/images/ \ 24 | # --ann_path data/mimic_cxr/annotation.json \ 25 | # --dataset_name mimic_cxr \ 26 | # --max_seq_length 100 \ 27 | # --threshold 10 \ 28 | # --batch_size 6 \ 29 | # --epochs 50 \ 30 | # --save_dir results/mimic_cxr/base_cmn_rl/ \ 31 | # --record_dir records/mimic_cxr/base_cmn_rl/ \ 32 | # --step_size 1 \ 33 | # --gamma 0.8 \ 34 | # --seed ${seed} \ 35 | # --topk 32 \ 36 | # --sc_eval_period 3000 \ 37 | # --resume results/mimic_cxr/mimic_cxr_0.8_1_16_5e-5_1e-4_3_3_32_2048_512_30799/current_checkpoint.pth 38 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import numpy as np 4 | from modules.tokenizers import Tokenizer 5 | from modules.dataloaders import R2DataLoader 6 | from modules.metrics import compute_scores 7 | from modules.optimizers import build_optimizer, build_lr_scheduler, build_noamopt_optimizer 8 | from modules.trainer import Trainer 9 | from modules.loss import compute_loss 10 | from models.models import BaseCMNModel 11 | 12 | 13 | def parse_agrs(): 14 | parser = argparse.ArgumentParser() 15 | 16 | # Data input settings 17 | parser.add_argument('--image_dir', type=str, default='data/iu_xray/images/', help='the path to the directory containing the data.') 18 | parser.add_argument('--ann_path', type=str, default='data/iu_xray/annotation.json', help='the path to the directory containing the data.') 19 | 20 | # Data loader settings 21 | parser.add_argument('--dataset_name', type=str, default='iu_xray', choices=['iu_xray', 'mimic_cxr'], help='the dataset to be used.') 22 | parser.add_argument('--max_seq_length', type=int, default=60, help='the maximum sequence length of the reports.') 23 | parser.add_argument('--threshold', type=int, default=3, help='the cut off frequency for the words.') 24 | parser.add_argument('--num_workers', type=int, default=2, help='the number of workers for dataloader.') 25 | parser.add_argument('--batch_size', type=int, default=16, help='the number of samples for a batch') 26 | 27 | # Model settings (for visual extractor) 28 | parser.add_argument('--visual_extractor', type=str, default='resnet101', help='the visual extractor to be used.') 29 | parser.add_argument('--visual_extractor_pretrained', type=bool, default=True, help='whether to load the pretrained visual extractor') 30 | parser.add_argument('--num_labels', type=int, default=14, help='the size of the label set') 31 | 32 | # Model settings (for Transformer) 33 | parser.add_argument('--d_model', type=int, default=512, help='the dimension of Transformer.') 34 | parser.add_argument('--d_ff', type=int, default=512, help='the dimension of FFN.') 35 | parser.add_argument('--d_vf', type=int, default=2048, help='the dimension of the patch features.') 36 | parser.add_argument('--num_heads', type=int, default=8, help='the number of heads in Transformer.') 37 | parser.add_argument('--num_layers', type=int, default=3, help='the number of layers of Transformer.') 38 | parser.add_argument('--dropout', type=float, default=0.1, help='the dropout rate of Transformer.') 39 | parser.add_argument('--logit_layers', type=int, default=1, help='the number of the logit layer.') 40 | parser.add_argument('--bos_idx', type=int, default=0, help='the index of .') 41 | parser.add_argument('--eos_idx', type=int, default=0, help='the index of .') 42 | parser.add_argument('--pad_idx', type=int, default=0, help='the index of .') 43 | parser.add_argument('--use_bn', type=int, default=0, help='whether to use batch normalization.') 44 | parser.add_argument('--drop_prob_lm', type=float, default=0.5, help='the dropout rate of the output layer.') 45 | # for Cross-modal Memory 46 | parser.add_argument('--topk', type=int, default=32, help='the number of k.') 47 | parser.add_argument('--cmm_size', type=int, default=2048, help='the numebr of cmm size.') 48 | parser.add_argument('--cmm_dim', type=int, default=512, help='the dimension of cmm dimension.') 49 | 50 | # Sample related 51 | parser.add_argument('--sample_method', type=str, default='beam_search', help='the sample methods to sample a report.') 52 | parser.add_argument('--beam_size', type=int, default=3, help='the beam size when beam searching.') 53 | parser.add_argument('--temperature', type=float, default=1.0, help='the temperature when sampling.') 54 | parser.add_argument('--sample_n', type=int, default=1, help='the sample number per image.') 55 | parser.add_argument('--group_size', type=int, default=1, help='the group size.') 56 | parser.add_argument('--output_logsoftmax', type=int, default=1, help='whether to output the probabilities.') 57 | parser.add_argument('--decoding_constraint', type=int, default=0, help='whether decoding constraint.') 58 | parser.add_argument('--block_trigrams', type=int, default=1, help='whether to use block trigrams.') 59 | 60 | # Trainer settings 61 | parser.add_argument('--n_gpu', type=int, default=1, help='the number of gpus to be used.') 62 | parser.add_argument('--epochs', type=int, default=100, help='the number of training epochs.') 63 | parser.add_argument('--save_dir', type=str, default='results/iu_xray', help='the patch to save the models.') 64 | parser.add_argument('--record_dir', type=str, default='records_acl/', help='the patch to save the results of experiments.') 65 | parser.add_argument('--log_period', type=int, default=1000, help='the logging interval (in batches).') 66 | parser.add_argument('--save_period', type=int, default=1, help='the saving period (in epochs).') 67 | parser.add_argument('--monitor_mode', type=str, default='max', choices=['min', 'max'], help='whether to max or min the metric.') 68 | parser.add_argument('--monitor_metric', type=str, default='BLEU_4', help='the metric to be monitored.') 69 | parser.add_argument('--early_stop', type=int, default=50, help='the patience of training.') 70 | 71 | # Optimization 72 | parser.add_argument('--optim', type=str, default='Adam', help='the type of the optimizer.') 73 | parser.add_argument('--lr_ve', type=float, default=5e-5, help='the learning rate for the visual extractor.') 74 | parser.add_argument('--lr_ed', type=float, default=7e-4, help='the learning rate for the remaining parameters.') 75 | parser.add_argument('--weight_decay', type=float, default=5e-5, help='the weight decay.') 76 | parser.add_argument('--adam_betas', type=tuple, default=(0.9, 0.98), help='the weight decay.') 77 | parser.add_argument('--adam_eps', type=float, default=1e-9, help='the weight decay.') 78 | parser.add_argument('--amsgrad', type=bool, default=True, help='.') 79 | parser.add_argument('--noamopt_warmup', type=int, default=5000, help='.') 80 | parser.add_argument('--noamopt_factor', type=int, default=1, help='.') 81 | 82 | # Learning Rate Scheduler 83 | parser.add_argument('--lr_scheduler', type=str, default='StepLR', help='the type of the learning rate scheduler.') 84 | parser.add_argument('--step_size', type=int, default=50, help='the step size of the learning rate scheduler.') 85 | parser.add_argument('--gamma', type=float, default=0.1, help='the gamma of the learning rate scheduler.') 86 | 87 | # Others 88 | parser.add_argument('--seed', type=int, default=9233, help='.') 89 | parser.add_argument('--resume', type=str, help='whether to resume the training from existing checkpoints.') 90 | 91 | args = parser.parse_args() 92 | return args 93 | 94 | 95 | def main(): 96 | # parse arguments 97 | args = parse_agrs() 98 | 99 | # fix random seeds 100 | torch.manual_seed(args.seed) 101 | torch.backends.cudnn.deterministic = True 102 | torch.backends.cudnn.benchmark = False 103 | np.random.seed(args.seed) 104 | 105 | # create tokenizer 106 | tokenizer = Tokenizer(args) 107 | 108 | # create data loader 109 | train_dataloader = R2DataLoader(args, tokenizer, split='train', shuffle=True) 110 | val_dataloader = R2DataLoader(args, tokenizer, split='val', shuffle=False) 111 | test_dataloader = R2DataLoader(args, tokenizer, split='test', shuffle=False) 112 | 113 | # build model architecture 114 | model = BaseCMNModel(args, tokenizer) 115 | 116 | # get function handles of loss and metrics 117 | criterion = compute_loss 118 | metrics = compute_scores 119 | 120 | # build optimizer, learning rate scheduler 121 | ve_optimizer, ed_optimizer = build_noamopt_optimizer(args, model) 122 | # lr_scheduler = build_lr_scheduler(args, optimizer) 123 | 124 | # build trainer and start to train 125 | trainer = Trainer(model, criterion, metrics, ve_optimizer, ed_optimizer, args, train_dataloader, val_dataloader, test_dataloader) 126 | trainer.train() 127 | 128 | 129 | if __name__ == '__main__': 130 | main() 131 | -------------------------------------------------------------------------------- /train_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import numpy as np 4 | from modules.tokenizers import Tokenizer 5 | from modules.dataloaders import R2DataLoader 6 | from modules.metrics import compute_scores 7 | from modules.optimizers import build_optimizer, build_lr_scheduler 8 | from modules.trainer_base import Trainer 9 | from modules.loss import compute_loss 10 | from models.r2gen import R2GenModel 11 | 12 | 13 | def parse_agrs(): 14 | parser = argparse.ArgumentParser() 15 | 16 | # Data input settings 17 | parser.add_argument('--image_dir', type=str, default='data/iu_xray/images/', help='the path to the directory containing the data.') 18 | parser.add_argument('--ann_path', type=str, default='data/iu_xray/annotation.json', help='the path to the directory containing the data.') 19 | 20 | # Data loader settings 21 | parser.add_argument('--dataset_name', type=str, default='iu_xray', choices=['iu_xray', 'mimic_cxr'], help='the dataset to be used.') 22 | parser.add_argument('--max_seq_length', type=int, default=60, help='the maximum sequence length of the reports.') 23 | parser.add_argument('--threshold', type=int, default=3, help='the cut off frequency for the words.') 24 | parser.add_argument('--num_workers', type=int, default=2, help='the number of workers for dataloader.') 25 | parser.add_argument('--batch_size', type=int, default=16, help='the number of samples for a batch') 26 | 27 | # Model settings (for visual extractor) 28 | parser.add_argument('--visual_extractor', type=str, default='resnet101', help='the visual extractor to be used.') 29 | parser.add_argument('--visual_extractor_pretrained', type=bool, default=True, help='whether to load the pretrained visual extractor') 30 | parser.add_argument('--num_labels', type=int, default=14, help='the size of the label set') 31 | 32 | # Model settings (for Transformer) 33 | parser.add_argument('--d_model', type=int, default=512, help='the dimension of Transformer.') 34 | parser.add_argument('--d_ff', type=int, default=512, help='the dimension of FFN.') 35 | parser.add_argument('--d_vf', type=int, default=2048, help='the dimension of the patch features.') 36 | parser.add_argument('--num_heads', type=int, default=8, help='the number of heads in Transformer.') 37 | parser.add_argument('--num_layers', type=int, default=3, help='the number of layers of Transformer.') 38 | parser.add_argument('--dropout', type=float, default=0.1, help='the dropout rate of Transformer.') 39 | parser.add_argument('--logit_layers', type=int, default=1, help='the number of the logit layer.') 40 | parser.add_argument('--bos_idx', type=int, default=0, help='the index of .') 41 | parser.add_argument('--eos_idx', type=int, default=0, help='the index of .') 42 | parser.add_argument('--pad_idx', type=int, default=0, help='the index of .') 43 | parser.add_argument('--use_bn', type=int, default=0, help='whether to use batch normalization.') 44 | parser.add_argument('--drop_prob_lm', type=float, default=0.5, help='the dropout rate of the output layer.') 45 | # for Relational Memory 46 | parser.add_argument('--rm_num_slots', type=int, default=3, help='the number of memory slots.') 47 | parser.add_argument('--rm_num_heads', type=int, default=8, help='the numebr of heads in rm.') 48 | parser.add_argument('--rm_d_model', type=int, default=512, help='the dimension of rm.') 49 | 50 | # Sample related 51 | parser.add_argument('--sample_method', type=str, default='beam_search', help='the sample methods to sample a report.') 52 | parser.add_argument('--beam_size', type=int, default=3, help='the beam size when beam searching.') 53 | parser.add_argument('--temperature', type=float, default=1.0, help='the temperature when sampling.') 54 | parser.add_argument('--sample_n', type=int, default=1, help='the sample number per image.') 55 | parser.add_argument('--group_size', type=int, default=1, help='the group size.') 56 | parser.add_argument('--output_logsoftmax', type=int, default=1, help='whether to output the probabilities.') 57 | parser.add_argument('--decoding_constraint', type=int, default=0, help='whether decoding constraint.') 58 | parser.add_argument('--block_trigrams', type=int, default=1, help='whether to use block trigrams.') 59 | 60 | # Trainer settings 61 | parser.add_argument('--n_gpu', type=int, default=1, help='the number of gpus to be used.') 62 | parser.add_argument('--epochs', type=int, default=100, help='the number of training epochs.') 63 | parser.add_argument('--save_dir', type=str, default='results/iu_xray', help='the patch to save the models.') 64 | parser.add_argument('--record_dir', type=str, default='records/', help='the patch to save the results of experiments.') 65 | parser.add_argument('--log_period', type=int, default=1000, help='the logging interval (in batches).') 66 | parser.add_argument('--save_period', type=int, default=1, help='the saving period (in epochs).') 67 | parser.add_argument('--monitor_mode', type=str, default='max', choices=['min', 'max'], help='whether to max or min the metric.') 68 | parser.add_argument('--monitor_metric', type=str, default='BLEU_4', help='the metric to be monitored.') 69 | parser.add_argument('--early_stop', type=int, default=50, help='the patience of training.') 70 | 71 | # Optimization 72 | parser.add_argument('--optim', type=str, default='Adam', help='the type of the optimizer.') 73 | parser.add_argument('--lr_ve', type=float, default=5e-5, help='the learning rate for the visual extractor.') 74 | parser.add_argument('--lr_ed', type=float, default=1e-4, help='the learning rate for the remaining parameters.') 75 | parser.add_argument('--weight_decay', type=float, default=5e-5, help='the weight decay.') 76 | parser.add_argument('--adam_betas', type=tuple, default=(0.9, 0.98), help='the weight decay.') 77 | parser.add_argument('--adam_eps', type=float, default=1e-9, help='the weight decay.') 78 | parser.add_argument('--amsgrad', type=bool, default=True, help='.') 79 | 80 | # Learning Rate Scheduler 81 | parser.add_argument('--lr_scheduler', type=str, default='StepLR', help='the type of the learning rate scheduler.') 82 | parser.add_argument('--step_size', type=int, default=50, help='the step size of the learning rate scheduler.') 83 | parser.add_argument('--gamma', type=float, default=0.1, help='the gamma of the learning rate scheduler.') 84 | 85 | # Others 86 | parser.add_argument('--seed', type=int, default=9233, help='.') 87 | parser.add_argument('--resume', type=str, help='whether to resume the training from existing checkpoints.') 88 | 89 | args = parser.parse_args() 90 | return args 91 | 92 | 93 | def main(): 94 | # parse arguments 95 | args = parse_agrs() 96 | 97 | # fix random seeds 98 | torch.manual_seed(args.seed) 99 | torch.backends.cudnn.deterministic = True 100 | torch.backends.cudnn.benchmark = False 101 | np.random.seed(args.seed) 102 | 103 | # create tokenizer 104 | tokenizer = Tokenizer(args) 105 | 106 | # create data loader 107 | train_dataloader = R2DataLoader(args, tokenizer, split='train', shuffle=True) 108 | val_dataloader = R2DataLoader(args, tokenizer, split='val', shuffle=False) 109 | test_dataloader = R2DataLoader(args, tokenizer, split='test', shuffle=False) 110 | 111 | # build model architecture 112 | model = R2GenModel(args, tokenizer) 113 | 114 | # get function handles of loss and metrics 115 | criterion = compute_loss 116 | metrics = compute_scores 117 | 118 | # build optimizer, learning rate scheduler 119 | optimizer = build_optimizer(args, model) 120 | lr_scheduler = build_lr_scheduler(args, optimizer) 121 | 122 | # build trainer and start to train 123 | trainer = Trainer(model, criterion, metrics, optimizer, args, lr_scheduler, train_dataloader, val_dataloader, test_dataloader) 124 | trainer.train() 125 | 126 | 127 | if __name__ == '__main__': 128 | main() 129 | -------------------------------------------------------------------------------- /train_rl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import numpy as np 4 | from modules.tokenizers import Tokenizer 5 | from modules.dataloaders import R2DataLoader 6 | from modules.metrics import compute_scores 7 | from modules.optimizers import build_optimizer, build_lr_scheduler, build_noamopt_optimizer, build_plateau_optimizer 8 | from modules.trainer_rl import Trainer 9 | from modules.loss import compute_loss, RewardCriterion 10 | from models.models import BaseCMNModel 11 | 12 | 13 | def parse_agrs(): 14 | parser = argparse.ArgumentParser() 15 | 16 | # Data input settings 17 | parser.add_argument('--image_dir', type=str, default='data/iu_xray/images/', help='the path to the directory containing the data.') 18 | parser.add_argument('--ann_path', type=str, default='data/iu_xray/annotation.json', help='the path to the directory containing the data.') 19 | 20 | # Data loader settings 21 | parser.add_argument('--dataset_name', type=str, default='iu_xray', choices=['iu_xray', 'mimic_cxr'], help='the dataset to be used.') 22 | parser.add_argument('--max_seq_length', type=int, default=60, help='the maximum sequence length of the reports.') 23 | parser.add_argument('--threshold', type=int, default=3, help='the cut off frequency for the words.') 24 | parser.add_argument('--num_workers', type=int, default=2, help='the number of workers for dataloader.') 25 | parser.add_argument('--batch_size', type=int, default=16, help='the number of samples for a batch') 26 | 27 | # Model settings (for visual extractor) 28 | parser.add_argument('--visual_extractor', type=str, default='resnet101', help='the visual extractor to be used.') 29 | parser.add_argument('--visual_extractor_pretrained', type=bool, default=True, help='whether to load the pretrained visual extractor') 30 | parser.add_argument('--num_labels', type=int, default=14, help='the size of the label set') 31 | 32 | # Model settings (for Transformer) 33 | parser.add_argument('--d_model', type=int, default=512, help='the dimension of Transformer.') 34 | parser.add_argument('--d_ff', type=int, default=512, help='the dimension of FFN.') 35 | parser.add_argument('--d_vf', type=int, default=2048, help='the dimension of the patch features.') 36 | parser.add_argument('--num_heads', type=int, default=8, help='the number of heads in Transformer.') 37 | parser.add_argument('--num_layers', type=int, default=3, help='the number of layers of Transformer.') 38 | parser.add_argument('--dropout', type=float, default=0.1, help='the dropout rate of Transformer.') 39 | parser.add_argument('--logit_layers', type=int, default=1, help='the number of the logit layer.') 40 | parser.add_argument('--bos_idx', type=int, default=0, help='the index of .') 41 | parser.add_argument('--eos_idx', type=int, default=0, help='the index of .') 42 | parser.add_argument('--pad_idx', type=int, default=0, help='the index of .') 43 | parser.add_argument('--use_bn', type=int, default=0, help='whether to use batch normalization.') 44 | parser.add_argument('--drop_prob_lm', type=float, default=0.5, help='the dropout rate of the output layer.') 45 | # for Cross-modal Memory 46 | parser.add_argument('--topk', type=int, default=32, help='the number of k.') 47 | parser.add_argument('--cmm_size', type=int, default=2048, help='the numebr of cmm size.') 48 | parser.add_argument('--cmm_dim', type=int, default=512, help='the dimension of cmm dimension.') 49 | 50 | # Sample related 51 | parser.add_argument('--sample_method', type=str, default='beam_search', help='the sample methods to sample a report.') 52 | parser.add_argument('--beam_size', type=int, default=3, help='the beam size when beam searching.') 53 | parser.add_argument('--temperature', type=float, default=1.0, help='the temperature when sampling.') 54 | parser.add_argument('--sample_n', type=int, default=1, help='the sample number per image.') 55 | parser.add_argument('--group_size', type=int, default=1, help='the group size.') 56 | parser.add_argument('--output_logsoftmax', type=int, default=1, help='whether to output the probabilities.') 57 | parser.add_argument('--decoding_constraint', type=int, default=0, help='whether decoding constraint.') 58 | parser.add_argument('--block_trigrams', type=int, default=1, help='whether to use block trigrams.') 59 | 60 | # Trainer settings 61 | parser.add_argument('--n_gpu', type=int, default=1, help='the number of gpus to be used.') 62 | parser.add_argument('--epochs', type=int, default=100, help='the number of training epochs.') 63 | parser.add_argument('--save_dir', type=str, default='results/iu_xray', help='the patch to save the models.') 64 | parser.add_argument('--record_dir', type=str, default='records_acl/', help='the patch to save the results of experiments.') 65 | parser.add_argument('--log_period', type=int, default=10, help='the logging interval (in batches).') 66 | parser.add_argument('--save_period', type=int, default=1, help='the saving period (in epochs).') 67 | parser.add_argument('--sc_eval_period', type=int, default=10000, help='the saving period (in epochs).') 68 | parser.add_argument('--monitor_mode', type=str, default='max', choices=['min', 'max'], help='whether to max or min the metric.') 69 | parser.add_argument('--monitor_metric', type=str, default='BLEU_4', help='the metric to be monitored.') 70 | parser.add_argument('--early_stop', type=int, default=50, help='the patience of training.') 71 | 72 | # Optimization 73 | parser.add_argument('--optim', type=str, default='Adam', help='the type of the optimizer.') 74 | parser.add_argument('--lr_ve', type=float, default=1e-6, help='the learning rate for the visual extractor.') 75 | parser.add_argument('--lr_ed', type=float, default=1e-5, help='the learning rate for the remaining parameters.') 76 | parser.add_argument('--weight_decay', type=float, default=5e-5, help='the weight decay.') 77 | parser.add_argument('--adam_betas', type=tuple, default=(0.9, 0.98), help='the weight decay.') 78 | parser.add_argument('--adam_eps', type=float, default=1e-9, help='the weight decay.') 79 | parser.add_argument('--amsgrad', type=bool, default=True, help='.') 80 | parser.add_argument('--noamopt_warmup', type=int, default=5000, help='.') 81 | parser.add_argument('--noamopt_factor', type=int, default=1, help='.') 82 | 83 | parser.add_argument('--reduce_on_plateau_factor', type=float, default=0.5, help='') 84 | parser.add_argument('--reduce_on_plateau_patience', type=int, default=3, help='') 85 | 86 | # Learning Rate Scheduler 87 | parser.add_argument('--lr_scheduler', type=str, default='StepLR', help='the type of the learning rate scheduler.') 88 | parser.add_argument('--step_size', type=int, default=50, help='the step size of the learning rate scheduler.') 89 | parser.add_argument('--gamma', type=float, default=0.1, help='the gamma of the learning rate scheduler.') 90 | 91 | # Self-Critical Training 92 | parser.add_argument('--train_sample_n', type=int, default=1, help='The reward weight from cider') 93 | parser.add_argument('--train_sample_method', type=str, default='sample', help='') 94 | parser.add_argument('--train_beam_size', type=int, default=1, help='') 95 | parser.add_argument('--sc_sample_method', type=str, default='greedy', help='') 96 | parser.add_argument('--sc_beam_size', type=int, default=1, help='') 97 | 98 | # Others 99 | parser.add_argument('--seed', type=int, default=9233, help='.') 100 | parser.add_argument('--resume', type=str, help='whether to resume the training from existing checkpoints.') 101 | 102 | args = parser.parse_args() 103 | return args 104 | 105 | 106 | def main(): 107 | # parse arguments 108 | args = parse_agrs() 109 | 110 | # fix random seeds 111 | torch.manual_seed(args.seed) 112 | torch.backends.cudnn.deterministic = True 113 | torch.backends.cudnn.benchmark = False 114 | np.random.seed(args.seed) 115 | 116 | # create tokenizer 117 | tokenizer = Tokenizer(args) 118 | 119 | # create data loader 120 | train_dataloader = R2DataLoader(args, tokenizer, split='train', shuffle=True) 121 | val_dataloader = R2DataLoader(args, tokenizer, split='val', shuffle=False) 122 | test_dataloader = R2DataLoader(args, tokenizer, split='test', shuffle=False) 123 | 124 | # build model architecture 125 | model = BaseCMNModel(args, tokenizer) 126 | 127 | # get function handles of loss and metrics 128 | criterion = RewardCriterion() 129 | metrics = compute_scores 130 | 131 | # build optimizer, learning rate scheduler 132 | ve_optimizer, ed_optimizer = build_plateau_optimizer(args, model) 133 | # lr_scheduler = build_lr_scheduler(args, optimizer) 134 | 135 | # build trainer and start to train 136 | trainer = Trainer(model, criterion, metrics, ve_optimizer, ed_optimizer, args, train_dataloader, val_dataloader, test_dataloader) 137 | trainer.train() 138 | 139 | 140 | if __name__ == '__main__': 141 | main() 142 | -------------------------------------------------------------------------------- /train_rl_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import numpy as np 4 | from modules.tokenizers import Tokenizer 5 | from modules.dataloaders import R2DataLoader 6 | from modules.metrics import compute_scores 7 | from modules.optimizers import build_optimizer, build_lr_scheduler, build_noamopt_optimizer, build_plateau_optimizer 8 | from modules.trainer_rl import Trainer 9 | from modules.loss import compute_loss, RewardCriterion 10 | from models.r2gen import R2GenModel 11 | 12 | 13 | def parse_agrs(): 14 | parser = argparse.ArgumentParser() 15 | 16 | # Data input settings 17 | parser.add_argument('--image_dir', type=str, default='data/iu_xray/images/', help='the path to the directory containing the data.') 18 | parser.add_argument('--ann_path', type=str, default='data/iu_xray/annotation.json', help='the path to the directory containing the data.') 19 | 20 | # Data loader settings 21 | parser.add_argument('--dataset_name', type=str, default='iu_xray', choices=['iu_xray', 'mimic_cxr'], help='the dataset to be used.') 22 | parser.add_argument('--max_seq_length', type=int, default=60, help='the maximum sequence length of the reports.') 23 | parser.add_argument('--threshold', type=int, default=3, help='the cut off frequency for the words.') 24 | parser.add_argument('--num_workers', type=int, default=2, help='the number of workers for dataloader.') 25 | parser.add_argument('--batch_size', type=int, default=16, help='the number of samples for a batch') 26 | 27 | # Model settings (for visual extractor) 28 | parser.add_argument('--visual_extractor', type=str, default='resnet101', help='the visual extractor to be used.') 29 | parser.add_argument('--visual_extractor_pretrained', type=bool, default=True, help='whether to load the pretrained visual extractor') 30 | parser.add_argument('--num_labels', type=int, default=14, help='the size of the label set') 31 | 32 | # Model settings (for Transformer) 33 | parser.add_argument('--d_model', type=int, default=512, help='the dimension of Transformer.') 34 | parser.add_argument('--d_ff', type=int, default=512, help='the dimension of FFN.') 35 | parser.add_argument('--d_vf', type=int, default=2048, help='the dimension of the patch features.') 36 | parser.add_argument('--num_heads', type=int, default=8, help='the number of heads in Transformer.') 37 | parser.add_argument('--num_layers', type=int, default=3, help='the number of layers of Transformer.') 38 | parser.add_argument('--dropout', type=float, default=0.1, help='the dropout rate of Transformer.') 39 | parser.add_argument('--logit_layers', type=int, default=1, help='the number of the logit layer.') 40 | parser.add_argument('--bos_idx', type=int, default=0, help='the index of .') 41 | parser.add_argument('--eos_idx', type=int, default=0, help='the index of .') 42 | parser.add_argument('--pad_idx', type=int, default=0, help='the index of .') 43 | parser.add_argument('--use_bn', type=int, default=0, help='whether to use batch normalization.') 44 | parser.add_argument('--drop_prob_lm', type=float, default=0.5, help='the dropout rate of the output layer.') 45 | # for Cross-modal Memory 46 | parser.add_argument('--topk', type=int, default=32, help='the number of k.') 47 | parser.add_argument('--cmm_size', type=int, default=2048, help='the numebr of cmm size.') 48 | parser.add_argument('--cmm_dim', type=int, default=512, help='the dimension of cmm dimension.') 49 | 50 | # Sample related 51 | parser.add_argument('--sample_method', type=str, default='beam_search', help='the sample methods to sample a report.') 52 | parser.add_argument('--beam_size', type=int, default=3, help='the beam size when beam searching.') 53 | parser.add_argument('--temperature', type=float, default=1.0, help='the temperature when sampling.') 54 | parser.add_argument('--sample_n', type=int, default=1, help='the sample number per image.') 55 | parser.add_argument('--group_size', type=int, default=1, help='the group size.') 56 | parser.add_argument('--output_logsoftmax', type=int, default=1, help='whether to output the probabilities.') 57 | parser.add_argument('--decoding_constraint', type=int, default=0, help='whether decoding constraint.') 58 | parser.add_argument('--block_trigrams', type=int, default=1, help='whether to use block trigrams.') 59 | 60 | # Trainer settings 61 | parser.add_argument('--n_gpu', type=int, default=1, help='the number of gpus to be used.') 62 | parser.add_argument('--epochs', type=int, default=100, help='the number of training epochs.') 63 | parser.add_argument('--save_dir', type=str, default='results/iu_xray', help='the patch to save the models.') 64 | parser.add_argument('--record_dir', type=str, default='records_acl/', help='the patch to save the results of experiments.') 65 | parser.add_argument('--log_period', type=int, default=10, help='the logging interval (in batches).') 66 | parser.add_argument('--save_period', type=int, default=1, help='the saving period (in epochs).') 67 | parser.add_argument('--sc_eval_period', type=int, default=10000, help='the saving period (in epochs).') 68 | parser.add_argument('--monitor_mode', type=str, default='max', choices=['min', 'max'], help='whether to max or min the metric.') 69 | parser.add_argument('--monitor_metric', type=str, default='BLEU_4', help='the metric to be monitored.') 70 | parser.add_argument('--early_stop', type=int, default=50, help='the patience of training.') 71 | 72 | # Optimization 73 | parser.add_argument('--optim', type=str, default='Adam', help='the type of the optimizer.') 74 | parser.add_argument('--lr_ve', type=float, default=1e-6, help='the learning rate for the visual extractor.') 75 | parser.add_argument('--lr_ed', type=float, default=1e-5, help='the learning rate for the remaining parameters.') 76 | parser.add_argument('--weight_decay', type=float, default=5e-5, help='the weight decay.') 77 | parser.add_argument('--adam_betas', type=tuple, default=(0.9, 0.98), help='the weight decay.') 78 | parser.add_argument('--adam_eps', type=float, default=1e-9, help='the weight decay.') 79 | parser.add_argument('--amsgrad', type=bool, default=True, help='.') 80 | parser.add_argument('--noamopt_warmup', type=int, default=5000, help='.') 81 | parser.add_argument('--noamopt_factor', type=int, default=1, help='.') 82 | 83 | parser.add_argument('--reduce_on_plateau_factor', type=float, default=0.5, help='') 84 | parser.add_argument('--reduce_on_plateau_patience', type=int, default=3, help='') 85 | 86 | # Learning Rate Scheduler 87 | parser.add_argument('--lr_scheduler', type=str, default='StepLR', help='the type of the learning rate scheduler.') 88 | parser.add_argument('--step_size', type=int, default=50, help='the step size of the learning rate scheduler.') 89 | parser.add_argument('--gamma', type=float, default=0.1, help='the gamma of the learning rate scheduler.') 90 | 91 | # Self-Critical Training 92 | parser.add_argument('--train_sample_n', type=int, default=1, help='The reward weight from cider') 93 | parser.add_argument('--train_sample_method', type=str, default='sample', help='') 94 | parser.add_argument('--train_beam_size', type=int, default=1, help='') 95 | parser.add_argument('--sc_sample_method', type=str, default='greedy', help='') 96 | parser.add_argument('--sc_beam_size', type=int, default=1, help='') 97 | 98 | # Others 99 | parser.add_argument('--seed', type=int, default=9233, help='.') 100 | parser.add_argument('--resume', type=str, help='whether to resume the training from existing checkpoints.') 101 | 102 | args = parser.parse_args() 103 | return args 104 | 105 | 106 | def main(): 107 | # parse arguments 108 | args = parse_agrs() 109 | 110 | # fix random seeds 111 | torch.manual_seed(args.seed) 112 | torch.backends.cudnn.deterministic = True 113 | torch.backends.cudnn.benchmark = False 114 | np.random.seed(args.seed) 115 | 116 | # create tokenizer 117 | tokenizer = Tokenizer(args) 118 | 119 | # create data loader 120 | train_dataloader = R2DataLoader(args, tokenizer, split='train', shuffle=True) 121 | val_dataloader = R2DataLoader(args, tokenizer, split='val', shuffle=False) 122 | test_dataloader = R2DataLoader(args, tokenizer, split='test', shuffle=False) 123 | 124 | # build model architecture 125 | model = R2GenModel(args, tokenizer) 126 | 127 | # get function handles of loss and metrics 128 | criterion = RewardCriterion() 129 | metrics = compute_scores 130 | 131 | # build optimizer, learning rate scheduler 132 | ve_optimizer, ed_optimizer = build_plateau_optimizer(args, model) 133 | # lr_scheduler = build_lr_scheduler(args, optimizer) 134 | 135 | # build trainer and start to train 136 | trainer = Trainer(model, criterion, metrics, ve_optimizer, ed_optimizer, args, train_dataloader, val_dataloader, test_dataloader) 137 | trainer.train() 138 | 139 | 140 | if __name__ == '__main__': 141 | main() 142 | --------------------------------------------------------------------------------