├── .gitignore ├── DATA_PREP_README.md ├── DATA_README.md ├── LICENSE ├── README.md ├── code ├── anchors.py ├── dat_loader.py ├── eval_script.py ├── evaluator.py ├── extended_config.py ├── fpn_resnet.py ├── loss.py ├── main.py ├── main_dist.py ├── mdl.py ├── ssd_vgg.py └── utils.py ├── conda_env_zsg.yml ├── configs ├── cfg.json ├── cfg.yaml └── ds_info.json └── data ├── download_ann.sh ├── ds_prep_config.json ├── ds_prep_utils.py ├── flatten_train.py ├── prepare_c01_flickr_splits.py ├── prepare_flickr30k.py └── prepare_referit.py /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__/ 2 | data/flickr30k 3 | data/referit 4 | data/vg_split 5 | data/flickr30k_c0 6 | data/flickr30k_c1 7 | tmp/ 8 | weights/ -------------------------------------------------------------------------------- /DATA_PREP_README.md: -------------------------------------------------------------------------------- 1 | # Dataset Preparation 2 | 3 | Note that the following steps are required only if you want to prepare the annotations from parent repository. If you just want to run the model with the annotations see [DATA_README.md](./DATA_README.md) 4 | 5 | Here I have outlined the steps to prepare the following datasets: 6 | - Flickr30k Entities 7 | - ReferIt 8 | - Unseen splits 9 | 10 | We convert the annotations for each dataset into `.csv` file with the format: 11 | img_id, bbox, queries 12 | 13 | The project directory is $ROOT 14 | 15 | ## Flickr30k Entities 16 | Current directory is located at $FLICKR=/some_path/flickr30k 17 | 1. To get the Flickr30k Images you need to fill a form whose instructions can be found here http://shannon.cs.illinois.edu/DenotationGraph/. Un-tar the file and save it under $FLICKR/flickr30k_images 18 | 1. `git clone https://github.com/BryanPlummer/flickr30k_entities.git`. Unzip the annotations. At this point the directory should look like: 19 | ``` 20 | $FLICKR 21 | |-- flickr30k_entities 22 | |-- Annotations 23 | |-- Sentences 24 | |-- test.txt 25 | |-- train.txt 26 | |-- val.txt 27 | |-- flickr30k_images 28 | |-- results.json 29 | ``` 30 | 1. Make a symbolic link to $FLICKR here using `ln -s $FLICKR $ROOT/data/flickr30k` 31 | 1. Now we prepare the flickr30k entities dataset using 32 | ``` 33 | cd $ROOT 34 | python data/prepare_flickr30k.py 35 | ``` 36 | 1. The above code does the following: 37 | + Convert the annotations in `.xml` to a single `.json` file (because it is easier to deal with dictionaries and better to read only once). It is saved in $FLICKR/all_ann.json 38 | + Create train/val/test `.csv` files under $FLICKR/csvs/flickr_normal/{train/val/test}.csv 39 | 1. At this point the directory structure should look like this: 40 | ``` 41 | $FLICKR 42 | |-- all_ann_2.json 43 | |-- all_annot_new.json 44 | |-- csv_dir 45 | |-- train.csv 46 | |-- val.csv 47 | |-- test.csv 48 | |-- flickr30k_entities 49 | |-- Annotations 50 | |-- Sentences 51 | |-- test.txt 52 | |-- train.txt 53 | |-- val.txt 54 | |-- flickr30k_images 55 | ``` 56 | 57 | ## ReferIt (Refclef) 58 | Current directory is located at $REF=/some_path/referit 59 | 1. Follow the download links at https://github.com/lichengunc/refer to setup referit (refclef). Your folder structure after downloading the images (image subset of imageclef) and the annotations should look like this: 60 | ``` 61 | $REF 62 | |-- images 63 | |-- saiapr_tc12_images 64 | |-- refclef 65 | |-- instances.json 66 | |-- refs(berkeley).p 67 | |-- refs(unc).p 68 | ``` 69 | 1. We use only the `berkeley` split to be consistent with previous works. 70 | 1. Now we again convert to csv format. First create a symbolic link, and then run `prepare_referit.py` 71 | ``` 72 | cd $ROOT/data 73 | ln -s $REF referit 74 | cd $ROOT 75 | ptyhon data/prepare_referit.py 76 | ``` 77 | 78 | The final structure looks like 79 | 80 | ``` 81 | $REF 82 | |-- images 83 | |-- saiapr_tc12_images 84 | |-- refclef 85 | |-- instances.json 86 | |-- refs(berkeley).p 87 | |-- refs(unc).p 88 | |-- csv_dir 89 | |-- train.csv 90 | |-- val.csv 91 | |-- test.csv 92 | ``` 93 | 94 | ## Unseen Splits 95 | (Coming soon!) 96 | -------------------------------------------------------------------------------- /DATA_README.md: -------------------------------------------------------------------------------- 1 | # Dataset Loading 2 | 3 | Note that the following steps uses annotations from the parent dataset converted into a fixed format (outlined below). For steps to reproduce the annotations see [DATA_PREP_README.md](./DATA_PREP_README.md) 4 | 5 | The project directory is $ROOT 6 | 7 | ## Setup the directories 8 | ``` 9 | cd $ROOT/data 10 | bash download_ann.sh 11 | ``` 12 | 13 | If you run into error downloading the annotations please use this drive link: https://drive.google.com/open?id=1oZ5llnA4btD9LSmnSB0GaZtTogskLwCe 14 | 15 | ## Image Download 16 | 17 | # Flickr30k Entities 18 | Current directory is located at $FLICKR=/some_path/flickr30k 19 | 20 | 1. To get the Flickr30k Images you need to fill a form whose instructions can be found here http://shannon.cs.illinois.edu/DenotationGraph/. Un-tar the file and save it under $FLICKR/flickr30k_images 21 | 1. Make a symbolic link to the images using `ln -s $FLICKR/flickr30k_images $ROOT/data/flickr30k/flickr30k_images` 22 | 23 | # ReferIt 24 | Current directory is located at $REF=/some_path/referit 25 | 26 | 1. Download the ImageClef subset for referit from https://github.com/lichengunc/refer. Download link: http://bvisionweb1.cs.unc.edu/licheng/referit/data/images/saiapr_tc-12.zip 27 | 1. Unzip to $REF/images/saiapr_tc12_images 28 | 1. Make a symbolic link to $REFER here using `ln -s $REF/images/saiapr_tc12_images $ROOT/data/referit/images/` 29 | 30 | # Visual Genome 31 | Current directory is located $VG=/some_path/visual_genome 32 | 33 | 1. See download page for Visual Genome (https://visualgenome.org/api/v0/api_home.html). Download the two image files to $VG/VG_100K and $VG/VG_100K_2 34 | 1. Make a symbolic link to $VG using 35 | ``` 36 | ln -s $VG/VG_100K $ROOT/data/visual_genome/ 37 | ln -s $VG/VG_100K_2 $ROOT/data/visual_genome/ 38 | ``` 39 | 40 | The remaining annotations are already existing, so should work out of the box. 41 | 42 | TODO: 43 | - [ ] Create a script to automate the above given root directory (flickr30k still needs to be done manually). 44 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Arka Sadhu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # zsgnet-pytorch 2 | This is the official repository for ICCV19 oral paper [Zero-Shot Grounding of Objects from Natural Language Queries](https://arxiv.org/abs/1908.07129). It contains the code and the datasets to reproduce the numbers for our model ZSGNet in the paper. 3 | 4 | The code has been refactored from the original implementation and now supports Distributed learning (see [pytorch docs](https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel)) for significantly faster training (around 4x speedup from pytorch [Dataparallel](https://pytorch.org/docs/stable/nn.html#dataparallel)) 5 | 6 | The code is fairly easy to use and extendable for future work. Feel free to open an issue in case of queries. 7 | 8 | ## Install 9 | Requirements: 10 | - python>=3.6 11 | - pytorch>=1.1 12 | 13 | To use the same environment you can use `conda` and the environment file `conda_env_zsg.yml` file provided. Please refer to [Miniconda](https://docs.conda.io/en/latest/miniconda.html) for details on installing `conda`. 14 | 15 | ``` 16 | MINICONDA_ROOT=[to your Miniconda/Anaconda root directory] 17 | conda env create -f conda_env_zsg.yml --prefix $MINICONDA_ROOT/envs/zsg_pyt 18 | conda activate zsg_pyt 19 | ``` 20 | 21 | ## Data Preparation 22 | Look at [DATA_README.md](./DATA_README.md) for quick start and [DATA_PREP_README.md](./DATA_PREP_README.md) for obtaining annotations from the parent datasets. 23 | 24 | ## Training 25 | Basic usage is `python code/main_dist.py "experiment_name" --arg1=val1 --arg2=val2` and the arg1, arg2 can be found in `configs/cfg.yaml`. This trains using the DataParallel mode. 26 | 27 | For distributed learning use `python -m torch.distributed.launch --nproc_per_node=$ngpus code/main_dist.py` instead. This trains using the DistributedDataParallel mode. (Also see [caveat in using distributed training](#caveats-in-distributeddataparallel) below) 28 | 29 | An example to train on ReferIt dataset (note you must have prepared referit dataset) would be: 30 | 31 | ``` 32 | python code/main_dist.py "referit_try" --ds_to_use='refclef' --bs=16 --nw=4 33 | ``` 34 | 35 | Similarly for distributed learning (need to set npgus as the number of gpus) 36 | ``` 37 | python -m torch.distributed.launch --nproc_per_node=$npgus code/main_dist.py "referit_try" --ds_to_use='refclef' --bs=16 --nw=4 38 | ``` 39 | 40 | ### Logging 41 | Logs are stored inside `tmp/` directory. When you run the code with $exp_name the following are stored: 42 | - `txt_logs/$exp_name.txt`: the config used and the training, validation losses after ever epoch. 43 | - `models/$exp_name.pth`: the model, optimizer, scheduler, accuracy, number of epochs and iterations completed are stored. Only the best model upto the current epoch is stored. 44 | - `ext_logs/$exp_name.txt`: this uses the `logging` module of python to store the `logger.debug` outputs printed. Mainly used for debugging. 45 | - `tb_logs/$exp_name`: this is still wip, right now just creates a directory and nothing more, ideally want to support the tensorboard logs. 46 | - `predictions`: the validation outputs of current best model. 47 | 48 | ## Evaluation 49 | There are two ways to evaluate. 50 | 51 | 1. For validation, it is already computed in the training loop. If you just want to evaluate on validation or testing on a model trained previously ($exp_name) you can do: 52 | ``` 53 | python code/main_dist.py $exp_name --ds_to_use='refclef' --resume=True --only_valid=True --only_test=True 54 | ``` 55 | or you can use a different experiment name as well and pass `--resume_path` argument like: 56 | ``` 57 | python code/main_dist.py $exp_name --ds_to_use='refclef' --resume=True --resume_path='./tmp/models/referit_try.pth' 58 | ``` 59 | After this, the logs would be available inside `tmp/txt_logs/$exp_name.txt` 60 | 61 | 2. If you have some other model, you can output the predictions in the following structure into a pickle file say `predictions.pkl`: 62 | ``` 63 | [ 64 | {'id': annotation_id, 65 | 'pred_boxes': [x1,y1,x2,y2]}, 66 | . 67 | . 68 | . 69 | ] 70 | ``` 71 | 72 | Then you can evaluate using `code/eval_script.py` using: 73 | ``` 74 | python code/eval_script.py predictions_file gt_file 75 | ``` 76 | For referit it would be 77 | ``` 78 | python code/eval_script.py ./tmp/predictions/$exp_name/val_preds_$exp_name.pkl ./data/referit/csv_dir/val.csv 79 | ``` 80 | 81 | ### Caveats in DistributedDataParallel 82 | When training using DDP, there is no easy way to get all the validation outputs into one process (that works only for tensors). As a result one has to save the predictions of each separate process and then read again to combine them in the main process. Current implementation doesn't do this for simplicity, as a result the validation results obtained during training are slight different from the actual results. 83 | 84 | To get the correct results, one can follow the steps in [Evaluation](#evaluation) as is (the point to note is **NOT** use `torch.distributed.launch` for evaluation). Thus, you would get correct results when using simply dataparallel. 85 | 86 | 87 | ## Pre-trained Models 88 | The pre-trained models are available in [Google Drive](https://drive.google.com/file/d/1tFGm87vdbQUEX4PNfmgs-UQzthMO849t/view?usp=sharing) 89 | 90 | ## ToDo 91 | - [ ] Add colab demo. 92 | - [ ] Add installation guide. 93 | - [x] Pretrained models 94 | - [ ] Add hubconfig 95 | - [ ] Add tensorboard 96 | 97 | # Acknowledgements 98 | We thank: 99 | 1. [@yhenon](https://github.com/yhenon) for their repository on retina-net (https://github.com/yhenon/pytorch-retinanet). 100 | 1. [@amdegroot](https://github.com/amdegroot) for their repsository on ssd using vgg (https://github.com/amdegroot/ssd.pytorch) 101 | 1. [fastai](https://github.com/fastai/fastai) repository for helpful logging, anchor box generation and convolution functions. 102 | 1. [maskrcnn-benchmark](https://github.com/facebookresearch/maskrcnn-benchmark/) repository for many of the distributed utils and implementation of non-maxima suppression. 103 | 104 | # Citation 105 | 106 | If you find the code or dataset useful, please cite us: 107 | 108 | ``` 109 | @InProceedings{Sadhu_2019_ICCV, 110 | author = {Sadhu, Arka and Chen, Kan and Nevatia, Ram}, 111 | title = {Zero-Shot Grounding of Objects From Natural Language Queries}, 112 | booktitle = {The IEEE International Conference on Computer Vision (ICCV)}, 113 | month = {October}, 114 | year = {2019} 115 | } 116 | ``` 117 | 118 | -------------------------------------------------------------------------------- /code/anchors.py: -------------------------------------------------------------------------------- 1 | """ 2 | Creates anchor based on the backbone 3 | Based on code from https://github.com/fastai/fastai_docs/blob/master/dev_nb/102a_coco.ipynb 4 | Author: Arka Sadhu 5 | """ 6 | import torch 7 | import numpy as np 8 | from torch import nn 9 | 10 | 11 | def cthw2tlbr(boxes): 12 | "Convert center/size format `boxes` to top/left bottom/right corners." 13 | top_left = boxes[..., :2] - boxes[..., 2:]/2 14 | bot_right = boxes[..., :2] + boxes[..., 2:]/2 15 | return torch.cat([top_left, bot_right], dim=-1) 16 | 17 | 18 | def tlbr2cthw(boxes): 19 | "Convert top/left bottom/right format `boxes` to center/size corners." 20 | center = (boxes[..., :2] + boxes[..., 2:])/2 21 | sizes = boxes[..., 2:] - boxes[..., :2] 22 | return torch.cat([center, sizes], dim=-1) 23 | 24 | 25 | def tlbr2tlhw(boxes): 26 | "Convert tl br format `boxes` to tl hw format" 27 | top_left = boxes[:, :2] 28 | height_width = boxes[:, 2:] - boxes[:, :2] 29 | return torch.cat([top_left, height_width], 1) 30 | 31 | 32 | def tlhw2tlbr(boxes): 33 | "Convert tl br format `boxes` to tl hw format" 34 | top_left = boxes[..., :2] 35 | bottom_right = boxes[..., 2:] + boxes[..., :2] 36 | return torch.cat([top_left, bottom_right], -1) 37 | 38 | 39 | def x1y1x2y2_to_y1x1y2x2(boxes): 40 | "Convert xy boxes to yx boxes and vice versa" 41 | box_tmp = boxes.clone() 42 | box_tmp[..., 0], box_tmp[..., 1] = boxes[..., 1], boxes[..., 0] 43 | box_tmp[..., 2], box_tmp[..., 3] = boxes[..., 3], boxes[..., 2] 44 | return box_tmp 45 | 46 | 47 | def create_grid(size, flatten=True): 48 | "Create a grid of a given `size`." 49 | if isinstance(size, tuple): 50 | H, W = size 51 | else: 52 | H, W = size, size 53 | 54 | grid = torch.FloatTensor(H, W, 2) 55 | linear_points = torch.linspace(-1+1/W, 1-1/W, 56 | W) if W > 1 else torch.tensor([0.]) 57 | grid[:, :, 1] = torch.ger(torch.ones( 58 | H), linear_points).expand_as(grid[:, :, 0]) 59 | linear_points = torch.linspace(-1+1/H, 1-1/H, 60 | H) if H > 1 else torch.tensor([0.]) 61 | grid[:, :, 0] = torch.ger( 62 | linear_points, torch.ones(W)).expand_as(grid[:, :, 1]) 63 | return grid.view(-1, 2) if flatten else grid 64 | 65 | 66 | def create_anchors(sizes, ratios, scales, flatten=True, device=torch.device('cuda')): 67 | "Create anchor of `sizes`, `ratios` and `scales`." 68 | # device = torch.device('cuda') 69 | aspects = [[[s*np.sqrt(r), s*np.sqrt(1/r)] 70 | for s in scales] for r in ratios] 71 | aspects = torch.tensor(aspects).to(device).view(-1, 2) 72 | anchors = [] 73 | for h, w in sizes: 74 | if type(h) == torch.Tensor: 75 | h = int(h.item()) 76 | w = int(w.item()) 77 | 78 | sized_aspects = ( 79 | aspects * torch.tensor([2/h, 2/w]).to(device)).unsqueeze(0) 80 | base_grid = create_grid((h, w)).to(device).unsqueeze(1) 81 | n, a = base_grid.size(0), aspects.size(0) 82 | ancs = torch.cat([base_grid.expand(n, a, 2), 83 | sized_aspects.expand(n, a, 2)], 2) 84 | anchors.append(ancs.view(h, w, a, 4)) 85 | anchs = torch.cat([anc.view(-1, 4) 86 | for anc in anchors], 0) if flatten else anchors 87 | return cthw2tlbr(anchs) if flatten else anchors 88 | 89 | 90 | def intersection(anchors, targets): 91 | """ 92 | Compute the sizes of the intersections of `anchors` by `targets`. 93 | Assume both anchors and targets are in tl br format 94 | """ 95 | ancs, tgts = anchors, targets 96 | a, t = ancs.size(0), tgts.size(0) 97 | ancs, tgts = ancs.unsqueeze(1).expand( 98 | a, t, 4), tgts.unsqueeze(0).expand(a, t, 4) 99 | top_left_i = torch.max(ancs[..., :2], tgts[..., :2]) 100 | bot_right_i = torch.min(ancs[..., 2:], tgts[..., 2:]) 101 | 102 | sizes = torch.clamp(bot_right_i - top_left_i, min=0) 103 | return sizes[..., 0] * sizes[..., 1] 104 | 105 | 106 | def IoU_values(anchors, targets): 107 | """ 108 | Compute the IoU values of `anchors` by `targets`. 109 | Expects both in tlbr format 110 | """ 111 | inter = intersection(anchors, targets) 112 | ancs, tgts = tlbr2cthw(anchors), tlbr2cthw(targets) 113 | anc_sz, tgt_sz = ancs[:, 2] * \ 114 | ancs[:, 3], tgts[:, 2] * tgts[:, 3] 115 | union = anc_sz.unsqueeze(1) + tgt_sz.unsqueeze(0) - inter 116 | return inter/(union+1e-8) 117 | 118 | 119 | def simple_iou(box1, box2): 120 | """ 121 | Simple iou between box1 and box2 122 | """ 123 | def simple_inter(ancs, tgts): 124 | top_left_i = torch.max(ancs[..., :2], tgts[..., :2]) 125 | bot_right_i = torch.min(ancs[..., 2:], tgts[..., 2:]) 126 | sizes = torch.clamp(bot_right_i - top_left_i, min=0) 127 | return sizes[..., 0] * sizes[..., 1] 128 | 129 | inter = intersection(box1, box2) 130 | ancs, tgts = tlbr2tlhw(box1), tlbr2tlhw(box2) 131 | anc_sz, tgt_sz = ancs[:, 2] * \ 132 | ancs[:, 3], tgts[:, 2] * tgts[:, 3] 133 | union = anc_sz + tgt_sz - inter 134 | return inter / (union + 1e-8) 135 | 136 | 137 | def match_anchors(anchors, targets, match_thr=0.5, bkg_thr=0.4): 138 | """ 139 | Match `anchors` to targets. -1 is match to background, -2 is ignore. 140 | """ 141 | ious = IoU_values(anchors, targets) 142 | matches = anchors.new(anchors.size(0)).zero_().long() - 2 143 | vals, idxs = torch.max(ious, 1) 144 | matches[vals < bkg_thr] = -1 145 | matches[vals > match_thr] = idxs[vals > match_thr] 146 | # Overwrite matches with each target getting the anchor that has the max IoU. 147 | vals, idxs = torch.max(ious, 0) 148 | # If idxs contains repetition, this doesn't bug and only the last is considered. 149 | matches[idxs] = targets.new_tensor(list(range(targets.size(0)))).long() 150 | return matches 151 | 152 | 153 | def simple_match_anchors(anchors, targets, match_thr=0.4, bkg_thr=0.1): 154 | """ 155 | Match `anchors` to targets. -1 is match to background, -2 is ignore. 156 | Note here: 157 | anchors are fixed 158 | targets are from a batch 159 | """ 160 | # ious = IoU_values(anchors, targets) 161 | ious = IoU_values(targets, anchors) 162 | matches = ious.new(ious.shape).zero_().long() - 2 163 | matches[ious < bkg_thr] = -1 164 | matches[ious > match_thr] = 1 165 | return matches 166 | 167 | 168 | def bbox_to_reg_params(anchors, boxes): 169 | """ 170 | Converts boxes to corresponding reg params 171 | Assume both in rchw format 172 | """ 173 | boxes = tlbr2cthw(boxes) 174 | anchors = tlbr2cthw(anchors) 175 | anchors = anchors.expand(boxes.size(0), anchors.size(0), 4) 176 | boxes = boxes.unsqueeze(1) 177 | trc = (boxes[..., :2] - anchors[..., :2]) / (anchors[..., 2:] + 1e-8) 178 | thw = torch.log(boxes[..., 2:] / (anchors[..., 2:] + 1e-8)) 179 | return torch.cat((trc, thw), 2) 180 | 181 | 182 | def reg_params_to_bbox(anchors, boxes, std12=[1, 1]): 183 | """ 184 | Converts reg_params to corresponding boxes 185 | Assume anchors in r1c1r2c2 format 186 | Boxes in standard form r*, c*, h*, w* 187 | """ 188 | anc1 = anchors.clone() 189 | anc1 = tlbr2cthw(anc1) 190 | b1 = boxes[..., :2] * std12[0] 191 | a111 = anc1[..., 2:] * b1 + anc1[..., :2] 192 | 193 | b2 = boxes[..., 2:] * std12[1] 194 | a222 = torch.exp(b2) * anc1[..., 2:] 195 | af = torch.cat([a111, a222], dim=2) 196 | aft = cthw2tlbr(af) 197 | return aft 198 | -------------------------------------------------------------------------------- /code/dat_loader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | from torch.utils.data.distributed import DistributedSampler 3 | from torchvision.transforms import functional as F 4 | import pandas as pd 5 | from utils import DataWrap 6 | import numpy as np 7 | from pathlib import Path 8 | import torch 9 | from tqdm import tqdm 10 | import re 11 | import PIL 12 | import json 13 | from dataclasses import dataclass 14 | from typing import Dict, List, Optional, Union, Any, Callable, Tuple 15 | import pickle 16 | import ast 17 | import logging 18 | from torchvision import transforms 19 | import spacy 20 | from extended_config import cfg as conf 21 | 22 | 23 | nlp = spacy.load('en_core_web_md') 24 | 25 | 26 | def pil2tensor(image, dtype: np.dtype): 27 | "Convert PIL style `image` array to torch style image tensor." 28 | a = np.asarray(image) 29 | if a.ndim == 2: 30 | a = np.expand_dims(a, 2) 31 | a = np.transpose(a, (1, 0, 2)) 32 | a = np.transpose(a, (2, 1, 0)) 33 | return torch.from_numpy(a.astype(dtype, copy=False)) 34 | 35 | 36 | class NewDistributedSampler(DistributedSampler): 37 | """ 38 | Same as default distributed sampler of pytorch 39 | Just has another argument for shuffle 40 | Allows distributed in validation/testing as well 41 | """ 42 | 43 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 44 | super().__init__(dataset, num_replicas=num_replicas, rank=rank) 45 | self.shuffle = shuffle 46 | 47 | def __iter__(self): 48 | if self.shuffle: 49 | # deterministically shuffle based on epoch 50 | g = torch.Generator() 51 | g.manual_seed(self.epoch) 52 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 53 | else: 54 | indices = torch.arange(len(self.dataset)).tolist() 55 | 56 | # add extra samples to make it evenly divisible 57 | indices += indices[: (self.total_size - len(indices))] 58 | assert len(indices) == self.total_size 59 | 60 | # subsample 61 | offset = self.num_samples * self.rank 62 | indices = indices[offset: offset + self.num_samples] 63 | assert len(indices) == self.num_samples 64 | 65 | return iter(indices) 66 | 67 | 68 | class ImgQuDataset(Dataset): 69 | """ 70 | Any Grounding dataset. 71 | Args: 72 | train_file (string): CSV file with annotations 73 | The format should be: img_file, bbox, queries 74 | Can have same img_file on multiple lines 75 | """ 76 | 77 | def __init__(self, cfg, csv_file, ds_name, split_type='train'): 78 | self.cfg = cfg 79 | self.ann_file = csv_file 80 | self.ds_name = ds_name 81 | self.split_type = split_type 82 | 83 | # self.image_data = pd.read_csv(csv_file) 84 | self.image_data = self._read_annotations(csv_file) 85 | # self.image_data = self.image_data.iloc[:200] 86 | self.img_dir = Path(self.cfg.ds_info[self.ds_name]['img_dir']) 87 | self.phrase_len = 50 88 | self.item_getter = getattr(self, 'simple_item_getter') 89 | # normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 90 | # std=[0.229, 0.224, 0.225]) 91 | 92 | def __len__(self): 93 | return len(self.image_data) 94 | 95 | def __getitem__(self, idx): 96 | return self.item_getter(idx) 97 | 98 | def simple_item_getter(self, idx): 99 | img_file, annot, q_chosen = self.load_annotations(idx) 100 | img = PIL.Image.open(img_file).convert('RGB') 101 | 102 | h, w = img.height, img.width 103 | 104 | q_chosen = q_chosen.strip() 105 | qtmp = nlp(str(q_chosen)) 106 | if len(qtmp) == 0: 107 | # logger.error('Empty string provided') 108 | raise NotImplementedError 109 | qlen = len(qtmp) 110 | q_chosen = q_chosen + ' PD'*(self.phrase_len - qlen) 111 | q_chosen_emb = nlp(q_chosen) 112 | if not len(q_chosen_emb) == self.phrase_len: 113 | q_chosen_emb = q_chosen_emb[:self.phrase_len] 114 | 115 | q_chosen_emb_vecs = np.array([q.vector for q in q_chosen_emb]) 116 | # qlen = len(q_chosen_emb_vecs) 117 | # Annot is in x1y1x2y2 format 118 | target = np.array(annot) 119 | # img = self.resize_fixed_transform(img) 120 | img = img.resize((self.cfg.resize_img[0], self.cfg.resize_img[1])) 121 | # Now target is in y1x1y2x2 format which is required by the model 122 | # The above is because the anchor format is created 123 | # in row, column format 124 | target = np.array([target[1], target[0], target[3], target[2]]) 125 | # Resize target to range 0-1 126 | target = np.array([ 127 | target[0] / h, target[1] / w, 128 | target[2] / h, target[3] / w 129 | ]) 130 | # Target in range -1 to 1 131 | target = 2 * target - 1 132 | 133 | # img = self.img_transforms(img) 134 | # img = Image(pil2tensor(img, np.float_).float().div_(255)) 135 | img = pil2tensor(img, np.float_).float().div_(255) 136 | out = { 137 | 'img': img, 138 | 'idxs': torch.tensor(idx).long(), 139 | 'qvec': torch.from_numpy(q_chosen_emb_vecs), 140 | 'qlens': torch.tensor(qlen), 141 | 'annot': torch.from_numpy(target).float(), 142 | 'orig_annot': torch.tensor(annot).float(), 143 | 'img_size': torch.tensor([h, w]) 144 | } 145 | 146 | return out 147 | 148 | def load_annotations(self, idx): 149 | annotation_list = self.image_data.iloc[idx] 150 | img_file, x1, y1, x2, y2, queries = annotation_list 151 | img_file = self.img_dir / f'{img_file}' 152 | if isinstance(queries, list): 153 | query_chosen = np.random.choice(queries) 154 | else: 155 | assert isinstance(queries, str) 156 | query_chosen = queries 157 | if '_' in query_chosen: 158 | query_chosen = query_chosen.replace('_', ' ') 159 | # annotations = np.array([y1, x1, y2, x2]) 160 | annotations = np.array([x1, y1, x2, y2]) 161 | return img_file, annotations, query_chosen 162 | 163 | def _read_annotations(self, trn_file): 164 | trn_data = pd.read_csv(trn_file) 165 | trn_data['bbox'] = trn_data.bbox.apply( 166 | lambda x: ast.literal_eval(x)) 167 | sample = trn_data['query'].iloc[0] 168 | if sample[0] == '[': 169 | trn_data['query'] = trn_data['query'].apply( 170 | lambda x: ast.literal_eval(x)) 171 | 172 | trn_data['x1'] = trn_data.bbox.apply(lambda x: x[0]) 173 | trn_data['y1'] = trn_data.bbox.apply(lambda x: x[1]) 174 | trn_data['x2'] = trn_data.bbox.apply(lambda x: x[2]) 175 | trn_data['y2'] = trn_data.bbox.apply(lambda x: x[3]) 176 | if self.ds_name == 'flickr30k': 177 | trn_data = trn_data.assign( 178 | image_fpath=trn_data.img_id.apply(lambda x: f'{x}.jpg')) 179 | trn_df = trn_data[['image_fpath', 180 | 'x1', 'y1', 'x2', 'y2', 'query']] 181 | elif self.ds_name == 'refclef': 182 | trn_df = trn_data[['img_id', 183 | 'x1', 'y1', 'x2', 'y2', 'query']] 184 | return trn_df 185 | 186 | 187 | def collater(batch): 188 | qlens = torch.Tensor([i['qlens'] for i in batch]) 189 | max_qlen = int(qlens.max().item()) 190 | # query_vecs = [torch.Tensor(i['query'][:max_qlen]) for i in batch] 191 | out_dict = {} 192 | for k in batch[0]: 193 | out_dict[k] = torch.stack([b[k] for b in batch]).float() 194 | out_dict['qvec'] = out_dict['qvec'][:, :max_qlen] 195 | 196 | return out_dict 197 | 198 | 199 | def make_data_sampler(dataset, shuffle, distributed): 200 | if distributed: 201 | return NewDistributedSampler(dataset, shuffle=shuffle) 202 | if shuffle: 203 | sampler = torch.utils.data.sampler.RandomSampler(dataset) 204 | else: 205 | sampler = torch.utils.data.sampler.SequentialSampler(dataset) 206 | return sampler 207 | 208 | 209 | def get_dataloader(cfg, dataset: Dataset, is_train: bool) -> DataLoader: 210 | is_distributed = cfg.do_dist 211 | images_per_gpu = cfg.bs 212 | if is_distributed: 213 | # DistributedDataParallel 214 | batch_size = images_per_gpu 215 | num_workers = cfg.nw 216 | else: 217 | # DataParallel 218 | batch_size = images_per_gpu * cfg.num_gpus 219 | num_workers = cfg.nw * cfg.num_gpus 220 | if is_train: 221 | shuffle = True 222 | else: 223 | shuffle = False if not is_distributed else True 224 | sampler = make_data_sampler(dataset, shuffle, is_distributed) 225 | return DataLoader(dataset, batch_size=batch_size, 226 | sampler=sampler, drop_last=is_train, 227 | num_workers=num_workers, collate_fn=collater) 228 | 229 | 230 | def get_data(cfg): 231 | # Get which dataset to use 232 | ds_name = cfg.ds_to_use 233 | 234 | # Training file 235 | trn_csv_file = cfg.ds_info[ds_name]['trn_csv_file'] 236 | trn_ds = ImgQuDataset(cfg=cfg, csv_file=trn_csv_file, 237 | ds_name=ds_name, split_type='train') 238 | trn_dl = get_dataloader(cfg, trn_ds, is_train=True) 239 | 240 | # Validation file 241 | val_csv_file = cfg.ds_info[ds_name]['val_csv_file'] 242 | val_ds = ImgQuDataset(cfg=cfg, csv_file=val_csv_file, 243 | ds_name=ds_name, split_type='valid') 244 | val_dl = get_dataloader(cfg, val_ds, is_train=False) 245 | 246 | test_csv_file = cfg.ds_info[ds_name]['test_csv_file'] 247 | test_ds = ImgQuDataset(cfg=cfg, csv_file=test_csv_file, 248 | ds_name=ds_name, split_type='valid') 249 | test_dl = get_dataloader(cfg, test_ds, is_train=False) 250 | 251 | data = DataWrap(path=cfg.tmp_path, train_dl=trn_dl, valid_dl=val_dl, 252 | test_dl={'test0': test_dl}) 253 | return data 254 | 255 | 256 | if __name__ == '__main__': 257 | cfg = conf 258 | data = get_data(cfg, ds_name='refclef') 259 | -------------------------------------------------------------------------------- /code/eval_script.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple evaluation script. 3 | Requires outputs to be saved before in the following format 4 | [ 5 | { 6 | "id": , "pred_boxes": (x1y1x2y2), "pred_scores": int 7 | } 8 | ] 9 | """ 10 | from anchors import IoU_values 11 | import pickle 12 | import pandas as pd 13 | import ast 14 | import torch 15 | import fire 16 | from pathlib import Path 17 | 18 | 19 | def evaluate(pred_file, gt_file, **kwargs): 20 | acc_iou_thresh = kwargs.get('acc_iou_thresh', 0.5) 21 | pred_file = Path(pred_file) 22 | if not pred_file.exists(): 23 | assert 'num_gpus' in kwargs 24 | num_gpus = kwargs['num_gpus'] 25 | pred_files_to_use = [pred_file.parent / 26 | f'{r}_{pred_file.name}' for r in range(num_gpus)] 27 | assert all([p.exists() for p in pred_files_to_use]) 28 | out_preds = [] 29 | for pf in pred_files_to_use: 30 | tmp = pickle.load(open(pf, 'rb')) 31 | assert isinstance(tmp, list) 32 | out_preds += tmp 33 | pickle.dump(out_preds, pred_file.open('wb')) 34 | 35 | predictions = pickle.load(open(pred_file, 'rb')) 36 | gt_annot = pd.read_csv(gt_file) 37 | # gt_annot = gt_annot.iloc[:len(predictions)] 38 | gt_annot['bbox'] = gt_annot.bbox.apply(lambda x: ast.literal_eval(x)) 39 | 40 | # assert len(predictions) == len(gt_annot) 41 | corr = 0 42 | tot = 0 43 | inds_used = set() 44 | for p in predictions: 45 | ind = int(p['id']) 46 | if ind not in inds_used: 47 | inds_used.add(ind) 48 | annot = gt_annot.iloc[ind] 49 | gt_box = torch.tensor(annot.bbox) 50 | pred_box = torch.tensor(p['pred_boxes']) 51 | 52 | iou = IoU_values(pred_box[None, :], gt_box[None, :]) 53 | if iou > acc_iou_thresh: 54 | corr += 1 55 | tot += 1 56 | return corr/tot, corr, tot 57 | 58 | 59 | if __name__ == '__main__': 60 | fire.Fire(evaluate) 61 | -------------------------------------------------------------------------------- /code/evaluator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from anchors import (create_anchors, reg_params_to_bbox, 4 | IoU_values, x1y1x2y2_to_y1x1y2x2) 5 | from typing import Dict 6 | from functools import partial 7 | # from utils import reduce_dict 8 | 9 | 10 | def reshape(box, new_size): 11 | """ 12 | box: (N, 4) in y1x1y2x2 format 13 | new_size: (N, 2) stack of (h, w) 14 | """ 15 | box[:, :2] = new_size * box[:, :2] 16 | box[:, 2:] = new_size * box[:, 2:] 17 | return box 18 | 19 | 20 | class Evaluator(nn.Module): 21 | """ 22 | To get the accuracy. Operates at training time. 23 | """ 24 | 25 | def __init__(self, ratios, scales, cfg): 26 | super().__init__() 27 | self.cfg = cfg 28 | 29 | self.ratios = ratios 30 | self.scales = scales 31 | 32 | self.alpha = cfg['alpha'] 33 | self.gamma = cfg['gamma'] 34 | self.use_focal = cfg['use_focal'] 35 | self.use_softmax = cfg['use_softmax'] 36 | self.use_multi = cfg['use_multi'] 37 | 38 | self.lamb_reg = cfg['lamb_reg'] 39 | 40 | self.met_keys = ['Acc', 'MaxPos'] 41 | self.anchs = None 42 | self.get_anchors = partial( 43 | create_anchors, ratios=self.ratios, 44 | scales=self.scales, flatten=True) 45 | 46 | self.acc_iou_threshold = self.cfg['acc_iou_threshold'] 47 | 48 | def forward(self, out: Dict[str, torch.tensor], 49 | inp: Dict[str, torch.tensor]) -> Dict[str, torch.tensor]: 50 | 51 | annot = inp['annot'] 52 | att_box = out['att_out'] 53 | reg_box = out['bbx_out'] 54 | feat_sizes = out['feat_sizes'] 55 | num_f_out = out['num_f_out'] 56 | 57 | device = att_box.device 58 | 59 | if len(num_f_out) > 1: 60 | num_f_out = int(num_f_out[0].item()) 61 | else: 62 | num_f_out = int(num_f_out.item()) 63 | 64 | feat_sizes = feat_sizes[:num_f_out, :] 65 | 66 | if self.anchs is None: 67 | feat_sizes = feat_sizes[:num_f_out, :] 68 | anchs = self.get_anchors(feat_sizes) 69 | anchs = anchs.to(device) 70 | self.anchs = anchs 71 | else: 72 | anchs = self.anchs 73 | 74 | att_box_sigmoid = torch.sigmoid(att_box).squeeze(-1) 75 | att_box_best, att_box_best_ids = att_box_sigmoid.max(1) 76 | # self.att_box_best = att_box_best 77 | 78 | ious1 = IoU_values(annot, anchs) 79 | gt_mask, expected_best_ids = ious1.max(1) 80 | 81 | actual_bbox = reg_params_to_bbox( 82 | anchs, reg_box) 83 | 84 | best_possible_result, _ = self.get_eval_result( 85 | actual_bbox, annot, expected_best_ids) 86 | 87 | msk = None 88 | actual_result, pred_boxes = self.get_eval_result( 89 | actual_bbox, annot, att_box_best_ids, msk) 90 | 91 | out_dict = {} 92 | out_dict['Acc'] = actual_result 93 | out_dict['MaxPos'] = best_possible_result 94 | out_dict['idxs'] = inp['idxs'] 95 | 96 | reshaped_boxes = x1y1x2y2_to_y1x1y2x2(reshape( 97 | (pred_boxes + 1)/2, (inp['img_size']))) 98 | out_dict['pred_boxes'] = reshaped_boxes 99 | out_dict['pred_scores'] = att_box_best 100 | # orig_annot = inp['orig_annot'] 101 | # Sanity check 102 | # iou1 = (torch.diag(IoU_values(reshaped_boxes, orig_annot)) 103 | # >= self.acc_iou_threshold).float().mean() 104 | # assert actual_result.item() == iou1.item() 105 | return out_dict 106 | # return reduce_dict(out_dict) 107 | 108 | def get_eval_result(self, actual_bbox, annot, ids_to_use, msk=None): 109 | best_boxes = torch.gather( 110 | actual_bbox, 1, ids_to_use.view(-1, 1, 1).expand(-1, 1, 4)) 111 | best_boxes = best_boxes.view(best_boxes.size(0), -1) 112 | if msk is not None: 113 | best_boxes[msk] = 0 114 | # self.best_boxes = best_boxes 115 | ious = torch.diag(IoU_values(best_boxes, annot)) 116 | # self.fin_results = ious 117 | return (ious >= self.acc_iou_threshold).float().mean(), best_boxes 118 | 119 | 120 | def get_default_eval(ratios, scales, cfg): 121 | return Evaluator(ratios, scales, cfg) 122 | -------------------------------------------------------------------------------- /code/extended_config.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | import json 3 | from typing import Dict, Any 4 | 5 | ds_info = CN(json.load(open('./configs/ds_info.json'))) 6 | def_cfg = CN(json.load(open('./configs/cfg.json'))) 7 | 8 | cfg = CN(def_cfg) 9 | cfg.ds_info = CN(ds_info) 10 | 11 | # Device 12 | # setting default device 13 | cfg.device = 'cuda' 14 | 15 | # Training 16 | cfg.local_rank = 0 17 | cfg.do_dist = False 18 | 19 | # Testing 20 | cfg.only_val = False 21 | cfg.only_test = False 22 | 23 | key_maps = {} 24 | 25 | 26 | def create_from_dict(dct: Dict[str, Any], prefix: str, cfg: CN): 27 | """ 28 | Helper function to create yacs config from dictionary 29 | """ 30 | dct_cfg = CN(dct, new_allowed=True) 31 | prefix_list = prefix.split('.') 32 | d = cfg 33 | for pref in prefix_list[:-1]: 34 | assert isinstance(d, CN) 35 | if pref not in d: 36 | setattr(d, pref, CN()) 37 | d = d[pref] 38 | if hasattr(d, prefix_list[-1]): 39 | old_dct_cfg = d[prefix_list[-1]] 40 | dct_cfg.merge_from_other_cfg(old_dct_cfg) 41 | 42 | setattr(d, prefix_list[-1], dct_cfg) 43 | return cfg 44 | 45 | 46 | def update_from_dict(cfg: CN, dct: Dict[str, Any], 47 | key_maps: Dict[str, str] = None) -> CN: 48 | """ 49 | Given original CfgNode (cfg) and input dictionary allows changing 50 | the cfg with the updated dictionary values 51 | Optional key_maps argument which defines a mapping between 52 | same keys of the cfg node. Only used for convenience 53 | Adapted from: 54 | https://github.com/rbgirshick/yacs/blob/master/yacs/config.py#L219 55 | """ 56 | # Original cfg 57 | root = cfg 58 | 59 | # Change the input dictionary using keymaps 60 | # Now it is aligned with the cfg 61 | full_key_list = list(dct.keys()) 62 | for full_key in full_key_list: 63 | if full_key in key_maps: 64 | cfg[full_key] = dct[full_key] 65 | new_key = key_maps[full_key] 66 | dct[new_key] = dct.pop(full_key) 67 | 68 | # Convert the cfg using dictionary input 69 | for full_key, v in dct.items(): 70 | if root.key_is_deprecated(full_key): 71 | continue 72 | if root.key_is_renamed(full_key): 73 | root.raise_key_rename_error(full_key) 74 | key_list = full_key.split(".") 75 | d = cfg 76 | for subkey in key_list[:-1]: 77 | # Most important statement 78 | assert subkey in d, f'key {full_key} doesnot exist' 79 | d = d[subkey] 80 | 81 | subkey = key_list[-1] 82 | # Most important statement 83 | assert subkey in d, f'key {full_key} doesnot exist' 84 | 85 | value = cfg._decode_cfg_value(v) 86 | 87 | assert isinstance(value, type(d[subkey])) 88 | d[subkey] = value 89 | 90 | return cfg 91 | 92 | 93 | # def get_config_after_kwargs(cfg, kwargs: Dict[str, Any]): 94 | # ds_info = CN(json.load(open('./configs/ds_info.json'))) 95 | # def_cfg = CN(json.load(open('./configs/cfg.json'))) 96 | 97 | # upd_cfg = update_from_dict(def_cfg, kwargs, key_maps) 98 | 99 | # cfg_dict = { 100 | # 'ds_to_use': upd_cfg.ds_to_use, 101 | # 'mdl_to_use': upd_cfg.mdl_to_use, 102 | # 'lfn_to_use': upd_cfg.lfn_to_use, 103 | # 'efn_to_use': upd_cfg.efn_to_use, 104 | # 'opt_to_use': upd_cfg.opt_to_use, 105 | # 'sfn_to_use': upd_cfg.sfn_to_use 106 | # } 107 | 108 | # cfg = CN(cfg_dict) 109 | 110 | # cfg = create_from_dict(ds_info[cfg.ds_to_use], 'DS', cfg) 111 | -------------------------------------------------------------------------------- /code/fpn_resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Taken from the wonderful repository: https://github.com/yhenon/pytorch-retinanet/blob/master/model.py 3 | """ 4 | 5 | import torch.nn as nn 6 | import torch 7 | import torch.utils.model_zoo as model_zoo 8 | import torch.nn.functional as F 9 | import numpy as np 10 | 11 | model_urls = { 12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 13 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 14 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 15 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 16 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 17 | } 18 | 19 | 20 | def conv3x3(in_planes, out_planes, stride=1): 21 | """3x3 convolution with padding""" 22 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 23 | padding=1, bias=False) 24 | 25 | 26 | class BasicBlock(nn.Module): 27 | """ 28 | standard Basic block 29 | """ 30 | expansion = 1 31 | 32 | def __init__(self, inplanes, planes, stride=1, downsample=None): 33 | super(BasicBlock, self).__init__() 34 | self.conv1 = conv3x3(inplanes, planes, stride) 35 | self.bn1 = nn.BatchNorm2d(planes) 36 | self.relu = nn.ReLU(inplace=True) 37 | self.conv2 = conv3x3(planes, planes) 38 | self.bn2 = nn.BatchNorm2d(planes) 39 | self.downsample = downsample 40 | self.stride = stride 41 | 42 | def forward(self, x): 43 | residual = x 44 | 45 | out = self.conv1(x) 46 | out = self.bn1(out) 47 | out = self.relu(out) 48 | 49 | out = self.conv2(out) 50 | out = self.bn2(out) 51 | 52 | if self.downsample is not None: 53 | residual = self.downsample(x) 54 | 55 | out += residual 56 | out = self.relu(out) 57 | 58 | return out 59 | 60 | 61 | class Bottleneck(nn.Module): 62 | """ 63 | Standard Bottleneck block 64 | """ 65 | expansion = 4 66 | 67 | def __init__(self, inplanes, planes, stride=1, downsample=None): 68 | super(Bottleneck, self).__init__() 69 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 70 | self.bn1 = nn.BatchNorm2d(planes) 71 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 72 | padding=1, bias=False) 73 | self.bn2 = nn.BatchNorm2d(planes) 74 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 75 | self.bn3 = nn.BatchNorm2d(planes * 4) 76 | self.relu = nn.ReLU(inplace=True) 77 | self.downsample = downsample 78 | self.stride = stride 79 | 80 | def forward(self, x): 81 | residual = x 82 | 83 | out = self.conv1(x) 84 | out = self.bn1(out) 85 | out = self.relu(out) 86 | 87 | out = self.conv2(out) 88 | out = self.bn2(out) 89 | out = self.relu(out) 90 | 91 | out = self.conv3(out) 92 | out = self.bn3(out) 93 | 94 | if self.downsample is not None: 95 | residual = self.downsample(x) 96 | 97 | out += residual 98 | out = self.relu(out) 99 | 100 | return out 101 | 102 | 103 | def pad_out(k): 104 | "padding to have same size" 105 | return (k-1)//2 106 | 107 | 108 | class FPN_backbone(nn.Module): 109 | """ 110 | A different fpn, doubt it will work 111 | """ 112 | 113 | def __init__(self, inch_list, cfg, feat_size=256): 114 | super().__init__() 115 | 116 | # self.backbone = backbone 117 | 118 | # expects c3, c4, c5 channel dims 119 | self.inch_list = inch_list 120 | self.cfg = cfg 121 | c3_ch, c4_ch, c5_ch = self.inch_list 122 | self.feat_size = feat_size 123 | 124 | self.P7_2 = nn.Conv2d(in_channels=self.feat_size, 125 | out_channels=self.feat_size, stride=2, 126 | kernel_size=3, 127 | padding=1) 128 | self.P6 = nn.Conv2d(in_channels=c5_ch, 129 | out_channels=self.feat_size, 130 | kernel_size=3, stride=2, padding=pad_out(3)) 131 | self.P5_1 = nn.Conv2d(in_channels=c5_ch, 132 | out_channels=self.feat_size, 133 | kernel_size=1, padding=pad_out(1)) 134 | 135 | self.P5_2 = nn.Conv2d(in_channels=self.feat_size, out_channels=self.feat_size, 136 | kernel_size=3, padding=pad_out(3)) 137 | 138 | self.P4_1 = nn.Conv2d(in_channels=c4_ch, 139 | out_channels=self.feat_size, kernel_size=1, 140 | padding=pad_out(1)) 141 | 142 | self.P4_2 = nn.Conv2d(in_channels=self.feat_size, 143 | out_channels=self.feat_size, kernel_size=3, 144 | padding=pad_out(3)) 145 | 146 | self.P3_1 = nn.Conv2d(in_channels=c3_ch, 147 | out_channels=self.feat_size, kernel_size=1, 148 | padding=pad_out(1)) 149 | 150 | self.P3_2 = nn.Conv2d(in_channels=self.feat_size, 151 | out_channels=self.feat_size, kernel_size=3, 152 | padding=pad_out(3)) 153 | 154 | def forward(self, inp): 155 | # expects inp to be output of c3, c4, c5 156 | c3, c4, c5 = inp 157 | p51 = self.P5_1(c5) 158 | p5_out = self.P5_2(p51) 159 | 160 | # p5_up = F.interpolate(p51, scale_factor=2) 161 | p5_up = F.interpolate(p51, size=(c4.size(2), c4.size(3))) 162 | p41 = self.P4_1(c4) + p5_up 163 | p4_out = self.P4_2(p41) 164 | 165 | # p4_up = F.interpolate(p41, scale_factor=2) 166 | p4_up = F.interpolate(p41, size=(c3.size(2), c3.size(3))) 167 | p31 = self.P3_1(c3) + p4_up 168 | p3_out = self.P3_2(p31) 169 | 170 | p6_out = self.P6(c5) 171 | 172 | p7_out = self.P7_2(F.relu(p6_out)) 173 | if self.cfg['resize_img'] == [600, 600]: 174 | return [p4_out, p5_out, p6_out, p7_out] 175 | 176 | # p8_out = self.p8_gen(F.relu(p7_out)) 177 | p8_out = F.adaptive_avg_pool2d(p7_out, 1) 178 | return [p3_out, p4_out, p5_out, p6_out, p7_out, p8_out] 179 | 180 | 181 | class PyramidFeatures(nn.Module): 182 | """ 183 | Pyramid Features, especially for Resnet 184 | """ 185 | 186 | def __init__(self, C3_size, C4_size, C5_size, feature_size=256): 187 | super(PyramidFeatures, self).__init__() 188 | 189 | # upsample C5 to get P5 from the FPN paper 190 | self.P5_1 = nn.Conv2d(C5_size, feature_size, 191 | kernel_size=1, stride=1, padding=0) 192 | self.P5_upsampled = nn.Upsample(scale_factor=2, mode='nearest') 193 | self.P5_2 = nn.Conv2d(feature_size, feature_size, 194 | kernel_size=3, stride=1, padding=1) 195 | # add P5 elementwise to C4 196 | self.P4_1 = nn.Conv2d(C4_size, feature_size, 197 | kernel_size=1, stride=1, padding=0) 198 | self.P4_upsampled = nn.Upsample(scale_factor=2, mode='nearest') 199 | self.P4_2 = nn.Conv2d(feature_size, feature_size, 200 | kernel_size=3, stride=1, padding=1) 201 | # add P4 elementwise to C3 202 | self.P3_1 = nn.Conv2d(C3_size, feature_size, 203 | kernel_size=1, stride=1, padding=0) 204 | self.P3_2 = nn.Conv2d(feature_size, feature_size, 205 | kernel_size=3, stride=1, padding=1) 206 | # "P6 is obtained via a 3x3 stride-2 conv on C5" 207 | self.P6 = nn.Conv2d(C5_size, feature_size, 208 | kernel_size=3, stride=2, padding=1) 209 | # "P7 is computed by applying ReLU followed by a 3x3 stride-2 conv on P6" 210 | self.P7_1 = nn.ReLU() 211 | self.P7_2 = nn.Conv2d(feature_size, feature_size, 212 | kernel_size=3, stride=2, padding=1) 213 | 214 | def forward(self, inputs): 215 | """ 216 | Inputs should be from layer2,3,4 217 | """ 218 | C3, C4, C5 = inputs 219 | P5_x = self.P5_1(C5) 220 | P5_upsampled_x = self.P5_upsampled(P5_x) 221 | P5_x = self.P5_2(P5_x) 222 | 223 | P4_x = self.P4_1(C4) 224 | P4_x = P5_upsampled_x + P4_x 225 | P4_upsampled_x = self.P4_upsampled(P4_x) 226 | P4_x = self.P4_2(P4_x) 227 | P3_x = self.P3_1(C3) 228 | P3_x = P3_x + P4_upsampled_x 229 | P3_x = self.P3_2(P3_x) 230 | P6_x = self.P6(C5) 231 | P7_x = self.P7_1(P6_x) 232 | P7_x = self.P7_2(P7_x) 233 | return [P3_x, P4_x, P5_x, P6_x, P7_x] 234 | 235 | 236 | class ResNet(nn.Module): 237 | """ 238 | Basic Resnet Module 239 | """ 240 | 241 | def __init__(self, num_classes, block, layers): 242 | self.inplanes = 64 243 | super(ResNet, self).__init__() 244 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, 245 | stride=2, padding=3, bias=False) 246 | self.bn1 = nn.BatchNorm2d(64) 247 | self.relu = nn.ReLU(inplace=True) 248 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 249 | self.layer1 = self._make_layer(block, 64, layers[0]) 250 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 251 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 252 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 253 | 254 | if block == BasicBlock: 255 | fpn_sizes = [self.layer2[layers[1]-1].conv2.out_channels, self.layer3[layers[2] - 256 | 1].conv2.out_channels, self.layer4[layers[3]-1].conv2.out_channels] 257 | elif block == Bottleneck: 258 | fpn_sizes = [self.layer2[layers[1]-1].conv3.out_channels, self.layer3[layers[2] - 259 | 1].conv3.out_channels, self.layer4[layers[3]-1].conv3.out_channels] 260 | 261 | self.freeze_bn() 262 | self.fpn = PyramidFeatures(fpn_sizes[0], fpn_sizes[1], fpn_sizes[2]) 263 | 264 | for m in self.modules(): 265 | if isinstance(m, nn.Conv2d): 266 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 267 | m.weight.data.normal_(0, np.sqrt(2. / n)) 268 | elif isinstance(m, nn.BatchNorm2d): 269 | m.weight.data.fill_(1) 270 | m.bias.data.zero_() 271 | prior = 0.01 272 | 273 | def _make_layer(self, block, planes, blocks, stride=1): 274 | """ 275 | Convenience function to generate layers given blocks and 276 | channel dimensions 277 | """ 278 | downsample = None 279 | if stride != 1 or self.inplanes != planes * block.expansion: 280 | downsample = nn.Sequential( 281 | nn.Conv2d(self.inplanes, planes * block.expansion, 282 | kernel_size=1, stride=stride, bias=False), 283 | nn.BatchNorm2d(planes * block.expansion), 284 | ) 285 | layers = [] 286 | layers.append(block(self.inplanes, planes, stride, downsample)) 287 | self.inplanes = planes * block.expansion 288 | for i in range(1, blocks): 289 | layers.append(block(self.inplanes, planes)) 290 | return nn.Sequential(*layers) 291 | 292 | def freeze_bn(self): 293 | '''Freeze BatchNorm layers.''' 294 | for layer in self.modules(): 295 | if isinstance(layer, nn.BatchNorm2d): 296 | layer.eval() 297 | 298 | def forward(self, inputs): 299 | """ 300 | inputs should be images 301 | """ 302 | img_batch = inputs 303 | 304 | x = self.conv1(img_batch) 305 | x = self.bn1(x) 306 | x = self.relu(x) 307 | x = self.maxpool(x) 308 | x1 = self.layer1(x) 309 | x2 = self.layer2(x1) 310 | x3 = self.layer3(x2) 311 | x4 = self.layer4(x3) 312 | features = self.fpn([x2, x3, x4]) 313 | return features 314 | 315 | 316 | def resnet50(num_classes, pretrained=False, **kwargs): 317 | """Constructs a ResNet-50 model. 318 | Args: 319 | pretrained (bool): If True, returns a model pre-trained on ImageNet 320 | """ 321 | model = ResNet(num_classes, Bottleneck, [3, 4, 6, 3], **kwargs) 322 | if pretrained: 323 | model.load_state_dict(model_zoo.load_url( 324 | model_urls['resnet50'], model_dir='.'), strict=False) 325 | return model 326 | -------------------------------------------------------------------------------- /code/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from anchors import (create_anchors, simple_match_anchors, 5 | bbox_to_reg_params, IoU_values) 6 | from typing import Dict 7 | from functools import partial 8 | # from utils import reduce_dict 9 | 10 | 11 | class ZSGLoss(nn.Module): 12 | """ 13 | Criterion to be minimized 14 | Requires the anchors to be used 15 | for loss computation 16 | """ 17 | 18 | def __init__(self, ratios, scales, cfg): 19 | super().__init__() 20 | self.cfg = cfg 21 | 22 | self.ratios = ratios 23 | self.scales = scales 24 | 25 | self.alpha = cfg['alpha'] 26 | self.gamma = cfg['gamma'] 27 | 28 | # Which loss fucntion to use 29 | self.use_focal = cfg['use_focal'] 30 | self.use_softmax = cfg['use_softmax'] 31 | self.use_multi = cfg['use_multi'] 32 | 33 | self.lamb_reg = cfg['lamb_reg'] 34 | 35 | self.loss_keys = ['loss', 'cls_ls', 'box_ls'] 36 | self.anchs = None 37 | self.get_anchors = partial( 38 | create_anchors, ratios=self.ratios, 39 | scales=self.scales, flatten=True) 40 | 41 | self.box_loss = nn.SmoothL1Loss(reduction='none') 42 | 43 | def forward(self, out: Dict[str, torch.tensor], 44 | inp: Dict[str, torch.tensor]) -> Dict[str, torch.tensor]: 45 | """ 46 | inp: att_box, reg_box, feat_sizes 47 | annot: gt box (r1c1r2c2 form) 48 | """ 49 | annot = inp['annot'] 50 | att_box = out['att_out'] 51 | reg_box = out['bbx_out'] 52 | feat_sizes = out['feat_sizes'] 53 | num_f_out = out['num_f_out'] 54 | 55 | device = att_box.device 56 | 57 | # get the correct number of output features 58 | # in the case of DataParallel 59 | if len(num_f_out) > 1: 60 | num_f_out = int(num_f_out[0].item()) 61 | else: 62 | num_f_out = int(num_f_out.item()) 63 | 64 | # Computes Anchors only once since size is kept fixed 65 | # Needs to be changed in case size is not fixed 66 | if self.anchs is None: 67 | feat_sizes = feat_sizes[:num_f_out, :] 68 | anchs = self.get_anchors(feat_sizes) 69 | anchs = anchs.to(device) 70 | self.anchs = anchs 71 | else: 72 | anchs = self.anchs 73 | matches = simple_match_anchors( 74 | anchs, annot, match_thr=self.cfg['matching_threshold']) 75 | bbx_mask = (matches >= 0) 76 | ious1 = IoU_values(annot, anchs) 77 | _, msk = ious1.max(1) 78 | 79 | bbx_mask2 = torch.eye(anchs.size(0))[msk] 80 | bbx_mask2 = bbx_mask2 > 0 81 | bbx_mask2 = bbx_mask2.to(device) 82 | top1_mask = bbx_mask2 83 | 84 | if not self.use_multi: 85 | bbx_mask = bbx_mask2 86 | else: 87 | bbx_mask = bbx_mask | bbx_mask2 88 | 89 | # all clear 90 | gt_reg_params = bbox_to_reg_params(anchs, annot) 91 | box_l = self.box_loss(reg_box, gt_reg_params) 92 | # box_l_relv = box_l.sum(dim=2)[bbx_mask] 93 | box_l_relv = box_l.sum(dim=2) * bbx_mask.float() 94 | box_l_relv = box_l_relv.sum(dim=1) / bbx_mask.sum(dim=-1).float() 95 | box_loss = box_l_relv.mean() 96 | 97 | if box_loss.cpu() == torch.Tensor([float("Inf")]): 98 | # There is a likely bug with annot box 99 | # being very small 100 | import pdb 101 | pdb.set_trace() 102 | 103 | att_box = att_box.squeeze(-1) 104 | att_box_sigm = torch.sigmoid(att_box) 105 | 106 | if self.use_softmax: 107 | assert self.use_multi is False 108 | gt_ids = msk 109 | clas_loss = F.cross_entropy(att_box, gt_ids, reduction='none') 110 | else: 111 | if self.use_focal: 112 | encoded_tgt = bbx_mask.float() 113 | ps = att_box_sigm 114 | weights = encoded_tgt * (1-ps) + (1-encoded_tgt) * ps 115 | alphas = ((1-encoded_tgt) * self.alpha + 116 | encoded_tgt * (1-self.alpha)) 117 | weights.pow_(self.gamma).mul_(alphas) 118 | weights = weights.detach() 119 | else: 120 | weights = None 121 | 122 | clas_loss = F.binary_cross_entropy_with_logits( 123 | att_box, bbx_mask.float(), weight=weights, reduction='none') 124 | 125 | clas_loss = clas_loss.sum() / bbx_mask.sum() 126 | # clas_loss = clas_loss.sum() / clas_loss.size(0) 127 | 128 | if torch.isnan(box_loss) or torch.isnan(clas_loss): 129 | # print('Nan Loss') 130 | box_loss = box_loss.new_ones(box_loss.shape) * 0.01 131 | box_loss.requires_grad = True 132 | clas_loss = clas_loss.new_ones(clas_loss.shape) 133 | clas_loss.requires_grad = True 134 | 135 | out_loss = self.lamb_reg * box_loss + clas_loss 136 | # + self.lamb_rel * rel_loss 137 | out_dict = {} 138 | out_dict['loss'] = out_loss 139 | out_dict['cls_ls'] = clas_loss 140 | out_dict['box_ls'] = box_loss 141 | # out_dict['rel_ls'] = rel_loss 142 | 143 | return out_dict 144 | # return reduce_dict(out_dict) 145 | 146 | 147 | def get_default_loss(ratios, scales, cfg): 148 | return ZSGLoss(ratios, scales, cfg) 149 | -------------------------------------------------------------------------------- /code/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | QNet: Main SetUp Code 3 | Author: Arka Sadhu 4 | """ 5 | import pandas as pd 6 | from dat_loader import get_data 7 | from mdl import get_default_net 8 | # from qnet_model import get_default_net 9 | from loss import get_default_loss 10 | import torch 11 | import fire 12 | from evaluator import Evaluator 13 | # from evaluate import Evaluator 14 | import json 15 | from functools import partial 16 | from torch.optim import Adam 17 | import numpy as np 18 | from tqdm import tqdm 19 | from utils import Learner 20 | # import logging 21 | from extended_config import cfg as conf 22 | 23 | # logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 24 | # datefmt='%m/%d/%Y %H:%M:%S', 25 | # level=logging.INFO) 26 | # logger = logging.getLogger(__name__) 27 | 28 | 29 | def sanity_check(learn): 30 | qnet = learn.model 31 | qlos = learn.loss_func 32 | db = learn.data 33 | x, y = next(iter(db.train_dl)) 34 | opt = torch.optim.Adam(qnet.parameters(), lr=2e-4) 35 | met = learn.metrics[0] 36 | for i in range(1000): 37 | opt.zero_grad() 38 | out = qnet(*x) 39 | loss = qlos(out, y) 40 | loss.backward() 41 | opt.step() 42 | met_out = met(out, y) 43 | # print(i, loss.item(), met_out.item()) 44 | print(i, qlos.box_loss_smooth.smooth.item(), qlos.cls_loss_smooth.smooth.item(), 45 | met_out.item(), met.best_possible_result.item()) 46 | 47 | 48 | def test_for_entities(learn): 49 | learn.model.eval() 50 | full_eval_lists = [] 51 | with torch.no_grad(): 52 | for xb, yb in tqdm(learn.data.test_dl): 53 | out = learn.model(*xb) 54 | evl = learn.metrics[0](out, yb) 55 | eval_list = learn.metrics[0].fin_results 56 | full_eval_lists.append(eval_list) 57 | full_eval_tensor = torch.cat(full_eval_lists, dim=0) 58 | # import pdb 59 | # pdb.set_trace() 60 | print(f'Acc: {full_eval_tensor.float().mean()}') 61 | full_eval_df = pd.DataFrame(full_eval_tensor.tolist()) 62 | full_eval_df.to_csv('./vgg_flickr_eval_test.lst', 63 | header=False, index=False) 64 | 65 | 66 | def learner_init(uid, cfg): 67 | device_count = torch.cuda.device_count() 68 | device = torch.device('cuda') 69 | 70 | if type(cfg['ratios']) != list: 71 | ratios = eval(cfg['ratios'], {}) 72 | else: 73 | ratios = cfg['ratios'] 74 | if type(cfg['scales']) != list: 75 | scales = cfg['scale_factor'] * np.array(eval(cfg['scales'], {})) 76 | else: 77 | scales = cfg['scale_factor'] * np.array(cfg['scales']) 78 | 79 | num_anchors = len(ratios) * len(scales) 80 | qnet = get_default_net(num_anchors=num_anchors, cfg=cfg) 81 | qnet = qnet.to(device) 82 | qnet = torch.nn.DataParallel(qnet) 83 | 84 | qlos = get_default_loss( 85 | ratios, scales, cfg) 86 | qlos = qlos.to(device) 87 | qeval = Evaluator(ratios, scales, cfg) 88 | # db = get_data(bs=cfg['bs'] * device_count, nw=cfg['nw'], bsv=cfg['bsv'] * device_count, 89 | # nwv=cfg['nwv'], devices=cfg['devices'], do_tfms=cfg['do_tfms'], 90 | # cfg=cfg, data_cfg=data_cfg) 91 | # db = get_data(cfg, ds_name=cfg['ds_to_use']) 92 | db = get_data(cfg) 93 | opt_fn = partial(torch.optim.Adam, betas=(0.9, 0.99)) 94 | 95 | # Note: Currently using default optimizer 96 | learn = Learner(uid=uid, data=db, mdl=qnet, loss_fn=qlos, 97 | opt_fn=opt_fn, eval_fn=qeval, device=device, cfg=cfg) 98 | return learn 99 | 100 | 101 | def main(uid, del_existing=False, resume=True, **kwargs): 102 | # cfg = json.load(open('cfg.json')) 103 | cfg = conf 104 | cfg['resume'] = resume 105 | cfg['del_existing'] = del_existing 106 | cfg.update(kwargs) 107 | 108 | cfg.num_gpus = torch.cuda.device_count() 109 | # data_cfg = json.load(open('./ds_info.json')) 110 | # if cfg.do_dp: 111 | cfg.bs = cfg.bs * cfg.num_gpus 112 | cfg.nw = cfg.nw * cfg.num_gpus 113 | 114 | cfg.bsv = cfg.bsv * cfg.num_gpus 115 | cfg.nwv = cfg.nwv * cfg.num_gpus 116 | 117 | learn = learner_init(uid, cfg) 118 | if not (cfg['only_val'] or cfg['only_test']): 119 | learn.fit(epochs=int(cfg['epochs']), lr=cfg['lr']) 120 | else: 121 | print(cfg) 122 | if cfg['only_val']: 123 | # learn.testing(learn.data.valid_dl) 124 | val_loss, val_acc, preds = learn.validate(db=learn.data.valid_dl) 125 | for k in val_acc: 126 | print(val_acc[k]) 127 | # if isinstance(learn.data.test_dl, list): 128 | # for i, t in enumerate(learn.data.test_dl): 129 | # print('For dl ', i) 130 | # test_loss, test_acc = learn.validate(db=t) 131 | # for k in test_acc: 132 | # print(test_acc[k]) 133 | 134 | # else: 135 | # test_loss, test_acc = learn.validate(db=learn.data.test_dl) 136 | # for k in test_acc: 137 | # print(test_acc[k]) 138 | # learn.validate() 139 | # sanity_check(learn) 140 | 141 | 142 | if __name__ == '__main__': 143 | fire.Fire(main) 144 | -------------------------------------------------------------------------------- /code/main_dist.py: -------------------------------------------------------------------------------- 1 | """ 2 | Main file for distributed training 3 | """ 4 | from dat_loader import get_data 5 | from mdl import get_default_net 6 | from loss import get_default_loss 7 | from evaluator import get_default_eval 8 | from utils import Learner, synchronize 9 | 10 | import numpy as np 11 | import torch 12 | import fire 13 | from functools import partial 14 | 15 | from extended_config import (cfg as conf, key_maps, CN, update_from_dict) 16 | 17 | 18 | def learner_init(uid: str, cfg: CN) -> Learner: 19 | device = torch.device('cuda') 20 | data = get_data(cfg) 21 | 22 | # Ugly hack because I wanted ratios, scales 23 | # in fractional formats 24 | if type(cfg['ratios']) != list: 25 | ratios = eval(cfg['ratios'], {}) 26 | else: 27 | ratios = cfg['ratios'] 28 | if type(cfg['scales']) != list: 29 | scales = cfg['scale_factor'] * np.array(eval(cfg['scales'], {})) 30 | else: 31 | scales = cfg['scale_factor'] * np.array(cfg['scales']) 32 | 33 | num_anchors = len(ratios) * len(scales) 34 | mdl = get_default_net(num_anchors=num_anchors, cfg=cfg) 35 | mdl.to(device) 36 | if cfg.do_dist: 37 | mdl = torch.nn.parallel.DistributedDataParallel( 38 | mdl, device_ids=[cfg.local_rank], 39 | output_device=cfg.local_rank, broadcast_buffers=True, 40 | find_unused_parameters=True) 41 | elif not cfg.do_dist and cfg.num_gpus: 42 | # Use data parallel 43 | mdl = torch.nn.DataParallel(mdl) 44 | 45 | loss_fn = get_default_loss(ratios, scales, cfg) 46 | loss_fn.to(device) 47 | 48 | eval_fn = get_default_eval(ratios, scales, cfg) 49 | # eval_fn.to(device) 50 | opt_fn = partial(torch.optim.Adam, betas=(0.9, 0.99)) 51 | 52 | learn = Learner(uid=uid, data=data, mdl=mdl, loss_fn=loss_fn, 53 | opt_fn=opt_fn, eval_fn=eval_fn, device=device, cfg=cfg) 54 | return learn 55 | 56 | 57 | def main_dist(uid: str, **kwargs): 58 | """ 59 | uid is a unique identifier for the experiment name 60 | Can be kept same as a previous run, by default will start executing 61 | from latest saved model 62 | **kwargs: allows arbit arguments of cfg to be changed 63 | """ 64 | cfg = conf 65 | num_gpus = torch.cuda.device_count() 66 | cfg.num_gpus = num_gpus 67 | 68 | if num_gpus > 1: 69 | 70 | if 'local_rank' in kwargs: 71 | # We are doing distributed parallel 72 | cfg.do_dist = True 73 | torch.cuda.set_device(kwargs['local_rank']) 74 | torch.distributed.init_process_group( 75 | backend="nccl", init_method="env://" 76 | ) 77 | synchronize() 78 | else: 79 | # We are doing data parallel 80 | cfg.do_dist = False 81 | 82 | # Update the config file depending on the command line args 83 | cfg = update_from_dict(cfg, kwargs, key_maps) 84 | 85 | # Freeze the cfg, can no longer be changed 86 | cfg.freeze() 87 | # print(cfg) 88 | # Initialize learner 89 | learn = learner_init(uid, cfg) 90 | # Train or Test 91 | if not (cfg.only_val or cfg.only_test): 92 | learn.fit(epochs=cfg.epochs, lr=cfg.lr) 93 | else: 94 | if cfg.only_val: 95 | learn.testing(learn.data.valid_dl) 96 | if cfg.only_test: 97 | learn.testing(learn.data.test_dl) 98 | 99 | 100 | if __name__ == '__main__': 101 | fire.Fire(main_dist) 102 | -------------------------------------------------------------------------------- /code/mdl.py: -------------------------------------------------------------------------------- 1 | """ 2 | Model file for zsgnet 3 | Author: Arka Sadhu 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | # import torch.nn.functional as F 8 | import torchvision.models as tvm 9 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 10 | from fpn_resnet import FPN_backbone 11 | from anchors import create_grid 12 | import ssd_vgg 13 | from typing import Dict, Any 14 | from extended_config import cfg as conf 15 | from dat_loader import get_data 16 | 17 | 18 | # conv2d, conv2d_relu are adapted from 19 | # https://github.com/fastai/fastai/blob/5c4cefdeaf11fdbbdf876dbe37134c118dca03ad/fastai/layers.py#L98 20 | def conv2d(ni: int, nf: int, ks: int = 3, stride: int = 1, 21 | padding: int = None, bias=False) -> nn.Conv2d: 22 | "Create and initialize `nn.Conv2d` layer. `padding` defaults to `ks//2`." 23 | if padding is None: 24 | padding = ks//2 25 | return nn.Conv2d(ni, nf, kernel_size=ks, stride=stride, 26 | padding=padding, bias=bias) 27 | 28 | 29 | def conv2d_relu(ni: int, nf: int, ks: int = 3, stride: int = 1, padding: int = None, 30 | bn: bool = False, bias: bool = False) -> nn.Sequential: 31 | """ 32 | Create a `conv2d` layer with `nn.ReLU` activation 33 | and optional(`bn`) `nn.BatchNorm2d`: `ni` input, `nf` out 34 | filters, `ks` kernel, `stride`:stride, `padding`:padding, 35 | `bn`: batch normalization. 36 | """ 37 | layers = [conv2d(ni, nf, ks=ks, stride=stride, 38 | padding=padding, bias=bias), nn.ReLU(inplace=True)] 39 | if bn: 40 | layers.append(nn.BatchNorm2d(nf)) 41 | return nn.Sequential(*layers) 42 | 43 | 44 | class BackBone(nn.Module): 45 | """ 46 | A general purpose Backbone class. 47 | For a new network, need to redefine: 48 | --> encode_feats 49 | Optionally after_init 50 | """ 51 | 52 | def __init__(self, encoder: nn.Module, cfg: dict, out_chs=256): 53 | """ 54 | Make required forward hooks 55 | """ 56 | super().__init__() 57 | self.device = torch.device(cfg.device) 58 | self.encoder = encoder 59 | self.cfg = cfg 60 | self.out_chs = out_chs 61 | self.after_init() 62 | 63 | def after_init(self): 64 | pass 65 | 66 | def num_channels(self): 67 | raise NotImplementedError 68 | 69 | def concat_we(self, x, we, only_we=False, only_grid=False): 70 | """ 71 | Convenience function to concat we 72 | Expects x in the form B x C x H x W (one feature map) 73 | we: B x wdim (the language vector) 74 | Output: concatenated word embedding and grid centers 75 | """ 76 | # Both cannot be true 77 | assert not (only_we and only_grid) 78 | 79 | # Create the grid 80 | grid = create_grid((x.size(2), x.size(3)), 81 | flatten=False).to(self.device) 82 | grid = grid.permute(2, 0, 1).contiguous() 83 | 84 | # TODO: Slightly cleaner implementation? 85 | grid_tile = grid.view( 86 | 1, grid.size(0), grid.size(1), grid.size(2)).expand( 87 | we.size(0), grid.size(0), grid.size(1), grid.size(2)) 88 | 89 | # In case we only need the grid 90 | # Basically, don't use any image/language information 91 | if only_grid: 92 | return grid_tile 93 | 94 | # Expand word embeddings 95 | word_emb_tile = we.view( 96 | we.size(0), we.size(1), 1, 1).expand( 97 | we.size(0), we.size(1), x.size(2), x.size(3)) 98 | 99 | # In case performing image blind (requiring only language) 100 | if only_we: 101 | return word_emb_tile 102 | 103 | # Concatenate along the channel dimension 104 | return torch.cat((x, word_emb_tile, grid_tile), dim=1) 105 | 106 | def encode_feats(self, inp): 107 | return self.encoder(inp) 108 | 109 | def forward(self, inp, we=None, 110 | only_we=False, only_grid=False): 111 | """ 112 | expecting word embedding of shape B x WE. 113 | If only image features are needed, don't 114 | provide any word embedding 115 | """ 116 | feats = self.encode_feats(inp) 117 | # If we want to do normalization of the features 118 | if self.cfg['do_norm']: 119 | feats = [ 120 | feat / feat.norm(dim=1).unsqueeze(1).expand(*feat.shape) 121 | for feat in feats 122 | ] 123 | 124 | # For language blind setting, can directly return the features 125 | if we is None: 126 | return feats 127 | 128 | if self.cfg['do_norm']: 129 | b, wdim = we.shape 130 | we = we / we.norm(dim=1).unsqueeze(1).expand(b, wdim) 131 | 132 | out = [self.concat_we( 133 | f, we, only_we=only_we, only_grid=only_grid) for f in feats] 134 | 135 | return out 136 | 137 | 138 | class RetinaBackBone(BackBone): 139 | def after_init(self): 140 | self.num_chs = self.num_channels() 141 | self.fpn = FPN_backbone(self.num_chs, self.cfg, feat_size=self.out_chs) 142 | 143 | def num_channels(self): 144 | return [self.encoder.layer2[-1].conv3.out_channels, 145 | self.encoder.layer3[-1].conv3.out_channels, 146 | self.encoder.layer4[-1].conv3.out_channels] 147 | 148 | def encode_feats(self, inp): 149 | x = self.encoder.conv1(inp) 150 | x = self.encoder.bn1(x) 151 | x = self.encoder.relu(x) 152 | x = self.encoder.maxpool(x) 153 | x1 = self.encoder.layer1(x) 154 | x2 = self.encoder.layer2(x1) 155 | x3 = self.encoder.layer3(x2) 156 | x4 = self.encoder.layer4(x3) 157 | 158 | feats = self.fpn([x2, x3, x4]) 159 | return feats 160 | 161 | 162 | class SSDBackBone(BackBone): 163 | """ 164 | ssd_vgg.py already implements encoder 165 | """ 166 | 167 | def encode_feats(self, inp): 168 | return self.encoder(inp) 169 | 170 | 171 | class ZSGNet(nn.Module): 172 | """ 173 | The main model 174 | Uses SSD like architecture but for Lang+Vision 175 | """ 176 | 177 | def __init__(self, backbone, n_anchors=1, final_bias=0., cfg=None): 178 | super().__init__() 179 | # assert isinstance(backbone, BackBone) 180 | self.backbone = backbone 181 | 182 | # Assume the output from each 183 | # component of backbone will have 256 channels 184 | self.device = torch.device(cfg.device) 185 | 186 | self.cfg = cfg 187 | 188 | # should be len(ratios) * len(scales) 189 | self.n_anchors = n_anchors 190 | 191 | self.emb_dim = cfg['emb_dim'] 192 | self.bid = cfg['use_bidirectional'] 193 | self.lstm_dim = cfg['lstm_dim'] 194 | 195 | # Calculate output dimension of LSTM 196 | self.lstm_out_dim = self.lstm_dim * (self.bid + 1) 197 | 198 | # Separate cases for language, image blind settings 199 | if self.cfg['use_lang'] and self.cfg['use_img']: 200 | self.start_dim_head = self.lstm_dim*(self.bid+1) + 256 + 2 201 | elif self.cfg['use_img'] and not self.cfg['use_lang']: 202 | # language blind 203 | self.start_dim_head = 256 204 | elif self.cfg['use_lang'] and not self.cfg['use_img']: 205 | # image blind 206 | self.start_dim_head = self.lstm_dim*(self.bid+1) 207 | else: 208 | # both image, lang blind 209 | self.start_dim_head = 2 210 | 211 | # If shared heads for classification, box regression 212 | # This is the config used in the paper 213 | if self.cfg['use_same_atb']: 214 | bias = torch.zeros(5 * self.n_anchors) 215 | bias[torch.arange(4, 5 * self.n_anchors, 5)] = -4 216 | self.att_reg_box = self._head_subnet( 217 | 5, self.n_anchors, final_bias=bias, 218 | start_dim_head=self.start_dim_head 219 | ) 220 | # This is not used. Kept for historical purposes 221 | else: 222 | self.att_box = self._head_subnet( 223 | 1, self.n_anchors, -4., start_dim_head=self.start_dim_head) 224 | self.reg_box = self._head_subnet( 225 | 4, self.n_anchors, start_dim_head=self.start_dim_head) 226 | 227 | self.lstm = nn.LSTM(self.emb_dim, self.lstm_dim, 228 | bidirectional=self.bid, batch_first=False) 229 | self.after_init() 230 | 231 | def after_init(self): 232 | "Placeholder if any child class needs something more" 233 | pass 234 | 235 | def _head_subnet(self, n_classes, n_anchors, final_bias=0., n_conv=4, chs=256, 236 | start_dim_head=256): 237 | """ 238 | Convenience function to create attention and regression heads 239 | """ 240 | layers = [conv2d_relu(start_dim_head, chs, bias=True)] 241 | layers += [conv2d_relu(chs, chs, bias=True) for _ in range(n_conv)] 242 | layers += [conv2d(chs, n_classes * n_anchors, bias=True)] 243 | layers[-1].bias.data.zero_().add_(final_bias) 244 | return nn.Sequential(*layers) 245 | 246 | def permute_correctly(self, inp, outc): 247 | """ 248 | Basically square box features are flattened 249 | """ 250 | # inp is features 251 | # B x C x H x W -> B x H x W x C 252 | out = inp.permute(0, 2, 3, 1).contiguous() 253 | out = out.view(out.size(0), -1, outc) 254 | return out 255 | 256 | def concat_we(self, x, we, append_grid_centers=True): 257 | """ 258 | Convenience function to concat we 259 | Expects x in the form B x C x H x W 260 | we: B x wdim 261 | """ 262 | b, wdim = we.shape 263 | we = we / we.norm(dim=1).unsqueeze(1).expand(b, wdim) 264 | word_emb_tile = we.view(we.size(0), we.size(1), 265 | 1, 1).expand(we.size(0), 266 | we.size(1), 267 | x.size(2), x.size(3)) 268 | 269 | if append_grid_centers: 270 | grid = create_grid((x.size(2), x.size(3)), 271 | flatten=False).to(self.device) 272 | grid = grid.permute(2, 0, 1).contiguous() 273 | grid_tile = grid.view(1, grid.size(0), grid.size(1), grid.size(2)).expand( 274 | we.size(0), grid.size(0), grid.size(1), grid.size(2)) 275 | 276 | return torch.cat((x, word_emb_tile, grid_tile), dim=1) 277 | return torch.cat((x, word_emb_tile), dim=1) 278 | 279 | def lstm_init_hidden(self, bs): 280 | """ 281 | Initialize the very first hidden state of LSTM 282 | Basically, the LSTM should be independent of this 283 | """ 284 | if not self.bid: 285 | hidden_a = torch.randn(1, bs, self.lstm_dim) 286 | hidden_b = torch.randn(1, bs, self.lstm_dim) 287 | else: 288 | hidden_a = torch.randn(2, bs, self.lstm_dim) 289 | hidden_b = torch.randn(2, bs, self.lstm_dim) 290 | 291 | hidden_a = hidden_a.to(self.device) 292 | hidden_b = hidden_b.to(self.device) 293 | 294 | return (hidden_a, hidden_b) 295 | 296 | def apply_lstm(self, word_embs, qlens, max_qlen, get_full_seq=False): 297 | """ 298 | Applies lstm function. 299 | word_embs: word embeddings, B x seq_len x 300 300 | qlen: length of the phrases 301 | Try not to fiddle with this function. 302 | IT JUST WORKS 303 | """ 304 | # B x T x E 305 | bs, max_seq_len, emb_dim = word_embs.shape 306 | # bid x B x L 307 | self.hidden = self.lstm_init_hidden(bs) 308 | # B x 1, B x 1 309 | qlens1, perm_idx = qlens.sort(0, descending=True) 310 | # B x T x E (permuted) 311 | qtoks = word_embs[perm_idx] 312 | # T x B x E 313 | embeds = qtoks.permute(1, 0, 2).contiguous() 314 | # Packed Embeddings 315 | packed_embed_inp = pack_padded_sequence( 316 | embeds, lengths=qlens1, batch_first=False) 317 | # To ensure no pains with DataParallel 318 | # self.lstm.flatten_parameters() 319 | lstm_out1, (self.hidden, _) = self.lstm(packed_embed_inp, self.hidden) 320 | 321 | # T x B x L 322 | lstm_out, req_lens = pad_packed_sequence( 323 | lstm_out1, batch_first=False, total_length=max_qlen) 324 | 325 | # TODO: Simplify getting the last vector 326 | masks = (qlens1-1).view(1, -1, 1).expand(max_qlen, 327 | lstm_out.size(1), lstm_out.size(2)) 328 | qvec_sorted = lstm_out.gather(0, masks.long())[0] 329 | 330 | qvec_out = word_embs.new_zeros(qvec_sorted.shape) 331 | qvec_out[perm_idx] = qvec_sorted 332 | # if full sequence is needed for future work 333 | if get_full_seq: 334 | lstm_out_1 = lstm_out.transpose(1, 0).contiguous() 335 | return lstm_out_1 336 | return qvec_out.contiguous() 337 | 338 | def forward(self, inp: Dict[str, Any]): 339 | """ 340 | Forward method of the model 341 | inp0 : image to be used 342 | inp1 : word embeddings, B x seq_len x 300 343 | qlens: length of phrases 344 | 345 | The following is performed: 346 | 1. Get final hidden state features of lstm 347 | 2. Get image feature maps 348 | 3. Concatenate the two, specifically, copy lang features 349 | and append it to all the image feature maps, also append the 350 | grid centers. 351 | 4. Use the classification, regression head on this concatenated features 352 | The matching with groundtruth is done in loss function and evaluation 353 | """ 354 | inp0 = inp['img'] 355 | inp1 = inp['qvec'] 356 | qlens = inp['qlens'] 357 | max_qlen = int(qlens.max().item()) 358 | req_embs = inp1[:, :max_qlen, :].contiguous() 359 | 360 | req_emb = self.apply_lstm(req_embs, qlens, max_qlen) 361 | 362 | # image blind 363 | if self.cfg['use_lang'] and not self.cfg['use_img']: 364 | # feat_out = self.backbone(inp0) 365 | feat_out = self.backbone(inp0, req_emb, only_we=True) 366 | 367 | # language blind 368 | elif self.cfg['use_img'] and not self.cfg['use_lang']: 369 | feat_out = self.backbone(inp0) 370 | 371 | elif not self.cfg['use_img'] and not self.cfg['use_lang']: 372 | feat_out = self.backbone(inp0, req_emb, only_grid=True) 373 | # see full language + image (happens by default) 374 | else: 375 | feat_out = self.backbone(inp0, req_emb) 376 | 377 | # Strategy depending on shared head or not 378 | if self.cfg['use_same_atb']: 379 | att_bbx_out = torch.cat([self.permute_correctly( 380 | self.att_reg_box(feature), 5) for feature in feat_out], dim=1) 381 | att_out = att_bbx_out[..., [-1]] 382 | bbx_out = att_bbx_out[..., :-1] 383 | else: 384 | att_out = torch.cat( 385 | [self.permute_correctly(self.att_box(feature), 1) 386 | for feature in feat_out], dim=1) 387 | bbx_out = torch.cat( 388 | [self.permute_correctly(self.reg_box(feature), 4) 389 | for feature in feat_out], dim=1) 390 | 391 | feat_sizes = torch.tensor([[f.size(2), f.size(3)] 392 | for f in feat_out]).to(self.device) 393 | 394 | # Used mainly due to dataparallel consistency 395 | num_f_out = torch.tensor([len(feat_out)]).to(self.device) 396 | 397 | out_dict = {} 398 | out_dict['att_out'] = att_out 399 | out_dict['bbx_out'] = bbx_out 400 | out_dict['feat_sizes'] = feat_sizes 401 | out_dict['num_f_out'] = num_f_out 402 | 403 | return out_dict 404 | 405 | 406 | def get_default_net(num_anchors=1, cfg=None): 407 | """ 408 | Constructs the network based on the config 409 | """ 410 | if cfg['mdl_to_use'] == 'retina': 411 | encoder = tvm.resnet50(True) 412 | backbone = RetinaBackBone(encoder, cfg) 413 | elif cfg['mdl_to_use'] == 'ssd_vgg': 414 | encoder = ssd_vgg.build_ssd('train', cfg=cfg) 415 | encoder.vgg.load_state_dict( 416 | torch.load('./weights/vgg16_reducedfc.pth')) 417 | print('loaded pretrained vgg backbone') 418 | backbone = SSDBackBone(encoder, cfg) 419 | # backbone = encoder 420 | 421 | zsg_net = ZSGNet(backbone, num_anchors, cfg=cfg) 422 | return zsg_net 423 | 424 | 425 | if __name__ == '__main__': 426 | # torch.manual_seed(0) 427 | cfg = conf 428 | cfg.mdl_to_use = 'ssd_vgg' 429 | cfg.ds_to_use = 'refclef' 430 | cfg.num_gpus = 1 431 | # cfg.device = 'cpu' 432 | device = torch.device(cfg.device) 433 | data = get_data(cfg) 434 | 435 | zsg_net = get_default_net(num_anchors=9, cfg=cfg) 436 | zsg_net.to(device) 437 | 438 | batch = next(iter(data.train_dl)) 439 | for k in batch: 440 | batch[k] = batch[k].to(device) 441 | out = zsg_net(batch) 442 | -------------------------------------------------------------------------------- /code/ssd_vgg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from the amazing repository: https://github.com/amdegroot/ssd.pytorch/blob/master/ssd.py 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.autograd import Variable 9 | import torchvision.models as tvm 10 | import os 11 | 12 | 13 | class SSD(nn.Module): 14 | """Single Shot Multibox Architecture 15 | The network is composed of a base VGG network followed by the 16 | added multibox conv layers. Each multibox layer branches into 17 | 1) conv2d for class conf scores 18 | 2) conv2d for localization predictions 19 | 3) associated priorbox layer to produce default bounding 20 | boxes specific to the layer's feature map size. 21 | See: https://arxiv.org/pdf/1512.02325.pdf for more details. 22 | 23 | Args: 24 | phase: (string) Can be "test" or "train" 25 | size: input image size 26 | base: VGG16 layers for input, size of either 300 or 500 27 | extras: extra layers that feed to multibox loc and conf layers 28 | head: "multibox head" consists of loc and conf conv layers 29 | """ 30 | 31 | def __init__(self, phase, size, base, extras, head, num_classes, cfg=None): 32 | super(SSD, self).__init__() 33 | self.phase = phase 34 | self.num_classes = num_classes 35 | # self.cfg = (coco, voc)[num_classes == 21] 36 | # self.priorbox = PriorBox(self.cfg) 37 | # self.priors = Variable(self.priorbox.forward(), volatile=True) 38 | self.size = size 39 | self.cfg = cfg 40 | 41 | # SSD network 42 | self.vgg = nn.ModuleList(base) 43 | # self.vgg = tvm.vgg16(pretrained=True) 44 | # Layer learns to scale the l2 normalized features from conv4_3 45 | # self.L2Norm = L2Norm(512, 20) 46 | self.fproj1 = nn.Conv2d(512, 256, kernel_size=1) 47 | self.fproj2 = nn.Conv2d(1024, 256, kernel_size=1) 48 | self.fproj3 = nn.Conv2d(512, 256, kernel_size=1) 49 | self.extras = nn.ModuleList(extras) 50 | 51 | self.loc = nn.ModuleList(head[0]) 52 | self.conf = nn.ModuleList(head[1]) 53 | 54 | def forward(self, x): 55 | """Applies network layers and ops on input image(s) x. 56 | 57 | Args: 58 | x: input image or batch of images. Shape: [batch,3,300,300]. 59 | 60 | Return: 61 | Depending on phase: 62 | test: 63 | Variable(tensor) of output class label predictions, 64 | confidence score, and corresponding location predictions for 65 | each object detected. Shape: [batch,topk,7] 66 | 67 | train: 68 | list of concat outputs from: 69 | 1: confidence layers, Shape: [batch*num_priors,num_classes] 70 | 2: localization layers, Shape: [batch,num_priors*4] 71 | 3: priorbox layers, Shape: [2,num_priors*4] 72 | """ 73 | sources = list() 74 | 75 | # apply vgg up to conv4_3 relu 76 | for k in range(23): 77 | x = self.vgg[k](x) 78 | 79 | # s = self.L2Norm(x) 80 | s = x / x.norm(dim=1, keepdim=True) 81 | sources.append(s) 82 | # print(f'Adding1 of dim {s.shape}') 83 | 84 | # apply vgg up to fc7 85 | for k in range(23, len(self.vgg)): 86 | x = self.vgg[k](x) 87 | sources.append(x) 88 | # print(f'Adding2 of dim {x.shape}') 89 | 90 | # apply extra layers and cache source layer outputs 91 | for k, v in enumerate(self.extras): 92 | x = F.relu(v(x), inplace=True) 93 | if k % 2 == 1: 94 | sources.append(x) 95 | # print(f'Adding3 of dim {x.shape}') 96 | 97 | out_sources = [self.fproj1(sources[0]), self.fproj2( 98 | sources[1]), self.fproj3(sources[2])] + sources[3:] 99 | if self.cfg['resize_img'][0] >= 600: 100 | # To Reduce the computation 101 | return out_sources[1:] 102 | return out_sources 103 | 104 | def load_weights(self, base_file): 105 | other, ext = os.path.splitext(base_file) 106 | if ext == '.pkl' or '.pth': 107 | print('Loading weights into state dict...') 108 | self.load_state_dict(torch.load(base_file, 109 | map_location=lambda storage, loc: storage)) 110 | print('Finished!') 111 | else: 112 | print('Sorry only .pth and .pkl files supported.') 113 | 114 | 115 | # This function is derived from torchvision VGG make_layers() 116 | # https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py 117 | def vgg(cfg, i, batch_norm=False): 118 | layers = [] 119 | in_channels = i 120 | for v in cfg: 121 | if v == 'M': 122 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 123 | elif v == 'C': 124 | layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)] 125 | else: 126 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 127 | if batch_norm: 128 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 129 | else: 130 | layers += [conv2d, nn.ReLU(inplace=True)] 131 | in_channels = v 132 | pool5 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1) 133 | conv6 = nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6) 134 | conv7 = nn.Conv2d(1024, 1024, kernel_size=1) 135 | layers += [pool5, conv6, 136 | nn.ReLU(inplace=True), conv7, nn.ReLU(inplace=True)] 137 | return layers 138 | 139 | 140 | def add_extras(cfg, i, batch_norm=False): 141 | # Extra layers added to VGG for feature scaling 142 | layers = [] 143 | in_channels = i 144 | flag = False 145 | for k, v in enumerate(cfg): 146 | if in_channels != 'S': 147 | if v == 'S': 148 | layers += [nn.Conv2d(in_channels, cfg[k + 1], 149 | kernel_size=(1, 3)[flag], stride=2, padding=1)] 150 | else: 151 | layers += [nn.Conv2d(in_channels, v, kernel_size=(1, 3)[flag])] 152 | flag = not flag 153 | in_channels = v 154 | return layers 155 | 156 | 157 | def multibox(vgg, extra_layers, cfg, num_classes): 158 | loc_layers = [] 159 | conf_layers = [] 160 | vgg_source = [21, -2] 161 | for k, v in enumerate(vgg_source): 162 | loc_layers += [nn.Conv2d(vgg[v].out_channels, 163 | cfg[k] * 4, kernel_size=3, padding=1)] 164 | conf_layers += [nn.Conv2d(vgg[v].out_channels, 165 | cfg[k] * num_classes, kernel_size=3, padding=1)] 166 | for k, v in enumerate(extra_layers[1::2], 2): 167 | loc_layers += [nn.Conv2d(v.out_channels, cfg[k] 168 | * 4, kernel_size=3, padding=1)] 169 | conf_layers += [nn.Conv2d(v.out_channels, cfg[k] 170 | * num_classes, kernel_size=3, padding=1)] 171 | return vgg, extra_layers, (loc_layers, conf_layers) 172 | 173 | 174 | base = { 175 | '300': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'C', 512, 512, 512, 'M', 176 | 512, 512, 512], 177 | '512': [], 178 | } 179 | extras = { 180 | '300': [256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256], 181 | '512': [], 182 | } 183 | mbox = { 184 | '300': [4, 6, 6, 6, 4, 4], # number of boxes per feature map location 185 | '512': [], 186 | } 187 | 188 | 189 | def build_ssd(phase, size=300, num_classes=21, cfg=None): 190 | if phase != "test" and phase != "train": 191 | print("ERROR: Phase: " + phase + " not recognized") 192 | return 193 | if size != 300: 194 | print("ERROR: You specified size " + repr(size) + ". However, " + 195 | "currently only SSD300 (size=300) is supported!") 196 | return 197 | base_, extras_, head_ = multibox(vgg(base[str(size)], 3), 198 | add_extras(extras[str(size)], 1024), 199 | mbox[str(size)], num_classes) 200 | return SSD(phase, size, base_, extras_, head_, num_classes, cfg=cfg) 201 | -------------------------------------------------------------------------------- /code/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions 3 | """ 4 | from typing import Dict, List, Optional, Union, Any, Callable 5 | import torch 6 | from torch import nn 7 | from torch.utils.data import DataLoader 8 | from dataclasses import dataclass 9 | from pathlib import Path 10 | import sys 11 | import re 12 | import numpy as np 13 | from collections import Counter 14 | from tqdm import tqdm 15 | import time 16 | import shutil 17 | import json 18 | from fastprogress.fastprogress import master_bar, progress_bar 19 | import logging 20 | import pickle 21 | from torch.utils.tensorboard import SummaryWriter 22 | from torch import distributed as dist 23 | from torch.distributed import ReduceOp 24 | from yacs.config import CfgNode as CN 25 | 26 | 27 | def get_world_size(): 28 | if not dist.is_available(): 29 | return 1 30 | if not dist.is_initialized(): 31 | return 1 32 | return dist.get_world_size() 33 | 34 | 35 | def get_rank(): 36 | if not dist.is_available(): 37 | return 0 38 | if not dist.is_initialized(): 39 | return 0 40 | return dist.get_rank() 41 | 42 | 43 | def is_main_process(): 44 | return get_rank() == 0 45 | 46 | 47 | def synchronize(): 48 | """ 49 | Helper function to synchronize (barrier) among all processes when 50 | using distributed training 51 | """ 52 | if not dist.is_available(): 53 | return 54 | if not dist.is_initialized(): 55 | return 56 | world_size = dist.get_world_size() 57 | if world_size == 1: 58 | return 59 | dist.barrier() 60 | 61 | 62 | def reduce_dict(input_dict, average=False): 63 | """ 64 | Args: 65 | input_dict (dict): all the values will be reduced 66 | average (bool): whether to do average or sum 67 | Reduce the values in the dictionary from all processes so that process with rank 68 | 0 has the averaged results. Returns a dict with the same fields as 69 | input_dict, after reduction. 70 | """ 71 | world_size = get_world_size() 72 | if world_size < 2: 73 | return input_dict 74 | with torch.no_grad(): 75 | names = [] 76 | values = [] 77 | # sort the keys so that they are consistent across processes 78 | for k in sorted(input_dict.keys()): 79 | names.append(k) 80 | values.append(input_dict[k]) 81 | values = torch.stack(values, dim=0) 82 | dist.reduce(values, dst=0) 83 | # if dist.get_rank() == 0: 84 | # only main process gets accumulated, so only divide by 85 | # world_size in this case 86 | # values /= world_size 87 | if average: 88 | values /= world_size 89 | reduced_dict = { 90 | k: v for k, v in zip(names, values)} 91 | return reduced_dict 92 | 93 | 94 | def reduce_dict_corr(input_dict, nums): 95 | world_size = get_world_size() 96 | if world_size < 2: 97 | return input_dict 98 | 99 | new_inp_dict = {k: v*nums for k, v in input_dict.items()} 100 | out_dict = reduce_dict(new_inp_dict) 101 | dist.reduce(nums, dst=0) 102 | if not is_main_process(): 103 | return out_dict 104 | out_dict_avg = {k: v / nums.item() for k, v in out_dict.items()} 105 | return out_dict_avg 106 | 107 | 108 | def exec_func_if_main_proc(func: Callable): 109 | def wrapper(*args, **kwargs): 110 | if is_main_process(): 111 | func(*args, **kwargs) 112 | return wrapper 113 | 114 | 115 | @dataclass 116 | class DataWrap: 117 | path: Union[str, Path] 118 | train_dl: DataLoader 119 | valid_dl: DataLoader 120 | test_dl: Optional[Union[DataLoader, Dict]] = None 121 | 122 | 123 | class SmoothenValue(): 124 | """ 125 | Create a smooth moving average for a value(loss, etc) using `beta`. 126 | Adapted from fastai(https://github.com/fastai/fastai) 127 | """ 128 | 129 | def __init__(self, beta: float): 130 | self.beta, self.n, self.mov_avg = beta, 0, 0 131 | self.smooth = 0 132 | 133 | def add_value(self, val: float) -> None: 134 | "Add `val` to calculate updated smoothed value." 135 | self.n += 1 136 | self.mov_avg = self.beta * \ 137 | self.mov_avg + (1 - self.beta) * val 138 | self.smooth = self.mov_avg / (1 - self.beta ** self.n) 139 | 140 | 141 | class SmoothenDict: 142 | "Converts list to dicts" 143 | 144 | def __init__(self, keys: List[str], val: int): 145 | self.keys = keys 146 | self.smooth_vals = {k: SmoothenValue(val) for k in keys} 147 | 148 | def add_value(self, val: Dict[str, torch.tensor]): 149 | for k in self.keys: 150 | self.smooth_vals[k].add_value(val[k].detach()) 151 | 152 | @property 153 | def smooth(self): 154 | return {k: self.smooth_vals[k].smooth for k in self.keys} 155 | 156 | @property 157 | def smooth1(self): 158 | return self.smooth_vals[self.keys[0]].smooth 159 | 160 | 161 | def compute_avg(inp: List, nums: torch.tensor) -> float: 162 | "Computes average given list of torch.tensor and numbers corresponding to them" 163 | return (torch.stack(inp) * nums).sum() / nums.sum() 164 | 165 | 166 | def compute_avg_dict(inp: Dict[str, List], 167 | nums: torch.tensor) -> Dict[str, float]: 168 | "Takes dict as input" 169 | out_dict = {} 170 | for k in inp: 171 | out_dict[k] = compute_avg(inp[k], nums) 172 | 173 | return out_dict 174 | 175 | 176 | def good_format_stats(names, stats) -> str: 177 | "Format stats before printing." 178 | str_stats = [] 179 | for name, stat in zip(names, stats): 180 | t = str(stat) if isinstance(stat, int) else f'{stat.item():.4f}' 181 | t += ' ' * (len(name) - len(t)) 182 | str_stats.append(t) 183 | return ' '.join(str_stats) 184 | 185 | 186 | @dataclass 187 | class Learner: 188 | uid: str 189 | data: DataWrap 190 | mdl: nn.Module 191 | loss_fn: nn.Module 192 | cfg: Dict 193 | eval_fn: nn.Module 194 | opt_fn: Callable 195 | device: torch.device = torch.device('cuda') 196 | 197 | def __post_init__(self): 198 | "Setup log file, load model if required" 199 | 200 | # Get rank 201 | self.rank = get_rank() 202 | 203 | self.init_log_dirs() 204 | 205 | self.prepare_log_keys() 206 | 207 | self.prepare_log_file() 208 | 209 | self.logger = self.init_logger() 210 | 211 | # Set the number of iterations, epochs, best_met to 0. 212 | # Updated in loading if required 213 | self.num_it = 0 214 | self.num_epoch = 0 215 | self.best_met = 0 216 | 217 | # Resume if given a path 218 | if self.cfg['resume']: 219 | self.load_model_dict( 220 | resume_path=self.cfg['resume_path'], 221 | load_opt=self.cfg['load_opt']) 222 | 223 | # self.writer.add_text(tag='cfg', text_string=json.dumps(self.cfg), 224 | # global_step=self.num_epoch) 225 | 226 | def init_logger(self): 227 | logger = logging.getLogger(__name__) 228 | logger.setLevel(logging.DEBUG) 229 | if not is_main_process(): 230 | return logger 231 | ch = logging.StreamHandler(stream=sys.stdout) 232 | ch.setLevel(logging.INFO) 233 | formatter = logging.Formatter( 234 | "%(asctime)s %(name)s %(levelname)s: %(message)s") 235 | ch.setFormatter(formatter) 236 | logger.addHandler(ch) 237 | 238 | fh = logging.FileHandler(str(self.extra_logger_file)) 239 | fh.setLevel(logging.DEBUG) 240 | fh.setFormatter(formatter) 241 | logger.addHandler(fh) 242 | # logging.basicConfig( 243 | # filename=self.extra_logger_file, 244 | # filemode='a', 245 | # format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 246 | # datefmt='%m/%d/%Y %H:%M:%S', 247 | # level=logging.INFO 248 | # ) 249 | return logger 250 | 251 | def init_log_dirs(self): 252 | """ 253 | Convenience function to create the following: 254 | 1. Log dir to store log file in txt format 255 | 2. Extra Log dir to store the logger output 256 | 3. Tb log dir to store tensorboard files 257 | 4. Model dir to save the model files 258 | 5. Predictions dir to store the predictions of the saved model 259 | 6. [Optional] Can add some 3rd party logger 260 | """ 261 | # Saves the text logs 262 | self.txt_log_file = Path( 263 | self.data.path) / 'txt_logs' / f'{self.uid}.txt' 264 | 265 | # Saves the output of self.logger 266 | self.extra_logger_file = Path( 267 | self.data.path) / 'ext_logs' / f'{self.uid}.txt' 268 | 269 | # Saves SummaryWriter outputs 270 | self.tb_log_dir = Path(self.data.path) / 'tb_logs' / f'{self.uid}' 271 | 272 | # Saves the trained model 273 | self.model_file = Path(self.data.path) / 'models' / f'{self.uid}.pth' 274 | 275 | # Saves the output predictions 276 | self.predictions_dir = Path( 277 | self.data.path) / 'predictions' / f'{self.uid}' 278 | 279 | self.create_log_dirs() 280 | 281 | @exec_func_if_main_proc 282 | def create_log_dirs(self): 283 | """ 284 | Creates the directories initialized in init_log_dirs 285 | """ 286 | self.txt_log_file.parent.mkdir(exist_ok=True, parents=True) 287 | self.extra_logger_file.parent.mkdir(exist_ok=True) 288 | self.tb_log_dir.mkdir(exist_ok=True, parents=True) 289 | self.model_file.parent.mkdir(exist_ok=True) 290 | self.predictions_dir.mkdir(exist_ok=True, parents=True) 291 | 292 | def prepare_log_keys(self): 293 | """ 294 | Creates the relevant keys to be logged. 295 | Mainly used by the txt logger to output in a good format 296 | """ 297 | def _prepare_log_keys(keys_list: List[List[str]], 298 | prefix: List[str]) -> List[str]: 299 | """ 300 | Convenience function to create log keys 301 | keys_list: List[loss_keys, met_keys] 302 | prefix: List['trn', 'val'] 303 | """ 304 | log_keys = [] 305 | for keys in keys_list: 306 | for key in keys: 307 | log_keys += [f'{p}_{key}' for p in prefix] 308 | return log_keys 309 | 310 | self.loss_keys = self.loss_fn.loss_keys 311 | self.met_keys = self.eval_fn.met_keys 312 | 313 | # When writing Training and Validation together 314 | self.log_keys = ['epochs'] + _prepare_log_keys( 315 | [self.loss_keys, self.met_keys], 316 | ['trn', 'val'] 317 | ) 318 | 319 | self.val_log_keys = ['epochs'] + _prepare_log_keys( 320 | [self.loss_keys, self.met_keys], 321 | ['val'] 322 | ) 323 | 324 | self.test_log_keys = ['epochs'] + _prepare_log_keys( 325 | [self.met_keys], 326 | ['test'] 327 | ) 328 | 329 | @exec_func_if_main_proc 330 | def prepare_log_file(self): 331 | "Prepares the log files depending on arguments" 332 | f = self.txt_log_file.open('a') 333 | cfgtxt = json.dumps(self.cfg) 334 | f.write(cfgtxt) 335 | f.write('\n\n') 336 | f.write(' '.join(self.log_keys) + '\n') 337 | f.close() 338 | 339 | @exec_func_if_main_proc 340 | def update_log_file(self, towrite: str): 341 | "Updates the log files as and when required" 342 | with self.txt_log_file.open('a') as f: 343 | f.write(towrite + '\n') 344 | 345 | def get_predictions_list(self, predictions: Dict[str, List]) -> List[Dict]: 346 | "Converts dictionary of lists to list of dictionary" 347 | keys = list(predictions.keys()) 348 | num_preds = len(predictions[keys[0]]) 349 | out_list = [{k: predictions[k][ind] for k in keys} 350 | for ind in range(num_preds)] 351 | return out_list 352 | 353 | def validate(self, db: Optional[DataLoader] = None, 354 | mb=None) -> List[torch.tensor]: 355 | "Validation loop, done after every epoch" 356 | self.mdl.eval() 357 | if db is None: 358 | db = self.data.valid_dl 359 | 360 | predicted_box_dict_list = [] 361 | with torch.no_grad(): 362 | val_losses = {k: [] for k in self.loss_keys} 363 | eval_metrics = {k: [] for k in self.met_keys} 364 | nums = [] 365 | for batch in progress_bar(db, parent=mb): 366 | for b in batch.keys(): 367 | batch[b] = batch[b].to(self.device) 368 | out = self.mdl(batch) 369 | out_loss = self.loss_fn(out, batch) 370 | 371 | metric = self.eval_fn(out, batch) 372 | for k in self.loss_keys: 373 | val_losses[k].append(out_loss[k].detach()) 374 | for k in self.met_keys: 375 | eval_metrics[k].append(metric[k].detach()) 376 | nums.append(batch[next(iter(batch))].shape[0]) 377 | prediction_dict = { 378 | 'id': metric['idxs'].tolist(), 379 | 'pred_boxes': metric['pred_boxes'].tolist(), 380 | 'pred_scores': metric['pred_scores'].tolist() 381 | } 382 | predicted_box_dict_list += self.get_predictions_list( 383 | prediction_dict) 384 | nums = torch.tensor(nums).float().to(self.device) 385 | tot_nums = nums.sum() 386 | val_loss = compute_avg_dict(val_losses, nums) 387 | val_loss = reduce_dict_corr(val_loss, tot_nums) 388 | 389 | eval_metric = compute_avg_dict(eval_metrics, nums) 390 | eval_metric = reduce_dict_corr(eval_metric, tot_nums) 391 | return val_loss, eval_metric, predicted_box_dict_list 392 | 393 | def train_epoch(self, mb) -> List[torch.tensor]: 394 | "One epoch used for training" 395 | self.mdl.train() 396 | # trn_loss = SmoothenValue(0.9) 397 | trn_loss = SmoothenDict(self.loss_keys, 0.9) 398 | trn_acc = SmoothenDict(self.met_keys, 0.9) 399 | 400 | for batch_id, batch in enumerate(progress_bar(self.data.train_dl, parent=mb)): 401 | # for batch_id, batch in progress_bar(QueueIterator(batch_queue), parent=mb): 402 | # for batch_id, batch in QueueIterator(batch_queue): 403 | # Increment number of iterations 404 | self.num_it += 1 405 | for b in batch.keys(): 406 | batch[b] = batch[b].to(self.device) 407 | self.optimizer.zero_grad() 408 | out = self.mdl(batch) 409 | out_loss = self.loss_fn(out, batch) 410 | loss = out_loss[self.loss_keys[0]] 411 | loss = loss.mean() 412 | loss.backward() 413 | self.optimizer.step() 414 | metric = self.eval_fn(out, batch) 415 | 416 | # Returns original dictionary if not distributed parallel 417 | # loss_reduced = reduce_dict(out_loss, average=True) 418 | # metric_reduced = reduce_dict(metric, average=True) 419 | 420 | trn_loss.add_value(out_loss) 421 | trn_acc.add_value(metric) 422 | 423 | # self.writer.add_scalar( 424 | # tag='trn_loss', scalar_value=out_loss[self.loss_keys[0]], 425 | # global_step=self.num_it) 426 | comment_to_print = f'LossB {loss: .4f} | SmLossB {trn_loss.smooth1: .4f} | AccB {trn_acc.smooth1: .4f}' 427 | mb.child.comment = comment_to_print 428 | if self.num_it % 2 == 0: 429 | self.logger.debug(f'Num_it {self.num_it} {comment_to_print}') 430 | del out_loss 431 | del loss 432 | # print(f'Done {batch_id}') 433 | del batch 434 | self.optimizer.zero_grad() 435 | out_loss = reduce_dict(trn_loss.smooth, average=True) 436 | out_met = reduce_dict(trn_acc.smooth, average=True) 437 | # return trn_loss.smooth, trn_acc.smooth 438 | return out_loss, out_met 439 | 440 | def load_model_dict(self, resume_path: Optional[str] = None, load_opt: bool = False): 441 | "Load the model and/or optimizer" 442 | 443 | if resume_path == "": 444 | mfile = self.model_file 445 | else: 446 | mfile = Path(resume_path) 447 | 448 | if not mfile.exists(): 449 | self.logger.info( 450 | f'No existing model in {mfile}, starting from scratch') 451 | return 452 | try: 453 | checkpoint = torch.load(open(mfile, 'rb')) 454 | self.logger.info(f'Loaded model from {mfile} Correctly') 455 | except OSError as e: 456 | self.logger.error( 457 | f'Some problem with resume path: {resume_path}. Exception raised {e}') 458 | raise e 459 | if self.cfg['load_normally']: 460 | self.mdl.load_state_dict( 461 | checkpoint['model_state_dict'], strict=self.cfg['strict_load']) 462 | # else: 463 | # load_state_dict( 464 | # self.mdl, checkpoint['model_state_dict'] 465 | # ) 466 | # self.logger.info('Added model file correctly') 467 | if 'num_it' in checkpoint.keys(): 468 | self.num_it = checkpoint['num_it'] 469 | 470 | if 'num_epoch' in checkpoint.keys(): 471 | self.num_epoch = checkpoint['num_epoch'] 472 | 473 | if 'best_met' in checkpoint.keys(): 474 | self.best_met = checkpoint['best_met'] 475 | 476 | if load_opt: 477 | self.optimizer = self.prepare_optimizer() 478 | self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 479 | if 'scheduler_state_dict' in checkpoint: 480 | self.lr_scheduler = self.prepare_scheduler() 481 | self.lr_scheduler.load_state_dict( 482 | checkpoint['scheduler_state_dict']) 483 | 484 | @exec_func_if_main_proc 485 | def save_model_dict(self): 486 | "Save the model and optimizer" 487 | # if is_main_process(): 488 | checkpoint = { 489 | 'model_state_dict': self.mdl.state_dict(), 490 | 'optimizer_state_dict': self.optimizer.state_dict(), 491 | 'scheduler_state_dict': self.lr_scheduler.state_dict(), 492 | 'num_it': self.num_it, 493 | 'num_epoch': self.num_epoch, 494 | 'cfgtxt': json.dumps(self.cfg), 495 | 'best_met': self.best_met 496 | } 497 | torch.save(checkpoint, self.model_file.open('wb')) 498 | 499 | # @exec_func_if_main_proc 500 | def update_prediction_file(self, predictions, pred_file): 501 | rank = self.rank 502 | if self.cfg.do_dist: 503 | pred_file_to_use = pred_file.parent / f'{rank}_{pred_file.name}' 504 | pickle.dump(predictions, pred_file_to_use.open('wb')) 505 | if is_main_process() and self.cfg.do_dist: 506 | if pred_file.exists(): 507 | pred_file.unlink() 508 | else: 509 | pickle.dump(predictions, pred_file.open('wb')) 510 | # synchronize() 511 | # st_time = time.time() 512 | # self.rectify_predictions(pred_file) 513 | # end_time = time.time() 514 | # self.logger.info( 515 | # f'Updating prediction file took time {st_time - end_time}') 516 | 517 | @exec_func_if_main_proc 518 | def rectify_predictions(self, pred_file): 519 | world_size = get_world_size() 520 | pred_files_to_use = [pred_file.parent / 521 | f'{r}_{pred_file.name}' for r in range(world_size)] 522 | assert all([p.exists() for p in pred_files_to_use]) 523 | out_preds = [] 524 | for pf in pred_files_to_use: 525 | tmp = pickle.load(open(pf, 'rb')) 526 | assert isinstance(tmp, list) 527 | out_preds += tmp 528 | pickle.dump(out_preds, pred_file.open('wb')) 529 | 530 | def prepare_to_write( 531 | self, 532 | train_loss: Dict[str, torch.tensor], 533 | train_acc: Dict[str, torch.tensor], 534 | val_loss: Dict[str, torch.tensor] = None, 535 | val_acc: Dict[str, torch.tensor] = None, 536 | key_list: List[str] = None 537 | ) -> List[torch.tensor]: 538 | if key_list is None: 539 | key_list = self.log_keys 540 | 541 | epoch = self.num_epoch 542 | out_list = [epoch] 543 | 544 | for k in self.loss_keys: 545 | out_list += [train_loss[k]] 546 | if val_loss is not None: 547 | out_list += [val_loss[k]] 548 | 549 | for k in self.met_keys: 550 | out_list += [train_acc[k]] 551 | if val_acc is not None: 552 | out_list += [val_acc[k]] 553 | 554 | assert len(out_list) == len(key_list) 555 | return out_list 556 | 557 | @property 558 | def lr(self): 559 | return self.cfg['lr'] 560 | 561 | @property 562 | def epoch(self): 563 | return self.cfg['epochs'] 564 | 565 | @exec_func_if_main_proc 566 | def master_bar_write(self, mb, **kwargs): 567 | mb.write(**kwargs) 568 | 569 | def fit(self, epochs: int, lr: float, 570 | params_opt_dict: Optional[Dict] = None): 571 | "Main training loop" 572 | # Print logger at the start of the training loop 573 | self.logger.info(self.cfg) 574 | # Initialize the progress_bar 575 | mb = master_bar(range(epochs)) 576 | # Initialize optimizer 577 | # Prepare Optimizer may need to be re-written as per use 578 | self.optimizer = self.prepare_optimizer(params_opt_dict) 579 | # Initialize scheduler 580 | # Prepare scheduler may need to re-written as per use 581 | self.lr_scheduler = self.prepare_scheduler(self.optimizer) 582 | 583 | # Write the top row display 584 | # mb.write(self.log_keys, table=True) 585 | self.master_bar_write(mb, line=self.log_keys, table=True) 586 | exception = False 587 | met_to_use = None 588 | # Keep record of time until exit 589 | st_time = time.time() 590 | try: 591 | # Loop over epochs 592 | for epoch in mb: 593 | self.num_epoch += 1 594 | train_loss, train_acc = self.train_epoch(mb) 595 | 596 | valid_loss, valid_acc, predictions = self.validate( 597 | self.data.valid_dl, mb) 598 | 599 | valid_acc_to_use = valid_acc[self.met_keys[0]] 600 | # Depending on type 601 | self.scheduler_step(valid_acc_to_use) 602 | 603 | # Now only need main process 604 | # Decide to save or not 605 | met_to_use = valid_acc[self.met_keys[0]].cpu() 606 | if self.best_met < met_to_use: 607 | self.best_met = met_to_use 608 | self.save_model_dict() 609 | self.update_prediction_file( 610 | predictions, 611 | self.predictions_dir / f'val_preds_{self.uid}.pkl') 612 | 613 | # Prepare what all to write 614 | to_write = self.prepare_to_write( 615 | train_loss, train_acc, 616 | valid_loss, valid_acc 617 | ) 618 | 619 | # Display on terminal 620 | assert to_write is not None 621 | mb_write = [str(stat) if isinstance(stat, int) 622 | else f'{stat:.4f}' for stat in to_write] 623 | self.master_bar_write(mb, line=mb_write, table=True) 624 | 625 | # for k, record in zip(self.log_keys, to_write): 626 | # self.writer.add_scalar( 627 | # tag=k, scalar_value=record, global_step=self.num_epoch) 628 | # Update in the log file 629 | self.update_log_file( 630 | good_format_stats(self.log_keys, to_write)) 631 | 632 | except Exception as e: 633 | exception = e 634 | raise e 635 | finally: 636 | end_time = time.time() 637 | self.update_log_file( 638 | f'epochs done {epoch}. Exited due to exception {exception}. ' 639 | f'Total time taken {end_time - st_time: 0.4f}\n\n' 640 | ) 641 | # Decide to save finally or not 642 | if met_to_use: 643 | if self.best_met < met_to_use: 644 | self.save_model_dict() 645 | 646 | def testing(self, db: Dict[str, DataLoader]): 647 | if isinstance(db, DataLoader): 648 | db = {'dl0': db} 649 | for dl_name, dl in tqdm(db.items(), total=len(db)): 650 | out_loss, out_acc, preds = self.validate(dl) 651 | 652 | log_keys = self.val_log_keys 653 | 654 | to_write = self.prepare_to_write( 655 | out_loss, out_acc, key_list=log_keys) 656 | header = ' '.join(log_keys) + '\n' 657 | self.update_log_file(header) 658 | self.update_log_file(good_format_stats( 659 | log_keys, to_write)) 660 | 661 | self.logger.info(header) 662 | self.logger.info(good_format_stats(log_keys, to_write)) 663 | 664 | self.update_prediction_file( 665 | preds, self.predictions_dir / f'{dl_name}_preds.pkl') 666 | 667 | def prepare_optimizer(self, params=None): 668 | "Prepare a normal optimizer" 669 | if not params: 670 | params = self.mdl.parameters() 671 | opt = self.opt_fn(params, lr=self.lr) 672 | return opt 673 | 674 | def prepare_scheduler(self, opt: torch.optim): 675 | "Prepares a LR scheduler on top of optimizer" 676 | self.sched_using_val_metric = self.cfg.use_reduce_lr_plateau 677 | if self.sched_using_val_metric: 678 | lr_sched = torch.optim.lr_scheduler.ReduceLROnPlateau( 679 | opt, factor=self.cfg.reduce_factor, patience=self.cfg.patience) 680 | else: 681 | lr_sched = torch.optim.lr_scheduler.LambdaLR( 682 | opt, lambda epoch: 1) 683 | 684 | return lr_sched 685 | 686 | def scheduler_step(self, val_metric): 687 | if self.sched_using_val_metric: 688 | self.lr_scheduler.step(val_metric) 689 | else: 690 | self.lr_scheduler.step() 691 | return 692 | 693 | def overfit_batch(self, epochs: int, lr: float): 694 | "Sanity check to see if model overfits on a batch" 695 | batch = next(iter(self.data.train_dl)) 696 | for b in batch.keys(): 697 | batch[b] = batch[b].to(self.device) 698 | self.mdl.train() 699 | opt = self.prepare_optimizer(epochs, lr) 700 | 701 | for i in range(1000): 702 | opt.zero_grad() 703 | out = self.mdl(batch) 704 | loss = self.loss_fn(out, batch) 705 | loss.backward() 706 | opt.step() 707 | met = self.eval_fn(out, batch) 708 | print(f'Iter {i} | loss {loss: 0.4f} | acc {met: 0.4f}') 709 | -------------------------------------------------------------------------------- /conda_env_zsg.yml: -------------------------------------------------------------------------------- 1 | name: pyt_new 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - asn1crypto=0.24.0=py37_1003 8 | - attrs=19.1.0=py_0 9 | - backcall=0.1.0=py_0 10 | - bleach=3.1.0=py_0 11 | - blinker=1.4=py_1 12 | - boto=2.49.0=py_0 13 | - boto3=1.9.158=py_0 14 | - bz2file=0.98=py_0 15 | - bzip2=1.0.6=h14c3975_1002 16 | - ca-certificates=2019.6.16=hecc5488_0 17 | - cairo=1.16.0=h18b612c_1001 18 | - certifi=2019.6.16=py37_1 19 | - cffi=1.12.3=py37h8022711_0 20 | - chardet=3.0.4=py37_1003 21 | - cryptography=2.6.1=py37h72c5cf5_0 22 | - cudatoolkit=9.0=h13b8566_0 23 | - cymem=2.0.2=py37hfd86e86_0 24 | - cython-blis=0.2.4=py37h516909a_0 25 | - dbus=1.13.6=he372182_0 26 | - decorator=4.4.0=py_0 27 | - defusedxml=0.5.0=py_1 28 | - docutils=0.14=py37_1001 29 | - entrypoints=0.3=py37_1000 30 | - expat=2.2.5=hf484d3e_1002 31 | - ffmpeg=4.1.3=h167e202_0 32 | - fontconfig=2.13.1=he4413a7_1000 33 | - freetype=2.10.0=he983fc9_0 34 | - gensim=3.7.3=py37he1b5a44_0 35 | - gettext=0.19.8.1=hc5be6a0_1002 36 | - giflib=5.1.9=h516909a_0 37 | - glib=2.58.3=hf63aee3_1001 38 | - gmp=6.1.2=hf484d3e_1000 39 | - gnutls=3.6.5=hd3a4fd2_1002 40 | - graphite2=1.3.13=hf484d3e_1000 41 | - gst-plugins-base=1.14.4=hdf3bae2_1001 42 | - gstreamer=1.14.4=h66beb1c_1001 43 | - harfbuzz=2.4.0=h37c48d4_1 44 | - hdf5=1.10.5=nompi_h3c11f04_1100 45 | - icu=58.2=hf484d3e_1000 46 | - idna=2.8=py37_1000 47 | - intel-openmp=2019.3=199 48 | - ipykernel=5.1.0=py37h24bf2e0_1002 49 | - ipython=7.5.0=py37h24bf2e0_0 50 | - ipython_genutils=0.2.0=py_1 51 | - ipywidgets=7.4.2=py_0 52 | - jasper=1.900.1=h07fcdf6_1006 53 | - jedi=0.13.3=py37_0 54 | - jinja2=2.10.1=py_0 55 | - jmespath=0.9.4=py_0 56 | - joblib=0.13.2=py_0 57 | - jpeg=9c=h14c3975_1001 58 | - jsonschema=3.0.0a3=py37_1000 59 | - jupyter=1.0.0=py_2 60 | - jupyter_client=5.2.4=py_3 61 | - jupyter_console=6.0.0=py_0 62 | - jupyter_core=4.4.0=py_0 63 | - lame=3.100=h14c3975_1001 64 | - libblas=3.8.0=8_openblas 65 | - libcblas=3.8.0=8_openblas 66 | - libffi=3.2.1=he1b5a44_1006 67 | - libgcc-ng=8.2.0=hdf63c60_1 68 | - libgfortran-ng=7.3.0=hdf63c60_0 69 | - libiconv=1.15=h516909a_1005 70 | - liblapack=3.8.0=8_openblas 71 | - liblapacke=3.8.0=8_openblas 72 | - libpng=1.6.37=hed695b0_0 73 | - libsodium=1.0.16=h14c3975_1001 74 | - libstdcxx-ng=8.2.0=hdf63c60_1 75 | - libtiff=4.0.10=h648cc4a_1001 76 | - libuuid=2.32.1=h14c3975_1000 77 | - libwebp=1.0.2=h576950b_1 78 | - libxcb=1.13=h14c3975_1002 79 | - libxml2=2.9.9=h13577e0_0 80 | - llvmlite=0.29.0=py37hfd453ef_1 81 | - markupsafe=1.1.1=py37h14c3975_0 82 | - matplotlib-base=3.1.0=py37hfd891ef_1 83 | - mistune=0.8.4=py37h14c3975_1000 84 | - mkl=2019.3=199 85 | - murmurhash=1.0.0=py37hf484d3e_0 86 | - nbconvert=5.5.0=py_0 87 | - nbformat=4.4.0=py_1 88 | - ncurses=6.1=hf484d3e_1002 89 | - nettle=3.4.1=h1bed415_1002 90 | - ninja=1.9.0=h6bb024c_0 91 | - nltk=3.2.5=py_0 92 | - notebook=5.7.8=py37_0 93 | - numba=0.45.1=py37hb3f55d8_0 94 | - numpy=1.16.3=py37he5ce36f_0 95 | - oauthlib=3.0.1=py_0 96 | - olefile=0.46=py_0 97 | - openblas=0.3.6=h6e990d7_1 98 | - opencv=4.1.1=py37hd64ca61_0 99 | - openh264=1.8.0=hdbcaa40_1000 100 | - openssl=1.1.1c=h516909a_0 101 | - pandas=0.24.2=py37hf484d3e_0 102 | - pandoc=2.7.2=0 103 | - pandocfilters=1.4.2=py_1 104 | - parso=0.4.0=py_0 105 | - pcre=8.41=hf484d3e_1003 106 | - pexpect=4.7.0=py37_0 107 | - pickleshare=0.7.5=py37_1000 108 | - pillow=6.1.0=py37h6b7be26_1 109 | - pip=19.1=py37_0 110 | - pixman=0.38.0=h516909a_1003 111 | - plac=0.9.6=py_1 112 | - preshed=2.0.1=py37he6710b0_0 113 | - prometheus_client=0.6.0=py_0 114 | - prompt_toolkit=2.0.9=py_0 115 | - pthread-stubs=0.4=h14c3975_1001 116 | - ptyprocess=0.6.0=py_1001 117 | - pycparser=2.19=py37_1 118 | - pygments=2.4.0=py_0 119 | - pyjwt=1.7.1=py_0 120 | - pyopenssl=19.0.0=py37_0 121 | - pyqt=5.9.2=py37hcca6a23_0 122 | - pyrsistent=0.15.1=py37h516909a_0 123 | - pysocks=1.7.0=py37_0 124 | - python=3.7.3=h5b0a415_0 125 | - python-crfsuite=0.9.6=py37h6bb024c_1000 126 | - python-dateutil=2.8.0=py_0 127 | - pytorch=1.1.0=py3.7_cuda9.0.176_cudnn7.5.1_0 128 | - pytz=2019.1=py_0 129 | - pyzmq=18.0.1=py37hc4ba49a_1 130 | - qt=5.9.7=h52cfd70_1 131 | - qtconsole=4.4.4=py_0 132 | - readline=7.0=hf8c457e_1001 133 | - requests=2.22.0=py37_0 134 | - requests-oauthlib=1.2.0=py_0 135 | - s3transfer=0.2.0=py37_0 136 | - scikit-learn=0.21.2=py37h627018c_0 137 | - scipy=1.3.0=py37h921218d_0 138 | - send2trash=1.5.0=py_0 139 | - setuptools=41.0.1=py37_0 140 | - sip=4.19.8=py37hf484d3e_1000 141 | - six=1.12.0=py37_1000 142 | - smart_open=1.8.3=py_0 143 | - spacy=2.1.4=py37hc9558a2_0 144 | - sqlite=3.26.0=h67949de_1001 145 | - srsly=0.0.7=py37he1b5a44_0 146 | - terminado=0.8.2=py37_0 147 | - testpath=0.4.2=py_1001 148 | - thinc=7.0.4=py37hc9558a2_0 149 | - tk=8.6.9=h84994c4_1001 150 | - torchvision=0.2.2=py_3 151 | - tornado=6.0.2=py37h516909a_0 152 | - traitlets=4.3.2=py37_1000 153 | - twython=3.7.0=py_0 154 | - urllib3=1.24.3=py37_0 155 | - wasabi=0.2.2=py_0 156 | - wcwidth=0.1.7=py_1 157 | - webencodings=0.5.1=py_1 158 | - wheel=0.33.2=py37_0 159 | - widgetsnbextension=3.4.2=py37_1000 160 | - x264=1!152.20180806=h14c3975_0 161 | - xorg-kbproto=1.0.7=h14c3975_1002 162 | - xorg-libice=1.0.10=h516909a_0 163 | - xorg-libsm=1.2.3=h84519dc_1000 164 | - xorg-libx11=1.6.8=h516909a_0 165 | - xorg-libxau=1.0.9=h14c3975_0 166 | - xorg-libxdmcp=1.1.3=h516909a_0 167 | - xorg-libxext=1.3.4=h516909a_0 168 | - xorg-libxrender=0.9.10=h516909a_1002 169 | - xorg-renderproto=0.11.1=h14c3975_1002 170 | - xorg-xextproto=7.3.0=h14c3975_1002 171 | - xorg-xproto=7.0.31=h14c3975_1007 172 | - xz=5.2.4=h14c3975_1001 173 | - zeromq=4.3.1=hf484d3e_1000 174 | - zlib=1.2.11=h14c3975_1004 175 | - pip: 176 | - absl-py==0.7.1 177 | - alabaster==0.7.12 178 | - allennlp==0.8.5-unreleased 179 | - apex==0.1 180 | - appdirs==1.4.3 181 | - atomicwrites==1.3.0 182 | - autopep8==1.4.4 183 | - awscli==1.16.210 184 | - babel==2.7.0 185 | - black==19.3b0 186 | - blessings==1.7 187 | - botocore==1.12.200 188 | - click==7.0 189 | - colorama==0.3.9 190 | - conllu==0.11 191 | - cycler==0.10.0 192 | - cython==0.29.7 193 | - editdistance==0.5.3 194 | - en-core-web-md==2.1.0 195 | - en-core-web-sm==2.1.0 196 | - fastprogress==0.1.21 197 | - fire==0.1.3 198 | - flake8==3.7.7 199 | - flaky==3.6.0 200 | - flask==1.1.1 201 | - flask-cors==3.0.8 202 | - ftfy==5.5.1 203 | - future==0.17.1 204 | - gevent==1.4.0 205 | - gpustat==0.6.0 206 | - greenlet==0.4.15 207 | - grpcio==1.20.1 208 | - h5py==2.9.0 209 | - imagesize==1.1.0 210 | - importlib-metadata==0.19 211 | - importmagic==0.1.7 212 | - isodate==0.6.0 213 | - itsdangerous==1.1.0 214 | - jsonnet==0.13.0 215 | - jsonpickle==1.2 216 | - kiwisolver==1.1.0 217 | - markdown==3.1 218 | - matplotlib==3.1.0rc2 219 | - mccabe==0.6.1 220 | - more-itertools==7.2.0 221 | - munch==2.3.2 222 | - numpydoc==0.9.1 223 | - nvidia-dali==0.9.1 224 | - nvidia-ml-py3==7.352.0 225 | - opencv-python==4.1.0.25 226 | - overrides==1.9 227 | - packaging==19.1 228 | - parsimonious==0.8.1 229 | - pillow-simd==5.3.0.post1 230 | - pluggy==0.12.0 231 | - protobuf==3.7.1 232 | - psutil==5.6.3 233 | - py==1.8.0 234 | - pyasn1==0.4.6 235 | - pycocotools==2.0 236 | - pycodestyle==2.5.0 237 | - pyflakes==2.1.1 238 | - pyparsing==2.4.0 239 | - pytest==5.0.1 240 | - pytorch-pretrained-bert==0.6.2 241 | - pyyaml==5.1 242 | - rdflib==4.2.2 243 | - regex==2019.6.8 244 | - responses==0.10.6 245 | - rsa==3.4.2 246 | - snowballstemmer==1.9.0 247 | - sphinx==2.1.2 248 | - sphinxcontrib-applehelp==1.0.1 249 | - sphinxcontrib-devhelp==1.0.1 250 | - sphinxcontrib-htmlhelp==1.0.2 251 | - sphinxcontrib-jsmath==1.0.1 252 | - sphinxcontrib-qthelp==1.0.2 253 | - sphinxcontrib-serializinghtml==1.1.3 254 | - sqlparse==0.3.0 255 | - tb-nightly==1.14.0a20190510 256 | - tensorboard==1.13.1 257 | - tensorboardx==1.8 258 | - toml==0.10.0 259 | - torchtext==0.4.0 260 | - tqdm==4.31.1 261 | - unidecode==1.1.1 262 | - werkzeug==0.15.2 263 | - word2number==1.1 264 | - yacs==0.1.6 265 | - yapf==0.27.0 266 | - zipp==0.5.2 267 | prefix: /scratch/arka/miniconda3/envs/pyt_new 268 | 269 | -------------------------------------------------------------------------------- /configs/cfg.json: -------------------------------------------------------------------------------- 1 | { 2 | "ds_to_use": "vg_split", 3 | "bs": 16, 4 | "nw": 4, 5 | "bsv": 16, 6 | "nwv": 4, 7 | "lr": 1e-4, 8 | "devices": 0, 9 | "opt_fn": "Adam", 10 | "opt_fn_params": { 11 | "betas": [0.9, 0.99] 12 | }, 13 | "do_norm": false, 14 | "use_same_atb": true, 15 | "mdl_to_use": "retina", 16 | "resize_img": [300, 300], 17 | "tmp_path": "./tmp", 18 | "use_multi": true, 19 | "use_focal": true, 20 | "use_softmax": false, 21 | "alpha": 0.25, 22 | "gamma": 2, 23 | "ratios": "[1/2, 1, 2]", 24 | "scales": "[1, 2**(1/3), 2**(2/3)]", 25 | "scale_factor": 4, 26 | "emb_dim": 300, 27 | "matching_threshold": 0.6, 28 | "epochs": 10, 29 | "use_bidirectional": true, 30 | "lstm_dim": 128, 31 | "use_reduce_lr_plateau": true, 32 | "patience": 2, 33 | "reduce_factor": 0.1, 34 | "lamb_reg": 1, 35 | "resume_path": "", 36 | "resume": true, 37 | "load_opt": false, 38 | "strict_load": true, 39 | "load_normally": true, 40 | "acc_iou_threshold": 0.5, 41 | "use_lang": true, 42 | "use_img": true 43 | } 44 | -------------------------------------------------------------------------------- /configs/cfg.yaml: -------------------------------------------------------------------------------- 1 | acc_iou_threshold: 0.5 2 | alpha: 0.25 3 | bs: 16 4 | bsv: 16 5 | devices: 0 6 | do_norm: false 7 | ds_to_use: vg_split 8 | emb_dim: 300 9 | epochs: 10 10 | gamma: 2 11 | lamb_reg: 1 12 | load_normally: true 13 | load_opt: false 14 | lr: 0.0001 15 | lstm_dim: 128 16 | matching_threshold: 0.6 17 | mdl_to_use: retina 18 | nw: 4 19 | nwv: 4 20 | opt_fn: Adam 21 | opt_fn_params: 22 | betas: 23 | - 0.9 24 | - 0.99 25 | patience: 2 26 | ratios: '[1/2, 1, 2]' 27 | reduce_factor: 0.1 28 | resize_img: 29 | - 300 30 | - 300 31 | resume: true 32 | resume_path: '' 33 | scale_factor: 4 34 | scales: '[1, 2**(1/3), 2**(2/3)]' 35 | strict_load: true 36 | tmp_path: ./tmp 37 | use_bidirectional: true 38 | use_focal: true 39 | use_img: true 40 | use_lang: true 41 | use_multi: true 42 | use_reduce_lr_plateau: true 43 | use_same_atb: true 44 | use_softmax: false 45 | -------------------------------------------------------------------------------- /configs/ds_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "flickr30k": { 3 | "data_dir": "./data/flickr30k", 4 | "img_dir": "./data/flickr30k/flickr30k_images", 5 | "trn_csv_file": "./data/flickr30k/csv_dir/train_flat.csv", 6 | "val_csv_file": "./data/flickr30k/csv_dir/val.csv", 7 | "test_csv_file": "./data/flickr30k/csv_dir/test.csv" 8 | }, 9 | "refclef": { 10 | "data_dir": "./data/referit/refclef", 11 | "img_dir": "./data/referit/saiapr_tc12_images", 12 | "trn_csv_file": "./data/referit/csv_dir/train_flat.csv", 13 | "val_csv_file": "./data/referit/csv_dir/val.csv", 14 | "test_csv_file": "./data/referit/csv_dir/test.csv" 15 | }, 16 | "flickr30k_c0": { 17 | "data_dir": "./data/flickr30k", 18 | "img_dir": "./data/flickr30k/flickr30k_images", 19 | "trn_csv_file": "./data/flickr30k_c0/csv_dir/train.csv", 20 | "val_csv_file": "./data/flickr30k_c0/csv_dir/val.csv", 21 | "test_csv_file": "./data/flickr30k_c0/csv_dir/test.csv" 22 | }, 23 | "flickr30k_c1": { 24 | "data_dir": "./data/flickr30k", 25 | "img_dir": "./data/flickr30k/flickr30k_images", 26 | "trn_csv_file": "./data/flickr30k_c1/csv_dir/train.csv", 27 | "val_csv_file": "./data/flickr30k_c1/csv_dir/val.csv", 28 | "test_csv_file": "./data/flickr30k_c1/csv_dir/test.csv" 29 | }, 30 | "vg_split_c2": { 31 | "data_dir": "/scratch/arka/Ark_git_files/visual_genome/vg_split", 32 | "img_dir": "/scratch/arka/Ark_git_files/visual_genome", 33 | "trn_csv_file": "./data/vg_split_c2/csv_dir/train.csv", 34 | "val_csv_file": "./data/vg_split_c2/csv_dir/val.csv", 35 | "test_csv_file": "./data/vg_split_c2/csv_dir/test.csv" 36 | }, 37 | "vg_split_c3": { 38 | "data_dir": "/scratch/arka/Ark_git_files/visual_genome/vg_split", 39 | "img_dir": "/scratch/arka/Ark_git_files/visual_genome", 40 | "trn_csv_file": "./data/vg_split_c3/csv_dir/train.csv", 41 | "val_csv_file": "./data/vg_split_c3/csv_dir/val.csv", 42 | "test_csv_file": "./data/vg_split_c3/csv_dir/test.csv" 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /data/download_ann.sh: -------------------------------------------------------------------------------- 1 | # Script to download processed annotations 2 | wget --header="Host: doc-14-5s-docs.googleusercontent.com" --header="User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/76.0.3809.87 Safari/537.36" --header="Accept: text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3" --header="Accept-Language: en-US,en;q=0.9" --header="Referer: https://drive.google.com/drive/u/2/folders/1whckadqTaQhw5sZ_dIqQZhtaxjL8Lzts" --header="Cookie: AUTH_5duq3i08p8lb90k7bjdi6mv4mkuct380_nonce=8t7cdn4f717ui" --header="Connection: keep-alive" "https://doc-14-5s-docs.googleusercontent.com/docs/securesc/ijsjbo7j24dg1fshcvqn84h4ucoebo1g/n68kl1tmithfmprktsbr8v1d47n2rs47/1566432000000/16497152722325373235/16497152722325373235/1oZ5llnA4btD9LSmnSB0GaZtTogskLwCe?e=download&h=00885983406768461781&nonce=8t7cdn4f717ui&user=16497152722325373235&hash=qrljff7cmfnp6tur08smia9l8ifbmji0" -O "ds_csv_ann.zip" -c 3 | 4 | unzip ds_csv_ann.zip 5 | mv ds_csv_ann/* . 6 | rmdir ds_csv_ann 7 | rm ds_csv_ann.zip 8 | -------------------------------------------------------------------------------- /data/ds_prep_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "flickr30k": { 3 | "root": "./data/flickr30k", 4 | "ann_path": "./data/flickr30k/flickr30k_entities/Annotations", 5 | "sen_path": "./data/flickr30k/flickr30k_entities/Sentences", 6 | "trn_img_ids": "./data/flickr30k/flickr30k_entities/train.txt", 7 | "val_img_ids": "./data/flickr30k/flickr30k_entities/val.txt", 8 | "test_img_ids": "./data/flickr30k/flickr30k_entities/test.txt" 9 | }, 10 | "refclef": { 11 | "root": "./data/referit" 12 | }, 13 | "flickr_unseen_words": { 14 | "root": "./data/flickr30k/flickr_unseen_words" 15 | }, 16 | "vg_split": { 17 | "data_dir": "/scratch/arka/Ark_git_files/visual_genome/vg_split", 18 | "img_dir": "/scratch/arka/Ark_git_files/visual_genome" 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /data/ds_prep_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions to prepare dataset in a common format 3 | """ 4 | from typing import List, Dict, Union, Any 5 | from yacs.config import CfgNode as CN 6 | from pathlib import Path 7 | from abc import ABC, abstractmethod 8 | import pandas as pd 9 | import numpy as np 10 | from dataclasses import dataclass 11 | import json 12 | from tqdm import tqdm 13 | import copy 14 | 15 | Fpath = Union[Path, str] 16 | Cft = Union[Dict, CN] 17 | DF = pd.DataFrame 18 | ID = Union[int, str] 19 | 20 | 21 | def union_of_rects(rects): 22 | """ 23 | Calculates union of two rectangular boxes 24 | Assumes both rects of form N x [xmin, ymin, xmax, ymax] 25 | """ 26 | xA = np.min(rects[:, 0]) 27 | yA = np.min(rects[:, 1]) 28 | xB = np.max(rects[:, 2]) 29 | yB = np.max(rects[:, 3]) 30 | return np.array([xA, yA, xB, yB], dtype=np.int32) 31 | 32 | 33 | @dataclass 34 | class BaseCSVPrepare(ABC): 35 | """ 36 | Abstract class to prepare CSV files 37 | to be used for data loading 38 | ds_root: Path to root directory of the dataset. 39 | This can be a symbolic path as well 40 | """ 41 | ds_prep_cfg: Cft 42 | 43 | def __post_init__(self): 44 | """ 45 | Initializes stuff from the dataset preparation 46 | configuration. 47 | """ 48 | # Convert to CN type if not already 49 | if isinstance(self.ds_prep_cfg, dict): 50 | self.ds_prep_cfg = CN(self.ds_prep_cfg) 51 | 52 | # Set the dataset root (resolve symbolic links) 53 | self.ds_root = Path(self.ds_prep_cfg.root).resolve() 54 | self.ann_file = self.ds_root / 'all_annot_new.json' 55 | self.csv_root = self.ds_root / 'csv_dir' 56 | self.csv_root.mkdir(exist_ok=True) 57 | self.after_init() 58 | 59 | def after_init(self): 60 | pass 61 | 62 | def load_annotations(self) -> DF: 63 | if self.ann_file.exists(): 64 | return pd.DataFrame(json.load(open(self.ann_file))) 65 | else: 66 | output = self.get_annotations() 67 | assert isinstance(output, (dict, list)) 68 | json.dump(output, self.ann_file.open('w')) 69 | return pd.DataFrame(output) 70 | 71 | @abstractmethod 72 | def get_annotations(self): 73 | """ 74 | Getting the annotations, specific to dataset. 75 | The output should be of the format: 76 | output_annot = List[grnd_dict] 77 | grnd_dict = {'bbox': [x1,y1,x2,y2], 'img_id': img_id, 78 | 'queries': List[query], 'entity_name': optional, 'full_sentence': optional} 79 | bbox should be in x1y1x2y2 format 80 | """ 81 | raise NotImplementedError 82 | 83 | @abstractmethod 84 | def get_trn_val_test_ids(self, output_annot=None) -> List[Any]: 85 | """ 86 | Obtain training, validation and testing ids. 87 | Depends on the dataset 88 | """ 89 | raise NotImplementedError 90 | 91 | def get_dfmask_from_ids(self, ids: List[Any], annots: DF): 92 | ids_set = set(ids) 93 | return annots.img_id.apply(lambda x: x in ids_set) 94 | 95 | def get_df_from_ids(self, ids: List[Any], annots: DF, split_type='val'): 96 | """ 97 | Return the df with ids. Basically for train we can directly return. 98 | For validation, testing each query is separate row. 99 | """ 100 | 101 | msk1 = self.get_dfmask_from_ids(ids, annots) 102 | annots_to_use = annots[msk1] 103 | if split_type == 'train': 104 | return annots_to_use 105 | else: 106 | out_dict_list = [] 107 | for ind, row in tqdm(annots_to_use.iterrows(), 108 | total=len(annots_to_use)): 109 | for query in row['query']: 110 | out_dict = copy.deepcopy(row) 111 | out_dict['query'] = query 112 | out_dict_list.append(out_dict) 113 | return pd.DataFrame(out_dict_list) 114 | 115 | def save_annot_to_format(self): 116 | """ 117 | Saves the annotations to the following csv format 118 | img_name,x1,y1,x2,y2,query(ies) 119 | """ 120 | output_annot = self.load_annotations() 121 | trn_ids, val_ids, test_ids = self.get_trn_val_test_ids(output_annot) 122 | output_annot = output_annot[['img_id', 'bbox', 'query']] 123 | trn_df = self.get_df_from_ids( 124 | trn_ids, output_annot, split_type='train') 125 | trn_df.to_csv(self.csv_root / 'train.csv', index=False, header=True) 126 | 127 | val_df = self.get_df_from_ids(val_ids, output_annot) 128 | val_df.to_csv(self.csv_root / 'val.csv', index=False, header=True) 129 | 130 | if test_ids is not None: 131 | test_df = self.get_df_from_ids(test_ids, output_annot) 132 | test_df.to_csv(self.csv_root / 'test.csv', 133 | index=False, header=True) 134 | -------------------------------------------------------------------------------- /data/flatten_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Converts input csv file of type: img_name, bbox, List[queries] 3 | to output csv file of type: img_name, bbox, query 4 | """ 5 | import fire 6 | import pandas as pd 7 | from tqdm import tqdm 8 | import copy 9 | import ast 10 | 11 | 12 | def converter(inp_csv, out_csv): 13 | inp_df = pd.read_csv(inp_csv) 14 | inp_df['query'] = inp_df['query'].apply( 15 | lambda x: ast.literal_eval(x)) 16 | 17 | inp_df = inp_df.to_dict(orient='records') 18 | out_list = [] 19 | for row in tqdm(inp_df): 20 | queries = row.pop('query') 21 | for query in queries: 22 | out_dict = copy.deepcopy(row) 23 | out_dict['query'] = query 24 | out_list.append(out_dict) 25 | 26 | out_df = pd.DataFrame(out_list) 27 | out_df.to_csv(out_csv, index=False, header=True) 28 | 29 | 30 | if __name__ == '__main__': 31 | fire.Fire(converter) 32 | -------------------------------------------------------------------------------- /data/prepare_c01_flickr_splits.py: -------------------------------------------------------------------------------- 1 | """ 2 | Creating zero-shot splits from Flickr Annotations 3 | """ 4 | from typing import Dict, List, Any 5 | from ds_prep_utils import Cft, ID, BaseCSVPrepare, DF 6 | from dataclasses import dataclass 7 | from yacs.config import CfgNode as CN 8 | from pathlib import Path 9 | import json 10 | import pandas as pd 11 | import spacy 12 | from tqdm import tqdm 13 | from collections import Counter 14 | import numpy as np 15 | import copy 16 | import pickle 17 | nlp = spacy.load("en_core_web_sm") 18 | 19 | np.random.seed(5) 20 | 21 | 22 | class FlickrUnseenWordsCSVPrepare(BaseCSVPrepare): 23 | 24 | def after_init(self): 25 | self.flickr_ann_file = self.ds_root.parent / 'all_annot_new.json' 26 | self.flickr_ann = None 27 | self.load_annotations() 28 | 29 | def load_annotations(self): 30 | if self.flickr_ann is None: 31 | self.flickr_ann = json.load(open(self.flickr_ann_file)) 32 | return pd.DataFrame(self.flickr_ann) 33 | 34 | def get_annotations(self): 35 | return 36 | 37 | def get_query_word_list(self): 38 | self.query_word_lemma_file = self.ds_root / 'query_word_lemma_counter.json' 39 | if not self.query_word_lemma_file.exists(): 40 | query_word_list = [] 41 | for ind, grnd_dict in enumerate(tqdm(self.flickr_ann)): 42 | queries = grnd_dict['query'] 43 | for query in queries: 44 | tmp_query = nlp(query) 45 | query_word_list += [t.lemma_ for t in tmp_query] 46 | query_word_counter = Counter(query_word_list) 47 | json.dump(query_word_counter, open( 48 | self.query_word_lemma_file, 'w')) 49 | return Counter(json.load(open(self.query_word_lemma_file))) 50 | 51 | def create_exclude_include_list(self): 52 | self.exclude_include_list_file = self.ds_root / 'inc_exc_word_list.json' 53 | if not self.exclude_include_list_file.exists(): 54 | self.load_annotations() 55 | queries_lemma_count = self.get_query_word_list() 56 | 57 | # create include list 58 | qmost_common = queries_lemma_count.most_common(500) 59 | include_list = [q[0] for q in qmost_common] 60 | 61 | # exclude list 62 | remaining_list = [ 63 | r for r in queries_lemma_count if r not in set(include_list)] 64 | 65 | to_include_prob = 0.7 66 | num_to_incl = int(to_include_prob * len(remaining_list)) 67 | 68 | id_list = np.random.permutation(len(remaining_list)) 69 | to_include = id_list[:num_to_incl] 70 | to_exclude = id_list[num_to_incl:] 71 | 72 | include_list += [remaining_list[t] for t in to_include] 73 | exclude_list = [remaining_list[t] for t in to_exclude] 74 | 75 | out_dict = {'exclude_list': exclude_list, 76 | 'include_list': include_list} 77 | json.dump(out_dict, self.exclude_include_list_file.open('w')) 78 | return json.load(self.exclude_include_list_file.open('r')) 79 | 80 | def get_trn_val_test_ids(self, output_annot=None): 81 | inc_excl_lists = self.create_exclude_include_list() 82 | incl_set = inc_excl_lists['include_list'] 83 | excl_set = inc_excl_lists['exclude_list'] 84 | 85 | test_ids_file = self.ds_root / 'test_ids.pkl' 86 | new_output_annot_file = self.ds_root / 'test_output_annot.pkl' 87 | if not test_ids_file.exists(): 88 | test_ids = [] 89 | new_output_annot = [] 90 | for ind, grnd_dict in enumerate(tqdm(self.flickr_ann)): 91 | queries = grnd_dict['query'] 92 | qs_to_use = [] 93 | for query in queries: 94 | tmp_query = nlp(query) 95 | last_idx = -1 96 | qu = tmp_query[last_idx] 97 | while not len(qu.text) > 1: 98 | print('why', qu.text) 99 | try: 100 | last_idx -= 1 101 | qu = tmp_query[last_idx] 102 | except IndexError: 103 | print('noope') 104 | break 105 | if not (qu.lemma_ in incl_set): 106 | assert qu.lemma_ in excl_set 107 | qs_to_use.append(query) 108 | if len(qs_to_use) > 0: 109 | qs_to_use = list(set(qs_to_use)) 110 | grnd_dict1 = copy.deepcopy(grnd_dict) 111 | grnd_dict1['query'] = qs_to_use 112 | grnd_dict1['split_type'] = 'test' 113 | new_output_annot.append(grnd_dict1) 114 | test_ids.append(grnd_dict1['img_id']) 115 | pickle.dump(test_ids, test_ids_file.open('wb')) 116 | pickle.dump(new_output_annot, new_output_annot_file.open('wb')) 117 | test_ids = pickle.load(test_ids_file.open('rb')) 118 | new_output_annot = pickle.load(new_output_annot_file.open('rb')) 119 | 120 | flickr_df = pd.DataFrame(self.flickr_ann) 121 | all_ids = set(list(flickr_df.img_id)) 122 | trn_val_ids = list(all_ids - set(test_ids)) 123 | 124 | to_include_prob = 0.8 125 | num_to_incl = int(to_include_prob * len(trn_val_ids)) 126 | 127 | id_list = np.random.permutation(len(trn_val_ids)) 128 | trids = id_list[:num_to_incl] 129 | vlids = id_list[num_to_incl:] 130 | 131 | trn_ids = [trn_val_ids[trid] for trid in trids] 132 | val_ids = [trn_val_ids[vlid] for vlid in vlids] 133 | 134 | for ind, grnd_dict in enumerate(tqdm(self.flickr_ann)): 135 | if grnd_dict['img_id'] in trn_val_ids: 136 | queries = grnd_dict['query'] 137 | # if not all([nlp(q)[-1].lemma_ in incl_set for q in queries]): 138 | # continue 139 | new_output_annot.append(grnd_dict) 140 | 141 | return trn_ids, val_ids, test_ids, pd.DataFrame(new_output_annot) 142 | 143 | def save_annot_to_format(self): 144 | """ 145 | Saves the annotations to the following csv format 146 | img_name,x1,y1,x2,y2,query(ies) 147 | """ 148 | output_annot = self.load_annotations() 149 | trn_ids, val_ids, test_ids, output_annot = self.get_trn_val_test_ids( 150 | output_annot) 151 | output_annot = output_annot[['img_id', 'bbox', 'query']] 152 | trn_df = self.get_df_from_ids( 153 | trn_ids, output_annot, split_type='train') 154 | trn_df.to_csv(self.csv_root / 'train.csv', index=False, header=True) 155 | 156 | val_df = self.get_df_from_ids(val_ids, output_annot) 157 | val_df.to_csv(self.csv_root / 'val.csv', index=False, header=True) 158 | 159 | if test_ids is not None: 160 | test_df = self.get_df_from_ids(test_ids, output_annot) 161 | test_df.to_csv(self.csv_root / 'test.csv', 162 | index=False, header=True) 163 | 164 | 165 | if __name__ == '__main__': 166 | ds_cfg = json.load(open('./data/ds_prep_config.json')) 167 | fl0 = FlickrUnseenWordsCSVPrepare(ds_cfg['flickr_unseen_words']) 168 | # fl0.create_exclude_include_list() 169 | fl0.save_annot_to_format() 170 | -------------------------------------------------------------------------------- /data/prepare_flickr30k.py: -------------------------------------------------------------------------------- 1 | """ 2 | Create CSV file annotations for Flickr30k 3 | """ 4 | 5 | import pandas as pd 6 | from xml.etree import ElementTree as et 7 | from pathlib import Path 8 | from tqdm import tqdm 9 | import json 10 | import re 11 | from collections import defaultdict 12 | import numpy as np 13 | from ds_prep_utils import Cft, ID, BaseCSVPrepare, union_of_rects 14 | 15 | 16 | class Flickr_one_img_info: 17 | def __init__(self, ds_cfg: Cft, img_id: ID, results_out): 18 | self.img_id = img_id 19 | self.rw = results_out 20 | ann_path = Path(ds_cfg.ann_path) 21 | sen_path = Path(ds_cfg.sen_path) 22 | self.ann_fname = ann_path / f'{img_id}.xml' 23 | self.sen_fname = sen_path / f'{img_id}.txt' 24 | self.ann_file = et.parse(self.ann_fname).getroot() 25 | # list of lists 26 | self.cid_bbox_dict = defaultdict(list) 27 | self.cid_sc_nobnd = dict() 28 | self.cid_entity_dict = dict() 29 | self.cid_text_dict = defaultdict(list) 30 | 31 | self.p1 = re.compile(r'\[.*?\]') 32 | self.p2 = re.compile(r'\[/EN#(\d*)/(\w*)\s(.*)\]') 33 | self.p3 = re.compile(r'\[/EN#(\d*)/(\w*)/(\w*)\s(.*)\]') 34 | 35 | self.cid_dict = dict() 36 | self.get_full_ann() 37 | 38 | def get_img_dim(self): 39 | tmp = self.ann_file.find('size') 40 | self.img_w = int(tmp.find('width').text) 41 | self.img_h = int(tmp.find('height').text) 42 | self.img_depth = int(tmp.find('depth').text) 43 | 44 | def get_bbox(self, bbx): 45 | xmin = int(bbx.find('xmin').text) 46 | ymin = int(bbx.find('ymin').text) 47 | xmax = int(bbx.find('xmax').text) 48 | ymax = int(bbx.find('ymax').text) 49 | return [xmin, ymin, xmax, ymax] 50 | 51 | def get_ann(self): 52 | assert str(self.img_id) == self.ann_file.findall( 53 | 'filename')[0].text[:-4] 54 | tmp = self.ann_file.findall('object') 55 | 56 | for o in tmp: 57 | nlist = o.findall('name') 58 | bbox = o.find('bndbox') 59 | sc = o.find('scene') 60 | nbnd = o.find('nobndbox') 61 | 62 | if bbox is not None: 63 | for n in nlist: 64 | bnd_bbox = self.get_bbox(bbox) 65 | self.cid_bbox_dict[int(n.text)].append(bnd_bbox) 66 | if sc is None: 67 | sc = -1 68 | else: 69 | sc = int(sc.text) 70 | if nbnd is None: 71 | nbnd = -1 72 | else: 73 | nbnd = int(nbnd.text) 74 | 75 | for n in nlist: 76 | self.cid_sc_nobnd[int(n.text)] = {'scene': sc, 'nobnd': nbnd} 77 | return 78 | 79 | def get_sen_ann(self): 80 | """ 81 | Some complicated logic 82 | """ 83 | with self.sen_fname.open('r') as g: 84 | for l in g.readlines(): 85 | tmp = self.p1.findall(l) 86 | for t in tmp: 87 | tmp2 = self.p2.findall(t) 88 | if len(tmp2) == 0: 89 | tmp2 = self.p3.findall(t) 90 | if len(tmp2) != 1: 91 | tmp2 = self.p3.findall(t) 92 | assert len(tmp2) == 1 93 | tmp3 = tmp2[0] 94 | if int(tmp3[0]) != 0: 95 | self.cid_entity_dict[int(tmp3[0])] = [ 96 | tmp3[1], tmp3[2]] 97 | self.cid_text_dict[int(tmp3[0])].append(tmp3[3]) 98 | 99 | tmp3 = tmp2[0] 100 | if int(tmp3[0]) != 0: 101 | self.cid_entity_dict[int(tmp3[0])] = tmp3[1] 102 | self.cid_text_dict[int(tmp3[0])].append(tmp3[-1]) 103 | return 104 | 105 | def get_full_ann(self): 106 | self.get_ann() 107 | self.get_sen_ann() 108 | rw = self.rw 109 | for k in self.cid_entity_dict.keys(): 110 | bbox = self.cid_bbox_dict[k] 111 | if len(bbox) == 0: 112 | continue 113 | if len(bbox) > 1: 114 | bbox = union_of_rects(np.array(bbox)).tolist() 115 | else: 116 | bbox = bbox[0] 117 | tmpl = {'bbox': bbox, 118 | 'scene': self.cid_sc_nobnd[k]['scene'], 119 | 'nobnd': self.cid_sc_nobnd[k]['nobnd'], 120 | 'entity': self.cid_entity_dict[k], 121 | 'query': self.cid_text_dict[k], 122 | 'full_txt': rw[str(self.img_id)]} 123 | self.cid_dict[k] = tmpl 124 | return self.cid_dict 125 | 126 | 127 | class FlickrCSVPrepare(BaseCSVPrepare): 128 | def get_annotations(self): 129 | results_out = self.ds_root / 'results.json' 130 | rw = json.load(results_out.open('r')) 131 | full_img_cid_dict = {} 132 | for img_id in tqdm(rw.keys(), total=len(rw)): 133 | f = Flickr_one_img_info(self.ds_prep_cfg, img_id, rw) 134 | full_img_cid_dict[f.img_id] = f.cid_dict 135 | json.dump(full_img_cid_dict, open( 136 | self.ds_root / 'all_ann_2.json', 'w')) 137 | 138 | out_dict_list = [] 139 | for img_id in tqdm(full_img_cid_dict): 140 | for cid in full_img_cid_dict[img_id]: 141 | out_dict = full_img_cid_dict[img_id][cid] 142 | out_dict['img_id'] = img_id 143 | out_dict_list.append(out_dict) 144 | 145 | return out_dict_list 146 | 147 | def get_trn_val_test_ids(self, output_annot=None): 148 | trn_ids = list(pd.read_csv( 149 | self.ds_prep_cfg.trn_img_ids, header=None)[0]) 150 | val_ids = list(pd.read_csv( 151 | self.ds_prep_cfg.val_img_ids, header=None)[0]) 152 | test_ids = list(pd.read_csv( 153 | self.ds_prep_cfg.test_img_ids, header=None)[0]) 154 | 155 | return trn_ids, val_ids, test_ids 156 | 157 | 158 | if __name__ == '__main__': 159 | ds_cfg = json.load(open('./data/ds_prep_config.json')) 160 | fl = FlickrCSVPrepare(ds_cfg['flickr30k']) 161 | fl.save_annot_to_format() 162 | -------------------------------------------------------------------------------- /data/prepare_referit.py: -------------------------------------------------------------------------------- 1 | """ 2 | Python file to create required csvs for ReferIt datasets 3 | Author: Arka Sadhu 4 | Adapted from https://github.com/lichengunc/refer 5 | """ 6 | 7 | import pandas as pd 8 | from pathlib import Path 9 | from tqdm import tqdm 10 | import json 11 | from collections import defaultdict 12 | import numpy as np 13 | from typing import Dict, List, Any 14 | from ds_prep_utils import Cft, ID, BaseCSVPrepare, DF 15 | import pickle 16 | 17 | 18 | class ReferItCSVPrepare(BaseCSVPrepare): 19 | """ 20 | Data preparation class 21 | """ 22 | 23 | def after_init(self): 24 | self.splitBy = 'berkeley' 25 | self.data_dir = self.ds_root / 'refclef' 26 | self.ref_ann_file = self.data_dir / f'refs({self.splitBy}).p' 27 | self.ref_instance_file = self.data_dir / 'instances.json' 28 | 29 | self.ref_ann = pickle.load(self.ref_ann_file.open('rb')) 30 | self.ref_inst = json.load(self.ref_instance_file.open('r')) 31 | self.ref_inst_ann = self.ref_inst['annotations'] 32 | 33 | def get_annotations(self): 34 | self.instance_dict_by_ann_id = { 35 | v['id']: ind for ind, v in enumerate(self.ref_inst_ann)} 36 | out_dict_list = [] 37 | for rj in self.ref_ann: 38 | spl = rj['split'] 39 | sents = rj['sentences'] 40 | ann_id = rj['ann_id'] 41 | inst_bbox = self.ref_inst_ann[self.instance_dict_by_ann_id[ann_id]]['bbox'] 42 | # Saving in [x0, y0, x1, y1] format 43 | inst_bbox = [inst_bbox[0], inst_bbox[1], 44 | inst_bbox[2] + inst_bbox[0], inst_bbox[3]+inst_bbox[1]] 45 | 46 | sents = [s['raw'] for s in sents] 47 | sents = [t.strip().lower() for t in sents] 48 | out_dict = {} 49 | out_dict['img_id'] = f"{rj['image_id']}.jpg" 50 | out_dict['bbox'] = inst_bbox 51 | out_dict['split'] = spl 52 | out_dict['query'] = sents 53 | out_dict_list.append(out_dict) 54 | return out_dict_list 55 | 56 | def get_dfmask_from_ids(self, ids: List[Any], annots: DF): 57 | return ids 58 | 59 | def get_trn_val_test_ids(self, output_annot: DF): 60 | trn_ids_mask = output_annot.split.apply(lambda x: x == 'train') 61 | val_ids_mask = output_annot.split.apply(lambda x: x == 'val') 62 | test_ids_mask = output_annot.split.apply(lambda x: x == 'test') 63 | return trn_ids_mask, val_ids_mask, test_ids_mask 64 | 65 | 66 | if __name__ == '__main__': 67 | ds_cfg = json.load(open('./data/ds_prep_config.json')) 68 | ref = ReferItCSVPrepare(ds_cfg['refclef']) 69 | ref.save_annot_to_format() 70 | --------------------------------------------------------------------------------