├── README.md ├── TCE.yaml ├── config ├── __init__.py ├── default.py └── pretrain_kinetics400miningr_finetune_UCF101_resnet18.yaml ├── datasets ├── finetuning │ ├── UCF101.py │ └── __init__.py └── pretraining │ ├── __init__.py │ ├── augmentation.py │ └── kinetics400.py ├── finetune.py ├── images ├── TCE.png └── bowling_tsne_example.gif ├── loss ├── NCEAverage.py ├── NCECriterion.py └── __init__.py ├── network ├── __init__.py ├── model.py └── resnet.py ├── pretrain.py ├── utils ├── __init__.py ├── dataset_utils.py ├── eval_utils.py ├── log_utils.py ├── network_utils.py ├── optimizer.py └── tsne_utils.py └── visualise_tsne.py /README.md: -------------------------------------------------------------------------------- 1 | # Temporally Coherent Embeddings for Self-Supervised Video Representation Learning 2 | This repository contains the code implementation used in the ICPR2020 paper Temporally Coherent Embeddings for Self-Supervised Video Representation Learning (TCE). \[[arXiv](https://arxiv.org/abs/2004.02753)] \[[Website](https://csiro-robotics.github.io/TCE-Webpage/)] Our contributions in this repository are: 3 | - A Pytorch implementation of the self-supervised training used in the TCE paper 4 | - A Pytorch implementation of action recognition fine-tuning 5 | - Pre-trained checkpoints for models trained using the TCE self-supervised training paradigm 6 | - A Pytorch implementation of t-SNE visualisations of the network output 7 | 8 | ![Network Architecture](images/TCE.png) 9 | 10 | We benchmark our code on Split 1 of the UCF101 action recognition dataset, providing pre-trained models for our downstream and upstream training. See [Models](#models) for our provided models and Getting Started (#getting-started) for for instructions on training and evaluation. 11 | 12 | If you find this repo useful for your research, please consider citing the paper 13 | ``` 14 | @inproceedings{knights2020tce, 15 | title={Temporally Coherent Embeddings for Self-Supervised Video Representation Learning}, 16 | author={Joshua Knights and Ben Harwood and Daniel Ward and Anthony Vanderkop and Olivia Mackenzie-Ross and Peyman Moghadam}, 17 | booktitle={25th International Conference on Pattern Recognition (ICPR)}, 18 | year={2020} 19 | } 20 | 21 | ``` 22 | 23 | 24 | ## Updates 25 | - 23/04/2020 : Initial Commit 26 | - 30/11/2020 : ICPR Update 27 | 28 | ## Table of Contents 29 | 30 | - [Data Preparation](#data-preparation) 31 | - [Installation](#installation) 32 | - [Models](#models) 33 | - [Getting Started](#getting-started) 34 | - [Acknowledgements](#acknowledgements) 35 | 36 | 37 | ## Data Preparation 38 | 39 | 40 | ### Kinetics400 41 | Kinetics400 videos can be downloaded and split into frames directly from [Showmax/kinetics-downloader](https://github.com/Showmax/kinetics-downloader) 42 | 43 | The file directory should have the following layout: 44 | ``` 45 | ├── kinetics400/train 46 | | 47 | ├── CLASS_001 48 | ├── CLASS_002 49 | . 50 | . 51 | . 52 | CLASS_400 53 | | 54 | ├── VID_001 55 | ├── VID_002 56 | . 57 | . 58 | . 59 | ├── VID_### 60 | | 61 | ├── frame1.jpg 62 | ├── frame2.jpg 63 | . 64 | . 65 | . 66 | ├── frame###.jpg 67 | ``` 68 | Once the dataset is downloaded and split into frames, edit the following parameters in config/default.py to point towards the frames and splits: 69 | - DATASET.KINETICS400.FRAMES_PATH = /path/to/kinetics400/train 70 | 71 | ### UCF101 72 | 73 | UCF101 frames and splits can be downloaded directly from [feichtenhofer/twostreamfusion](https://github.com/feichtenhofer/twostreamfusion) 74 | 75 | ``` 76 | wget http://ftp.tugraz.at/pub/feichtenhofer/tsfusion/data/ucf101_jpegs_256.zip.001 77 | wget http://ftp.tugraz.at/pub/feichtenhofer/tsfusion/data/ucf101_jpegs_256.zip.002 78 | wget http://ftp.tugraz.at/pub/feichtenhofer/tsfusion/data/ucf101_jpegs_256.zip.003 79 | 80 | cat ucf101_jpegs_256.zip* > ucf101_jpegs_256.zip 81 | unzip ucf101_jpegs_256.zip 82 | ``` 83 | The file directory should have the following layout: 84 | 85 | ``` 86 | ├── UCF101 87 | | 88 | ├── v_{_CLASS_001}_g01_c01 89 | . | 90 | . ├── frame000001.jpg 91 | . ├── frame000002.jpg 92 | . . 93 | . . 94 | . ├── frame000###.jpg 95 | . 96 | ├── v_{_CLASS_101}_g##_c## 97 | | 98 | ├── frame000001.jpg 99 | ├── frame000002.jpg 100 | . 101 | . 102 | ├── frame000###.jpg 103 | ``` 104 | 105 | Once the dataset is downloaded and decompressed, edit the following parameters in config/default.py to point towards the frames and splits: 106 | - DATASET.UCF101.FRAMES_PATH = /path/to/UCF101_frames 107 | - DATASET.UCF101.SPLITS_PATH = /path/to/UCF101_splits 108 | 109 | 110 | 111 | 112 | 113 | ## Installation 114 | 115 | 116 | TCE is built using Python == 3.7.1 and PyTorch == 1.7.0 117 | 118 | We use Conda to setup the Python environment for this repository. In order to create the environment, run the following commands from the root directory: 119 | 120 | ``` 121 | conda env create -f TCE.yaml 122 | conda activate TCE 123 | ``` 124 | 125 | Once this is done, also specify a path to save assets (such as dataset pickles for faster setup) to in config.default.py: 126 | - ASSETS_PATH = /path/to/assets/folder 127 | 128 | 129 | 130 | ## Models 131 | 132 | 133 | | Architecture | Pre-Training Dataset | Link | 134 | |-------------- |---------------------- |---------------------------------------------------------------- | 135 | | ResNet-18 | Kinetics400 | [Link](https://cloudstor.aarnet.edu.au/plus/s/kNQKw5ATTbyamg2) | 136 | | ResNet-50 | Kinetics400 | [Link](https://cloudstor.aarnet.edu.au/plus/s/HbWxmhcUbfzQIQf) | 137 | 138 | ## Getting Started 139 | 140 | ### Self-Supervised Training 141 | We provide a script for pre-training with the Kinetics400 dataset using TCE, pretrain.py. To train, run the following script: 142 | 143 | ``` 144 | python finetune.py \ 145 | --cfg config/pretrain_kinetics400miningr_finetune_UCF101_resnet18.yaml \ 146 | TRAIN.PRETRAINING.SAVEDIR /path/to/savedir 147 | ``` 148 | 149 | If resuming from a previous pre-training checkpoint, set the flag `TRAIN.PRETRAINING.CHECKPOINT` to the path to the checkpoint to resume from 150 | 151 | ### Fine-tuning for action recognition 152 | We provide a fine-tuning script for action recognition on the UCF-101 dataset, finetune.py. To train, run the following script: 153 | 154 | ``` 155 | python finetune.py \ 156 | --cfg config/pretrain_kinetics400miningr_finetune_UCF101_resnet18.yaml \ 157 | TRAIN.FINETUNING.CHECKPOINT "/path/to/pretrained_checkpoint" \ 158 | TRAIN.FINETUNING.SAVEDIR "/path/to/savedir" 159 | ``` 160 | 161 | If resuming training from an earlier finetuning checkpoint, set the flag `TRAIN.FINETUNING.RESUME` to True 162 | 163 | 164 | 165 | 166 | ### Visualisation 167 | 168 | ![vid](images/bowling_tsne_example.gif) 169 | 170 | In order to demonstrate the ability of our approach to create temporally coherent embeddings, we provide a package to create t-SNE visualisations of our features similar to those found in the paper. This package can also be applied to other approaches and network architectures. 171 | 172 | The files in this repository used for generating t-SNE visualisations are: 173 | - `visualise_tsne.py` Is a wrapper for t-SNE and our network architecture for end-to-end generation of the t-SNE 174 | - `utils/tsne_utils.py` Contains t-SNE functionality for reducing the dimensionality of an array of embedded features for plotting, as well as tools to create an animated visualisation of the embedding's behaviour over time 175 | 176 | The following flags can be used as inputs for `make_tsne.py`: 177 | - `--cfg` : Path to config file 178 | - `--target` : Path to video to visualise t-SNE for. This video can either be a video file (avi, mp4) or a directory of images representing frames 179 | - `--ckpt` : Path to the model chekpoint to visualise the embedding space for 180 | - `--gif` : Use to visualise the change in the embedding space over time alongside the input video as a gif file 181 | - `--fps` : Set the framerate of the gif 182 | - `--save` : Path to save the output t-SNE to 183 | 184 | To visualise the embeddings from TCE, download our self-supervised model above and use the following command to visualise our embedding space as a gif: 185 | 186 | ``` 187 | python visualise_tsne.py 188 | --cfg config/pretrain_kinetics400miningr_finetune_UCF101_resnet18.yaml \ 189 | --target "/path/to/target/video" \ 190 | --ckpt "/path/to/TCE_checkpoint" \ 191 | --gif \ 192 | --fps 25 \ 193 | --save "/path/to/save/folder/t-SNE.gif" 194 | ``` 195 | 196 | Alternatively, to visualise the t-SNE as a PNG image use the following: 197 | 198 | ``` 199 | python visualise_tsne.py 200 | --cfg config/pretrain_kinetics400miningr_finetune_UCF101_resnet18.yaml \ 201 | --target "/path/to/target/video" \ 202 | --ckpt "/path/to/TCE_checkpoint" \ 203 | --save "/path/to/save/folder/t-SNE.png" 204 | ``` 205 | 206 | 207 | 208 | 209 | 210 | ## Acknowledgements 211 | Parts of this code base are derived from Yonglong Tian's unsupervised learning algorithm [Contrastive Multiview Coding](https://github.com/HobbitLong/CMC) and Jeffrey Huang's implementation of [action recognition](https://github.com/jeffreyyihuang/two-stream-action-recognition). 212 | 213 | 214 | 215 | 216 | -------------------------------------------------------------------------------- /TCE.yaml: -------------------------------------------------------------------------------- 1 | name: TCE 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - blas=1.0=mkl 9 | - bzip2=1.0.8=h7b6447c_0 10 | - ca-certificates=2020.10.14=0 11 | - cairo=1.14.12=h8948797_3 12 | - certifi=2020.11.8=py37h06a4308_0 13 | - cudatoolkit=10.2.89=hfd86e86_1 14 | - cycler=0.10.0=py37_0 15 | - dbus=1.13.18=hb2f20db_0 16 | - expat=2.2.10=he6710b0_2 17 | - ffmpeg=4.0=hcdf2ecd_0 18 | - fontconfig=2.13.0=h9420a91_0 19 | - freeglut=3.0.0=hf484d3e_5 20 | - freetype=2.10.4=h5ab3b9f_0 21 | - glib=2.63.1=h5a9c865_0 22 | - graphite2=1.3.14=h23475e2_0 23 | - gst-plugins-base=1.14.0=hbbd80ab_1 24 | - gstreamer=1.14.0=hb453b48_1 25 | - harfbuzz=1.8.8=hffaf4a1_0 26 | - hdf5=1.10.2=hba1933b_1 27 | - icu=58.2=he6710b0_3 28 | - imageio=2.9.0=py_0 29 | - intel-openmp=2020.2=254 30 | - jasper=2.0.14=h07fcdf6_1 31 | - joblib=0.17.0=py_0 32 | - jpeg=9b=h024ee3a_2 33 | - kiwisolver=1.3.0=py37h2531618_0 34 | - lcms2=2.11=h396b838_0 35 | - libedit=3.1.20191231=h14c3975_1 36 | - libffi=3.2.1=hf484d3e_1007 37 | - libgcc-ng=9.1.0=hdf63c60_0 38 | - libgfortran-ng=7.3.0=hdf63c60_0 39 | - libglu=9.0.0=hf484d3e_1 40 | - libopencv=3.4.2=hb342d67_1 41 | - libopus=1.3.1=h7b6447c_0 42 | - libpng=1.6.37=hbc83047_0 43 | - libprotobuf=3.13.0.1=h8b12597_0 44 | - libstdcxx-ng=9.1.0=hdf63c60_0 45 | - libtiff=4.1.0=h2733197_1 46 | - libuuid=1.0.3=h1bed415_2 47 | - libuv=1.40.0=h7b6447c_0 48 | - libvpx=1.7.0=h439df22_0 49 | - libxcb=1.14=h7b6447c_0 50 | - libxml2=2.9.10=hb55368b_3 51 | - lz4-c=1.9.2=heb0550a_3 52 | - matplotlib=3.3.2=0 53 | - matplotlib-base=3.3.2=py37h817c723_0 54 | - mkl=2020.2=256 55 | - mkl-service=2.3.0=py37he904b0f_0 56 | - mkl_fft=1.2.0=py37h23d657b_0 57 | - mkl_random=1.1.1=py37h0573a6f_0 58 | - ncurses=6.2=he6710b0_1 59 | - ninja=1.10.1=py37hfd86e86_0 60 | - numpy=1.19.2=py37h54aff64_0 61 | - numpy-base=1.19.2=py37hfa32c7d_0 62 | - olefile=0.46=py37_0 63 | - opencv=3.4.2=py37h6fd60c2_1 64 | - openssl=1.1.1h=h7b6447c_0 65 | - pcre=8.44=he6710b0_0 66 | - pillow=8.0.1=py37he98fc37_0 67 | - pip=20.2.4=py37h06a4308_0 68 | - pixman=0.40.0=h7b6447c_0 69 | - protobuf=3.13.0.1=py37h745909e_1 70 | - py-opencv=3.4.2=py37hb342d67_1 71 | - pyparsing=2.4.7=py_0 72 | - pyqt=5.9.2=py37h05f1152_2 73 | - python=3.7.1=h0371630_7 74 | - python-dateutil=2.8.1=py_0 75 | - python_abi=3.7=1_cp37m 76 | - pytorch=1.7.0=py3.7_cuda10.2.89_cudnn7.6.5_0 77 | - pyyaml=5.3.1=py37hb5d75c8_1 78 | - qt=5.9.7=h5867ecd_1 79 | - readline=7.0=h7b6447c_5 80 | - scikit-learn=0.23.2=py37h0573a6f_0 81 | - scipy=1.5.2=py37h0b6359f_0 82 | - setuptools=50.3.1=py37h06a4308_1 83 | - sip=4.19.8=py37hf484d3e_0 84 | - six=1.15.0=py37h06a4308_0 85 | - sqlite=3.33.0=h62c20be_0 86 | - tensorboardx=2.1=py_0 87 | - termcolor=1.1.0=py37_1 88 | - threadpoolctl=2.1.0=pyh5ca1d4c_0 89 | - tk=8.6.10=hbc83047_0 90 | - torchvision=0.8.1=py37_cu102 91 | - tornado=6.0.4=py37h7b6447c_1 92 | - tqdm=4.51.0=pyhd3eb1b0_0 93 | - typing_extensions=3.7.4.3=py_0 94 | - wheel=0.35.1=pyhd3eb1b0_0 95 | - xz=5.2.5=h7b6447c_0 96 | - yacs=0.1.6=py_0 97 | - yaml=0.2.5=h516909a_0 98 | - zlib=1.2.11=h7b6447c_3 99 | - zstd=1.4.5=h9ceee32_0 100 | prefix: /scratch1/kni101/miniconda3/envs/TCE 101 | 102 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | from .default import _C as config 2 | from .default import update_config -------------------------------------------------------------------------------- /config/default.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from yacs.config import CfgNode as CN 4 | 5 | _C = CN() 6 | 7 | 8 | # ----------------------------------------------------------------------------- 9 | # Misc 10 | # ----------------------------------------------------------------------------- 11 | _C.WORKERS = 8 12 | _C.PRINT_FREQ = 20 13 | _C.ASSETS_PATH = '/scratch1/kni101/TCE/assets' 14 | 15 | # ----------------------------------------------------------------------------- 16 | # Dataset 17 | # ----------------------------------------------------------------------------- 18 | _C.DATASET = CN() 19 | 20 | _C.DATASET.KINETICS400 = CN() 21 | _C.DATASET.KINETICS400.FRAMES_PATH = '/datasets/work/d61-eif/source/kinetics400-frames' 22 | 23 | _C.DATASET.UCF101 = CN() 24 | _C.DATASET.UCF101.FRAMES_PATH = '/datasets/work/d61-eif/source/UCF-101-twostream/jpegs_256' 25 | _C.DATASET.UCF101.SPLITS_PATH = '/datasets/work/d61-eif/source/UCF-101-twostream/UCF_list' 26 | _C.DATASET.UCF101.SPLIT = 1 27 | 28 | _C.DATASET.PRETRAINING = CN() 29 | _C.DATASET.PRETRAINING.DATASET = 'kinetics400' 30 | _C.DATASET.PRETRAINING.MEAN = [0.485, 0.456, 0.406] 31 | _C.DATASET.PRETRAINING.STD = [0.229, 0.224, 0.225] 32 | 33 | _C.DATASET.FINETUNING = CN() 34 | _C.DATASET.FINETUNING.DATASET = 'UCF101' 35 | _C.DATASET.FINETUNING.MEAN = [0.485, 0.456, 0.406] 36 | _C.DATASET.FINETUNING.STD = [0.229, 0.224, 0.225] 37 | 38 | 39 | _C.DATASET.PRETRAINING.TRANSFORMATIONS = CN() 40 | _C.DATASET.PRETRAINING.TRANSFORMATIONS.CROP_SIZE = (224,224) 41 | _C.DATASET.PRETRAINING.TRANSFORMATIONS.SCALE_SIZE = (224,224) 42 | _C.DATASET.PRETRAINING.TRANSFORMATIONS.HORIZONTAL_FLIP = True 43 | _C.DATASET.PRETRAINING.TRANSFORMATIONS.RANDOM_GREY = True 44 | _C.DATASET.PRETRAINING.TRANSFORMATIONS.COLOUR_JITTER = True 45 | 46 | _C.DATASET.FINETUNING.TRANSFORMATIONS = CN() 47 | _C.DATASET.FINETUNING.TRANSFORMATIONS.RESIZE = (224,224) 48 | _C.DATASET.FINETUNING.TRANSFORMATIONS.CROP_SIZE = (224,224) 49 | 50 | 51 | # ----------------------------------------------------------------------------- 52 | # Model 53 | # ----------------------------------------------------------------------------- 54 | _C.MODEL = CN() 55 | _C.MODEL.TRUNK = 'resnet18' 56 | _C.MODEL.PRETRAINING = CN() 57 | _C.MODEL.PRETRAINING.FC_DIM = 128 58 | 59 | _C.MODEL.FINETUNING = CN() 60 | _C.MODEL.FINETUNING.NUM_CLASSES = -1 61 | _C.MODEL.FINETUNING.DROPOUT = 0 62 | 63 | # ----------------------------------------------------------------------------- 64 | # Loss 65 | # ----------------------------------------------------------------------------- 66 | _C.LOSS = CN() 67 | _C.LOSS.PRETRAINING = CN() 68 | _C.LOSS.PRETRAINING.ROTATION_WEIGHT = 1 69 | 70 | _C.LOSS.PRETRAINING.NCE = CN() 71 | _C.LOSS.PRETRAINING.NCE.NEGATIVES = 8192 72 | _C.LOSS.PRETRAINING.NCE.TEMPERATURE = 0.07 73 | _C.LOSS.PRETRAINING.NCE.MOMENTUM = 0.5 74 | 75 | 76 | _C.LOSS.PRETRAINING.MINING = CN() 77 | _C.LOSS.PRETRAINING.MINING.USE_MINING = True 78 | _C.LOSS.PRETRAINING.MINING.THRESH_LOW = -1 79 | _C.LOSS.PRETRAINING.MINING.THRESH_HIGH = 1 80 | _C.LOSS.PRETRAINING.MINING.THRESH_RATE = 5 81 | _C.LOSS.PRETRAINING.MINING.MAX_HARD_NEGATIVES_PERCENTAGE = 1 82 | 83 | # ----------------------------------------------------------------------------- 84 | # Training 85 | # ----------------------------------------------------------------------------- 86 | _C.TRAIN = CN() 87 | _C.TRAIN.PRETRAINING = CN() 88 | _C.TRAIN.PRETRAINING.LEARNING_RATE = 0.03 89 | _C.TRAIN.PRETRAINING.DECAY_EPOCHS = (25,) 90 | _C.TRAIN.PRETRAINING.DECAY_FACTOR = 0.1 91 | 92 | _C.TRAIN.PRETRAINING.SAVEDIR = '' 93 | _C.TRAIN.PRETRAINING.MOMENTUM = 0.9 94 | _C.TRAIN.PRETRAINING.WEIGHT_DECAY = 1e-4 95 | _C.TRAIN.PRETRAINING.SAVE_FREQ = 1 96 | _C.TRAIN.PRETRAINING.BATCH_SIZE = 100 97 | _C.TRAIN.PRETRAINING.EPOCHS = 50 98 | _C.TRAIN.PRETRAINING.RESUME = '' 99 | _C.TRAIN.PRETRAINING.FRAME_PADDING = 0 100 | 101 | _C.TRAIN.FINETUNING = CN() 102 | _C.TRAIN.FINETUNING.CHECKPOINT = '' 103 | _C.TRAIN.FINETUNING.SAVEDIR = '' 104 | _C.TRAIN.FINETUNING.RESUME = False 105 | _C.TRAIN.FINETUNING.BATCH_SIZE = 100 106 | _C.TRAIN.FINETUNING.LEARNING_RATE = 0.05 107 | _C.TRAIN.FINETUNING.EPOCHS = 900 108 | _C.TRAIN.FINETUNING.DECAY_EPOCHS = (375,) 109 | _C.TRAIN.FINETUNING.MOMENTUM = 0.9 110 | _C.TRAIN.FINETUNING.WEIGHT_DECAY = 0 111 | _C.TRAIN.FINETUNING.DECAY_FACTOR = 0.1 112 | _C.TRAIN.FINETUNING.VAL_FREQ = 3 113 | # ----------------------------------------------------------------------------- 114 | # Visualisation 115 | # ----------------------------------------------------------------------------- 116 | _C.VISUALISATION = CN() 117 | _C.VISUALISATION.TSNE = CN() 118 | _C.VISUALISATION.TSNE.CROP_SIZE = 224 119 | _C.VISUALISATION.TSNE.RESIZE = 256 120 | 121 | 122 | def update_config(cfg, args): 123 | cfg.defrost() 124 | if args.cfg != None: 125 | cfg.merge_from_file(args.cfg) 126 | cfg.merge_from_list(args.opts) 127 | cfg.freeze() 128 | -------------------------------------------------------------------------------- /config/pretrain_kinetics400miningr_finetune_UCF101_resnet18.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | PRETRAINING: 3 | DATASET: 'kinetics400' 4 | TRANSFORMATIONS: 5 | CROP_SIZE: (224,224) 6 | SCALE_SIZE: (224,224) 7 | HORIZONTAL_FLIP: True 8 | RANDOM_GREY: True 9 | COLOUR_JITTER: True 10 | FINETUNING: 11 | TRANSFORMATIONS: 12 | RESIZE: (224,224) 13 | CROP_SIZE: (224,224) 14 | MODEL: 15 | TRUNK: 'resnet18' 16 | PRETRAINING: 17 | FC_DIM: 128 18 | FINETUNING: 19 | NUM_CLASSES: 101 20 | DROPOUT: 0 21 | LOSS: 22 | PRETRAINING: 23 | ROTATION_WEIGHT: 1 24 | NCE: 25 | NEGATIVES: 8192 26 | TEMPERATURE: 0.07 27 | MOMENTUM: 0.5 28 | MINING: 29 | THRESH_LOW: -1 30 | THRESH_HIGH: 1 31 | THRESH_RATE: 5 32 | MAX_HARD_NEGATIVES_PERCENTAGE: 1 33 | TRAIN: 34 | PRETRAINING: 35 | SAVEDIR: '' 36 | LEARNING_RATE: 0.03 37 | DECAY_EPOCHS: (25,) 38 | DECAY_FACTOR: 0.1 39 | MOMENTUM: 0.9 40 | WEIGHT_DECAY: 1e-4 41 | BATCH_SIZE: 100 42 | EPOCHS: 50 43 | FRAME_PADDING: 0 44 | FINETUNING: 45 | SAVEDIR: '' 46 | CHECKPOINT: '' 47 | RESUME: FALSE 48 | LEARNING_RATE: 0.05 49 | DECAY_EPOCHS: (375,) 50 | DECAY_FACTOR: 0.1 51 | MOMENTUM: 0.9 52 | WEIGHT_DECAY: 0 53 | BATCH_SIZE: 100 54 | EPOCHS: 900 55 | VAL_FREQ: 3 56 | VISUALISATION: 57 | TSNE: 58 | CROP_SIZE: 224 59 | RESIZE: 224 60 | 61 | 62 | -------------------------------------------------------------------------------- /datasets/finetuning/UCF101.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import random 4 | from tqdm import tqdm 5 | import PIL.Image as Image 6 | import multiprocessing as mp 7 | 8 | from utils import UCF101_splitter 9 | from torchvision import transforms 10 | from torch.utils.data import DataLoader 11 | 12 | class UCF_dataset: 13 | def __init__(self, root, file_dict, mode, transforms): 14 | self.root = root 15 | self.keys = list(file_dict.keys()) 16 | self.gts = list(file_dict.values()) 17 | self.mode = mode 18 | self.transform = transforms 19 | 20 | def __len__(self): 21 | return len(self.keys) 22 | 23 | def load_ucf_image(self, video_name, index): 24 | 25 | frame_path = os.path.join(self.root, '{}/frame{:06d}.jpg'.format( 26 | video_name, index)) 27 | 28 | frame = Image.open(frame_path) 29 | frame_tensor = self.transform(frame) 30 | 31 | return frame_tensor 32 | 33 | def __getitem__(self, idx): 34 | if self.mode == 'train': 35 | # ===== Get training sample ===== 36 | video_name, nb_clips = self.keys[idx].split(' ') 37 | 38 | nb_clips = int(nb_clips) 39 | label = self.gts[idx] 40 | clips = [] 41 | clips.append(random.randint(1, int(nb_clips / 3))) 42 | clips.append(random.randint(int(nb_clips / 3), int(nb_clips * 2 / 3))) 43 | clips.append(random.randint(int(nb_clips * 2 / 3), nb_clips + 1)) 44 | 45 | elif self.mode == 'val': 46 | video_name, index = self.keys[idx].split(' ') 47 | index = abs(int(index)) 48 | else: 49 | raise ValueError('There are only train and val modes') 50 | 51 | label = self.gts[idx] 52 | label = int(label) - 1 53 | 54 | if self.mode == 'train': 55 | data = [] 56 | for i in range(len(clips)): 57 | index = clips[i] 58 | data.append(self.load_ucf_image(video_name, index)) 59 | sample = (data, label) 60 | elif self.mode == 'val': 61 | data = self.load_ucf_image(video_name, index) 62 | sample = (video_name.split('/')[0], data, label) 63 | 64 | return sample 65 | 66 | 67 | class UCF101: 68 | def __init__(self, cfg, logger): 69 | self.batch_size = cfg.TRAIN.FINETUNING.BATCH_SIZE 70 | self.num_workers = cfg.WORKERS 71 | self.root = cfg.DATASET.UCF101.FRAMES_PATH 72 | self.splits_path = cfg.DATASET.UCF101.SPLITS_PATH 73 | self.split = cfg.DATASET.UCF101.SPLIT 74 | self.cfg = cfg 75 | self.logger = logger 76 | 77 | # ===== Split the dataset ===== 78 | splitter = UCF101_splitter(self.splits_path, self.split) 79 | self.train_videos, self.val_videos = splitter.split_video() 80 | self.load_frame_count() 81 | self.get_training_dict() 82 | self.get_validation_dict() 83 | 84 | @staticmethod 85 | def process_video(x): 86 | root, video, L = x 87 | key = video 88 | n_frames = len(os.listdir(os.path.join(root, video))) 89 | L.append([key.replace('HandStandPushups', 'HandstandPushups'), n_frames]) 90 | 91 | def load_frame_count(self): 92 | pickle_path = os.path.join(self.cfg.ASSETS_PATH, 'UCF101_frame_count_split_{:02d}.pickle'.format(self.split)) 93 | if not os.path.exists(pickle_path): 94 | 95 | # ===== Build frame count dict if no pickle file found ===== 96 | self.frame_count = {} 97 | self.logger.info('Creating frame count dictionary for UCF101') 98 | 99 | # ===== Set up multiprocessing ===== 100 | pool = mp.Pool(processes = 16) 101 | manager = mp.Manager() 102 | L = manager.list() 103 | in_args = [] 104 | for video in os.listdir(self.root): 105 | in_args.append([self.root, video, L]) 106 | 107 | # ===== Get frame counts using multiprocessing ===== 108 | for _ in tqdm(pool.imap_unordered(self.process_video, in_args), total = len(in_args)): 109 | pass 110 | 111 | pool.close() 112 | pool.join() 113 | self.frame_count = {k:v for k,v in list(L)} 114 | 115 | with open(pickle_path, 'wb') as f: 116 | pickle.dump(self.frame_count, f) 117 | self.logger.info('Saved frame count dictionary to {}'.format(pickle_path)) 118 | 119 | else: 120 | 121 | # ===== Load frame count dict if already exists 122 | with open(pickle_path, 'rb') as f: 123 | self.frame_count = pickle.load(f) 124 | 125 | def get_training_dict(self): 126 | self.dict_training = {} 127 | for video in self.train_videos: 128 | nb_frame = self.frame_count[video] - 10 + 1 129 | key = '{} {}'.format(video, nb_frame) 130 | self.dict_training[key] = self.train_videos[video] 131 | 132 | def get_validation_dict(self): 133 | self.dict_val = {} 134 | for video in self.val_videos: 135 | nb_frame = self.frame_count[video] - 10 + 1 136 | interval = nb_frame // 19 137 | for idx in range(19): 138 | frame = interval * idx + 1 139 | key = '{} {}'.format(video, frame) 140 | self.dict_val[key] = self.val_videos[video] 141 | 142 | 143 | def get_loaders(self): 144 | # ===== Get transforms ===== 145 | tran = transforms.Compose([ 146 | transforms.Resize(self.cfg.DATASET.FINETUNING.TRANSFORMATIONS.RESIZE), 147 | transforms.ToTensor(), 148 | transforms.Normalize(mean = self.cfg.DATASET.FINETUNING.MEAN, std = self.cfg.DATASET.FINETUNING.STD) 149 | ]) 150 | 151 | # ===== Make Datasets ===== 152 | train_dataset = UCF_dataset( 153 | root = self.cfg.DATASET.UCF101.FRAMES_PATH, 154 | file_dict = self.dict_training, 155 | mode = 'train', 156 | transforms = tran 157 | ) 158 | 159 | val_dataset = UCF_dataset( 160 | root = self.cfg.DATASET.UCF101.FRAMES_PATH, 161 | file_dict = self.dict_val, 162 | mode = 'val', 163 | transforms = tran 164 | ) 165 | 166 | self.logger.info('Created Train / Val splits') 167 | self.logger.info('Training Videos : {}'.format(len(train_dataset))) 168 | self.logger.info('Validation Videos : {}'.format(len(val_dataset) // 19)) 169 | 170 | # ===== Create loaders ===== 171 | train_loader = DataLoader( 172 | dataset = train_dataset, 173 | batch_size = self.cfg.TRAIN.FINETUNING.BATCH_SIZE, 174 | shuffle = True, 175 | num_workers = self.num_workers 176 | ) 177 | 178 | val_loader = DataLoader( 179 | dataset = val_dataset, 180 | batch_size = self.cfg.TRAIN.FINETUNING.BATCH_SIZE, 181 | shuffle = True, 182 | num_workers = self.num_workers 183 | ) 184 | 185 | val_gt = {k:v - 1 for k,v in self.val_videos.items()} 186 | 187 | return train_loader, val_loader, val_gt 188 | 189 | if __name__ == '__main__': 190 | import argparse 191 | from config import config, update_config 192 | from utils import * 193 | 194 | parser = argparse.ArgumentParser() 195 | parser.add_argument('--cfg', default = None) 196 | parser.add_argument('opts', default = None, nargs = argparse.REMAINDER) 197 | args = parser.parse_args() 198 | update_config(config, args) 199 | logger = setup_logger() 200 | 201 | train_loader, val_loader, val_gt = UCF101(cfg, logger) 202 | 203 | -------------------------------------------------------------------------------- /datasets/finetuning/__init__.py: -------------------------------------------------------------------------------- 1 | from .UCF101 import * -------------------------------------------------------------------------------- /datasets/pretraining/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | from .augmentation import * 5 | from .kinetics400 import Kinetics400 6 | from torchvision.transforms import Compose 7 | 8 | 9 | class PretrainTransforms: 10 | def __init__(self, cfg): 11 | img_transforms = [] 12 | tensor_transforms = [] 13 | # ===== Prepare scaling and crop sizes ===== 14 | crop_size = cfg.DATASET.PRETRAINING.TRANSFORMATIONS.CROP_SIZE 15 | scale_size = cfg.DATASET.PRETRAINING.TRANSFORMATIONS.SCALE_SIZE 16 | 17 | # ===== Add optional augmentations ===== 18 | if cfg.DATASET.PRETRAINING.TRANSFORMATIONS.HORIZONTAL_FLIP == True: 19 | img_transforms.append(RandomHorizontalFlip(consistent=True)) 20 | if cfg.DATASET.PRETRAINING.TRANSFORMATIONS.RANDOM_GREY == True: 21 | img_transforms.append(RandomGray(consistent=False, p=0.5)) 22 | if cfg.DATASET.PRETRAINING.TRANSFORMATIONS.COLOUR_JITTER == True: 23 | img_transforms.append(ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.25, p=1.0)) 24 | 25 | # ===== Add Scale and Crop transforms ===== 26 | img_transforms.append(Scale(size=scale_size)) 27 | img_transforms.append(RandomCrop(size=crop_size, consistent = True)) 28 | 29 | # ===== Create Rotation Transform ===== 30 | self.Rotation = Rotation() 31 | 32 | # ===== Create Tensor Transformations ===== 33 | tensor_transforms.append(ToTensor()) 34 | tensor_transforms.append(Normalize( 35 | mean = cfg.DATASET.PRETRAINING.MEAN, 36 | std = cfg.DATASET.PRETRAINING.STD, 37 | )) 38 | 39 | self.img_transforms = Compose(img_transforms) 40 | self.tensor_transforms = Compose(tensor_transforms) 41 | 42 | def __call__(self, anchor_frame, pair_frame): 43 | 44 | # ===== Transform Pair Together ===== 45 | anchor_tensor, pair_tensor = self.tensor_transforms(self.img_transforms([anchor_frame, pair_frame])) 46 | 47 | # ===== Transform Rotation, get Rotation GT ===== 48 | rotation_gt = np.random.randint(0,4) 49 | rotation_intermediate = self.img_transforms([anchor_frame]) 50 | rotation_intermediate = self.Rotation(rotation_intermediate, rotation = 90 * rotation_gt) 51 | rotation_tensor = self.tensor_transforms(rotation_intermediate)[0] 52 | 53 | return anchor_tensor, pair_tensor, rotation_tensor, rotation_gt 54 | 55 | def get_pretraining_dataset(cfg): 56 | transforms = PretrainTransforms(cfg) 57 | dataset_name = cfg.DATASET.PRETRAINING.DATASET 58 | 59 | if dataset_name == 'kinetics400': 60 | dataset = Kinetics400(cfg, transforms) 61 | else: 62 | raise NotImplementedError 63 | 64 | n_data = len(dataset) 65 | 66 | train_loader = torch.utils.data.DataLoader( 67 | dataset = dataset, 68 | batch_size = cfg.TRAIN.PRETRAINING.BATCH_SIZE, 69 | shuffle = True, 70 | num_workers = cfg.WORKERS, 71 | pin_memory = True, 72 | sampler = None, 73 | drop_last = True 74 | ) 75 | 76 | return train_loader, n_data 77 | 78 | -------------------------------------------------------------------------------- /datasets/pretraining/augmentation.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numbers 3 | import math 4 | import collections 5 | import numpy as np 6 | from PIL import ImageOps, Image 7 | #from joblib import Parallel, delayed 8 | 9 | import torchvision 10 | from torchvision import transforms 11 | import torchvision.transforms.functional as F 12 | 13 | class Padding: 14 | def __init__(self, pad): 15 | self.pad = pad 16 | 17 | def __call__(self, img): 18 | return ImageOps.expand(img, border=self.pad, fill=0) 19 | 20 | class Scale: 21 | def __init__(self, size, interpolation=Image.NEAREST): 22 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) 23 | self.size = size 24 | self.interpolation = interpolation 25 | 26 | def __call__(self, imgmap): 27 | # assert len(imgmap) > 1 # list of images 28 | img1 = imgmap[0] 29 | if isinstance(self.size, int): 30 | w, h = img1.size 31 | if (w <= h and w == self.size) or (h <= w and h == self.size): 32 | return imgmap 33 | if w < h: 34 | ow = self.size 35 | oh = int(self.size * h / w) 36 | return [i.resize((ow, oh), self.interpolation) for i in imgmap] 37 | else: 38 | oh = self.size 39 | ow = int(self.size * w / h) 40 | return [i.resize((ow, oh), self.interpolation) for i in imgmap] 41 | else: 42 | return [i.resize(self.size, self.interpolation) for i in imgmap] 43 | 44 | 45 | class CenterCrop: 46 | def __init__(self, size, consistent=True): 47 | if isinstance(size, numbers.Number): 48 | self.size = (int(size), int(size)) 49 | else: 50 | self.size = size 51 | 52 | def __call__(self, imgmap): 53 | img1 = imgmap[0] 54 | w, h = img1.size 55 | th, tw = self.size 56 | x1 = int(round((w - tw) / 2.)) 57 | y1 = int(round((h - th) / 2.)) 58 | return [i.crop((x1, y1, x1 + tw, y1 + th)) for i in imgmap] 59 | 60 | 61 | class RandomCropWithProb: 62 | def __init__(self, size, p=0.8, consistent=True): 63 | if isinstance(size, numbers.Number): 64 | self.size = (int(size), int(size)) 65 | else: 66 | self.size = size 67 | self.consistent = consistent 68 | self.threshold = p 69 | 70 | def __call__(self, imgmap): 71 | img1 = imgmap[0] 72 | w, h = img1.size 73 | if self.size is not None: 74 | th, tw = self.size 75 | if w == tw and h == th: 76 | return imgmap 77 | if self.consistent: 78 | if random.random() < self.threshold: 79 | x1 = random.randint(0, w - tw) 80 | y1 = random.randint(0, h - th) 81 | else: 82 | x1 = int(round((w - tw) / 2.)) 83 | y1 = int(round((h - th) / 2.)) 84 | return [i.crop((x1, y1, x1 + tw, y1 + th)) for i in imgmap] 85 | else: 86 | result = [] 87 | for i in imgmap: 88 | if random.random() < self.threshold: 89 | x1 = random.randint(0, w - tw) 90 | y1 = random.randint(0, h - th) 91 | else: 92 | x1 = int(round((w - tw) / 2.)) 93 | y1 = int(round((h - th) / 2.)) 94 | result.append(i.crop((x1, y1, x1 + tw, y1 + th))) 95 | return result 96 | else: 97 | return imgmap 98 | 99 | class RandomCrop: 100 | def __init__(self, size, consistent=True): 101 | if isinstance(size, numbers.Number): 102 | self.size = (int(size), int(size)) 103 | else: 104 | self.size = size 105 | self.consistent = consistent 106 | 107 | def __call__(self, imgmap, flowmap=None): 108 | img1 = imgmap[0] 109 | w, h = img1.size 110 | if self.size is not None: 111 | th, tw = self.size 112 | if w == tw and h == th: 113 | return imgmap 114 | if not flowmap: 115 | if self.consistent: 116 | x1 = random.randint(0, w - tw) 117 | y1 = random.randint(0, h - th) 118 | return [i.crop((x1, y1, x1 + tw, y1 + th)) for i in imgmap] 119 | else: 120 | result = [] 121 | for i in imgmap: 122 | x1 = random.randint(0, w - tw) 123 | y1 = random.randint(0, h - th) 124 | result.append(i.crop((x1, y1, x1 + tw, y1 + th))) 125 | return result 126 | elif flowmap is not None: 127 | assert (not self.consistent) 128 | result = [] 129 | for idx, i in enumerate(imgmap): 130 | proposal = [] 131 | for j in range(3): # number of proposal: use the one with largest optical flow 132 | x = random.randint(0, w - tw) 133 | y = random.randint(0, h - th) 134 | proposal.append([x, y, abs(np.mean(flowmap[idx,y:y+th,x:x+tw,:]))]) 135 | [x1, y1, _] = max(proposal, key=lambda x: x[-1]) 136 | result.append(i.crop((x1, y1, x1 + tw, y1 + th))) 137 | return result 138 | else: 139 | raise ValueError('wrong case') 140 | else: 141 | return imgmap 142 | 143 | 144 | class RandomSizedCrop: 145 | def __init__(self, size, interpolation=Image.BILINEAR, consistent=True, p=1.0): 146 | self.size = size 147 | self.interpolation = interpolation 148 | self.consistent = consistent 149 | self.threshold = p 150 | 151 | def __call__(self, imgmap): 152 | img1 = imgmap[0] 153 | if random.random() < self.threshold: # do RandomSizedCrop 154 | for attempt in range(10): 155 | area = img1.size[0] * img1.size[1] 156 | target_area = random.uniform(0.5, 1) * area 157 | aspect_ratio = random.uniform(3. / 4, 4. / 3) 158 | 159 | w = int(round(math.sqrt(target_area * aspect_ratio))) 160 | h = int(round(math.sqrt(target_area / aspect_ratio))) 161 | 162 | if self.consistent: 163 | if random.random() < 0.5: 164 | w, h = h, w 165 | if w <= img1.size[0] and h <= img1.size[1]: 166 | x1 = random.randint(0, img1.size[0] - w) 167 | y1 = random.randint(0, img1.size[1] - h) 168 | 169 | imgmap = [i.crop((x1, y1, x1 + w, y1 + h)) for i in imgmap] 170 | for i in imgmap: assert(i.size == (w, h)) 171 | 172 | return [i.resize((self.size, self.size), self.interpolation) for i in imgmap] 173 | else: 174 | result = [] 175 | for i in imgmap: 176 | if random.random() < 0.5: 177 | w, h = h, w 178 | if w <= img1.size[0] and h <= img1.size[1]: 179 | x1 = random.randint(0, img1.size[0] - w) 180 | y1 = random.randint(0, img1.size[1] - h) 181 | result.append(i.crop((x1, y1, x1 + w, y1 + h))) 182 | assert(result[-1].size == (w, h)) 183 | else: 184 | result.append(i) 185 | 186 | assert len(result) == len(imgmap) 187 | return [i.resize((self.size, self.size), self.interpolation) for i in result] 188 | 189 | # Fallback 190 | scale = Scale(self.size, interpolation=self.interpolation) 191 | crop = CenterCrop(self.size) 192 | return crop(scale(imgmap)) 193 | else: # don't do RandomSizedCrop, do CenterCrop 194 | crop = CenterCrop(self.size) 195 | return crop(imgmap) 196 | 197 | 198 | class RandomHorizontalFlip: 199 | def __init__(self, consistent=True, command=None): 200 | self.consistent = consistent 201 | if command == 'left': 202 | self.threshold = 0 203 | elif command == 'right': 204 | self.threshold = 1 205 | else: 206 | self.threshold = 0.5 207 | def __call__(self, imgmap): 208 | if self.consistent: 209 | if random.random() < self.threshold: 210 | return [i.transpose(Image.FLIP_LEFT_RIGHT) for i in imgmap] 211 | else: 212 | return imgmap 213 | else: 214 | result = [] 215 | for i in imgmap: 216 | if random.random() < self.threshold: 217 | result.append(i.transpose(Image.FLIP_LEFT_RIGHT)) 218 | else: 219 | result.append(i) 220 | assert len(result) == len(imgmap) 221 | return result 222 | 223 | 224 | class RandomGray: 225 | '''Actually it is a channel splitting, not strictly grayscale images''' 226 | def __init__(self, consistent=True, p=0.5): 227 | self.consistent = consistent 228 | self.p = p # probability to apply grayscale 229 | def __call__(self, imgmap): 230 | if self.consistent: 231 | if random.random() < self.p: 232 | return [self.grayscale(i) for i in imgmap] 233 | else: 234 | return imgmap 235 | else: 236 | result = [] 237 | for i in imgmap: 238 | if random.random() < self.p: 239 | result.append(self.grayscale(i)) 240 | else: 241 | result.append(i) 242 | assert len(result) == len(imgmap) 243 | return result 244 | 245 | def grayscale(self, img): 246 | channel = np.random.choice(3) 247 | np_img = np.array(img)[:,:,channel] 248 | np_img = np.dstack([np_img, np_img, np_img]) 249 | img = Image.fromarray(np_img, 'RGB') 250 | return img 251 | 252 | 253 | class ColorJitter(object): 254 | """Randomly change the brightness, contrast and saturation of an image. --modified from pytorch source code 255 | Args: 256 | brightness (float or tuple of float (min, max)): How much to jitter brightness. 257 | brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness] 258 | or the given [min, max]. Should be non negative numbers. 259 | contrast (float or tuple of float (min, max)): How much to jitter contrast. 260 | contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast] 261 | or the given [min, max]. Should be non negative numbers. 262 | saturation (float or tuple of float (min, max)): How much to jitter saturation. 263 | saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation] 264 | or the given [min, max]. Should be non negative numbers. 265 | hue (float or tuple of float (min, max)): How much to jitter hue. 266 | hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. 267 | Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. 268 | """ 269 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, consistent=False, p=1.0): 270 | self.brightness = self._check_input(brightness, 'brightness') 271 | self.contrast = self._check_input(contrast, 'contrast') 272 | self.saturation = self._check_input(saturation, 'saturation') 273 | self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), 274 | clip_first_on_zero=False) 275 | self.consistent = consistent 276 | self.threshold = p 277 | 278 | def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): 279 | if isinstance(value, numbers.Number): 280 | if value < 0: 281 | raise ValueError("If {} is a single number, it must be non negative.".format(name)) 282 | value = [center - value, center + value] 283 | if clip_first_on_zero: 284 | value[0] = max(value[0], 0) 285 | elif isinstance(value, (tuple, list)) and len(value) == 2: 286 | if not bound[0] <= value[0] <= value[1] <= bound[1]: 287 | raise ValueError("{} values should be between {}".format(name, bound)) 288 | else: 289 | raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name)) 290 | 291 | # if value is 0 or (1., 1.) for brightness/contrast/saturation 292 | # or (0., 0.) for hue, do nothing 293 | if value[0] == value[1] == center: 294 | value = None 295 | return value 296 | 297 | @staticmethod 298 | def get_params(brightness, contrast, saturation, hue): 299 | """Get a randomized transform to be applied on image. 300 | Arguments are same as that of __init__. 301 | Returns: 302 | Transform which randomly adjusts brightness, contrast and 303 | saturation in a random order. 304 | """ 305 | transforms = [] 306 | 307 | if brightness is not None: 308 | brightness_factor = random.uniform(brightness[0], brightness[1]) 309 | transforms.append(torchvision.transforms.Lambda(lambda img: F.adjust_brightness(img, brightness_factor))) 310 | 311 | if contrast is not None: 312 | contrast_factor = random.uniform(contrast[0], contrast[1]) 313 | transforms.append(torchvision.transforms.Lambda(lambda img: F.adjust_contrast(img, contrast_factor))) 314 | 315 | if saturation is not None: 316 | saturation_factor = random.uniform(saturation[0], saturation[1]) 317 | transforms.append(torchvision.transforms.Lambda(lambda img: F.adjust_saturation(img, saturation_factor))) 318 | 319 | if hue is not None: 320 | hue_factor = random.uniform(hue[0], hue[1]) 321 | transforms.append(torchvision.transforms.Lambda(lambda img: F.adjust_hue(img, hue_factor))) 322 | 323 | random.shuffle(transforms) 324 | transform = torchvision.transforms.Compose(transforms) 325 | 326 | return transform 327 | 328 | def __call__(self, imgmap): 329 | if random.random() < self.threshold: # do ColorJitter 330 | if self.consistent: 331 | transform = self.get_params(self.brightness, self.contrast, 332 | self.saturation, self.hue) 333 | return [transform(i) for i in imgmap] 334 | else: 335 | result = [] 336 | for img in imgmap: 337 | transform = self.get_params(self.brightness, self.contrast, 338 | self.saturation, self.hue) 339 | result.append(transform(img)) 340 | return result 341 | else: # don't do ColorJitter, do nothing 342 | return imgmap 343 | 344 | def __repr__(self): 345 | format_string = self.__class__.__name__ + '(' 346 | format_string += 'brightness={0}'.format(self.brightness) 347 | format_string += ', contrast={0}'.format(self.contrast) 348 | format_string += ', saturation={0}'.format(self.saturation) 349 | format_string += ', hue={0})'.format(self.hue) 350 | return format_string 351 | 352 | 353 | class RandomRotation: 354 | def __init__(self, consistent=True, degree=15, p=1.0): 355 | self.consistent = consistent 356 | self.degree = degree 357 | self.threshold = p 358 | def __call__(self, imgmap): 359 | if random.random() < self.threshold: # do RandomRotation 360 | if self.consistent: 361 | deg = np.random.randint(-self.degree, self.degree, 1)[0] 362 | return [i.rotate(deg, expand=True) for i in imgmap] 363 | else: 364 | return [i.rotate(np.random.randint(-self.degree, self.degree, 1)[0], expand=True) for i in imgmap] 365 | else: # don't do RandomRotation, do nothing 366 | return imgmap 367 | 368 | class Rotation: 369 | def __init__(self, consistent=True, degree_inc=90, p=1.0): 370 | self.consistent = consistent 371 | self.degree_inc = degree_inc 372 | self.threshold = p 373 | def __call__(self, imgmap, rotation): 374 | deg = rotation 375 | return [i.rotate(deg, expand=True) for i in imgmap] 376 | 377 | class ToTensor: 378 | def __call__(self, imgmap): 379 | totensor = transforms.ToTensor() 380 | return [totensor(i) for i in imgmap] 381 | 382 | class Normalize: 383 | def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): 384 | self.mean = mean 385 | self.std = std 386 | def __call__(self, imgmap): 387 | normalize = transforms.Normalize(mean=self.mean, std=self.std) 388 | return [normalize(i) for i in imgmap] 389 | 390 | -------------------------------------------------------------------------------- /datasets/pretraining/kinetics400.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pickle 4 | import numpy as np 5 | import random 6 | from glob import glob 7 | from tqdm import tqdm 8 | import PIL.Image as Image 9 | from multiprocessing import Process, Manager, Pool 10 | 11 | class kineticsFolderInstance: 12 | """ 13 | Dataset instance for a single kinetics video 14 | """ 15 | def __init__(self, root, membank_idx, frame_padding): 16 | self.root = root 17 | self.membank_idx = membank_idx 18 | self.frame_padding = frame_padding 19 | self.vid_length = len(glob(os.path.join(root,'*'))) 20 | 21 | def __len__(self): 22 | return self.vid_length 23 | 24 | def __call__(self): 25 | anchor_idx = np.random.randint(0, self.vid_length) 26 | if anchor_idx >= self.vid_length - (2 * self.frame_padding) - 1: 27 | pair_idx = anchor_idx - self.frame_padding 28 | else: 29 | pair_idx = anchor_idx + self.frame_padding 30 | 31 | anchor_path = os.path.join(self.root, 'frame{}.jpg'.format(anchor_idx)) 32 | pair_path = os.path.join(self.root, 'frame{}.jpg'.format(pair_idx)) 33 | 34 | return anchor_path, pair_path 35 | 36 | class Kinetics400: 37 | def __init__(self, cfg, transform): 38 | self.root = os.path.join(cfg.DATASET.KINETICS400.FRAMES_PATH, 'train') 39 | self.frame_padding = cfg.TRAIN.PRETRAINING.FRAME_PADDING 40 | self.transform = transform 41 | self.K = cfg.LOSS.PRETRAINING.NCE.NEGATIVES 42 | assert self.root is not '', 'Please specify Kinetics400 Path in config' 43 | pickle_file = os.path.join(cfg.ASSETS_PATH, 'Kinetics400Dataset.pickle') 44 | if not os.path.exists(pickle_file): 45 | # ===== Makes Sample List ===== 46 | self.samples = self.get_sample_list() 47 | 48 | # ===== Save Pickle File ===== 49 | with open(pickle_file, 'wb') as f: 50 | pickle.dump(self.samples,f) 51 | else: 52 | # ===== Load Pickle File ===== 53 | with open(pickle_file, 'rb') as f: 54 | self.samples = pickle.load(f) 55 | f.close() 56 | 57 | self.all_membank_negatives = list(range(len(self.samples))) 58 | 59 | 60 | def process_video(self, x): 61 | video, idx, L = x 62 | try: 63 | sample = kineticsFolderInstance( 64 | root = video, 65 | membank_idx = idx, 66 | frame_padding = self.frame_padding 67 | ) 68 | if len(sample) != 0: 69 | L.append(sample) 70 | except: 71 | pass 72 | 73 | def get_sample_list(self): 74 | video_folders = glob(os.path.join(self.root,'*','*')) 75 | 76 | # ===== Set up multiprocessing ===== 77 | L = Manager().list() 78 | pool = Pool(processes=32) 79 | 80 | # ===== Prepare inputs for multiprocessing ===== 81 | inputs = [[video, idx, L] for idx, video in enumerate(video_folders)] 82 | 83 | # ===== Create sample list ===== 84 | pbar = tqdm(total = len(inputs)) 85 | print(len(inputs)) 86 | for i in pool.imap_unordered(self.process_video, inputs): 87 | pbar.update(1) 88 | 89 | pool.close() 90 | pool.join() 91 | 92 | all_samples = list(L) 93 | 94 | # ===== Fix membank_idx due to failed samples ===== 95 | for idx, sample in enumerate(all_samples): 96 | sample.membank_idx = idx 97 | 98 | return all_samples 99 | 100 | 101 | 102 | def __len__(self): 103 | return len(self.samples) 104 | 105 | def __getitem__(self, index): 106 | 107 | # Get sample from index 108 | sample = self.samples[index] 109 | membank_idx = sample.membank_idx 110 | anchor_path, pair_path = sample() 111 | 112 | anchor_frame = Image.open(anchor_path) 113 | pair_frame = Image.open(pair_path) 114 | anchor_tensor, pair_tensor, rotation_tensor, rotation_gt = self.transform(anchor_frame, pair_frame) 115 | 116 | # Get negatives 117 | potential_negatives = self.all_membank_negatives[:membank_idx] + self.all_membank_negatives[membank_idx + 1:] 118 | negatives = torch.tensor(random.sample(potential_negatives, self.K + 1)) 119 | 120 | inputs = { 121 | 'anchor_tensor': anchor_tensor, 122 | 'pair_tensor': pair_tensor, 123 | 'membank_idx': membank_idx, 124 | 'rotation_tensor': rotation_tensor, 125 | 'rotation_gt': rotation_gt, 126 | 'negatives': negatives, 127 | } 128 | 129 | return inputs 130 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import time 4 | import argparse 5 | import logging 6 | import torch 7 | import torch.nn as nn 8 | import numpy as np 9 | from tensorboardX import SummaryWriter 10 | 11 | 12 | from config import config, update_config 13 | from network import * 14 | from datasets.pretraining import get_pretraining_dataset 15 | from datasets.finetuning import * 16 | from loss import get_loss 17 | from utils import * 18 | 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser(description='Train TCE Self-Supervised') 22 | parser.add_argument('--cfg', help = 'Path to config file', type = str, default = None) 23 | parser.add_argument('--val', default = False, action = 'store_true', help = 'Just validate checkpoint') 24 | parser.add_argument('opts', help = 'Modify config using the command line', 25 | default = None, nargs=argparse.REMAINDER ) 26 | args = parser.parse_args() 27 | update_config(config, args) 28 | 29 | return args 30 | 31 | 32 | def train(epoch, train_loader, model, criterion, optimizer, cfg, tboard, logger): 33 | 34 | # ===== Set up Meters ===== 35 | batch_time = AverageMeter() 36 | loss_meter = AverageMeter() 37 | top1_meter = AverageMeter() 38 | top5_meter = AverageMeter() 39 | 40 | # ===== Switch to train mode ===== 41 | model.train() 42 | 43 | 44 | # ===== Start training on batches ===== 45 | for idx, (data, label) in enumerate(train_loader): 46 | end = time.time() 47 | bsz = data[0].size(0) 48 | 49 | # ===== Forwards ===== 50 | data = [x.cuda() for x in data] 51 | label = label.cuda() 52 | 53 | 54 | for i, frame_tensor in enumerate(data): 55 | if i == 0: 56 | output = model(frame_tensor) 57 | else: 58 | output += model(frame_tensor) 59 | 60 | loss = criterion(output, label) 61 | prec1, prec5 = accuracy(output, label, topk=(1,5)) 62 | 63 | # ===== Backwards ===== 64 | optimizer.zero_grad() 65 | loss.backward() 66 | optimizer.step() 67 | 68 | # ===== Update meters and do logging ===== 69 | batch_time.update(time.time() - end) 70 | loss_meter.update(loss.item(), bsz) 71 | top1_meter.update(prec1) 72 | top5_meter.update(prec5) 73 | 74 | if idx % cfg.PRINT_FREQ == 0: 75 | log_step = (epoch-1) * len(train_loader) + idx 76 | tboard.add_scalars('Train/Loss', {'val': loss_meter.val, 'avg': loss_meter.avg}, log_step) 77 | tboard.add_scalars('Train/Top1 Acc', {'val': top1_meter.val, 'avg': top1_meter.avg}, log_step) 78 | tboard.add_scalars('Train/Top5 Acc', {'val': top5_meter.val, 'avg': top5_meter.avg}, log_step) 79 | 80 | info = ('Epoch : {} ({}/{}) | ' 81 | 'BT : {:02f} ({:02f}) | ' 82 | 'Loss: {:02f} ({:02f}) | ' 83 | 'Top1: {:02f} ({:02f}) | ' 84 | 'Top5: {:02f} ({:02f})').format( 85 | epoch, idx, len(train_loader), 86 | batch_time.val, batch_time.avg, 87 | loss_meter.val, loss_meter.avg, 88 | top1_meter.val, top1_meter.avg, 89 | top5_meter.val, top5_meter.avg 90 | ) 91 | logger.info(info) 92 | 93 | @torch.no_grad() 94 | def validate(epoch, val_loader, val_gt, model, criterion, cfg, tboard, logger): 95 | 96 | # ===== Switch to eval mode ===== 97 | model.eval() 98 | video_predictions = {} 99 | end = time.time() 100 | pbar = tqdm(total = len(val_loader)) 101 | 102 | for idx, (video_keys, data, label) in enumerate(val_loader): 103 | bsz = data.size(0) 104 | data = data.cuda() 105 | label = label.cuda() 106 | 107 | # ===== Calculate network output and pack into video_predictions ===== 108 | output = model(data) 109 | for i in range(bsz): 110 | key = video_keys[i] 111 | pred = output[i] 112 | if key not in video_predictions.keys(): 113 | video_predictions[key] = [] 114 | video_predictions[key] = pred 115 | else: 116 | video_predictions[key] += pred 117 | 118 | pbar.update(1) 119 | 120 | print('\n') 121 | # ===== Get Eval Top1, Top5, Loss ===== 122 | video_level_preds = torch.zeros(len(video_predictions), 101).float().cuda() 123 | video_level_labels = torch.zeros(len(video_predictions)).long().cuda() 124 | 125 | for idx, key in enumerate(sorted(video_predictions.keys())): 126 | video_level_preds[idx] = video_predictions[key] 127 | video_level_labels[idx] = val_gt[key] 128 | 129 | prec1, prec5 = accuracy(video_level_preds, video_level_labels, topk=(1,5)) 130 | loss = criterion(video_level_preds, video_level_labels) 131 | 132 | # ===== Log ===== 133 | logger.info('Validation complete for epoch {}, time taken {:02f} seconds'.format(epoch, time.time() - end)) 134 | logger.info('Top1: {:02f} Top5: {:02f} Loss: {:02f}'.format(prec1, prec5, loss.item())) 135 | 136 | tboard.add_scalars('Val/Top1', {'prec1':prec1}, epoch) 137 | tboard.add_scalars('Val/Top5', {'prec5':prec5}, epoch) 138 | tboard.add_scalars('Val/Loss', {'loss':loss.item()}, epoch) 139 | 140 | return prec1 141 | 142 | 143 | 144 | 145 | def main(): 146 | args = parse_args() 147 | logger = setup_logger() 148 | logger.info(config) 149 | if not os.path.exists(config.ASSETS_PATH): 150 | os.makedirs(config.ASSETS_PATH) 151 | 152 | # ===== Create the dataloaders ===== 153 | UCF_dataset = UCF101(config, logger) 154 | train_loader, val_loader, val_gt = UCF_dataset.get_loaders() 155 | 156 | # ===== Create the model ===== 157 | model = FineTuneNet(config) 158 | logger.info('Built Model, using {} backbone'.format(config.MODEL.TRUNK)) 159 | if torch.cuda.device_count() > 1: 160 | model = torch.nn.DataParallel(model).cuda() 161 | else: 162 | model = model.cuda() 163 | logger.info('Training on {} GPUs'.format(torch.cuda.device_count())) 164 | 165 | # ===== Set the optimizer ===== 166 | optimizer = get_optimizer(model, config, pretraining = False) 167 | 168 | # ===== Get the loss ===== 169 | criterion = nn.CrossEntropyLoss().cuda() 170 | 171 | # ===== Load checkpoint ===== 172 | if config.TRAIN.FINETUNING.CHECKPOINT: 173 | checkpoint = torch.load(config.TRAIN.FINETUNING.CHECKPOINT) 174 | # ===== Align checkpoint keys with model ===== 175 | if 'module' in list(checkpoint['state_dict'].keys())[0] and 'module' not in list(model.state_dict().keys())[0]: 176 | checkpoint['state_dict'] = {k.replace('module.',''):v for k,v in checkpoint['state_dict'].items() } 177 | elif 'module' not in list(checkpoint['state_dict'].keys())[0] and 'module' in list(model.state_dict().keys())[0]: 178 | checkpoint['state_dict'] = {'module.' + k:v for k,v in checkpoint['state_dict'].items() } 179 | 180 | if not config.TRAIN.FINETUNING.RESUME: 181 | # ===== Load only backbone parameters for starting finetuning at epoch 0 ===== 182 | forgiving_load_state_dict(checkpoint['state_dict'], model, logger) 183 | start_epoch = 1 184 | best_prec1 = 0 185 | else: 186 | model.load_state_dict(checkpoint['state_dict']) 187 | optimizer.load_state_dict(checkpoint['optimizer']) 188 | start_epoch = checkpoint['epoch'] 189 | best_prec1 = checkpoint['best_prec1'] 190 | else: 191 | logger.warning('No checkpoint specified for pre-training. Training from scratch.') 192 | start_epoch = 1 193 | best_prec1 = 0 194 | 195 | # ===== Set up save directory and TensorBoard ===== 196 | assert config.TRAIN.FINETUNING.SAVEDIR, 'Please specify save directory' 197 | if not os.path.exists(config.TRAIN.FINETUNING.SAVEDIR): 198 | os.makedirs(config.TRAIN.FINETUNING.SAVEDIR) 199 | os.makedirs(os.path.join(config.TRAIN.FINETUNING.SAVEDIR, 'checkpoints')) 200 | os.makedirs(os.path.join(config.TRAIN.FINETUNING.SAVEDIR, 'tboard')) 201 | 202 | tboard = SummaryWriter(logdir = os.path.join(config.TRAIN.FINETUNING.SAVEDIR, 'tboard')) 203 | 204 | if args.val: 205 | logger.info('Running in Validation mode') 206 | validate( 207 | epoch = start_epoch, 208 | val_loader = val_loader, 209 | val_gt = val_gt, 210 | model = model, 211 | criterion = criterion, 212 | cfg = config, 213 | tboard = tboard, 214 | logger = logger, 215 | ) 216 | else: 217 | # ===== Train Loop ===== 218 | logger.info('Begin training') 219 | for epoch in range(start_epoch, config.TRAIN.FINETUNING.EPOCHS + 1): 220 | adjust_learning_rate(epoch, config, optimizer, pretraining = False, logger = logger) 221 | logger.info('Training epoch {}'.format(epoch)) 222 | 223 | # ===== Train 1 epoch ====== 224 | train( 225 | epoch = epoch, 226 | train_loader = train_loader, 227 | model = model, 228 | criterion = criterion, 229 | optimizer = optimizer, 230 | cfg = config, 231 | tboard = tboard, 232 | logger = logger 233 | ) 234 | 235 | # ===== Validate periodically ===== 236 | if epoch % config.TRAIN.FINETUNING.VAL_FREQ == 0: 237 | logger.info('Validating at epoch {}'.format(epoch)) 238 | prec1 = validate( 239 | epoch = epoch, 240 | val_loader = val_loader, 241 | val_gt = val_gt, 242 | model = model, 243 | criterion = criterion, 244 | cfg = config, 245 | tboard = tboard, 246 | logger = logger, 247 | ) 248 | 249 | # ===== Save if new best performance ===== 250 | if prec1 > best_prec1: 251 | best_prec1 = prec1 252 | logger.info('New best top1 precision: {}'.format(best_prec1)) 253 | checkpoint = { 254 | 'state_dict': model.state_dict(), 255 | 'optimizer': optimizer.state_dict(), 256 | 'epoch': epoch, 257 | 'best_prec1': best_prec1 258 | } 259 | save_path = os.path.join(config.TRAIN.FINETUNING.SAVEDIR, 'checkpoints', 'best_checkpoint.pth') 260 | torch.save(checkpoint, save_path) 261 | 262 | # ===== Save latest checkpoint ===== 263 | checkpoint = { 264 | 'state_dict': model.state_dict(), 265 | 'optimizer': optimizer.state_dict(), 266 | 'epoch': epoch, 267 | 'best_prec1': best_prec1 268 | } 269 | save_path = os.path.join(config.TRAIN.FINETUNING.SAVEDIR, 'checkpoints', 'latest_checkpoint.pth') 270 | torch.save(checkpoint, save_path) 271 | 272 | if __name__ == '__main__': 273 | main() -------------------------------------------------------------------------------- /images/TCE.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csiro-robotics/TCE/0da560c7c8616fc124dce0b5b35cd3ebae20bb81/images/TCE.png -------------------------------------------------------------------------------- /images/bowling_tsne_example.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csiro-robotics/TCE/0da560c7c8616fc124dce0b5b35cd3ebae20bb81/images/bowling_tsne_example.gif -------------------------------------------------------------------------------- /loss/NCEAverage.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | class Normalise(nn.Module): 6 | 7 | def __init__(self, power=2): 8 | super(Normalise, self).__init__() 9 | self.power = power 10 | 11 | def forward(self, x): 12 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 13 | out = x.div(norm) 14 | return out 15 | 16 | class NCEAverageKinetics(nn.Module): 17 | 18 | def __init__(self, feat_dim, n_data, K, T=0.07, momentum=0.5): 19 | super(NCEAverageKinetics, self).__init__() 20 | self.nLem = n_data 21 | self.unigrams = torch.ones(self.nLem) 22 | self.K = int(K) 23 | self.feat_dim = int(feat_dim) 24 | self.register_buffer('params', torch.tensor([K, T, -1, -1, momentum])) 25 | 26 | stdv = 1. / math.sqrt(feat_dim / 3) 27 | self.register_buffer('memory_bank', torch.rand(n_data, feat_dim).mul_(2 * stdv).add_(-stdv)) 28 | self.l2norm = Normalise(2) 29 | 30 | 31 | @torch.no_grad() 32 | def update_memorybank(self, features, index): 33 | with torch.no_grad(): 34 | self.memory_bank.index_copy_(0, index.view(-1), features) 35 | 36 | 37 | def contrast(self, anchor_feature, pair_feature, negatives): 38 | Z_c = self.params[2].item() 39 | T = self.params[1].item() 40 | K = int(self.params[0].item()) 41 | 42 | batchSize = anchor_feature.size(0) 43 | # ===== Retrieve negatives from memory bank, reshape, insert positive into 0th index ===== 44 | weight_c = torch.index_select(self.memory_bank, 0, negatives.view(-1)).detach() 45 | weight_c = weight_c.view(batchSize, K + 1, self.feat_dim) 46 | weight_c[:,0] = pair_feature 47 | 48 | # ===== BMM between positive and negative features + positive pairing ===== 49 | out_c = torch.bmm(weight_c, anchor_feature.view(batchSize, self.feat_dim, 1)) 50 | out_c = torch.exp(torch.div(out_c, T)) 51 | 52 | if Z_c < 0: 53 | self.params[2] = out_c.mean() * self.memory_bank.size(0) 54 | Z_c = self.params[2].clone().detach().item() 55 | print("normalization constant Z_c is set to {:.1f}".format(Z_c)) 56 | 57 | out_c = torch.div(out_c, Z_c).contiguous() 58 | 59 | return out_c 60 | 61 | 62 | def forward(self, inputs, _ = None): 63 | 64 | anchor_feature = inputs['anchor_feature'] 65 | pair_feature= inputs['pair_feature'] 66 | negatives = inputs['negatives'] 67 | membank_idx = inputs['membank_idx'] 68 | 69 | out_c = self.contrast(anchor_feature, pair_feature, negatives) 70 | self.update_memorybank(anchor_feature, membank_idx) 71 | 72 | 73 | return out_c 74 | 75 | class NCEAverageKineticsMining(nn.Module): 76 | def __init__(self, feat_dim, n_data, K, max_hard_negatives_percentage = 1, T=0.07): 77 | super(NCEAverageKineticsMining, self).__init__() 78 | self.nLem = n_data 79 | self.unigrams = torch.ones(self.nLem) 80 | self.K = int(K) 81 | self.feat_dim = int(feat_dim) 82 | self.max_negs = int(self.K * max_hard_negatives_percentage) 83 | 84 | self.register_buffer('params', torch.tensor([K, T, -1])) 85 | 86 | stdv = 1. / math.sqrt(feat_dim / 3) 87 | self.register_buffer('memory_bank', torch.rand(n_data, feat_dim).mul_(2 * stdv).add_(-stdv)) 88 | self.l2norm = Normalise(2) 89 | 90 | def update_memorybank(self, features, index): 91 | self.memory_bank.index_copy_(0, index.view(-1), features) 92 | return None 93 | 94 | 95 | def contrast(self, anchor_feature, pair_feature, negatives): 96 | Z_c = self.params[2].item() 97 | T = self.params[1].item() 98 | K = int(self.params[0].item()) 99 | batchSize = anchor_feature.size(0) 100 | 101 | # ===== Retrieve negatives from memory bank, reshape, insert positive into 0th index ===== 102 | weight_c = torch.index_select(self.memory_bank, 0, negatives.view(-1)).detach() 103 | weight_c = weight_c.view(batchSize, K + 1, self.feat_dim) 104 | weight_c[:,0] = pair_feature 105 | 106 | # ===== BMM between positive and negative features + positive pairing ===== 107 | out_c = torch.bmm(weight_c, anchor_feature.view(batchSize, self.feat_dim, 1)) 108 | out_c = torch.exp(torch.div(out_c, T)) 109 | 110 | if Z_c < 0: 111 | self.params[2] = out_c.mean() * self.memory_bank.size(0) 112 | Z_c = self.params[2].clone().detach().item() 113 | print("normalization constant Z_c is set to {:.1f}".format(Z_c)) 114 | 115 | out_c = torch.div(out_c, Z_c).contiguous() 116 | 117 | return out_c 118 | 119 | def mine_negative_examples(self, anchor_feature, threshold, membank_idx): 120 | with torch.no_grad(): 121 | bs = anchor_feature.size(0) 122 | cosine_scores = torch.matmul(anchor_feature, self.memory_bank.view(self.feat_dim, -1)) 123 | cosine_mask = cosine_scores >= threshold 124 | cosine_scores[cosine_mask] = -2 125 | 126 | scores, indices = cosine_scores.topk(k=self.K) 127 | negative_indices = torch.zeros(bs, self.K + 1).cuda() 128 | 129 | for rowid in range(bs): 130 | # Get all the hard examples, determine how many random examples are required 131 | row_topk = scores[rowid] # Top K values, sorted 132 | row_indices = indices[rowid] # Top K indices, associated 133 | row_idx = membank_idx[rowid] 134 | hard_examples = min((row_topk != -2).sum().item(), self.max_negs) 135 | rand_examples = self.K - hard_examples 136 | 137 | if hard_examples >= self.K: 138 | # Enough hard examples to use for all negatives samples 139 | negative_indices[rowid, 1:] = row_indices 140 | elif hard_examples == 0: 141 | probs = torch.ones(self.nLem) 142 | probs[row_idx] = 0 143 | rand_indices = torch.multinomial(probs, self.K, replacement = True) 144 | negative_indices[rowid, 1:] = rand_indices 145 | else: 146 | # Mix of hard and random examples 147 | negative_indices[rowid, 1:hard_examples + 1] = row_indices[:hard_examples] 148 | probs = torch.ones(self.nLem) 149 | probs[row_indices[:hard_examples]] = 0 # Don't sample hard negatives a second time 150 | probs[row_idx] = 0 151 | rand_indices = torch.multinomial(probs, self.K - hard_examples, replacement = True) 152 | 153 | negative_indices[rowid, hard_examples + 1:] = rand_indices 154 | 155 | return negative_indices.long().cuda() 156 | 157 | def forward(self, inputs, threshold): 158 | 159 | anchor_feature = inputs['anchor_feature'] 160 | pair_feature= inputs['pair_feature'] 161 | membank_idx = inputs['membank_idx'] 162 | 163 | negatives = self.mine_negative_examples(anchor_feature, threshold, membank_idx) 164 | out_c = self.contrast(anchor_feature, pair_feature, negatives) 165 | self.update_memorybank(anchor_feature, membank_idx) 166 | 167 | 168 | return out_c 169 | -------------------------------------------------------------------------------- /loss/NCECriterion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | eps = 1e-7 5 | 6 | 7 | class NCECriterion(nn.Module): 8 | """ 9 | Eq. (12): L_{NCE} 10 | """ 11 | def __init__(self, n_data): 12 | super(NCECriterion, self).__init__() 13 | self.n_data = n_data 14 | 15 | def forward(self, x): 16 | 17 | bsz = x.shape[0] 18 | m = x.size(1) - 1 19 | 20 | # noise distribution 21 | Pn = 1 / float(self.n_data) 22 | 23 | # loss for positive pair 24 | P_pos = x.select(1, 0) 25 | log_D1 = torch.div(P_pos, P_pos.add(m * Pn + eps)).log_() 26 | 27 | # loss for K negative pair 28 | P_neg = x.narrow(1, 1, m) 29 | log_D0 = torch.div(P_neg.clone().fill_(m * Pn), P_neg.add(m * Pn + eps)).log_() 30 | 31 | loss = - (log_D1.sum(0) + log_D0.view(-1, 1).sum(0)) / bsz 32 | 33 | 34 | return loss 35 | 36 | class NCEAvgCriterion(nn.Module): 37 | def __init__(self, n_data): 38 | super(NCEAvgCriterion, self).__init__() 39 | self.n_data = n_data 40 | 41 | def forward(self, x): 42 | 43 | bsz = x.shape[0] 44 | m = x.size(1) - 1 45 | 46 | # noise distribution 47 | Pn = 1 / float(self.n_data) 48 | 49 | # loss for positive pair 50 | P_pos = x.select(1, 0) 51 | log_D1 = torch.div(P_pos, P_pos.add(m * Pn + eps)).log_() 52 | 53 | # loss for K negative pair 54 | P_neg = x.narrow(1, 1, m) 55 | log_D0 = torch.div(P_neg.clone().fill_(m * Pn), P_neg.add(m * Pn + eps)).log_() 56 | 57 | loss = - (log_D1.sum(0) + log_D0.view(-1, 1).mean(0)) / bsz 58 | 59 | 60 | return loss 61 | 62 | class NCEMaxCriterion(nn.Module): 63 | def __init__(self, n_data): 64 | super(NCEMaxCriterion, self).__init__() 65 | self.n_data = n_data 66 | 67 | def forward(self, x): 68 | 69 | bsz = x.shape[0] 70 | m = x.size(1) - 1 71 | 72 | # noise distribution 73 | Pn = 1 / float(self.n_data) 74 | 75 | # loss for positive pair 76 | P_pos = x.select(1, 0) 77 | log_D1 = torch.div(P_pos, P_pos.add(m * Pn + eps)).log_() 78 | 79 | # loss for K negative pair 80 | P_neg = x.narrow(1, 1, m) 81 | log_D0 = torch.div(P_neg.clone().fill_(m * Pn), P_neg.add(m * Pn + eps)).log_() 82 | log_D0_max, log_D0_indexes = log_D0.view(-1, 1).max(0) 83 | 84 | loss = - (log_D1.sum(0) + log_D0_max) / bsz 85 | 86 | 87 | return loss 88 | 89 | class ContrastCriterion(nn.Module): 90 | def __init__(self, n_data): 91 | super(ContrastCriterion, self).__init__() 92 | self.n_data = n_data 93 | 94 | def forward(self, x): 95 | 96 | bsz = x.shape[0] 97 | m = x.size(1) - 1 98 | 99 | # noise distribution 100 | Pn = 1 / float(self.n_data) 101 | 102 | # loss for positive pair 103 | P_pos = x.select(1, 0) 104 | log_D1 = torch.div(P_pos, P_pos.add(m * Pn + eps)).log_() 105 | 106 | # loss for K negative pair 107 | P_neg = x.narrow(1, 1, m) 108 | log_D0 = torch.avg(torch.div(P_neg.clone().fill_(m * Pn), P_neg.add(m * Pn + eps))).log_() 109 | 110 | loss = - (log_D1.sum(0) + log_D0.view(-1, 1).sum(0)) / bsz 111 | 112 | 113 | return loss -------------------------------------------------------------------------------- /loss/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from loss.NCECriterion import NCECriterion 4 | from loss.NCEAverage import * 5 | 6 | 7 | def get_loss(cfg, n_data): 8 | if cfg.LOSS.PRETRAINING.MINING == False: 9 | average = NCEAverageKinetics( 10 | feat_dim = cfg.MODEL.PRETRAINING.FC_DIM, 11 | n_data = n_data, 12 | K = cfg.LOSS.PRETRAINING.NCE.NEGATIVES, 13 | T = cfg.LOSS.PRETRAINING.NCE.TEMPERATURE, 14 | momentum = cfg.LOSS.PRETRAINING.NCE.MOMENTUM 15 | ) 16 | else: 17 | average = NCEAverageKineticsMining( 18 | feat_dim = cfg.MODEL.PRETRAINING.FC_DIM, 19 | n_data = n_data, 20 | K = cfg.LOSS.PRETRAINING.NCE.NEGATIVES, 21 | T = cfg.LOSS.PRETRAINING.NCE.TEMPERATURE, 22 | max_hard_negatives_percentage = cfg.LOSS.PRETRAINING.MINING.MAX_HARD_NEGATIVES_PERCENTAGE 23 | ) 24 | criterion = NCECriterion(n_data) 25 | 26 | TCELoss = NCELoss(average, criterion) 27 | RotationLoss = nn.CrossEntropyLoss() 28 | 29 | return TCELoss, RotationLoss 30 | 31 | class NCELoss(nn.Module): 32 | def __init__(self, Average, Criterion): 33 | super(NCELoss, self).__init__() 34 | self.Average = Average 35 | self.Criterion = Criterion 36 | 37 | def forward(self, inputs, threshold): 38 | NCE_Average = self.Average(inputs, threshold) 39 | NCE_loss = self.Criterion(NCE_Average) 40 | return NCE_loss -------------------------------------------------------------------------------- /network/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import PreTrainNet, FineTuneNet -------------------------------------------------------------------------------- /network/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from utils import get_backbone 4 | 5 | class Normalise(nn.Module): 6 | 7 | def __init__(self, power=2): 8 | super(Normalise, self).__init__() 9 | self.power = power 10 | 11 | def forward(self, x): 12 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 13 | out = x.div(norm) 14 | return out 15 | 16 | class PreTrainNet(nn.Module): 17 | def __init__(self, cfg): 18 | super(PreTrainNet, self).__init__() 19 | trunk = cfg.MODEL.TRUNK 20 | fc_dim = cfg.MODEL.PRETRAINING.FC_DIM 21 | 22 | self.backbone, backbone_channels = get_backbone(trunk) 23 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 24 | self.fc_contrastive = nn.Linear(backbone_channels, fc_dim) 25 | self.l2norm = Normalise(2) 26 | 27 | self.fc_rotation_1 = nn.Linear(backbone_channels, 200) 28 | self.bn_rotation_1 = nn.BatchNorm1d(num_features = 200) 29 | self.fc_rotation_2 = nn.Linear(200, 200) 30 | self.bn_rotation_2 = nn.BatchNorm1d(num_features = 200) 31 | self.fc_rotation_3 = nn.Linear(200,4) 32 | 33 | 34 | 35 | def forward(self, x, rotation = False): 36 | x = self.backbone(x) 37 | x = self.avgpool(x) 38 | x = torch.flatten(x, 1) 39 | if rotation == False: 40 | x = self.fc_contrastive(x) 41 | x = self.l2norm(x) 42 | return x 43 | elif rotation == True: 44 | x = self.fc_rotation_1(x) 45 | x = self.backbone.relu(x) 46 | x = self.bn_rotation_1(x) 47 | 48 | x = self.fc_rotation_2(x) 49 | x = self.backbone.relu(x) 50 | x = self.bn_rotation_2(x) 51 | 52 | x = self.fc_rotation_3(x) 53 | return x 54 | else: 55 | raise ValueError('Rotation is either a Boolean True or False') 56 | 57 | class FineTuneNet(nn.Module): 58 | def __init__(self, cfg): 59 | super(FineTuneNet, self).__init__() 60 | trunk = cfg.MODEL.TRUNK 61 | num_classes = cfg.MODEL.FINETUNING.NUM_CLASSES 62 | assert num_classes > 0 and isinstance(num_classes, int), 'Please give a positive integer for the number of classes in the finetuning stage' 63 | p_dropout = cfg.MODEL.FINETUNING.DROPOUT 64 | 65 | self.backbone, backbone_channels = get_backbone(trunk) 66 | self.avgpool = nn.AdaptiveAvgPool2d((1,1)) 67 | self.dropout = nn.Dropout(p = p_dropout) 68 | self.fc = nn.Linear(backbone_channels, num_classes) 69 | 70 | def forward(self, x): 71 | x = self.backbone(x) 72 | x = self.avgpool(x) 73 | x = torch.flatten(x,1) 74 | x = self.dropout(x) 75 | x = self.fc(x) 76 | 77 | return x 78 | -------------------------------------------------------------------------------- /network/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import numpy as np 5 | import torch.utils.model_zoo as model_zoo 6 | 7 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 8 | 'resnet152'] 9 | 10 | model_urls = { 11 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 12 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 13 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 14 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 15 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 16 | } 17 | 18 | 19 | def conv3x3(in_planes, out_planes, stride=1): 20 | """3x3 convolution with padding""" 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 22 | padding=1, bias=False) 23 | 24 | class BasicBlock(nn.Module): 25 | expansion = 1 26 | 27 | def __init__(self, inplanes, planes, stride=1, downsample=None): 28 | super(BasicBlock, self).__init__() 29 | self.conv1 = conv3x3(inplanes, planes, stride) 30 | self.bn1 = nn.BatchNorm2d(planes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv2 = conv3x3(planes, planes) 33 | self.bn2 = nn.BatchNorm2d(planes) 34 | self.downsample = downsample 35 | self.stride = stride 36 | 37 | def forward(self, x): 38 | residual = x 39 | 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv2(out) 45 | out = self.bn2(out) 46 | 47 | if self.downsample is not None: 48 | residual = self.downsample(x) 49 | 50 | out += residual 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | 56 | class Bottleneck(nn.Module): 57 | expansion = 4 58 | 59 | def __init__(self, inplanes, planes, stride=1, downsample=None): 60 | super(Bottleneck, self).__init__() 61 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 62 | self.bn1 = nn.BatchNorm2d(planes) 63 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 64 | padding=1, bias=False) 65 | self.bn2 = nn.BatchNorm2d(planes) 66 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 67 | self.bn3 = nn.BatchNorm2d(planes * 4) 68 | self.relu = nn.ReLU(inplace=True) 69 | self.downsample = downsample 70 | self.stride = stride 71 | 72 | def forward(self, x): 73 | residual = x 74 | out = self.conv1(x) 75 | out = self.bn1(out) 76 | out = self.relu(out) 77 | 78 | out = self.conv2(out) 79 | out = self.bn2(out) 80 | out = self.relu(out) 81 | 82 | out = self.conv3(out) 83 | out = self.bn3(out) 84 | 85 | if self.downsample is not None: 86 | residual = self.downsample(x) 87 | 88 | out += residual 89 | out = self.relu(out) 90 | 91 | return out 92 | 93 | 94 | class ResNet(nn.Module): 95 | 96 | def __init__(self, block, layers, in_channel=3, width=1): 97 | self.inplanes = 64 98 | super(ResNet, self).__init__() 99 | self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=7, stride=2, padding=3, 100 | bias=False) 101 | self.bn1 = nn.BatchNorm2d(64) 102 | self.relu = nn.ReLU(inplace=True) 103 | 104 | self.base = int(64 * width) 105 | 106 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 107 | self.layer1 = self._make_layer(block, self.base, layers[0]) 108 | self.layer2 = self._make_layer(block, self.base * 2, layers[1], stride=2) 109 | self.layer3 = self._make_layer(block, self.base * 4, layers[2], stride=2) 110 | self.layer4 = self._make_layer(block, self.base * 8, layers[3], stride=2) 111 | self.avgpool = nn.AvgPool2d(7, stride=1) 112 | 113 | for m in self.modules(): 114 | if isinstance(m, nn.Conv2d): 115 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 116 | m.weight.data.normal_(0, math.sqrt(2. / n)) 117 | elif isinstance(m, nn.BatchNorm2d): 118 | m.weight.data.fill_(1) 119 | m.bias.data.zero_() 120 | 121 | def _make_layer(self, block, planes, blocks, stride=1): 122 | downsample = None 123 | if stride != 1 or self.inplanes != planes * block.expansion: 124 | downsample = nn.Sequential( 125 | nn.Conv2d(self.inplanes, planes * block.expansion, 126 | kernel_size=1, stride=stride, bias=False), 127 | nn.BatchNorm2d(planes * block.expansion), 128 | ) 129 | 130 | layers = [] 131 | layers.append(block(self.inplanes, planes, stride, downsample)) 132 | self.inplanes = planes * block.expansion 133 | for i in range(1, blocks): 134 | layers.append(block(self.inplanes, planes)) 135 | 136 | return nn.Sequential(*layers) 137 | 138 | def forward(self, x, layer=7): 139 | if layer <= 0: 140 | return x 141 | x = self.conv1(x) 142 | x = self.bn1(x) 143 | x = self.relu(x) 144 | 145 | x = self.maxpool(x) 146 | if layer == 1: 147 | return x 148 | x = self.layer1(x) 149 | if layer == 2: 150 | return x 151 | x = self.layer2(x) 152 | if layer == 3: 153 | return x 154 | x = self.layer3(x) 155 | if layer == 4: 156 | return x 157 | x = self.layer4(x) 158 | if layer == 5: 159 | return x 160 | return x 161 | 162 | 163 | def resnet18(pretrained=False, **kwargs): 164 | """Constructs a ResNet-18 model. 165 | Args: 166 | pretrained (bool): If True, returns a model pre-trained on ImageNet 167 | """ 168 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 169 | if pretrained: 170 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 171 | return model 172 | 173 | 174 | def resnet34(pretrained=False, **kwargs): 175 | """Constructs a ResNet-34 model. 176 | Args: 177 | pretrained (bool): If True, returns a model pre-trained on ImageNet 178 | """ 179 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 180 | if pretrained: 181 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 182 | return model 183 | 184 | 185 | def resnet50(pretrained=False, **kwargs): 186 | """Constructs a ResNet-50 model. 187 | Args: 188 | pretrained (bool): If True, returns a model pre-trained on ImageNet 189 | """ 190 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 191 | if pretrained: 192 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 193 | return model 194 | 195 | 196 | def resnet101(pretrained=False, **kwargs): 197 | """Constructs a ResNet-101 model. 198 | Args: 199 | pretrained (bool): If True, returns a model pre-trained on ImageNet 200 | """ 201 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 202 | if pretrained: 203 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 204 | return model 205 | 206 | 207 | def resnet152(pretrained=False, **kwargs): 208 | """Constructs a ResNet-152 model. 209 | Args: 210 | pretrained (bool): If True, returns a model pre-trained on ImageNet 211 | """ 212 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 213 | if pretrained: 214 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 215 | return model 216 | 217 | 218 | class InsResNet50(nn.Module): 219 | """Encoder for instance discrimination and MoCo""" 220 | def __init__(self, width=1): 221 | super(InsResNet50, self).__init__() 222 | self.encoder = resnet50(width=width) 223 | self.encoder = nn.DataParallel(self.encoder) 224 | 225 | def forward(self, x, layer=7): 226 | return self.encoder(x, layer) 227 | 228 | 229 | class ResNetV1(nn.Module): 230 | def __init__(self, name='resnet50'): 231 | super(ResNetV1, self).__init__() 232 | if name == 'resnet50': 233 | self.l_to_ab = resnet50(in_channel=1, width=0.5) 234 | self.ab_to_l = resnet50(in_channel=2, width=0.5) 235 | elif name == 'resnet18': 236 | self.l_to_ab = resnet18(in_channel=1, width=0.5) 237 | self.ab_to_l = resnet18(in_channel=2, width=0.5) 238 | elif name == 'resnet101': 239 | self.l_to_ab = resnet101(in_channel=1, width=0.5) 240 | self.ab_to_l = resnet101(in_channel=2, width=0.5) 241 | else: 242 | raise NotImplementedError('model {} is not implemented'.format(name)) 243 | 244 | def forward(self, x, layer=7): 245 | l, ab = torch.split(x, [1, 2], dim=1) 246 | feat_l = self.l_to_ab(l, layer) 247 | feat_ab = self.ab_to_l(ab, layer) 248 | return feat_l, feat_ab 249 | 250 | 251 | class ResNetV2(nn.Module): 252 | def __init__(self, name='resnet50'): 253 | super(ResNetV2, self).__init__() 254 | if name == 'resnet50': 255 | self.l_to_ab = resnet50(in_channel=1, width=1) 256 | self.ab_to_l = resnet50(in_channel=2, width=1) 257 | elif name == 'resnet18': 258 | self.l_to_ab = resnet18(in_channel=1, width=1) 259 | self.ab_to_l = resnet18(in_channel=2, width=1) 260 | elif name == 'resnet101': 261 | self.l_to_ab = resnet101(in_channel=1, width=1) 262 | self.ab_to_l = resnet101(in_channel=2, width=1) 263 | else: 264 | raise NotImplementedError('model {} is not implemented'.format(name)) 265 | 266 | def forward(self, x, layer=7): 267 | l, ab = torch.split(x, [1, 2], dim=1) 268 | feat_l = self.l_to_ab(l, layer) 269 | feat_ab = self.ab_to_l(ab, layer) 270 | return feat_l, feat_ab 271 | 272 | 273 | class ResNetV3(nn.Module): 274 | def __init__(self, name='resnet50'): 275 | super(ResNetV3, self).__init__() 276 | if name == 'resnet50': 277 | self.l_to_ab = resnet50(in_channel=1, width=2) 278 | self.ab_to_l = resnet50(in_channel=2, width=2) 279 | elif name == 'resnet18': 280 | self.l_to_ab = resnet18(in_channel=1, width=2) 281 | self.ab_to_l = resnet18(in_channel=2, width=2) 282 | elif name == 'resnet101': 283 | self.l_to_ab = resnet101(in_channel=1, width=2) 284 | self.ab_to_l = resnet101(in_channel=2, width=2) 285 | else: 286 | raise NotImplementedError('model {} is not implemented'.format(name)) 287 | 288 | def forward(self, x, layer=7): 289 | l, ab = torch.split(x, [1, 2], dim=1) 290 | feat_l = self.l_to_ab(l, layer) 291 | feat_ab = self.ab_to_l(ab, layer) 292 | return feat_l, feat_ab 293 | 294 | 295 | class MyResNetsCMC(nn.Module): 296 | def __init__(self, name='resnet50v1'): 297 | super(MyResNetsCMC, self).__init__() 298 | if name.endswith('v1'): 299 | self.encoder = ResNetV1(name[:-2]) 300 | elif name.endswith('v2'): 301 | self.encoder = ResNetV2(name[:-2]) 302 | elif name.endswith('v3'): 303 | self.encoder = ResNetV3(name[:-2]) 304 | else: 305 | raise NotImplementedError('model not support: {}'.format(name)) 306 | 307 | self.encoder = nn.DataParallel(self.encoder) 308 | 309 | def forward(self, x, layer=7): 310 | return self.encoder(x, layer) 311 | -------------------------------------------------------------------------------- /pretrain.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import time 4 | import argparse 5 | import logging 6 | import torch 7 | import numpy as np 8 | from tensorboardX import SummaryWriter 9 | 10 | 11 | from config import config, update_config 12 | from network import * 13 | from datasets.pretraining import get_pretraining_dataset 14 | from loss import get_loss 15 | from utils import * 16 | 17 | 18 | torch.manual_seed(0) 19 | np.random.seed(0) 20 | torch.backends.cudnn.deterministic = True 21 | 22 | # TODO 23 | # Test Learning Rate Decay 24 | # Test end-of-epoch 25 | # Test Loading Checkpoints 26 | # Train end-to-end 27 | 28 | def parse_args(): 29 | parser = argparse.ArgumentParser(description='Train TCE Self-Supervised') 30 | parser.add_argument('--cfg', help = 'Path to config file', type = str, default = None) 31 | parser.add_argument('opts', help = 'Modify config using the command line', 32 | default = None, nargs=argparse.REMAINDER ) 33 | args = parser.parse_args() 34 | update_config(config, args) 35 | 36 | return args 37 | 38 | def train(epoch, train_loader, model, NCELoss, RotationLoss, optimizer, config, tboard, logger): 39 | 40 | # ===== Get total steps, set up meters ===== 41 | total_steps = config.TRAIN.PRETRAINING.EPOCHS * len(train_loader) 42 | model.train() 43 | NCELoss.train() 44 | 45 | batch_time = AverageMeter() 46 | 47 | MainLossMeter = AverageMeter() 48 | RotationLossMeter = AverageMeter() 49 | TotalLossMeter = AverageMeter() 50 | RotationAccMeter = AverageMeter() 51 | 52 | tl = config.LOSS.PRETRAINING.MINING.THRESH_LOW 53 | th = config.LOSS.PRETRAINING.MINING.THRESH_HIGH 54 | tr = config.LOSS.PRETRAINING.MINING.THRESH_RATE 55 | 56 | for idx, inputs in enumerate(train_loader): 57 | end = time.time() 58 | # ===== Prepare data and get threshold ====== 59 | anchor_tensor = inputs['anchor_tensor'].cuda() 60 | pair_tensor = inputs['pair_tensor'].cuda() 61 | rotation_tensor = inputs['rotation_tensor'].cuda() 62 | rotation_gt = inputs['rotation_gt'].cuda() 63 | inputs['negatives'] = inputs['negatives'].cuda() 64 | inputs['membank_idx'] = inputs['membank_idx'].cuda() 65 | 66 | step = (epoch - 1) * len(train_loader) + idx + 1 67 | threshold = tl + (th - tl) * (1 - math.exp(tr * step / total_steps)) 68 | 69 | # ===== Forward ===== 70 | bsz = anchor_tensor.size(0) 71 | inputs['anchor_feature'] = model(anchor_tensor, rotation = False) 72 | inputs['pair_feature'] = model(pair_tensor, rotation = False) 73 | rotation_feature = model(rotation_tensor, rotation = True) 74 | 75 | main_loss = NCELoss(inputs, threshold) 76 | with torch.no_grad(): 77 | rotation_indexes = rotation_feature.max(1)[1] 78 | rotation_accuracy = int(torch.sum(torch.eq(rotation_indexes, rotation_gt).long())) / bsz 79 | rotation_loss = RotationLoss(rotation_feature, rotation_gt).mul(config.LOSS.PRETRAINING.ROTATION_WEIGHT) 80 | total_loss = main_loss + rotation_loss 81 | 82 | # ===== Backward ===== 83 | optimizer.zero_grad() 84 | total_loss.backward() 85 | optimizer.step() 86 | 87 | # ===== Update Meters ===== 88 | batch_time.update(time.time() - end, bsz) 89 | MainLossMeter.update(main_loss.item(), bsz) 90 | RotationLossMeter.update(rotation_loss.item(), bsz) 91 | TotalLossMeter.update(total_loss.item(), bsz) 92 | RotationAccMeter.update(rotation_accuracy, bsz) 93 | 94 | # ===== Print and Update Tensorboards 95 | if idx % config.PRINT_FREQ == 0: 96 | logger.info('Train: [{0}][{1}/{2}]\t' 97 | 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 98 | 'Main Loss: {mainloss.val:.3f} ({mainloss.avg:.3f})\t' 99 | 'Rotation Loss {rotloss.val:.3f} ({rotloss.avg:.3f})\t' 100 | 'Total Loss {totloss.val:.3f} ({totloss.avg:.3f})\t'.format( 101 | epoch, idx + 1, len(train_loader), batch_time=batch_time, 102 | mainloss = MainLossMeter, 103 | rotloss = RotationLossMeter, 104 | totloss = TotalLossMeter, 105 | )) 106 | tboard.add_scalars("Main Loss", 107 | {"Absolute":MainLossMeter.val, "Average":MainLossMeter.avg}, step) 108 | tboard.add_scalars("Rotation Loss", 109 | {"Absolute":RotationLossMeter.val, "Average":RotationLossMeter.avg}, step) 110 | tboard.add_scalars("Total Loss", 111 | {"Absolute":TotalLossMeter.val, "Average":TotalLossMeter.avg}, step) 112 | tboard.add_scalars("Rotation Accuracy", 113 | {"Absolute":RotationAccMeter.val, "Average":RotationAccMeter.avg}, step) 114 | 115 | return_dic = { 116 | 'Main Loss' : {'Average' : MainLossMeter.avg}, 117 | 'Rotation Loss' : {'Average' : RotationLossMeter.avg}, 118 | 'Total Loss' : {'Average' : TotalLossMeter.avg}, 119 | 'Rotation Accuracy' : {'Average' : RotationAccMeter.avg}, 120 | } 121 | 122 | return return_dic 123 | 124 | 125 | def main(): 126 | args = parse_args() 127 | logger = setup_logger() 128 | logger.info(config) 129 | if not os.path.exists(config.ASSETS_PATH): 130 | os.makedirs(config.ASSETS_PATH) 131 | 132 | # ===== Create the dataloader ===== 133 | train_loader, n_data = get_pretraining_dataset(config) 134 | logger.info('Training with {} Train Samples'.format(n_data)) 135 | 136 | # ===== Create the model ===== 137 | model = PreTrainNet(config) 138 | logger.info('Built Model, using {} backbone'.format(config.MODEL.TRUNK)) 139 | if torch.cuda.device_count() > 1: 140 | model = torch.nn.DataParallel(model).cuda() 141 | else: 142 | model = model.cuda() 143 | logger.info('Training on {} GPUs'.format(torch.cuda.device_count())) 144 | 145 | # ===== Set the optimizer ===== 146 | optimizer = get_optimizer(model, config, pretraining = True ) 147 | 148 | # ===== Get the loss ===== 149 | NCELoss, RotationLoss = get_loss(config, n_data) 150 | NCELoss = NCELoss.cuda() 151 | RotationLoss = RotationLoss.cuda() 152 | 153 | # ===== Resume from am earlier checkpoint ===== 154 | start_epoch = 1 155 | if config.TRAIN.PRETRAINING.RESUME: 156 | try: 157 | checkpoint = torch.load(config.TRAIN.PRETRAINING.RESUME) 158 | except FileNotFoundError: 159 | raise FileNotFoundError('No Checkpoint found at path {}'.format(config.TRAIN.PRETRAINING.RESUME)) 160 | 161 | start_epoch = checkpoint['epoch'] + 1 162 | # ===== Align checkpoint keys with model ===== 163 | if 'module' in list(checkpoint['state_dict'].keys())[0] and 'module' not in list(model.state_dict().keys())[0]: 164 | checkpoint['state_dict'] = {k.replace('module.',''):v for k,v in checkpoint['state_dict'].items() } 165 | elif 'module' not in list(checkpoint['state_dict'].keys())[0] and 'module' in list(model.state_dict().keys())[0]: 166 | checkpoint['state_dict'] = {'module.' + k:v for k,v in checkpoint['state_dict'].items() } 167 | 168 | model.load_state_dict(checkpoint['state_dict']) 169 | optimizer.load_state_dict(checkpoint['optimizer']) 170 | NCELoss.load_state_dict(checkpoint['NCELoss']) 171 | logger.info('Loaded Checkpoint from "{}"'.format(config.TRAIN.PRETRAINING.RESUME)) 172 | else: 173 | logger.info('Training from Random Initialisation') 174 | 175 | # ===== Set up Save Directory and TensorBoard ===== 176 | assert config.TRAIN.PRETRAINING.SAVEDIR, 'Please specify save directory for model' 177 | if not os.path.exists(config.TRAIN.PRETRAINING.SAVEDIR): 178 | os.makedirs(config.TRAIN.PRETRAINING.SAVEDIR) 179 | os.makedirs(os.path.join(config.TRAIN.PRETRAINING.SAVEDIR, 'checkpoints')) 180 | os.makedirs(os.path.join(config.TRAIN.PRETRAINING.SAVEDIR, 'tboard')) 181 | 182 | tboard = SummaryWriter(logdir = os.path.join(config.TRAIN.PRETRAINING.SAVEDIR, 'tboard')) 183 | 184 | # ===== Train Loop ===== 185 | logger.info('Begin training') 186 | for epoch in range(start_epoch, config.TRAIN.PRETRAINING.EPOCHS + 1): 187 | adjust_learning_rate(epoch, config, optimizer, pretraining = True, logger = logger) 188 | logger.info('Training Epoch {}'.format(epoch)) 189 | 190 | return_dic = train( 191 | epoch = epoch, 192 | train_loader = train_loader, 193 | model = model, 194 | NCELoss = NCELoss, 195 | RotationLoss = RotationLoss, 196 | optimizer = optimizer, 197 | config = config, 198 | tboard = tboard, 199 | logger = logger, 200 | ) 201 | 202 | for key, value in return_dic.items(): 203 | tboard.add_scalars(key, value, epoch) 204 | 205 | if epoch % config.TRAIN.PRETRAINING.SAVE_FREQ == 0: 206 | state = { 207 | 'config': config, 208 | 'state_dict': model.state_dict(), 209 | 'optimizer': optimizer.state_dict(), 210 | 'NCELoss': NCELoss.state_dict(), 211 | 'epoch': epoch, 212 | } 213 | save_file = os.path.join(config.TRAIN.PRETRAINING.SAVEDIR, 'checkpoints', 'ckpt_epoch_{}.pth'.format(epoch)) 214 | logger.info('Saved Checkpoint to {}'.format(save_file)) 215 | torch.save(state, save_file) 216 | 217 | if __name__ == '__main__': 218 | main() 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .optimizer import * 2 | from .log_utils import * 3 | from .tsne_utils import * 4 | from .network_utils import * 5 | from .dataset_utils import * 6 | from .eval_utils import * -------------------------------------------------------------------------------- /utils/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import os, pickle 2 | 3 | class UCF101_splitter(): 4 | def __init__(self, root, split): 5 | self.root = root 6 | self.split = split 7 | 8 | def get_action_index(self): 9 | ''' 10 | Create a dictionary relating integer labels (e.g. 1, 2, ... 101) to their respective action classes 11 | ''' 12 | self.action_label = {} 13 | 14 | # ===== Open class index text file and retrieve lines ===== 15 | with open(os.path.join(self.root,'classInd.txt')) as f: 16 | content = f.readlines() 17 | content = [x.strip('\r\n') for x in content] 18 | f.close() 19 | 20 | # ===== Process the text file to associate labels with classes ===== 21 | for line in content: 22 | label,action = line.split(' ') 23 | if action not in self.action_label.keys(): 24 | self.action_label[action]=label 25 | 26 | return None 27 | 28 | 29 | def split_video(self): 30 | ''' 31 | Create dictionaries for the train and testing splits, relating the video name to the paths to the frames contained within that video 32 | ''' 33 | # ===== Create dictionary relating labels to action classes ===== 34 | self.get_action_index() 35 | train_video = None 36 | for path,subdir,files in os.walk(self.root): 37 | for filename in files: 38 | # ===== Create dictionary for train split ====== 39 | if filename.split('.')[0] == 'trainlist{:02d}'.format(self.split): 40 | train_video = self.file2_dic(os.path.join(self.root,filename)) 41 | 42 | # ===== Create dictionary for test split ===== 43 | if filename.split('.')[0] == 'testlist{:02d}'.format(self.split): 44 | test_video = self.file2_dic(os.path.join(self.root,filename)) 45 | 46 | # ===== Correct issue with handstandpushups class : Inconsistency between text file and folder name capitalisation ===== 47 | self.train_video = self.name_HandstandPushups(train_video) 48 | self.test_video = self.name_HandstandPushups(test_video) 49 | return self.train_video, self.test_video 50 | 51 | def file2_dic(self,fname): 52 | ''' 53 | Get a text file containing the list of videos, and return a dictionary relating the video name to the paths to the frames in the video 54 | 55 | Arguments: 56 | fname : Path to the text file containing a list of videos in the split 57 | ''' 58 | 59 | # ===== Create a list of videonames from the file ===== 60 | with open(fname) as f: 61 | content = f.readlines() 62 | content = [x.strip('\r\n') for x in content] 63 | f.close() 64 | 65 | # ===== Create a dictionary relating every video name to a list of paths to the frames within that video ===== 66 | dic={} 67 | for line in content: 68 | video = line.split('/',1)[1].split(' ',1)[0] 69 | key = 'v_' + video.split('_',1)[1].split('.',1)[0] 70 | label = self.action_label[line.split('/')[0]] 71 | dic[key] = int(label) 72 | return dic 73 | 74 | def name_HandstandPushups(self,dic): 75 | ''' 76 | Account for a discrepancy between the capitalisation of HandstandPushups in the split file and in the files in the dataset 77 | ''' 78 | dic = {k.replace('HandStandPushups', 'HandstandPushups'):v for k,v in dic.items()} 79 | return dic -------------------------------------------------------------------------------- /utils/eval_utils.py: -------------------------------------------------------------------------------- 1 | def accuracy(output, target, topk=(1,)): 2 | """Computes the precision@k for the specified values of k""" 3 | maxk = max(topk) 4 | batch_size = target.size(0) 5 | 6 | _, pred = output.topk(maxk, 1, True, True) 7 | pred = pred.t() 8 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 9 | 10 | res = [] 11 | for k in topk: 12 | correct_k = correct[:k].contiguous().view(-1).float().sum(0) 13 | res.append(correct_k.mul_(100.0 / batch_size)) 14 | return res -------------------------------------------------------------------------------- /utils/log_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import logging 3 | from termcolor import colored 4 | 5 | class AverageMeter(object): 6 | """Computes and stores the average and current value""" 7 | def __init__(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | self.reset() 13 | 14 | def reset(self): 15 | self.val = 0 16 | self.avg = 0 17 | self.sum = 0 18 | self.count = 0 19 | 20 | def update(self, val, n=1): 21 | self.val = val 22 | self.sum += val * n 23 | self.count += n 24 | self.avg = self.sum / self.count 25 | 26 | class _ColorfulFormatter(logging.Formatter): 27 | def __init__(self, *args, **kwargs): 28 | self._root_name = kwargs.pop("root_name") + "." 29 | self._abbrev_name = kwargs.pop("abbrev_name", "") 30 | if len(self._abbrev_name): 31 | self._abbrev_name = self._abbrev_name + "." 32 | super(_ColorfulFormatter, self).__init__(*args, **kwargs) 33 | 34 | def formatMessage(self, record): 35 | record.name = record.name.replace(self._root_name, self._abbrev_name) 36 | log = super(_ColorfulFormatter, self).formatMessage(record) 37 | if record.levelno == logging.WARNING: 38 | prefix = colored("WARNING", "red", attrs=["blink"]) 39 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: 40 | prefix = colored("ERROR", "red", attrs=["blink", "underline"]) 41 | else: 42 | return log 43 | return prefix + " " + log 44 | 45 | 46 | def setup_logger(color=True, name="TCE", abbrev_name=None): 47 | logger = logging.getLogger(name) 48 | logger.setLevel(logging.DEBUG) 49 | logger.propagate = False 50 | 51 | if abbrev_name is None: 52 | abbrev_name = name 53 | 54 | plain_formatter = logging.Formatter( 55 | "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S" 56 | ) 57 | # stdout logging: master only 58 | ch = logging.StreamHandler(stream=sys.stdout) 59 | ch.setLevel(logging.DEBUG) 60 | if color: 61 | formatter = _ColorfulFormatter( 62 | colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", 63 | datefmt="%m/%d %H:%M:%S", 64 | root_name=name, 65 | abbrev_name=str(abbrev_name), 66 | ) 67 | else: 68 | formatter = plain_formatter 69 | ch.setFormatter(formatter) 70 | logger.addHandler(ch) 71 | 72 | return logger 73 | -------------------------------------------------------------------------------- /utils/network_utils.py: -------------------------------------------------------------------------------- 1 | from network.resnet import * 2 | 3 | def get_backbone(trunk): 4 | if trunk == 'resnet18': 5 | backbone = resnet18() 6 | backbone_channels = 512 7 | elif trunk == 'resnet34': 8 | backbone = resnet34() 9 | backbone_channels = 512 10 | elif trunk == 'resnet50': 11 | backbone = resnet50() 12 | backbone_channels = 2048 13 | elif trunk == 'resnet101': 14 | backbone = resnet50() 15 | backbone_channels = 2048 16 | else: 17 | raise NotImplementedError('Backbone "{}" not currently supported'.format(trunk)) 18 | return backbone, backbone_channels 19 | 20 | def forgiving_load_state_dict(state_dict, model, logger): 21 | keys_dict = state_dict.keys() 22 | keys_model = model.state_dict().keys() 23 | 24 | keys_shared = [k for k in keys_dict if k in keys_model] 25 | keys_missing = [k for k in keys_model if k not in keys_dict] 26 | keys_unexpected = [k for k in keys_dict if k not in keys_model] 27 | 28 | load_dict = {k:v for k,v in state_dict.items() if k in keys_shared} 29 | model.load_state_dict(load_dict, strict = False) 30 | 31 | logger.info('Missing Keys in checkpoint : ') 32 | for k in keys_missing: 33 | logger.info('\t\t{}'.format(k)) 34 | logger.info('Unexpected Keys in checkpoint : ') 35 | for k in keys_unexpected: 36 | logger.info('\t\t{}'.format(k)) 37 | -------------------------------------------------------------------------------- /utils/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def get_optimizer(model, config, pretraining): # TODO Move to Utils 5 | if pretraining == True: 6 | lr = config.TRAIN.PRETRAINING.LEARNING_RATE 7 | momentum = config.TRAIN.PRETRAINING.MOMENTUM 8 | weight_decay = config.TRAIN.PRETRAINING.WEIGHT_DECAY 9 | elif pretraining == False: 10 | lr = config.TRAIN.FINETUNING.LEARNING_RATE 11 | momentum = config.TRAIN.FINETUNING.MOMENTUM 12 | weight_decay = config.TRAIN.FINETUNING.WEIGHT_DECAY 13 | 14 | optimizer = torch.optim.SGD( 15 | model.parameters(), 16 | lr = lr, 17 | momentum = momentum, 18 | weight_decay = weight_decay 19 | ) 20 | 21 | return optimizer 22 | 23 | def adjust_learning_rate(epoch, cfg, optimizer, pretraining, logger): 24 | """Sets the learning rate to the initial LR decayed by the given rate every steep step""" 25 | if pretraining == True: 26 | decay_epochs = cfg.TRAIN.PRETRAINING.DECAY_EPOCHS 27 | lr = cfg.TRAIN.PRETRAINING.LEARNING_RATE 28 | decay_factor = cfg.TRAIN.PRETRAINING.DECAY_FACTOR 29 | elif pretraining == False: 30 | decay_epochs = cfg.TRAIN.FINETUNING.DECAY_EPOCHS 31 | lr = cfg.TRAIN.FINETUNING.LEARNING_RATE 32 | decay_factor = cfg.TRAIN.FINETUNING.DECAY_FACTOR 33 | 34 | 35 | steps = np.sum(epoch > np.asarray(decay_epochs)) 36 | if steps > 0: 37 | new_lr = lr * (decay_factor ** steps) 38 | logger.info('Learning rate for epoch set to {}'.format(new_lr)) 39 | for param_group in optimizer.param_groups: 40 | param_group['lr'] = new_lr 41 | else: 42 | logger.info('Learning rate for epoch set to {}'.format(lr)) 43 | -------------------------------------------------------------------------------- /utils/tsne_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import time 4 | import numpy as np 5 | from tqdm import tqdm 6 | from glob import glob 7 | 8 | from sklearn.manifold import TSNE 9 | from sklearn.decomposition import PCA 10 | 11 | import PIL.Image as Image 12 | from torchvision import transforms 13 | import matplotlib 14 | matplotlib.use('Agg') 15 | import matplotlib.pyplot as plt 16 | 17 | def create_gif_frames(reduced_embeddings, target, logger): 18 | 19 | # ===== Set up axis and points ===== 20 | fig = plt.figure() 21 | ax = fig.add_subplot(1,1,1) 22 | 23 | ax.axes.xaxis.set_visible(False) 24 | ax.axes.yaxis.set_visible(False) 25 | 26 | embeddings_y = reduced_embeddings[:,0] 27 | embeddings_x = reduced_embeddings[:,1] 28 | 29 | # ===== Set up frame loader ===== 30 | is_folder = os.path.isdir(target) 31 | if is_folder: 32 | imlist = sorted(glob(os.path.join(target, '*'))) 33 | n_images = len(imlist) 34 | else: 35 | reader = cv2.VideoCapture(target) 36 | n_images = int(reader.get(cv2.CAP_PROP_FRAME_COUNT)) 37 | assert reader.isOpened(), 'Target was neither directory or a valid video file' 38 | # ===== Iterate over frames ===== 39 | frames = [] 40 | logger.info('Beginning frame generation') 41 | 42 | for idx in tqdm(range(n_images)): 43 | ax.cla() 44 | # ===== Get frame ===== 45 | if is_folder: 46 | frame = cv2.imread(imlist[idx]) 47 | else: 48 | tf, frame = reader.read() 49 | if not tf: 50 | logger.info('Reached end of video at frame {} rather than {}'.format(x, n_images)) 51 | break 52 | h, w = frame.shape[:2] 53 | frame = frame[:,:,::-1] 54 | height, width = frame.shape[:2] 55 | # ===== Plot on graph, highlighting the current frame's embedding ===== 56 | x = list(embeddings_x[:idx]) + list(embeddings_x[idx+1:]) 57 | y = list(embeddings_y[:idx]) + list(embeddings_y[idx+1:]) 58 | ax.plot(x,y,',', color='b') 59 | ax.plot(embeddings_x[idx], embeddings_y[idx], 'o', color='r') 60 | fig.canvas.draw() 61 | 62 | tsne = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='').reshape(fig.canvas.get_width_height()[::-1] + (3,)) 63 | tsne = cv2.resize(tsne, (width,height)) 64 | frames.append(np.concatenate((frame, tsne), axis = 1)) 65 | 66 | return frames 67 | 68 | class get_blob: 69 | def __init__(self, cfg): 70 | self.resize = transforms.Resize(cfg.VISUALISATION.TSNE.RESIZE) 71 | self.crop = transforms.CenterCrop(cfg.VISUALISATION.TSNE.CROP_SIZE) 72 | self.totensor = transforms.ToTensor() 73 | self.normalize = transforms.Normalize( 74 | mean = cfg.DATASET.PRETRAINING.MEAN, 75 | std = cfg.DATASET.PRETRAINING.STD 76 | ) 77 | 78 | def __call__(self, frame): 79 | frame = Image.fromarray(frame) 80 | frame = self.resize(frame) 81 | frame = self.crop(frame) 82 | frame = self.totensor(frame) 83 | frame = self.normalize(frame) 84 | frame = frame.unsqueeze(0) 85 | frame = frame.cuda() 86 | 87 | return frame 88 | 89 | def fit_pca(input_data, logger, num_dims=0, threshold=0.85, max_error=0.05, debug=False): 90 | 91 | optimised = False 92 | # ===== Fit PCA to num_dims output dimensions of num_dims is not 0 ===== 93 | if num_dims > 0: 94 | pca = PCA(n_components=num_dims) 95 | # ===== Fit PCA ===== 96 | reduced = pca.fit_transform(input_data) 97 | explained_variance = np.sum(pca.explained_variance_ratio_) 98 | logger.info("PCA transform fitted. Explained variance in {} dims is {}.".format(num_dims, explained_variance)) 99 | optimised = True 100 | 101 | """ 102 | Get data input shape, set variables for the PCA loop 103 | lower, upper is the lower and upper bounds we are trying to converge for PCA dimensionality to achieve 104 | the desired explained_variance, as a fraction of the original input dimensionality 105 | previous is the previous attempt to fit the PCA's dimensionality 106 | """ 107 | dims = min(input_data.shape[0], input_data.shape[1]) 108 | lower, upper = (0., 1.) 109 | previous = -1 110 | 111 | while not optimised: 112 | # ===== Iterate the PCA until desired explained_variance achieved ===== 113 | 114 | # ===== Fit a PCA to dimensionality halfway between lower and upper bounds already found ===== 115 | num_dims = int(dims * (0.5*(upper - lower) + lower)) 116 | if num_dims == previous: 117 | # ===== Settle on PCA dimensionality if this iteration has the same number of dimensions as the last ===== 118 | num_dims = int(upper * dims) 119 | optimised = True 120 | 121 | # ===== Fit PCA to num_dims for this iteration ===== 122 | t1 = time.time() 123 | logger.info('Fitting PCA') 124 | pca = PCA(n_components=num_dims) 125 | reduced = pca.fit_transform(input_data) 126 | logger.info('Time Taken = {} seconds'.format(time.time() - t1)) 127 | explained_variance = np.sum(pca.explained_variance_ratio_) 128 | previous = num_dims 129 | 130 | if debug: 131 | logger.info('Lower&Upper: ({}, {})\tNumber of dimensions: {}\tExplained Variance: {}' 132 | .format(lower, upper, num_dims, explained_variance)) 133 | 134 | if explained_variance < threshold: 135 | # ===== Raise lower bound if explained variance is high enough ===== 136 | lower = num_dims / dims 137 | else: 138 | # ===== Lower upper bound if explained variance is not high enough ===== 139 | upper = num_dims / dims 140 | 141 | if upper - lower < max_error: 142 | # ===== Settle on PCA dimensionality if upper and lower bounds within acceptable range of each other ===== 143 | optimised = True 144 | 145 | return reduced 146 | 147 | def fit_tsne(input_data, logger, pca=True, pca_threshold=0.85, pca_error=0.05, pca_num_dims=0, num_dims=2, num_iterations=500, debug=False): 148 | ''' 149 | Performs TSNE on the input data to reduce it to num_dims dimensions. 150 | Will first perform PCA by default to reduce the number of dimensions and make 151 | fitting tsne faster 152 | 153 | Arguments: 154 | input_data : A [SHAPE] [TENSOR, ARRAY?] containing the data for the PCA, where: #TODO 155 | DIM INFO 156 | pca : If True, erform some dimensionality reduction with a PCA before the TSNE to reduce computation time 157 | pca_threshold : Explained variance threshold for PCA 158 | pca_error : Acceptable distance between lower and upper bounds for the pca to be considered converged, as a value between 159 | 0 and 1 160 | pca_num_dims : If set to a non-zero value, PCA will reduce input to num_dims dimensions 161 | num_dims : Number of dimensions to reduce to using TSNE 162 | num_iterations : Number of iterations to run the TSNE for 163 | debug : If True, print debug information 164 | ''' 165 | # ===== Reduce data with PCA first if pca=True ===== 166 | if pca: 167 | input_data = fit_pca(input_data, logger, num_dims=pca_num_dims, threshold=pca_threshold, max_error=pca_error, debug=debug) 168 | t1 = time.time() 169 | logger.info('Fitting TSNE: ') 170 | tsne=TSNE(n_iter=num_iterations, n_components=num_dims) 171 | 172 | reduced = tsne.fit_transform(input_data) 173 | logger.info('Time Taken = {} seconds'.format(time.time() - t1)) 174 | 175 | return reduced -------------------------------------------------------------------------------- /visualise_tsne.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import imageio 5 | import argparse 6 | import numpy as np 7 | from glob import glob 8 | from tqdm import tqdm 9 | from glob import glob 10 | 11 | import matplotlib 12 | matplotlib.use('Agg') 13 | import matplotlib.pyplot as plt 14 | 15 | from network import PreTrainNet 16 | from config import config, update_config 17 | from utils import * 18 | 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser(description='Train TCE Self-Supervised') 22 | parser.add_argument('--cfg', help = 'Path to config file', type = str, default = None) 23 | parser.add_argument('--target', help = 'Path to visualisation target. Can be a video or folder of images', default = None) 24 | parser.add_argument('--ckpt', help = 'Checkpoint to visualise', type=str, required = True) 25 | parser.add_argument('--gif', action = 'store_true', default = False, help = 'Save output as a gif with the corresponding video alongside') 26 | parser.add_argument('--fps', type = float, help = 'Frames per second for gif', default = 30) 27 | parser.add_argument('--save', help = 'Save path', default = None) 28 | parser.add_argument('opts', help = 'Modify config using the command line', 29 | default = None, nargs=argparse.REMAINDER ) 30 | args = parser.parse_args() 31 | update_config(config, args) 32 | 33 | return args 34 | 35 | 36 | 37 | if __name__ == '__main__': 38 | args = parse_args() 39 | logger = setup_logger() 40 | 41 | # ===== Get model ===== 42 | model = PreTrainNet(config).eval() 43 | model = torch.nn.DataParallel(model).cuda() 44 | 45 | # ===== Load Model Checkpoint ===== 46 | checkpoint = torch.load(args.ckpt) 47 | model.load_state_dict(checkpoint['state_dict']) 48 | logger.info('Loaded Checkpoint from {}'.format(args.ckpt)) 49 | 50 | 51 | # ===== Get input transformation class ===== 52 | create_blob = get_blob(config) 53 | 54 | # ===== Create either list of frames or videoreader object ===== 55 | is_folder = os.path.isdir(args.target) 56 | if is_folder: 57 | imlist = sorted(glob(os.path.join(args.target, '*'))) 58 | n_images = len(imlist) 59 | else: 60 | reader = cv2.VideoCapture(args.target) 61 | n_images = int(reader.get(cv2.CAP_PROP_FRAME_COUNT)) 62 | assert reader.isOpened(), 'Target was neither directory or a valid video file' 63 | 64 | # ===== Process video frames with network ===== 65 | embeddings = [] 66 | logger.info('Processing {} Frames'.format(n_images)) 67 | for idx in tqdm(range(n_images)): 68 | if is_folder: 69 | frame = cv2.imread(imlist[idx]) 70 | else: 71 | tf, frame = reader.read() 72 | if not tf: 73 | logger.info('Reached end of video at frame {} rather than {}'.format(x, n_images)) 74 | break 75 | 76 | # ===== Get image dimensions, convert from BGR to RGB, create input blob ===== 77 | height, width = np.shape(frame)[:2] 78 | frame = frame[:, :, ::-1] 79 | blob = create_blob(frame) 80 | 81 | # ===== Get embedding from network ===== 82 | with torch.no_grad(): 83 | embeddings.append(model(blob).cpu().numpy()) 84 | 85 | if not is_folder: 86 | reader.release() 87 | 88 | # ===== Use TSNE to reduce embeddings to 2D for plotting ===== 89 | embeddings = np.array(embeddings)[:,0,:] 90 | reduced_embeddings = fit_tsne(embeddings, logger) 91 | 92 | 93 | # ===== Create Plot ===== 94 | if args.gif: 95 | # ===== Create gif if flag set ===== 96 | frames = create_gif_frames(reduced_embeddings, args.target, logger) 97 | imageio.mimsave(args.save, frames, duration = 1 / args.fps) 98 | else: 99 | # ===== Create TSNE single image ===== 100 | fig, ax = plt.subplots() 101 | ax.axes.xaxis.set_visible(False) 102 | ax.axes.yaxis.set_visible(False) 103 | 104 | colormap = plt.cm.plasma(np.linspace(0,1, len(reduced_embeddings))) 105 | for idx, point in enumerate(reduced_embeddings): 106 | ax.plot(point[0], point[1], '.', color = colormap[idx]) 107 | 108 | # ===== Write plot to image file ===== 109 | fig.canvas.draw() 110 | tsne = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='').reshape(fig.canvas.get_width_height()[::-1] + (3,)) 111 | Image.fromarray(tsne).save(args.save) 112 | 113 | 114 | 115 | 116 | --------------------------------------------------------------------------------