├── ucl ├── __init__.py ├── losses.py ├── loader.py ├── resnet.py ├── datasets_gb.py ├── datasets.py └── builder_gb.py ├── scripts ├── evaluate_gb.sh ├── finetuning_gb.sh ├── pretrain_gb.sh ├── pretrain_ucl_butterfly.sh └── eval_ucl_pocus.py ├── .gitignore ├── requirements.txt ├── README.md ├── utils ├── util.py └── flops_counter.py ├── dataloader.py ├── LICENSE ├── lincls_bin.py └── train_ucl.py /ucl/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /scripts/evaluate_gb.sh: -------------------------------------------------------------------------------- 1 | python lincls_bin.py \ 2 | -a resnet50 \ 3 | --train_list $1 \ 4 | --val_list $2 \ 5 | --world-size 1 \ 6 | --num_classes 3 \ 7 | --batch-size 64 \ 8 | --pretrained $3 \ 9 | --fc_type 2 \ 10 | --lr 0.003 \ 11 | --cos_lr \ 12 | --wd 0.0005 \ 13 | --epochs 30 \ 14 | --gpu $4 \ 15 | --evaluate \ -------------------------------------------------------------------------------- /scripts/finetuning_gb.sh: -------------------------------------------------------------------------------- 1 | python lincls_bin.py \ 2 | -a resnet50 \ 3 | --train_list $1 \ 4 | --val_list $2 \ 5 | --world-size 1 \ 6 | --num_classes 3 \ 7 | --batch-size 64 \ 8 | --pretrained $3 \ 9 | --save-dir $4 \ 10 | --fc_type 2 \ 11 | --lr 0.003 \ 12 | --cos_lr \ 13 | --wd 0.0005 \ 14 | --epochs 30 \ 15 | --gpu $5 \ -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.png 2 | *.npy 3 | *.json 4 | *.tar 5 | cycle_contrast/__pycache__ 6 | runs/ 7 | scripts/__pycache__ 8 | utils/__pycache__ 9 | __pycache__ 10 | .neptune 11 | output 12 | cls_output 13 | data 14 | GB_Video 15 | .ipynb_checkpoints 16 | *.ipynb 17 | Data_32 18 | models 19 | finetune_gb.sh 20 | checker.sh 21 | lincls_bin_CT.py 22 | dataloader_CT.py 23 | finetune_gb_CT.sh 24 | pretrain_ucl_gb.sh 25 | run_classifier.sh 26 | main_lincls.py 27 | scripts/finetune_test_gb.sh -------------------------------------------------------------------------------- /scripts/pretrain_gb.sh: -------------------------------------------------------------------------------- 1 | python train_ucl.py \ 2 | --moco-random-video-frame-as-pos --multiprocessing-distributed --soft-nn \ 3 | --cycle-back-cls --mlp --aug-plus --cycle-back-candidates \ 4 | --cycle-back-cls-video-as-pos --sep-head \ 5 | --pretrained-models --exp_name="UCL_GBC" \ 6 | --learning-rate=0.003 \ 7 | --cos \ 8 | --dist-url="tcp://localhost:10001" \ 9 | --moco-t=0.07 \ 10 | --soft-nn-t=0.07 \ 11 | --soft-nn-topk-support \ 12 | --negatives \ 13 | --intranegs-only-two \ 14 | --soft-nn-loss-weight=0.1 \ 15 | --cross-neg-topk-mining \ 16 | --cross-neg-topk-support-size=4 \ 17 | --num-var=32 \ 18 | --anchor-reverse-cross \ 19 | --single-loss-intra-inter --single-loss-ncap-support-size=4 --num-negatives=3 \ 20 | --qcap-include --cosine-curriculum $1 -------------------------------------------------------------------------------- /scripts/pretrain_ucl_butterfly.sh: -------------------------------------------------------------------------------- 1 | python train_ucl.py \ 2 | --moco-random-video-frame-as-pos --multiprocessing-distributed --soft-nn \ 3 | --cycle-back-cls --mlp --aug-plus --cycle-back-candidates \ 4 | --cycle-back-cls-video-as-pos --sep-head \ 5 | --pretrained-models --exp_name="UCL_Butterfly" \ 6 | --learning-rate=0.003 \ 7 | --cos \ 8 | --dist-url="tcp://localhost:10001" \ 9 | --moco-t=0.07 \ 10 | --soft-nn-t=0.07 \ 11 | --soft-nn-topk-support \ 12 | --negatives \ 13 | --intranegs-only-two \ 14 | --soft-nn-loss-weight=0.1 \ 15 | --cross-neg-topk-mining \ 16 | --cross-neg-topk-support-size=4 \ 17 | --num-var=22 \ 18 | --num-gpu=2 \ 19 | --anchor-reverse-cross \ 20 | --single-loss-intra-inter --single-loss-ncap-support-size=2 --num-negatives=3 \ 21 | --qcap-include --cosine-curriculum $1 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | arrow==1.2.2 3 | astor==0.8.1 4 | atari-py==0.2.9 5 | attrs==21.4.0 6 | cached-property==1.5.2 7 | cachetools==4.2.4 8 | certifi==2021.10.8 9 | charset-normalizer==2.0.11 10 | click==8.0.3 11 | cloudpickle==1.2.2 12 | cycler==0.11.0 13 | Cython==0.29.27 14 | easydict==1.9 15 | fonttools==4.29.1 16 | fqdn==1.5.1 17 | future==0.18.2 18 | gast==0.2.2 19 | gitdb==4.0.9 20 | google-pasta==0.2.0 21 | grpcio==1.43.0 22 | h5py==3.6.0 23 | idna==3.3 24 | imageio==2.9.0 25 | importlib-metadata==4.10.1 26 | importlib-resources==5.4.0 27 | isoduration==20.11.0 28 | jmespath==0.10.0 29 | joblib==1.1.0 30 | jsonpointer==2.2 31 | jsonref==0.2 32 | jsonschema==4.4.0 33 | Keras-Applications==1.0.8 34 | Keras-Preprocessing==1.1.2 35 | kiwisolver==1.3.2 36 | Markdown==3.3.6 37 | matplotlib==3.5.1 38 | monotonic==1.6 39 | msgpack==1.0.3 40 | networkx==2.6.3 41 | numpy==1.21.5 42 | opencv-python==4.5.5.62 43 | opt-einsum==3.3.0 44 | packaging==21.3 45 | pandas==1.3.5 46 | Pillow==6.2.1 47 | protobuf==3.19.4 48 | psutil==5.9.0 49 | pyasn1==0.4.8 50 | pyasn1-modules==0.2.8 51 | pyglet==1.3.2 52 | PyJWT==2.3.0 53 | pyparsing==3.0.7 54 | pyrsistent==0.18.1 55 | python-dateutil==2.8.2 56 | pytz==2021.3 57 | PyWavelets==1.2.0 58 | PyYAML==6.0 59 | rfc3339-validator==0.1.4 60 | rfc3987==1.3.8 61 | rsa==4.8 62 | s3transfer==0.5.1 63 | scikit-image==0.19.1 64 | scikit-learn==1.0.2 65 | scipy==1.4.1 66 | seaborn==0.11.2 67 | simplejson==3.17.6 68 | six==1.16.0 69 | smmap==5.0.0 70 | tabulate==0.8.9 71 | tensorboard==2.1.1 72 | tensorboardX==2.4.1 73 | termcolor==1.1.0 74 | threadpoolctl==3.1.0 75 | tifffile==2021.11.2 76 | torch==1.3.0 77 | torchvision==0.4.1 78 | tqdm==4.62.3 79 | typing_extensions==4.0.1 80 | uri-template==1.1.0 81 | urllib3==1.26.8 82 | webcolors==1.11.1 83 | Werkzeug==2.0.3 84 | wrapt==1.13.3 85 | yacs==0.1.8 86 | zipp==3.7.0 87 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Unsupervised Contrastive Learning of Image Representations from Ultrasound Videos with Hard Negative Mining 2 | 3 | This is the official implementation of the UCL-method based on Hard Negative Mining, introduced in our paper [https://arxiv.org/abs/2207.13148](https://arxiv.org/abs/2207.13148). 4 | 5 | ### Installation 6 | 7 | Our code is tested on Python 3.7 and Pytorch 1.3.0, please install the environment via 8 | 9 | ``` 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | Note: To run our code you'll need multiple GPUs this code is not designed to run on single GPU systems. 14 | 15 | ### Our Datasets 16 | 1. Video Dataset: To get our Video Dataset(GBUSV) follow the instructions [here](https://gbc-iitd.github.io/data/gbusv). 17 | 2. Image Dataset: To get our Image Dataset(GBCU) follow the instructions [here](https://gbc-iitd.github.io/data/gbcu). 18 | 19 | ### Model Zoo 20 | 21 | We provide the models pretrained on Video Data for 50 epochs. 22 | 23 | | Method | Downstream Task | Backbone | Pre-trained model | Fine-tuned model (on a single val split) 24 | |---------------|:--------:|:--------:|:-----------------:|:-------------------------:| 25 | | UCL | Gallbladder Cancer Detection | Resnet50 | [pretrain ckpt](https://drive.google.com/file/d/1nu4-WtuUj7VIV4vyKmoz9M0Tw2GXvS9P/view?usp=sharing) | [finetune ckpt](https://drive.google.com/file/d/1H9Abh9YvKkIUe38opbAPWhDEt51XJxTd/view) | 26 | | UCL | COVID-19 Detection | Resnet50 | [pretrain ckpt](https://drive.google.com/file/d/1giXcf52tD2zUmQuC_DXBXdCXXnS3xFIm/view?usp=sharing) | - | 27 | 28 | 29 | ### Run on Custom Dataset 30 | #### Data preparation 31 | 32 | The directory structure should be as followed: 33 | ``` 34 | Data_64 35 | │ ├── Video_1 36 | │ │ ├── 00000.jpg 37 | | | ├── 00001.jpg 38 | | | ├── 00002.jpg 39 | | | ├── ..... 40 | | ├── Video_2 41 | │ │ ├── 00000.jpg 42 | | | ├── 00001.jpg 43 | | | ├── 00002.jpg 44 | | | ├── .... 45 | ``` 46 | 47 | #### Note on flags in pre-training script 48 | In the pre-training script some flags which can be experimented if the default configuration doesn't give results: 49 | 1. moco-t: It selects the temperature used while loss calculation 50 | 2. soft-nn-t: It selects the temperature used to calculate cross video negative 51 | 3. single-loss-ncap-support-size: It decides the number of cross video negatives used to form cross video negative used in the loss function 52 | 4. num-negatives: It selects ths number of intra video negatives to be used 53 | 54 | ### Experiments on GB Datasets 55 | 56 | #### Unsupervised Contrastive Pretrain 57 | ``` 58 | bash scripts/pretrain_gb.sh 59 | ``` 60 | #### Fine-tune for Downstream Classification 61 | ``` 62 | bash scripts/finetuning_gb.sh 63 | ``` 64 | #### Evaluate Classifier 65 | ``` 66 | bash scripts/evaluate_gb.sh 67 | ``` 68 | ### Experiments on POCUS and Butterfly Dataset 69 | 70 | #### Getting Butterfly and POCUS Dataset 71 | Please follow instructions on [USCL Repo](https://github.com/983632847/USCL) to get these datasets. 72 | 73 | #### Unsupervised Contrastive Pretrain 74 | 1. After Downloading the butterfly dataset make sure the frames in each video follow the naming convention as described in Data Preparation section of this README. 75 | 2. Run the command: 76 | ``` 77 | bash scripts/pretrain_ucl_butterfly.sh 78 | ``` 79 | 80 | #### Fine-Tuning and Evaluating Classifier 81 | For this we have followed the evaluation mechanism as proposed in [USCL](https://link.springer.com/chapter/10.1007/978-3-030-87237-3_60), the script modified for our proposed model can be found [here](scripts/eval_ucl_pocus.py). 82 | 83 | Steps to run this script: 84 | 1. Setup environment and code from [USCL repo](https://github.com/983632847/USCL). 85 | 2. Place [Our Evaluation Script](scripts/eval_ucl_pocus.py) in eval_pretrained_model directory of USCL repo. 86 | 3. Run the command: 87 | ``` 88 | python eval_ucl_pocus.py --path --gpu 89 | ``` 90 | 91 | ## Acknowledgements 92 | The codebase is based on [CycleContrast](https://github.com/happywu/CycleContrast) and [MoCo](https://github.com/facebookresearch/moco). 93 | -------------------------------------------------------------------------------- /ucl/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def regression_loss(logits, labels, num_steps, steps, seq_lens, loss_type, 5 | normalize_indices, variance_lambda, huber_delta): 6 | """Loss function based on regressing to the correct indices. 7 | In the paper, this is called Cycle-back Regression. There are 3 variants 8 | of this loss: 9 | i) regression_mse: MSE of the predicted indices and ground truth indices. 10 | ii) regression_mse_var: MSE of the predicted indices that takes into account 11 | the variance of the similarities. This is important when the rate at which 12 | sequences go through different phases changes a lot. The variance scaling 13 | allows dynamic weighting of the MSE loss based on the similarities. 14 | iii) regression_huber: Huber loss between the predicted indices and ground 15 | truth indices. 16 | Args: 17 | logits: Tensor, Pre-softmax similarity scores after cycling back to the 18 | starting sequence. 19 | labels: Tensor, One hot labels containing the ground truth. The index where 20 | the cycle started is 1. 21 | num_steps: Integer, Number of steps in the sequence embeddings. 22 | steps: Tensor, step indices/frame indices of the embeddings of the shape 23 | [N, T] where N is the batch size, T is the number of the timesteps. 24 | seq_lens: Tensor, Lengths of the sequences from which the sampling was done. 25 | This can provide additional temporal information to the alignment loss. 26 | loss_type: String, This specifies the kind of regression loss function. 27 | Currently supported loss functions: regression_mse, regression_mse_var, 28 | regression_huber. 29 | normalize_indices: Boolean, If True, normalizes indices by sequence lengths. 30 | Useful for ensuring numerical instabilities don't arise as sequence 31 | indices can be large numbers. 32 | variance_lambda: Float, Weight of the variance of the similarity 33 | predictions while cycling back. If this is high then the low variance 34 | similarities are preferred by the loss while making this term low results 35 | in high variance of the similarities (more uniform/random matching). 36 | huber_delta: float, Huber delta described in tf.keras.losses.huber_loss. 37 | Returns: 38 | loss: Tensor, A scalar loss calculated using a variant of regression. 39 | """ 40 | # Just to be safe, we stop gradients from labels as we are generating labels. 41 | labels = labels.detach() 42 | steps = steps.detach() 43 | 44 | # print('steps', steps.shape) 45 | # print('seq_lens', seq_lens) 46 | # print('labels', labels) 47 | # always normalize to 0~1 48 | # print('steps', steps.float()) 49 | # print('seq len', seq_lens.shape) 50 | # tile_seq_lens = seq_lens.clone().expand(-1, num_steps) 51 | 52 | # NOTICES: have to use repeat than expand, otherwise strange behavior would occur 53 | tile_seq_lens = seq_lens.repeat(1, num_steps) 54 | # print('normalized steps, torch div', torch.div(steps.float(), tile_seq_lens)) 55 | steps = steps.float() / tile_seq_lens 56 | # print('tile', tile_seq_lens.shape, steps.shape, tile_seq_lens) 57 | # print('normalized steps', steps) 58 | # print(steps, seq_lens.shape, steps.shape) 59 | # print('be', logits.shape, labels.shape) 60 | 61 | beta = F.softmax(logits, dim=1) 62 | # true_time = torch.sum(steps * labels, dim=1) 63 | true_time = torch.gather(steps, dim=1, index=labels.view(-1, 1)).squeeze(-1) 64 | pred_time = torch.sum(steps * beta, dim=1) 65 | # print(true_time, pred_time, true_time.shape, pred_time.shape) 66 | # print('beta', beta) 67 | # print('pred shape', true_time.shape, pred_time.shape) 68 | # print('true, pred time', true_time[:5], pred_time[:5]) 69 | # print('true, pred time', true_time, pred_time) 70 | 71 | if loss_type in ['regression_mse', 'regression_mse_var']: 72 | if 'var' in loss_type: 73 | # Variance aware regression. 74 | # pred_time_tiled = pred_time.view(-1, 1).expand(-1, num_steps) 75 | pred_time_tiled = pred_time.view(-1, 1).repeat(1, num_steps) 76 | pred_time_variance = torch.sum((steps - pred_time_tiled) ** 2 * beta, dim=1) 77 | 78 | # Using log of variance as it is numerically stabler. 79 | pred_time_log_var = torch.log(pred_time_variance) 80 | squared_error = (true_time - pred_time) ** 2 81 | # print('squared_error', torch.mean(squared_error)) 82 | # mean_error = torch.exp(-pred_time_log_var) * squared_error 83 | # print(squared_error[:5], (pred_time_log_var * variance_lambda)[:5]) 84 | return torch.mean(torch.exp(-pred_time_log_var) * squared_error 85 | + variance_lambda * pred_time_log_var) 86 | 87 | else: 88 | # squared_error = (true_time - pred_time) ** 2 89 | # print('squared_error', torch.mean(squared_error), true_time - pred_time) 90 | return torch.mean((true_time - pred_time) ** 2) 91 | else: 92 | raise ValueError('Unsupported regression loss %s. Supported losses are: ' 93 | 'regression_mse, regresstion_mse_var and regression_huber.' 94 | % loss_type) 95 | -------------------------------------------------------------------------------- /ucl/loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from PIL import ImageFilter 3 | import random 4 | 5 | import torch 6 | 7 | class CropsTransform: 8 | """Take two random crops of one image as the query and key.""" 9 | 10 | def __init__(self, base_transform): 11 | self.base_transform = base_transform 12 | 13 | def __call__(self, xq, xk, xn): 14 | q = self.base_transform(xq) 15 | k = self.base_transform(xk) 16 | n = [self.base_transform(y) for y in xn] 17 | return [q, k, n] 18 | 19 | class TwoCropsTransform: 20 | """Take two random crops of one image as the query and key.""" 21 | 22 | def __init__(self, base_transform): 23 | self.base_transform = base_transform 24 | 25 | def __call__(self, x): 26 | q = self.base_transform(x) 27 | k = self.base_transform(x) 28 | return [q, k] 29 | 30 | 31 | class MultiCropsTransform: 32 | """Take two random crops of one image as the query and key.""" 33 | 34 | def __init__(self, base_transform, num_crops=2): 35 | self.base_transform = base_transform 36 | self.num_crops = num_crops 37 | 38 | def __call__(self, x): 39 | multi_crops = [self.base_transform(x) for _i in range(self.num_crops)] 40 | return multi_crops 41 | 42 | 43 | class GaussianBlur(object): 44 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" 45 | 46 | def __init__(self, sigma=[.1, 2.]): 47 | self.sigma = sigma 48 | 49 | def __call__(self, x): 50 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 51 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 52 | return x 53 | 54 | 55 | class data_prefetcher(): 56 | def __init__(self, loader, return_all_video_frames=True): 57 | self.loader = iter(loader) 58 | self.stream = torch.cuda.Stream() 59 | self.return_all_video_frames = return_all_video_frames 60 | self.preload() 61 | 62 | def preload(self): 63 | try: 64 | if self.return_all_video_frames: 65 | self.next_input, self.next_target, self.next_indices, self.next_video_frames = next(self.loader) 66 | else: 67 | self.next_input, self.next_target, self.next_indices = next(self.loader) 68 | except StopIteration: 69 | self.next_input = None 70 | self.next_target = None 71 | self.next_indices = None 72 | self.next_video_frames = None 73 | return 74 | # if record_stream() doesn't work, another option is to make sure device inputs are created 75 | # on the main stream. 76 | # self.next_input_gpu = torch.empty_like(self.next_input, device='cuda') 77 | # self.next_target_gpu = torch.empty_like(self.next_target, device='cuda') 78 | # Need to make sure the memory allocated for next_* is not still in use by the main stream 79 | # at the time we start copying to next_*: 80 | # self.stream.wait_stream(torch.cuda.current_stream()) 81 | with torch.cuda.stream(self.stream): 82 | self.next_input[0] = self.next_input[0].cuda(non_blocking=True) 83 | self.next_input[1] = self.next_input[1].cuda(non_blocking=True) 84 | self.next_target = self.next_target.cuda(non_blocking=True) 85 | self.next_indices = self.next_indices.cuda(non_blocking=True) 86 | if self.return_all_video_frames: 87 | self.next_video_frames = self.next_video_frames.cuda(non_blocking=True) 88 | # more code for the alternative if record_stream() doesn't work: 89 | # copy_ will record the use of the pinned source tensor in this side stream. 90 | # self.next_input_gpu.copy_(self.next_input, non_blocking=True) 91 | # self.next_target_gpu.copy_(self.next_target, non_blocking=True) 92 | # self.next_input = self.next_input_gpu 93 | # self.next_target = self.next_target_gpu 94 | 95 | 96 | def next(self): 97 | torch.cuda.current_stream().wait_stream(self.stream) 98 | input = self.next_input 99 | target = self.next_target 100 | indices = self.next_indices 101 | if self.return_all_video_frames: 102 | video_frames = self.next_video_frames 103 | if input is not None: 104 | input[0].record_stream(torch.cuda.current_stream()) 105 | if input is not None: 106 | input[1].record_stream(torch.cuda.current_stream()) 107 | if target is not None: 108 | target.record_stream(torch.cuda.current_stream()) 109 | if indices is not None: 110 | indices.record_stream(torch.cuda.current_stream()) 111 | if self.return_all_video_frames and video_frames is not None: 112 | video_frames.record_stream(torch.cuda.current_stream()) 113 | self.preload() 114 | if self.return_all_video_frames: 115 | return input, target, indices, video_frames 116 | else: 117 | return input, target, indices 118 | -------------------------------------------------------------------------------- /ucl/resnet.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) by contributors 3 | # Licensed under the MIT License. 4 | # Written by Haiping Wu 5 | # ------------------------------------------------------------------------------ 6 | import torch 7 | import torch.nn as nn 8 | import torchvision 9 | from torchvision.models.utils import load_state_dict_from_url 10 | from torchvision.models.resnet import model_urls 11 | from torchvision.models.resnet import BasicBlock, Bottleneck 12 | 13 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50'] 14 | 15 | 16 | class ResNet(torchvision.models.resnet.ResNet): 17 | def __init__(self, block, layers, **kwargs): 18 | if 'dropout_rate' in kwargs: 19 | dropout_rate = kwargs['dropout_rate'] 20 | del kwargs['dropout_rate'] 21 | else: 22 | dropout_rate = 0.0 23 | if 'return_inter' in kwargs: 24 | self.return_inter = kwargs['return_inter'] 25 | del kwargs['return_inter'] 26 | else: 27 | self.return_inter = False 28 | if 'return_res4' in kwargs: 29 | self.return_res4 = kwargs['return_res4'] 30 | del kwargs['return_res4'] 31 | else: 32 | self.return_res4 = False 33 | super(ResNet, self).__init__(block, layers, **kwargs) 34 | self.dropout = nn.Dropout(p=dropout_rate) 35 | 36 | if self.return_res4: 37 | del self.layer4 38 | self.fc = nn.Linear(256, self.fc.weight.shape[0]) 39 | 40 | def _forward_impl(self, x): 41 | # See note [TorchScript super()] 42 | x = self.conv1(x) 43 | x = self.bn1(x) 44 | x = self.relu(x) 45 | x = self.maxpool(x) 46 | 47 | x = self.layer1(x) 48 | x = self.layer2(x) 49 | x = self.layer3(x) 50 | 51 | if not self.return_res4: 52 | x = self.layer4(x) 53 | 54 | x = self.avgpool(x) 55 | feat = torch.flatten(x, 1) 56 | x = self.dropout(feat) 57 | x = self.fc(x) 58 | 59 | if self.return_inter: 60 | return x, feat 61 | else: 62 | return x 63 | 64 | def forward(self, x): 65 | return self._forward_impl(x) 66 | 67 | 68 | class VideoClassifier(nn.Module): 69 | def __init__(self, base, num_classes, **kwargs): 70 | super(VideoClassifier, self).__init__() 71 | self.base = base 72 | # self.fc = nn.Linear(128 * 20, num_classes) 73 | self.fc = nn.Linear(128, num_classes) 74 | 75 | self.fc.weight.data.normal_(mean=0.0, std=0.01) 76 | self.fc.bias.data.zero_() 77 | 78 | def forward(self, x): 79 | # x: batch_size x num_frames x c x h x w 80 | bs, num_frames, c, h, w = x.shape 81 | # print(x.shape) 82 | x = x.view(bs * num_frames, c, h, w) 83 | x = self.base(x) 84 | 85 | x = x.view(bs, num_frames, -1) 86 | x = x.mean(dim=1) 87 | 88 | # x = x.view(bs, -1) 89 | 90 | 91 | x = self.fc(x) 92 | return x 93 | 94 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 95 | arbit_keyword_arg = {} 96 | 97 | for param_model in kwargs.items(): 98 | name_param_model = param_model[0] 99 | if name_param_model != 'num_classes': 100 | arbit_keyword_arg[name_param_model]=kwargs[name_param_model] 101 | 102 | model = ResNet(block, layers, **arbit_keyword_arg) 103 | if pretrained: 104 | state_dict = load_state_dict_from_url(model_urls[arch], 105 | progress=progress) 106 | model.load_state_dict(state_dict) 107 | model.fc = nn.Linear(model.fc.weight.shape[1],kwargs['num_classes']) 108 | return model 109 | 110 | 111 | def resnet18(pretrained=False, progress=True, **kwargs): 112 | r"""ResNet-18 model from 113 | `"Deep Residual Learning for Image Recognition" `_ 114 | 115 | Args: 116 | pretrained (bool): If True, returns a model pre-trained on ImageNet 117 | progress (bool): If True, displays a progress bar of the download to stderr 118 | """ 119 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 120 | **kwargs) 121 | 122 | 123 | def resnet34(pretrained=False, progress=True, **kwargs): 124 | r"""ResNet-34 model from 125 | `"Deep Residual Learning for Image Recognition" `_ 126 | 127 | Args: 128 | pretrained (bool): If True, returns a model pre-trained on ImageNet 129 | progress (bool): If True, displays a progress bar of the download to stderr 130 | """ 131 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 132 | **kwargs) 133 | 134 | 135 | def resnet50(pretrained=False, progress=True, **kwargs): 136 | r"""ResNet-50 model from 137 | `"Deep Residual Learning for Image Recognition" `_ 138 | 139 | Args: 140 | pretrained (bool): If True, returns a model pre-trained on ImageNet 141 | progress (bool): If True, displays a progress bar of the download to stderr 142 | """ 143 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 144 | **kwargs) 145 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) by contributors 3 | # Licensed under the MIT License. 4 | # Written by Haiping Wu 5 | # ------------------------------------------------------------------------------ 6 | import os 7 | import shutil 8 | import math 9 | 10 | import numpy as np 11 | import cv2 12 | import torch 13 | import torchvision 14 | import matplotlib.pyplot as plt 15 | 16 | class AverageMeter(object): 17 | """Computes and stores the average and current value""" 18 | def __init__(self, name, fmt=':f'): 19 | self.name = name 20 | self.fmt = fmt 21 | self.reset() 22 | 23 | def reset(self): 24 | self.val = 0 25 | self.avg = 0 26 | self.sum = 0 27 | self.count = 0 28 | 29 | def update(self, val, n=1): 30 | self.val = val 31 | self.sum += val * n 32 | self.count += n 33 | self.avg = self.sum / self.count 34 | 35 | def __str__(self): 36 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 37 | return fmtstr.format(**self.__dict__) 38 | 39 | 40 | class ProgressMeter(object): 41 | def __init__(self, num_batches, meters, prefix=""): 42 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 43 | self.meters = meters 44 | self.prefix = prefix 45 | 46 | def display(self, batch): 47 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 48 | entries += [str(meter) for meter in self.meters] 49 | print('\t'.join(entries)) 50 | 51 | def _get_batch_fmtstr(self, num_batches): 52 | num_digits = len(str(num_batches // 1)) 53 | fmt = '{:' + str(num_digits) + 'd}' 54 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 55 | 56 | 57 | def adjust_learning_rate(optimizer, epoch, args): 58 | """Decay the learning rate based on schedule""" 59 | lr = args.lr 60 | for milestone in args.schedule: 61 | lr *= 0.1 if epoch >= milestone else 1. 62 | for param_group in optimizer.param_groups: 63 | param_group['lr'] = lr 64 | 65 | 66 | def accuracy(output, target, topk=(1,)): 67 | """Computes the accuracy over the k top predictions for the specified values of k""" 68 | with torch.no_grad(): 69 | maxk = max(topk) 70 | batch_size = target.size(0) 71 | 72 | _, pred = output.topk(maxk, 1, True, True) 73 | pred = pred.t() 74 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 75 | 76 | res = [] 77 | for k in topk: 78 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 79 | res.append(correct_k.mul_(100.0 / batch_size)) 80 | return res 81 | 82 | 83 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', path=None): 84 | if dir is not None: 85 | filename = os.path.join(path, filename) 86 | torch.save(state, filename) 87 | if is_best: 88 | shutil.copyfile(filename, 'model_best.pth.tar') 89 | 90 | 91 | def sanity_check(state_dict, pretrained_weights): 92 | """ 93 | Linear classifier should not change any weights other than the linear layer. 94 | This sanity check asserts nothing wrong happens (e.g., BN stats updated). 95 | """ 96 | print("=> loading '{}' for sanity check".format(pretrained_weights)) 97 | checkpoint = torch.load(pretrained_weights, map_location="cpu") 98 | state_dict_pre = checkpoint['state_dict'] 99 | 100 | for k in list(state_dict.keys()): 101 | # only ignore fc layer 102 | if 'fc.weight' in k or 'fc.bias' in k: 103 | continue 104 | 105 | # name in pretrained model 106 | k_pre = 'module.encoder_q.' + k[len('module.'):] \ 107 | if k.startswith('module.') else 'module.encoder_q.' + k 108 | 109 | assert ((state_dict[k].cpu() == state_dict_pre[k_pre]).all()), \ 110 | '{} is changed in linear classifier training.'.format(k) 111 | 112 | print("=> sanity check passed.") 113 | 114 | 115 | def save_batch_image(batch_image, 116 | file_name, nrow=8, padding=2): 117 | ''' 118 | batch_image: [batch_size, channel, height, width] 119 | batch_joints: [batch_size, num_joints, 3], 120 | batch_joints_vis: [batch_size, num_joints, 1], 121 | } 122 | ''' 123 | grid = torchvision.utils.make_grid(batch_image, nrow, padding, True) 124 | ndarr = grid.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy() 125 | ndarr = ndarr.copy() 126 | 127 | nmaps = batch_image.size(0) 128 | xmaps = min(nrow, nmaps) 129 | ymaps = int(math.ceil(float(nmaps) / xmaps)) 130 | height = int(batch_image.size(2) + padding) 131 | width = int(batch_image.size(3) + padding) 132 | k = 0 133 | cv2.imwrite(file_name, ndarr) 134 | 135 | 136 | def save_debug_images(input, prefix='debug'): 137 | save_batch_image( 138 | input, 139 | '{}_gt.jpg'.format(prefix) 140 | ) 141 | 142 | 143 | def matplotlib_imshow(img, one_channel=False): 144 | if one_channel: 145 | img = img.mean(dim=0) 146 | # print('max', img.max()) 147 | img = img / 2 + 0.5 # unnormalize 148 | # print('after max', img.max()) 149 | img = img / img.max() 150 | npimg = img.numpy() 151 | if one_channel: 152 | plt.imshow(npimg, cmap="Greys") 153 | else: 154 | plt.imshow(np.transpose(npimg, (1, 2, 0))) 155 | 156 | 157 | def plot_knn_images(images, probs, labels, q_label): 158 | fig = plt.figure() 159 | # figsize=(120, 480)) 160 | num_images = len(images) 161 | ax = fig.add_subplot(1, num_images, 1, xticks=[], yticks=[]) 162 | matplotlib_imshow(images[0]) 163 | ax.set_title("Query, class {}".format(q_label)) 164 | for idx in np.arange(1, num_images): 165 | ax = fig.add_subplot(1, num_images, idx+1, xticks=[], yticks=[]) 166 | matplotlib_imshow(images[idx]) 167 | ax.set_title("{0:.1f}%\n(label: {1})".format( 168 | probs[idx-1] * 100.0, 169 | labels[idx-1]), 170 | color=("green" if q_label==labels[idx-1].item() else "red")) 171 | return fig 172 | 173 | 174 | def setup_for_distributed(is_master): 175 | """ 176 | This function disables printing when not in master process 177 | """ 178 | import builtins as __builtin__ 179 | builtin_print = __builtin__.print 180 | 181 | def print(*args, **kwargs): 182 | force = kwargs.pop('force', False) 183 | if is_master or force: 184 | builtin_print(*args, **kwargs) 185 | 186 | __builtin__.print = print 187 | 188 | 189 | def init_distributed_mode(args): 190 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 191 | args.rank = int(os.environ["RANK"]) 192 | args.world_size = int(os.environ['WORLD_SIZE']) 193 | args.gpu = int(os.environ['LOCAL_RANK']) 194 | elif 'SLURM_PROCID' in os.environ: 195 | args.rank = int(os.environ['SLURM_PROCID']) 196 | args.gpu = args.rank % torch.cuda.device_count() 197 | else: 198 | print('Not using distributed mode') 199 | args.distributed = False 200 | return 201 | 202 | # import socket 203 | # ip = socket.gethostbyname(socket.gethostname()) 204 | # port = find_free_port() 205 | # args.dist_url = "tcp://{}:{}".format(ip, port) 206 | # jobid = os.environ["SLURM_JOBID"] 207 | # hostfile = "dist_url." + jobid + ".txt" 208 | # with open(hostfile, "w") as f: 209 | # f.write(args.dist_url) 210 | # print("dist-url:{} at PROCID {} / {}".format(args.dist_url, args.rank, args.world_size)) 211 | 212 | args.distributed = True 213 | 214 | torch.cuda.set_device(args.gpu) 215 | args.dist_backend = 'nccl' 216 | print('| distributed init (rank {}, local rank {}, gpu {}): {}'.format( 217 | args.rank, args.local_rank, args.gpu, args.dist_url), flush=True) 218 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 219 | world_size=args.world_size, rank=args.rank) 220 | # torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url) 221 | torch.distributed.barrier() 222 | 223 | setup_for_distributed(args.rank == 0) 224 | 225 | def is_dist_avail_and_initialized(): 226 | if not torch.distributed.is_available(): 227 | return False 228 | if not torch.distributed.is_initialized(): 229 | return False 230 | return True 231 | 232 | def get_world_size(): 233 | if not is_dist_avail_and_initialized(): 234 | return 1 235 | return torch.distributed.get_world_size() 236 | 237 | def get_rank(): 238 | if not is_dist_avail_and_initialized(): 239 | return 0 240 | return torch.distributed.get_rank() 241 | 242 | 243 | def is_main_process(): 244 | return get_rank() == 0 245 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | """ 4 | Read images and corresponding labels. 5 | """ 6 | import cv2 7 | import json 8 | import torch 9 | from torch.utils.data import Dataset, DataLoader 10 | import torchvision.transforms as T 11 | from PIL import Image, ImageFilter 12 | import os 13 | 14 | 15 | class GbUsgDataSet(Dataset): 16 | def __init__(self, data_dir, image_list_file, bin_classify=False, to_blur=False, sigma=0, transform=None): 17 | """ 18 | Args: 19 | data_dir: path to image directory. 20 | image_list_file: path to the file containing images 21 | with corresponding labels. 22 | transform: optional transform to be applied on a sample. 23 | """ 24 | image_names = [] 25 | labels = [] 26 | with open(image_list_file, "r") as f: 27 | for line in f: 28 | items = line.split(",") 29 | image_name= items[0] 30 | label = int(items[1]) 31 | image_name = os.path.join(data_dir, image_name) 32 | image_names.append(image_name) 33 | if bin_classify: 34 | label = 0 if label in [0, 1] else 1 35 | labels.append(label) 36 | self.image_names = image_names 37 | self.labels = labels 38 | self.transform = transform 39 | self.to_blur = to_blur 40 | self.sigma = sigma 41 | 42 | def __getitem__(self, index): 43 | """ 44 | Args: 45 | index: the index of item 46 | 47 | Returns: 48 | image and its labels 49 | """ 50 | image_name = self.image_names[index] 51 | image = Image.open(image_name).convert('RGB') 52 | #image = cv2.imread(image_name) 53 | if self.to_blur: 54 | image = image.filter(ImageFilter.GaussianBlur(self.sigma)) 55 | label = self.labels[index] 56 | if self.transform is not None: 57 | image = self.transform(image) 58 | label = torch.as_tensor(label, dtype=torch.int64) 59 | return image, label, image_name 60 | 61 | def __len__(self): 62 | return len(self.image_names) 63 | 64 | 65 | def crop_image(image, box, p=0.1): 66 | x1, y1, x2, y2 = box 67 | x1 = (1-p)*x1 68 | y1 = (1-p)*y1 69 | x2 = (1+p)*x2 70 | y2 = (1+p)*y2 71 | cropped_img = image.crop((x1,y1,x2,y2)) 72 | return cropped_img 73 | 74 | 75 | class GbDataSet(Dataset): 76 | def __init__(self, data_dir, image_list_file, df, p=0.1, train=True, transform=None): 77 | """ 78 | Args: 79 | data_dir: path to image directory. 80 | image_list_file: path to the file containing images 81 | with corresponding labels. 82 | transform: optional transform to be applied on a sample. 83 | """ 84 | file_names = [] 85 | image_names = [] 86 | labels = [] 87 | with open(image_list_file, "r") as f: 88 | for line in f: 89 | items = line.split(",") 90 | image_name= items[0] 91 | label = int(items[1]) 92 | file_names.append(image_name) 93 | image_name = os.path.join(data_dir, image_name) 94 | image_names.append(image_name) 95 | labels.append(label) 96 | self.image_names = image_names 97 | self.file_names = file_names 98 | self.labels = labels 99 | self.transform = transform 100 | self.df = df 101 | self.p = p 102 | self.train = train 103 | 104 | def __getitem__(self, index): 105 | """ 106 | Args: 107 | index: the index of item 108 | 109 | Returns: 110 | image and its labels 111 | """ 112 | image_name = self.image_names[index] 113 | file_name = self.file_names[index] 114 | image = Image.open(image_name).convert('RGB') 115 | if self.transform is not None: 116 | x = self.transform(image) 117 | label = self.labels[index] 118 | z, label = self.get_crop_image(image, label, file_name) 119 | return x, z, label, file_name 120 | 121 | def get_crop_image(self, image, label, file_name): 122 | """ Get ROI cropped images 123 | """ 124 | label = torch.as_tensor(label, dtype=torch.int64) 125 | if self.train: 126 | image = crop_image(image, self.df[file_name]["Gold"], self.p) 127 | if self.transform is not None: 128 | img = self.transform(image) 129 | else: 130 | if self.transform is not None: 131 | orig = self.transform(image) 132 | num_objs = len(self.df[file_name]["Boxes"]) 133 | imgs = [] 134 | labels = [] 135 | for i in range(num_objs): 136 | bbs = self.df[file_name]["Boxes"][i] 137 | crop_img = crop_image(image, bbs, self.p) 138 | if self.transform: 139 | crop_img = self.transform(crop_img) 140 | imgs.append(crop_img) 141 | labels.append(label) 142 | if num_objs == 0: 143 | img = orig.unsqueeze(0) 144 | label = label.unsqueeze(0) 145 | else: 146 | img = torch.stack(imgs, 0) 147 | label = torch.stack(labels, 0) 148 | return img, label 149 | 150 | def __len__(self): 151 | return len(self.image_names) 152 | 153 | 154 | class GbUsgRoiTrainDataSet(Dataset): 155 | def __init__(self, data_dir, image_list_file, df, to_blur=False, sigma=0, transform=None): 156 | """ 157 | Args: 158 | data_dir: path to image directory. 159 | image_list_file: path to the file containing images 160 | with corresponding labels. 161 | transform: optional transform to be applied on a sample. 162 | """ 163 | image_names = [] 164 | labels = [] 165 | file_names = [] 166 | with open(image_list_file, "r") as f: 167 | for line in f: 168 | items = line.split(",") 169 | image_name= items[0] 170 | label = int(items[1]) 171 | file_names.append(image_name) 172 | image_name = os.path.join(data_dir, image_name) 173 | image_names.append(image_name) 174 | labels.append(label) 175 | self.image_names = image_names 176 | self.file_names = file_names 177 | self.labels = labels 178 | self.transform = transform 179 | self.df = df 180 | self.to_blur = to_blur 181 | self.sigma = sigma 182 | 183 | def __getitem__(self, index): 184 | """ 185 | Args: 186 | index: the index of item 187 | 188 | Returns: 189 | image and its labels 190 | """ 191 | image_name = self.image_names[index] 192 | file_name = self.file_names[index] 193 | image = Image.open(image_name).convert('RGB') 194 | label = self.labels[index] 195 | if self.to_blur: 196 | image = image.filter(ImageFilter.GaussianBlur(self.sigma)) 197 | image = crop_image(image, self.df[file_name]["Gold"]) 198 | if self.transform is not None: 199 | image = self.transform(image) 200 | label = torch.as_tensor(label, dtype=torch.int64) 201 | return image, label, image_name 202 | 203 | def __len__(self): 204 | return len(self.image_names) 205 | 206 | 207 | class GbUsgRoiTestDataSet(Dataset): 208 | def __init__(self, data_dir, image_list_file, df, to_blur=False, sigma=0, transform=None): 209 | """ 210 | Args: 211 | data_dir: path to image directory. 212 | image_list_file: path to the file containing images 213 | with corresponding labels. 214 | transform: optional transform to be applied on a sample. 215 | """ 216 | image_names = [] 217 | labels = [] 218 | file_names = [] 219 | with open(image_list_file, "r") as f: 220 | for line in f: 221 | items = line.split(",") 222 | image_name= items[0] 223 | label = int(items[1]) 224 | file_names.append(image_name) 225 | image_name = os.path.join(data_dir, image_name) 226 | image_names.append(image_name) 227 | labels.append(label) 228 | self.image_names = image_names 229 | self.file_names = file_names 230 | self.labels = labels 231 | self.transform = transform 232 | self.df = df 233 | self.to_blur = to_blur 234 | self.sigma = sigma 235 | 236 | def __getitem__(self, index): 237 | """ 238 | Args: 239 | index: the index of item 240 | 241 | Returns: 242 | image and its labels 243 | """ 244 | image_name = self.image_names[index] 245 | file_name = self.file_names[index] 246 | image = Image.open(image_name).convert('RGB') 247 | label = self.labels[index] 248 | if self.to_blur: 249 | image = image.filter(ImageFilter.GaussianBlur(self.sigma)) 250 | #orig = crop_image(image, self.df[file_name]["Gold"]) 251 | if self.transform is not None: 252 | orig = self.transform(image) 253 | label = torch.as_tensor(label, dtype=torch.int64) 254 | 255 | num_objs = len(self.df[file_name]["Boxes"]) 256 | imgs = [] 257 | labels = [] 258 | for i in range(num_objs): 259 | bbs = self.df[file_name]["Boxes"][i] 260 | crop_img = crop_image(image, bbs, 0.1) 261 | if self.transform: 262 | crop_img = self.transform(crop_img) 263 | imgs.append(crop_img) 264 | labels.append(label) 265 | if num_objs == 0: 266 | img = orig.unsqueeze(0) 267 | label = label.unsqueeze(0) 268 | else: 269 | img = torch.stack(imgs, 0) 270 | label = torch.stack(labels, 0) 271 | return img, label, image_name 272 | 273 | def __len__(self): 274 | return len(self.image_names) 275 | 276 | 277 | if __name__ == "__main__": 278 | with open("data/res_new.json", "r") as f: 279 | df = json.load(f) 280 | ds = GbUsgRoiTrainDataSet(data_dir="data/gb_imgs", image_list_file="data/cls_split/val.txt",\ 281 | df=df, transform = T.Compose([T.Resize(224), T.CenterCrop(224), T.ToTensor()])) 282 | dl = DataLoader(dataset=ds, batch_size=4, shuffle=False) 283 | for img, label, img_name in dl: 284 | print(label, img.size(),img_name) 285 | ds = GbUsgRoiTestDataSet(data_dir="data/gb_imgs", image_list_file="data/cls_split/val.txt",\ 286 | df=df, transform = T.Compose([T.Resize(224), T.CenterCrop(224), T.ToTensor()])) 287 | dl = DataLoader(dataset=ds, batch_size=1, shuffle=False) 288 | for img, label, img_name in dl: 289 | print(label, img.size(),img_name) 290 | -------------------------------------------------------------------------------- /ucl/datasets_gb.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) by contributors 3 | # Licensed under the MIT License. 4 | # Written by Haiping Wu 5 | # ------------------------------------------------------------------------------ 6 | import os 7 | import glob 8 | import random 9 | 10 | from PIL import Image 11 | import numpy as np 12 | from skimage import color 13 | import scipy.io as sio 14 | import tqdm 15 | import pickle 16 | from collections import OrderedDict, defaultdict 17 | 18 | import torch 19 | import torchvision.datasets as datasets 20 | import torchvision.transforms as transforms 21 | from torch import multiprocessing 22 | from torch.multiprocessing import Pool 23 | import cv2 24 | import math 25 | 26 | 27 | import ucl.loader 28 | 29 | 30 | def pil_loader(path): 31 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 32 | with open(path, 'rb') as f: 33 | img = Image.open(f) 34 | return img.convert('RGB') 35 | 36 | 37 | def accimage_loader(path): 38 | import accimage 39 | try: 40 | return accimage.Image(path) 41 | except IOError: 42 | # Potentially a decoding problem, fall back to PIL.Image 43 | return pil_loader(path) 44 | 45 | 46 | def default_loader(path): 47 | from torchvision import get_image_backend 48 | if get_image_backend() == 'accimage': 49 | return accimage_loader(path) 50 | else: 51 | return pil_loader(path) 52 | 53 | class ImageNetVal(torch.utils.data.Dataset): 54 | # the class name and idx do not necessarily follows the standard one 55 | def __init__(self, root, class_names, class_to_idx, extensions=None, transform=None, 56 | target_transform=None, is_valid_file=None): 57 | # super(ImageNetVal, self).__init__(root, transform=transform, 58 | # target_transform=target_transform) 59 | self.transform = transform 60 | self.target_transform = target_transform 61 | self.root = root 62 | samples = self._make_dataset(class_names, class_to_idx) 63 | if len(samples) == 0: 64 | raise (RuntimeError("Found 0 samples")) 65 | 66 | self.loader = default_loader 67 | 68 | self.classes = list(class_names) 69 | self.class_to_idx = class_to_idx 70 | self.samples = samples 71 | self.targets = [s[1] for s in samples] 72 | 73 | def _make_dataset(self, class_names, class_to_idx): 74 | meta_file = os.path.join(self.root, 'meta_clsloc.mat') 75 | meta = sio.loadmat(meta_file, squeeze_me=True)['synsets'] 76 | idcs, wnids, classes = list(zip(*meta))[:3] 77 | classes = [tuple(clss.split(', ')) for clss in classes] 78 | idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)} 79 | wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)} 80 | 81 | annot_file = os.path.join(self.root, 'ILSVRC2012_validation_ground_truth.txt') 82 | with open(annot_file, 'r') as f: 83 | val_idcs = f.readlines() 84 | val_idcs = [int(val_idx) for val_idx in val_idcs] 85 | pattern = os.path.join(self.root, 'ILSVRC2012_val_%08d.JPEG') 86 | samples = [] 87 | for i in range(50000): 88 | # filter class names needed 89 | gt_wnid = idx_to_wnid[val_idcs[i]] 90 | if gt_wnid in class_names: 91 | samples.append([pattern%(i+1), class_to_idx[gt_wnid]]) 92 | return samples 93 | 94 | def __getitem__(self, index): 95 | """ 96 | Args: 97 | index (int): Index 98 | Returns: 99 | tuple: (sample, target) where target is class_index of the target class. 100 | """ 101 | path, target = self.samples[index] 102 | sample = self.loader(path) 103 | if self.transform is not None: 104 | sample = self.transform(sample) 105 | if self.target_transform is not None: 106 | target = self.target_transform(target) 107 | 108 | return sample, target 109 | 110 | def __len__(self): 111 | return len(self.samples) 112 | 113 | 114 | 115 | class ImageFolderInstance(torch.utils.data.Dataset): 116 | """Folder datasets which returns the index of the image as well 117 | """ 118 | def __init__(self, root, transform=None, target_transform=None): 119 | self.dataset = datasets.ImageFolder(root, transform) 120 | 121 | def __getitem__(self, index): 122 | data, target = self.dataset[index] 123 | return data, target, index 124 | 125 | def __len__(self): 126 | return len(self.dataset) 127 | 128 | SAMPLE_NAME = "AA/AA2pFq9pFTA_000001.jpg" 129 | LEN_SAMPLE_NAME = len(SAMPLE_NAME) 130 | LEN_VID_NAME = len("AA2pFq9pFTA") 131 | LEN_NUM_NAME = len("000001") 132 | LEN_CLIP_NAME = len("0001") 133 | 134 | 135 | class GbVideoDataset(torch.utils.data.Dataset): 136 | """Folder datasets which returns the index of the image as well 137 | """ 138 | def __init__(self, root, transform=None, target_transform=None, \ 139 | data_split='train', return_all_video_frames=False, \ 140 | num_of_sampled_frames=4, return_same_frame_indicator=False, \ 141 | pos_dists=[3] , neg_dists=[0.2, 0.4], num_neg_samples=2, \ 142 | return_neg_frame=True): 143 | 144 | self.root = root 145 | self.transform = transform 146 | self.two_crops_transform = ucl.loader.TwoCropsTransform(transform) 147 | #self.crops_transform = cycle_contrast.loader.CropsTransform(transform) 148 | self.target_transform = target_transform ## None 149 | self.data_split = data_split 150 | self.return_all_video_frames = return_all_video_frames ## True 151 | self.num_of_sampled_frames = num_of_sampled_frames ##1 152 | self.return_same_frame_indicator = return_same_frame_indicator ## False 153 | self.return_neg_frame = return_neg_frame ## True 154 | self.pos_dists = pos_dists 155 | self.neg_dists = neg_dists 156 | self.num_neg_samples = num_neg_samples 157 | 158 | self._get_annotations() 159 | 160 | self.loader = default_loader 161 | 162 | @staticmethod 163 | def get_video_name(name): 164 | return name.split("/")[-2] 165 | 166 | @staticmethod 167 | def get_frame_id(name): 168 | return int(name.split("/")[-1][:-4]) 169 | 170 | def get_image_paths(self): 171 | print('path ############', self.data_basepath) 172 | return sorted(list(tqdm.tqdm(glob.iglob(os.path.join(self.data_basepath, "*/*.jpg"))))) 173 | 174 | def get_image_name(self, key: str, ind: int): 175 | return os.path.join(self.data_split_path, key, "%05d.jpg" % ind) 176 | 177 | def video_id_frame_id_split(self, name): 178 | return self.get_video_name(name), self.get_frame_id(name) 179 | 180 | def _get_single_frame(self, path_key, ind): 181 | return self.transform(self.loader(self.get_image_name(path_key, ind))) 182 | 183 | def _get_annotations(self): 184 | self.data_basepath = self.root 185 | self.data_split_path = os.path.join(self.data_basepath) 186 | pickle_path = os.path.join(self.data_basepath, self.data_split+ "_names.pkl") 187 | 188 | if not os.path.exists(pickle_path): 189 | print('creat new cache') 190 | images = self.get_image_paths() 191 | path_info = OrderedDict() 192 | video_names = sorted([self.video_id_frame_id_split(name) for name in images]) 193 | for vid_id, ind in video_names: 194 | if vid_id not in path_info: 195 | path_info[vid_id] = [] 196 | path_info[vid_id].append(ind) 197 | path_info = sorted([(key, val) for key, val in path_info.items()]) 198 | os.makedirs(self.data_split_path, exist_ok=True) 199 | pickle.dump(path_info, open(pickle_path, "wb")) 200 | self.path_info = pickle.load(open(pickle_path, "rb")) 201 | num_frames = int(np.sum([len(p_info[1]) for p_info in self.path_info])) 202 | print("Num for %s videos %d frames %d" % (self.data_split, len(self.path_info), num_frames)) 203 | 204 | def __getitem__(self, index): 205 | path_key, frame_ids = self.path_info[index] 206 | target = index 207 | ## index is the video number 208 | ## ind is the frame number 209 | num_frames = len(frame_ids) 210 | pos_dists = self.pos_dists 211 | neg_dists = self.neg_dists 212 | low = int(math.ceil(num_frames*0.2)) 213 | high = int(math.ceil(num_frames*0.5)) 214 | anchor_frame = np.random.randint(low, high) 215 | 216 | pos_indices = [elem for elem in range(max(0, anchor_frame-pos_dists[0]), \ 217 | min(num_frames, anchor_frame+pos_dists[0]+1))] 218 | pos_indices.remove(anchor_frame) 219 | 220 | left_low = max(0, min(anchor_frame-int(neg_dists[1]*num_frames), \ 221 | anchor_frame-pos_dists[0]-3)) 222 | left_high = max(0, min(anchor_frame-int(neg_dists[0]*num_frames), \ 223 | anchor_frame-pos_dists[0]-1)) 224 | right_low = min(num_frames, max(anchor_frame+int(neg_dists[0]*num_frames), \ 225 | anchor_frame+pos_dists[0]+1)) 226 | right_high = min(num_frames, max(anchor_frame+int(math.ceil(neg_dists[1]*num_frames)), \ 227 | anchor_frame+pos_dists[0]+3)) 228 | neg_indices = [elem for elem in range(left_low, left_high)] \ 229 | + [elem for elem in range(right_low, right_high)] 230 | 231 | pos_ind = np.random.choice(pos_indices) 232 | if self.return_neg_frame: 233 | #ind = anchor_frame + np.random.randint(self.pos_dists[0], self.pos_dists[1]+1) 234 | #neg_ = np.random.randint(int(self.neg_dists[0]*num_frames), \ 235 | # int(self.neg_dists[1]*num_frames)+1, size=self.num_neg_samples) 236 | #neg_inds = [(anchor_frame+d)%num_frames for d in neg_] 237 | neg_inds = np.random.choice(neg_indices, size=self.num_neg_samples) 238 | 239 | ## ind gets a random frame out of allthe frames 240 | 241 | q_img = self.loader(self.get_image_name(path_key, anchor_frame)) 242 | k_img = self.loader(self.get_image_name(path_key, pos_ind)) 243 | n_imgs = [self.loader(self.get_image_name(path_key, idx)) for idx in neg_inds] 244 | 245 | ## Loading the image at the chosen random index 246 | if self.transform is not None: 247 | sample = self.two_crops_transform(q_img) # sample is [q, k, [n1, n2]] 248 | #q_frame = sample[1] #self.transform(q_img) ## for the neighbor set 249 | k_frame = self.transform(k_img) 250 | n_frames = [self.transform(n) for n in n_imgs] 251 | video_frames = [k_frame] + n_frames 252 | video_frames = torch.stack(video_frames, dim=0) 253 | ## sample --> two augs of q 254 | ## video_frames --> augs of [k, n1, n2] 255 | return sample, target, index, video_frames 256 | 257 | elif self.return_all_video_frames: 258 | video_frames = [self.loader(self.get_image_name(path_key, _ind)) 259 | for _ind in frame_ids if _ind != ind] 260 | if self.transform is not None: 261 | video_frames = [self.transform(video_frame) for video_frame in video_frames] 262 | video_frames = torch.stack([self.transform(image), *video_frames], dim=0) 263 | return sample, target, index, video_frames 264 | 265 | else: 266 | return sample, target, index 267 | 268 | def __len__(self): 269 | # path_info: dictionary; video_name, frame_ids in video 270 | return len(self.path_info) 271 | 272 | 273 | def parse_file(dataset_adr, categories): 274 | dataset = [] 275 | with open(dataset_adr) as f: 276 | for line in f: 277 | line = line[:-1].split("/") 278 | category = "/".join(line[2:-1]) 279 | file_name = "/".join(line[2:]) 280 | if not category in categories: 281 | continue 282 | dataset.append([file_name, category]) 283 | return dataset 284 | 285 | 286 | def get_class_names(path): 287 | classes = [] 288 | with open(path) as f: 289 | for line in f: 290 | categ = "/".join(line[:-1].split("/")[2:]) 291 | classes.append(categ) 292 | class_dic = {classes[i]: i for i in range(len(classes))} 293 | return class_dic 294 | 295 | -------------------------------------------------------------------------------- /ucl/datasets.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) by contributors 3 | # Licensed under the MIT License. 4 | # Written by Haiping Wu 5 | # ------------------------------------------------------------------------------ 6 | import os 7 | import glob 8 | import random 9 | 10 | from PIL import Image 11 | import numpy as np 12 | from skimage import color 13 | import scipy.io as sio 14 | import tqdm 15 | import pickle 16 | from collections import OrderedDict, defaultdict 17 | 18 | import torch 19 | import torchvision.datasets as datasets 20 | import torchvision.transforms as transforms 21 | from torch import multiprocessing 22 | from torch.multiprocessing import Pool 23 | import cv2 24 | 25 | 26 | import ucl.loader 27 | 28 | 29 | def pil_loader(path): 30 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 31 | with open(path, 'rb') as f: 32 | img = Image.open(f) 33 | return img.convert('RGB') 34 | 35 | 36 | def accimage_loader(path): 37 | import accimage 38 | try: 39 | return accimage.Image(path) 40 | except IOError: 41 | # Potentially a decoding problem, fall back to PIL.Image 42 | return pil_loader(path) 43 | 44 | 45 | def default_loader(path): 46 | from torchvision import get_image_backend 47 | if get_image_backend() == 'accimage': 48 | return accimage_loader(path) 49 | else: 50 | return pil_loader(path) 51 | 52 | class ImageNetVal(torch.utils.data.Dataset): 53 | # the class name and idx do not necessarily follows the standard one 54 | def __init__(self, root, class_names, class_to_idx, extensions=None, transform=None, 55 | target_transform=None, is_valid_file=None): 56 | # super(ImageNetVal, self).__init__(root, transform=transform, 57 | # target_transform=target_transform) 58 | self.transform = transform 59 | self.target_transform = target_transform 60 | self.root = root 61 | samples = self._make_dataset(class_names, class_to_idx) 62 | if len(samples) == 0: 63 | raise (RuntimeError("Found 0 samples")) 64 | 65 | self.loader = default_loader 66 | 67 | self.classes = list(class_names) 68 | self.class_to_idx = class_to_idx 69 | self.samples = samples 70 | self.targets = [s[1] for s in samples] 71 | 72 | def _make_dataset(self, class_names, class_to_idx): 73 | meta_file = os.path.join(self.root, 'meta_clsloc.mat') 74 | meta = sio.loadmat(meta_file, squeeze_me=True)['synsets'] 75 | idcs, wnids, classes = list(zip(*meta))[:3] 76 | classes = [tuple(clss.split(', ')) for clss in classes] 77 | idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)} 78 | wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)} 79 | 80 | annot_file = os.path.join(self.root, 'ILSVRC2012_validation_ground_truth.txt') 81 | with open(annot_file, 'r') as f: 82 | val_idcs = f.readlines() 83 | val_idcs = [int(val_idx) for val_idx in val_idcs] 84 | pattern = os.path.join(self.root, 'ILSVRC2012_val_%08d.JPEG') 85 | samples = [] 86 | for i in range(50000): 87 | # filter class names needed 88 | gt_wnid = idx_to_wnid[val_idcs[i]] 89 | if gt_wnid in class_names: 90 | samples.append([pattern%(i+1), class_to_idx[gt_wnid]]) 91 | return samples 92 | 93 | def __getitem__(self, index): 94 | """ 95 | Args: 96 | index (int): Index 97 | Returns: 98 | tuple: (sample, target) where target is class_index of the target class. 99 | """ 100 | path, target = self.samples[index] 101 | sample = self.loader(path) 102 | if self.transform is not None: 103 | sample = self.transform(sample) 104 | if self.target_transform is not None: 105 | target = self.target_transform(target) 106 | 107 | return sample, target 108 | 109 | def __len__(self): 110 | return len(self.samples) 111 | 112 | 113 | 114 | class ImageFolderInstance(torch.utils.data.Dataset): 115 | """Folder datasets which returns the index of the image as well 116 | """ 117 | def __init__(self, root, transform=None, target_transform=None): 118 | self.dataset = datasets.ImageFolder(root, transform) 119 | 120 | def __getitem__(self, index): 121 | data, target = self.dataset[index] 122 | return data, target, index 123 | 124 | def __len__(self): 125 | return len(self.dataset) 126 | 127 | SAMPLE_NAME = "AA/AA2pFq9pFTA_000001.jpg" 128 | LEN_SAMPLE_NAME = len(SAMPLE_NAME) 129 | LEN_VID_NAME = len("AA2pFq9pFTA") 130 | LEN_NUM_NAME = len("000001") 131 | LEN_CLIP_NAME = len("0001") 132 | 133 | class R2V2Dataset(torch.utils.data.Dataset): 134 | """Folder datasets which returns the index of the image as well 135 | """ 136 | def __init__(self, root, transform=None, target_transform=None, data_split='train', return_all_video_frames=False, 137 | num_of_sampled_frames=-1, return_same_frame_indicator=False, return_neg_frame=False): 138 | import pandas as pd 139 | self.root = root 140 | self.transform = transform 141 | self.two_crops_transform = ucl.loader.TwoCropsTransform(transform) 142 | self.target_transform = target_transform ## None 143 | self.data_split = data_split 144 | self.return_all_video_frames = return_all_video_frames ## True 145 | self.num_of_sampled_frames = num_of_sampled_frames ##1 146 | self.return_same_frame_indicator = return_same_frame_indicator ## False 147 | self.return_neg_frame = return_neg_frame ## False 148 | 149 | self._get_annotations() 150 | 151 | self.loader = default_loader 152 | 153 | @staticmethod 154 | def get_video_name(name): 155 | return name.split("/")[-2] 156 | 157 | @staticmethod 158 | def get_frame_id(name): 159 | return int(name.split("/")[-1][:-4]) 160 | 161 | def get_image_paths(self): 162 | print('path ############', self.data_basepath) 163 | return sorted(list(tqdm.tqdm(glob.iglob(os.path.join(self.data_basepath, "*/*.jpg"))))) 164 | 165 | def get_image_name(self, key: str, ind: int): 166 | return os.path.join(self.data_split_path, key, "%05d.jpg" % ind) 167 | 168 | def video_id_frame_id_split(self, name): 169 | return self.get_video_name(name), self.get_frame_id(name) 170 | 171 | def _get_single_frame(self, path_key, ind): 172 | return self.transform(self.loader(self.get_image_name(path_key, ind))) 173 | 174 | def _get_annotations(self): 175 | self.data_basepath = self.root 176 | self.data_split_path = os.path.join(self.data_basepath) 177 | pickle_path = os.path.join(self.data_basepath, self.data_split+ "_names.pkl") 178 | 179 | if not os.path.exists(pickle_path): 180 | print('creat new cache') 181 | images = self.get_image_paths() 182 | path_info = OrderedDict() 183 | video_names = sorted([self.video_id_frame_id_split(name) for name in images]) 184 | for vid_id, ind in video_names: 185 | if vid_id not in path_info: 186 | path_info[vid_id] = [] 187 | path_info[vid_id].append(ind) 188 | path_info = sorted([(key, val) for key, val in path_info.items()]) 189 | os.makedirs(self.data_split_path, exist_ok=True) 190 | pickle.dump(path_info, open(pickle_path, "wb")) 191 | self.path_info = pickle.load(open(pickle_path, "rb")) 192 | num_frames = int(np.sum([len(p_info[1]) for p_info in self.path_info])) 193 | print("Num for %s videos %d frames %d" % (self.data_split, len(self.path_info), num_frames)) 194 | 195 | def __getitem__(self, index): 196 | 197 | path_key, frame_ids = self.path_info[index] 198 | target = index 199 | 200 | ## index is the video number 201 | ## ind is the frame number 202 | 203 | if self.return_neg_frame: 204 | ind = np.random.choice(frame_ids, 2, False) 205 | ind, neg_ind = ind 206 | else: 207 | ind = np.random.choice(frame_ids, 1) 208 | 209 | ## ind gets a random frame out of allthe frames 210 | 211 | path = self.get_image_name(path_key, ind) 212 | image = self.loader(path) 213 | 214 | 215 | key_q = index 216 | #key_q = 1e6 * index + ind[0] 217 | #""" 218 | #image.save("Original_withoutaug_"+str(index)+".jpg") 219 | #""" 220 | 221 | ## Loading the image at the chosen random index 222 | 223 | if self.transform is not None: 224 | sample = self.two_crops_transform(image) 225 | 226 | """ 227 | img_temp = sample[0] 228 | img_temp = img_temp.detach().cpu().numpy() 229 | img_temp = np.swapaxes(img_temp,0,2) 230 | img_temp = np.swapaxes(img_temp,0,1) 231 | img_temp *= 255 232 | cv2.imwrite("Original_withaug_"+str(index)+".jpg",img_temp) 233 | 234 | img_temp = sample[1] 235 | img_temp = img_temp.detach().cpu().numpy() 236 | img_temp = np.swapaxes(img_temp,0,2) 237 | img_temp = np.swapaxes(img_temp,0,1) 238 | img_temp *= 255 239 | cv2.imwrite("Original_withaug2_"+str(index)+".jpg",img_temp) 240 | """ 241 | 242 | 243 | key_k= 0 244 | 245 | if self.target_transform is not None: 246 | target = self.target_transform(target) 247 | 248 | if self.num_of_sampled_frames != -1: 249 | sampled_frame_ids = np.random.choice(frame_ids, self.num_of_sampled_frames, False) 250 | video_frames = [self.loader(self.get_image_name(path_key, _ind)) 251 | for _ind in sampled_frame_ids] 252 | 253 | key_k = index 254 | #key_k = 1e6 * index + sampled_frame_ids[0] 255 | #""" 256 | #for video_frame_temp in video_frames: 257 | # video_frame_temp.save("PosPair_withoutaug_"+str(index)+".jpg") 258 | #""" 259 | 260 | 261 | if self.transform is not None: 262 | video_frames = [self.transform(video_frame) for video_frame in video_frames] 263 | 264 | """ 265 | for video_frame_temp in video_frames: 266 | 267 | video_frame_temp = video_frame_temp.detach().cpu().numpy() 268 | video_frame_temp = np.swapaxes(video_frame_temp,0,2) 269 | video_frame_temp = np.swapaxes(video_frame_temp,0,1) 270 | video_frame_temp *= 255 271 | cv2.imwrite("PosPair_withaug_"+str(index)+".jpg",video_frame_temp) 272 | """ 273 | 274 | 275 | if self.return_neg_frame: 276 | if neg_ind not in sampled_frame_ids: 277 | video_frames_neg = self.loader(self.get_image_name(path_key, neg_ind)) 278 | video_frames_neg = self.transform(video_frames_neg) 279 | video_frames.append(video_frames_neg) 280 | else: 281 | idx = np.where(sampled_frame_ids == neg_ind)[0][0] 282 | video_frames.append(video_frames[idx]) 283 | 284 | video_frames = torch.stack(video_frames, dim=0) 285 | 286 | metadata_q = torch.tensor([key_q],dtype=torch.int64) 287 | metadata_k = torch.tensor([key_k],dtype=torch.int64) 288 | #metadata = torch.from_numpy(metadata) 289 | 290 | return sample, target, index, video_frames,metadata_q,metadata_k 291 | elif self.return_all_video_frames: 292 | video_frames = [self.loader(self.get_image_name(path_key, _ind)) 293 | for _ind in frame_ids if _ind != ind] 294 | if self.transform is not None: 295 | video_frames = [self.transform(video_frame) for video_frame in video_frames] 296 | video_frames = torch.stack([self.transform(image), *video_frames], dim=0) 297 | return sample, target, index, video_frames 298 | else: 299 | return sample, target, index 300 | 301 | def __len__(self): 302 | # path_info: dictionary; video_name, frame_ids in video 303 | return len(self.path_info) 304 | 305 | def parse_file(dataset_adr, categories): 306 | dataset = [] 307 | with open(dataset_adr) as f: 308 | for line in f: 309 | line = line[:-1].split("/") 310 | category = "/".join(line[2:-1]) 311 | file_name = "/".join(line[2:]) 312 | if not category in categories: 313 | continue 314 | dataset.append([file_name, category]) 315 | return dataset 316 | 317 | 318 | def get_class_names(path): 319 | classes = [] 320 | with open(path) as f: 321 | for line in f: 322 | categ = "/".join(line[:-1].split("/")[2:]) 323 | classes.append(categ) 324 | class_dic = {classes[i]: i for i in range(len(classes))} 325 | return class_dic 326 | 327 | -------------------------------------------------------------------------------- /scripts/eval_ucl_pocus.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import random 5 | import argparse 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from torch.utils.data import DataLoader 10 | import torchvision.transforms as transforms 11 | import torch.optim as optim 12 | 13 | from tools.my_dataset import COVIDDataset 14 | from resnet_uscl import ResNetUSCL 15 | 16 | 17 | import argparse 18 | import builtins 19 | import os 20 | import random 21 | import shutil 22 | import time 23 | import datetime 24 | import warnings 25 | import copy 26 | 27 | import torch 28 | import torch.nn as nn 29 | import torch.nn.parallel 30 | import torch.backends.cudnn as cudnn 31 | import torch.distributed as dist 32 | import torch.optim 33 | import torch.multiprocessing as mp 34 | import torch.utils.data 35 | import torch.utils.data.distributed 36 | import torchvision.transforms as transforms 37 | import torchvision.datasets as datasets 38 | import torchvision.models as models 39 | 40 | 41 | model_names = sorted(name for name in models.__dict__ 42 | if name.islower() and not name.startswith("__") 43 | and callable(models.__dict__[name])) 44 | 45 | 46 | apex_support = False 47 | try: 48 | sys.path.append('./apex') 49 | from apex import amp 50 | print("Apex on, run on mixed precision.") 51 | apex_support = True 52 | except: 53 | print("Please install apex for mixed precision training from: https://github.com/NVIDIA/apex") 54 | apex_support = False 55 | 56 | """ 57 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 58 | print("\nRunning on:", device) 59 | 60 | if device == 'cuda': 61 | device_name = torch.cuda.get_device_name() 62 | print("The device name is:", device_name) 63 | cap = torch.cuda.get_device_capability(device=None) 64 | print("The capability of this device is:", cap, '\n') 65 | """ 66 | 67 | def set_seed(seed=1): 68 | random.seed(seed) 69 | np.random.seed(seed) 70 | torch.manual_seed(seed) 71 | torch.cuda.manual_seed(seed) 72 | 73 | def main(device,split): 74 | # ============================ step 1/5 data ============================ 75 | # transforms 76 | train_transform = transforms.Compose([ 77 | transforms.Resize((224, 224)), 78 | transforms.RandomResizedCrop(size=224, scale=(0.8, 1.0), ratio=(0.8, 1.25)), 79 | transforms.RandomHorizontalFlip(), 80 | transforms.ToTensor(), 81 | transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.25,0.25,0.25]) 82 | ]) 83 | 84 | valid_transform = transforms.Compose([ 85 | transforms.Resize((224, 224)), 86 | transforms.ToTensor(), 87 | transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.25,0.25,0.25]) 88 | ]) 89 | 90 | # MyDataset 91 | train_data = COVIDDataset(data_dir=data_dir, train=True, transform=train_transform) 92 | valid_data = COVIDDataset(data_dir=data_dir, train=False, transform=valid_transform) 93 | 94 | # DataLoder 95 | train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) 96 | valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE) 97 | 98 | # ============================ step 2/5 model ============================ 99 | 100 | if pretrained: 101 | print('\nThe ImageNet pretrained parameters are loaded.') 102 | else: 103 | print('\nThe ImageNet pretrained parameters are not loaded.') 104 | 105 | num_classes = 3 106 | 107 | net = models.__dict__["resnet50"](pretrained=pretrained) 108 | num_ftrs = net.fc.in_features 109 | 110 | net.fc = nn.Linear(num_ftrs, num_classes) 111 | # init the fc layer 112 | net.fc.weight.data.normal_(mean=0.0, std=0.01) 113 | net.fc.bias.data.zero_() 114 | 115 | num_layers=16 116 | 117 | # load from pre-trained, before DistributedDataParallel constructor 118 | if selfsup: 119 | #print("=> loading checkpoint '{}'".format(args.pretrained)) 120 | checkpoint = torch.load(state_dict_path, map_location="cpu") 121 | 122 | # rename cycle_contrast pre-trained keys 123 | state_dict = checkpoint['state_dict'] 124 | for k in list(state_dict.keys()): 125 | # retain only encoder_q up to before the embedding layer 126 | if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'): 127 | # remove prefix 128 | state_dict[k[len("module.encoder_q."):]] = state_dict[k] 129 | #print(k[len("module.encoder_q."):], k) 130 | 131 | if k.startswith('module.target_encoder.net') and not k.startswith('module.target_encoder.net.fc'): 132 | # remove prefix 133 | state_dict[k[len("module.target_encoder.net."):]] = state_dict[k] 134 | 135 | # delete renamed or unused k 136 | del state_dict[k] 137 | 138 | args.start_epoch = 0 139 | msg = net.load_state_dict(state_dict, strict=False) 140 | # assert set(msg.missing_keys) == {"fc.weight", "fc.bias"} 141 | 142 | print("=> loaded pre-trained model '{}'".format(state_dict_path)) 143 | 144 | else: 145 | print('\nThe self-supervised trained parameters are not loaded.\n') 146 | 147 | 148 | 149 | # frozen all convolutional layers 150 | # for param in net.parameters(): 151 | # param.requires_grad = False 152 | 153 | # fine-tune last 3 layers 154 | for name, param in net.named_parameters(): 155 | if not name.startswith('layer4.1'): 156 | param.requires_grad = False 157 | 158 | # add a classifier for linear evaluation 159 | num_ftrs = net.fc.in_features 160 | net.fc = nn.Linear(num_ftrs, 3) 161 | #net.fc = nn.Linear(3, 3) 162 | 163 | for name, param in net.named_parameters(): 164 | print(name, '\t', 'requires_grad=', param.requires_grad) 165 | 166 | net.to(device) 167 | print(net) 168 | 169 | # ============================ step 3/5 loss function ============================ 170 | criterion = nn.CrossEntropyLoss() # choose loss function 171 | 172 | # ============================ step 4/5 optimizer ============================ 173 | optimizer = optim.Adam(net.parameters(), lr=LR, weight_decay=weight_decay) # choose optimizer 174 | """ 175 | optimizer = torch.optim.SGD(net.parameters(), LR, 176 | momentum=0.9, 177 | weight_decay=weight_decay) 178 | """ 179 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 180 | T_max=MAX_EPOCH, 181 | eta_min=0, 182 | last_epoch=-1) # set learning rate decay strategy 183 | 184 | 185 | # ============================ step 5/5 training ============================ 186 | print('\nTraining start!\n') 187 | start = time.time() 188 | train_curve = list() 189 | valid_curve = list() 190 | max_acc = 0. 191 | reached = 0 # which epoch reached the max accuracy 192 | 193 | # the statistics of classification result: classification_results[true][pred] 194 | classification_results = [[0, 0, 0], [0, 0, 0], [0, 0, 0]] 195 | best_classification_results = None 196 | 197 | if apex_support and fp16_precision: 198 | net, optimizer = amp.initialize(net, optimizer, 199 | opt_level='O2', 200 | keep_batchnorm_fp32=True) 201 | for epoch in range(MAX_EPOCH): 202 | 203 | loss_mean = 0. 204 | correct = 0. 205 | total = 0. 206 | 207 | net.train() 208 | for i, data in enumerate(train_loader): 209 | 210 | # forward 211 | inputs, labels = data 212 | inputs = inputs.to(device) 213 | labels = labels.to(device) 214 | outputs = net(inputs) 215 | 216 | # backward 217 | optimizer.zero_grad() 218 | loss = criterion(outputs, labels) 219 | if apex_support and fp16_precision: 220 | with amp.scale_loss(loss, optimizer) as scaled_loss: 221 | scaled_loss.backward() 222 | else: 223 | loss.backward() 224 | 225 | # update weights 226 | optimizer.step() 227 | 228 | _, predicted = torch.max(outputs.data, 1) 229 | total += labels.size(0) 230 | correct += (predicted == labels).cpu().squeeze().sum().numpy() 231 | 232 | # print training information 233 | loss_mean += loss.item() 234 | train_curve.append(loss.item()) 235 | if (i+1) % log_interval == 0: 236 | loss_mean = loss_mean / log_interval 237 | print("\nTraining:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format( 238 | epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total)) 239 | loss_mean = 0. 240 | 241 | print('Learning rate this epoch:', scheduler.get_last_lr()[0]) 242 | scheduler.step() # updata learning rate 243 | 244 | # validate the model 245 | if (epoch+1) % val_interval == 0: 246 | 247 | correct_val = 0. 248 | total_val = 0. 249 | loss_val = 0. 250 | net.eval() 251 | with torch.no_grad(): 252 | for j, data in enumerate(valid_loader): 253 | inputs, labels = data 254 | inputs = inputs.to(device) 255 | labels = labels.to(device) 256 | outputs = net(inputs) 257 | loss = criterion(outputs, labels) 258 | 259 | _, predicted = torch.max(outputs.data, 1) 260 | total_val += labels.size(0) 261 | correct_val += (predicted == labels).cpu().squeeze().sum().numpy() 262 | for k in range(len(predicted)): 263 | classification_results[labels[k]][predicted[k]] += 1 # "label" is regarded as "predicted" 264 | 265 | loss_val += loss.item() 266 | 267 | acc = correct_val / total_val 268 | if acc > max_acc: # record best accuracy 269 | max_acc = acc 270 | reached = epoch 271 | best_classification_results = classification_results 272 | torch.save(net.state_dict(),'/home/somanshu/scratch/POCUS_ours_'+str(split)+'.pth.tar') 273 | classification_results = [[0, 0, 0], [0, 0, 0], [0, 0, 0]] 274 | valid_curve.append(loss_val/valid_loader.__len__()) 275 | print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format( 276 | epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val, acc)) 277 | 278 | 279 | print('\nTraining finish, the time consumption of {} epochs is {}s\n'.format(MAX_EPOCH, round(time.time() - start))) 280 | print('The max validation accuracy is: {:.2%}, reached at epoch {}.\n'.format(max_acc, reached)) 281 | 282 | print('\nThe best prediction results of the dataset:') 283 | print('Class 0 predicted as class 0:', best_classification_results[0][0]) 284 | print('Class 0 predicted as class 1:', best_classification_results[0][1]) 285 | print('Class 0 predicted as class 2:', best_classification_results[0][2]) 286 | print('Class 1 predicted as class 0:', best_classification_results[1][0]) 287 | print('Class 1 predicted as class 1:', best_classification_results[1][1]) 288 | print('Class 1 predicted as class 2:', best_classification_results[1][2]) 289 | print('Class 2 predicted as class 0:', best_classification_results[2][0]) 290 | print('Class 2 predicted as class 1:', best_classification_results[2][1]) 291 | print('Class 2 predicted as class 2:', best_classification_results[2][2]) 292 | 293 | acc0 = best_classification_results[0][0] / sum(best_classification_results[i][0] for i in range(3)) 294 | recall0 = best_classification_results[0][0] / sum(best_classification_results[0]) 295 | print('\nClass 0 accuracy:', acc0) 296 | print('Class 0 recall:', recall0) 297 | print('Class 0 F1:', 2 * acc0 * recall0 / (acc0 + recall0)) 298 | 299 | acc1 = best_classification_results[1][1] / sum(best_classification_results[i][1] for i in range(3)) 300 | recall1 = best_classification_results[1][1] / sum(best_classification_results[1]) 301 | print('\nClass 1 accuracy:', acc1) 302 | print('Class 1 recall:', recall1) 303 | print('Class 1 F1:', 2 * acc1 * recall1 / (acc1 + recall1)) 304 | 305 | acc2 = best_classification_results[2][2] / sum(best_classification_results[i][2] for i in range(3)) 306 | recall2 = best_classification_results[2][2] / sum(best_classification_results[2]) 307 | print('\nClass 2 accuracy:', acc2) 308 | print('Class 2 recall:', recall2) 309 | print('Class 2 F1:', 2 * acc2 * recall2 / (acc2 + recall2)) 310 | 311 | return best_classification_results 312 | 313 | 314 | if __name__ == '__main__': 315 | 316 | parser = argparse.ArgumentParser(description='linear evaluation') 317 | parser.add_argument('-p', '--path', default='checkpoint', help='folder of ckpt') 318 | parser.add_argument('-g', '--gpu', default=2,type=int, help='GPU ID to use') 319 | parser.add_argument('-c', '--ckpt', default='59', help='ckpt to use') 320 | args = parser.parse_args() 321 | 322 | set_seed(1) # random seed 323 | 324 | # parameters 325 | MAX_EPOCH = 100 # default = 100 326 | BATCH_SIZE = 32 # default = 32 327 | LR = 0.01 # default = 0.01 328 | weight_decay = 1e-4 # default = 1e-4 329 | log_interval = 10 330 | val_interval = 1 331 | base_path = "./eval_pretrained_model/" 332 | state_dict_path = args.path 333 | device = args.gpu 334 | 335 | if torch.cuda.is_available(): 336 | device = "cuda:"+str(device) 337 | else: 338 | device = "cpu" 339 | 340 | print(device) 341 | state_dict_path = os.path.join("/home/somanshu/scratch/cyclecontrast/output",args.path, args.ckpt) 342 | print('State dict path:', state_dict_path) 343 | fp16_precision = True 344 | pretrained = True 345 | selfsup = True 346 | 347 | # save result 348 | """ 349 | save_dir = os.path.join('result') 350 | if not os.path.exists(save_dir): 351 | os.makedirs(save_dir) 352 | resultfile = save_dir + '/my_result.txt' 353 | """ 354 | 355 | print(os.getcwd()) 356 | print(os.path.exists(state_dict_path)) 357 | #print(os.path.exists(resultfile)) 358 | #print(os.path.isdir()) 359 | if (os.path.exists(state_dict_path)): 360 | confusion_matrix = np.array([[0, 0, 0], [0, 0, 0], [0, 0, 0]]) 361 | for i in range(1, 6): 362 | print('\n' + '='*20 + 'The training of fold {} start.'.format(i) + '='*20) 363 | data_dir = "./scratch/pocus/covid_data{}.pkl".format(i) 364 | best_classification_results = main(device,i) 365 | confusion_matrix = confusion_matrix + np.array(best_classification_results) 366 | 367 | print('\nThe confusion matrix is:') 368 | print(confusion_matrix) 369 | print('\nThe precision of class 0 is:', confusion_matrix[0,0] / sum(confusion_matrix[:,0])) 370 | print('The precision of class 1 is:', confusion_matrix[1,1] / sum(confusion_matrix[:,1])) 371 | print('The precision of class 2 is:', confusion_matrix[2,2] / sum(confusion_matrix[:,2])) 372 | print('\nThe recall of class 0 is:', confusion_matrix[0,0] / sum(confusion_matrix[0])) 373 | print('The recall of class 1 is:', confusion_matrix[1,1] / sum(confusion_matrix[1])) 374 | print('The recall of class 2 is:', confusion_matrix[2,2] / sum(confusion_matrix[2])) 375 | 376 | print("****************************") 377 | print("*****************") 378 | 379 | print('\nTotal acc is:', (confusion_matrix[0,0]+confusion_matrix[1,1]+confusion_matrix[2,2])/confusion_matrix.sum()) 380 | 381 | 382 | print('\nCOVID acc is:',(confusion_matrix[0][0]/np.sum(confusion_matrix[0]))) 383 | 384 | print('\nPneumonia acc is:',(confusion_matrix[1][1]/np.sum(confusion_matrix[1]))) 385 | 386 | print('\nNormal acc is:',(confusion_matrix[2][2]/np.sum(confusion_matrix[2]))) 387 | 388 | print("****************************") 389 | print("*****************") 390 | 391 | -------------------------------------------------------------------------------- /utils/flops_counter.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2019 Sovrasov V. - All Rights Reserved 3 | * You may use, distribute and modify this code under the 4 | * terms of the MIT license. 5 | * You should have received a copy of the MIT license with 6 | * this file. If not visit https://opensource.org/licenses/MIT 7 | ''' 8 | 9 | import sys 10 | from functools import partial 11 | 12 | import numpy as np 13 | import torch 14 | import torch.nn as nn 15 | 16 | 17 | def get_model_complexity_info(model, input_res, 18 | print_per_layer_stat=True, 19 | as_strings=True, 20 | input_constructor=None, ost=sys.stdout, 21 | verbose=False, ignore_modules=[], 22 | custom_modules_hooks={}): 23 | assert type(input_res) is tuple 24 | assert len(input_res) >= 1 25 | assert isinstance(model, nn.Module) 26 | global CUSTOM_MODULES_MAPPING 27 | CUSTOM_MODULES_MAPPING = custom_modules_hooks 28 | flops_model = add_flops_counting_methods(model) 29 | flops_model.eval() 30 | flops_model.start_flops_count(ost=ost, verbose=verbose, 31 | ignore_list=ignore_modules) 32 | if input_constructor: 33 | input = input_constructor(input_res) 34 | _ = flops_model(**input) 35 | else: 36 | try: 37 | batch = torch.ones(()).new_empty((1, *input_res), 38 | dtype=next(flops_model.parameters()).dtype, 39 | device=next(flops_model.parameters()).device) 40 | except StopIteration: 41 | batch = torch.ones(()).new_empty((1, *input_res)) 42 | 43 | _ = flops_model(batch) 44 | 45 | flops_count, params_count = flops_model.compute_average_flops_cost() 46 | if print_per_layer_stat: 47 | print_model_with_flops(flops_model, flops_count, params_count, ost=ost) 48 | flops_model.stop_flops_count() 49 | CUSTOM_MODULES_MAPPING = {} 50 | 51 | if as_strings: 52 | return flops_to_string(flops_count), params_to_string(params_count) 53 | 54 | return flops_count, params_count 55 | 56 | 57 | def flops_to_string(flops, units='GMac', precision=2): 58 | if units is None: 59 | if flops // 10**9 > 0: 60 | return str(round(flops / 10.**9, precision)) + ' GMac' 61 | elif flops // 10**6 > 0: 62 | return str(round(flops / 10.**6, precision)) + ' MMac' 63 | elif flops // 10**3 > 0: 64 | return str(round(flops / 10.**3, precision)) + ' KMac' 65 | else: 66 | return str(flops) + ' Mac' 67 | else: 68 | if units == 'GMac': 69 | return str(round(flops / 10.**9, precision)) + ' ' + units 70 | elif units == 'MMac': 71 | return str(round(flops / 10.**6, precision)) + ' ' + units 72 | elif units == 'KMac': 73 | return str(round(flops / 10.**3, precision)) + ' ' + units 74 | else: 75 | return str(flops) + ' Mac' 76 | 77 | 78 | def params_to_string(params_num, units=None, precision=2): 79 | if units is None: 80 | if params_num // 10 ** 6 > 0: 81 | return str(round(params_num / 10 ** 6, 2)) + ' M' 82 | elif params_num // 10 ** 3: 83 | return str(round(params_num / 10 ** 3, 2)) + ' k' 84 | else: 85 | return str(params_num) 86 | else: 87 | if units == 'M': 88 | return str(round(params_num / 10.**6, precision)) + ' ' + units 89 | elif units == 'K': 90 | return str(round(params_num / 10.**3, precision)) + ' ' + units 91 | else: 92 | return str(params_num) 93 | 94 | 95 | def accumulate_flops(self): 96 | if is_supported_instance(self): 97 | return self.__flops__ 98 | else: 99 | sum = 0 100 | for m in self.children(): 101 | sum += m.accumulate_flops() 102 | return sum 103 | 104 | 105 | def print_model_with_flops(model, total_flops, total_params, units='GMac', 106 | precision=3, ost=sys.stdout): 107 | 108 | def accumulate_params(self): 109 | if is_supported_instance(self): 110 | return self.__params__ 111 | else: 112 | sum = 0 113 | for m in self.children(): 114 | sum += m.accumulate_params() 115 | return sum 116 | 117 | def flops_repr(self): 118 | accumulated_params_num = self.accumulate_params() 119 | accumulated_flops_cost = self.accumulate_flops() / model.__batch_counter__ 120 | return ', '.join([params_to_string(accumulated_params_num, 121 | units='M', precision=precision), 122 | '{:.3%} Params'.format(accumulated_params_num / total_params), 123 | flops_to_string(accumulated_flops_cost, 124 | units=units, precision=precision), 125 | '{:.3%} MACs'.format(accumulated_flops_cost / total_flops), 126 | self.original_extra_repr()]) 127 | 128 | def add_extra_repr(m): 129 | m.accumulate_flops = accumulate_flops.__get__(m) 130 | m.accumulate_params = accumulate_params.__get__(m) 131 | flops_extra_repr = flops_repr.__get__(m) 132 | if m.extra_repr != flops_extra_repr: 133 | m.original_extra_repr = m.extra_repr 134 | m.extra_repr = flops_extra_repr 135 | assert m.extra_repr != m.original_extra_repr 136 | 137 | def del_extra_repr(m): 138 | if hasattr(m, 'original_extra_repr'): 139 | m.extra_repr = m.original_extra_repr 140 | del m.original_extra_repr 141 | if hasattr(m, 'accumulate_flops'): 142 | del m.accumulate_flops 143 | 144 | model.apply(add_extra_repr) 145 | print(repr(model), file=ost) 146 | model.apply(del_extra_repr) 147 | 148 | 149 | def get_model_parameters_number(model): 150 | params_num = sum(p.numel() for p in model.parameters() if p.requires_grad) 151 | return params_num 152 | 153 | 154 | def add_flops_counting_methods(net_main_module): 155 | # adding additional methods to the existing module object, 156 | # this is done this way so that each function has access to self object 157 | net_main_module.start_flops_count = start_flops_count.__get__(net_main_module) 158 | net_main_module.stop_flops_count = stop_flops_count.__get__(net_main_module) 159 | net_main_module.reset_flops_count = reset_flops_count.__get__(net_main_module) 160 | net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__( 161 | net_main_module) 162 | 163 | net_main_module.reset_flops_count() 164 | 165 | return net_main_module 166 | 167 | 168 | def compute_average_flops_cost(self): 169 | """ 170 | A method that will be available after add_flops_counting_methods() is called 171 | on a desired net object. 172 | 173 | Returns current mean flops consumption per image. 174 | 175 | """ 176 | 177 | for m in self.modules(): 178 | m.accumulate_flops = accumulate_flops.__get__(m) 179 | 180 | flops_sum = self.accumulate_flops() 181 | 182 | for m in self.modules(): 183 | if hasattr(m, 'accumulate_flops'): 184 | del m.accumulate_flops 185 | 186 | params_sum = get_model_parameters_number(self) 187 | return flops_sum / self.__batch_counter__, params_sum 188 | 189 | 190 | def start_flops_count(self, **kwargs): 191 | """ 192 | A method that will be available after add_flops_counting_methods() is called 193 | on a desired net object. 194 | 195 | Activates the computation of mean flops consumption per image. 196 | Call it before you run the network. 197 | 198 | """ 199 | add_batch_counter_hook_function(self) 200 | 201 | seen_types = set() 202 | 203 | def add_flops_counter_hook_function(module, ost, verbose, ignore_list): 204 | if type(module) in ignore_list: 205 | seen_types.add(type(module)) 206 | if is_supported_instance(module): 207 | module.__params__ = 0 208 | elif is_supported_instance(module): 209 | if hasattr(module, '__flops_handle__'): 210 | return 211 | if type(module) in CUSTOM_MODULES_MAPPING: 212 | handle = module.register_forward_hook( 213 | CUSTOM_MODULES_MAPPING[type(module)]) 214 | else: 215 | handle = module.register_forward_hook(MODULES_MAPPING[type(module)]) 216 | module.__flops_handle__ = handle 217 | seen_types.add(type(module)) 218 | else: 219 | if verbose and not type(module) in (nn.Sequential, nn.ModuleList) and \ 220 | not type(module) in seen_types: 221 | print('Warning: module ' + type(module).__name__ + 222 | ' is treated as a zero-op.', file=ost) 223 | seen_types.add(type(module)) 224 | 225 | self.apply(partial(add_flops_counter_hook_function, **kwargs)) 226 | 227 | 228 | def stop_flops_count(self): 229 | """ 230 | A method that will be available after add_flops_counting_methods() is called 231 | on a desired net object. 232 | 233 | Stops computing the mean flops consumption per image. 234 | Call whenever you want to pause the computation. 235 | 236 | """ 237 | remove_batch_counter_hook_function(self) 238 | self.apply(remove_flops_counter_hook_function) 239 | 240 | 241 | def reset_flops_count(self): 242 | """ 243 | A method that will be available after add_flops_counting_methods() is called 244 | on a desired net object. 245 | 246 | Resets statistics computed so far. 247 | 248 | """ 249 | add_batch_counter_variables_or_reset(self) 250 | self.apply(add_flops_counter_variable_or_reset) 251 | 252 | 253 | # ---- Internal functions 254 | def empty_flops_counter_hook(module, input, output): 255 | module.__flops__ += 0 256 | 257 | 258 | def upsample_flops_counter_hook(module, input, output): 259 | output_size = output[0] 260 | batch_size = output_size.shape[0] 261 | output_elements_count = batch_size 262 | for val in output_size.shape[1:]: 263 | output_elements_count *= val 264 | module.__flops__ += int(output_elements_count) 265 | 266 | 267 | def relu_flops_counter_hook(module, input, output): 268 | active_elements_count = output.numel() 269 | module.__flops__ += int(active_elements_count) 270 | 271 | 272 | def linear_flops_counter_hook(module, input, output): 273 | input = input[0] 274 | # pytorch checks dimensions, so here we don't care much 275 | output_last_dim = output.shape[-1] 276 | bias_flops = output_last_dim if module.bias is not None else 0 277 | module.__flops__ += int(np.prod(input.shape) * output_last_dim + bias_flops) 278 | 279 | 280 | def pool_flops_counter_hook(module, input, output): 281 | input = input[0] 282 | module.__flops__ += int(np.prod(input.shape)) 283 | 284 | 285 | def bn_flops_counter_hook(module, input, output): 286 | input = input[0] 287 | 288 | batch_flops = np.prod(input.shape) 289 | if module.affine: 290 | batch_flops *= 2 291 | module.__flops__ += int(batch_flops) 292 | 293 | 294 | def conv_flops_counter_hook(conv_module, input, output): 295 | # Can have multiple inputs, getting the first one 296 | input = input[0] 297 | 298 | batch_size = input.shape[0] 299 | output_dims = list(output.shape[2:]) 300 | 301 | kernel_dims = list(conv_module.kernel_size) 302 | in_channels = conv_module.in_channels 303 | out_channels = conv_module.out_channels 304 | groups = conv_module.groups 305 | 306 | filters_per_channel = out_channels // groups 307 | conv_per_position_flops = int(np.prod(kernel_dims)) * \ 308 | in_channels * filters_per_channel 309 | 310 | active_elements_count = batch_size * int(np.prod(output_dims)) 311 | 312 | overall_conv_flops = conv_per_position_flops * active_elements_count 313 | 314 | bias_flops = 0 315 | 316 | if conv_module.bias is not None: 317 | 318 | bias_flops = out_channels * active_elements_count 319 | 320 | overall_flops = overall_conv_flops + bias_flops 321 | 322 | conv_module.__flops__ += int(overall_flops) 323 | 324 | 325 | def batch_counter_hook(module, input, output): 326 | batch_size = 1 327 | if len(input) > 0: 328 | # Can have multiple inputs, getting the first one 329 | input = input[0] 330 | batch_size = len(input) 331 | else: 332 | pass 333 | print('Warning! No positional inputs found for a module,' 334 | ' assuming batch size is 1.') 335 | module.__batch_counter__ += batch_size 336 | 337 | 338 | def rnn_flops(flops, rnn_module, w_ih, w_hh, input_size): 339 | # matrix matrix mult ih state and internal state 340 | flops += w_ih.shape[0]*w_ih.shape[1] 341 | # matrix matrix mult hh state and internal state 342 | flops += w_hh.shape[0]*w_hh.shape[1] 343 | if isinstance(rnn_module, (nn.RNN, nn.RNNCell)): 344 | # add both operations 345 | flops += rnn_module.hidden_size 346 | elif isinstance(rnn_module, (nn.GRU, nn.GRUCell)): 347 | # hadamard of r 348 | flops += rnn_module.hidden_size 349 | # adding operations from both states 350 | flops += rnn_module.hidden_size*3 351 | # last two hadamard product and add 352 | flops += rnn_module.hidden_size*3 353 | elif isinstance(rnn_module, (nn.LSTM, nn.LSTMCell)): 354 | # adding operations from both states 355 | flops += rnn_module.hidden_size*4 356 | # two hadamard product and add for C state 357 | flops += rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size 358 | # final hadamard 359 | flops += rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size 360 | return flops 361 | 362 | 363 | def rnn_flops_counter_hook(rnn_module, input, output): 364 | """ 365 | Takes into account batch goes at first position, contrary 366 | to pytorch common rule (but actually it doesn't matter). 367 | IF sigmoid and tanh are made hard, only a comparison FLOPS should be accurate 368 | """ 369 | flops = 0 370 | # input is a tuple containing a sequence to process and (optionally) hidden state 371 | inp = input[0] 372 | batch_size = inp.shape[0] 373 | seq_length = inp.shape[1] 374 | num_layers = rnn_module.num_layers 375 | 376 | for i in range(num_layers): 377 | w_ih = rnn_module.__getattr__('weight_ih_l' + str(i)) 378 | w_hh = rnn_module.__getattr__('weight_hh_l' + str(i)) 379 | if i == 0: 380 | input_size = rnn_module.input_size 381 | else: 382 | input_size = rnn_module.hidden_size 383 | flops = rnn_flops(flops, rnn_module, w_ih, w_hh, input_size) 384 | if rnn_module.bias: 385 | b_ih = rnn_module.__getattr__('bias_ih_l' + str(i)) 386 | b_hh = rnn_module.__getattr__('bias_hh_l' + str(i)) 387 | flops += b_ih.shape[0] + b_hh.shape[0] 388 | 389 | flops *= batch_size 390 | flops *= seq_length 391 | if rnn_module.bidirectional: 392 | flops *= 2 393 | rnn_module.__flops__ += int(flops) 394 | 395 | 396 | def rnn_cell_flops_counter_hook(rnn_cell_module, input, output): 397 | flops = 0 398 | inp = input[0] 399 | batch_size = inp.shape[0] 400 | w_ih = rnn_cell_module.__getattr__('weight_ih') 401 | w_hh = rnn_cell_module.__getattr__('weight_hh') 402 | input_size = inp.shape[1] 403 | flops = rnn_flops(flops, rnn_cell_module, w_ih, w_hh, input_size) 404 | if rnn_cell_module.bias: 405 | b_ih = rnn_cell_module.__getattr__('bias_ih') 406 | b_hh = rnn_cell_module.__getattr__('bias_hh') 407 | flops += b_ih.shape[0] + b_hh.shape[0] 408 | 409 | flops *= batch_size 410 | rnn_cell_module.__flops__ += int(flops) 411 | 412 | 413 | def add_batch_counter_variables_or_reset(module): 414 | 415 | module.__batch_counter__ = 0 416 | 417 | 418 | def add_batch_counter_hook_function(module): 419 | if hasattr(module, '__batch_counter_handle__'): 420 | return 421 | 422 | handle = module.register_forward_hook(batch_counter_hook) 423 | module.__batch_counter_handle__ = handle 424 | 425 | 426 | def remove_batch_counter_hook_function(module): 427 | if hasattr(module, '__batch_counter_handle__'): 428 | module.__batch_counter_handle__.remove() 429 | del module.__batch_counter_handle__ 430 | 431 | 432 | def add_flops_counter_variable_or_reset(module): 433 | if is_supported_instance(module): 434 | if hasattr(module, '__flops__') or hasattr(module, '__params__'): 435 | print('Warning: variables __flops__ or __params__ are already ' 436 | 'defined for the module' + type(module).__name__ + 437 | ' ptflops can affect your code!') 438 | module.__flops__ = 0 439 | module.__params__ = get_model_parameters_number(module) 440 | 441 | 442 | CUSTOM_MODULES_MAPPING = {} 443 | 444 | MODULES_MAPPING = { 445 | # convolutions 446 | nn.Conv1d: conv_flops_counter_hook, 447 | nn.Conv2d: conv_flops_counter_hook, 448 | nn.Conv3d: conv_flops_counter_hook, 449 | # activations 450 | nn.ReLU: relu_flops_counter_hook, 451 | nn.PReLU: relu_flops_counter_hook, 452 | nn.ELU: relu_flops_counter_hook, 453 | nn.LeakyReLU: relu_flops_counter_hook, 454 | nn.ReLU6: relu_flops_counter_hook, 455 | # poolings 456 | nn.MaxPool1d: pool_flops_counter_hook, 457 | nn.AvgPool1d: pool_flops_counter_hook, 458 | nn.AvgPool2d: pool_flops_counter_hook, 459 | nn.MaxPool2d: pool_flops_counter_hook, 460 | nn.MaxPool3d: pool_flops_counter_hook, 461 | nn.AvgPool3d: pool_flops_counter_hook, 462 | nn.AdaptiveMaxPool1d: pool_flops_counter_hook, 463 | nn.AdaptiveAvgPool1d: pool_flops_counter_hook, 464 | nn.AdaptiveMaxPool2d: pool_flops_counter_hook, 465 | nn.AdaptiveAvgPool2d: pool_flops_counter_hook, 466 | nn.AdaptiveMaxPool3d: pool_flops_counter_hook, 467 | nn.AdaptiveAvgPool3d: pool_flops_counter_hook, 468 | # BNs 469 | nn.BatchNorm1d: bn_flops_counter_hook, 470 | nn.BatchNorm2d: bn_flops_counter_hook, 471 | nn.BatchNorm3d: bn_flops_counter_hook, 472 | # FC 473 | nn.Linear: linear_flops_counter_hook, 474 | # Upscale 475 | nn.Upsample: upsample_flops_counter_hook, 476 | # Deconvolution 477 | nn.ConvTranspose1d: conv_flops_counter_hook, 478 | nn.ConvTranspose2d: conv_flops_counter_hook, 479 | nn.ConvTranspose3d: conv_flops_counter_hook, 480 | # RNN 481 | nn.RNN: rnn_flops_counter_hook, 482 | nn.GRU: rnn_flops_counter_hook, 483 | nn.LSTM: rnn_flops_counter_hook, 484 | nn.RNNCell: rnn_cell_flops_counter_hook, 485 | nn.LSTMCell: rnn_cell_flops_counter_hook, 486 | nn.GRUCell: rnn_cell_flops_counter_hook 487 | } 488 | 489 | 490 | def is_supported_instance(module): 491 | if type(module) in MODULES_MAPPING or type(module) in CUSTOM_MODULES_MAPPING: 492 | return True 493 | return False 494 | 495 | 496 | def remove_flops_counter_hook_function(module): 497 | if is_supported_instance(module): 498 | if hasattr(module, '__flops_handle__'): 499 | module.__flops_handle__.remove() 500 | del module.__flops_handle__ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | Section 1 -- Definitions. 71 | 72 | a. Adapted Material means material subject to Copyright and Similar 73 | Rights that is derived from or based upon the Licensed Material 74 | and in which the Licensed Material is translated, altered, 75 | arranged, transformed, or otherwise modified in a manner requiring 76 | permission under the Copyright and Similar Rights held by the 77 | Licensor. For purposes of this Public License, where the Licensed 78 | Material is a musical work, performance, or sound recording, 79 | Adapted Material is always produced where the Licensed Material is 80 | synched in timed relation with a moving image. 81 | 82 | b. Adapter's License means the license You apply to Your Copyright 83 | and Similar Rights in Your contributions to Adapted Material in 84 | accordance with the terms and conditions of this Public License. 85 | 86 | c. Copyright and Similar Rights means copyright and/or similar rights 87 | closely related to copyright including, without limitation, 88 | performance, broadcast, sound recording, and Sui Generis Database 89 | Rights, without regard to how the rights are labeled or 90 | categorized. For purposes of this Public License, the rights 91 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 92 | Rights. 93 | d. Effective Technological Measures means those measures that, in the 94 | absence of proper authority, may not be circumvented under laws 95 | fulfilling obligations under Article 11 of the WIPO Copyright 96 | Treaty adopted on December 20, 1996, and/or similar international 97 | agreements. 98 | 99 | e. Exceptions and Limitations means fair use, fair dealing, and/or 100 | any other exception or limitation to Copyright and Similar Rights 101 | that applies to Your use of the Licensed Material. 102 | 103 | f. Licensed Material means the artistic or literary work, database, 104 | or other material to which the Licensor applied this Public 105 | License. 106 | 107 | g. Licensed Rights means the rights granted to You subject to the 108 | terms and conditions of this Public License, which are limited to 109 | all Copyright and Similar Rights that apply to Your use of the 110 | Licensed Material and that the Licensor has authority to license. 111 | 112 | h. Licensor means the individual(s) or entity(ies) granting rights 113 | under this Public License. 114 | 115 | i. NonCommercial means not primarily intended for or directed towards 116 | commercial advantage or monetary compensation. For purposes of 117 | this Public License, the exchange of the Licensed Material for 118 | other material subject to Copyright and Similar Rights by digital 119 | file-sharing or similar means is NonCommercial provided there is 120 | no payment of monetary compensation in connection with the 121 | exchange. 122 | 123 | j. Share means to provide material to the public by any means or 124 | process that requires permission under the Licensed Rights, such 125 | as reproduction, public display, public performance, distribution, 126 | dissemination, communication, or importation, and to make material 127 | available to the public including in ways that members of the 128 | public may access the material from a place and at a time 129 | individually chosen by them. 130 | 131 | k. Sui Generis Database Rights means rights other than copyright 132 | resulting from Directive 96/9/EC of the European Parliament and of 133 | the Council of 11 March 1996 on the legal protection of databases, 134 | as amended and/or succeeded, as well as other essentially 135 | equivalent rights anywhere in the world. 136 | 137 | l. You means the individual or entity exercising the Licensed Rights 138 | under this Public License. Your has a corresponding meaning. 139 | 140 | Section 2 -- Scope. 141 | 142 | a. License grant. 143 | 144 | 1. Subject to the terms and conditions of this Public License, 145 | the Licensor hereby grants You a worldwide, royalty-free, 146 | non-sublicensable, non-exclusive, irrevocable license to 147 | exercise the Licensed Rights in the Licensed Material to: 148 | 149 | a. reproduce and Share the Licensed Material, in whole or 150 | in part, for NonCommercial purposes only; and 151 | 152 | b. produce, reproduce, and Share Adapted Material for 153 | NonCommercial purposes only. 154 | 155 | 2. Exceptions and Limitations. For the avoidance of doubt, where 156 | Exceptions and Limitations apply to Your use, this Public 157 | License does not apply, and You do not need to comply with 158 | its terms and conditions. 159 | 160 | 3. Term. The term of this Public License is specified in Section 161 | 6(a). 162 | 163 | 4. Media and formats; technical modifications allowed. The 164 | Licensor authorizes You to exercise the Licensed Rights in 165 | all media and formats whether now known or hereafter created, 166 | and to make technical modifications necessary to do so. The 167 | Licensor waives and/or agrees not to assert any right or 168 | authority to forbid You from making technical modifications 169 | necessary to exercise the Licensed Rights, including 170 | technical modifications necessary to circumvent Effective 171 | Technological Measures. For purposes of this Public License, 172 | simply making modifications authorized by this Section 2(a) 173 | (4) never produces Adapted Material. 174 | 175 | 5. Downstream recipients. 176 | 177 | a. Offer from the Licensor -- Licensed Material. Every 178 | recipient of the Licensed Material automatically 179 | receives an offer from the Licensor to exercise the 180 | Licensed Rights under the terms and conditions of this 181 | Public License. 182 | 183 | b. No downstream restrictions. You may not offer or impose 184 | any additional or different terms or conditions on, or 185 | apply any Effective Technological Measures to, the 186 | Licensed Material if doing so restricts exercise of the 187 | Licensed Rights by any recipient of the Licensed 188 | Material. 189 | 190 | 6. No endorsement. Nothing in this Public License constitutes or 191 | may be construed as permission to assert or imply that You 192 | are, or that Your use of the Licensed Material is, connected 193 | with, or sponsored, endorsed, or granted official status by, 194 | the Licensor or others designated to receive attribution as 195 | provided in Section 3(a)(1)(A)(i). 196 | 197 | b. Other rights. 198 | 199 | 1. Moral rights, such as the right of integrity, are not 200 | licensed under this Public License, nor are publicity, 201 | privacy, and/or other similar personality rights; however, to 202 | the extent possible, the Licensor waives and/or agrees not to 203 | assert any such rights held by the Licensor to the limited 204 | extent necessary to allow You to exercise the Licensed 205 | Rights, but not otherwise. 206 | 207 | 2. Patent and trademark rights are not licensed under this 208 | Public License. 209 | 210 | 3. To the extent possible, the Licensor waives any right to 211 | collect royalties from You for the exercise of the Licensed 212 | Rights, whether directly or through a collecting society 213 | under any voluntary or waivable statutory or compulsory 214 | licensing scheme. In all other cases the Licensor expressly 215 | reserves any right to collect such royalties, including when 216 | the Licensed Material is used other than for NonCommercial 217 | purposes. 218 | 219 | Section 3 -- License Conditions. 220 | 221 | Your exercise of the Licensed Rights is expressly made subject to the 222 | following conditions. 223 | 224 | a. Attribution. 225 | 226 | 1. If You Share the Licensed Material (including in modified 227 | form), You must: 228 | 229 | a. retain the following if it is supplied by the Licensor 230 | with the Licensed Material: 231 | 232 | i. identification of the creator(s) of the Licensed 233 | Material and any others designated to receive 234 | attribution, in any reasonable manner requested by 235 | the Licensor (including by pseudonym if 236 | designated); 237 | 238 | ii. a copyright notice; 239 | 240 | iii. a notice that refers to this Public License; 241 | 242 | iv. a notice that refers to the disclaimer of 243 | warranties; 244 | 245 | v. a URI or hyperlink to the Licensed Material to the 246 | extent reasonably practicable; 247 | 248 | b. indicate if You modified the Licensed Material and 249 | retain an indication of any previous modifications; and 250 | 251 | c. indicate the Licensed Material is licensed under this 252 | Public License, and include the text of, or the URI or 253 | hyperlink to, this Public License. 254 | 255 | 2. You may satisfy the conditions in Section 3(a)(1) in any 256 | reasonable manner based on the medium, means, and context in 257 | which You Share the Licensed Material. For example, it may be 258 | reasonable to satisfy the conditions by providing a URI or 259 | hyperlink to a resource that includes the required 260 | information. 261 | 262 | 3. If requested by the Licensor, You must remove any of the 263 | information required by Section 3(a)(1)(A) to the extent 264 | reasonably practicable. 265 | 266 | 4. If You Share Adapted Material You produce, the Adapter's 267 | License You apply must not prevent recipients of the Adapted 268 | Material from complying with this Public License. 269 | 270 | Section 4 -- Sui Generis Database Rights. 271 | 272 | Where the Licensed Rights include Sui Generis Database Rights that 273 | apply to Your use of the Licensed Material: 274 | 275 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 276 | to extract, reuse, reproduce, and Share all or a substantial 277 | portion of the contents of the database for NonCommercial purposes 278 | only; 279 | 280 | b. if You include all or a substantial portion of the database 281 | contents in a database in which You have Sui Generis Database 282 | Rights, then the database in which You have Sui Generis Database 283 | Rights (but not its individual contents) is Adapted Material; and 284 | 285 | c. You must comply with the conditions in Section 3(a) if You Share 286 | all or a substantial portion of the contents of the database. 287 | 288 | For the avoidance of doubt, this Section 4 supplements and does not 289 | replace Your obligations under this Public License where the Licensed 290 | Rights include other Copyright and Similar Rights. 291 | 292 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 293 | 294 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 295 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 296 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 297 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 298 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 299 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 300 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 301 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 302 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 303 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 304 | 305 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 306 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 307 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 308 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 309 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 310 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 311 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 312 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 313 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 314 | 315 | c. The disclaimer of warranties and limitation of liability provided 316 | above shall be interpreted in a manner that, to the extent 317 | possible, most closely approximates an absolute disclaimer and 318 | waiver of all liability. 319 | 320 | Section 6 -- Term and Termination. 321 | 322 | a. This Public License applies for the term of the Copyright and 323 | Similar Rights licensed here. However, if You fail to comply with 324 | this Public License, then Your rights under this Public License 325 | terminate automatically. 326 | 327 | b. Where Your right to use the Licensed Material has terminated under 328 | Section 6(a), it reinstates: 329 | 330 | 1. automatically as of the date the violation is cured, provided 331 | it is cured within 30 days of Your discovery of the 332 | violation; or 333 | 334 | 2. upon express reinstatement by the Licensor. 335 | 336 | For the avoidance of doubt, this Section 6(b) does not affect any 337 | right the Licensor may have to seek remedies for Your violations 338 | of this Public License. 339 | 340 | c. For the avoidance of doubt, the Licensor may also offer the 341 | Licensed Material under separate terms or conditions or stop 342 | distributing the Licensed Material at any time; however, doing so 343 | will not terminate this Public License. 344 | 345 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 346 | License. 347 | 348 | Section 7 -- Other Terms and Conditions. 349 | 350 | a. The Licensor shall not be bound by any additional or different 351 | terms or conditions communicated by You unless expressly agreed. 352 | 353 | b. Any arrangements, understandings, or agreements regarding the 354 | Licensed Material not stated herein are separate from and 355 | independent of the terms and conditions of this Public License. 356 | 357 | Section 8 -- Interpretation. 358 | 359 | a. For the avoidance of doubt, this Public License does not, and 360 | shall not be interpreted to, reduce, limit, restrict, or impose 361 | conditions on any use of the Licensed Material that could lawfully 362 | be made without permission under this Public License. 363 | 364 | b. To the extent possible, if any provision of this Public License is 365 | deemed unenforceable, it shall be automatically reformed to the 366 | minimum extent necessary to make it enforceable. If the provision 367 | cannot be reformed, it shall be severed from this Public License 368 | without affecting the enforceability of the remaining terms and 369 | conditions. 370 | 371 | c. No term or condition of this Public License will be waived and no 372 | failure to comply consented to unless expressly agreed to by the 373 | Licensor. 374 | 375 | d. Nothing in this Public License constitutes or may be interpreted 376 | as a limitation upon, or waiver of, any privileges and immunities 377 | that apply to the Licensor or You, including from the legal 378 | processes of any jurisdiction or authority. 379 | 380 | ======================================================================= 381 | 382 | Creative Commons is not a party to its public 383 | licenses. Notwithstanding, Creative Commons may elect to apply one of 384 | its public licenses to material it publishes and in those instances 385 | will be considered the “Licensor.” The text of the Creative Commons 386 | public licenses is dedicated to the public domain under the CC0 Public 387 | Domain Dedication. Except for the limited purpose of indicating that 388 | material is shared under a Creative Commons public license or as 389 | otherwise permitted by the Creative Commons policies published at 390 | creativecommons.org/policies, Creative Commons does not authorize the 391 | use of the trademark "Creative Commons" or any other trademark or logo 392 | of Creative Commons without its prior written consent including, 393 | without limitation, in connection with any unauthorized modifications 394 | to any of its public licenses or any other arrangements, 395 | understandings, or agreements concerning use of licensed material. For 396 | the avoidance of doubt, this paragraph does not form part of the 397 | public licenses. 398 | 399 | Creative Commons may be contacted at creativecommons.org. 400 | -------------------------------------------------------------------------------- /lincls_bin.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Just to set git 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 4 | import argparse 5 | import builtins 6 | import os 7 | import random 8 | import shutil 9 | import time 10 | import datetime 11 | import warnings 12 | import copy 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.parallel 17 | import torch.backends.cudnn as cudnn 18 | import torch.distributed as dist 19 | import torch.optim 20 | import torch.multiprocessing as mp 21 | import torch.utils.data 22 | import torch.utils.data.distributed 23 | import torchvision.transforms as transforms 24 | import torchvision.datasets as datasets 25 | import torchvision.models as models 26 | 27 | import numpy as np 28 | from sklearn.metrics import accuracy_score, confusion_matrix, matthews_corrcoef 29 | 30 | from dataloader import GbUsgDataSet 31 | 32 | from ucl.datasets import ImageNetVal 33 | 34 | #import neptune.new as neptune 35 | 36 | 37 | model_names = sorted(name for name in models.__dict__ 38 | if name.islower() and not name.startswith("__") 39 | and callable(models.__dict__[name])) 40 | 41 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 42 | #parser.add_argument('data', metavar='DIR', 43 | # help='path to dataset') 44 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', 45 | choices=model_names, 46 | help='model architecture: ' + 47 | ' | '.join(model_names) + 48 | ' (default: resnet50)') 49 | parser.add_argument('--img_dir', dest="img_dir", default="data/gb_imgs") 50 | parser.add_argument('--train_list', dest="train_list", default="data/cls_split/train.txt") 51 | parser.add_argument('--val_list', dest="val_list", default="data/cls_split/val.txt") 52 | parser.add_argument('--step', dest="step", default=1, type=int) 53 | parser.add_argument('--warmup', dest="warmup", default=5, type=int) 54 | parser.add_argument('--gradual_unfreeze', action='store_true') 55 | parser.add_argument('--cos_lr', action='store_true') 56 | 57 | parser.add_argument('--num_classes', default=2, type=int, metavar='NC', 58 | help='number of output classes for finetuning') 59 | parser.add_argument('--fc_type', default=1, type=int, metavar='FC', 60 | help='fc_layer type') 61 | parser.add_argument('--last_layer', action="store_true", 62 | help='whether only last layer is trainable') 63 | parser.add_argument('-j', '--workers', default=32, type=int, metavar='N', 64 | help='number of data loading workers (default: 32)') 65 | parser.add_argument('--epochs', default=50, type=int, metavar='N', 66 | help='number of total epochs to run') 67 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 68 | help='manual epoch number (useful on restarts)') 69 | parser.add_argument('-b', '--batch-size', default=256, type=int, 70 | metavar='N', 71 | help='mini-batch size (default: 256), this is the total ' 72 | 'batch size of all GPUs on the current node when ' 73 | 'using Data Parallel or Distributed Data Parallel') 74 | parser.add_argument('--lr', '--learning-rate', default=30., type=float, 75 | metavar='LR', help='initial learning rate', dest='lr') 76 | parser.add_argument('--schedule', default=[60, 80], nargs='*', type=int, 77 | help='learning rate schedule (when to drop lr by a ratio)') 78 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 79 | help='momentum') 80 | parser.add_argument('--wd', '--weight-decay', default=0., type=float, 81 | metavar='W', help='weight decay (default: 0.)', 82 | dest='weight_decay') 83 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 84 | help='path to latest checkpoint (default: none)') 85 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 86 | help='evaluate model on validation set') 87 | parser.add_argument('--world-size', default=-1, type=int, 88 | help='number of nodes for distributed training') 89 | parser.add_argument('--rank', default=-1, type=int, 90 | help='node rank for distributed training') 91 | # parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 92 | # help='url used to set up distributed training') 93 | parser.add_argument('--dist-url', default='env://', type=str, 94 | help='url used to set up distributed training') 95 | parser.add_argument('--dist-backend', default='nccl', type=str, 96 | help='distributed backend') 97 | parser.add_argument('--seed', default=None, type=int, 98 | help='seed for initializing training. ') 99 | parser.add_argument('--gpu', default=None, type=int, 100 | help='GPU id to use.') 101 | parser.add_argument('--multiprocessing-distributed', action='store_true', 102 | help='Use multi-processing distributed training to launch ' 103 | 'N processes per node, which has N GPUs. This is the ' 104 | 'fastest way to use PyTorch for either single node or ' 105 | 'multi node data parallel training') 106 | 107 | parser.add_argument('--pretrained', default='', type=str, 108 | help='path to cycle_contrast pretrained checkpoint') 109 | 110 | parser.add_argument('--dataset', default='gbc', type=str) 111 | parser.add_argument('--save-dir', default='', type=str) 112 | 113 | parser.add_argument('--eval-interval', default=1, type=int) 114 | 115 | best_acc1 = 0 116 | 117 | 118 | def main(): 119 | args = parser.parse_args() 120 | 121 | if args.seed is not None: 122 | random.seed(args.seed) 123 | torch.manual_seed(args.seed) 124 | cudnn.deterministic = True 125 | warnings.warn('You have chosen to seed training. ' 126 | 'This will turn on the CUDNN deterministic setting, ' 127 | 'which can slow down your training considerably! ' 128 | 'You may see unexpected behavior when restarting ' 129 | 'from checkpoints.') 130 | 131 | #if args.gpu is not None: 132 | # warnings.warn('You have chosen a specific GPU. This will completely ' 133 | # 'disable data parallelism.') 134 | 135 | if args.dist_url == "env://" and args.world_size == -1: 136 | args.world_size = int(os.environ["WORLD_SIZE"]) 137 | 138 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 139 | 140 | ngpus_per_node = torch.cuda.device_count() 141 | if args.multiprocessing_distributed: 142 | # Since we have ngpus_per_node processes per node, the total world_size 143 | # needs to be adjusted accordingly 144 | args.world_size = ngpus_per_node * args.world_size 145 | # Use torch.multiprocessing.spawn to launch distributed processes: the 146 | # main_worker process function 147 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 148 | else: 149 | # Simply call main_worker function 150 | main_worker(args.gpu, ngpus_per_node, args) 151 | 152 | 153 | def main_worker(gpu, ngpus_per_node, args): 154 | global best_acc1 155 | args.gpu = gpu 156 | 157 | # suppress printing if not master 158 | if args.multiprocessing_distributed and args.gpu != 0: 159 | def print_pass(*args): 160 | pass 161 | builtins.print = print_pass 162 | 163 | #if args.gpu is not None: 164 | # print("Use GPU: {} for training".format(args.gpu)) 165 | 166 | if args.distributed: 167 | if args.dist_url == "env://" and args.rank == -1: 168 | args.rank = int(os.environ["RANK"]) 169 | if args.multiprocessing_distributed: 170 | # For multiprocessing distributed training, rank needs to be the 171 | # global rank among all the processes 172 | args.rank = args.rank * ngpus_per_node + gpu 173 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 174 | world_size=args.world_size, rank=args.rank) 175 | # create model 176 | #print("=> creating model '{}'".format(args.arch)) 177 | 178 | num_classes = args.num_classes 179 | model = models.__dict__[args.arch](pretrained=True) 180 | num_ftrs = model.fc.in_features 181 | if args.last_layer or args.gradual_unfreeze: 182 | for name, param in model.named_parameters(): 183 | param.requires_grad = False 184 | 185 | if args.fc_type==1: 186 | model.fc = nn.Linear(num_ftrs, num_classes) 187 | # init the fc layer 188 | model.fc.weight.data.normal_(mean=0.0, std=0.01) 189 | model.fc.bias.data.zero_() 190 | else: 191 | model.fc = nn.Sequential( 192 | nn.Linear(num_ftrs, 256), 193 | nn.ReLU(inplace=True), 194 | nn.Dropout(0.4), 195 | nn.Linear(256, num_classes) 196 | ) 197 | # init the fc layer 198 | for i in [0, 3]: 199 | model.fc[i].weight.data.normal_(mean=0.0, std=0.01) 200 | model.fc[i].bias.data.zero_() 201 | if args.arch == "resnet50": 202 | num_layers=16 203 | 204 | # load from pre-trained, before DistributedDataParallel constructor 205 | if args.pretrained: 206 | if os.path.isfile(args.pretrained): 207 | #print("=> loading checkpoint '{}'".format(args.pretrained)) 208 | checkpoint = torch.load(args.pretrained, map_location="cpu") 209 | 210 | if not args.evaluate: 211 | # rename cycle_contrast pre-trained keys 212 | state_dict = checkpoint['state_dict'] 213 | for k in list(state_dict.keys()): 214 | # retain only encoder_q up to before the embedding layer 215 | if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'): 216 | # remove prefix 217 | state_dict[k[len("module.encoder_q."):]] = state_dict[k] 218 | #print(k[len("module.encoder_q."):], k) 219 | 220 | if k.startswith('module.target_encoder.net') and not k.startswith('module.target_encoder.net.fc'): 221 | # remove prefix 222 | state_dict[k[len("module.target_encoder.net."):]] = state_dict[k] 223 | 224 | # delete renamed or unused k 225 | del state_dict[k] 226 | 227 | args.start_epoch = 0 228 | msg = model.load_state_dict(state_dict, strict=False) 229 | 230 | else: 231 | state_dict = checkpoint['state_dict'] 232 | args.start_epoch = 0 233 | msg = model.load_state_dict(state_dict, strict=False) 234 | 235 | else: 236 | print("=> no checkpoint found at '{}'".format(args.pretrained)) 237 | 238 | if args.distributed: 239 | # For multiprocessing distributed, DistributedDataParallel constructor 240 | # should always set the single device scope, otherwise, 241 | # DistributedDataParallel will use all available devices. 242 | if args.gpu is not None: 243 | torch.cuda.set_device(args.gpu) 244 | model.cuda(args.gpu) 245 | # When using a single GPU per process and per 246 | # DistributedDataParallel, we need to divide the batch size 247 | # ourselves based on the total number of GPUs we have 248 | args.batch_size = int(args.batch_size / ngpus_per_node) 249 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 250 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 251 | else: 252 | model.cuda() 253 | # DistributedDataParallel will divide and allocate batch_size to all 254 | # available GPUs if device_ids are not set 255 | model = torch.nn.parallel.DistributedDataParallel(model) 256 | elif args.gpu is not None: 257 | torch.cuda.set_device(args.gpu) 258 | model = model.cuda(args.gpu) 259 | else: 260 | # DataParallel will divide and allocate batch_size to all available GPUs 261 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 262 | model.features = torch.nn.DataParallel(model.features) 263 | model.cuda() 264 | else: 265 | model = torch.nn.DataParallel(model).cuda() 266 | 267 | # define loss function (criterion) and optimizer 268 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 269 | 270 | # optimize only the linear classifier 271 | parameters = list(filter(lambda p: p.requires_grad, model.parameters())) 272 | #assert len(parameters) == 2*int(args.fc_type) # fc.weight, fc.bias 273 | optimizer = torch.optim.SGD(parameters, args.lr, 274 | momentum=args.momentum, 275 | weight_decay=args.weight_decay) 276 | if args.cos_lr: 277 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50, eta_min=1e-6) 278 | 279 | # optionally resume from a checkpoint 280 | if args.resume: 281 | if os.path.isfile(args.resume): 282 | print("=> loading checkpoint '{}'".format(args.resume)) 283 | if args.gpu is None: 284 | checkpoint = torch.load(args.resume) 285 | else: 286 | # Map model to be loaded to specified single gpu. 287 | loc = 'cuda:{}'.format(args.gpu) 288 | checkpoint = torch.load(args.resume, map_location=loc) 289 | args.start_epoch = checkpoint['epoch'] 290 | best_acc = checkpoint['metrics'][0] 291 | if args.gpu is not None: 292 | # best_acc1 may be from a checkpoint from a different GPU 293 | best_acc = best_acc.to(args.gpu) 294 | model.load_state_dict(checkpoint['state_dict'], strict=False) 295 | optimizer.load_state_dict(checkpoint['optimizer']) 296 | print("=> loaded checkpoint '{}' (epoch {})" 297 | .format(args.resume, checkpoint['epoch'])) 298 | else: 299 | print("=> no checkpoint found at '{}'".format(args.resume)) 300 | 301 | #cudnn.benchmark = True 302 | 303 | # Data loader 304 | normalize = transforms.Normalize([0.485, 0.456, 0.406], 305 | [0.229, 0.224, 0.225]) 306 | 307 | train_dataset = GbUsgDataSet(data_dir=args.img_dir, 308 | image_list_file=args.train_list, 309 | #df=df, 310 | #train=True, 311 | bin_classify=(args.num_classes==2), 312 | transform=transforms.Compose([ 313 | #transforms.Resize((224,224)), 314 | transforms.Resize(224), 315 | transforms.RandomCrop(224), 316 | transforms.ToTensor(), 317 | normalize, 318 | ])) 319 | 320 | #train_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, 321 | # shuffle=True, num_workers=0) 322 | 323 | val_dataset = GbUsgDataSet(data_dir=args.img_dir, 324 | image_list_file=args.val_list, 325 | #df=df, 326 | #train=True, 327 | bin_classify=(args.num_classes==2), 328 | transform=transforms.Compose([ 329 | #transforms.Resize((224,224)), 330 | transforms.Resize(224), 331 | transforms.CenterCrop(224), 332 | transforms.ToTensor(), 333 | normalize, 334 | ])) 335 | 336 | #val_loader = DataLoader(dataset=val_dataset, batch_size=1, 337 | # shuffle=False, num_workers=0) 338 | if args.distributed: 339 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 340 | else: 341 | train_sampler = None 342 | 343 | train_loader = torch.utils.data.DataLoader( 344 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 345 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 346 | 347 | val_loader = torch.utils.data.DataLoader( 348 | val_dataset, batch_size=1, shuffle=False, 349 | num_workers=args.workers, pin_memory=True) 350 | 351 | if args.evaluate: 352 | y_true, y_pred = validate(val_loader, model, criterion, args) 353 | cfm = confusion_matrix(y_true, y_pred) 354 | acc = accuracy_score(y_true, y_pred) 355 | if args.num_classes == 2: 356 | spec = cfm[0][0]/np.sum(cfm[0]) 357 | sens = cfm[1][1]/np.sum(cfm[1]) 358 | print("%.4f %.4f %.4f"%(acc, spec, sens)) 359 | else: 360 | spec = (cfm[0][0]+cfm[0][1]+cfm[1][0]+cfm[1][1])/(np.sum(cfm[0])+np.sum(cfm[1])) 361 | sens = cfm[2][2]/np.sum(cfm[2]) 362 | acc_2 = (cfm[0][0]+cfm[0][1]+cfm[1][0]+cfm[1][1]+cfm[2][2])/np.sum(cfm) 363 | print("%.4f %.4f %.4f %.4f"%(acc_2, spec, sens, acc)) 364 | 365 | print(cfm) 366 | return 367 | 368 | """ 369 | run = neptune.init( 370 | project="sbasu276/cycle-contrast", 371 | api_token="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiIzNjgzZTk5Yi0xNmFlLTQ4YTAtODBhZS0xOGRmNzdlMTFhMmEifQ==", 372 | mode="offline", 373 | ) # your credentials 374 | 375 | params = { 376 | "learning_rate": args.lr, 377 | "weight decay": args.weight_decay, 378 | "batch size": args.batch_size, 379 | "fc_type": args.fc_type, 380 | "last layer": args.last_layer, 381 | "save_dir": args.save_dir, 382 | "gradual unfreeze": args.gradual_unfreeze, 383 | "pretrained model path": args.pretrained} 384 | run["parameters"] = params 385 | """ 386 | 387 | os.makedirs(args.save_dir, exist_ok=True) 388 | #start_time = time.time() 389 | best_f1, best_acc = 0, 0 390 | is_best = True 391 | for epoch in range(args.start_epoch, args.epochs): 392 | if args.distributed: 393 | train_sampler.set_epoch(epoch) 394 | #adjust_learning_rate(optimizer, epoch, args) 395 | 396 | # train for one epoch 397 | epoch_loss = train_one_epoch(train_loader, model, criterion, optimizer, epoch, args) 398 | 399 | if args.gradual_unfreeze: 400 | if epoch >= args.warmup: 401 | if epoch % args.step == 0: 402 | unfreeze_layers(model, epoch-args.warmup, args.step) #, num_layers) 403 | #run["Train/Loss"].log(epoch_loss) 404 | 405 | # evaluate on validation set 406 | if epoch % args.eval_interval == args.eval_interval - 1: 407 | y_true, y_pred = validate(val_loader, model, criterion, args) 408 | run=None 409 | acc, spec, sens, cfm = log_stats(run, y_true, y_pred, args) 410 | f1 = 2*(spec*sens)/(spec+sens) 411 | 412 | # remember best mcc and save checkpoint 413 | #is_best = acc > best_acc 414 | is_best = f1 > best_f1 415 | best_f1 = max(f1, best_f1) 416 | #if args.num_classes == 2: 417 | # print("Epoch: %s\t Acc: %.4f\t Spec: %.4f\t Sens: %.4f\t Loss: %.4f"%(epoch, acc, spec, sens, epoch_loss)) 418 | #else: 419 | # acc_2 = (cfm[0][0]+cfm[0][1]+cfm[1][0]+cfm[1][1]+cfm[2][2])/np.sum(cfm) 420 | # print("Epoch: %s\t Acc-2: %.4f\t Spec: %.4f\t Sens: %.4f\t Acc-3: %.4f\t Loss: %.4f"%(epoch, acc_2, spec, sens, acc, epoch_loss)) 421 | #best_mcc = max(mcc, best_mcc) 422 | if is_best: 423 | best_cfm = copy.deepcopy(cfm) 424 | best_epoch = epoch 425 | 426 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 427 | and args.rank % ngpus_per_node == 0): 428 | save_checkpoint({ 429 | 'epoch': epoch + 1, 430 | 'arch': args.arch, 431 | 'state_dict': model.state_dict(), 432 | 'metrics': [acc, spec, sens], 433 | 'optimizer' : optimizer.state_dict(), 434 | }, is_best, filename='lincls_ep_%s.pth.tar'%(epoch), path=args.save_dir) 435 | 436 | if epoch>=args.warmup and args.cos_lr: 437 | lr_scheduler.step() 438 | 439 | #print('best cfm\n', best_cfm) 440 | cfm = best_cfm 441 | spec = (cfm[0][0]+cfm[0][1]+cfm[1][0]+cfm[1][1])/(np.sum(cfm[0])+np.sum(cfm[1])) 442 | sens = cfm[2][2]/np.sum(cfm[2]) 443 | acc_2 = (cfm[0][0]+cfm[0][1]+cfm[1][0]+cfm[1][1]+cfm[2][2])/np.sum(cfm) 444 | acc = (cfm[0][0]+cfm[1][1]+cfm[2][2])/np.sum(cfm) 445 | print("%s %.4f %.4f %.4f %.4f"%(best_epoch, acc_2, spec, sens, acc)) 446 | 447 | 448 | def unfreeze_layers(model, epoch, steps, num_layers=16): 449 | LAYERS = [ 450 | ("layer4", 2), 451 | ("layer4", 1), 452 | ("layer4", 0), 453 | ("layer3", 5), 454 | ("layer3", 4), 455 | ("layer3", 3), 456 | ("layer3", 2), 457 | ("layer3", 1), 458 | ("layer3", 0), 459 | ("layer2", 3), 460 | ("layer2", 2), 461 | ("layer2", 1), 462 | ("layer2", 0), 463 | ("layer1", 2), 464 | ("layer1", 1), 465 | ("layer1", 0) 466 | ] 467 | for j in range(num_layers): 468 | if epoch >= j*steps and epoch < (j+1)*steps: 469 | layer, conv = LAYERS[j] 470 | layername = "%s.%s"%(layer, conv) 471 | for name, params in model.named_parameters(): 472 | if layername in name and params.requires_grad == False: 473 | params.requires_grad = True 474 | if epoch == num_layers*steps: 475 | for name, params in model.named_parameters(): 476 | if params.requires_grad == False: 477 | params.requires_grad = True 478 | 479 | 480 | def train_one_epoch(train_loader, model, criterion, optimizer, epoch, args): 481 | #model.eval() 482 | running_loss = 0.0 483 | for i, (images, target, _) in enumerate(train_loader): 484 | if args.gpu is not None: 485 | images = images.cuda(args.gpu, non_blocking=True) 486 | target = target.cuda(args.gpu, non_blocking=True) 487 | # compute output 488 | output = model(images) 489 | loss = criterion(output, target) 490 | 491 | # compute gradient and do SGD step 492 | optimizer.zero_grad() 493 | loss.backward() 494 | optimizer.step() 495 | 496 | running_loss += loss.data.item() 497 | 498 | return running_loss/len(train_loader) 499 | 500 | def mcc_score(cfm): 501 | tp = cfm[1][1] 502 | tn = cfm[0][0] 503 | fp = cfm[0][1] 504 | fn = cfm[1][0] 505 | numer = (tp*tn)-(fp*fn) 506 | denom = np.sqrt((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn)) 507 | return numer/denom 508 | 509 | 510 | def validate(val_loader, model, criterion, args): 511 | model.eval() 512 | with torch.no_grad(): 513 | y_true, y_pred = [], [] 514 | for i, (images, target, _) in enumerate(val_loader): 515 | if args.gpu is not None: 516 | images = images.cuda(args.gpu, non_blocking=True) 517 | target = target.cuda(args.gpu, non_blocking=True) 518 | 519 | # compute output 520 | output = model(images) 521 | _, pred = torch.max(output, dim=1) 522 | 523 | #loss = criterion(output, target) 524 | y_true.append(target.tolist()[0]) 525 | y_pred.append(pred.item()) 526 | 527 | y_true = np.array(y_true) 528 | y_pred = np.array(y_pred) 529 | 530 | return y_true, y_pred 531 | 532 | def log_stats(logobj, y_true, y_pred, args, label="Eval"): 533 | acc = accuracy_score(y_true, y_pred) 534 | cfm = confusion_matrix(y_true, y_pred) 535 | if args.num_classes == 2: 536 | spec = cfm[0][0]/np.sum(cfm[0]) 537 | sens = cfm[1][1]/np.sum(cfm[1]) 538 | else: 539 | spec = (cfm[0][0]+cfm[0][1]+cfm[1][0]+cfm[1][1])/(np.sum(cfm[0])+np.sum(cfm[1])) 540 | sens = cfm[2][2]/np.sum(cfm[2]) 541 | acc_2 = (cfm[0][0]+cfm[0][1]+cfm[1][0]+cfm[1][1]+cfm[2][2])/np.sum(cfm) 542 | #logobj["%s/Acc-2cls"%label].log(acc_2) 543 | #mcc = mcc_score(cfm) 544 | 545 | #logobj["%s/Accuracy"%label].log(acc) 546 | #logobj["%s/MCC"%label].log(mcc) 547 | #logobj["%s/Specificity"%label].log(spec) 548 | #logobj["%s/Sensitivity"%label].log(sens) 549 | 550 | return acc, spec, sens, cfm 551 | 552 | 553 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', path=None): 554 | if path is not None: 555 | filename = os.path.join(path, filename) 556 | torch.save(state, filename) 557 | if is_best: 558 | shutil.copyfile(filename, 'model_best_lincls.pth.tar') 559 | 560 | 561 | def sanity_check(state_dict, pretrained_weights): 562 | """ 563 | Linear classifier should not change any weights other than the linear layer. 564 | This sanity check asserts nothing wrong happens (e.g., BN stats updated). 565 | """ 566 | #print("=> loading '{}' for sanity check".format(pretrained_weights)) 567 | checkpoint = torch.load(pretrained_weights, map_location="cpu") 568 | state_dict_pre = checkpoint['state_dict'] 569 | 570 | for k in list(state_dict.keys()): 571 | # only ignore fc layer 572 | if 'fc.weight' in k or 'fc.bias' in k: 573 | continue 574 | 575 | # name in pretrained model 576 | k_pre = 'module.encoder_q.' + k[len('module.'):] \ 577 | if k.startswith('module.') else 'module.encoder_q.' + k 578 | 579 | assert ((state_dict[k].cpu() == state_dict_pre[k_pre]).all()), \ 580 | '{} is changed in linear classifier training.'.format(k) 581 | 582 | #print("=> sanity check passed.") 583 | 584 | def adjust_learning_rate(optimizer, epoch, args): 585 | """Decay the learning rate based on schedule""" 586 | lr = args.lr 587 | for milestone in args.schedule: 588 | lr *= 0.1 if epoch >= milestone else 1. 589 | for param_group in optimizer.param_groups: 590 | param_group['lr'] = lr 591 | 592 | 593 | if __name__ == '__main__': 594 | main() 595 | -------------------------------------------------------------------------------- /train_ucl.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import builtins 3 | import datetime 4 | import math 5 | import os 6 | import random 7 | import shutil 8 | import time 9 | import warnings 10 | import json 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.parallel 15 | import torch.backends.cudnn as cudnn 16 | import torch.distributed as dist 17 | import torch.optim 18 | import torch.multiprocessing as mp 19 | import torch.utils.data 20 | import torch.utils.data.distributed 21 | import torchvision.transforms as transforms 22 | import torchvision.models as models 23 | from torch.utils.tensorboard import SummaryWriter 24 | 25 | import ucl.loader 26 | import ucl.builder_gb 27 | from ucl.datasets import ImageFolderInstance 28 | from ucl.datasets import R2V2Dataset 29 | from ucl.datasets_gb import GbVideoDataset 30 | from utils.util import is_main_process 31 | #import neptune.new as neptune 32 | 33 | model_names = sorted(name for name in models.__dict__ 34 | if name.islower() and not name.startswith("__") 35 | and callable(models.__dict__[name])) 36 | 37 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 38 | parser.add_argument('data', metavar='DIR', 39 | help='path to dataset') 40 | parser.add_argument('--dataset', type=str, default='gbc', 41 | help='dataset name') 42 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', 43 | choices=model_names, 44 | help='model architecture: ' + 45 | ' | '.join(model_names) + 46 | ' (default: resnet50)') 47 | parser.add_argument('-j', '--workers', default=32, type=int, metavar='N', 48 | help='number of data loading workers (default: 32)') 49 | parser.add_argument('--epochs', default=200, type=int, metavar='N', 50 | help='number of total epochs to run') 51 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 52 | help='manual epoch number (useful on restarts)') 53 | parser.add_argument('-b', '--batch-size', default=512, type=int, 54 | metavar='N', 55 | help='mini-batch size (default: 512), this is the total ' 56 | 'batch size of all GPUs on the current node when ' 57 | 'using Data Parallel or Distributed Data Parallel') 58 | parser.add_argument('--lr', '--learning-rate', default=0.06, type=float, 59 | metavar='LR', help='initial learning rate', dest='lr') 60 | parser.add_argument('--schedule', default=[120, 160], nargs='*', type=int, 61 | help='learning rate schedule (when to drop lr by 10x)') 62 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 63 | help='momentum of SGD solver') 64 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 65 | metavar='W', help='weight decay (default: 1e-4)', 66 | dest='weight_decay') 67 | parser.add_argument('-p', '--print-freq', default=10, type=int, 68 | metavar='N', help='print frequency (default: 10)') 69 | parser.add_argument('--save-freq', default=10, type=int, 70 | help='save frequency (default: 10)') 71 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 72 | help='path to latest checkpoint (default: none)') 73 | parser.add_argument('--pretrained', default='', type=str, metavar='PATH') 74 | parser.add_argument('--world-size', default=1, type=int, 75 | help='number of nodes for distributed training') 76 | parser.add_argument('--rank', default=0, type=int, 77 | help='node rank for distributed training') 78 | parser.add_argument('--dist-url', default='tcp://localhost:10001', type=str, 79 | help='url used to set up distributed training') 80 | parser.add_argument('--dist-backend', default='nccl', type=str, 81 | help='distributed backend') 82 | parser.add_argument('--seed', default=None, type=int, 83 | help='seed for initializing training. ') 84 | parser.add_argument('--gpu', default=None, type=int, 85 | help='GPU id to use.') 86 | parser.add_argument('--multiprocessing-distributed', action='store_true', 87 | help='Use multi-processing distributed training to launch ' 88 | 'N processes per node, which has N GPUs. This is the ' 89 | 'fastest way to use PyTorch for either single node or ' 90 | 'multi node data parallel training') 91 | 92 | # moco specific configs: 93 | parser.add_argument('--moco-dim', default=128, type=int, 94 | help='feature dimension (default: 128)') 95 | parser.add_argument('--moco-k', default=65536, type=int, 96 | help='queue size; number of negative keys (default: 65536)') 97 | parser.add_argument('--cycle-k', default=81920, type=int, 98 | help='cycle queue size; number of negative keys (default: 65536)') 99 | 100 | parser.add_argument('--negative-N', default=32, type=int, 101 | help='Negatives buffer size (default: 32)') 102 | parser.add_argument('--moco-m', default=0.999, type=float, 103 | help='moco momentum of updating key encoder (default: 0.999)') 104 | parser.add_argument('--moco-t', default=0.07, type=float, 105 | help='softmax temperature (default: 0.07)') 106 | 107 | # options for cycle_contrast v2 108 | parser.add_argument('--mlp', action='store_true', 109 | help='use mlp head') 110 | parser.add_argument('--aug-plus', action='store_true', 111 | help='use moco v2 data augmentation') 112 | parser.add_argument('--cos', action='store_true', 113 | help='use cosine lr schedule') 114 | 115 | # options for cycle contrastive 116 | parser.add_argument('--soft-nn', action='store_true') 117 | parser.add_argument('--soft-nn-loss-weight', default=0.1, type=float) 118 | parser.add_argument('--moco-loss-weight', default=1., type=float) 119 | parser.add_argument('--soft-nn-support', default=16384, type=int) 120 | parser.add_argument('--sep-head', action='store_true') 121 | parser.add_argument('--cycle-neg-only', dest='cycle_neg_only', action='store_true') 122 | parser.add_argument('--no-cycle-neg-only', dest='cycle_neg_only', action='store_false') 123 | parser.add_argument('--soft-nn-t', default=-1., type=float) 124 | parser.add_argument('--cycle-back-cls', action='store_true') 125 | parser.add_argument('--cycle-back-cls-video-as-pos', action='store_true') 126 | 127 | parser.add_argument('--resizecropsize', default=0.2, type=float) 128 | 129 | parser.add_argument('--cycle-back-candidates', action='store_true') 130 | 131 | parser.add_argument('--num-classes', default=100, type=int) 132 | parser.add_argument('--moco-random-video-frame-as-pos', action='store_true') 133 | parser.add_argument('--detach-target', action='store_true') 134 | parser.add_argument('--multi-crops', default=2, type=int) 135 | parser.add_argument('--num-of-sampled-frames', default=1, type=int) 136 | parser.set_defaults(cycle_neg_only=True) 137 | parser.add_argument('--save-dir', default='../../../scratch/cyclecontrast/output', type=str) 138 | parser.add_argument('--soft-nn-topk-support', action='store_true') 139 | parser.add_argument('--exp_name',default='integrating_changes',type=str) 140 | parser.add_argument('--adam', action='store_true') 141 | parser.add_argument('--pretrained-models',action='store_true') 142 | parser.add_argument('--constant-lr',action='store_true') 143 | parser.add_argument('--negatives',action='store_true') 144 | parser.add_argument('--intranegs-only-two',action='store_true') 145 | parser.add_argument('--convex-combo-loss',action='store_true') 146 | parser.add_argument('--cross-neg-topk-mining',action='store_true') 147 | parser.add_argument('--cross-neg-topk-support-size',default=4, type=int) 148 | parser.add_argument('--anchor-reverse-cross',action='store_true') 149 | parser.add_argument('--single-loss-intra-inter',action='store_true') 150 | parser.add_argument('--qcap-include',action='store_true') 151 | parser.add_argument('--cosine-curriculum',action='store_true') 152 | parser.add_argument('--cosine-clipping',action='store_true') 153 | parser.add_argument('--mean-neighbors',action='store_true') 154 | parser.add_argument('--single-loss-ncap-support-size',default=4, type=int) 155 | parser.add_argument('--num-negatives',default=2, type=int) 156 | parser.add_argument('--num-var',default=32, type=int) 157 | parser.add_argument('--local_rank', type=int) 158 | parser.add_argument('--num-gpu',default=4, type=int) 159 | 160 | def main(): 161 | args = parser.parse_args() 162 | 163 | if os.path.isdir(os.path.join(args.save_dir,args.exp_name)): 164 | shutil.rmtree(os.path.join(args.save_dir,args.exp_name),ignore_errors=True) 165 | os.mkdir(os.path.join(args.save_dir,args.exp_name)) 166 | ##################### 167 | num_var = args.num_var 168 | args.batch_size = num_var 169 | ##################### 170 | args.cycle_k = num_var 171 | args.moco_k = num_var 172 | args.soft_nn_support = 4 173 | args.negative_N = args.batch_size * args.num_negatives 174 | args.epochs = 60 175 | ##################### 176 | 177 | if args.seed is not None: 178 | random.seed(args.seed) 179 | torch.manual_seed(args.seed) 180 | cudnn.deterministic = True 181 | warnings.warn('You have chosen to seed training. ' 182 | 'This will turn on the CUDNN deterministic setting, ' 183 | 'which can slow down your training considerably! ' 184 | 'You may see unexpected behavior when restarting ' 185 | 'from checkpoints.') 186 | 187 | if args.gpu is not None: 188 | warnings.warn('You have chosen a specific GPU. This will completely ' 189 | 'disable data parallelism.') 190 | 191 | if args.dist_url == "env://" and args.world_size == -1: 192 | args.world_size = int(os.environ["WORLD_SIZE"]) 193 | 194 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 195 | 196 | ngpus_per_node = torch.cuda.device_count() 197 | ngpus_per_node = args.num_gpu 198 | if args.multiprocessing_distributed: 199 | # Since we have ngpus_per_node processes per node, the total world_size 200 | # needs to be adjusted accordingly 201 | args.world_size = ngpus_per_node * args.world_size 202 | # Use torch.multiprocessing.spawn to launch distributed processes: the 203 | # main_worker process function 204 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 205 | else: 206 | # Simply call main_worker function 207 | main_worker(args.gpu, ngpus_per_node, args) 208 | 209 | 210 | def main_worker(gpu, ngpus_per_node, args): 211 | args.gpu = gpu 212 | if args.gpu !=0: 213 | args.gpu +=3 214 | 215 | # suppress printing if not master 216 | if args.multiprocessing_distributed and args.gpu != 0: 217 | def print_pass(*args): 218 | pass 219 | builtins.print = print_pass 220 | 221 | if args.gpu is not None: 222 | print("Use GPU: {} for training".format(args.gpu)) 223 | 224 | if args.distributed: 225 | if args.dist_url == "env://" and args.rank == -1: 226 | args.rank = int(os.environ["RANK"]) 227 | if args.multiprocessing_distributed: 228 | # For multiprocessing distributed training, rank needs to be the 229 | # global rank among all the processes 230 | args.rank = args.rank * ngpus_per_node + gpu 231 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 232 | world_size=args.world_size, rank=args.rank) 233 | print('rank', dist.get_rank()) 234 | 235 | # init_distributed_mode(args) 236 | #print(args) 237 | 238 | # create model 239 | 240 | print("=> creating model '{}'".format(args.arch)) 241 | model = ucl.builder_gb.CycleContrast( 242 | args.arch, args.moco_dim, args.moco_k, args.moco_m, args.moco_t, args.mlp, soft_nn=args.soft_nn, 243 | soft_nn_support=args.soft_nn_support, 244 | sep_head=args.sep_head, 245 | cycle_neg_only=args.cycle_neg_only, 246 | soft_nn_T=args.soft_nn_t, 247 | cycle_back_cls=args.cycle_back_cls, 248 | cycle_back_cls_video_as_pos=args.cycle_back_cls_video_as_pos, 249 | moco_random_video_frame_as_pos=args.moco_random_video_frame_as_pos, 250 | cycle_K=args.cycle_k, 251 | pretrained_on_imagenet=args.pretrained_models, 252 | soft_nn_topk_support= args.soft_nn_topk_support, 253 | negative_use = args.negatives, 254 | neg_queue_size=args.negative_N, 255 | intranegs_only_two = args.intranegs_only_two, 256 | cross_neg_topk_mining = args.cross_neg_topk_mining, 257 | cross_neg_topk_support_size = args.cross_neg_topk_support_size, 258 | anchor_reverse_cross = args.anchor_reverse_cross, 259 | single_loss_intra_inter = args.single_loss_intra_inter, 260 | single_loss_ncap_support_size = args.single_loss_ncap_support_size, 261 | qcap_include = args.qcap_include, 262 | tsne_name = args.exp_name, 263 | mean_neighbors = args.mean_neighbors 264 | ) 265 | print(model) 266 | 267 | if args.gpu == 0: 268 | writer = SummaryWriter(args.save_dir) 269 | else: 270 | writer = None 271 | 272 | if args.distributed: 273 | # For multiprocessing distributed, DistributedDataParallel constructor 274 | # should always set the single device scope, otherwise, 275 | # DistributedDataParallel will use all available devices. 276 | if args.gpu is not None: 277 | torch.cuda.set_device(args.gpu) 278 | model.cuda(args.gpu) 279 | # When using a single GPU per process and per 280 | # DistributedDataParallel, we need to divide the batch size 281 | # ourselves based on the total number of GPUs we have 282 | args.batch_size = int(args.batch_size / ngpus_per_node) 283 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 284 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu],find_unused_parameters=True) 285 | else: 286 | model.cuda() 287 | # DistributedDataParallel will divide and allocate batch_size to all 288 | # available GPUs if device_ids are not set 289 | model = torch.nn.parallel.DistributedDataParallel(model) 290 | elif args.gpu is not None: 291 | torch.cuda.set_device(args.gpu) 292 | model = model.cuda(args.gpu) 293 | # comment out the following line for debugging 294 | # raise NotImplementedError("Only DistributedDataParallel is supported.") 295 | else: 296 | # AllGather implementation (batch shuffle, queue update, etc.) in 297 | # this code only supports DistributedDataParallel. 298 | raise NotImplementedError("Only DistributedDataParallel is supported.") 299 | 300 | if args.cycle_back_cls: 301 | criterion = [nn.CrossEntropyLoss().cuda(args.gpu), 302 | nn.CrossEntropyLoss().cuda(args.gpu)] 303 | else: 304 | criterion = [nn.CrossEntropyLoss().cuda(args.gpu), 305 | nn.MSELoss().cuda(args.gpu)] 306 | 307 | 308 | if args.adam: 309 | optimizer = torch.optim.Adam(model.parameters(),args.lr) 310 | else: 311 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 312 | momentum=args.momentum, 313 | weight_decay=args.weight_decay) 314 | #print(optimizer) 315 | 316 | # optionally resume from a checkpoint 317 | if args.resume: 318 | if os.path.isfile(args.resume): 319 | print("=> loading checkpoint '{}'".format(args.resume)) 320 | if args.gpu is None: 321 | checkpoint = torch.load(args.resume) 322 | else: 323 | # Map model to be loaded to specified single gpu. 324 | loc = 'cuda:{}'.format(args.gpu) 325 | checkpoint = torch.load(args.resume, map_location=loc) 326 | args.start_epoch = checkpoint['epoch'] 327 | msg = model.load_state_dict(checkpoint['state_dict'], strict=True) 328 | optimizer.load_state_dict(checkpoint['optimizer']) 329 | print("=> loaded checkpoint '{}' (epoch {}), {}" 330 | .format(args.resume, checkpoint['epoch'], msg)) 331 | else: 332 | print("=> no checkpoint found at '{}'".format(args.resume)) 333 | 334 | cudnn.benchmark = True 335 | 336 | # Data loading code 337 | traindir = os.path.join(args.data) 338 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 339 | std=[0.229, 0.224, 0.225]) 340 | resizecropsize = args.resizecropsize 341 | 342 | augmentation = [ 343 | transforms.Resize(224), 344 | transforms.RandomCrop(224), 345 | transforms.ToTensor(), 346 | normalize 347 | ] 348 | 349 | 350 | 351 | print("Start training") 352 | 353 | """ 354 | Curriculum = [Epoch_Number_till_which_this_will_happen, Negs_low, Negs_High, Only_use_qcap] 355 | """ 356 | curriculum = [[10,0.2,0.4,True],[30,0.2,0.4,False],[50,0.1,0.15,False],[60,0.03,0.07,False]] 357 | """ 358 | curriculum = [[20,0.2,0.4,False],[40,0.1,0.15,False],[50,0.03,0.07,False]] 359 | """ 360 | curriculum_ptr = 0 361 | 362 | if args.cosine_curriculum: 363 | low_n = 0.2 364 | high_n = 0.4 365 | 366 | if args.cosine_clipping: 367 | min_low_n = 0.03 368 | min_high_n = 0.07 369 | 370 | start_time = time.time() 371 | for epoch in range(args.start_epoch, args.epochs): 372 | 373 | """ Getting the correct Dataloader as per epoch requirements""" 374 | 375 | 376 | if epoch==curriculum[curriculum_ptr][0]: 377 | curriculum_ptr+=1 378 | 379 | if args.dataset == 'r2v2': 380 | train_dataset = R2V2Dataset(args.data, 381 | transforms.Compose(augmentation), 382 | return_all_video_frames=args.cycle_back_candidates 383 | or args.moco_random_video_frame_as_pos, 384 | num_of_sampled_frames=args.num_of_sampled_frames, 385 | ) 386 | paths_ds_tsne = train_dataset.path_info 387 | paths_ds_dict ={} 388 | for i in range(len(paths_ds_tsne)): 389 | paths_ds_dict[i]=paths_ds_tsne[i][0] 390 | with open('vid_index_mapping_tsne.json', 'w') as fp: 391 | json.dump(paths_ds_dict, fp) 392 | 393 | elif args.dataset == 'gbc': 394 | if not args.cosine_curriculum: 395 | train_dataset = GbVideoDataset(args.data, 396 | transforms.Compose(augmentation),neg_dists=[curriculum[curriculum_ptr][1], curriculum[curriculum_ptr][2]], num_neg_samples=args.num_negatives) 397 | 398 | else: 399 | if epoch >= curriculum[1][0]: 400 | epochs_done = curriculum[1][0] 401 | low_n , high_n = adjust_negs(low_n,high_n,epoch,args,epochs_done) 402 | if args.cosine_clipping: 403 | low_n = max(min_low_n, low_n) 404 | high_n = max(min_high_n, high_n) 405 | 406 | print(low_n,high_n) 407 | train_dataset = GbVideoDataset(args.data, 408 | transforms.Compose(augmentation),neg_dists=[low_n, high_n], num_neg_samples=args.num_negatives) 409 | 410 | else: 411 | crops_transform = ucl.loader.TwoCropsTransform(transforms.Compose(augmentation)) 412 | train_dataset = ImageFolderInstance(traindir, 413 | crops_transform) 414 | print('class name to idx', train_dataset.dataset.class_to_idx) 415 | 416 | if args.distributed: 417 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 418 | else: 419 | train_sampler = None 420 | 421 | train_loader = torch.utils.data.DataLoader( 422 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 423 | num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=False) 424 | 425 | 426 | ep_start_time = time.time() 427 | if args.distributed: 428 | train_sampler.set_epoch(epoch) 429 | if not args.adam: 430 | if not args.constant_lr: 431 | adjust_learning_rate(optimizer, epoch, args) 432 | 433 | # train for one epoch 434 | if args.single_loss_intra_inter: 435 | total_loss_n,loss_moco_n, top1_n = train(train_loader, model, criterion, optimizer, epoch, args, only_qcap= curriculum[curriculum_ptr][3], writer=writer) 436 | else: 437 | if not args.intranegs_only_two: 438 | total_loss_n,loss_moco_n, loss_softnn_n, top1_n,top5_n = train(train_loader, model, criterion, optimizer, epoch, args,writer) 439 | else: 440 | total_loss_n,loss_moco_n, loss_softnn_n, top1_n = train(train_loader, model, criterion, optimizer, epoch, args,writer) 441 | 442 | if args.gpu==0: 443 | learning_rate_neptune = 0 444 | for param_group in optimizer.param_groups: 445 | learning_rate_neptune=param_group['lr'] 446 | ##print(learning_rate_neptune) 447 | 448 | #print(optimizer.state_dict()) 449 | 450 | 451 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 452 | and args.rank % ngpus_per_node == 0): 453 | if epoch % args.save_freq == args.save_freq - 1 and is_main_process(): 454 | save_checkpoint({ 455 | 'epoch': epoch + 1, 456 | 'arch': args.arch, 457 | 'state_dict': model.state_dict(), 458 | 'optimizer': optimizer.state_dict(), 459 | }, is_best=False, filename='checkpoint_{:04d}'.format(epoch)+"_"+args.exp_name+".pth.tar", path=os.path.join(args.save_dir,args.exp_name)) 460 | epoch_time_str = str(datetime.timedelta(seconds=int(time.time() - ep_start_time))) 461 | print('Train Epoch {} time {}'.format(epoch, epoch_time_str)) 462 | 463 | total_time = time.time() - start_time 464 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 465 | print('Training time {}'.format(total_time_str)) 466 | 467 | 468 | def train(train_loader, model, criterion, optimizer, epoch, args,only_qcap=False, writer=None): 469 | batch_time = AverageMeter('Time', ':6.3f') 470 | data_time = AverageMeter('Data', ':6.3f') 471 | losses = AverageMeter('Loss', ':.4e') 472 | top1 = AverageMeter('Acc@1', ':6.2f') 473 | top5 = AverageMeter('Acc@5', ':6.2f') 474 | softnn_losses = AverageMeter('Cycle Loss', ':.4e') 475 | total_loss = AverageMeter('Total Loss', ':.4e') 476 | 477 | if args.single_loss_intra_inter: 478 | log_stats = [batch_time, data_time, losses, top1] 479 | else: 480 | if not args.intranegs_only_two: 481 | log_stats = [batch_time, data_time, losses, softnn_losses, top1, top5] 482 | 483 | else: 484 | log_stats = [batch_time, data_time, losses, softnn_losses, top1] 485 | 486 | progress = ProgressMeter( 487 | len(train_loader), 488 | log_stats, 489 | prefix="Epoch: [{}]".format(epoch)) 490 | 491 | # switch to train mode 492 | 493 | model.train() 494 | 495 | end = time.time() 496 | 497 | for i, data_pack in enumerate(train_loader): 498 | 499 | if len(data_pack) == 3: 500 | images, cls_labels, indices = data_pack 501 | video_frames = None 502 | elif len(data_pack) == 4: 503 | images, cls_labels, indices, video_frames = data_pack 504 | elif len(data_pack) == 5: 505 | images, cls_labels, indices, video_frames, is_same_frame = data_pack 506 | elif len(data_pack) ==6: 507 | images, cls_labels, indices, video_frames,metadata_q,metadata_can = data_pack 508 | else: 509 | assert False, 'unsupported data pack of len {}'.format(len(data_pack)) 510 | data_time.update(time.time() - end) 511 | 512 | if args.negatives: 513 | cls_candidates = video_frames[:,0].unsqueeze(1) 514 | 515 | im_negs = None 516 | for itr_var in range(args.num_negatives): 517 | if itr_var == 0: 518 | im_negs = video_frames[:,itr_var+1].unsqueeze(1) 519 | else: 520 | im_negs = torch.cat((im_negs,video_frames[:,itr_var+1].unsqueeze(1)),axis=1) 521 | 522 | 523 | outputs = model(im_q=images[0], im_k=images[1], cls_labels=cls_labels, indices=indices, 524 | cls_candidates=cls_candidates,im_negs=im_negs,tsne=(epoch==args.epochs-1),only_qcap=only_qcap) 525 | else: 526 | outputs = model(im_q=images[0], im_k=images[1], cls_labels=cls_labels, indices=indices, 527 | cls_candidates=video_frames,tsne=(epoch==args.epochs-1)) 528 | 529 | if args.soft_nn: 530 | if not args.single_loss_intra_inter: 531 | output, target, softnn_feat, q_feat, meta = outputs 532 | else: 533 | output, target, meta = outputs 534 | else: 535 | output, target, meta = outputs 536 | 537 | loss_moco = criterion[0](output, target) 538 | 539 | if args.soft_nn: 540 | """ 541 | if not args.cycle_back_k and args.detach_target: 542 | softnn_feat = softnn_feat.detach() 543 | """ 544 | if not args.single_loss_intra_inter: 545 | loss_softnn = criterion[1](softnn_feat, q_feat) 546 | if args.moco_loss_weight == 0: 547 | loss_moco = loss_moco.detach() 548 | if args.soft_nn_loss_weight == 0: 549 | loss_softnn = loss_softnn.detach() 550 | if not args.convex_combo_loss: 551 | loss = loss_moco * args.moco_loss_weight + loss_softnn * args.soft_nn_loss_weight 552 | else: 553 | loss = loss_moco * (1-args.soft_nn_loss_weight) + loss_softnn * args.soft_nn_loss_weight 554 | else: 555 | if args.moco_loss_weight == 0: 556 | loss_moco = loss_moco.detach() 557 | 558 | loss = loss_moco 559 | else: 560 | if args.moco_loss_weight == 0: 561 | loss_moco = loss_moco.detach() 562 | loss = loss_moco * args.moco_loss_weight 563 | 564 | losses.update(loss_moco.item(), images[0].size(0)) 565 | total_loss.update(loss.item(),images[0].size(0)) 566 | 567 | if args.soft_nn: 568 | if not args.single_loss_intra_inter: 569 | softnn_losses.update(loss_softnn.item(), images[0].size(0)) 570 | 571 | if not args.intranegs_only_two: 572 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 573 | top1.update(acc1[0].cpu().item(), images[0].size(0)) 574 | top5.update(acc5[0].cpu().item(), images[0].size(0)) 575 | else: 576 | acc1 = accuracy(output, target, topk=(1,)) 577 | top1.update(acc1[0].cpu().item(), images[0].size(0)) 578 | 579 | # compute gradient and do SGD step 580 | optimizer.zero_grad() 581 | loss.backward() 582 | optimizer.step() 583 | 584 | # measure elapsed time 585 | batch_time.update(time.time() - end) 586 | end = time.time() 587 | 588 | if i % args.print_freq == 0: 589 | progress.display(i) 590 | 591 | if args.single_loss_intra_inter: 592 | return total_loss,losses,top1 593 | else: 594 | if not args.intranegs_only_two: 595 | return total_loss,losses,softnn_losses,top1, top5 596 | else: 597 | return total_loss,losses,softnn_losses,top1 598 | 599 | 600 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', path=None): 601 | if path is not None: 602 | filename = os.path.join(path, filename) 603 | torch.save(state, filename) 604 | if is_best: 605 | shutil.copyfile(filename, 'model_best.pth.tar') 606 | 607 | 608 | class AverageMeter(object): 609 | """Computes and stores the average and current value""" 610 | def __init__(self, name, fmt=':f'): 611 | self.name = name 612 | self.fmt = fmt 613 | self.reset() 614 | 615 | def reset(self): 616 | self.val = 0 617 | self.avg = 0 618 | self.sum = 0 619 | self.count = 0 620 | 621 | def update(self, val, n=1): 622 | self.val = val 623 | self.sum += val * n 624 | self.count += n 625 | self.avg = self.sum / self.count 626 | 627 | def __str__(self): 628 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 629 | return fmtstr.format(**self.__dict__) 630 | 631 | 632 | class ProgressMeter(object): 633 | def __init__(self, num_batches, meters, prefix=""): 634 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 635 | self.meters = meters 636 | self.prefix = prefix 637 | 638 | def display(self, batch): 639 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 640 | entries += [str(meter) for meter in self.meters] 641 | print('\t'.join(entries)) 642 | 643 | def _get_batch_fmtstr(self, num_batches): 644 | num_digits = len(str(num_batches // 1)) 645 | fmt = '{:' + str(num_digits) + 'd}' 646 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 647 | 648 | 649 | def adjust_learning_rate(optimizer, epoch, args): 650 | """Decay the learning rate based on schedule""" 651 | lr = args.lr 652 | if args.cos: # cosine lr schedule 653 | lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) 654 | else: # stepwise lr schedule 655 | for milestone in args.schedule: 656 | lr *= 0.1 if epoch >= milestone else 1. 657 | for param_group in optimizer.param_groups: 658 | param_group['lr'] = lr 659 | 660 | def adjust_negs(low_n,high_n,epoch,args,epochs_done): 661 | epochs = args.epochs 662 | low_n *= 0.5 * (1. + math.cos(math.pi * (epoch-epochs_done) / (epochs-epochs_done))) 663 | high_n *= 0.5 * (1. + math.cos(math.pi * (epoch-epochs_done) / (epochs-epochs_done))) 664 | return low_n, high_n 665 | 666 | def accuracy(output, target, topk=(1,)): 667 | """Computes the accuracy over the k top predictions for the specified values of k""" 668 | with torch.no_grad(): 669 | maxk = max(topk) 670 | batch_size = target.size(0) 671 | 672 | _, pred = output.topk(maxk, 1, True, True) 673 | pred = pred.t() 674 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 675 | 676 | res = [] 677 | for k in topk: 678 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 679 | res.append(correct_k.mul_(100.0 / batch_size)) 680 | return res 681 | 682 | 683 | def accuracy_nn(output, target, topk=(1,)): 684 | """Computes the accuracy over the k top predictions for the specified values of k""" 685 | with torch.no_grad(): 686 | maxk = max(topk) 687 | batch_size = target.size(0) 688 | 689 | _, pred = output.topk(maxk, 1, True, True) 690 | pred = pred.t() 691 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 692 | 693 | res = [] 694 | for k in topk: 695 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 696 | res.append(correct_k.mul_(100.0 / batch_size)) 697 | return res 698 | 699 | if __name__ == '__main__': 700 | main() 701 | -------------------------------------------------------------------------------- /ucl/builder_gb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import ucl.resnet as models 5 | import cv2 6 | import numpy as np 7 | from sklearn.manifold import TSNE 8 | from sklearn.decomposition import PCA 9 | import pandas as pd 10 | import matplotlib.pyplot as plt 11 | import seaborn as sns 12 | 13 | 14 | class CycleContrast(nn.Module): 15 | """ 16 | Build a CycleContrast model with: a query encoder, a key encoder, and a queue 17 | https://arxiv.org/abs/2105.06463 18 | """ 19 | def __init__(self, arch, dim=128, 20 | K=65536, m=0.999, 21 | T=0.07, mlp=False, 22 | soft_nn=False, soft_nn_support=-1, 23 | sep_head=False, cycle_neg_only=True, 24 | soft_nn_T=-1., 25 | cycle_back_cls=False, 26 | cycle_back_cls_video_as_pos=False, 27 | cycle_K=None, 28 | moco_random_video_frame_as_pos=False, 29 | pretrained_on_imagenet=False, 30 | soft_nn_topk_support=True, 31 | negative_use = True, 32 | neg_queue_size=32, 33 | intranegs_only_two=True, 34 | cross_neg_topk_mining= True, 35 | cross_neg_topk_support_size = 4, 36 | anchor_reverse_cross = True, 37 | single_loss_intra_inter= True, 38 | single_loss_ncap_support_size = 4, 39 | qcap_include = False, 40 | tsne_name="", 41 | mean_neighbors= False 42 | ): 43 | super(CycleContrast, self).__init__() 44 | 45 | self.K = K ## 65536 size of queue 46 | self.tsne_name = tsne_name 47 | if cycle_K is not None: 48 | self.cycle_K = cycle_K ## 81920 49 | else: 50 | self.cycle_K = K 51 | self.m = m ## 0.999 52 | self.T = T ## 0.07 53 | 54 | if soft_nn_T != -1: 55 | self.soft_nn_T = soft_nn_T 56 | else: 57 | self.soft_nn_T = T ## soft_nn_T gets the value of T for us [0.07] 58 | 59 | self.soft_nn = soft_nn ## True 60 | self.soft_nn_support = soft_nn_support ## 16384 61 | self.soft_nn_topk_support = soft_nn_topk_support 62 | self.sep_head = sep_head ## True 63 | self.cycle_back_cls = cycle_back_cls ## True 64 | self.cycle_neg_only = cycle_neg_only ## True 65 | self.cycle_back_cls_video_as_pos = cycle_back_cls_video_as_pos ## True 66 | self.moco_random_video_frame_as_pos = moco_random_video_frame_as_pos ## True 67 | self.neg_queue_size = neg_queue_size 68 | self.negative_use = negative_use 69 | self.intranegs_only_two = intranegs_only_two 70 | self.cross_neg_topk_mining = cross_neg_topk_mining 71 | self.cross_neg_topk_support_size = cross_neg_topk_support_size 72 | self.anchor_reverse_cross = anchor_reverse_cross 73 | self.single_loss_intra_inter = single_loss_intra_inter 74 | self.single_loss_ncap_support_size = single_loss_ncap_support_size 75 | self.qcap_include = qcap_include 76 | self.mean_neighbors = mean_neighbors 77 | # create the encoders 78 | # num_classes is the output fc dimension 79 | self.encoder_q = models.__dict__[arch](pretrained=pretrained_on_imagenet, progress= pretrained_on_imagenet,num_classes=dim, 80 | return_inter=True) 81 | self.encoder_k = models.__dict__[arch](pretrained=pretrained_on_imagenet, progress= pretrained_on_imagenet,num_classes=dim, 82 | return_inter=sep_head) 83 | 84 | ## return inter is True because they use a seperate MLP Head 85 | 86 | # sep head 87 | if self.sep_head: 88 | dim_mlp = self.encoder_q.fc.weight.shape[1] 89 | self.q_cycle_fc = nn.Linear(dim_mlp, dim) 90 | self.k_cycle_fc = nn.Linear(dim_mlp, dim) 91 | 92 | if mlp: # hack: brute-force replacement 93 | dim_mlp = self.encoder_q.fc.weight.shape[1] 94 | self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc) 95 | self.encoder_k.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc) 96 | 97 | if self.sep_head: 98 | dim_mlp = self.q_cycle_fc.weight.shape[1] 99 | self.q_cycle_fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.q_cycle_fc) 100 | self.k_cycle_fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.k_cycle_fc) 101 | 102 | ## This code above defines the head/fc layer for us 103 | 104 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 105 | param_k.data.copy_(param_q.data) # initialize 106 | param_k.requires_grad = False # not update by gradient 107 | 108 | if self.sep_head: 109 | for param_q, param_k in zip(self.q_cycle_fc.parameters(), self.k_cycle_fc.parameters()): 110 | param_k.data.copy_(param_q.data) # initialize 111 | param_k.requires_grad = False # not update by gradient 112 | # create the queue 113 | self.register_buffer("queue", torch.randn(dim, K)) ## initialize the queue by random K vectors 114 | self.queue = nn.functional.normalize(self.queue, dim=0) 115 | 116 | if self.sep_head: 117 | self.register_buffer("queue_cycle", torch.randn(dim, self.cycle_K)) 118 | self.queue_cycle = nn.functional.normalize(self.queue_cycle, dim=0) 119 | self.register_buffer("queue_cycle_ptr", torch.zeros(1, dtype=torch.long)) 120 | 121 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 122 | 123 | self.register_buffer("queue_labels", torch.zeros(1, K, dtype=torch.long)) 124 | 125 | self.register_buffer("queue_indices", torch.zeros(1, K, dtype=torch.long)) 126 | 127 | ######### 128 | self.register_buffer("neg_vecs_labels",torch.zeros(1, self.cycle_K,dtype=torch.long)) 129 | self.register_buffer("label_ptr", torch.zeros(1, dtype=torch.long)) 130 | ######### 131 | 132 | self.register_buffer("negatives",torch.zeros(dim, self.neg_queue_size)) 133 | self.register_buffer("negative_queue_ptr", torch.zeros(1, dtype=torch.long)) 134 | 135 | @torch.no_grad() 136 | def _init_encoder_k(self): 137 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 138 | param_k.data.copy_(param_q.data) # initialize 139 | 140 | @torch.no_grad() 141 | def _momentum_update_key_encoder(self): 142 | """ 143 | Momentum update of the key encoder 144 | """ 145 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 146 | param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) 147 | 148 | if self.sep_head: 149 | for param_q, param_k in zip(self.q_cycle_fc.parameters(), self.k_cycle_fc.parameters()): 150 | param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) 151 | 152 | @torch.no_grad() 153 | def _dequeue_and_enqueue(self, keys, keys_cycle=None, labels=None, indices=None, cluster=None,metadata_q_tsne=None): 154 | # gather keys before updating queue 155 | keys = concat_all_gather(keys) 156 | if keys_cycle is not None: 157 | keys_cycle = concat_all_gather(keys_cycle) 158 | if labels is not None: 159 | labels = concat_all_gather(labels) 160 | if indices is not None: 161 | indices = concat_all_gather(indices) 162 | if cluster is not None: 163 | cluster = concat_all_gather(cluster) 164 | 165 | batch_size = keys.shape[0] 166 | 167 | ptr = int(self.queue_ptr[0]) 168 | # if some gpu of the machine is dead, K might not able to be divided by batch size 169 | assert self.K % batch_size == 0, 'K {} can not be divided by batch size {}'.format(self.K, batch_size) # for simplicity 170 | 171 | if self.sep_head: 172 | cycle_ptr = int(self.queue_cycle_ptr[0]) 173 | # if some gpu of the machine is dead, K might not able to be divided by batch size 174 | assert self.cycle_K % batch_size == 0, \ 175 | 'K {} can not be divided by batch size {}'.format(self.cycle_K, 176 | batch_size) # for simplicity 177 | # replace the keys at ptr (dequeue and enqueue) 178 | try: 179 | self.queue[:, ptr:ptr + batch_size] = keys.T 180 | except Exception: 181 | print('enqueue size', ptr, batch_size, self.K) 182 | enqueue_size = min(self.K-ptr, batch_size) 183 | self.queue[:, ptr:ptr + batch_size] = keys[:enqueue_size].T 184 | try: 185 | if self.sep_head and keys_cycle is not None: 186 | self.queue_cycle[:, cycle_ptr:cycle_ptr + batch_size] = keys_cycle.T 187 | except Exception: 188 | print('enqueue size', ptr, keys_cycle.shape[0], self.K) 189 | enqueue_size = min(self.K - ptr, keys_cycle.shape[0]) 190 | self.queue_cycle[:, ptr:ptr + keys_cycle.shape[0]] = keys_cycle[:enqueue_size].T 191 | try: 192 | if labels is not None: 193 | self.queue_labels[:, ptr:ptr + batch_size] = labels 194 | if indices is not None: 195 | self.queue_indices[:, ptr:ptr + batch_size] = indices 196 | if cluster is not None: 197 | self.queue_cluster[:, ptr:ptr + batch_size] = cluster.T 198 | except Exception: 199 | enqueue_size = min(self.K-ptr, batch_size) 200 | if labels is not None: 201 | self.queue_labels[:, ptr:ptr + batch_size] = labels[:enqueue_size] 202 | if indices is not None: 203 | self.queue_indices[:, ptr:ptr + batch_size] = indices[:enqueue_size] 204 | if cluster is not None: 205 | self.queue_cluster[:, ptr:ptr + batch_size] = cluster[:enqueue_size].T 206 | 207 | 208 | ############################# 209 | label_ptr = self.label_ptr[0] 210 | try: 211 | if self.sep_head and metadata_q_tsne is not None: 212 | self.neg_vecs_labels[:,label_ptr:label_ptr + batch_size] = metadata_q_tsne.T 213 | except Exception: 214 | print('error') 215 | ############################## 216 | 217 | ptr = (ptr + batch_size) % self.K # move pointer 218 | assert ptr < self.K, 'ptr: {}, batch_size: {}, K: {}'.format(ptr, batch_size, self.K) 219 | 220 | self.queue_ptr[0] = ptr 221 | 222 | if self.sep_head: 223 | cycle_ptr = (cycle_ptr + batch_size) % self.cycle_K # move pointer 224 | assert cycle_ptr < self.cycle_K, 'cycle ptr: {}, batch_size: {}, cycle K: {}'.format( 225 | cycle_ptr, batch_size, self.cycle_K) 226 | 227 | self.queue_cycle_ptr[0] = cycle_ptr 228 | 229 | ################################ 230 | if self.sep_head: 231 | label_ptr = (label_ptr + batch_size) % self.cycle_K # move pointer 232 | assert label_ptr < self.cycle_K, 'label ptr: {}, batch_size: {}, cycle K: {}'.format( 233 | cycle_ptr, batch_size, self.cycle_K) 234 | 235 | self.label_ptr[0] = label_ptr 236 | ################################# 237 | 238 | 239 | 240 | 241 | @torch.no_grad() 242 | def _dequeue_and_enqueue_negs(self,negatives=None): 243 | 244 | # gather keys before updating queue 245 | if negatives is not None: 246 | negatives = concat_all_gather(negatives) 247 | 248 | neg_queue_ptr = self.negative_queue_ptr[0] 249 | num_negs = negatives.shape[0] 250 | try: 251 | if negatives is not None: 252 | if (neg_queue_ptr+num_negs) <= self.neg_queue_size: 253 | self.negatives[:,neg_queue_ptr:neg_queue_ptr + num_negs] = negatives.T 254 | elif num_negs > self.neg_queue_size: 255 | num_negs = self.neg_queue_size 256 | self.negatives[:,neg_queue_ptr:neg_queue_ptr + num_negs] = negatives.T[:,0:num_negs] 257 | else: 258 | diff = (neg_queue_ptr+num_negs - self.neg_queue_size) 259 | fit = self.neg_queue_size-neg_queue_ptr 260 | self.negatives[:,neg_queue_ptr:neg_queue_ptr+fit] = negatives.T[:,0:fit] 261 | self.negatives[:,0:diff]=negatives.T[:,fit:fit+diff] 262 | 263 | except Exception: 264 | print('error') 265 | 266 | if negatives is not None: 267 | neg_queue_ptr = (neg_queue_ptr + num_negs) % self.neg_queue_size # move pointer 268 | assert neg_queue_ptr < self.neg_queue_size, 'Neg ptr: {}, Num Negs: {}, Neg Queue Size: {}'.format( 269 | neg_queue_ptr, num_negs, self.neg_queue_size) 270 | 271 | self.negative_queue_ptr[0] = neg_queue_ptr 272 | 273 | 274 | @torch.no_grad() 275 | def _batch_shuffle_ddp(self, x): 276 | """ 277 | Batch shuffle, for making use of BatchNorm. 278 | *** Only support DistributedDataParallel (DDP) model. *** 279 | """ 280 | # gather from all gpus 281 | batch_size_this = x.shape[0] 282 | x_gather = concat_all_gather(x) 283 | batch_size_all = x_gather.shape[0] 284 | 285 | num_gpus = batch_size_all // batch_size_this 286 | 287 | # random shuffle index 288 | idx_shuffle = torch.randperm(batch_size_all).cuda() 289 | 290 | # broadcast to all gpus 291 | torch.distributed.broadcast(idx_shuffle, src=0) 292 | 293 | # index for restoring 294 | idx_unshuffle = torch.argsort(idx_shuffle) 295 | 296 | # shuffled index for this gpu 297 | gpu_idx = torch.distributed.get_rank() 298 | idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] 299 | 300 | return x_gather[idx_this], idx_unshuffle 301 | 302 | @torch.no_grad() 303 | def _batch_unshuffle_ddp(self, x, idx_unshuffle): 304 | """ 305 | Undo batch shuffle. 306 | *** Only support DistributedDataParallel (DDP) model. *** 307 | """ 308 | # gather from all gpus 309 | batch_size_this = x.shape[0] 310 | x_gather = concat_all_gather(x) 311 | batch_size_all = x_gather.shape[0] 312 | 313 | num_gpus = batch_size_all // batch_size_this 314 | 315 | # restored index for this gpu 316 | gpu_idx = torch.distributed.get_rank() 317 | idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx] 318 | 319 | return x_gather[idx_this] 320 | 321 | def forward(self, im_q, im_k=None, 322 | cls_labels=None, indices=None, 323 | cls_candidates=None, im_negs=None, tsne=False, only_qcap= False 324 | ): 325 | 326 | ## im_negs is of the shape 2x2x3x224x224 327 | ## want to convert it to 4x3x224x224 328 | 329 | im_negs_temp = None 330 | num_samples_to_handle = im_negs.shape[0] 331 | for itr_var in range(num_samples_to_handle): 332 | if itr_var ==0: 333 | im_negs_temp = im_negs[itr_var] 334 | else: 335 | temporary_negatives = im_negs[itr_var] 336 | im_negs_temp = torch.cat((im_negs_temp,temporary_negatives),axis=0) 337 | 338 | im_negs = im_negs_temp ## 4x3x224x224 339 | 340 | if self.negative_use is not None: 341 | assert im_negs is not None, "No negatives found even when negative flag activated" 342 | unnormalize_negs = self.encoder_q(im_negs) # queries: NxC 343 | 344 | unnormalize_negs, negs_avgpool = unnormalize_negs 345 | if self.sep_head: 346 | negs_cycle = self.q_cycle_fc(negs_avgpool) 347 | negs_cycle = nn.functional.normalize(negs_cycle, dim=1) 348 | negs = nn.functional.normalize(unnormalize_negs, dim=1) 349 | negs_curr_gpu = negs.clone() 350 | 351 | self._dequeue_and_enqueue_negs(negs) 352 | 353 | # compute query features 354 | unnormalize_q = self.encoder_q(im_q) # queries: NxC 355 | 356 | unnormalize_q, q_avgpool = unnormalize_q 357 | if self.sep_head: 358 | q_cycle = self.q_cycle_fc(q_avgpool) 359 | q_cycle = nn.functional.normalize(q_cycle, dim=1) 360 | q = nn.functional.normalize(unnormalize_q, dim=1) 361 | #print(q.shape) 362 | meta = {} 363 | 364 | # compute key features 365 | with torch.no_grad(): # no gradient to keys 366 | self._momentum_update_key_encoder() # update the key encoder 367 | if (self.cycle_back_cls or self.moco_random_video_frame_as_pos) and cls_candidates is not None: 368 | n, k_can = cls_candidates.shape[:2] 369 | concat_im = torch.cat((im_k, cls_candidates.flatten(0, 1)), dim=0) 370 | 371 | 372 | im_k, idx_unshuffle = self._batch_shuffle_ddp(concat_im) 373 | k = self.encoder_k(im_k) # keys: NxC 374 | 375 | if self.sep_head: 376 | k, k_avgpool = k 377 | k_cycle = self.k_cycle_fc(k_avgpool) 378 | k_cycle = nn.functional.normalize(k_cycle, dim=1) 379 | k_cycle = self._batch_unshuffle_ddp(k_cycle, idx_unshuffle) 380 | k_cycle, can_cycle = torch.split(k_cycle, [n, n * k_can], dim=0) 381 | can_cycle = can_cycle.view(n, k_can, -1) 382 | k = nn.functional.normalize(k, dim=1) 383 | 384 | # undo shuffle 385 | k = self._batch_unshuffle_ddp(k, idx_unshuffle) 386 | k, can = torch.split(k, [n, n * k_can], dim=0) 387 | can = can.view(n, k_can, -1) 388 | 389 | else: 390 | im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k) 391 | k = self.encoder_k(im_k) # keys: NxC 392 | if self.sep_head: 393 | k, k_avgpool = k 394 | k_cycle = self.k_cycle_fc(k_avgpool) 395 | k_cycle = nn.functional.normalize(k_cycle, dim=1) 396 | k_cycle = self._batch_unshuffle_ddp(k_cycle, idx_unshuffle) 397 | k = nn.functional.normalize(k, dim=1) 398 | # undo shuffle 399 | k = self._batch_unshuffle_ddp(k, idx_unshuffle) 400 | 401 | 402 | # compute logits 403 | # Einstein sum is more intuitive 404 | # positive logits: Nx1 405 | l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) 406 | # negative logits: NxK 407 | l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) 408 | 409 | # logits: Nx(1+K) 410 | logits = torch.cat([l_pos, l_neg], dim=1) 411 | 412 | # apply temperature 413 | logits /= self.T 414 | 415 | if self.sep_head: 416 | l_pos_cycle = torch.einsum('nc,nc->n', [q_cycle, k_cycle]).unsqueeze(-1) 417 | # negative logits: NxK 418 | l_neg_cycle = torch.einsum('nc,ck->nk', [q_cycle, self.queue_cycle.clone().detach()]) 419 | else: 420 | l_pos_cycle = l_pos 421 | l_neg_cycle = l_neg 422 | 423 | 424 | if self.soft_nn: 425 | if self.cycle_neg_only: 426 | """ In case we want topk based sampling""" 427 | if self.soft_nn_topk_support: 428 | nn_embs, perm_indices,sampled_indices_tsne,wts_tsne = \ 429 | self.cycle_back_topk_queue_without_self(l_neg_cycle, self.queue_cycle \ 430 | if self.sep_head else self.queue,can.device) 431 | else: 432 | nn_embs, perm_indices,sampled_indices_tsne,wts_tsne = \ 433 | self.cycle_back_queue_without_self(l_neg_cycle, self.queue_cycle \ 434 | if self.sep_head else self.queue,can.device) 435 | else: 436 | pos_neigbor = q_cycle if self.sep_head else q 437 | nn_embs, perm_indices = \ 438 | self.cycle_back_queue_with_self(l_pos_cycle, l_neg_cycle, 439 | pos_neigbor, 440 | self.queue_cycle if self.sep_head else self.queue, 441 | meta=meta, 442 | q=q, k=k 443 | ) 444 | 445 | nn_embs = F.normalize(nn_embs, dim=1) 446 | 447 | 448 | 449 | sampled_indices_logits = sampled_indices_tsne.clone() 450 | 451 | # labels: positive key indicators 452 | labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() 453 | 454 | if self.cycle_back_cls: 455 | assert cls_candidates is not None, "cls candidates can not be None" 456 | """ In case we want topk based sampling""" 457 | if self.negative_use: 458 | if not self.cross_neg_topk_mining: 459 | back_logits, cycle_back_labels = \ 460 | self.get_logits_cycle_back_cls_video_as_pos_negs(nn_embs, 461 | can_cycle if self.sep_head else can, 462 | self.negatives) 463 | else: 464 | if not self.anchor_reverse_cross: 465 | back_logits, cycle_back_labels = \ 466 | self.get_logits_cycle_back_cls_video_as_pos_topk_negs(nn_embs, 467 | can_cycle if self.sep_head else can, 468 | self.negatives) 469 | else: 470 | back_logits, cycle_back_labels = \ 471 | self.get_logits_cycle_back_cls_video_as_pos_topk_negs_anchor_rev(nn_embs, 472 | can_cycle if self.sep_head else can, 473 | self.negatives) 474 | else: 475 | if self.soft_nn_topk_support: 476 | back_logits, cycle_back_labels = \ 477 | self.get_logits_topk_cycle_back_cls_video_as_pos(nn_embs, 478 | can_cycle if self.sep_head else can, 479 | self.queue_cycle if self.sep_head else self.queue, 480 | sampled_indices_logits, 481 | indices=indices) 482 | else: 483 | back_logits, cycle_back_labels = \ 484 | self.get_logits_cycle_back_cls_video_as_pos(nn_embs, 485 | can_cycle if self.sep_head else can, 486 | self.queue_cycle if self.sep_head else self.queue, 487 | perm_indices, 488 | indices=indices) 489 | 490 | # random sample one frame from the same video of im_q / im_k as positive, i.e. intra-video objective 491 | if self.moco_random_video_frame_as_pos: 492 | if self.negative_use: 493 | if not self.intranegs_only_two: 494 | """ Using N complete """ 495 | # can: n x k x 128 496 | pos_indices = torch.randint(high=can.shape[1], size=(q.shape[0],), device=can.device) 497 | pos_indices = pos_indices.unsqueeze(1).unsqueeze(2).expand(-1, -1, can.shape[2]) 498 | 499 | logits_neg = torch.matmul(q, self.negatives.clone().detach()) 500 | logits_pos = torch.einsum('nc, nc->n', [q, torch.gather(can, 1, pos_indices).squeeze()]).unsqueeze(-1) 501 | logits = torch.cat([logits_pos, logits_neg], dim=1) 502 | logits /= self.T 503 | labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() 504 | 505 | if self.intranegs_only_two and (not self.qcap_include): 506 | """ Using n1,n2 .... """ 507 | pos_indices = torch.randint(high=can.shape[1], size=(q.shape[0],), device=can.device) 508 | pos_indices = pos_indices.unsqueeze(1).unsqueeze(2).expand(-1, -1, can.shape[2]) 509 | 510 | logits_neg = None 511 | num_samples_curr_gpu = q.shape[0] 512 | num_negs_curr_gpu = negs_curr_gpu.shape[0] 513 | negs_per_q = int(num_negs_curr_gpu/num_samples_curr_gpu) 514 | 515 | for loop_variable in range(num_samples_curr_gpu): 516 | q_sample = q[loop_variable].unsqueeze(0) ## 1,128 517 | 518 | if loop_variable ==0: 519 | negatives_to_be_used = negs_curr_gpu[:loop_variable+negs_per_q].T ## 128x2 520 | logits_neg = torch.matmul(q_sample,negatives_to_be_used) 521 | 522 | else: 523 | negatives_to_be_used = negs_curr_gpu[(loop_variable*negs_per_q):(loop_variable*negs_per_q)+negs_per_q].T ## 128x2 524 | logits_curr_sample = torch.matmul(q_sample,negatives_to_be_used) 525 | logits_neg = torch.cat((logits_neg,logits_curr_sample),axis=0) ## (loop_variable+1)x 2 526 | 527 | logits_pos = torch.einsum('nc, nc->n', [q, torch.gather(can, 1, pos_indices).squeeze()]).unsqueeze(-1) ## 2,128 x 2,128 kind of multiplication 528 | logits = torch.cat([logits_pos, logits_neg], dim=1) 529 | logits /= self.T 530 | labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() 531 | 532 | 533 | if self.single_loss_intra_inter and self.intranegs_only_two and self.qcap_include and only_qcap: 534 | """ Using only n_cap""" 535 | 536 | pos_indices = torch.randint(high=can.shape[1], size=(q.shape[0],), device=can.device) 537 | pos_indices = pos_indices.unsqueeze(1).unsqueeze(2).expand(-1, -1, can.shape[2]) 538 | N_set = self.negatives.detach().clone() 539 | 540 | logits_neg = None 541 | num_samples_curr_gpu = q.shape[0] 542 | num_negs_curr_gpu = negs_curr_gpu.shape[0] 543 | negs_per_q = int(num_negs_curr_gpu/num_samples_curr_gpu) 544 | 545 | for loop_variable in range(num_samples_curr_gpu): 546 | q_sample = q[loop_variable].unsqueeze(0) ## 1,128 547 | 548 | if loop_variable ==0: 549 | negatives_to_be_used = negs_curr_gpu[:loop_variable+negs_per_q].T ## 128x2 550 | 551 | sim_wts = torch.matmul(q_sample,N_set) 552 | (wts_selected,index_used) = sim_wts.topk(self.single_loss_ncap_support_size,dim=1) 553 | wts_selected = wts_selected/self.soft_nn_T 554 | wts_selected = torch.nn.functional.softmax(wts_selected,dim=1) 555 | selected_negs = N_set[:,index_used[0]] 556 | n_cap = torch.matmul(wts_selected,selected_negs.T) 557 | if self.mean_neighbors: 558 | n_cap = torch.mean(N_set,1,True).T 559 | logits_neg = torch.matmul(q_sample,n_cap.T) 560 | 561 | else: 562 | negatives_to_be_used = negs_curr_gpu[(loop_variable*negs_per_q):(loop_variable*negs_per_q)+negs_per_q].T ## 128x2 563 | 564 | sim_wts = torch.matmul(q_sample,N_set) 565 | (wts_selected,index_used) = sim_wts.topk(self.single_loss_ncap_support_size,dim=1) 566 | wts_selected = wts_selected/self.soft_nn_T 567 | wts_selected = torch.nn.functional.softmax(wts_selected,dim=1) 568 | selected_negs = N_set[:,index_used[0]] 569 | n_cap = torch.matmul(wts_selected,selected_negs.T) 570 | if self.mean_neighbors: 571 | n_cap = torch.mean(N_set,1,True).T 572 | logits_curr_sample = torch.matmul(q_sample,n_cap.T) 573 | logits_neg = torch.cat((logits_neg,logits_curr_sample),axis=0) ## (loop_variable+1)x 2 574 | 575 | logits_pos = torch.einsum('nc, nc->n', [q, torch.gather(can, 1, pos_indices).squeeze()]).unsqueeze(-1) ## 2,128 x 2,128 kind of multiplication 576 | logits = torch.cat([logits_pos, logits_neg], dim=1) 577 | logits /= self.T 578 | labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() 579 | 580 | 581 | 582 | if self.single_loss_intra_inter and self.intranegs_only_two and self.qcap_include and not(only_qcap): 583 | """ Using n1,n2.... with n_cap""" 584 | 585 | pos_indices = torch.randint(high=can.shape[1], size=(q.shape[0],), device=can.device) 586 | pos_indices = pos_indices.unsqueeze(1).unsqueeze(2).expand(-1, -1, can.shape[2]) 587 | N_set = self.negatives.detach().clone() 588 | 589 | logits_neg = None 590 | num_samples_curr_gpu = q.shape[0] 591 | num_negs_curr_gpu = negs_curr_gpu.shape[0] 592 | negs_per_q = int(num_negs_curr_gpu/num_samples_curr_gpu) 593 | 594 | for loop_variable in range(num_samples_curr_gpu): 595 | q_sample = q[loop_variable].unsqueeze(0) ## 1,128 596 | 597 | if loop_variable ==0: 598 | negatives_to_be_used = negs_curr_gpu[:loop_variable+negs_per_q].T ## 128x2 599 | 600 | sim_wts = torch.matmul(q_sample,N_set) 601 | (wts_selected,index_used) = sim_wts.topk(self.single_loss_ncap_support_size,dim=1) 602 | wts_selected = wts_selected/self.soft_nn_T 603 | wts_selected = torch.nn.functional.softmax(wts_selected,dim=1) 604 | selected_negs = N_set[:,index_used[0]] 605 | n_cap = torch.matmul(wts_selected,selected_negs.T) 606 | if self.mean_neighbors: 607 | n_cap = torch.mean(N_set,1,True).T 608 | negatives_to_be_used = torch.cat((negatives_to_be_used,n_cap.T),axis=1) 609 | 610 | logits_neg = torch.matmul(q_sample,negatives_to_be_used) 611 | 612 | else: 613 | negatives_to_be_used = negs_curr_gpu[(loop_variable*negs_per_q):(loop_variable*negs_per_q)+negs_per_q].T ## 128x2 614 | 615 | sim_wts = torch.matmul(q_sample,N_set) 616 | (wts_selected,index_used) = sim_wts.topk(self.single_loss_ncap_support_size,dim=1) 617 | wts_selected = wts_selected/self.soft_nn_T 618 | wts_selected = torch.nn.functional.softmax(wts_selected,dim=1) 619 | selected_negs = N_set[:,index_used[0]] 620 | n_cap = torch.matmul(wts_selected,selected_negs.T) 621 | if self.mean_neighbors: 622 | n_cap = torch.mean(N_set,1,True).T 623 | negatives_to_be_used = torch.cat((negatives_to_be_used,n_cap.T),axis=1) 624 | 625 | logits_curr_sample = torch.matmul(q_sample,negatives_to_be_used) 626 | logits_neg = torch.cat((logits_neg,logits_curr_sample),axis=0) ## (loop_variable+1)x 2 627 | 628 | logits_pos = torch.einsum('nc, nc->n', [q, torch.gather(can, 1, pos_indices).squeeze()]).unsqueeze(-1) ## 2,128 x 2,128 kind of multiplication 629 | logits = torch.cat([logits_pos, logits_neg], dim=1) 630 | logits /= self.T 631 | labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() 632 | 633 | else: 634 | pos_indices = torch.randint(high=can.shape[1], size=(q.shape[0],), device=can.device) 635 | pos_indices = pos_indices.unsqueeze(1).unsqueeze(2).expand(-1, -1, can.shape[2]) 636 | 637 | logits_neg = torch.matmul(q, self.queue_cycle.clone().detach()) 638 | logits_pos = torch.einsum('nc, nc->n', [q, torch.gather(can, 1, pos_indices).squeeze()]).unsqueeze(-1) 639 | logits = torch.cat([logits_pos, logits_neg], dim=1) 640 | logits /= self.T 641 | labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() 642 | 643 | self._dequeue_and_enqueue(k, 644 | keys_cycle=k_cycle if self.sep_head else None, 645 | labels=cls_labels, indices=indices 646 | ) 647 | 648 | if self.soft_nn: 649 | if not self.single_loss_intra_inter: 650 | return logits, labels, back_logits, cycle_back_labels, meta 651 | else: 652 | return logits, labels, meta 653 | else: 654 | return logits, labels, meta 655 | 656 | def cycle_back_queue_without_self(self, l_neg_cycle, queue_cycle,device): 657 | if self.soft_nn_support != -1: 658 | perm_indices = torch.randperm(self.cycle_K if self.sep_head else self.K) 659 | sampled_indices = perm_indices[:self.soft_nn_support] 660 | sampled_l_neg = l_neg_cycle[:, sampled_indices] 661 | weights = nn.functional.softmax(sampled_l_neg / self.soft_nn_T, dim=1) 662 | nn_embs = torch.matmul(weights, queue_cycle.clone().detach().transpose(1, 0)[sampled_indices]) 663 | else: 664 | weights = nn.functional.softmax(l_neg_cycle / self.soft_nn_T, dim=1) 665 | nn_embs = torch.matmul(weights, queue_cycle.clone().detach().transpose(1, 0)) 666 | perm_indices = None 667 | 668 | sampled_indices_ret = sampled_indices.to(device) 669 | sampled_indices_ret = torch.reshape(sampled_indices_ret,(1,sampled_indices_ret.shape[0])) 670 | sampled_indices_ret = torch.cat((sampled_indices_ret,sampled_indices_ret),0) 671 | return nn_embs, perm_indices, sampled_indices_ret,weights 672 | 673 | def cycle_back_topk_queue_without_self(self, l_neg_cycle, queue_cycle,device): 674 | assert self.soft_nn_support != -1 #TODO: what to do when soft_nn_support == -1 i.e all elements 675 | #are to be used for soft nn calculation 676 | 677 | perm_indices = None 678 | 679 | ## Starting with all the vectors at first 680 | sampled_l_neg = l_neg_cycle 681 | weights = nn.functional.softmax(sampled_l_neg / self.soft_nn_T, dim=1) 682 | (weights,indices) = weights.topk(self.soft_nn_support,dim=1) 683 | 684 | weights = weights.unsqueeze(1) ## Num_samples x 1 x self.soft_nn_support 685 | """ Getting Correct Weights for each sample in the batch based on topk""" 686 | num_samples = weights.shape[0] 687 | queue_cycle_wts = None 688 | 689 | 690 | for i in range(num_samples): 691 | if i ==0: 692 | queue_cycle_wts = queue_cycle.clone().detach().transpose(1, 0)[indices[i]].unsqueeze(0) 693 | else: 694 | temp_var = queue_cycle.clone().detach().transpose(1, 0)[indices[i]].unsqueeze(0) 695 | queue_cycle_wts = torch.cat((queue_cycle_wts,temp_var),0) 696 | 697 | 698 | nn_embs = torch.bmm(weights, queue_cycle_wts).squeeze() 699 | weights = torch.squeeze(weights) 700 | 701 | """ Some non important stuff for tsne plots/book keeping""" 702 | sampled_indices_ret = indices.to(device) 703 | 704 | return nn_embs, perm_indices, sampled_indices_ret,weights 705 | 706 | def cycle_back_queue_with_self(self, l_pos_cycle, l_neg_cycle, q_cycle, queue_cycle, meta=None, q=None, k=None): 707 | if self.soft_nn_support != -1: 708 | perm_indices = torch.randperm(self.cycle_K if self.sep_head else self.K) 709 | sampled_indices = perm_indices[:self.soft_nn_support] 710 | sampled_l_neg = l_neg_cycle[:, sampled_indices] 711 | logits_cycle = torch.cat([l_pos_cycle, sampled_l_neg], dim=1) 712 | else: 713 | logits_cycle = torch.cat([l_pos_cycle, l_neg_cycle], dim=1) 714 | perm_indices = None 715 | 716 | weights = nn.functional.softmax(logits_cycle / self.soft_nn_T, dim=1) 717 | num_neg = self.soft_nn_support if self.soft_nn_support != -1 \ 718 | else (self.cycle_K if self.sep_head else self.K) 719 | weights_pos, weights_neg = torch.split(weights, [1, num_neg], dim=1) 720 | nn_embs_pos = weights_pos * q_cycle 721 | if self.soft_nn_support != -1: 722 | nn_embs_neg = torch.matmul(weights_neg, queue_cycle.clone().detach().transpose(1, 0)[sampled_indices]) 723 | else: 724 | nn_embs_neg = torch.matmul(weights_neg, queue_cycle.clone().detach().transpose(1, 0)) 725 | nn_embs = nn_embs_pos + nn_embs_neg # 726 | return nn_embs, perm_indices 727 | 728 | def get_logits_cycle_back_cls_video_as_pos(self, nn_embs, can_cycle, queue_cycle, perm_indices, indices=None): 729 | if self.soft_nn_support == -1: 730 | back_logits_neg = torch.matmul(nn_embs, queue_cycle.clone().detach()) 731 | else: 732 | remain_indices = perm_indices[self.soft_nn_support:] 733 | back_logits_neg = torch.matmul(nn_embs, queue_cycle.clone().detach()[:, remain_indices]) 734 | pos_indices = torch.randint(high=can_cycle.shape[1], size=(nn_embs.shape[0],), device=can_cycle.device) 735 | pos_indices = pos_indices.unsqueeze(1).unsqueeze(2).expand(-1, -1, can_cycle.shape[2]) 736 | back_logits_pos = torch.einsum('nc, nc->n', 737 | [nn_embs, torch.gather(can_cycle, 1, pos_indices).squeeze()]).unsqueeze(-1) 738 | 739 | back_logits = torch.cat([back_logits_pos, back_logits_neg], dim=1) 740 | back_logits /= self.T 741 | 742 | cycle_back_labels = torch.zeros(back_logits.shape[0], dtype=torch.long).cuda() 743 | return back_logits, cycle_back_labels 744 | 745 | def get_logits_cycle_back_cls_video_as_pos_negs(self, nn_embs, can_cycle, negatives ): 746 | 747 | back_logits_neg = torch.matmul(nn_embs, negatives.clone().detach()) 748 | pos_indices = torch.randint(high=can_cycle.shape[1], size=(nn_embs.shape[0],), device=can_cycle.device) 749 | pos_indices = pos_indices.unsqueeze(1).unsqueeze(2).expand(-1, -1, can_cycle.shape[2]) 750 | back_logits_pos = torch.einsum('nc, nc->n', 751 | [nn_embs, torch.gather(can_cycle, 1, pos_indices).squeeze()]).unsqueeze(-1) 752 | 753 | back_logits = torch.cat([back_logits_pos, back_logits_neg], dim=1) 754 | back_logits /= self.T 755 | 756 | cycle_back_labels = torch.zeros(back_logits.shape[0], dtype=torch.long).cuda() 757 | return back_logits, cycle_back_labels 758 | 759 | def get_logits_cycle_back_cls_video_as_pos_topk_negs(self, nn_embs, can_cycle, negatives ): 760 | 761 | similarity_scores = torch.matmul(nn_embs,negatives.clone().detach()) ## 2x32 762 | (wts,indices_hard_negs) = similarity_scores.topk(self.cross_neg_topk_support_size,dim=1) 763 | 764 | #negatives_copy = negatives.clone().T #128,32 765 | topk_negatives = None 766 | num_samples = nn_embs.shape[0] 767 | 768 | for itr in range(num_samples): 769 | if itr ==0: 770 | topk_negatives = negatives[:,indices_hard_negs[itr]].unsqueeze(0) 771 | else: 772 | temp_element = negatives[:,indices_hard_negs[itr]].unsqueeze(0) 773 | topk_negatives = torch.cat((topk_negatives,temp_element),axis=0) 774 | 775 | nn_embs_reshaped = nn_embs.unsqueeze(1) 776 | back_logits_neg = torch.bmm(nn_embs_reshaped, topk_negatives) 777 | back_logits_neg = back_logits_neg.squeeze() 778 | 779 | 780 | pos_indices = torch.randint(high=can_cycle.shape[1], size=(nn_embs.shape[0],), device=can_cycle.device) 781 | pos_indices = pos_indices.unsqueeze(1).unsqueeze(2).expand(-1, -1, can_cycle.shape[2]) 782 | back_logits_pos = torch.einsum('nc, nc->n', 783 | [nn_embs, torch.gather(can_cycle, 1, pos_indices).squeeze()]).unsqueeze(-1) 784 | 785 | back_logits = torch.cat([back_logits_pos, back_logits_neg], dim=1) 786 | back_logits /= self.T 787 | 788 | cycle_back_labels = torch.zeros(back_logits.shape[0], dtype=torch.long).cuda() 789 | return back_logits, cycle_back_labels 790 | 791 | 792 | def get_logits_cycle_back_cls_video_as_pos_topk_negs_anchor_rev(self, nn_embs, can_cycle, negatives ): 793 | 794 | #print(can_cycle.shape) 795 | can_cycle_sim = can_cycle.clone().squeeze() 796 | 797 | similarity_scores = torch.matmul(can_cycle_sim,negatives.clone().detach()) ## 2x32 798 | (wts,indices_hard_negs) = similarity_scores.topk(self.cross_neg_topk_support_size,dim=1) 799 | 800 | #negatives_copy = negatives.clone().T #128,32 801 | topk_negatives = None 802 | num_samples = nn_embs.shape[0] 803 | 804 | for itr in range(num_samples): 805 | if itr ==0: 806 | topk_negatives = negatives[:,indices_hard_negs[itr]].unsqueeze(0) 807 | else: 808 | temp_element = negatives[:,indices_hard_negs[itr]].unsqueeze(0) 809 | topk_negatives = torch.cat((topk_negatives,temp_element),axis=0) 810 | 811 | can_copy = can_cycle.clone() 812 | back_logits_neg = torch.bmm(can_copy, topk_negatives) 813 | back_logits_neg = back_logits_neg.squeeze() 814 | 815 | pos_indices = torch.randint(high=can_cycle.shape[1], size=(nn_embs.shape[0],), device=can_cycle.device) 816 | pos_indices = pos_indices.unsqueeze(1).unsqueeze(2).expand(-1, -1, can_cycle.shape[2]) 817 | back_logits_pos = torch.einsum('nc, nc->n', 818 | [nn_embs, torch.gather(can_cycle, 1, pos_indices).squeeze()]).unsqueeze(-1) 819 | 820 | back_logits = torch.cat([back_logits_pos, back_logits_neg], dim=1) 821 | back_logits /= self.T 822 | 823 | cycle_back_labels = torch.zeros(back_logits.shape[0], dtype=torch.long).cuda() 824 | return back_logits, cycle_back_labels 825 | 826 | def get_logits_topk_cycle_back_cls_video_as_pos(self, nn_embs, can_cycle, queue_cycle, sampled_indices, indices=None): 827 | assert self.soft_nn_support != -1 #TODO: what to do when soft_nn_support == -1 i.e all elements 828 | #are to be used for soft nn calculation 829 | 830 | num_samples = nn_embs.shape[0] 831 | 832 | """ Getting the leftover sample indices from queue""" 833 | remain_indices = None 834 | 835 | for i in range(num_samples): 836 | used_indices = sampled_indices[i] 837 | use_indices_np = used_indices.cpu().detach().numpy() 838 | total_possible_indices = np.arange(0,queue_cycle.shape[1],1) 839 | remain_indices_temp = [i for i in total_possible_indices if i not in use_indices_np] 840 | 841 | remain_indices_temp = torch.Tensor(remain_indices_temp).unsqueeze(0) 842 | remain_indices_temp = remain_indices_temp.long() 843 | 844 | if i ==0: 845 | remain_indices = remain_indices_temp 846 | else: 847 | remain_indices = torch.cat((remain_indices,remain_indices_temp),0) 848 | 849 | 850 | """ Getting the leftover sample embeddings based on the indices generated above""" 851 | queue_cycle_wts= None 852 | for i in range(num_samples): 853 | remain_index_sample = remain_indices[i] 854 | 855 | if i ==0: 856 | queue_cycle_wts = queue_cycle.clone().detach()[:, remain_index_sample].unsqueeze(0) 857 | else: 858 | temp_var = queue_cycle.clone().detach()[:, remain_index_sample].unsqueeze(0) 859 | queue_cycle_wts = torch.cat((queue_cycle_wts,temp_var),0) 860 | 861 | nn_embs = nn_embs.unsqueeze(1) 862 | 863 | back_logits_neg = torch.bmm(nn_embs, queue_cycle_wts) 864 | 865 | nn_embs = nn_embs.squeeze() 866 | back_logits_neg = back_logits_neg.squeeze() 867 | 868 | pos_indices = torch.randint(high=can_cycle.shape[1], size=(nn_embs.shape[0],), device=can_cycle.device) 869 | pos_indices = pos_indices.unsqueeze(1).unsqueeze(2).expand(-1, -1, can_cycle.shape[2]) 870 | back_logits_pos = torch.einsum('nc, nc->n', 871 | [nn_embs, torch.gather(can_cycle, 1, pos_indices).squeeze()]).unsqueeze(-1) 872 | 873 | back_logits = torch.cat([back_logits_pos, back_logits_neg], dim=1) 874 | back_logits /= self.T 875 | 876 | cycle_back_labels = torch.zeros(back_logits.shape[0], dtype=torch.long).cuda() 877 | return back_logits, cycle_back_labels 878 | 879 | 880 | # utils 881 | @torch.no_grad() 882 | def concat_all_gather(tensor): 883 | """ 884 | Performs all_gather operation on the provided tensors. 885 | *** Warning ***: torch.distributed.all_gather has no gradient. 886 | """ 887 | tensors_gather = [torch.ones_like(tensor) 888 | for _ in range(torch.distributed.get_world_size())] 889 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 890 | 891 | output = torch.cat(tensors_gather, dim=0) 892 | return output --------------------------------------------------------------------------------