├── README.md ├── allennlp-requirements.txt ├── config.py ├── data └── data.txt ├── dataloaders ├── bert_field.py ├── box_utils.py ├── cocoontology.json ├── mask_utils.py └── vcr.py ├── multiatt └── default.json ├── pic ├── fig1.png ├── fig2.png └── fig3.png ├── poster └── poster.pdf ├── train └── train.py └── utils ├── detector.py ├── newdetector.py ├── pytorch_misc.py ├── smalldetector.py └── testdetector.py /README.md: -------------------------------------------------------------------------------- 1 | # CCN 2 | ## Connective Cognition Network for Directional Visual Commonsense Reasoning (NeurIPS 2019) 3 | 4 | ![Method](https://github.com/AmingWu/CCN/blob/master/pic/fig1.png?raw=true "Illustration of our method") 5 | Visual commonsense reasoning (VCR) has been introduced to boost research of cognition-level visual understanding, i.e., a thorough understanding of correlated details of the scene plus an inference with related commonsense knowledge. We propose a connective cognition network (CCN) to dynamically reorganize the visual neuron connectivity that is contextualized by the meaning of questions and answers. And our method mainly includes visual neuron connectivity, contextualized connectivity, and directional connectivity. 6 | 7 | ![Framework](https://github.com/AmingWu/CCN/blob/master/pic/fig2.png?raw=true "Illustration of our framework") 8 | 9 | The goal of visual neuron connectivity is to obtain a global representation of an image, which is helpful for a thorough understanding of visual content. It mainly includes visual element connectivity and the computation of both conditional centers and GraphVLAD. 10 | 11 | ![Visual Neuron Connectivity](https://github.com/AmingWu/CCN/blob/master/pic/fig3.png?raw=true "Illustration of Visual Neuron Connectivity") 12 | 13 | ## Setting Up and Data Preparation 14 | We used pytorch 1.1.0, python 3.6, and CUDA 9.0 for this project. Before using this code, you should download VCR dataset from this link, i.e., https://visualcommonsense.com/. Follow the steps given by the link, i.e., https://github.com/rowanz/r2c/, to set up the running environment. 15 | 16 | ## Training and Validation 17 | export CUDA_VISIBLE_DEVICES=0,1,2 18 | python train.py -params multiatt/default.json -folder saves/flagship_answer 19 | 20 | ## Citation 21 | ```bibtex 22 | @incollection{NIPS2019_8804, 23 | title = {Connective Cognition Network for Directional Visual Commonsense Reasoning}, 24 | author = {Wu, Aming and Zhu, Linchao and Han, Yahong and Yang, Yi}, 25 | booktitle = {Advances in Neural Information Processing Systems 32}, 26 | url = {http://papers.nips.cc/paper/8804-connective-cognition-network-for-directional-visual-commonsense-reasoning.pdf} 27 | } 28 | ``` 29 | -------------------------------------------------------------------------------- /allennlp-requirements.txt: -------------------------------------------------------------------------------- 1 | # Library dependencies for the python code. You need to install these with 2 | # `pip install -r requirements.txt` before you can run this. 3 | # NOTE: all essential packages must be placed under a section named 'ESSENTIAL ...' 4 | # so that the script `./scripts/check_requirements_and_setup.py` can find them. 5 | 6 | #### ESSENTIAL LIBRARIES FOR MAIN FUNCTIONALITY #### 7 | 8 | # Parameter parsing (but not on Windows). 9 | jsonnet==0.10.0 ; sys.platform != 'win32' 10 | 11 | # Adds an @overrides decorator for better documentation and error checking when using subclasses. 12 | overrides 13 | 14 | # Used by some old code. We moved away from it because it's too slow, but some old code still 15 | # imports this. 16 | nltk 17 | 18 | # Pin msgpack because the newer version introduces an incompatibility with spaCy 19 | # Get rid of this if we ever unpin spacy 20 | msgpack>=0.5.6,<0.6.0 21 | 22 | # Mainly used for the faster tokenizer. 23 | spacy>=2.0,<2.1 24 | 25 | # Used by span prediction models. 26 | numpy 27 | 28 | # Used for reading configuration info out of numpy-style docstrings. 29 | numpydoc==0.8.0 30 | 31 | # Used in coreference resolution evaluation metrics. 32 | scipy 33 | scikit-learn 34 | 35 | # Write logs for training visualisation with the Tensorboard application 36 | # Install the Tensorboard application separately (part of tensorflow) to view them. 37 | tensorboardX==1.2 38 | 39 | # Required by torch.utils.ffi 40 | cffi==1.11.5 41 | 42 | # aws commandline tools for running on Docker remotely. 43 | # second requirement is to get botocore < 1.11, to avoid the below bug 44 | awscli>=1.11.91 45 | 46 | # Accessing files from S3 directly. 47 | boto3 48 | 49 | # REST interface for models 50 | flask==1.0.2 51 | flask-cors==3.0.7 52 | gevent==1.3.6 53 | 54 | # Used by semantic parsing code to strip diacritics from unicode strings. 55 | unidecode 56 | 57 | # Used by semantic parsing code to parse SQL 58 | parsimonious==0.8.0 59 | 60 | # Used by semantic parsing code to format and postprocess SQL 61 | sqlparse==0.2.4 62 | 63 | # For text normalization 64 | ftfy 65 | 66 | # To use the BERT model 67 | pytorch-pretrained-bert==0.3.0 68 | 69 | #### ESSENTIAL LIBRARIES USED IN SCRIPTS #### 70 | 71 | # Plot graphs for learning rate finder 72 | matplotlib==2.2.3 73 | 74 | # Used for downloading datasets over HTTP 75 | requests>=2.18 76 | 77 | # progress bars in data cleaning scripts 78 | tqdm>=4.19 79 | 80 | # In SQuAD eval script, we use this to see if we likely have some tokenization problem. 81 | editdistance 82 | 83 | # For pretrained model weights 84 | h5py 85 | 86 | # For timezone utilities 87 | pytz==2017.3 88 | 89 | # Reads Universal Dependencies files. 90 | conllu==0.11 91 | 92 | #### ESSENTIAL TESTING-RELATED PACKAGES #### 93 | 94 | # We'll use pytest to run our tests; this isn't really necessary to run the code, but it is to run 95 | # the tests. With this here, you can run the tests with `py.test` from the base directory. 96 | pytest 97 | 98 | # Allows marking tests as flaky, to be rerun if they fail 99 | flaky 100 | 101 | # Required to mock out `requests` calls 102 | responses>=0.7 103 | 104 | # For mocking s3. 105 | moto==1.3.4 106 | 107 | #### TESTING-RELATED PACKAGES #### 108 | 109 | # Checks style, syntax, and other useful errors. 110 | pylint==1.8.1 111 | 112 | # Tutorial notebooks 113 | # see: https://github.com/jupyter/jupyter/issues/370 for ipykernel 114 | ipykernel<5.0.0 115 | jupyter 116 | 117 | # Static type checking 118 | mypy==0.521 119 | 120 | # Allows generation of coverage reports with pytest. 121 | pytest-cov 122 | 123 | # Allows codecov to generate coverage reports 124 | coverage 125 | codecov 126 | 127 | # Required to run sanic tests 128 | aiohttp 129 | 130 | #### DOC-RELATED PACKAGES #### 131 | 132 | # Builds our documentation. 133 | sphinx==1.5.3 134 | 135 | # Watches the documentation directory and rebuilds on changes. 136 | sphinx-autobuild 137 | 138 | # doc theme 139 | sphinx_rtd_theme 140 | 141 | # Only used to convert our readme to reStructuredText on Pypi. 142 | pypandoc 143 | 144 | # Pypi uploads 145 | twine==1.11.0 146 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | USE_IMAGENET_PRETRAINED = True # otherwise use detectron, but that doesnt seem to work?!? 3 | 4 | # Change these to match where your annotations and images are 5 | VCR_IMAGES_DIR = os.path.join(os.path.dirname(__file__), 'data', 'vcr1images') 6 | VCR_ANNOTS_DIR = os.path.join(os.path.dirname(__file__), 'data') 7 | 8 | if not os.path.exists(VCR_IMAGES_DIR): 9 | raise ValueError("Update config.py with where you saved VCR images to.") -------------------------------------------------------------------------------- /data/data.txt: -------------------------------------------------------------------------------- 1 | The data is downloaded from the link https://visualcommonsense.com/ 2 | -------------------------------------------------------------------------------- /dataloaders/bert_field.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional 2 | import textwrap 3 | 4 | from overrides import overrides 5 | from spacy.tokens import Token as SpacyToken 6 | import torch 7 | 8 | from allennlp.common.checks import ConfigurationError 9 | from allennlp.data.fields.sequence_field import SequenceField 10 | from allennlp.data.tokenizers.token import Token 11 | from allennlp.data.token_indexers.token_indexer import TokenIndexer, TokenType 12 | from allennlp.data.vocabulary import Vocabulary 13 | from allennlp.nn import util 14 | import numpy 15 | TokenList = List[TokenType] # pylint: disable=invalid-name 16 | 17 | 18 | # This will work for anything really 19 | class BertField(SequenceField[Dict[str, torch.Tensor]]): 20 | """ 21 | A class representing an array, which could have arbitrary dimensions. 22 | A batch of these arrays are padded to the max dimension length in the batch 23 | for each dimension. 24 | """ 25 | def __init__(self, tokens: List[Token], embs: numpy.ndarray, padding_value: int = 0, 26 | token_indexers=None) -> None: 27 | self.tokens = tokens 28 | self.embs = embs 29 | self.padding_value = padding_value 30 | 31 | if len(self.tokens) != self.embs.shape[0]: 32 | raise ValueError("The tokens you passed into the BERTField, {} " 33 | "aren't the same size as the embeddings of shape {}".format(self.tokens, self.embs.shape)) 34 | assert len(self.tokens) == self.embs.shape[0] 35 | 36 | @overrides 37 | def sequence_length(self) -> int: 38 | return len(self.tokens) 39 | 40 | 41 | @overrides 42 | def get_padding_lengths(self) -> Dict[str, int]: 43 | return {'num_tokens': self.sequence_length()} 44 | 45 | @overrides 46 | def as_tensor(self, padding_lengths: Dict[str, int]) -> Dict[str, torch.Tensor]: 47 | num_tokens = padding_lengths['num_tokens'] 48 | 49 | new_arr = numpy.ones((num_tokens, self.embs.shape[1]), 50 | dtype=numpy.float32) * self.padding_value 51 | new_arr[:self.sequence_length()] = self.embs 52 | 53 | tensor = torch.from_numpy(new_arr) 54 | return {'bert': tensor} 55 | 56 | @overrides 57 | def empty_field(self): 58 | return BertField([], numpy.array([], dtype="float32"),padding_value=self.padding_value) 59 | 60 | @overrides 61 | def batch_tensors(self, tensor_list: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: 62 | # pylint: disable=no-self-use 63 | # This is creating a dict of {token_indexer_key: batch_tensor} for each token indexer used 64 | # to index this field. 65 | return util.batch_tensor_dicts(tensor_list) 66 | 67 | 68 | def __str__(self) -> str: 69 | return f"BertField: {self.tokens} and {self.embs.shape}." 70 | -------------------------------------------------------------------------------- /dataloaders/box_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import scipy 5 | import warnings 6 | from torchvision.datasets.folder import default_loader 7 | from torchvision.transforms import functional 8 | from config import USE_IMAGENET_PRETRAINED 9 | 10 | 11 | ##### Image 12 | def load_image(img_fn): 13 | """Load the specified image and return a [H,W,3] Numpy array. 14 | """ 15 | return default_loader(img_fn) 16 | # # Load image 17 | # image = skimage.io.imread(img_fn) 18 | # # If grayscale. Convert to RGB for consistency. 19 | # if image.ndim != 3: 20 | # image = skimage.color.gray2rgb(image) 21 | # # If has an alpha channel, remove it for consistency 22 | # if image.shape[-1] == 4: 23 | # image = image[..., :3] 24 | # return image 25 | 26 | 27 | # Let's do 16x9 28 | # Two common resolutions: 16x9 and 16/6 -> go to 16x8 as that's simple 29 | # let's say width is 576. for neural motifs it was 576*576 pixels so 331776. here we have 2x*x = 331776-> 408 base 30 | # so the best thing that's divisible by 4 is 384. that's 31 | def resize_image(image, desired_width=768, desired_height=384, random_pad=False): 32 | """Resizes an image keeping the aspect ratio mostly unchanged. 33 | 34 | Returns: 35 | image: the resized image 36 | window: (x1, y1, x2, y2). If max_dim is provided, padding might 37 | be inserted in the returned image. If so, this window is the 38 | coordinates of the image part of the full image (excluding 39 | the padding). The x2, y2 pixels are not included. 40 | scale: The scale factor used to resize the image 41 | padding: Padding added to the image [left, top, right, bottom] 42 | """ 43 | # Default window (x1, y1, x2, y2) and default scale == 1. 44 | w, h = image.size 45 | 46 | width_scale = desired_width / w 47 | height_scale = desired_height / h 48 | scale = min(width_scale, height_scale) 49 | 50 | # Resize image using bilinear interpolation 51 | if scale != 1: 52 | image = functional.resize(image, (round(h * scale), round(w * scale))) 53 | w, h = image.size 54 | y_pad = desired_height - h 55 | x_pad = desired_width - w 56 | top_pad = random.randint(0, y_pad) if random_pad else y_pad // 2 57 | left_pad = random.randint(0, x_pad) if random_pad else x_pad // 2 58 | 59 | padding = (left_pad, top_pad, x_pad - left_pad, y_pad - top_pad) 60 | assert all([x >= 0 for x in padding]) 61 | image = functional.pad(image, padding) 62 | window = [left_pad, top_pad, w + left_pad, h + top_pad] 63 | 64 | return image, window, scale, padding 65 | 66 | 67 | if USE_IMAGENET_PRETRAINED: 68 | def to_tensor_and_normalize(image): 69 | return functional.normalize(functional.to_tensor(image), mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) 70 | else: 71 | # For COCO pretrained 72 | def to_tensor_and_normalize(image): 73 | tensor255 = functional.to_tensor(image) * 255 74 | return functional.normalize(tensor255, mean=(102.9801, 115.9465, 122.7717), std=(1, 1, 1)) -------------------------------------------------------------------------------- /dataloaders/cocoontology.json: -------------------------------------------------------------------------------- 1 | { 2 | "1": { 3 | "name": "person", 4 | "supercategory": "person" 5 | }, 6 | "2": { 7 | "name": "bicycle", 8 | "supercategory": "vehicle" 9 | }, 10 | "3": { 11 | "name": "car", 12 | "supercategory": "vehicle" 13 | }, 14 | "4": { 15 | "name": "motorcycle", 16 | "supercategory": "vehicle" 17 | }, 18 | "5": { 19 | "name": "airplane", 20 | "supercategory": "vehicle" 21 | }, 22 | "6": { 23 | "name": "bus", 24 | "supercategory": "vehicle" 25 | }, 26 | "7": { 27 | "name": "train", 28 | "supercategory": "vehicle" 29 | }, 30 | "8": { 31 | "name": "truck", 32 | "supercategory": "vehicle" 33 | }, 34 | "9": { 35 | "name": "boat", 36 | "supercategory": "vehicle" 37 | }, 38 | "10": { 39 | "name": "trafficlight", 40 | "supercategory": "furniture" 41 | }, 42 | "11": { 43 | "name": "firehydrant", 44 | "supercategory": "furniture" 45 | }, 46 | "13": { 47 | "name": "stopsign", 48 | "supercategory": "furniture" 49 | }, 50 | "14": { 51 | "name": "parkingmeter", 52 | "supercategory": "furniture" 53 | }, 54 | "15": { 55 | "name": "bench", 56 | "supercategory": "furniture" 57 | }, 58 | "16": { 59 | "name": "bird", 60 | "supercategory": "animal" 61 | }, 62 | "17": { 63 | "name": "cat", 64 | "supercategory": "animal" 65 | }, 66 | "18": { 67 | "name": "dog", 68 | "supercategory": "animal" 69 | }, 70 | "19": { 71 | "name": "horse", 72 | "supercategory": "animal" 73 | }, 74 | "20": { 75 | "name": "sheep", 76 | "supercategory": "animal" 77 | }, 78 | "21": { 79 | "name": "cow", 80 | "supercategory": "animal" 81 | }, 82 | "22": { 83 | "name": "elephant", 84 | "supercategory": "animal" 85 | }, 86 | "23": { 87 | "name": "bear", 88 | "supercategory": "animal" 89 | }, 90 | "24": { 91 | "name": "zebra", 92 | "supercategory": "animal" 93 | }, 94 | "25": { 95 | "name": "giraffe", 96 | "supercategory": "animal" 97 | }, 98 | "27": { 99 | "name": "backpack", 100 | "supercategory": "accessory" 101 | }, 102 | "28": { 103 | "name": "umbrella", 104 | "supercategory": "accessory" 105 | }, 106 | "31": { 107 | "name": "handbag", 108 | "supercategory": "accessory" 109 | }, 110 | "32": { 111 | "name": "tie", 112 | "supercategory": "accessory" 113 | }, 114 | "33": { 115 | "name": "suitcase", 116 | "supercategory": "accessory" 117 | }, 118 | "34": { 119 | "name": "frisbee", 120 | "supercategory": "object" 121 | }, 122 | "35": { 123 | "name": "skis", 124 | "supercategory": "object" 125 | }, 126 | "36": { 127 | "name": "snowboard", 128 | "supercategory": "object" 129 | }, 130 | "37": { 131 | "name": "sportsball", 132 | "supercategory": "object" 133 | }, 134 | "38": { 135 | "name": "kite", 136 | "supercategory": "object" 137 | }, 138 | "39": { 139 | "name": "baseballbat", 140 | "supercategory": "object" 141 | }, 142 | "40": { 143 | "name": "baseballglove", 144 | "supercategory": "object" 145 | }, 146 | "41": { 147 | "name": "skateboard", 148 | "supercategory": "object" 149 | }, 150 | "42": { 151 | "name": "surfboard", 152 | "supercategory": "object" 153 | }, 154 | "43": { 155 | "name": "tennisracket", 156 | "supercategory": "object" 157 | }, 158 | "44": { 159 | "name": "bottle", 160 | "supercategory": "object" 161 | }, 162 | "46": { 163 | "name": "wineglass", 164 | "supercategory": "object" 165 | }, 166 | "47": { 167 | "name": "cup", 168 | "supercategory": "object" 169 | }, 170 | "48": { 171 | "name": "fork", 172 | "supercategory": "object" 173 | }, 174 | "49": { 175 | "name": "knife", 176 | "supercategory": "object" 177 | }, 178 | "50": { 179 | "name": "spoon", 180 | "supercategory": "object" 181 | }, 182 | "51": { 183 | "name": "bowl", 184 | "supercategory": "object" 185 | }, 186 | "52": { 187 | "name": "banana", 188 | "supercategory": "food" 189 | }, 190 | "53": { 191 | "name": "apple", 192 | "supercategory": "food" 193 | }, 194 | "54": { 195 | "name": "sandwich", 196 | "supercategory": "food" 197 | }, 198 | "55": { 199 | "name": "orange", 200 | "supercategory": "food" 201 | }, 202 | "56": { 203 | "name": "broccoli", 204 | "supercategory": "food" 205 | }, 206 | "57": { 207 | "name": "carrot", 208 | "supercategory": "food" 209 | }, 210 | "58": { 211 | "name": "hotdog", 212 | "supercategory": "food" 213 | }, 214 | "59": { 215 | "name": "pizza", 216 | "supercategory": "food" 217 | }, 218 | "60": { 219 | "name": "donut", 220 | "supercategory": "food" 221 | }, 222 | "61": { 223 | "name": "cake", 224 | "supercategory": "food" 225 | }, 226 | "62": { 227 | "name": "chair", 228 | "supercategory": "furniture" 229 | }, 230 | "63": { 231 | "name": "couch", 232 | "supercategory": "furniture" 233 | }, 234 | "64": { 235 | "name": "pottedplant", 236 | "supercategory": "furniture" 237 | }, 238 | "65": { 239 | "name": "bed", 240 | "supercategory": "furniture" 241 | }, 242 | "67": { 243 | "name": "diningtable", 244 | "supercategory": "furniture" 245 | }, 246 | "70": { 247 | "name": "toilet", 248 | "supercategory": "furniture" 249 | }, 250 | "72": { 251 | "name": "tv", 252 | "supercategory": "object" 253 | }, 254 | "73": { 255 | "name": "laptop", 256 | "supercategory": "object" 257 | }, 258 | "74": { 259 | "name": "mouse", 260 | "supercategory": "object" 261 | }, 262 | "75": { 263 | "name": "remote", 264 | "supercategory": "object" 265 | }, 266 | "76": { 267 | "name": "keyboard", 268 | "supercategory": "object" 269 | }, 270 | "77": { 271 | "name": "cellphone", 272 | "supercategory": "object" 273 | }, 274 | "78": { 275 | "name": "microwave", 276 | "supercategory": "object" 277 | }, 278 | "79": { 279 | "name": "oven", 280 | "supercategory": "object" 281 | }, 282 | "80": { 283 | "name": "toaster", 284 | "supercategory": "object" 285 | }, 286 | "81": { 287 | "name": "sink", 288 | "supercategory": "object" 289 | }, 290 | "82": { 291 | "name": "refrigerator", 292 | "supercategory": "object" 293 | }, 294 | "84": { 295 | "name": "book", 296 | "supercategory": "object" 297 | }, 298 | "85": { 299 | "name": "clock", 300 | "supercategory": "object" 301 | }, 302 | "86": { 303 | "name": "vase", 304 | "supercategory": "object" 305 | }, 306 | "87": { 307 | "name": "scissors", 308 | "supercategory": "object" 309 | }, 310 | "88": { 311 | "name": "teddybear", 312 | "supercategory": "object" 313 | }, 314 | "89": { 315 | "name": "hairdrier", 316 | "supercategory": "object" 317 | }, 318 | "90": { 319 | "name": "toothbrush", 320 | "supercategory": "object" 321 | } 322 | } -------------------------------------------------------------------------------- /dataloaders/mask_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib 3 | from matplotlib import path 4 | matplotlib.use('agg') 5 | 6 | 7 | def _spaced_points(low, high,n): 8 | """ We want n points between low and high, but we don't want them to touch either side""" 9 | padding = (high-low)/(n*2) 10 | return np.linspace(low + padding, high-padding, num=n) 11 | 12 | def make_mask(mask_size, box, polygons_list): 13 | """ 14 | Mask size: int about how big mask will be 15 | box: [x1, y1, x2, y2, conf.] 16 | polygons_list: List of polygons that go inside the box 17 | """ 18 | mask = np.zeros((mask_size, mask_size), dtype=np.bool) 19 | 20 | xy = np.meshgrid(_spaced_points(box[0], box[2], n=mask_size), 21 | _spaced_points(box[1], box[3], n=mask_size)) 22 | xy_flat = np.stack(xy, 2).reshape((-1, 2)) 23 | 24 | for polygon in polygons_list: 25 | polygon_path = path.Path(polygon) 26 | mask |= polygon_path.contains_points(xy_flat).reshape((mask_size, mask_size)) 27 | return mask.astype(np.float32) 28 | # 29 | #from matplotlib import pyplot as plt 30 | # 31 | # 32 | #with open('XdtbL0dP0X0@44.json', 'r') as f: 33 | # metadata = json.load(f) 34 | #from time import time 35 | #s = time() 36 | #for i in range(100): 37 | # mask = make_mask(14, metadata['boxes'][3], metadata['segms'][3]) 38 | #print("Elapsed {:3f}s".format(time()-s)) 39 | #plt.imshow(mask) -------------------------------------------------------------------------------- /dataloaders/vcr.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataloaders for VCR 3 | """ 4 | import json 5 | import os 6 | 7 | import numpy as np 8 | import torch 9 | from allennlp.data.dataset import Batch 10 | from allennlp.data.fields import TextField, ListField, LabelField, SequenceLabelField, ArrayField, MetadataField 11 | from allennlp.data.instance import Instance 12 | from allennlp.data.token_indexers import ELMoTokenCharactersIndexer 13 | from allennlp.data.tokenizers import Token 14 | from allennlp.data.vocabulary import Vocabulary 15 | from allennlp.nn.util import get_text_field_mask 16 | from torch.utils.data import Dataset 17 | from dataloaders.box_utils import load_image, resize_image, to_tensor_and_normalize 18 | from dataloaders.mask_utils import make_mask 19 | from dataloaders.bert_field import BertField 20 | import h5py 21 | from copy import deepcopy 22 | from config import VCR_IMAGES_DIR, VCR_ANNOTS_DIR 23 | 24 | GENDER_NEUTRAL_NAMES = ['Casey', 'Riley', 'Jessie', 'Jackie', 'Avery', 'Jaime', 'Peyton', 'Kerry', 'Jody', 'Kendall', 25 | 'Peyton', 'Skyler', 'Frankie', 'Pat', 'Quinn'] 26 | 27 | 28 | # Here's an example jsonl 29 | # { 30 | # "movie": "3015_CHARLIE_ST_CLOUD", 31 | # "objects": ["person", "person", "person", "car"], 32 | # "interesting_scores": [0], 33 | # "answer_likelihood": "possible", 34 | # "img_fn": "lsmdc_3015_CHARLIE_ST_CLOUD/3015_CHARLIE_ST_CLOUD_00.23.57.935-00.24.00.783@0.jpg", 35 | # "metadata_fn": "lsmdc_3015_CHARLIE_ST_CLOUD/3015_CHARLIE_ST_CLOUD_00.23.57.935-00.24.00.783@0.json", 36 | # "answer_orig": "No she does not", 37 | # "question_orig": "Does 3 feel comfortable?", 38 | # "rationale_orig": "She is standing with her arms crossed and looks disturbed", 39 | # "question": ["Does", [2], "feel", "comfortable", "?"], 40 | # "answer_match_iter": [3, 0, 2, 1], 41 | # "answer_sources": [3287, 0, 10184, 2260], 42 | # "answer_choices": [ 43 | # ["Yes", "because", "the", "person", "sitting", "next", "to", "her", "is", "smiling", "."], 44 | # ["No", "she", "does", "not", "."], 45 | # ["Yes", ",", "she", "is", "wearing", "something", "with", "thin", "straps", "."], 46 | # ["Yes", ",", "she", "is", "cold", "."]], 47 | # "answer_label": 1, 48 | # "rationale_choices": [ 49 | # ["There", "is", "snow", "on", "the", "ground", ",", "and", 50 | # "she", "is", "wearing", "a", "coat", "and", "hate", "."], 51 | # ["She", "is", "standing", "with", "her", "arms", "crossed", "and", "looks", "disturbed", "."], 52 | # ["She", "is", "sitting", "very", "rigidly", "and", "tensely", "on", "the", "edge", "of", "the", 53 | # "bed", ".", "her", "posture", "is", "not", "relaxed", "and", "her", "face", "looks", "serious", "."], 54 | # [[2], "is", "laying", "in", "bed", "but", "not", "sleeping", ".", 55 | # "she", "looks", "sad", "and", "is", "curled", "into", "a", "ball", "."]], 56 | # "rationale_sources": [1921, 0, 9750, 25743], 57 | # "rationale_match_iter": [3, 0, 2, 1], 58 | # "rationale_label": 1, 59 | # "img_id": "train-0", 60 | # "question_number": 0, 61 | # "annot_id": "train-0", 62 | # "match_fold": "train-0", 63 | # "match_index": 0, 64 | # } 65 | 66 | def _fix_tokenization(tokenized_sent, bert_embs, old_det_to_new_ind, obj_to_type, token_indexers, pad_ind=-1): 67 | """ 68 | Turn a detection list into what we want: some text, as well as some tags. 69 | :param tokenized_sent: Tokenized sentence with detections collapsed to a list. 70 | :param old_det_to_new_ind: Mapping of the old ID -> new ID (which will be used as the tag) 71 | :param obj_to_type: [person, person, pottedplant] indexed by the old labels 72 | :return: tokenized sentence 73 | """ 74 | 75 | new_tokenization_with_tags = [] 76 | for tok in tokenized_sent: 77 | if isinstance(tok, list): 78 | for int_name in tok: 79 | obj_type = obj_to_type[int_name] 80 | new_ind = old_det_to_new_ind[int_name] 81 | if new_ind < 0: 82 | raise ValueError("Oh no, the new index is negative! that means it's invalid. {} {}".format( 83 | tokenized_sent, old_det_to_new_ind 84 | )) 85 | text_to_use = GENDER_NEUTRAL_NAMES[ 86 | new_ind % len(GENDER_NEUTRAL_NAMES)] if obj_type == 'person' else obj_type 87 | new_tokenization_with_tags.append((text_to_use, new_ind)) 88 | else: 89 | new_tokenization_with_tags.append((tok, pad_ind)) 90 | 91 | text_field = BertField([Token(x[0]) for x in new_tokenization_with_tags], 92 | bert_embs, 93 | padding_value=0) 94 | tags = SequenceLabelField([x[1] for x in new_tokenization_with_tags], text_field) 95 | return text_field, tags 96 | 97 | 98 | class VCR(Dataset): 99 | def __init__(self, split, mode, only_use_relevant_dets=True, add_image_as_a_box=True, embs_to_load='bert_da', 100 | conditioned_answer_choice=0): 101 | """ 102 | 103 | :param split: train, val, or test 104 | :param mode: answer or rationale 105 | :param only_use_relevant_dets: True, if we will only use the detections mentioned in the question and answer. 106 | False, if we should use all detections. 107 | :param add_image_as_a_box: True to add the image in as an additional 'detection'. It'll go first in the list 108 | of objects. 109 | :param embs_to_load: Which precomputed embeddings to load. 110 | :param conditioned_answer_choice: If you're in test mode, the answer labels aren't provided, which could be 111 | a problem for the QA->R task. Pass in 'conditioned_answer_choice=i' 112 | to always condition on the i-th answer. 113 | """ 114 | self.split = split 115 | self.mode = mode 116 | self.only_use_relevant_dets = only_use_relevant_dets 117 | print("Only relevant dets" if only_use_relevant_dets else "Using all detections", flush=True) 118 | 119 | self.add_image_as_a_box = add_image_as_a_box 120 | self.conditioned_answer_choice = conditioned_answer_choice 121 | 122 | with open(os.path.join(VCR_ANNOTS_DIR, '{}.jsonl'.format(split)), 'r') as f: 123 | self.items = [json.loads(s) for s in f] 124 | 125 | if split not in ('test', 'train', 'val'): 126 | raise ValueError("Mode must be in test, train, or val. Supplied {}".format(mode)) 127 | 128 | if mode not in ('answer', 'rationale'): 129 | raise ValueError("split must be answer or rationale") 130 | 131 | self.token_indexers = {'elmo': ELMoTokenCharactersIndexer()} 132 | self.vocab = Vocabulary() 133 | 134 | with open(os.path.join(os.path.dirname(VCR_ANNOTS_DIR), 'dataloaders', 'cocoontology.json'), 'r') as f: 135 | coco = json.load(f) 136 | self.coco_objects = ['__background__'] + [x['name'] for k, x in sorted(coco.items(), key=lambda x: int(x[0]))] 137 | self.coco_obj_to_ind = {o: i for i, o in enumerate(self.coco_objects)} 138 | 139 | self.embs_to_load = embs_to_load 140 | self.h5fn = os.path.join(VCR_ANNOTS_DIR, f'{self.embs_to_load}_{self.mode}_{self.split}.h5') 141 | print("Loading embeddings from {}".format(self.h5fn), flush=True) 142 | 143 | @property 144 | def is_train(self): 145 | return self.split == 'train' 146 | 147 | @classmethod 148 | def splits(cls, **kwargs): 149 | """ Helper method to generate splits of the dataset""" 150 | kwargs_copy = {x: y for x, y in kwargs.items()} 151 | if 'mode' not in kwargs: 152 | kwargs_copy['mode'] = 'answer' 153 | train = cls(split='train', **kwargs_copy) 154 | val = cls(split='val', **kwargs_copy) 155 | test = cls(split='test', **kwargs_copy) 156 | return train, val, test 157 | 158 | @classmethod 159 | def eval_splits(cls, **kwargs): 160 | """ Helper method to generate splits of the dataset. Use this for testing, because it will 161 | condition on everything.""" 162 | for forbidden_key in ['mode', 'split', 'conditioned_answer_choice']: 163 | if forbidden_key in kwargs: 164 | raise ValueError(f"don't supply {forbidden_key} to eval_splits()") 165 | 166 | stuff_to_return = [cls(split='test', mode='answer', **kwargs)] + [ 167 | cls(split='test', mode='rationale', conditioned_answer_choice=i, **kwargs) for i in range(4)] 168 | return tuple(stuff_to_return) 169 | 170 | def __len__(self): 171 | return len(self.items) 172 | 173 | def _get_dets_to_use(self, item): 174 | """ 175 | We might want to use fewer detectiosn so lets do so. 176 | :param item: 177 | :param question: 178 | :param answer_choices: 179 | :return: 180 | """ 181 | # Load questions and answers 182 | question = item['question'] 183 | answer_choices = item['{}_choices'.format(self.mode)] 184 | 185 | if self.only_use_relevant_dets: 186 | dets2use = np.zeros(len(item['objects']), dtype=bool) 187 | people = np.array([x == 'person' for x in item['objects']], dtype=bool) 188 | for sent in answer_choices + [question]: 189 | for possibly_det_list in sent: 190 | if isinstance(possibly_det_list, list): 191 | for tag in possibly_det_list: 192 | if tag >= 0 and tag < len(item['objects']): # sanity check 193 | dets2use[tag] = True 194 | elif possibly_det_list.lower() in ('everyone', 'everyones'): 195 | dets2use |= people 196 | if not dets2use.any(): 197 | dets2use |= people 198 | else: 199 | dets2use = np.ones(len(item['objects']), dtype=bool) 200 | 201 | # we will use these detections 202 | dets2use = np.where(dets2use)[0] 203 | 204 | old_det_to_new_ind = np.zeros(len(item['objects']), dtype=np.int32) - 1 205 | old_det_to_new_ind[dets2use] = np.arange(dets2use.shape[0], dtype=np.int32) 206 | 207 | # If we add the image as an extra box then the 0th will be the image. 208 | if self.add_image_as_a_box: 209 | old_det_to_new_ind[dets2use] += 1 210 | old_det_to_new_ind = old_det_to_new_ind.tolist() 211 | return dets2use, old_det_to_new_ind 212 | 213 | def __getitem__(self, index): 214 | # if self.split == 'test': 215 | # raise ValueError("blind test mode not supported quite yet") 216 | item = deepcopy(self.items[index]) 217 | 218 | ################################################################### 219 | # Load questions and answers 220 | if self.mode == 'rationale': 221 | conditioned_label = item['answer_label'] if self.split != 'test' else self.conditioned_answer_choice 222 | item['question'] += item['answer_choices'][conditioned_label] 223 | 224 | answer_choices = item['{}_choices'.format(self.mode)] 225 | dets2use, old_det_to_new_ind = self._get_dets_to_use(item) 226 | 227 | ################################################################### 228 | # Load in BERT. We'll get contextual representations of the context and the answer choices 229 | # grp_items = {k: np.array(v, dtype=np.float16) for k, v in self.get_h5_group(index).items()} 230 | with h5py.File(self.h5fn, 'r') as h5: 231 | grp_items = {k: np.array(v, dtype=np.float16) for k, v in h5[str(index)].items()} 232 | 233 | # Essentially we need to condition on the right answer choice here, if we're doing QA->R. We will always 234 | # condition on the `conditioned_answer_choice.` 235 | condition_key = self.conditioned_answer_choice if self.split == "test" and self.mode == "rationale" else "" 236 | 237 | instance_dict = {} 238 | if 'endingonly' not in self.embs_to_load: 239 | questions_tokenized, question_tags = zip(*[_fix_tokenization( 240 | item['question'], 241 | grp_items[f'ctx_{self.mode}{condition_key}{i}'], 242 | old_det_to_new_ind, 243 | item['objects'], 244 | token_indexers=self.token_indexers, 245 | pad_ind=0 if self.add_image_as_a_box else -1 246 | ) for i in range(4)]) 247 | instance_dict['question'] = ListField(questions_tokenized) 248 | instance_dict['question_tags'] = ListField(question_tags) 249 | 250 | answers_tokenized, answer_tags = zip(*[_fix_tokenization( 251 | answer, 252 | grp_items[f'answer_{self.mode}{condition_key}{i}'], 253 | old_det_to_new_ind, 254 | item['objects'], 255 | token_indexers=self.token_indexers, 256 | pad_ind=0 if self.add_image_as_a_box else -1 257 | ) for i, answer in enumerate(answer_choices)]) 258 | 259 | instance_dict['answers'] = ListField(answers_tokenized) 260 | instance_dict['answer_tags'] = ListField(answer_tags) 261 | if self.split != 'test': 262 | instance_dict['label'] = LabelField(item['{}_label'.format(self.mode)], skip_indexing=True) 263 | instance_dict['metadata'] = MetadataField({'annot_id': item['annot_id'], 'ind': index, 'movie': item['movie'], 264 | 'img_fn': item['img_fn'], 265 | 'question_number': item['question_number']}) 266 | 267 | ################################################################### 268 | # Load image now and rescale it. Might have to subtract the mean and whatnot here too. 269 | image = load_image(os.path.join(VCR_IMAGES_DIR, item['img_fn'])) 270 | image, window, img_scale, padding = resize_image(image, random_pad=self.is_train) 271 | image = to_tensor_and_normalize(image) 272 | c, h, w = image.shape 273 | 274 | ################################################################### 275 | # Load boxes. 276 | with open(os.path.join(VCR_IMAGES_DIR, item['metadata_fn']), 'r') as f: 277 | metadata = json.load(f) 278 | 279 | # [nobj, 14, 14] 280 | segms = np.stack([make_mask(mask_size=14, box=metadata['boxes'][i], polygons_list=metadata['segms'][i]) 281 | for i in dets2use]) 282 | 283 | # Chop off the final dimension, that's the confidence 284 | boxes = np.array(metadata['boxes'])[dets2use, :-1] 285 | # Possibly rescale them if necessary 286 | boxes *= img_scale 287 | boxes[:, :2] += np.array(padding[:2])[None] 288 | boxes[:, 2:] += np.array(padding[:2])[None] 289 | obj_labels = [self.coco_obj_to_ind[item['objects'][i]] for i in dets2use.tolist()] 290 | if self.add_image_as_a_box: 291 | boxes = np.row_stack((window, boxes)) 292 | segms = np.concatenate((np.ones((1, 14, 14), dtype=np.float32), segms), 0) 293 | obj_labels = [self.coco_obj_to_ind['__background__']] + obj_labels 294 | 295 | instance_dict['segms'] = ArrayField(segms, padding_value=0) 296 | instance_dict['objects'] = ListField([LabelField(x, skip_indexing=True) for x in obj_labels]) 297 | 298 | if not np.all((boxes[:, 0] >= 0.) & (boxes[:, 0] < boxes[:, 2])): 299 | import ipdb 300 | ipdb.set_trace() 301 | assert np.all((boxes[:, 1] >= 0.) & (boxes[:, 1] < boxes[:, 3])) 302 | assert np.all((boxes[:, 2] <= w)) 303 | assert np.all((boxes[:, 3] <= h)) 304 | instance_dict['boxes'] = ArrayField(boxes, padding_value=-1) 305 | 306 | instance = Instance(instance_dict) 307 | instance.index_fields(self.vocab) 308 | return image, instance 309 | 310 | 311 | def collate_fn(data, to_gpu=False): 312 | """Creates mini-batch tensors 313 | """ 314 | images, instances = zip(*data) 315 | images = torch.stack(images, 0) 316 | batch = Batch(instances) 317 | td = batch.as_tensor_dict() 318 | if 'question' in td: 319 | td['question_mask'] = get_text_field_mask(td['question'], num_wrapping_dims=1) 320 | td['question_tags'][td['question_mask'] == 0] = -2 # Padding 321 | 322 | td['answer_mask'] = get_text_field_mask(td['answers'], num_wrapping_dims=1) 323 | td['answer_tags'][td['answer_mask'] == 0] = -2 324 | 325 | td['box_mask'] = torch.all(td['boxes'] >= 0, -1).long() 326 | td['images'] = images 327 | 328 | # Deprecated 329 | # if to_gpu: 330 | # for k in td: 331 | # if k != 'metadata': 332 | # td[k] = {k2: v.cuda(non_blocking=True) for k2, v in td[k].items()} if isinstance(td[k], dict) else td[k].cuda( 333 | # non_blocking=True) 334 | 335 | # # No nested dicts 336 | # for k in sorted(td.keys()): 337 | # if isinstance(td[k], dict): 338 | # for k2 in sorted(td[k].keys()): 339 | # td['{}_{}'.format(k, k2)] = td[k].pop(k2) 340 | # td.pop(k) 341 | 342 | return td 343 | 344 | 345 | class VCRLoader(torch.utils.data.DataLoader): 346 | """ 347 | Iterates through the data, filtering out None, 348 | but also loads everything as a (cuda) variable 349 | """ 350 | 351 | @classmethod 352 | def from_dataset(cls, data, batch_size=3, num_workers=6, num_gpus=3, **kwargs): 353 | loader = cls( 354 | dataset=data, 355 | batch_size=batch_size * num_gpus, 356 | shuffle=data.is_train, 357 | num_workers=num_workers, 358 | collate_fn=lambda x: collate_fn(x, to_gpu=False), 359 | drop_last=data.is_train, 360 | pin_memory=False, 361 | **kwargs, 362 | ) 363 | return loader 364 | 365 | # You could use this for debugging maybe 366 | # if __name__ == '__main__': 367 | # train, val, test = VCR.splits() 368 | # for i in range(len(train)): 369 | # res = train[i] 370 | # print("done with {}".format(i)) 371 | -------------------------------------------------------------------------------- /multiatt/default.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "vswag" 4 | }, 5 | "model": { 6 | "type": "MultiHopAttentionQA", 7 | "span_encoder": { 8 | "type": "lstm", 9 | "input_size": 1280, 10 | "hidden_size": 256, 11 | "num_layers": 1, 12 | "bidirectional": true 13 | }, 14 | "reasoning_encoder": { 15 | "type": "lstm", 16 | "input_size": 1536, 17 | "hidden_size": 256, 18 | "num_layers": 2, 19 | "bidirectional": true 20 | }, 21 | "hidden_dim_maxpool": 1024, 22 | "input_dropout": 0.3, 23 | "pool_question": true, 24 | "pool_answer": true, 25 | "initializer": [ 26 | [".*final_mlp.*weight", {"type": "xavier_uniform"}], 27 | [".*final_mlp.*bias", {"type": "zero"}], 28 | [".*weight_ih.*", {"type": "xavier_uniform"}], 29 | [".*weight_hh.*", {"type": "orthogonal"}], 30 | [".*bias_ih.*", {"type": "zero"}], 31 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}]] 32 | }, 33 | "trainer": { 34 | "optimizer": { 35 | "type": "adam", 36 | "lr": 0.0002, 37 | "weight_decay": 0.0001 38 | }, 39 | "validation_metric": "+accuracy", 40 | "num_serialized_models_to_keep": 2, 41 | "num_epochs": 40, 42 | "grad_norm": 1.0, 43 | "patience": 3, 44 | "cuda_device": 0, 45 | "learning_rate_scheduler": { 46 | "type": "reduce_on_plateau", 47 | "factor": 0.5, 48 | "mode": "max", 49 | "patience": 1, 50 | "verbose": true, 51 | "cooldown": 2 52 | } 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /pic/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AmingWu/CCN/2254049e5b090952406f14fbf4a63b4e30e99a17/pic/fig1.png -------------------------------------------------------------------------------- /pic/fig2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AmingWu/CCN/2254049e5b090952406f14fbf4a63b4e30e99a17/pic/fig2.png -------------------------------------------------------------------------------- /pic/fig3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AmingWu/CCN/2254049e5b090952406f14fbf4a63b4e30e99a17/pic/fig3.png -------------------------------------------------------------------------------- /poster/poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AmingWu/CCN/2254049e5b090952406f14fbf4a63b4e30e99a17/poster/poster.pdf -------------------------------------------------------------------------------- /train/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Training script. Should be pretty adaptable to whatever. 3 | """ 4 | import argparse 5 | import os 6 | import shutil 7 | 8 | import multiprocessing 9 | import numpy as np 10 | import pandas as pd 11 | import torch 12 | from allennlp.common.params import Params 13 | from allennlp.training.learning_rate_schedulers import LearningRateScheduler 14 | from allennlp.training.optimizers import Optimizer 15 | from torch.nn import DataParallel 16 | from torch.nn.modules import BatchNorm2d 17 | from tqdm import tqdm 18 | 19 | from dataloaders.vcr import VCR, VCRLoader 20 | from utils.pytorch_misc import time_batch, save_checkpoint, clip_grad_norm, \ 21 | restore_checkpoint, print_para, restore_best_checkpoint 22 | 23 | import logging 24 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', level=logging.DEBUG) 25 | import pdb 26 | 27 | # This is needed to make the imports work 28 | #from allennlp.models import Model 29 | #import models 30 | 31 | from typing import Dict, List, Any 32 | import torch 33 | import torch.nn.functional as F 34 | import torch.nn.parallel 35 | from allennlp.data.vocabulary import Vocabulary 36 | from allennlp.models.model import Model 37 | from allennlp.modules import TextFieldEmbedder, Seq2SeqEncoder, FeedForward, InputVariationalDropout, TimeDistributed 38 | from allennlp.training.metrics import CategoricalAccuracy 39 | from allennlp.modules.matrix_attention import BilinearMatrixAttention 40 | from utils.newdetector import SimpleDetector 41 | from allennlp.nn.util import masked_softmax, weighted_sum, replace_masked_values 42 | from allennlp.nn import InitializerApplicator 43 | @Model.register("MultiHopAttentionQA") 44 | class AttentionQA(Model): 45 | def __init__(self, 46 | vocab: Vocabulary, 47 | span_encoder: Seq2SeqEncoder, 48 | reasoning_encoder: Seq2SeqEncoder, 49 | input_dropout: float = 0.5, 50 | hidden_dim_maxpool: int = 512, 51 | class_embs: bool=True, 52 | num_cluster: int = 32, 53 | reasoning_use_obj: bool=True, 54 | reasoning_use_answer: bool=True, 55 | reasoning_use_question: bool=True, 56 | pool_reasoning: bool = True, 57 | pool_answer: bool = True, 58 | pool_question: bool = False, 59 | initializer: InitializerApplicator = InitializerApplicator(), 60 | ): 61 | super(AttentionQA, self).__init__(vocab) 62 | 63 | self.detector = SimpleDetector(pretrained=True, average_pool=True, semantic=class_embs, final_dim=512) 64 | ################################################################################################### 65 | 66 | self.rnn_input_dropout = TimeDistributed(InputVariationalDropout(input_dropout)) if input_dropout > 0 else None 67 | 68 | self.span_encoder = TimeDistributed(span_encoder) 69 | self.reasoning_encoder = TimeDistributed(reasoning_encoder) 70 | 71 | self.qnode_graph11 = torch.nn.Sequential( 72 | torch.nn.Conv2d(512*3, 512, 1), 73 | torch.nn.Sigmoid(), 74 | ) 75 | self.qnode_graph12 = torch.nn.Sequential( 76 | torch.nn.Conv2d(512*3, 512, 1), 77 | torch.nn.Tanh(), 78 | ) 79 | 80 | self.anode_graph11 = torch.nn.Sequential( 81 | torch.nn.Conv2d(512*3, 512, 1), 82 | torch.nn.Sigmoid(), 83 | ) 84 | self.anode_graph12 = torch.nn.Sequential( 85 | torch.nn.Conv2d(512*3, 512, 1), 86 | torch.nn.Tanh(), 87 | ) 88 | 89 | self.scene_reps = torch.nn.Sequential( 90 | torch.nn.Conv2d(1024, 512, 1), 91 | torch.nn.MaxPool2d(2, stride=2) 92 | ) 93 | 94 | self.netvlad = torch.nn.Sequential( 95 | torch.nn.Conv2d(512, 32, 1, bias=True) 96 | ) 97 | 98 | self.vladchannel = torch.nn.Sequential( 99 | torch.nn.Linear(1024, 512) 100 | ) 101 | 102 | self.scene_graph0 = torch.nn.Sequential( 103 | torch.nn.Conv2d(512, 512, 1), 104 | torch.nn.Sigmoid(), 105 | ) 106 | self.scene_graph1 = torch.nn.Sequential( 107 | torch.nn.Conv2d(512, 512, 1), 108 | torch.nn.Tanh(), 109 | ) 110 | 111 | self.img_graph0 = torch.nn.Sequential( 112 | torch.nn.Conv2d(512, 512, 1), 113 | torch.nn.Sigmoid(), 114 | ) 115 | self.img_graph1 = torch.nn.Sequential( 116 | torch.nn.Conv2d(512, 512, 1), 117 | torch.nn.Tanh(), 118 | ) 119 | 120 | self.obj_graph0 = torch.nn.Sequential( 121 | torch.nn.Conv2d(512, 512, 1), 122 | torch.nn.Sigmoid(), 123 | ) 124 | self.obj_graph1 = torch.nn.Sequential( 125 | torch.nn.Conv2d(512, 512, 1), 126 | torch.nn.Tanh(), 127 | ) 128 | 129 | self.dropout = torch.nn.Sequential( 130 | torch.nn.Dropout(input_dropout, inplace=False), 131 | ) 132 | 133 | self.centroids = torch.nn.Parameter(torch.rand(num_cluster, 512)) 134 | self.num_cluster = num_cluster 135 | 136 | self.gama = torch.nn.Sequential( 137 | torch.nn.Conv2d(1024, 512, 1), 138 | torch.nn.ReLU(inplace=False), 139 | torch.nn.Dropout(p=0.5), 140 | torch.nn.Conv2d(512, 512, 1), 141 | ) 142 | 143 | self.fusion_conv = torch.nn.Sequential( 144 | torch.nn.Conv2d(512, 512, 1), 145 | ) 146 | 147 | self.question_conv = torch.nn.Sequential( 148 | torch.nn.Conv2d(512, 512, 1), 149 | ) 150 | 151 | self.answer_conv = torch.nn.Sequential( 152 | torch.nn.Conv2d(512, 512, 1), 153 | ) 154 | 155 | self.reason_conv = torch.nn.Sequential( 156 | torch.nn.Conv2d(512*2, 512, 1), 157 | ) 158 | self.reasoning1 = torch.nn.Sequential( 159 | torch.nn.Conv2d(512*2, 512, 1), 160 | torch.nn.Sigmoid(), 161 | ) 162 | self.reasoning2 = torch.nn.Sequential( 163 | torch.nn.Conv2d(512*2, 512, 1), 164 | torch.nn.Tanh(), 165 | ) 166 | 167 | self.final_mlp = torch.nn.Sequential( 168 | torch.nn.Dropout(input_dropout, inplace=False), 169 | torch.nn.Linear(1024+512, 1024), 170 | torch.nn.ReLU(inplace=True), 171 | torch.nn.Dropout(input_dropout, inplace=False), 172 | torch.nn.Linear(1024, 1), 173 | ) 174 | self._accuracy = CategoricalAccuracy() 175 | self._loss = torch.nn.CrossEntropyLoss() 176 | initializer(self) 177 | 178 | def _collect_obj_reps(self, span_tags, object_reps): 179 | """ 180 | Collect span-level object representations 181 | :param span_tags: [batch_size, ..leading_dims.., L] 182 | :param object_reps: [batch_size, max_num_objs_per_batch, obj_dim] 183 | :return: 184 | """ 185 | span_tags_fixed = torch.clamp(span_tags, min=0) # In case there were masked values here 186 | row_id = span_tags_fixed.new_zeros(span_tags_fixed.shape) 187 | row_id_broadcaster = torch.arange(0, row_id.shape[0], step=1, device=row_id.device)[:, None] 188 | 189 | # Add extra diminsions to the row broadcaster so it matches row_id 190 | leading_dims = len(span_tags.shape) - 2 191 | for i in range(leading_dims): 192 | row_id_broadcaster = row_id_broadcaster[..., None] 193 | row_id += row_id_broadcaster 194 | return object_reps[row_id.view(-1), span_tags_fixed.view(-1)].view(*span_tags_fixed.shape, -1) 195 | 196 | def embed_span(self, span, span_tags, span_mask, object_reps): 197 | """ 198 | :param span: Thing that will get embed and turned into [batch_size, ..leading_dims.., L, word_dim] 199 | :param span_tags: [batch_size, ..leading_dims.., L] 200 | :param object_reps: [batch_size, max_num_objs_per_batch, obj_dim] 201 | :param span_mask: [batch_size, ..leading_dims.., span_mask 202 | :return: 203 | """ 204 | retrieved_feats = self._collect_obj_reps(span_tags, object_reps) 205 | 206 | span_rep = torch.cat((span['bert'], retrieved_feats), -1) 207 | # add recurrent dropout here 208 | if self.rnn_input_dropout: 209 | span_rep = self.rnn_input_dropout(span_rep) 210 | 211 | return self.span_encoder(span_rep, span_mask), retrieved_feats 212 | 213 | def img_graph(self, img_feats): 214 | N, C, H, W = img_feats.shape[0:] 215 | x = img_feats.view(N, C, -1).permute(0,2,1) 216 | num = x.shape[1] 217 | diag = torch.ones(num) 218 | diag = torch.diag(diag) 219 | scene_graph = torch.matmul(x, x.permute(0,2,1)) 220 | 221 | ones = torch.ones(num, num) 222 | diag_zero = ones - diag 223 | scene_graph = masked_softmax(scene_graph, diag_zero[None,...].cuda(), dim=1) + diag[None,...].cuda() 224 | 225 | scene_value = torch.matmul(scene_graph, x)[:,:,None] 226 | scene_conv1 = self.img_graph0(scene_value.permute(0,3,1,2)) 227 | scene_conv2 = self.img_graph1(scene_value.permute(0,3,1,2)) 228 | scene_conv = scene_conv1 * scene_conv2 229 | scene_c = scene_conv.squeeze(-1) 230 | scene = scene_c.view(N, scene_c.shape[1], H, W) 231 | return scene 232 | 233 | def obj_graph(self, obj_reps): 234 | B, N, C = obj_reps.shape[0:] 235 | x = obj_reps 236 | diag = torch.ones(N) 237 | diag = torch.diag(diag) 238 | object_graph = torch.matmul(x, x.permute(0,2,1)) 239 | 240 | ones = torch.ones(N, N) 241 | diag_zero = ones - diag 242 | object_graph = masked_softmax(object_graph, diag_zero[None,...].cuda(), dim=1) + diag[None,...].cuda() 243 | 244 | obj_value = torch.matmul(object_graph, x)[:,:,None] 245 | obj_conv1 = self.obj_graph0(obj_value.permute(0,3,1,2)) 246 | obj_conv2 = self.obj_graph1(obj_value.permute(0,3,1,2)) 247 | obj_conv = obj_conv1 * obj_conv2 248 | obj_c = obj_conv.squeeze(-1).permute(0,2,1) 249 | return obj_c 250 | 251 | def vlad(self, scene, q_final): 252 | x = scene 253 | N, C, W, H = x.shape[0:] 254 | 255 | q_rep = q_final[:,:,None,None].repeat(1,1,W,H) 256 | s_q = torch.cat([scene, q_rep], 1) 257 | gama = self.gama(s_q) 258 | gama = gama.permute(0,2,3,1).reshape(N, W*H, C).permute(0, 2, 1) 259 | 260 | x = F.normalize(x, p=2, dim=1) 261 | soft_assign = self.netvlad(x) 262 | 263 | soft_assign = F.softmax(soft_assign, dim=1) 264 | soft_assign = soft_assign.view(soft_assign.shape[0], soft_assign.shape[1], -1) 265 | 266 | x_flatten = x.view(N, C, -1) 267 | 268 | x1 = x_flatten.expand(self.num_cluster, -1, -1, -1).permute(1, 0, 2, 3) 269 | x2 = self.centroids.expand(x_flatten.size(-1), -1, -1).permute(1, 2, 0).unsqueeze(0) 270 | 271 | gama_new = gama.expand(self.num_cluster, -1, -1, -1).permute(1, 0, 2, 3) 272 | x2 = gama_new * x2 273 | 274 | residual = x1 - x2 275 | residual = residual * soft_assign.unsqueeze(2) 276 | vlad = residual.sum(dim=-1) 277 | 278 | newvlad = torch.max(x2, 3, keepdim=False)[0] 279 | 280 | vlad = F.normalize(vlad, p=2, dim=2) #intra-normalization 281 | vlad = vlad.view(x.size(0), -1) # flatten 282 | vlad = F.normalize(vlad, p=2, dim=1) # L2 normalize 283 | 284 | vlad = vlad.view(vlad.shape[0], self.num_cluster, C) 285 | vlad = torch.cat((vlad, newvlad), -1) 286 | vlad = self.vladchannel(vlad) 287 | return vlad 288 | 289 | def reason_answer(self, reps): 290 | src = reps.view(reps.shape[0] * reps.shape[1], reps.shape[2], reps.shape[3]) 291 | x1 = self.reason_conv(reps.permute(0,3,1,2)).permute(0,2,3,1) 292 | B, N, L, C = x1.shape[0:] 293 | x = x1.reshape(B*N, L, C) 294 | diag = torch.ones(L) 295 | diag = torch.diag(diag) 296 | graph = torch.matmul(x, x.permute(0,2,1)) 297 | direct = torch.sign(graph) 298 | length = torch.abs(graph) 299 | 300 | ones = torch.ones(L, L) 301 | diag_zero = ones - diag 302 | length = masked_softmax(length, diag_zero[None,...].cuda(), dim=1) 303 | 304 | direct_graph = direct * length + diag[None,...].cuda() 305 | 306 | direct_value = torch.matmul(direct_graph, src)[:,:,None] 307 | direct_conv1 = self.reasoning1(direct_value.permute(0,3,1,2)) 308 | direct_conv2 = self.reasoning2(direct_value.permute(0,3,1,2)) 309 | 310 | direct_conv = direct_conv1 * direct_conv2 311 | 312 | result = direct_conv.squeeze(-1).permute(0,2,1) 313 | result = result.view(B, N, direct_conv.shape[2], direct_conv.shape[1]) 314 | return result 315 | 316 | def forward(self, 317 | images: torch.Tensor, 318 | objects: torch.LongTensor, 319 | segms: torch.Tensor, 320 | boxes: torch.Tensor, 321 | box_mask: torch.LongTensor, 322 | question: Dict[str, torch.Tensor], 323 | question_tags: torch.LongTensor, 324 | question_mask: torch.LongTensor, 325 | answers: Dict[str, torch.Tensor], 326 | answer_tags: torch.LongTensor, 327 | answer_mask: torch.LongTensor, 328 | metadata: List[Dict[str, Any]] = None, 329 | label: torch.LongTensor = None) -> Dict[str, torch.Tensor]: 330 | """ 331 | :param images: [batch_size, 3, im_height, im_width] 332 | :param objects: [batch_size, max_num_objects] Padded objects 333 | :param boxes: [batch_size, max_num_objects, 4] Padded boxes 334 | :param box_mask: [batch_size, max_num_objects] Mask for whether or not each box is OK 335 | :param question: AllenNLP representation of the question. [batch_size, num_answers, seq_length] 336 | :param question_tags: detection label for each item in the Q [batch_size, num_answers, seq_length] 337 | :param question_mask: Mask for the Q [batch_size, num_answers, seq_length] 338 | :param answers: AllenNLP representation of the answer. [batch_size, num_answers, seq_length] 339 | :param answer_tags: A detection label for each item in the A [batch_size, num_answers, seq_length] 340 | :param answer_mask: Mask for the As [batch_size, num_answers, seq_length] 341 | :param metadata: Ignore, this is about which dataset item we're on 342 | :param label: Optional, which item is valid 343 | :return: shit 344 | """ 345 | # Trim off boxes that are too long. this is an issue b/c dataparallel, it'll pad more zeros that are 346 | # not needed 347 | 348 | max_len = int(box_mask.sum(1).max().item()) 349 | objects = objects[:, :max_len] 350 | box_mask = box_mask[:, :max_len] 351 | boxes = boxes[:, :max_len] 352 | segms = segms[:, :max_len] 353 | 354 | for tag_type, the_tags in (('question', question_tags), ('answer', answer_tags)): 355 | if int(the_tags.max()) > max_len: 356 | raise ValueError("Oh no! {}_tags has maximum of {} but objects is of dim {}. Values are\n{}".format( 357 | tag_type, int(the_tags.max()), objects.shape, the_tags 358 | )) 359 | 360 | obj_reps, img_feats = self.detector(images=images, boxes=boxes, box_mask=box_mask, classes=objects, segms=segms) 361 | 362 | # Now get the question representations 363 | obj_feas = self.obj_graph(obj_reps['obj_reps']) 364 | q_rep, q_obj_reps = self.embed_span(question, question_tags, question_mask, obj_feas) 365 | a_rep, a_obj_reps = self.embed_span(answers, answer_tags, answer_mask, obj_feas) 366 | 367 | MASK = torch.sum(question_mask, 2, keepdim=True) - 1 368 | onehot = torch.cuda.LongTensor(MASK.size(0), q_rep.size(2), MASK.size(1)).zero_() 369 | target = onehot.scatter_(1, MASK.long().permute(0,2,1), 1).permute(0,2,1).float() 370 | q_final = torch.sum(q_rep * target[...,None], 2, keepdim=False) 371 | q_final = torch.mean(q_final, 1, keepdim=False) ## 32 * 512 372 | 373 | ############################################################## 374 | 375 | scene = self.scene_reps(img_feats) 376 | imggraph = self.img_graph(scene) 377 | scene = self.vlad(imggraph, q_final) 378 | 379 | diag = torch.ones(self.num_cluster) 380 | diag = torch.diag(diag) 381 | scene_graph = torch.matmul(scene, scene.permute(0,2,1)) 382 | 383 | ones = torch.ones(self.num_cluster, self.num_cluster) 384 | diag_zero = ones - diag 385 | scene_graph = masked_softmax(scene_graph, diag_zero[None,...].cuda(), dim=1) + diag[None,...].cuda() 386 | 387 | scene_value = torch.matmul(scene_graph, scene)[:,:,None] 388 | scene_conv1 = self.scene_graph0(scene_value.permute(0,3,1,2)) 389 | scene_conv2 = self.scene_graph1(scene_value.permute(0,3,1,2)) 390 | scene_conv = scene_conv1 * scene_conv2 391 | scene_c = scene_conv.squeeze(-1) 392 | scene_c = self.dropout(scene_c) 393 | scene = scene_c 394 | 395 | #################################### 396 | # Perform scene attention by question 397 | q_rep1 = self.question_conv(q_rep.permute(0,3,1,2)).permute(0,2,3,1) 398 | scene_attend = torch.matmul(q_rep1, scene_c[:,None]) 399 | scene_attention_weights = masked_softmax(scene_attend, question_mask[...,None]) 400 | scene_qo = torch.einsum('bnao,bod->bnad', (scene_attention_weights, scene.permute(0,2,1))) 401 | 402 | obj_attend = torch.matmul(q_rep, obj_feas.permute(0,2,1)[:,None]) 403 | obj_attention_weights = masked_softmax(obj_attend, question_mask[...,None]) 404 | obj_qo = torch.einsum('bnao,bod->bnad', (obj_attention_weights, obj_feas)) 405 | 406 | question_node = torch.cat([scene_qo, obj_qo, q_rep], -1) 407 | 408 | #question_node = self.dropout(question_node) 409 | 410 | ## question first layer 411 | question_diag = torch.diag_embed(question_mask.view(question_mask.shape[0]*question_mask.shape[1], question_mask.shape[2])) 412 | ques_diag = question_diag.view(question_mask.shape[0], question_mask.shape[1], question_mask.shape[2], question_mask.shape[2]).float() 413 | first_graph = torch.matmul(question_node, question_node.permute(0,1,3,2)) 414 | first_mask = torch.matmul(question_mask[...,None].float(), question_mask[:,:,None,:].float()) 415 | first_graph = masked_softmax(first_graph, first_mask, dim=2) + ques_diag 416 | 417 | first_conv = torch.matmul(first_graph, question_node) 418 | first_conv1 = self.qnode_graph11(first_conv.permute(0,3,1,2)) 419 | first_conv2 = self.qnode_graph12(first_conv.permute(0,3,1,2)) 420 | first_conv = torch.mul(first_conv1, first_conv2).permute(0,2,3,1) 421 | 422 | question_first_tree = first_conv 423 | 424 | question_third_tree = first_conv + q_rep + obj_qo 425 | 426 | # Perform scene attention by answer 427 | a_rep1 = self.answer_conv(a_rep.permute(0,3,1,2)).permute(0,2,3,1) 428 | scene_attend = torch.matmul(a_rep1, scene_c[:,None]) 429 | scene_attention_weights = masked_softmax(scene_attend, answer_mask[...,None]) 430 | scene_ao = torch.einsum('bnao,bod->bnad', (scene_attention_weights, scene.permute(0,2,1))) 431 | 432 | obj_attend = torch.matmul(a_rep, obj_feas.permute(0,2,1)[:,None]) 433 | obj_attention_weights = masked_softmax(obj_attend, answer_mask[...,None]) 434 | obj_ao = torch.einsum('bnao,bod->bnad', (obj_attention_weights, obj_feas)) 435 | 436 | answer_node = torch.cat([scene_ao, obj_ao, a_rep], -1) 437 | 438 | #answer_node = self.dropout(answer_node) 439 | 440 | ## answer first layer 441 | answer_diag = torch.diag_embed(answer_mask.view(answer_mask.shape[0]*answer_mask.shape[1], answer_mask.shape[2])) 442 | ans_diag = answer_diag.view(answer_mask.shape[0], answer_mask.shape[1], answer_mask.shape[2], answer_mask.shape[2]).float() 443 | first_graph = torch.matmul(answer_node, answer_node.permute(0,1,3,2)) 444 | first_mask = torch.matmul(answer_mask[...,None].float(), answer_mask[:,:,None,:].float()) 445 | first_graph = masked_softmax(first_graph, first_mask, dim=2) + ans_diag 446 | 447 | first_conv = torch.matmul(first_graph, answer_node) 448 | first_conv1 = self.anode_graph11(first_conv.permute(0,3,1,2)) 449 | first_conv2 = self.anode_graph12(first_conv.permute(0,3,1,2)) 450 | first_conv = torch.mul(first_conv1, first_conv2).permute(0,2,3,1) 451 | 452 | answer_first_tree = first_conv 453 | 454 | answer_third_tree = first_conv + a_rep + scene_ao 455 | 456 | # question and answer fusion 457 | question_third_tree1 = self.fusion_conv(question_third_tree.permute(0,3,1,2)).permute(0,2,3,1) 458 | qa_tree_similarity = torch.matmul(question_third_tree1, a_rep.permute(0,1,3,2)) 459 | qa_tree_attention_weights = masked_softmax(qa_tree_similarity, question_mask[...,None], dim=2) 460 | attended_tree_q = torch.einsum('bnqa,bnqd->bnad', (qa_tree_attention_weights, question_third_tree)) 461 | 462 | things_to_pool = torch.cat([attended_tree_q, answer_third_tree], -1) 463 | 464 | things_to_pool = self.reason_answer(things_to_pool) + attended_tree_q + answer_third_tree 465 | things_to_pool = torch.cat([things_to_pool, a_rep, obj_ao], -1) 466 | pooled_rep = replace_masked_values(things_to_pool,answer_mask[...,None], -1e7).max(2)[0] 467 | 468 | logits = self.final_mlp(pooled_rep).squeeze(2) 469 | 470 | ########################################### 471 | 472 | class_probabilities = F.softmax(logits, dim=-1) 473 | 474 | output_dict = {"label_logits": logits, "label_probs": class_probabilities, 475 | 'cnn_regularization_loss': obj_reps['cnn_regularization_loss'], 476 | # Uncomment to visualize attention, if you want 477 | # 'qa_attention_weights': qa_attention_weights, 478 | # 'atoo_attention_weights': atoo_attention_weights, 479 | } 480 | if label is not None: 481 | loss = self._loss(logits, label.long().view(-1)) 482 | self._accuracy(logits, label) 483 | output_dict["loss"] = loss[None] 484 | 485 | return output_dict 486 | 487 | def get_metrics(self, reset: bool = False) -> Dict[str, float]: 488 | return {'accuracy': self._accuracy.get_metric(reset)} 489 | 490 | ################################# 491 | ################################# 492 | ######## Data loading stuff 493 | ################################# 494 | ################################# 495 | 496 | parser = argparse.ArgumentParser(description='train') 497 | parser.add_argument( 498 | '-params', 499 | dest='params', 500 | help='Params location', 501 | type=str, 502 | ) 503 | parser.add_argument( 504 | '-rationale', 505 | action="store_true", 506 | help='use rationale', 507 | ) 508 | parser.add_argument( 509 | '-folder', 510 | dest='folder', 511 | help='folder location', 512 | type=str, 513 | ) 514 | parser.add_argument( 515 | '-no_tqdm', 516 | dest='no_tqdm', 517 | action='store_true', 518 | ) 519 | 520 | args = parser.parse_args() 521 | 522 | params = Params.from_file(args.params) 523 | train, val, test = VCR.splits(mode='rationale' if args.rationale else 'answer', 524 | embs_to_load=params['dataset_reader'].get('embs', 'bert_da'), 525 | only_use_relevant_dets=params['dataset_reader'].get('only_use_relevant_dets', True)) 526 | NUM_GPUS = torch.cuda.device_count() 527 | NUM_CPUS = multiprocessing.cpu_count() 528 | if NUM_GPUS == 0: 529 | raise ValueError("you need gpus!") 530 | 531 | def _to_gpu(td): 532 | if NUM_GPUS > 1: 533 | return td 534 | for k in td: 535 | if k != 'metadata': 536 | td[k] = {k2: v.cuda(non_blocking=True) for k2, v in td[k].items()} if isinstance(td[k], dict) else td[k].cuda( 537 | non_blocking=True) 538 | return td 539 | num_workers = (4 * NUM_GPUS if NUM_CPUS == 32 else 2*NUM_GPUS)-1 540 | print(f"Using {num_workers} workers out of {NUM_CPUS} possible", flush=True) 541 | loader_params = {'batch_size': 96 // NUM_GPUS, 'num_gpus':NUM_GPUS, 'num_workers':num_workers} 542 | train_loader = VCRLoader.from_dataset(train, **loader_params) 543 | val_loader = VCRLoader.from_dataset(val, **loader_params) 544 | test_loader = VCRLoader.from_dataset(test, **loader_params) 545 | 546 | ARGS_RESET_EVERY = 100 547 | print("Loading {} for {}".format(params['model'].get('type', 'WTF?'), 'rationales' if args.rationale else 'answer'), flush=True) 548 | model = Model.from_params(vocab=train.vocab, params=params['model']) 549 | for submodule in model.detector.backbone.modules(): 550 | if isinstance(submodule, BatchNorm2d): 551 | submodule.track_running_stats = False 552 | for p in submodule.parameters(): 553 | p.requires_grad = False 554 | 555 | model = DataParallel(model).cuda() if NUM_GPUS > 1 else model.cuda() 556 | optimizer = Optimizer.from_params([x for x in model.named_parameters() if x[1].requires_grad], 557 | params['trainer']['optimizer']) 558 | 559 | lr_scheduler_params = params['trainer'].pop("learning_rate_scheduler", None) 560 | scheduler = LearningRateScheduler.from_params(optimizer, lr_scheduler_params) if lr_scheduler_params else None 561 | 562 | if os.path.exists(args.folder): 563 | print("Found folder! restoring", flush=True) 564 | start_epoch, val_metric_per_epoch = restore_checkpoint(model, optimizer, serialization_dir=args.folder, 565 | learning_rate_scheduler=scheduler) 566 | else: 567 | print("Making directories") 568 | os.makedirs(args.folder, exist_ok=True) 569 | start_epoch, val_metric_per_epoch = 0, [] 570 | shutil.copy2(args.params, args.folder) 571 | 572 | valrecord = open('path', 'w') 573 | param_shapes = print_para(model) 574 | num_batches = 0 575 | for epoch_num in range(start_epoch, params['trainer']['num_epochs'] + start_epoch): 576 | train_results = [] 577 | norms = [] 578 | model.train() 579 | for b, (time_per_batch, batch) in enumerate(time_batch(train_loader if args.no_tqdm else tqdm(train_loader), reset_every=ARGS_RESET_EVERY)): 580 | batch = _to_gpu(batch) 581 | optimizer.zero_grad() 582 | output_dict = model(**batch) 583 | loss = output_dict['loss'].mean() + output_dict['cnn_regularization_loss'].mean() 584 | loss.backward() 585 | 586 | num_batches += 1 587 | if scheduler: 588 | scheduler.step_batch(num_batches) 589 | 590 | norms.append( 591 | clip_grad_norm(model.named_parameters(), max_norm=params['trainer']['grad_norm'], clip=True, verbose=False) 592 | ) 593 | optimizer.step() 594 | 595 | train_results.append(pd.Series({'loss': output_dict['loss'].mean().item(), 596 | 'crl': output_dict['cnn_regularization_loss'].mean().item(), 597 | 'accuracy': (model.module if NUM_GPUS > 1 else model).get_metrics( 598 | reset=(b % ARGS_RESET_EVERY) == 0)[ 599 | 'accuracy'], 600 | 'sec_per_batch': time_per_batch, 601 | 'hr_per_epoch': len(train_loader) * time_per_batch / 3600, 602 | })) 603 | if b % ARGS_RESET_EVERY == 0 and b > 0: 604 | norms_df = pd.DataFrame(pd.DataFrame(norms[-ARGS_RESET_EVERY:]).mean(), columns=['norm']).join( 605 | param_shapes[['shape', 'size']]).sort_values('norm', ascending=False) 606 | 607 | print("e{:2d}b{:5d}/{:5d}. norms: \n{}\nsumm:\n{}\n~~~~~~~~~~~~~~~~~~\n".format( 608 | epoch_num, b, len(train_loader), 609 | norms_df.to_string(formatters={'norm': '{:.2f}'.format}), 610 | pd.DataFrame(train_results[-ARGS_RESET_EVERY:]).mean(), 611 | ), flush=True) 612 | 613 | if len(val_metric_per_epoch): 614 | print("epoch:", epoch_num, "valacc:", val_metric_per_epoch[-1]) 615 | 616 | print("---\nTRAIN EPOCH {:2d}:\n{}\n----".format(epoch_num, pd.DataFrame(train_results).mean())) 617 | val_probs = [] 618 | val_labels = [] 619 | val_loss_sum = 0.0 620 | 621 | model.eval() 622 | for b, (time_per_batch, batch) in enumerate(time_batch(val_loader)): 623 | with torch.no_grad(): 624 | batch = _to_gpu(batch) 625 | output_dict = model(**batch) 626 | val_probs.append(output_dict['label_probs'].detach().cpu().numpy()) 627 | val_labels.append(batch['label'].detach().cpu().numpy()) 628 | val_loss_sum += output_dict['loss'].mean().item() * batch['label'].shape[0] 629 | 630 | val_labels = np.concatenate(val_labels, 0) 631 | val_probs = np.concatenate(val_probs, 0) 632 | val_loss_avg = val_loss_sum / val_labels.shape[0] 633 | 634 | val_metric_per_epoch.append(float(np.mean(val_labels == val_probs.argmax(1)))) 635 | if scheduler: 636 | scheduler.step(val_metric_per_epoch[-1], epoch_num) 637 | 638 | print("Val epoch {} has acc {:.3f} and loss {:.3f}".format(epoch_num, val_metric_per_epoch[-1], val_loss_avg), 639 | flush=True) 640 | model_txt = open('path', 'a+') 641 | model_txt.write('epoch:' + str(epoch_num) + ' ' + 'valacc:' + str(val_metric_per_epoch[-1]) + "\n") 642 | model_txt.close() 643 | save_checkpoint(model, optimizer, args.folder, epoch_num, val_metric_per_epoch, 644 | is_best=int(np.argmax(val_metric_per_epoch)) == (len(val_metric_per_epoch) - 1) or int(np.argmax(val_metric_per_epoch)) == (len(val_metric_per_epoch) - 2)) 645 | 646 | print("STOPPING. now running the best model on the validation set", flush=True) 647 | # Load best 648 | restore_best_checkpoint(model, args.folder) 649 | model.eval() 650 | val_probs = [] 651 | val_labels = [] 652 | for b, (time_per_batch, batch) in enumerate(time_batch(val_loader)): 653 | with torch.no_grad(): 654 | batch = _to_gpu(batch) 655 | output_dict = model(**batch) 656 | val_probs.append(output_dict['label_probs'].detach().cpu().numpy()) 657 | val_labels.append(batch['label'].detach().cpu().numpy()) 658 | val_labels = np.concatenate(val_labels, 0) 659 | val_probs = np.concatenate(val_probs, 0) 660 | acc = float(np.mean(val_labels == val_probs.argmax(1))) 661 | print("Final val accuracy is {:.3f}".format(acc)) 662 | np.save(os.path.join(args.folder, f'valpreds.npy'), val_probs) 663 | -------------------------------------------------------------------------------- /utils/detector.py: -------------------------------------------------------------------------------- 1 | """ 2 | ok so I lied. it's not a detector, it's the resnet backbone 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.parallel 8 | from torchvision.models import resnet 9 | 10 | from utils.pytorch_misc import Flattener 11 | from torchvision.layers import ROIAlign 12 | import torch.utils.model_zoo as model_zoo 13 | from config import USE_IMAGENET_PRETRAINED 14 | from utils.pytorch_misc import pad_sequence 15 | from torch.nn import functional as F 16 | 17 | 18 | def _load_resnet(pretrained=True): 19 | # huge thx to https://github.com/ruotianluo/pytorch-faster-rcnn/blob/master/lib/nets/resnet_v1.py 20 | backbone = resnet.resnet50(pretrained=False) 21 | if pretrained: 22 | backbone.load_state_dict(model_zoo.load_url( 23 | 'https://s3.us-west-2.amazonaws.com/ai2-rowanz/resnet50-e13db6895d81.th')) 24 | for i in range(2, 4): 25 | getattr(backbone, 'layer%d' % i)[0].conv1.stride = (2, 2) 26 | getattr(backbone, 'layer%d' % i)[0].conv2.stride = (1, 1) 27 | return backbone 28 | 29 | 30 | def _load_resnet_imagenet(pretrained=True): 31 | # huge thx to https://github.com/ruotianluo/pytorch-faster-rcnn/blob/master/lib/nets/resnet_v1.py 32 | backbone = resnet.resnet50(pretrained=pretrained) 33 | for i in range(2, 4): 34 | getattr(backbone, 'layer%d' % i)[0].conv1.stride = (2, 2) 35 | getattr(backbone, 'layer%d' % i)[0].conv2.stride = (1, 1) 36 | # use stride 1 for the last conv4 layer (same as tf-faster-rcnn) 37 | backbone.layer4[0].conv2.stride = (1, 1) 38 | backbone.layer4[0].downsample[0].stride = (1, 1) 39 | 40 | # # Make batchnorm more sensible 41 | # for submodule in backbone.modules(): 42 | # if isinstance(submodule, torch.nn.BatchNorm2d): 43 | # submodule.momentum = 0.01 44 | 45 | return backbone 46 | 47 | 48 | class SimpleDetector(nn.Module): 49 | def __init__(self, pretrained=True, average_pool=True, semantic=True, final_dim=1024): 50 | """ 51 | :param average_pool: whether or not to average pool the representations 52 | :param pretrained: Whether we need to load from scratch 53 | :param semantic: Whether or not we want to introduce the mask and the class label early on (default Yes) 54 | """ 55 | super(SimpleDetector, self).__init__() 56 | # huge thx to https://github.com/ruotianluo/pytorch-faster-rcnn/blob/master/lib/nets/resnet_v1.py 57 | backbone = _load_resnet_imagenet(pretrained=pretrained) if USE_IMAGENET_PRETRAINED else _load_resnet( 58 | pretrained=pretrained) 59 | 60 | self.backbone = nn.Sequential( 61 | backbone.conv1, 62 | backbone.bn1, 63 | backbone.relu, 64 | backbone.maxpool, 65 | backbone.layer1, 66 | backbone.layer2, 67 | backbone.layer3, 68 | # backbone.layer4 69 | ) 70 | self.roi_align = ROIAlign((7, 7) if USE_IMAGENET_PRETRAINED else (14, 14), 71 | spatial_scale=1 / 16, sampling_ratio=0) 72 | 73 | if semantic: 74 | self.mask_dims = 32 75 | self.object_embed = torch.nn.Embedding(num_embeddings=81, embedding_dim=128) 76 | self.mask_upsample = torch.nn.Conv2d(1, self.mask_dims, kernel_size=3, 77 | stride=2 if USE_IMAGENET_PRETRAINED else 1, 78 | padding=1, bias=True) 79 | else: 80 | self.object_embed = None 81 | self.mask_upsample = None 82 | 83 | after_roi_align = [backbone.layer4] 84 | self.final_dim = final_dim 85 | if average_pool: 86 | after_roi_align += [nn.AvgPool2d(7, stride=1), Flattener()] 87 | 88 | self.after_roi_align = torch.nn.Sequential(*after_roi_align) 89 | 90 | self.obj_downsample = torch.nn.Sequential( 91 | torch.nn.Dropout(p=0.1), 92 | torch.nn.Linear(2048 + (128 if semantic else 0), final_dim), 93 | torch.nn.ReLU(inplace=True), 94 | ) 95 | self.regularizing_predictor = torch.nn.Linear(2048, 81) 96 | 97 | def forward(self, 98 | images: torch.Tensor, 99 | boxes: torch.Tensor, 100 | box_mask: torch.LongTensor, 101 | classes: torch.Tensor = None, 102 | segms: torch.Tensor = None, 103 | ): 104 | """ 105 | :param images: [batch_size, 3, im_height, im_width] 106 | :param boxes: [batch_size, max_num_objects, 4] Padded boxes 107 | :param box_mask: [batch_size, max_num_objects] Mask for whether or not each box is OK 108 | :return: object reps [batch_size, max_num_objects, dim] 109 | """ 110 | # [batch_size, 2048, im_height // 32, im_width // 32 111 | img_feats = self.backbone(images) 112 | box_inds = box_mask.nonzero() 113 | assert box_inds.shape[0] > 0 114 | rois = torch.cat(( 115 | box_inds[:, 0, None].type(boxes.dtype), 116 | boxes[box_inds[:, 0], box_inds[:, 1]], 117 | ), 1) 118 | 119 | # Object class and segmentation representations 120 | roi_align_res = self.roi_align(img_feats, rois) 121 | if self.mask_upsample is not None: 122 | assert segms is not None 123 | segms_indexed = segms[box_inds[:, 0], None, box_inds[:, 1]] - 0.5 124 | roi_align_res[:, :self.mask_dims] += self.mask_upsample(segms_indexed) 125 | 126 | 127 | post_roialign = self.after_roi_align(roi_align_res) 128 | 129 | # Add some regularization, encouraging the model to keep giving decent enough predictions 130 | obj_logits = self.regularizing_predictor(post_roialign) 131 | obj_labels = classes[box_inds[:, 0], box_inds[:, 1]] 132 | cnn_regularization = F.cross_entropy(obj_logits, obj_labels, size_average=True)[None] 133 | 134 | feats_to_downsample = post_roialign if self.object_embed is None else torch.cat((post_roialign, self.object_embed(obj_labels)), -1) 135 | roi_aligned_feats = self.obj_downsample(feats_to_downsample) 136 | 137 | # Reshape into a padded sequence - this is expensive and annoying but easier to implement and debug... 138 | obj_reps = pad_sequence(roi_aligned_feats, box_mask.sum(1).tolist()) 139 | return { 140 | 'obj_reps_raw': post_roialign, 141 | 'obj_reps': obj_reps, 142 | 'obj_logits': obj_logits, 143 | 'obj_labels': obj_labels, 144 | 'cnn_regularization_loss': cnn_regularization 145 | } 146 | -------------------------------------------------------------------------------- /utils/newdetector.py: -------------------------------------------------------------------------------- 1 | """ 2 | ok so I lied. it's not a detector, it's the resnet backbone 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.parallel 8 | from torchvision.models import resnet 9 | 10 | from utils.pytorch_misc import Flattener 11 | from torchvision.layers import ROIAlign 12 | import torch.utils.model_zoo as model_zoo 13 | from config import USE_IMAGENET_PRETRAINED 14 | from utils.pytorch_misc import pad_sequence 15 | from torch.nn import functional as F 16 | 17 | 18 | def _load_resnet(pretrained=True): 19 | # huge thx to https://github.com/ruotianluo/pytorch-faster-rcnn/blob/master/lib/nets/resnet_v1.py 20 | backbone = resnet.resnet50(pretrained=False) 21 | if pretrained: 22 | backbone.load_state_dict(model_zoo.load_url( 23 | 'https://s3.us-west-2.amazonaws.com/ai2-rowanz/resnet50-e13db6895d81.th')) 24 | for i in range(2, 4): 25 | getattr(backbone, 'layer%d' % i)[0].conv1.stride = (2, 2) 26 | getattr(backbone, 'layer%d' % i)[0].conv2.stride = (1, 1) 27 | return backbone 28 | 29 | 30 | def _load_resnet_imagenet(pretrained=True): 31 | # huge thx to https://github.com/ruotianluo/pytorch-faster-rcnn/blob/master/lib/nets/resnet_v1.py 32 | backbone = resnet.resnet50(pretrained=pretrained) 33 | for i in range(2, 4): 34 | getattr(backbone, 'layer%d' % i)[0].conv1.stride = (2, 2) 35 | getattr(backbone, 'layer%d' % i)[0].conv2.stride = (1, 1) 36 | # use stride 1 for the last conv4 layer (same as tf-faster-rcnn) 37 | backbone.layer4[0].conv2.stride = (1, 1) 38 | backbone.layer4[0].downsample[0].stride = (1, 1) 39 | 40 | # # Make batchnorm more sensible 41 | # for submodule in backbone.modules(): 42 | # if isinstance(submodule, torch.nn.BatchNorm2d): 43 | # submodule.momentum = 0.01 44 | 45 | return backbone 46 | 47 | 48 | class SimpleDetector(nn.Module): 49 | def __init__(self, pretrained=True, average_pool=True, semantic=True, final_dim=1024): 50 | """ 51 | :param average_pool: whether or not to average pool the representations 52 | :param pretrained: Whether we need to load from scratch 53 | :param semantic: Whether or not we want to introduce the mask and the class label early on (default Yes) 54 | """ 55 | super(SimpleDetector, self).__init__() 56 | # huge thx to https://github.com/ruotianluo/pytorch-faster-rcnn/blob/master/lib/nets/resnet_v1.py 57 | backbone = _load_resnet_imagenet(pretrained=pretrained) if USE_IMAGENET_PRETRAINED else _load_resnet( 58 | pretrained=pretrained) 59 | 60 | self.backbone = nn.Sequential( 61 | backbone.conv1, 62 | backbone.bn1, 63 | backbone.relu, 64 | backbone.maxpool, 65 | backbone.layer1, 66 | backbone.layer2, 67 | backbone.layer3, 68 | # backbone.layer4 69 | ) 70 | self.roi_align = ROIAlign((7, 7) if USE_IMAGENET_PRETRAINED else (14, 14), 71 | spatial_scale=1 / 16, sampling_ratio=0) 72 | 73 | if semantic: 74 | self.mask_dims = 32 75 | self.object_embed = torch.nn.Embedding(num_embeddings=81, embedding_dim=128) 76 | self.mask_upsample = torch.nn.Conv2d(1, self.mask_dims, kernel_size=3, 77 | stride=2 if USE_IMAGENET_PRETRAINED else 1, 78 | padding=1, bias=True) 79 | else: 80 | self.object_embed = None 81 | self.mask_upsample = None 82 | 83 | after_roi_align = [backbone.layer4] 84 | self.final_dim = final_dim 85 | if average_pool: 86 | after_roi_align += [nn.AvgPool2d(7, stride=1), Flattener()] 87 | 88 | self.after_roi_align = torch.nn.Sequential(*after_roi_align) 89 | 90 | self.obj_downsample = torch.nn.Sequential( 91 | torch.nn.Dropout(p=0.1), 92 | torch.nn.Linear(2048 + (128 if semantic else 0), final_dim), 93 | torch.nn.ReLU(inplace=True), 94 | ) 95 | self.regularizing_predictor = torch.nn.Linear(2048, 81) 96 | 97 | def forward(self, 98 | images: torch.Tensor, 99 | boxes: torch.Tensor, 100 | box_mask: torch.LongTensor, 101 | classes: torch.Tensor = None, 102 | segms: torch.Tensor = None, 103 | ): 104 | """ 105 | :param images: [batch_size, 3, im_height, im_width] 106 | :param boxes: [batch_size, max_num_objects, 4] Padded boxes 107 | :param box_mask: [batch_size, max_num_objects] Mask for whether or not each box is OK 108 | :return: object reps [batch_size, max_num_objects, dim] 109 | """ 110 | # [batch_size, 2048, im_height // 32, im_width // 32 111 | img_feats = self.backbone(images) 112 | src_feats = img_feats 113 | box_inds = box_mask.nonzero() 114 | assert box_inds.shape[0] > 0 115 | rois = torch.cat(( 116 | box_inds[:, 0, None].type(boxes.dtype), 117 | boxes[box_inds[:, 0], box_inds[:, 1]], 118 | ), 1) 119 | 120 | # Object class and segmentation representations 121 | roi_align_res = self.roi_align(img_feats, rois) 122 | if self.mask_upsample is not None: 123 | assert segms is not None 124 | segms_indexed = segms[box_inds[:, 0], None, box_inds[:, 1]] - 0.5 125 | roi_align_res[:, :self.mask_dims] += self.mask_upsample(segms_indexed) 126 | 127 | 128 | post_roialign = self.after_roi_align(roi_align_res) 129 | 130 | # Add some regularization, encouraging the model to keep giving decent enough predictions 131 | obj_logits = self.regularizing_predictor(post_roialign) 132 | obj_labels = classes[box_inds[:, 0], box_inds[:, 1]] 133 | cnn_regularization = F.cross_entropy(obj_logits, obj_labels, size_average=True)[None] 134 | 135 | feats_to_downsample = post_roialign if self.object_embed is None else torch.cat((post_roialign, self.object_embed(obj_labels)), -1) 136 | roi_aligned_feats = self.obj_downsample(feats_to_downsample) 137 | 138 | # Reshape into a padded sequence - this is expensive and annoying but easier to implement and debug... 139 | obj_reps = pad_sequence(roi_aligned_feats, box_mask.sum(1).tolist()) 140 | return { 141 | 'obj_reps_raw': post_roialign, 142 | 'obj_reps': obj_reps, 143 | 'obj_logits': obj_logits, 144 | 'obj_labels': obj_labels, 145 | 'cnn_regularization_loss': cnn_regularization 146 | }, src_feats 147 | -------------------------------------------------------------------------------- /utils/pytorch_misc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Question relevance model 3 | """ 4 | 5 | # Make stuff 6 | import os 7 | import re 8 | import shutil 9 | import time 10 | 11 | import numpy as np 12 | import pandas as pd 13 | import torch 14 | from allennlp.common.util import START_SYMBOL, END_SYMBOL 15 | from allennlp.nn.util import device_mapping 16 | from allennlp.training.trainer import move_optimizer_to_cuda 17 | from torch.nn import DataParallel 18 | 19 | 20 | def time_batch(gen, reset_every=100): 21 | """ 22 | Gets timing info for a batch 23 | :param gen: 24 | :param reset_every: How often we'll reset 25 | :return: 26 | """ 27 | start = time.time() 28 | start_t = 0 29 | for i, item in enumerate(gen): 30 | time_per_batch = (time.time() - start) / (i + 1 - start_t) 31 | yield time_per_batch, item 32 | if i % reset_every == 0: 33 | start = time.time() 34 | start_t = i 35 | 36 | 37 | class Flattener(torch.nn.Module): 38 | def __init__(self): 39 | """ 40 | Flattens last 3 dimensions to make it only batch size, -1 41 | """ 42 | super(Flattener, self).__init__() 43 | 44 | def forward(self, x): 45 | return x.view(x.size(0), -1) 46 | 47 | 48 | def pad_sequence(sequence, lengths): 49 | """ 50 | :param sequence: [\sum b, .....] sequence 51 | :param lengths: [b1, b2, b3...] that sum to \sum b 52 | :return: [len(lengths), maxlen(b), .....] tensor 53 | """ 54 | output = sequence.new_zeros(len(lengths), max(lengths), *sequence.shape[1:]) 55 | start = 0 56 | for i, diff in enumerate(lengths): 57 | if diff > 0: 58 | output[i, :diff] = sequence[start:(start + diff)] 59 | start += diff 60 | return output 61 | 62 | 63 | def extra_leading_dim_in_sequence(f, x, mask): 64 | return f(x.view(-1, *x.shape[2:]), mask.view(-1, mask.shape[2])).view(*x.shape[:3], -1) 65 | 66 | 67 | def clip_grad_norm(named_parameters, max_norm, clip=True, verbose=False): 68 | """Clips gradient norm of an iterable of parameters. 69 | 70 | The norm is computed over all gradients together, as if they were 71 | concatenated into a single vector. Gradients are modified in-place. 72 | 73 | Arguments: 74 | parameters (Iterable[Variable]): an iterable of Variables that will have 75 | gradients normalized 76 | max_norm (float or int): max norm of the gradients 77 | 78 | Returns: 79 | Total norm of the parameters (viewed as a single vector). 80 | """ 81 | max_norm = float(max_norm) 82 | parameters = [(n, p) for n, p in named_parameters if p.grad is not None] 83 | total_norm = 0 84 | param_to_norm = {} 85 | param_to_shape = {} 86 | for n, p in parameters: 87 | param_norm = p.grad.data.norm(2) 88 | total_norm += param_norm ** 2 89 | param_to_norm[n] = param_norm 90 | param_to_shape[n] = tuple(p.size()) 91 | if np.isnan(param_norm.item()): 92 | raise ValueError("the param {} was null.".format(n)) 93 | 94 | total_norm = total_norm ** (1. / 2) 95 | clip_coef = max_norm / (total_norm + 1e-6) 96 | if clip_coef.item() < 1 and clip: 97 | for n, p in parameters: 98 | p.grad.data.mul_(clip_coef) 99 | 100 | if verbose: 101 | print('---Total norm {:.3f} clip coef {:.3f}-----------------'.format(total_norm, clip_coef)) 102 | for name, norm in sorted(param_to_norm.items(), key=lambda x: -x[1]): 103 | print("{:<60s}: {:.3f}, ({}: {})".format(name, norm, np.prod(param_to_shape[name]), param_to_shape[name])) 104 | print('-------------------------------', flush=True) 105 | 106 | return pd.Series({name: norm.item() for name, norm in param_to_norm.items()}) 107 | 108 | 109 | def find_latest_checkpoint(serialization_dir): 110 | """ 111 | Return the location of the latest model and training state files. 112 | If there isn't a valid checkpoint then return None. 113 | """ 114 | have_checkpoint = (serialization_dir is not None and 115 | any("model_state_epoch_" in x for x in os.listdir(serialization_dir))) 116 | 117 | if not have_checkpoint: 118 | return None 119 | 120 | serialization_files = os.listdir(serialization_dir) 121 | model_checkpoints = [x for x in serialization_files if "model_state_epoch" in x] 122 | # Get the last checkpoint file. Epochs are specified as either an 123 | # int (for end of epoch files) or with epoch and timestamp for 124 | # within epoch checkpoints, e.g. 5.2018-02-02-15-33-42 125 | found_epochs = [ 126 | # pylint: disable=anomalous-backslash-in-string 127 | re.search("model_state_epoch_([0-9\.\-]+)\.th", x).group(1) 128 | for x in model_checkpoints 129 | ] 130 | int_epochs = [] 131 | for epoch in found_epochs: 132 | pieces = epoch.split('.') 133 | if len(pieces) == 1: 134 | # Just a single epoch without timestamp 135 | int_epochs.append([int(pieces[0]), 0]) 136 | else: 137 | # has a timestamp 138 | int_epochs.append([int(pieces[0]), pieces[1]]) 139 | last_epoch = sorted(int_epochs, reverse=True)[0] 140 | if last_epoch[1] == 0: 141 | epoch_to_load = str(last_epoch[0]) 142 | else: 143 | epoch_to_load = '{0}.{1}'.format(last_epoch[0], last_epoch[1]) 144 | 145 | model_path = os.path.join(serialization_dir, 146 | "model_state_epoch_{}.th".format(epoch_to_load)) 147 | training_state_path = os.path.join(serialization_dir, 148 | "training_state_epoch_{}.th".format(epoch_to_load)) 149 | return model_path, training_state_path 150 | 151 | 152 | def save_checkpoint(model, optimizer, serialization_dir, epoch, val_metric_per_epoch, is_best=None, 153 | learning_rate_scheduler=None) -> None: 154 | """ 155 | Saves a checkpoint of the model to self._serialization_dir. 156 | Is a no-op if self._serialization_dir is None. 157 | Parameters 158 | ---------- 159 | epoch : Union[int, str], required. 160 | The epoch of training. If the checkpoint is saved in the middle 161 | of an epoch, the parameter is a string with the epoch and timestamp. 162 | is_best: bool, optional (default = None) 163 | A flag which causes the model weights at the given epoch to 164 | be copied to a "best.th" file. The value of this flag should 165 | be based on some validation metric computed by your model. 166 | """ 167 | if serialization_dir is not None: 168 | model_path = os.path.join(serialization_dir, "model_state_epoch_{}.th".format(epoch)) 169 | model_state = model.module.state_dict() if isinstance(model, DataParallel) else model.state_dict() 170 | torch.save(model_state, model_path) 171 | 172 | training_state = {'epoch': epoch, 173 | 'val_metric_per_epoch': val_metric_per_epoch, 174 | 'optimizer': optimizer.state_dict() 175 | } 176 | if learning_rate_scheduler is not None: 177 | training_state["learning_rate_scheduler"] = \ 178 | learning_rate_scheduler.lr_scheduler.state_dict() 179 | training_path = os.path.join(serialization_dir, 180 | "training_state_epoch_{}.th".format(epoch)) 181 | torch.save(training_state, training_path) 182 | if is_best: 183 | print("Best validation performance so far. Copying weights to '{}/best.th'.".format(serialization_dir)) 184 | shutil.copyfile(model_path, os.path.join(serialization_dir, "best.th")) 185 | 186 | 187 | def restore_best_checkpoint(model, serialization_dir): 188 | fn = os.path.join(serialization_dir, 'best.th') 189 | model_state = torch.load(fn, map_location=device_mapping(-1)) 190 | assert os.path.exists(fn) 191 | if isinstance(model, DataParallel): 192 | model.module.load_state_dict(model_state) 193 | else: 194 | model.load_state_dict(model_state) 195 | 196 | 197 | def restore_checkpoint(model, optimizer, serialization_dir, learning_rate_scheduler=None): 198 | """ 199 | Restores a model from a serialization_dir to the last saved checkpoint. 200 | This includes an epoch count and optimizer state, which is serialized separately 201 | from model parameters. This function should only be used to continue training - 202 | if you wish to load a model for inference/load parts of a model into a new 203 | computation graph, you should use the native Pytorch functions: 204 | `` model.load_state_dict(torch.load("/path/to/model/weights.th"))`` 205 | If ``self._serialization_dir`` does not exist or does not contain any checkpointed weights, 206 | this function will do nothing and return 0. 207 | Returns 208 | ------- 209 | epoch: int 210 | The epoch at which to resume training, which should be one after the epoch 211 | in the saved training state. 212 | """ 213 | latest_checkpoint = find_latest_checkpoint(serialization_dir) 214 | 215 | if latest_checkpoint is None: 216 | # No checkpoint to restore, start at 0 217 | return 0, [] 218 | 219 | model_path, training_state_path = latest_checkpoint 220 | 221 | # Load the parameters onto CPU, then transfer to GPU. 222 | # This avoids potential OOM on GPU for large models that 223 | # load parameters onto GPU then make a new GPU copy into the parameter 224 | # buffer. The GPU transfer happens implicitly in load_state_dict. 225 | model_state = torch.load(model_path, map_location=device_mapping(-1)) 226 | training_state = torch.load(training_state_path, map_location=device_mapping(-1)) 227 | if isinstance(model, DataParallel): 228 | model.module.load_state_dict(model_state) 229 | else: 230 | model.load_state_dict(model_state) 231 | 232 | # idk this is always bad luck for me 233 | optimizer.load_state_dict(training_state["optimizer"]) 234 | 235 | if learning_rate_scheduler is not None and "learning_rate_scheduler" in training_state: 236 | learning_rate_scheduler.lr_scheduler.load_state_dict( 237 | training_state["learning_rate_scheduler"]) 238 | move_optimizer_to_cuda(optimizer) 239 | 240 | # We didn't used to save `validation_metric_per_epoch`, so we can't assume 241 | # that it's part of the trainer state. If it's not there, an empty list is all 242 | # we can do. 243 | if "val_metric_per_epoch" not in training_state: 244 | print("trainer state `val_metric_per_epoch` not found, using empty list") 245 | val_metric_per_epoch: [] 246 | else: 247 | val_metric_per_epoch = training_state["val_metric_per_epoch"] 248 | 249 | if isinstance(training_state["epoch"], int): 250 | epoch_to_return = training_state["epoch"] + 1 251 | else: 252 | epoch_to_return = int(training_state["epoch"].split('.')[0]) + 1 253 | return epoch_to_return, val_metric_per_epoch 254 | 255 | 256 | def detokenize(array, vocab): 257 | """ 258 | Given an array of ints, we'll turn this into a string or a list of strings. 259 | :param array: possibly multidimensional numpy array 260 | :return: 261 | """ 262 | if array.ndim > 1: 263 | return [detokenize(x, vocab) for x in array] 264 | tokenized = [vocab.get_token_from_index(v) for v in array] 265 | return ' '.join([x for x in tokenized if x not in (vocab._padding_token, START_SYMBOL, END_SYMBOL)]) 266 | 267 | 268 | def print_para(model): 269 | """ 270 | Prints parameters of a model 271 | :param opt: 272 | :return: 273 | """ 274 | st = {} 275 | total_params = 0 276 | total_params_training = 0 277 | for p_name, p in model.named_parameters(): 278 | # if not ('bias' in p_name.split('.')[-1] or 'bn' in p_name.split('.')[-1]): 279 | st[p_name] = ([str(x) for x in p.size()], np.prod(p.size()), p.requires_grad) 280 | total_params += np.prod(p.size()) 281 | if p.requires_grad: 282 | total_params_training += np.prod(p.size()) 283 | pd.set_option('display.max_columns', None) 284 | shapes_df = pd.DataFrame([(p_name, '[{}]'.format(','.join(size)), prod, p_req_grad) 285 | for p_name, (size, prod, p_req_grad) in sorted(st.items(), key=lambda x: -x[1][1])], 286 | columns=['name', 'shape', 'size', 'requires_grad']).set_index('name') 287 | 288 | print('\n {:.1f}M total parameters. {:.1f}M training \n ----- \n {} \n ----'.format(total_params / 1000000.0, 289 | total_params_training / 1000000.0, 290 | shapes_df.to_string()), 291 | flush=True) 292 | return shapes_df 293 | 294 | 295 | def batch_index_iterator(len_l, batch_size, skip_end=True): 296 | """ 297 | Provides indices that iterate over a list 298 | :param len_l: int representing size of thing that we will 299 | iterate over 300 | :param batch_size: size of each batch 301 | :param skip_end: if true, don't iterate over the last batch 302 | :return: A generator that returns (start, end) tuples 303 | as it goes through all batches 304 | """ 305 | iterate_until = len_l 306 | if skip_end: 307 | iterate_until = (len_l // batch_size) * batch_size 308 | 309 | for b_start in range(0, iterate_until, batch_size): 310 | yield (b_start, min(b_start + batch_size, len_l)) 311 | 312 | 313 | def batch_iterator(seq, batch_size, skip_end=True): 314 | for b_start, b_end in batch_index_iterator(len(seq), batch_size, skip_end=skip_end): 315 | yield seq[b_start:b_end] 316 | -------------------------------------------------------------------------------- /utils/smalldetector.py: -------------------------------------------------------------------------------- 1 | """ 2 | ok so I lied. it's not a detector, it's the resnet backbone 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.parallel 8 | from torchvision.models import resnet 9 | 10 | from utils.pytorch_misc import Flattener 11 | from torchvision.layers import ROIAlign 12 | import torch.utils.model_zoo as model_zoo 13 | from config import USE_IMAGENET_PRETRAINED 14 | from utils.pytorch_misc import pad_sequence 15 | from torch.nn import functional as F 16 | 17 | 18 | def _load_resnet(pretrained=True): 19 | # huge thx to https://github.com/ruotianluo/pytorch-faster-rcnn/blob/master/lib/nets/resnet_v1.py 20 | backbone = resnet.resnet50(pretrained=False) 21 | if pretrained: 22 | backbone.load_state_dict(model_zoo.load_url( 23 | 'https://s3.us-west-2.amazonaws.com/ai2-rowanz/resnet50-e13db6895d81.th')) 24 | for i in range(2, 4): 25 | getattr(backbone, 'layer%d' % i)[0].conv1.stride = (2, 2) 26 | getattr(backbone, 'layer%d' % i)[0].conv2.stride = (1, 1) 27 | return backbone 28 | 29 | 30 | def _load_resnet_imagenet(pretrained=True): 31 | # huge thx to https://github.com/ruotianluo/pytorch-faster-rcnn/blob/master/lib/nets/resnet_v1.py 32 | backbone = resnet.resnet50(pretrained=pretrained) 33 | for i in range(2, 4): 34 | getattr(backbone, 'layer%d' % i)[0].conv1.stride = (2, 2) 35 | getattr(backbone, 'layer%d' % i)[0].conv2.stride = (1, 1) 36 | # use stride 1 for the last conv4 layer (same as tf-faster-rcnn) 37 | backbone.layer4[0].conv2.stride = (1, 1) 38 | backbone.layer4[0].downsample[0].stride = (1, 1) 39 | 40 | # # Make batchnorm more sensible 41 | # for submodule in backbone.modules(): 42 | # if isinstance(submodule, torch.nn.BatchNorm2d): 43 | # submodule.momentum = 0.01 44 | 45 | return backbone 46 | 47 | 48 | class SimpleDetector(nn.Module): 49 | def __init__(self, pretrained=True, average_pool=True, semantic=True, final_dim=1024): 50 | """ 51 | :param average_pool: whether or not to average pool the representations 52 | :param pretrained: Whether we need to load from scratch 53 | :param semantic: Whether or not we want to introduce the mask and the class label early on (default Yes) 54 | """ 55 | super(SimpleDetector, self).__init__() 56 | # huge thx to https://github.com/ruotianluo/pytorch-faster-rcnn/blob/master/lib/nets/resnet_v1.py 57 | backbone = _load_resnet_imagenet(pretrained=pretrained) if USE_IMAGENET_PRETRAINED else _load_resnet( 58 | pretrained=pretrained) 59 | 60 | self.backbone = nn.Sequential( 61 | backbone.conv1, 62 | backbone.bn1, 63 | backbone.relu, 64 | backbone.maxpool, 65 | backbone.layer1, 66 | backbone.layer2, 67 | backbone.layer3, 68 | # backbone.layer4 69 | ) 70 | 71 | self.newbackbone = nn.Sequential( 72 | backbone.conv1, 73 | backbone.bn1, 74 | backbone.relu, 75 | backbone.maxpool, 76 | backbone.layer1, 77 | backbone.layer2, 78 | backbone.layer3, 79 | backbone.layer4 80 | ) 81 | 82 | self.roi_align = ROIAlign((7, 7) if USE_IMAGENET_PRETRAINED else (14, 14), 83 | spatial_scale=1 / 16, sampling_ratio=0) 84 | 85 | if semantic: 86 | self.mask_dims = 32 87 | self.object_embed = torch.nn.Embedding(num_embeddings=81, embedding_dim=128) 88 | self.mask_upsample = torch.nn.Conv2d(1, self.mask_dims, kernel_size=3, 89 | stride=2 if USE_IMAGENET_PRETRAINED else 1, 90 | padding=1, bias=True) 91 | else: 92 | self.object_embed = None 93 | self.mask_upsample = None 94 | 95 | after_roi_align = [backbone.layer4] 96 | self.final_dim = final_dim 97 | if average_pool: 98 | after_roi_align += [nn.AvgPool2d(7, stride=1), Flattener()] 99 | 100 | self.after_roi_align = torch.nn.Sequential(*after_roi_align) 101 | 102 | self.obj_downsample = torch.nn.Sequential( 103 | torch.nn.Dropout(p=0.1), 104 | torch.nn.Linear(2048 + (128 if semantic else 0), final_dim), 105 | torch.nn.ReLU(inplace=True), 106 | ) 107 | self.regularizing_predictor = torch.nn.Linear(2048, 81) 108 | 109 | def forward(self, 110 | images: torch.Tensor, 111 | boxes: torch.Tensor, 112 | box_mask: torch.LongTensor, 113 | classes: torch.Tensor = None, 114 | segms: torch.Tensor = None, 115 | ): 116 | """ 117 | :param images: [batch_size, 3, im_height, im_width] 118 | :param boxes: [batch_size, max_num_objects, 4] Padded boxes 119 | :param box_mask: [batch_size, max_num_objects] Mask for whether or not each box is OK 120 | :return: object reps [batch_size, max_num_objects, dim] 121 | """ 122 | # [batch_size, 2048, im_height // 32, im_width // 32 123 | img_feats = self.backbone(images) 124 | src_feats = self.newbackbone(images) 125 | box_inds = box_mask.nonzero() 126 | assert box_inds.shape[0] > 0 127 | rois = torch.cat(( 128 | box_inds[:, 0, None].type(boxes.dtype), 129 | boxes[box_inds[:, 0], box_inds[:, 1]], 130 | ), 1) 131 | 132 | # Object class and segmentation representations 133 | roi_align_res = self.roi_align(img_feats, rois) 134 | if self.mask_upsample is not None: 135 | assert segms is not None 136 | segms_indexed = segms[box_inds[:, 0], None, box_inds[:, 1]] - 0.5 137 | roi_align_res[:, :self.mask_dims] += self.mask_upsample(segms_indexed) 138 | 139 | 140 | post_roialign = self.after_roi_align(roi_align_res) 141 | 142 | # Add some regularization, encouraging the model to keep giving decent enough predictions 143 | obj_logits = self.regularizing_predictor(post_roialign) 144 | obj_labels = classes[box_inds[:, 0], box_inds[:, 1]] 145 | cnn_regularization = F.cross_entropy(obj_logits, obj_labels, size_average=True)[None] 146 | 147 | feats_to_downsample = post_roialign if self.object_embed is None else torch.cat((post_roialign, self.object_embed(obj_labels)), -1) 148 | roi_aligned_feats = self.obj_downsample(feats_to_downsample) 149 | 150 | # Reshape into a padded sequence - this is expensive and annoying but easier to implement and debug... 151 | obj_reps = pad_sequence(roi_aligned_feats, box_mask.sum(1).tolist()) 152 | return { 153 | 'obj_reps_raw': post_roialign, 154 | 'obj_reps': obj_reps, 155 | 'obj_logits': obj_logits, 156 | 'obj_labels': obj_labels, 157 | 'cnn_regularization_loss': cnn_regularization 158 | }, src_feats 159 | -------------------------------------------------------------------------------- /utils/testdetector.py: -------------------------------------------------------------------------------- 1 | """ 2 | ok so I lied. it's not a detector, it's the resnet backbone 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.parallel 8 | from torchvision.models import resnet 9 | 10 | from utils.pytorch_misc import Flattener 11 | from torchvision.layers import ROIAlign 12 | import torch.utils.model_zoo as model_zoo 13 | from config import USE_IMAGENET_PRETRAINED 14 | from utils.pytorch_misc import pad_sequence 15 | from torch.nn import functional as F 16 | 17 | 18 | def _load_resnet(pretrained=True): 19 | # huge thx to https://github.com/ruotianluo/pytorch-faster-rcnn/blob/master/lib/nets/resnet_v1.py 20 | backbone = resnet.resnet50(pretrained=False) 21 | if pretrained: 22 | backbone.load_state_dict(model_zoo.load_url( 23 | 'https://s3.us-west-2.amazonaws.com/ai2-rowanz/resnet50-e13db6895d81.th')) 24 | for i in range(2, 4): 25 | getattr(backbone, 'layer%d' % i)[0].conv1.stride = (2, 2) 26 | getattr(backbone, 'layer%d' % i)[0].conv2.stride = (1, 1) 27 | return backbone 28 | 29 | 30 | def _load_resnet_imagenet(pretrained=True): 31 | # huge thx to https://github.com/ruotianluo/pytorch-faster-rcnn/blob/master/lib/nets/resnet_v1.py 32 | backbone = resnet.resnet50(pretrained=pretrained) 33 | for i in range(2, 4): 34 | getattr(backbone, 'layer%d' % i)[0].conv1.stride = (2, 2) 35 | getattr(backbone, 'layer%d' % i)[0].conv2.stride = (1, 1) 36 | # use stride 1 for the last conv4 layer (same as tf-faster-rcnn) 37 | backbone.layer4[0].conv2.stride = (1, 1) 38 | backbone.layer4[0].downsample[0].stride = (1, 1) 39 | 40 | # # Make batchnorm more sensible 41 | # for submodule in backbone.modules(): 42 | # if isinstance(submodule, torch.nn.BatchNorm2d): 43 | # submodule.momentum = 0.01 44 | 45 | return backbone 46 | 47 | 48 | class SimpleDetector(nn.Module): 49 | def __init__(self, pretrained=True, average_pool=True, semantic=True, final_dim=1024): 50 | """ 51 | :param average_pool: whether or not to average pool the representations 52 | :param pretrained: Whether we need to load from scratch 53 | :param semantic: Whether or not we want to introduce the mask and the class label early on (default Yes) 54 | """ 55 | super(SimpleDetector, self).__init__() 56 | # huge thx to https://github.com/ruotianluo/pytorch-faster-rcnn/blob/master/lib/nets/resnet_v1.py 57 | backbone = _load_resnet_imagenet(pretrained=pretrained) if USE_IMAGENET_PRETRAINED else _load_resnet( 58 | pretrained=pretrained) 59 | 60 | self.backbone = nn.Sequential( 61 | backbone.conv1, 62 | backbone.bn1, 63 | backbone.relu, 64 | backbone.maxpool, 65 | backbone.layer1, 66 | backbone.layer2, 67 | backbone.layer3, 68 | # backbone.layer4 69 | ) 70 | 71 | self.backbone1 = nn.Sequential( 72 | backbone.conv1, 73 | backbone.bn1, 74 | backbone.relu, 75 | backbone.maxpool, 76 | backbone.layer1, 77 | backbone.layer2, 78 | backbone.layer3, 79 | backbone.layer4 80 | ) 81 | 82 | self.roi_align = ROIAlign((7, 7) if USE_IMAGENET_PRETRAINED else (14, 14), 83 | spatial_scale=1 / 16, sampling_ratio=0) 84 | 85 | if semantic: 86 | self.mask_dims = 32 87 | self.object_embed = torch.nn.Embedding(num_embeddings=81, embedding_dim=128) 88 | self.mask_upsample = torch.nn.Conv2d(1, self.mask_dims, kernel_size=3, 89 | stride=2 if USE_IMAGENET_PRETRAINED else 1, 90 | padding=1, bias=True) 91 | else: 92 | self.object_embed = None 93 | self.mask_upsample = None 94 | 95 | after_roi_align = [backbone.layer4] 96 | self.final_dim = final_dim 97 | if average_pool: 98 | after_roi_align += [nn.AvgPool2d(7, stride=1), Flattener()] 99 | 100 | self.after_roi_align = torch.nn.Sequential(*after_roi_align) 101 | 102 | self.obj_downsample = torch.nn.Sequential( 103 | torch.nn.Dropout(p=0.1), 104 | torch.nn.Linear(2048 + (128 if semantic else 0), final_dim), 105 | torch.nn.ReLU(inplace=True), 106 | ) 107 | self.regularizing_predictor = torch.nn.Linear(2048, 81) 108 | 109 | def forward(self, 110 | images: torch.Tensor, 111 | boxes: torch.Tensor, 112 | box_mask: torch.LongTensor, 113 | classes: torch.Tensor = None, 114 | segms: torch.Tensor = None, 115 | ): 116 | """ 117 | :param images: [batch_size, 3, im_height, im_width] 118 | :param boxes: [batch_size, max_num_objects, 4] Padded boxes 119 | :param box_mask: [batch_size, max_num_objects] Mask for whether or not each box is OK 120 | :return: object reps [batch_size, max_num_objects, dim] 121 | """ 122 | # [batch_size, 2048, im_height // 32, im_width // 32 123 | img_feats = self.backbone(images) 124 | img_feats1 = self.backbone1(images) 125 | src_feats = img_feats 126 | src_feats1 = img_feats1 127 | box_inds = box_mask.nonzero() 128 | assert box_inds.shape[0] > 0 129 | rois = torch.cat(( 130 | box_inds[:, 0, None].type(boxes.dtype), 131 | boxes[box_inds[:, 0], box_inds[:, 1]], 132 | ), 1) 133 | 134 | # Object class and segmentation representations 135 | roi_align_res = self.roi_align(img_feats, rois) 136 | if self.mask_upsample is not None: 137 | assert segms is not None 138 | segms_indexed = segms[box_inds[:, 0], None, box_inds[:, 1]] - 0.5 139 | roi_align_res[:, :self.mask_dims] += self.mask_upsample(segms_indexed) 140 | 141 | 142 | post_roialign = self.after_roi_align(roi_align_res) 143 | 144 | # Add some regularization, encouraging the model to keep giving decent enough predictions 145 | obj_logits = self.regularizing_predictor(post_roialign) 146 | obj_labels = classes[box_inds[:, 0], box_inds[:, 1]] 147 | cnn_regularization = F.cross_entropy(obj_logits, obj_labels, size_average=True)[None] 148 | 149 | feats_to_downsample = post_roialign if self.object_embed is None else torch.cat((post_roialign, self.object_embed(obj_labels)), -1) 150 | roi_aligned_feats = self.obj_downsample(feats_to_downsample) 151 | 152 | # Reshape into a padded sequence - this is expensive and annoying but easier to implement and debug... 153 | obj_reps = pad_sequence(roi_aligned_feats, box_mask.sum(1).tolist()) 154 | return { 155 | 'obj_reps_raw': post_roialign, 156 | 'obj_reps': obj_reps, 157 | 'obj_logits': obj_logits, 158 | 'obj_labels': obj_labels, 159 | 'cnn_regularization_loss': cnn_regularization 160 | }, src_feats, src_feats1 161 | --------------------------------------------------------------------------------