├── README.md
├── TCE.yaml
├── config
├── __init__.py
├── default.py
└── pretrain_kinetics400miningr_finetune_UCF101_resnet18.yaml
├── datasets
├── finetuning
│ ├── UCF101.py
│ └── __init__.py
└── pretraining
│ ├── __init__.py
│ ├── augmentation.py
│ └── kinetics400.py
├── finetune.py
├── images
├── TCE.png
└── bowling_tsne_example.gif
├── loss
├── NCEAverage.py
├── NCECriterion.py
└── __init__.py
├── network
├── __init__.py
├── model.py
└── resnet.py
├── pretrain.py
├── utils
├── __init__.py
├── dataset_utils.py
├── eval_utils.py
├── log_utils.py
├── network_utils.py
├── optimizer.py
└── tsne_utils.py
└── visualise_tsne.py
/README.md:
--------------------------------------------------------------------------------
1 | # Temporally Coherent Embeddings for Self-Supervised Video Representation Learning
2 | This repository contains the code implementation used in the ICPR2020 paper Temporally Coherent Embeddings for Self-Supervised Video Representation Learning (TCE). \[[arXiv](https://arxiv.org/abs/2004.02753)] \[[Website](https://csiro-robotics.github.io/TCE-Webpage/)] Our contributions in this repository are:
3 | - A Pytorch implementation of the self-supervised training used in the TCE paper
4 | - A Pytorch implementation of action recognition fine-tuning
5 | - Pre-trained checkpoints for models trained using the TCE self-supervised training paradigm
6 | - A Pytorch implementation of t-SNE visualisations of the network output
7 |
8 | 
9 |
10 | We benchmark our code on Split 1 of the UCF101 action recognition dataset, providing pre-trained models for our downstream and upstream training. See [Models](#models) for our provided models and Getting Started (#getting-started) for for instructions on training and evaluation.
11 |
12 | If you find this repo useful for your research, please consider citing the paper
13 | ```
14 | @inproceedings{knights2020tce,
15 | title={Temporally Coherent Embeddings for Self-Supervised Video Representation Learning},
16 | author={Joshua Knights and Ben Harwood and Daniel Ward and Anthony Vanderkop and Olivia Mackenzie-Ross and Peyman Moghadam},
17 | booktitle={25th International Conference on Pattern Recognition (ICPR)},
18 | year={2020}
19 | }
20 |
21 | ```
22 |
23 |
24 | ## Updates
25 | - 23/04/2020 : Initial Commit
26 | - 30/11/2020 : ICPR Update
27 |
28 | ## Table of Contents
29 |
30 | - [Data Preparation](#data-preparation)
31 | - [Installation](#installation)
32 | - [Models](#models)
33 | - [Getting Started](#getting-started)
34 | - [Acknowledgements](#acknowledgements)
35 |
36 |
37 | ## Data Preparation
38 |
39 |
40 | ### Kinetics400
41 | Kinetics400 videos can be downloaded and split into frames directly from [Showmax/kinetics-downloader](https://github.com/Showmax/kinetics-downloader)
42 |
43 | The file directory should have the following layout:
44 | ```
45 | ├── kinetics400/train
46 | |
47 | ├── CLASS_001
48 | ├── CLASS_002
49 | .
50 | .
51 | .
52 | CLASS_400
53 | |
54 | ├── VID_001
55 | ├── VID_002
56 | .
57 | .
58 | .
59 | ├── VID_###
60 | |
61 | ├── frame1.jpg
62 | ├── frame2.jpg
63 | .
64 | .
65 | .
66 | ├── frame###.jpg
67 | ```
68 | Once the dataset is downloaded and split into frames, edit the following parameters in config/default.py to point towards the frames and splits:
69 | - DATASET.KINETICS400.FRAMES_PATH = /path/to/kinetics400/train
70 |
71 | ### UCF101
72 |
73 | UCF101 frames and splits can be downloaded directly from [feichtenhofer/twostreamfusion](https://github.com/feichtenhofer/twostreamfusion)
74 |
75 | ```
76 | wget http://ftp.tugraz.at/pub/feichtenhofer/tsfusion/data/ucf101_jpegs_256.zip.001
77 | wget http://ftp.tugraz.at/pub/feichtenhofer/tsfusion/data/ucf101_jpegs_256.zip.002
78 | wget http://ftp.tugraz.at/pub/feichtenhofer/tsfusion/data/ucf101_jpegs_256.zip.003
79 |
80 | cat ucf101_jpegs_256.zip* > ucf101_jpegs_256.zip
81 | unzip ucf101_jpegs_256.zip
82 | ```
83 | The file directory should have the following layout:
84 |
85 | ```
86 | ├── UCF101
87 | |
88 | ├── v_{_CLASS_001}_g01_c01
89 | . |
90 | . ├── frame000001.jpg
91 | . ├── frame000002.jpg
92 | . .
93 | . .
94 | . ├── frame000###.jpg
95 | .
96 | ├── v_{_CLASS_101}_g##_c##
97 | |
98 | ├── frame000001.jpg
99 | ├── frame000002.jpg
100 | .
101 | .
102 | ├── frame000###.jpg
103 | ```
104 |
105 | Once the dataset is downloaded and decompressed, edit the following parameters in config/default.py to point towards the frames and splits:
106 | - DATASET.UCF101.FRAMES_PATH = /path/to/UCF101_frames
107 | - DATASET.UCF101.SPLITS_PATH = /path/to/UCF101_splits
108 |
109 |
110 |
111 |
112 |
113 | ## Installation
114 |
115 |
116 | TCE is built using Python == 3.7.1 and PyTorch == 1.7.0
117 |
118 | We use Conda to setup the Python environment for this repository. In order to create the environment, run the following commands from the root directory:
119 |
120 | ```
121 | conda env create -f TCE.yaml
122 | conda activate TCE
123 | ```
124 |
125 | Once this is done, also specify a path to save assets (such as dataset pickles for faster setup) to in config.default.py:
126 | - ASSETS_PATH = /path/to/assets/folder
127 |
128 |
129 |
130 | ## Models
131 |
132 |
133 | | Architecture | Pre-Training Dataset | Link |
134 | |-------------- |---------------------- |---------------------------------------------------------------- |
135 | | ResNet-18 | Kinetics400 | [Link](https://cloudstor.aarnet.edu.au/plus/s/kNQKw5ATTbyamg2) |
136 | | ResNet-50 | Kinetics400 | [Link](https://cloudstor.aarnet.edu.au/plus/s/HbWxmhcUbfzQIQf) |
137 |
138 | ## Getting Started
139 |
140 | ### Self-Supervised Training
141 | We provide a script for pre-training with the Kinetics400 dataset using TCE, pretrain.py. To train, run the following script:
142 |
143 | ```
144 | python finetune.py \
145 | --cfg config/pretrain_kinetics400miningr_finetune_UCF101_resnet18.yaml \
146 | TRAIN.PRETRAINING.SAVEDIR /path/to/savedir
147 | ```
148 |
149 | If resuming from a previous pre-training checkpoint, set the flag `TRAIN.PRETRAINING.CHECKPOINT` to the path to the checkpoint to resume from
150 |
151 | ### Fine-tuning for action recognition
152 | We provide a fine-tuning script for action recognition on the UCF-101 dataset, finetune.py. To train, run the following script:
153 |
154 | ```
155 | python finetune.py \
156 | --cfg config/pretrain_kinetics400miningr_finetune_UCF101_resnet18.yaml \
157 | TRAIN.FINETUNING.CHECKPOINT "/path/to/pretrained_checkpoint" \
158 | TRAIN.FINETUNING.SAVEDIR "/path/to/savedir"
159 | ```
160 |
161 | If resuming training from an earlier finetuning checkpoint, set the flag `TRAIN.FINETUNING.RESUME` to True
162 |
163 |
164 |
165 |
166 | ### Visualisation
167 |
168 | 
169 |
170 | In order to demonstrate the ability of our approach to create temporally coherent embeddings, we provide a package to create t-SNE visualisations of our features similar to those found in the paper. This package can also be applied to other approaches and network architectures.
171 |
172 | The files in this repository used for generating t-SNE visualisations are:
173 | - `visualise_tsne.py` Is a wrapper for t-SNE and our network architecture for end-to-end generation of the t-SNE
174 | - `utils/tsne_utils.py` Contains t-SNE functionality for reducing the dimensionality of an array of embedded features for plotting, as well as tools to create an animated visualisation of the embedding's behaviour over time
175 |
176 | The following flags can be used as inputs for `make_tsne.py`:
177 | - `--cfg` : Path to config file
178 | - `--target` : Path to video to visualise t-SNE for. This video can either be a video file (avi, mp4) or a directory of images representing frames
179 | - `--ckpt` : Path to the model chekpoint to visualise the embedding space for
180 | - `--gif` : Use to visualise the change in the embedding space over time alongside the input video as a gif file
181 | - `--fps` : Set the framerate of the gif
182 | - `--save` : Path to save the output t-SNE to
183 |
184 | To visualise the embeddings from TCE, download our self-supervised model above and use the following command to visualise our embedding space as a gif:
185 |
186 | ```
187 | python visualise_tsne.py
188 | --cfg config/pretrain_kinetics400miningr_finetune_UCF101_resnet18.yaml \
189 | --target "/path/to/target/video" \
190 | --ckpt "/path/to/TCE_checkpoint" \
191 | --gif \
192 | --fps 25 \
193 | --save "/path/to/save/folder/t-SNE.gif"
194 | ```
195 |
196 | Alternatively, to visualise the t-SNE as a PNG image use the following:
197 |
198 | ```
199 | python visualise_tsne.py
200 | --cfg config/pretrain_kinetics400miningr_finetune_UCF101_resnet18.yaml \
201 | --target "/path/to/target/video" \
202 | --ckpt "/path/to/TCE_checkpoint" \
203 | --save "/path/to/save/folder/t-SNE.png"
204 | ```
205 |
206 |
207 |
208 |
209 |
210 | ## Acknowledgements
211 | Parts of this code base are derived from Yonglong Tian's unsupervised learning algorithm [Contrastive Multiview Coding](https://github.com/HobbitLong/CMC) and Jeffrey Huang's implementation of [action recognition](https://github.com/jeffreyyihuang/two-stream-action-recognition).
212 |
213 |
214 |
215 |
216 |
--------------------------------------------------------------------------------
/TCE.yaml:
--------------------------------------------------------------------------------
1 | name: TCE
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 | - _libgcc_mutex=0.1=main
8 | - blas=1.0=mkl
9 | - bzip2=1.0.8=h7b6447c_0
10 | - ca-certificates=2020.10.14=0
11 | - cairo=1.14.12=h8948797_3
12 | - certifi=2020.11.8=py37h06a4308_0
13 | - cudatoolkit=10.2.89=hfd86e86_1
14 | - cycler=0.10.0=py37_0
15 | - dbus=1.13.18=hb2f20db_0
16 | - expat=2.2.10=he6710b0_2
17 | - ffmpeg=4.0=hcdf2ecd_0
18 | - fontconfig=2.13.0=h9420a91_0
19 | - freeglut=3.0.0=hf484d3e_5
20 | - freetype=2.10.4=h5ab3b9f_0
21 | - glib=2.63.1=h5a9c865_0
22 | - graphite2=1.3.14=h23475e2_0
23 | - gst-plugins-base=1.14.0=hbbd80ab_1
24 | - gstreamer=1.14.0=hb453b48_1
25 | - harfbuzz=1.8.8=hffaf4a1_0
26 | - hdf5=1.10.2=hba1933b_1
27 | - icu=58.2=he6710b0_3
28 | - imageio=2.9.0=py_0
29 | - intel-openmp=2020.2=254
30 | - jasper=2.0.14=h07fcdf6_1
31 | - joblib=0.17.0=py_0
32 | - jpeg=9b=h024ee3a_2
33 | - kiwisolver=1.3.0=py37h2531618_0
34 | - lcms2=2.11=h396b838_0
35 | - libedit=3.1.20191231=h14c3975_1
36 | - libffi=3.2.1=hf484d3e_1007
37 | - libgcc-ng=9.1.0=hdf63c60_0
38 | - libgfortran-ng=7.3.0=hdf63c60_0
39 | - libglu=9.0.0=hf484d3e_1
40 | - libopencv=3.4.2=hb342d67_1
41 | - libopus=1.3.1=h7b6447c_0
42 | - libpng=1.6.37=hbc83047_0
43 | - libprotobuf=3.13.0.1=h8b12597_0
44 | - libstdcxx-ng=9.1.0=hdf63c60_0
45 | - libtiff=4.1.0=h2733197_1
46 | - libuuid=1.0.3=h1bed415_2
47 | - libuv=1.40.0=h7b6447c_0
48 | - libvpx=1.7.0=h439df22_0
49 | - libxcb=1.14=h7b6447c_0
50 | - libxml2=2.9.10=hb55368b_3
51 | - lz4-c=1.9.2=heb0550a_3
52 | - matplotlib=3.3.2=0
53 | - matplotlib-base=3.3.2=py37h817c723_0
54 | - mkl=2020.2=256
55 | - mkl-service=2.3.0=py37he904b0f_0
56 | - mkl_fft=1.2.0=py37h23d657b_0
57 | - mkl_random=1.1.1=py37h0573a6f_0
58 | - ncurses=6.2=he6710b0_1
59 | - ninja=1.10.1=py37hfd86e86_0
60 | - numpy=1.19.2=py37h54aff64_0
61 | - numpy-base=1.19.2=py37hfa32c7d_0
62 | - olefile=0.46=py37_0
63 | - opencv=3.4.2=py37h6fd60c2_1
64 | - openssl=1.1.1h=h7b6447c_0
65 | - pcre=8.44=he6710b0_0
66 | - pillow=8.0.1=py37he98fc37_0
67 | - pip=20.2.4=py37h06a4308_0
68 | - pixman=0.40.0=h7b6447c_0
69 | - protobuf=3.13.0.1=py37h745909e_1
70 | - py-opencv=3.4.2=py37hb342d67_1
71 | - pyparsing=2.4.7=py_0
72 | - pyqt=5.9.2=py37h05f1152_2
73 | - python=3.7.1=h0371630_7
74 | - python-dateutil=2.8.1=py_0
75 | - python_abi=3.7=1_cp37m
76 | - pytorch=1.7.0=py3.7_cuda10.2.89_cudnn7.6.5_0
77 | - pyyaml=5.3.1=py37hb5d75c8_1
78 | - qt=5.9.7=h5867ecd_1
79 | - readline=7.0=h7b6447c_5
80 | - scikit-learn=0.23.2=py37h0573a6f_0
81 | - scipy=1.5.2=py37h0b6359f_0
82 | - setuptools=50.3.1=py37h06a4308_1
83 | - sip=4.19.8=py37hf484d3e_0
84 | - six=1.15.0=py37h06a4308_0
85 | - sqlite=3.33.0=h62c20be_0
86 | - tensorboardx=2.1=py_0
87 | - termcolor=1.1.0=py37_1
88 | - threadpoolctl=2.1.0=pyh5ca1d4c_0
89 | - tk=8.6.10=hbc83047_0
90 | - torchvision=0.8.1=py37_cu102
91 | - tornado=6.0.4=py37h7b6447c_1
92 | - tqdm=4.51.0=pyhd3eb1b0_0
93 | - typing_extensions=3.7.4.3=py_0
94 | - wheel=0.35.1=pyhd3eb1b0_0
95 | - xz=5.2.5=h7b6447c_0
96 | - yacs=0.1.6=py_0
97 | - yaml=0.2.5=h516909a_0
98 | - zlib=1.2.11=h7b6447c_3
99 | - zstd=1.4.5=h9ceee32_0
100 | prefix: /scratch1/kni101/miniconda3/envs/TCE
101 |
102 |
--------------------------------------------------------------------------------
/config/__init__.py:
--------------------------------------------------------------------------------
1 | from .default import _C as config
2 | from .default import update_config
--------------------------------------------------------------------------------
/config/default.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from yacs.config import CfgNode as CN
4 |
5 | _C = CN()
6 |
7 |
8 | # -----------------------------------------------------------------------------
9 | # Misc
10 | # -----------------------------------------------------------------------------
11 | _C.WORKERS = 8
12 | _C.PRINT_FREQ = 20
13 | _C.ASSETS_PATH = '/scratch1/kni101/TCE/assets'
14 |
15 | # -----------------------------------------------------------------------------
16 | # Dataset
17 | # -----------------------------------------------------------------------------
18 | _C.DATASET = CN()
19 |
20 | _C.DATASET.KINETICS400 = CN()
21 | _C.DATASET.KINETICS400.FRAMES_PATH = '/datasets/work/d61-eif/source/kinetics400-frames'
22 |
23 | _C.DATASET.UCF101 = CN()
24 | _C.DATASET.UCF101.FRAMES_PATH = '/datasets/work/d61-eif/source/UCF-101-twostream/jpegs_256'
25 | _C.DATASET.UCF101.SPLITS_PATH = '/datasets/work/d61-eif/source/UCF-101-twostream/UCF_list'
26 | _C.DATASET.UCF101.SPLIT = 1
27 |
28 | _C.DATASET.PRETRAINING = CN()
29 | _C.DATASET.PRETRAINING.DATASET = 'kinetics400'
30 | _C.DATASET.PRETRAINING.MEAN = [0.485, 0.456, 0.406]
31 | _C.DATASET.PRETRAINING.STD = [0.229, 0.224, 0.225]
32 |
33 | _C.DATASET.FINETUNING = CN()
34 | _C.DATASET.FINETUNING.DATASET = 'UCF101'
35 | _C.DATASET.FINETUNING.MEAN = [0.485, 0.456, 0.406]
36 | _C.DATASET.FINETUNING.STD = [0.229, 0.224, 0.225]
37 |
38 |
39 | _C.DATASET.PRETRAINING.TRANSFORMATIONS = CN()
40 | _C.DATASET.PRETRAINING.TRANSFORMATIONS.CROP_SIZE = (224,224)
41 | _C.DATASET.PRETRAINING.TRANSFORMATIONS.SCALE_SIZE = (224,224)
42 | _C.DATASET.PRETRAINING.TRANSFORMATIONS.HORIZONTAL_FLIP = True
43 | _C.DATASET.PRETRAINING.TRANSFORMATIONS.RANDOM_GREY = True
44 | _C.DATASET.PRETRAINING.TRANSFORMATIONS.COLOUR_JITTER = True
45 |
46 | _C.DATASET.FINETUNING.TRANSFORMATIONS = CN()
47 | _C.DATASET.FINETUNING.TRANSFORMATIONS.RESIZE = (224,224)
48 | _C.DATASET.FINETUNING.TRANSFORMATIONS.CROP_SIZE = (224,224)
49 |
50 |
51 | # -----------------------------------------------------------------------------
52 | # Model
53 | # -----------------------------------------------------------------------------
54 | _C.MODEL = CN()
55 | _C.MODEL.TRUNK = 'resnet18'
56 | _C.MODEL.PRETRAINING = CN()
57 | _C.MODEL.PRETRAINING.FC_DIM = 128
58 |
59 | _C.MODEL.FINETUNING = CN()
60 | _C.MODEL.FINETUNING.NUM_CLASSES = -1
61 | _C.MODEL.FINETUNING.DROPOUT = 0
62 |
63 | # -----------------------------------------------------------------------------
64 | # Loss
65 | # -----------------------------------------------------------------------------
66 | _C.LOSS = CN()
67 | _C.LOSS.PRETRAINING = CN()
68 | _C.LOSS.PRETRAINING.ROTATION_WEIGHT = 1
69 |
70 | _C.LOSS.PRETRAINING.NCE = CN()
71 | _C.LOSS.PRETRAINING.NCE.NEGATIVES = 8192
72 | _C.LOSS.PRETRAINING.NCE.TEMPERATURE = 0.07
73 | _C.LOSS.PRETRAINING.NCE.MOMENTUM = 0.5
74 |
75 |
76 | _C.LOSS.PRETRAINING.MINING = CN()
77 | _C.LOSS.PRETRAINING.MINING.USE_MINING = True
78 | _C.LOSS.PRETRAINING.MINING.THRESH_LOW = -1
79 | _C.LOSS.PRETRAINING.MINING.THRESH_HIGH = 1
80 | _C.LOSS.PRETRAINING.MINING.THRESH_RATE = 5
81 | _C.LOSS.PRETRAINING.MINING.MAX_HARD_NEGATIVES_PERCENTAGE = 1
82 |
83 | # -----------------------------------------------------------------------------
84 | # Training
85 | # -----------------------------------------------------------------------------
86 | _C.TRAIN = CN()
87 | _C.TRAIN.PRETRAINING = CN()
88 | _C.TRAIN.PRETRAINING.LEARNING_RATE = 0.03
89 | _C.TRAIN.PRETRAINING.DECAY_EPOCHS = (25,)
90 | _C.TRAIN.PRETRAINING.DECAY_FACTOR = 0.1
91 |
92 | _C.TRAIN.PRETRAINING.SAVEDIR = ''
93 | _C.TRAIN.PRETRAINING.MOMENTUM = 0.9
94 | _C.TRAIN.PRETRAINING.WEIGHT_DECAY = 1e-4
95 | _C.TRAIN.PRETRAINING.SAVE_FREQ = 1
96 | _C.TRAIN.PRETRAINING.BATCH_SIZE = 100
97 | _C.TRAIN.PRETRAINING.EPOCHS = 50
98 | _C.TRAIN.PRETRAINING.RESUME = ''
99 | _C.TRAIN.PRETRAINING.FRAME_PADDING = 0
100 |
101 | _C.TRAIN.FINETUNING = CN()
102 | _C.TRAIN.FINETUNING.CHECKPOINT = ''
103 | _C.TRAIN.FINETUNING.SAVEDIR = ''
104 | _C.TRAIN.FINETUNING.RESUME = False
105 | _C.TRAIN.FINETUNING.BATCH_SIZE = 100
106 | _C.TRAIN.FINETUNING.LEARNING_RATE = 0.05
107 | _C.TRAIN.FINETUNING.EPOCHS = 900
108 | _C.TRAIN.FINETUNING.DECAY_EPOCHS = (375,)
109 | _C.TRAIN.FINETUNING.MOMENTUM = 0.9
110 | _C.TRAIN.FINETUNING.WEIGHT_DECAY = 0
111 | _C.TRAIN.FINETUNING.DECAY_FACTOR = 0.1
112 | _C.TRAIN.FINETUNING.VAL_FREQ = 3
113 | # -----------------------------------------------------------------------------
114 | # Visualisation
115 | # -----------------------------------------------------------------------------
116 | _C.VISUALISATION = CN()
117 | _C.VISUALISATION.TSNE = CN()
118 | _C.VISUALISATION.TSNE.CROP_SIZE = 224
119 | _C.VISUALISATION.TSNE.RESIZE = 256
120 |
121 |
122 | def update_config(cfg, args):
123 | cfg.defrost()
124 | if args.cfg != None:
125 | cfg.merge_from_file(args.cfg)
126 | cfg.merge_from_list(args.opts)
127 | cfg.freeze()
128 |
--------------------------------------------------------------------------------
/config/pretrain_kinetics400miningr_finetune_UCF101_resnet18.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | PRETRAINING:
3 | DATASET: 'kinetics400'
4 | TRANSFORMATIONS:
5 | CROP_SIZE: (224,224)
6 | SCALE_SIZE: (224,224)
7 | HORIZONTAL_FLIP: True
8 | RANDOM_GREY: True
9 | COLOUR_JITTER: True
10 | FINETUNING:
11 | TRANSFORMATIONS:
12 | RESIZE: (224,224)
13 | CROP_SIZE: (224,224)
14 | MODEL:
15 | TRUNK: 'resnet18'
16 | PRETRAINING:
17 | FC_DIM: 128
18 | FINETUNING:
19 | NUM_CLASSES: 101
20 | DROPOUT: 0
21 | LOSS:
22 | PRETRAINING:
23 | ROTATION_WEIGHT: 1
24 | NCE:
25 | NEGATIVES: 8192
26 | TEMPERATURE: 0.07
27 | MOMENTUM: 0.5
28 | MINING:
29 | THRESH_LOW: -1
30 | THRESH_HIGH: 1
31 | THRESH_RATE: 5
32 | MAX_HARD_NEGATIVES_PERCENTAGE: 1
33 | TRAIN:
34 | PRETRAINING:
35 | SAVEDIR: ''
36 | LEARNING_RATE: 0.03
37 | DECAY_EPOCHS: (25,)
38 | DECAY_FACTOR: 0.1
39 | MOMENTUM: 0.9
40 | WEIGHT_DECAY: 1e-4
41 | BATCH_SIZE: 100
42 | EPOCHS: 50
43 | FRAME_PADDING: 0
44 | FINETUNING:
45 | SAVEDIR: ''
46 | CHECKPOINT: ''
47 | RESUME: FALSE
48 | LEARNING_RATE: 0.05
49 | DECAY_EPOCHS: (375,)
50 | DECAY_FACTOR: 0.1
51 | MOMENTUM: 0.9
52 | WEIGHT_DECAY: 0
53 | BATCH_SIZE: 100
54 | EPOCHS: 900
55 | VAL_FREQ: 3
56 | VISUALISATION:
57 | TSNE:
58 | CROP_SIZE: 224
59 | RESIZE: 224
60 |
61 |
62 |
--------------------------------------------------------------------------------
/datasets/finetuning/UCF101.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import random
4 | from tqdm import tqdm
5 | import PIL.Image as Image
6 | import multiprocessing as mp
7 |
8 | from utils import UCF101_splitter
9 | from torchvision import transforms
10 | from torch.utils.data import DataLoader
11 |
12 | class UCF_dataset:
13 | def __init__(self, root, file_dict, mode, transforms):
14 | self.root = root
15 | self.keys = list(file_dict.keys())
16 | self.gts = list(file_dict.values())
17 | self.mode = mode
18 | self.transform = transforms
19 |
20 | def __len__(self):
21 | return len(self.keys)
22 |
23 | def load_ucf_image(self, video_name, index):
24 |
25 | frame_path = os.path.join(self.root, '{}/frame{:06d}.jpg'.format(
26 | video_name, index))
27 |
28 | frame = Image.open(frame_path)
29 | frame_tensor = self.transform(frame)
30 |
31 | return frame_tensor
32 |
33 | def __getitem__(self, idx):
34 | if self.mode == 'train':
35 | # ===== Get training sample =====
36 | video_name, nb_clips = self.keys[idx].split(' ')
37 |
38 | nb_clips = int(nb_clips)
39 | label = self.gts[idx]
40 | clips = []
41 | clips.append(random.randint(1, int(nb_clips / 3)))
42 | clips.append(random.randint(int(nb_clips / 3), int(nb_clips * 2 / 3)))
43 | clips.append(random.randint(int(nb_clips * 2 / 3), nb_clips + 1))
44 |
45 | elif self.mode == 'val':
46 | video_name, index = self.keys[idx].split(' ')
47 | index = abs(int(index))
48 | else:
49 | raise ValueError('There are only train and val modes')
50 |
51 | label = self.gts[idx]
52 | label = int(label) - 1
53 |
54 | if self.mode == 'train':
55 | data = []
56 | for i in range(len(clips)):
57 | index = clips[i]
58 | data.append(self.load_ucf_image(video_name, index))
59 | sample = (data, label)
60 | elif self.mode == 'val':
61 | data = self.load_ucf_image(video_name, index)
62 | sample = (video_name.split('/')[0], data, label)
63 |
64 | return sample
65 |
66 |
67 | class UCF101:
68 | def __init__(self, cfg, logger):
69 | self.batch_size = cfg.TRAIN.FINETUNING.BATCH_SIZE
70 | self.num_workers = cfg.WORKERS
71 | self.root = cfg.DATASET.UCF101.FRAMES_PATH
72 | self.splits_path = cfg.DATASET.UCF101.SPLITS_PATH
73 | self.split = cfg.DATASET.UCF101.SPLIT
74 | self.cfg = cfg
75 | self.logger = logger
76 |
77 | # ===== Split the dataset =====
78 | splitter = UCF101_splitter(self.splits_path, self.split)
79 | self.train_videos, self.val_videos = splitter.split_video()
80 | self.load_frame_count()
81 | self.get_training_dict()
82 | self.get_validation_dict()
83 |
84 | @staticmethod
85 | def process_video(x):
86 | root, video, L = x
87 | key = video
88 | n_frames = len(os.listdir(os.path.join(root, video)))
89 | L.append([key.replace('HandStandPushups', 'HandstandPushups'), n_frames])
90 |
91 | def load_frame_count(self):
92 | pickle_path = os.path.join(self.cfg.ASSETS_PATH, 'UCF101_frame_count_split_{:02d}.pickle'.format(self.split))
93 | if not os.path.exists(pickle_path):
94 |
95 | # ===== Build frame count dict if no pickle file found =====
96 | self.frame_count = {}
97 | self.logger.info('Creating frame count dictionary for UCF101')
98 |
99 | # ===== Set up multiprocessing =====
100 | pool = mp.Pool(processes = 16)
101 | manager = mp.Manager()
102 | L = manager.list()
103 | in_args = []
104 | for video in os.listdir(self.root):
105 | in_args.append([self.root, video, L])
106 |
107 | # ===== Get frame counts using multiprocessing =====
108 | for _ in tqdm(pool.imap_unordered(self.process_video, in_args), total = len(in_args)):
109 | pass
110 |
111 | pool.close()
112 | pool.join()
113 | self.frame_count = {k:v for k,v in list(L)}
114 |
115 | with open(pickle_path, 'wb') as f:
116 | pickle.dump(self.frame_count, f)
117 | self.logger.info('Saved frame count dictionary to {}'.format(pickle_path))
118 |
119 | else:
120 |
121 | # ===== Load frame count dict if already exists
122 | with open(pickle_path, 'rb') as f:
123 | self.frame_count = pickle.load(f)
124 |
125 | def get_training_dict(self):
126 | self.dict_training = {}
127 | for video in self.train_videos:
128 | nb_frame = self.frame_count[video] - 10 + 1
129 | key = '{} {}'.format(video, nb_frame)
130 | self.dict_training[key] = self.train_videos[video]
131 |
132 | def get_validation_dict(self):
133 | self.dict_val = {}
134 | for video in self.val_videos:
135 | nb_frame = self.frame_count[video] - 10 + 1
136 | interval = nb_frame // 19
137 | for idx in range(19):
138 | frame = interval * idx + 1
139 | key = '{} {}'.format(video, frame)
140 | self.dict_val[key] = self.val_videos[video]
141 |
142 |
143 | def get_loaders(self):
144 | # ===== Get transforms =====
145 | tran = transforms.Compose([
146 | transforms.Resize(self.cfg.DATASET.FINETUNING.TRANSFORMATIONS.RESIZE),
147 | transforms.ToTensor(),
148 | transforms.Normalize(mean = self.cfg.DATASET.FINETUNING.MEAN, std = self.cfg.DATASET.FINETUNING.STD)
149 | ])
150 |
151 | # ===== Make Datasets =====
152 | train_dataset = UCF_dataset(
153 | root = self.cfg.DATASET.UCF101.FRAMES_PATH,
154 | file_dict = self.dict_training,
155 | mode = 'train',
156 | transforms = tran
157 | )
158 |
159 | val_dataset = UCF_dataset(
160 | root = self.cfg.DATASET.UCF101.FRAMES_PATH,
161 | file_dict = self.dict_val,
162 | mode = 'val',
163 | transforms = tran
164 | )
165 |
166 | self.logger.info('Created Train / Val splits')
167 | self.logger.info('Training Videos : {}'.format(len(train_dataset)))
168 | self.logger.info('Validation Videos : {}'.format(len(val_dataset) // 19))
169 |
170 | # ===== Create loaders =====
171 | train_loader = DataLoader(
172 | dataset = train_dataset,
173 | batch_size = self.cfg.TRAIN.FINETUNING.BATCH_SIZE,
174 | shuffle = True,
175 | num_workers = self.num_workers
176 | )
177 |
178 | val_loader = DataLoader(
179 | dataset = val_dataset,
180 | batch_size = self.cfg.TRAIN.FINETUNING.BATCH_SIZE,
181 | shuffle = True,
182 | num_workers = self.num_workers
183 | )
184 |
185 | val_gt = {k:v - 1 for k,v in self.val_videos.items()}
186 |
187 | return train_loader, val_loader, val_gt
188 |
189 | if __name__ == '__main__':
190 | import argparse
191 | from config import config, update_config
192 | from utils import *
193 |
194 | parser = argparse.ArgumentParser()
195 | parser.add_argument('--cfg', default = None)
196 | parser.add_argument('opts', default = None, nargs = argparse.REMAINDER)
197 | args = parser.parse_args()
198 | update_config(config, args)
199 | logger = setup_logger()
200 |
201 | train_loader, val_loader, val_gt = UCF101(cfg, logger)
202 |
203 |
--------------------------------------------------------------------------------
/datasets/finetuning/__init__.py:
--------------------------------------------------------------------------------
1 | from .UCF101 import *
--------------------------------------------------------------------------------
/datasets/pretraining/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | import numpy as np
4 | from .augmentation import *
5 | from .kinetics400 import Kinetics400
6 | from torchvision.transforms import Compose
7 |
8 |
9 | class PretrainTransforms:
10 | def __init__(self, cfg):
11 | img_transforms = []
12 | tensor_transforms = []
13 | # ===== Prepare scaling and crop sizes =====
14 | crop_size = cfg.DATASET.PRETRAINING.TRANSFORMATIONS.CROP_SIZE
15 | scale_size = cfg.DATASET.PRETRAINING.TRANSFORMATIONS.SCALE_SIZE
16 |
17 | # ===== Add optional augmentations =====
18 | if cfg.DATASET.PRETRAINING.TRANSFORMATIONS.HORIZONTAL_FLIP == True:
19 | img_transforms.append(RandomHorizontalFlip(consistent=True))
20 | if cfg.DATASET.PRETRAINING.TRANSFORMATIONS.RANDOM_GREY == True:
21 | img_transforms.append(RandomGray(consistent=False, p=0.5))
22 | if cfg.DATASET.PRETRAINING.TRANSFORMATIONS.COLOUR_JITTER == True:
23 | img_transforms.append(ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.25, p=1.0))
24 |
25 | # ===== Add Scale and Crop transforms =====
26 | img_transforms.append(Scale(size=scale_size))
27 | img_transforms.append(RandomCrop(size=crop_size, consistent = True))
28 |
29 | # ===== Create Rotation Transform =====
30 | self.Rotation = Rotation()
31 |
32 | # ===== Create Tensor Transformations =====
33 | tensor_transforms.append(ToTensor())
34 | tensor_transforms.append(Normalize(
35 | mean = cfg.DATASET.PRETRAINING.MEAN,
36 | std = cfg.DATASET.PRETRAINING.STD,
37 | ))
38 |
39 | self.img_transforms = Compose(img_transforms)
40 | self.tensor_transforms = Compose(tensor_transforms)
41 |
42 | def __call__(self, anchor_frame, pair_frame):
43 |
44 | # ===== Transform Pair Together =====
45 | anchor_tensor, pair_tensor = self.tensor_transforms(self.img_transforms([anchor_frame, pair_frame]))
46 |
47 | # ===== Transform Rotation, get Rotation GT =====
48 | rotation_gt = np.random.randint(0,4)
49 | rotation_intermediate = self.img_transforms([anchor_frame])
50 | rotation_intermediate = self.Rotation(rotation_intermediate, rotation = 90 * rotation_gt)
51 | rotation_tensor = self.tensor_transforms(rotation_intermediate)[0]
52 |
53 | return anchor_tensor, pair_tensor, rotation_tensor, rotation_gt
54 |
55 | def get_pretraining_dataset(cfg):
56 | transforms = PretrainTransforms(cfg)
57 | dataset_name = cfg.DATASET.PRETRAINING.DATASET
58 |
59 | if dataset_name == 'kinetics400':
60 | dataset = Kinetics400(cfg, transforms)
61 | else:
62 | raise NotImplementedError
63 |
64 | n_data = len(dataset)
65 |
66 | train_loader = torch.utils.data.DataLoader(
67 | dataset = dataset,
68 | batch_size = cfg.TRAIN.PRETRAINING.BATCH_SIZE,
69 | shuffle = True,
70 | num_workers = cfg.WORKERS,
71 | pin_memory = True,
72 | sampler = None,
73 | drop_last = True
74 | )
75 |
76 | return train_loader, n_data
77 |
78 |
--------------------------------------------------------------------------------
/datasets/pretraining/augmentation.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numbers
3 | import math
4 | import collections
5 | import numpy as np
6 | from PIL import ImageOps, Image
7 | #from joblib import Parallel, delayed
8 |
9 | import torchvision
10 | from torchvision import transforms
11 | import torchvision.transforms.functional as F
12 |
13 | class Padding:
14 | def __init__(self, pad):
15 | self.pad = pad
16 |
17 | def __call__(self, img):
18 | return ImageOps.expand(img, border=self.pad, fill=0)
19 |
20 | class Scale:
21 | def __init__(self, size, interpolation=Image.NEAREST):
22 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
23 | self.size = size
24 | self.interpolation = interpolation
25 |
26 | def __call__(self, imgmap):
27 | # assert len(imgmap) > 1 # list of images
28 | img1 = imgmap[0]
29 | if isinstance(self.size, int):
30 | w, h = img1.size
31 | if (w <= h and w == self.size) or (h <= w and h == self.size):
32 | return imgmap
33 | if w < h:
34 | ow = self.size
35 | oh = int(self.size * h / w)
36 | return [i.resize((ow, oh), self.interpolation) for i in imgmap]
37 | else:
38 | oh = self.size
39 | ow = int(self.size * w / h)
40 | return [i.resize((ow, oh), self.interpolation) for i in imgmap]
41 | else:
42 | return [i.resize(self.size, self.interpolation) for i in imgmap]
43 |
44 |
45 | class CenterCrop:
46 | def __init__(self, size, consistent=True):
47 | if isinstance(size, numbers.Number):
48 | self.size = (int(size), int(size))
49 | else:
50 | self.size = size
51 |
52 | def __call__(self, imgmap):
53 | img1 = imgmap[0]
54 | w, h = img1.size
55 | th, tw = self.size
56 | x1 = int(round((w - tw) / 2.))
57 | y1 = int(round((h - th) / 2.))
58 | return [i.crop((x1, y1, x1 + tw, y1 + th)) for i in imgmap]
59 |
60 |
61 | class RandomCropWithProb:
62 | def __init__(self, size, p=0.8, consistent=True):
63 | if isinstance(size, numbers.Number):
64 | self.size = (int(size), int(size))
65 | else:
66 | self.size = size
67 | self.consistent = consistent
68 | self.threshold = p
69 |
70 | def __call__(self, imgmap):
71 | img1 = imgmap[0]
72 | w, h = img1.size
73 | if self.size is not None:
74 | th, tw = self.size
75 | if w == tw and h == th:
76 | return imgmap
77 | if self.consistent:
78 | if random.random() < self.threshold:
79 | x1 = random.randint(0, w - tw)
80 | y1 = random.randint(0, h - th)
81 | else:
82 | x1 = int(round((w - tw) / 2.))
83 | y1 = int(round((h - th) / 2.))
84 | return [i.crop((x1, y1, x1 + tw, y1 + th)) for i in imgmap]
85 | else:
86 | result = []
87 | for i in imgmap:
88 | if random.random() < self.threshold:
89 | x1 = random.randint(0, w - tw)
90 | y1 = random.randint(0, h - th)
91 | else:
92 | x1 = int(round((w - tw) / 2.))
93 | y1 = int(round((h - th) / 2.))
94 | result.append(i.crop((x1, y1, x1 + tw, y1 + th)))
95 | return result
96 | else:
97 | return imgmap
98 |
99 | class RandomCrop:
100 | def __init__(self, size, consistent=True):
101 | if isinstance(size, numbers.Number):
102 | self.size = (int(size), int(size))
103 | else:
104 | self.size = size
105 | self.consistent = consistent
106 |
107 | def __call__(self, imgmap, flowmap=None):
108 | img1 = imgmap[0]
109 | w, h = img1.size
110 | if self.size is not None:
111 | th, tw = self.size
112 | if w == tw and h == th:
113 | return imgmap
114 | if not flowmap:
115 | if self.consistent:
116 | x1 = random.randint(0, w - tw)
117 | y1 = random.randint(0, h - th)
118 | return [i.crop((x1, y1, x1 + tw, y1 + th)) for i in imgmap]
119 | else:
120 | result = []
121 | for i in imgmap:
122 | x1 = random.randint(0, w - tw)
123 | y1 = random.randint(0, h - th)
124 | result.append(i.crop((x1, y1, x1 + tw, y1 + th)))
125 | return result
126 | elif flowmap is not None:
127 | assert (not self.consistent)
128 | result = []
129 | for idx, i in enumerate(imgmap):
130 | proposal = []
131 | for j in range(3): # number of proposal: use the one with largest optical flow
132 | x = random.randint(0, w - tw)
133 | y = random.randint(0, h - th)
134 | proposal.append([x, y, abs(np.mean(flowmap[idx,y:y+th,x:x+tw,:]))])
135 | [x1, y1, _] = max(proposal, key=lambda x: x[-1])
136 | result.append(i.crop((x1, y1, x1 + tw, y1 + th)))
137 | return result
138 | else:
139 | raise ValueError('wrong case')
140 | else:
141 | return imgmap
142 |
143 |
144 | class RandomSizedCrop:
145 | def __init__(self, size, interpolation=Image.BILINEAR, consistent=True, p=1.0):
146 | self.size = size
147 | self.interpolation = interpolation
148 | self.consistent = consistent
149 | self.threshold = p
150 |
151 | def __call__(self, imgmap):
152 | img1 = imgmap[0]
153 | if random.random() < self.threshold: # do RandomSizedCrop
154 | for attempt in range(10):
155 | area = img1.size[0] * img1.size[1]
156 | target_area = random.uniform(0.5, 1) * area
157 | aspect_ratio = random.uniform(3. / 4, 4. / 3)
158 |
159 | w = int(round(math.sqrt(target_area * aspect_ratio)))
160 | h = int(round(math.sqrt(target_area / aspect_ratio)))
161 |
162 | if self.consistent:
163 | if random.random() < 0.5:
164 | w, h = h, w
165 | if w <= img1.size[0] and h <= img1.size[1]:
166 | x1 = random.randint(0, img1.size[0] - w)
167 | y1 = random.randint(0, img1.size[1] - h)
168 |
169 | imgmap = [i.crop((x1, y1, x1 + w, y1 + h)) for i in imgmap]
170 | for i in imgmap: assert(i.size == (w, h))
171 |
172 | return [i.resize((self.size, self.size), self.interpolation) for i in imgmap]
173 | else:
174 | result = []
175 | for i in imgmap:
176 | if random.random() < 0.5:
177 | w, h = h, w
178 | if w <= img1.size[0] and h <= img1.size[1]:
179 | x1 = random.randint(0, img1.size[0] - w)
180 | y1 = random.randint(0, img1.size[1] - h)
181 | result.append(i.crop((x1, y1, x1 + w, y1 + h)))
182 | assert(result[-1].size == (w, h))
183 | else:
184 | result.append(i)
185 |
186 | assert len(result) == len(imgmap)
187 | return [i.resize((self.size, self.size), self.interpolation) for i in result]
188 |
189 | # Fallback
190 | scale = Scale(self.size, interpolation=self.interpolation)
191 | crop = CenterCrop(self.size)
192 | return crop(scale(imgmap))
193 | else: # don't do RandomSizedCrop, do CenterCrop
194 | crop = CenterCrop(self.size)
195 | return crop(imgmap)
196 |
197 |
198 | class RandomHorizontalFlip:
199 | def __init__(self, consistent=True, command=None):
200 | self.consistent = consistent
201 | if command == 'left':
202 | self.threshold = 0
203 | elif command == 'right':
204 | self.threshold = 1
205 | else:
206 | self.threshold = 0.5
207 | def __call__(self, imgmap):
208 | if self.consistent:
209 | if random.random() < self.threshold:
210 | return [i.transpose(Image.FLIP_LEFT_RIGHT) for i in imgmap]
211 | else:
212 | return imgmap
213 | else:
214 | result = []
215 | for i in imgmap:
216 | if random.random() < self.threshold:
217 | result.append(i.transpose(Image.FLIP_LEFT_RIGHT))
218 | else:
219 | result.append(i)
220 | assert len(result) == len(imgmap)
221 | return result
222 |
223 |
224 | class RandomGray:
225 | '''Actually it is a channel splitting, not strictly grayscale images'''
226 | def __init__(self, consistent=True, p=0.5):
227 | self.consistent = consistent
228 | self.p = p # probability to apply grayscale
229 | def __call__(self, imgmap):
230 | if self.consistent:
231 | if random.random() < self.p:
232 | return [self.grayscale(i) for i in imgmap]
233 | else:
234 | return imgmap
235 | else:
236 | result = []
237 | for i in imgmap:
238 | if random.random() < self.p:
239 | result.append(self.grayscale(i))
240 | else:
241 | result.append(i)
242 | assert len(result) == len(imgmap)
243 | return result
244 |
245 | def grayscale(self, img):
246 | channel = np.random.choice(3)
247 | np_img = np.array(img)[:,:,channel]
248 | np_img = np.dstack([np_img, np_img, np_img])
249 | img = Image.fromarray(np_img, 'RGB')
250 | return img
251 |
252 |
253 | class ColorJitter(object):
254 | """Randomly change the brightness, contrast and saturation of an image. --modified from pytorch source code
255 | Args:
256 | brightness (float or tuple of float (min, max)): How much to jitter brightness.
257 | brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
258 | or the given [min, max]. Should be non negative numbers.
259 | contrast (float or tuple of float (min, max)): How much to jitter contrast.
260 | contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
261 | or the given [min, max]. Should be non negative numbers.
262 | saturation (float or tuple of float (min, max)): How much to jitter saturation.
263 | saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
264 | or the given [min, max]. Should be non negative numbers.
265 | hue (float or tuple of float (min, max)): How much to jitter hue.
266 | hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
267 | Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
268 | """
269 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, consistent=False, p=1.0):
270 | self.brightness = self._check_input(brightness, 'brightness')
271 | self.contrast = self._check_input(contrast, 'contrast')
272 | self.saturation = self._check_input(saturation, 'saturation')
273 | self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5),
274 | clip_first_on_zero=False)
275 | self.consistent = consistent
276 | self.threshold = p
277 |
278 | def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
279 | if isinstance(value, numbers.Number):
280 | if value < 0:
281 | raise ValueError("If {} is a single number, it must be non negative.".format(name))
282 | value = [center - value, center + value]
283 | if clip_first_on_zero:
284 | value[0] = max(value[0], 0)
285 | elif isinstance(value, (tuple, list)) and len(value) == 2:
286 | if not bound[0] <= value[0] <= value[1] <= bound[1]:
287 | raise ValueError("{} values should be between {}".format(name, bound))
288 | else:
289 | raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name))
290 |
291 | # if value is 0 or (1., 1.) for brightness/contrast/saturation
292 | # or (0., 0.) for hue, do nothing
293 | if value[0] == value[1] == center:
294 | value = None
295 | return value
296 |
297 | @staticmethod
298 | def get_params(brightness, contrast, saturation, hue):
299 | """Get a randomized transform to be applied on image.
300 | Arguments are same as that of __init__.
301 | Returns:
302 | Transform which randomly adjusts brightness, contrast and
303 | saturation in a random order.
304 | """
305 | transforms = []
306 |
307 | if brightness is not None:
308 | brightness_factor = random.uniform(brightness[0], brightness[1])
309 | transforms.append(torchvision.transforms.Lambda(lambda img: F.adjust_brightness(img, brightness_factor)))
310 |
311 | if contrast is not None:
312 | contrast_factor = random.uniform(contrast[0], contrast[1])
313 | transforms.append(torchvision.transforms.Lambda(lambda img: F.adjust_contrast(img, contrast_factor)))
314 |
315 | if saturation is not None:
316 | saturation_factor = random.uniform(saturation[0], saturation[1])
317 | transforms.append(torchvision.transforms.Lambda(lambda img: F.adjust_saturation(img, saturation_factor)))
318 |
319 | if hue is not None:
320 | hue_factor = random.uniform(hue[0], hue[1])
321 | transforms.append(torchvision.transforms.Lambda(lambda img: F.adjust_hue(img, hue_factor)))
322 |
323 | random.shuffle(transforms)
324 | transform = torchvision.transforms.Compose(transforms)
325 |
326 | return transform
327 |
328 | def __call__(self, imgmap):
329 | if random.random() < self.threshold: # do ColorJitter
330 | if self.consistent:
331 | transform = self.get_params(self.brightness, self.contrast,
332 | self.saturation, self.hue)
333 | return [transform(i) for i in imgmap]
334 | else:
335 | result = []
336 | for img in imgmap:
337 | transform = self.get_params(self.brightness, self.contrast,
338 | self.saturation, self.hue)
339 | result.append(transform(img))
340 | return result
341 | else: # don't do ColorJitter, do nothing
342 | return imgmap
343 |
344 | def __repr__(self):
345 | format_string = self.__class__.__name__ + '('
346 | format_string += 'brightness={0}'.format(self.brightness)
347 | format_string += ', contrast={0}'.format(self.contrast)
348 | format_string += ', saturation={0}'.format(self.saturation)
349 | format_string += ', hue={0})'.format(self.hue)
350 | return format_string
351 |
352 |
353 | class RandomRotation:
354 | def __init__(self, consistent=True, degree=15, p=1.0):
355 | self.consistent = consistent
356 | self.degree = degree
357 | self.threshold = p
358 | def __call__(self, imgmap):
359 | if random.random() < self.threshold: # do RandomRotation
360 | if self.consistent:
361 | deg = np.random.randint(-self.degree, self.degree, 1)[0]
362 | return [i.rotate(deg, expand=True) for i in imgmap]
363 | else:
364 | return [i.rotate(np.random.randint(-self.degree, self.degree, 1)[0], expand=True) for i in imgmap]
365 | else: # don't do RandomRotation, do nothing
366 | return imgmap
367 |
368 | class Rotation:
369 | def __init__(self, consistent=True, degree_inc=90, p=1.0):
370 | self.consistent = consistent
371 | self.degree_inc = degree_inc
372 | self.threshold = p
373 | def __call__(self, imgmap, rotation):
374 | deg = rotation
375 | return [i.rotate(deg, expand=True) for i in imgmap]
376 |
377 | class ToTensor:
378 | def __call__(self, imgmap):
379 | totensor = transforms.ToTensor()
380 | return [totensor(i) for i in imgmap]
381 |
382 | class Normalize:
383 | def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
384 | self.mean = mean
385 | self.std = std
386 | def __call__(self, imgmap):
387 | normalize = transforms.Normalize(mean=self.mean, std=self.std)
388 | return [normalize(i) for i in imgmap]
389 |
390 |
--------------------------------------------------------------------------------
/datasets/pretraining/kinetics400.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import pickle
4 | import numpy as np
5 | import random
6 | from glob import glob
7 | from tqdm import tqdm
8 | import PIL.Image as Image
9 | from multiprocessing import Process, Manager, Pool
10 |
11 | class kineticsFolderInstance:
12 | """
13 | Dataset instance for a single kinetics video
14 | """
15 | def __init__(self, root, membank_idx, frame_padding):
16 | self.root = root
17 | self.membank_idx = membank_idx
18 | self.frame_padding = frame_padding
19 | self.vid_length = len(glob(os.path.join(root,'*')))
20 |
21 | def __len__(self):
22 | return self.vid_length
23 |
24 | def __call__(self):
25 | anchor_idx = np.random.randint(0, self.vid_length)
26 | if anchor_idx >= self.vid_length - (2 * self.frame_padding) - 1:
27 | pair_idx = anchor_idx - self.frame_padding
28 | else:
29 | pair_idx = anchor_idx + self.frame_padding
30 |
31 | anchor_path = os.path.join(self.root, 'frame{}.jpg'.format(anchor_idx))
32 | pair_path = os.path.join(self.root, 'frame{}.jpg'.format(pair_idx))
33 |
34 | return anchor_path, pair_path
35 |
36 | class Kinetics400:
37 | def __init__(self, cfg, transform):
38 | self.root = os.path.join(cfg.DATASET.KINETICS400.FRAMES_PATH, 'train')
39 | self.frame_padding = cfg.TRAIN.PRETRAINING.FRAME_PADDING
40 | self.transform = transform
41 | self.K = cfg.LOSS.PRETRAINING.NCE.NEGATIVES
42 | assert self.root is not '', 'Please specify Kinetics400 Path in config'
43 | pickle_file = os.path.join(cfg.ASSETS_PATH, 'Kinetics400Dataset.pickle')
44 | if not os.path.exists(pickle_file):
45 | # ===== Makes Sample List =====
46 | self.samples = self.get_sample_list()
47 |
48 | # ===== Save Pickle File =====
49 | with open(pickle_file, 'wb') as f:
50 | pickle.dump(self.samples,f)
51 | else:
52 | # ===== Load Pickle File =====
53 | with open(pickle_file, 'rb') as f:
54 | self.samples = pickle.load(f)
55 | f.close()
56 |
57 | self.all_membank_negatives = list(range(len(self.samples)))
58 |
59 |
60 | def process_video(self, x):
61 | video, idx, L = x
62 | try:
63 | sample = kineticsFolderInstance(
64 | root = video,
65 | membank_idx = idx,
66 | frame_padding = self.frame_padding
67 | )
68 | if len(sample) != 0:
69 | L.append(sample)
70 | except:
71 | pass
72 |
73 | def get_sample_list(self):
74 | video_folders = glob(os.path.join(self.root,'*','*'))
75 |
76 | # ===== Set up multiprocessing =====
77 | L = Manager().list()
78 | pool = Pool(processes=32)
79 |
80 | # ===== Prepare inputs for multiprocessing =====
81 | inputs = [[video, idx, L] for idx, video in enumerate(video_folders)]
82 |
83 | # ===== Create sample list =====
84 | pbar = tqdm(total = len(inputs))
85 | print(len(inputs))
86 | for i in pool.imap_unordered(self.process_video, inputs):
87 | pbar.update(1)
88 |
89 | pool.close()
90 | pool.join()
91 |
92 | all_samples = list(L)
93 |
94 | # ===== Fix membank_idx due to failed samples =====
95 | for idx, sample in enumerate(all_samples):
96 | sample.membank_idx = idx
97 |
98 | return all_samples
99 |
100 |
101 |
102 | def __len__(self):
103 | return len(self.samples)
104 |
105 | def __getitem__(self, index):
106 |
107 | # Get sample from index
108 | sample = self.samples[index]
109 | membank_idx = sample.membank_idx
110 | anchor_path, pair_path = sample()
111 |
112 | anchor_frame = Image.open(anchor_path)
113 | pair_frame = Image.open(pair_path)
114 | anchor_tensor, pair_tensor, rotation_tensor, rotation_gt = self.transform(anchor_frame, pair_frame)
115 |
116 | # Get negatives
117 | potential_negatives = self.all_membank_negatives[:membank_idx] + self.all_membank_negatives[membank_idx + 1:]
118 | negatives = torch.tensor(random.sample(potential_negatives, self.K + 1))
119 |
120 | inputs = {
121 | 'anchor_tensor': anchor_tensor,
122 | 'pair_tensor': pair_tensor,
123 | 'membank_idx': membank_idx,
124 | 'rotation_tensor': rotation_tensor,
125 | 'rotation_gt': rotation_gt,
126 | 'negatives': negatives,
127 | }
128 |
129 | return inputs
130 |
--------------------------------------------------------------------------------
/finetune.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import time
4 | import argparse
5 | import logging
6 | import torch
7 | import torch.nn as nn
8 | import numpy as np
9 | from tensorboardX import SummaryWriter
10 |
11 |
12 | from config import config, update_config
13 | from network import *
14 | from datasets.pretraining import get_pretraining_dataset
15 | from datasets.finetuning import *
16 | from loss import get_loss
17 | from utils import *
18 |
19 |
20 | def parse_args():
21 | parser = argparse.ArgumentParser(description='Train TCE Self-Supervised')
22 | parser.add_argument('--cfg', help = 'Path to config file', type = str, default = None)
23 | parser.add_argument('--val', default = False, action = 'store_true', help = 'Just validate checkpoint')
24 | parser.add_argument('opts', help = 'Modify config using the command line',
25 | default = None, nargs=argparse.REMAINDER )
26 | args = parser.parse_args()
27 | update_config(config, args)
28 |
29 | return args
30 |
31 |
32 | def train(epoch, train_loader, model, criterion, optimizer, cfg, tboard, logger):
33 |
34 | # ===== Set up Meters =====
35 | batch_time = AverageMeter()
36 | loss_meter = AverageMeter()
37 | top1_meter = AverageMeter()
38 | top5_meter = AverageMeter()
39 |
40 | # ===== Switch to train mode =====
41 | model.train()
42 |
43 |
44 | # ===== Start training on batches =====
45 | for idx, (data, label) in enumerate(train_loader):
46 | end = time.time()
47 | bsz = data[0].size(0)
48 |
49 | # ===== Forwards =====
50 | data = [x.cuda() for x in data]
51 | label = label.cuda()
52 |
53 |
54 | for i, frame_tensor in enumerate(data):
55 | if i == 0:
56 | output = model(frame_tensor)
57 | else:
58 | output += model(frame_tensor)
59 |
60 | loss = criterion(output, label)
61 | prec1, prec5 = accuracy(output, label, topk=(1,5))
62 |
63 | # ===== Backwards =====
64 | optimizer.zero_grad()
65 | loss.backward()
66 | optimizer.step()
67 |
68 | # ===== Update meters and do logging =====
69 | batch_time.update(time.time() - end)
70 | loss_meter.update(loss.item(), bsz)
71 | top1_meter.update(prec1)
72 | top5_meter.update(prec5)
73 |
74 | if idx % cfg.PRINT_FREQ == 0:
75 | log_step = (epoch-1) * len(train_loader) + idx
76 | tboard.add_scalars('Train/Loss', {'val': loss_meter.val, 'avg': loss_meter.avg}, log_step)
77 | tboard.add_scalars('Train/Top1 Acc', {'val': top1_meter.val, 'avg': top1_meter.avg}, log_step)
78 | tboard.add_scalars('Train/Top5 Acc', {'val': top5_meter.val, 'avg': top5_meter.avg}, log_step)
79 |
80 | info = ('Epoch : {} ({}/{}) | '
81 | 'BT : {:02f} ({:02f}) | '
82 | 'Loss: {:02f} ({:02f}) | '
83 | 'Top1: {:02f} ({:02f}) | '
84 | 'Top5: {:02f} ({:02f})').format(
85 | epoch, idx, len(train_loader),
86 | batch_time.val, batch_time.avg,
87 | loss_meter.val, loss_meter.avg,
88 | top1_meter.val, top1_meter.avg,
89 | top5_meter.val, top5_meter.avg
90 | )
91 | logger.info(info)
92 |
93 | @torch.no_grad()
94 | def validate(epoch, val_loader, val_gt, model, criterion, cfg, tboard, logger):
95 |
96 | # ===== Switch to eval mode =====
97 | model.eval()
98 | video_predictions = {}
99 | end = time.time()
100 | pbar = tqdm(total = len(val_loader))
101 |
102 | for idx, (video_keys, data, label) in enumerate(val_loader):
103 | bsz = data.size(0)
104 | data = data.cuda()
105 | label = label.cuda()
106 |
107 | # ===== Calculate network output and pack into video_predictions =====
108 | output = model(data)
109 | for i in range(bsz):
110 | key = video_keys[i]
111 | pred = output[i]
112 | if key not in video_predictions.keys():
113 | video_predictions[key] = []
114 | video_predictions[key] = pred
115 | else:
116 | video_predictions[key] += pred
117 |
118 | pbar.update(1)
119 |
120 | print('\n')
121 | # ===== Get Eval Top1, Top5, Loss =====
122 | video_level_preds = torch.zeros(len(video_predictions), 101).float().cuda()
123 | video_level_labels = torch.zeros(len(video_predictions)).long().cuda()
124 |
125 | for idx, key in enumerate(sorted(video_predictions.keys())):
126 | video_level_preds[idx] = video_predictions[key]
127 | video_level_labels[idx] = val_gt[key]
128 |
129 | prec1, prec5 = accuracy(video_level_preds, video_level_labels, topk=(1,5))
130 | loss = criterion(video_level_preds, video_level_labels)
131 |
132 | # ===== Log =====
133 | logger.info('Validation complete for epoch {}, time taken {:02f} seconds'.format(epoch, time.time() - end))
134 | logger.info('Top1: {:02f} Top5: {:02f} Loss: {:02f}'.format(prec1, prec5, loss.item()))
135 |
136 | tboard.add_scalars('Val/Top1', {'prec1':prec1}, epoch)
137 | tboard.add_scalars('Val/Top5', {'prec5':prec5}, epoch)
138 | tboard.add_scalars('Val/Loss', {'loss':loss.item()}, epoch)
139 |
140 | return prec1
141 |
142 |
143 |
144 |
145 | def main():
146 | args = parse_args()
147 | logger = setup_logger()
148 | logger.info(config)
149 | if not os.path.exists(config.ASSETS_PATH):
150 | os.makedirs(config.ASSETS_PATH)
151 |
152 | # ===== Create the dataloaders =====
153 | UCF_dataset = UCF101(config, logger)
154 | train_loader, val_loader, val_gt = UCF_dataset.get_loaders()
155 |
156 | # ===== Create the model =====
157 | model = FineTuneNet(config)
158 | logger.info('Built Model, using {} backbone'.format(config.MODEL.TRUNK))
159 | if torch.cuda.device_count() > 1:
160 | model = torch.nn.DataParallel(model).cuda()
161 | else:
162 | model = model.cuda()
163 | logger.info('Training on {} GPUs'.format(torch.cuda.device_count()))
164 |
165 | # ===== Set the optimizer =====
166 | optimizer = get_optimizer(model, config, pretraining = False)
167 |
168 | # ===== Get the loss =====
169 | criterion = nn.CrossEntropyLoss().cuda()
170 |
171 | # ===== Load checkpoint =====
172 | if config.TRAIN.FINETUNING.CHECKPOINT:
173 | checkpoint = torch.load(config.TRAIN.FINETUNING.CHECKPOINT)
174 | # ===== Align checkpoint keys with model =====
175 | if 'module' in list(checkpoint['state_dict'].keys())[0] and 'module' not in list(model.state_dict().keys())[0]:
176 | checkpoint['state_dict'] = {k.replace('module.',''):v for k,v in checkpoint['state_dict'].items() }
177 | elif 'module' not in list(checkpoint['state_dict'].keys())[0] and 'module' in list(model.state_dict().keys())[0]:
178 | checkpoint['state_dict'] = {'module.' + k:v for k,v in checkpoint['state_dict'].items() }
179 |
180 | if not config.TRAIN.FINETUNING.RESUME:
181 | # ===== Load only backbone parameters for starting finetuning at epoch 0 =====
182 | forgiving_load_state_dict(checkpoint['state_dict'], model, logger)
183 | start_epoch = 1
184 | best_prec1 = 0
185 | else:
186 | model.load_state_dict(checkpoint['state_dict'])
187 | optimizer.load_state_dict(checkpoint['optimizer'])
188 | start_epoch = checkpoint['epoch']
189 | best_prec1 = checkpoint['best_prec1']
190 | else:
191 | logger.warning('No checkpoint specified for pre-training. Training from scratch.')
192 | start_epoch = 1
193 | best_prec1 = 0
194 |
195 | # ===== Set up save directory and TensorBoard =====
196 | assert config.TRAIN.FINETUNING.SAVEDIR, 'Please specify save directory'
197 | if not os.path.exists(config.TRAIN.FINETUNING.SAVEDIR):
198 | os.makedirs(config.TRAIN.FINETUNING.SAVEDIR)
199 | os.makedirs(os.path.join(config.TRAIN.FINETUNING.SAVEDIR, 'checkpoints'))
200 | os.makedirs(os.path.join(config.TRAIN.FINETUNING.SAVEDIR, 'tboard'))
201 |
202 | tboard = SummaryWriter(logdir = os.path.join(config.TRAIN.FINETUNING.SAVEDIR, 'tboard'))
203 |
204 | if args.val:
205 | logger.info('Running in Validation mode')
206 | validate(
207 | epoch = start_epoch,
208 | val_loader = val_loader,
209 | val_gt = val_gt,
210 | model = model,
211 | criterion = criterion,
212 | cfg = config,
213 | tboard = tboard,
214 | logger = logger,
215 | )
216 | else:
217 | # ===== Train Loop =====
218 | logger.info('Begin training')
219 | for epoch in range(start_epoch, config.TRAIN.FINETUNING.EPOCHS + 1):
220 | adjust_learning_rate(epoch, config, optimizer, pretraining = False, logger = logger)
221 | logger.info('Training epoch {}'.format(epoch))
222 |
223 | # ===== Train 1 epoch ======
224 | train(
225 | epoch = epoch,
226 | train_loader = train_loader,
227 | model = model,
228 | criterion = criterion,
229 | optimizer = optimizer,
230 | cfg = config,
231 | tboard = tboard,
232 | logger = logger
233 | )
234 |
235 | # ===== Validate periodically =====
236 | if epoch % config.TRAIN.FINETUNING.VAL_FREQ == 0:
237 | logger.info('Validating at epoch {}'.format(epoch))
238 | prec1 = validate(
239 | epoch = epoch,
240 | val_loader = val_loader,
241 | val_gt = val_gt,
242 | model = model,
243 | criterion = criterion,
244 | cfg = config,
245 | tboard = tboard,
246 | logger = logger,
247 | )
248 |
249 | # ===== Save if new best performance =====
250 | if prec1 > best_prec1:
251 | best_prec1 = prec1
252 | logger.info('New best top1 precision: {}'.format(best_prec1))
253 | checkpoint = {
254 | 'state_dict': model.state_dict(),
255 | 'optimizer': optimizer.state_dict(),
256 | 'epoch': epoch,
257 | 'best_prec1': best_prec1
258 | }
259 | save_path = os.path.join(config.TRAIN.FINETUNING.SAVEDIR, 'checkpoints', 'best_checkpoint.pth')
260 | torch.save(checkpoint, save_path)
261 |
262 | # ===== Save latest checkpoint =====
263 | checkpoint = {
264 | 'state_dict': model.state_dict(),
265 | 'optimizer': optimizer.state_dict(),
266 | 'epoch': epoch,
267 | 'best_prec1': best_prec1
268 | }
269 | save_path = os.path.join(config.TRAIN.FINETUNING.SAVEDIR, 'checkpoints', 'latest_checkpoint.pth')
270 | torch.save(checkpoint, save_path)
271 |
272 | if __name__ == '__main__':
273 | main()
--------------------------------------------------------------------------------
/images/TCE.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csiro-robotics/TCE/0da560c7c8616fc124dce0b5b35cd3ebae20bb81/images/TCE.png
--------------------------------------------------------------------------------
/images/bowling_tsne_example.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csiro-robotics/TCE/0da560c7c8616fc124dce0b5b35cd3ebae20bb81/images/bowling_tsne_example.gif
--------------------------------------------------------------------------------
/loss/NCEAverage.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 |
5 | class Normalise(nn.Module):
6 |
7 | def __init__(self, power=2):
8 | super(Normalise, self).__init__()
9 | self.power = power
10 |
11 | def forward(self, x):
12 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power)
13 | out = x.div(norm)
14 | return out
15 |
16 | class NCEAverageKinetics(nn.Module):
17 |
18 | def __init__(self, feat_dim, n_data, K, T=0.07, momentum=0.5):
19 | super(NCEAverageKinetics, self).__init__()
20 | self.nLem = n_data
21 | self.unigrams = torch.ones(self.nLem)
22 | self.K = int(K)
23 | self.feat_dim = int(feat_dim)
24 | self.register_buffer('params', torch.tensor([K, T, -1, -1, momentum]))
25 |
26 | stdv = 1. / math.sqrt(feat_dim / 3)
27 | self.register_buffer('memory_bank', torch.rand(n_data, feat_dim).mul_(2 * stdv).add_(-stdv))
28 | self.l2norm = Normalise(2)
29 |
30 |
31 | @torch.no_grad()
32 | def update_memorybank(self, features, index):
33 | with torch.no_grad():
34 | self.memory_bank.index_copy_(0, index.view(-1), features)
35 |
36 |
37 | def contrast(self, anchor_feature, pair_feature, negatives):
38 | Z_c = self.params[2].item()
39 | T = self.params[1].item()
40 | K = int(self.params[0].item())
41 |
42 | batchSize = anchor_feature.size(0)
43 | # ===== Retrieve negatives from memory bank, reshape, insert positive into 0th index =====
44 | weight_c = torch.index_select(self.memory_bank, 0, negatives.view(-1)).detach()
45 | weight_c = weight_c.view(batchSize, K + 1, self.feat_dim)
46 | weight_c[:,0] = pair_feature
47 |
48 | # ===== BMM between positive and negative features + positive pairing =====
49 | out_c = torch.bmm(weight_c, anchor_feature.view(batchSize, self.feat_dim, 1))
50 | out_c = torch.exp(torch.div(out_c, T))
51 |
52 | if Z_c < 0:
53 | self.params[2] = out_c.mean() * self.memory_bank.size(0)
54 | Z_c = self.params[2].clone().detach().item()
55 | print("normalization constant Z_c is set to {:.1f}".format(Z_c))
56 |
57 | out_c = torch.div(out_c, Z_c).contiguous()
58 |
59 | return out_c
60 |
61 |
62 | def forward(self, inputs, _ = None):
63 |
64 | anchor_feature = inputs['anchor_feature']
65 | pair_feature= inputs['pair_feature']
66 | negatives = inputs['negatives']
67 | membank_idx = inputs['membank_idx']
68 |
69 | out_c = self.contrast(anchor_feature, pair_feature, negatives)
70 | self.update_memorybank(anchor_feature, membank_idx)
71 |
72 |
73 | return out_c
74 |
75 | class NCEAverageKineticsMining(nn.Module):
76 | def __init__(self, feat_dim, n_data, K, max_hard_negatives_percentage = 1, T=0.07):
77 | super(NCEAverageKineticsMining, self).__init__()
78 | self.nLem = n_data
79 | self.unigrams = torch.ones(self.nLem)
80 | self.K = int(K)
81 | self.feat_dim = int(feat_dim)
82 | self.max_negs = int(self.K * max_hard_negatives_percentage)
83 |
84 | self.register_buffer('params', torch.tensor([K, T, -1]))
85 |
86 | stdv = 1. / math.sqrt(feat_dim / 3)
87 | self.register_buffer('memory_bank', torch.rand(n_data, feat_dim).mul_(2 * stdv).add_(-stdv))
88 | self.l2norm = Normalise(2)
89 |
90 | def update_memorybank(self, features, index):
91 | self.memory_bank.index_copy_(0, index.view(-1), features)
92 | return None
93 |
94 |
95 | def contrast(self, anchor_feature, pair_feature, negatives):
96 | Z_c = self.params[2].item()
97 | T = self.params[1].item()
98 | K = int(self.params[0].item())
99 | batchSize = anchor_feature.size(0)
100 |
101 | # ===== Retrieve negatives from memory bank, reshape, insert positive into 0th index =====
102 | weight_c = torch.index_select(self.memory_bank, 0, negatives.view(-1)).detach()
103 | weight_c = weight_c.view(batchSize, K + 1, self.feat_dim)
104 | weight_c[:,0] = pair_feature
105 |
106 | # ===== BMM between positive and negative features + positive pairing =====
107 | out_c = torch.bmm(weight_c, anchor_feature.view(batchSize, self.feat_dim, 1))
108 | out_c = torch.exp(torch.div(out_c, T))
109 |
110 | if Z_c < 0:
111 | self.params[2] = out_c.mean() * self.memory_bank.size(0)
112 | Z_c = self.params[2].clone().detach().item()
113 | print("normalization constant Z_c is set to {:.1f}".format(Z_c))
114 |
115 | out_c = torch.div(out_c, Z_c).contiguous()
116 |
117 | return out_c
118 |
119 | def mine_negative_examples(self, anchor_feature, threshold, membank_idx):
120 | with torch.no_grad():
121 | bs = anchor_feature.size(0)
122 | cosine_scores = torch.matmul(anchor_feature, self.memory_bank.view(self.feat_dim, -1))
123 | cosine_mask = cosine_scores >= threshold
124 | cosine_scores[cosine_mask] = -2
125 |
126 | scores, indices = cosine_scores.topk(k=self.K)
127 | negative_indices = torch.zeros(bs, self.K + 1).cuda()
128 |
129 | for rowid in range(bs):
130 | # Get all the hard examples, determine how many random examples are required
131 | row_topk = scores[rowid] # Top K values, sorted
132 | row_indices = indices[rowid] # Top K indices, associated
133 | row_idx = membank_idx[rowid]
134 | hard_examples = min((row_topk != -2).sum().item(), self.max_negs)
135 | rand_examples = self.K - hard_examples
136 |
137 | if hard_examples >= self.K:
138 | # Enough hard examples to use for all negatives samples
139 | negative_indices[rowid, 1:] = row_indices
140 | elif hard_examples == 0:
141 | probs = torch.ones(self.nLem)
142 | probs[row_idx] = 0
143 | rand_indices = torch.multinomial(probs, self.K, replacement = True)
144 | negative_indices[rowid, 1:] = rand_indices
145 | else:
146 | # Mix of hard and random examples
147 | negative_indices[rowid, 1:hard_examples + 1] = row_indices[:hard_examples]
148 | probs = torch.ones(self.nLem)
149 | probs[row_indices[:hard_examples]] = 0 # Don't sample hard negatives a second time
150 | probs[row_idx] = 0
151 | rand_indices = torch.multinomial(probs, self.K - hard_examples, replacement = True)
152 |
153 | negative_indices[rowid, hard_examples + 1:] = rand_indices
154 |
155 | return negative_indices.long().cuda()
156 |
157 | def forward(self, inputs, threshold):
158 |
159 | anchor_feature = inputs['anchor_feature']
160 | pair_feature= inputs['pair_feature']
161 | membank_idx = inputs['membank_idx']
162 |
163 | negatives = self.mine_negative_examples(anchor_feature, threshold, membank_idx)
164 | out_c = self.contrast(anchor_feature, pair_feature, negatives)
165 | self.update_memorybank(anchor_feature, membank_idx)
166 |
167 |
168 | return out_c
169 |
--------------------------------------------------------------------------------
/loss/NCECriterion.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | eps = 1e-7
5 |
6 |
7 | class NCECriterion(nn.Module):
8 | """
9 | Eq. (12): L_{NCE}
10 | """
11 | def __init__(self, n_data):
12 | super(NCECriterion, self).__init__()
13 | self.n_data = n_data
14 |
15 | def forward(self, x):
16 |
17 | bsz = x.shape[0]
18 | m = x.size(1) - 1
19 |
20 | # noise distribution
21 | Pn = 1 / float(self.n_data)
22 |
23 | # loss for positive pair
24 | P_pos = x.select(1, 0)
25 | log_D1 = torch.div(P_pos, P_pos.add(m * Pn + eps)).log_()
26 |
27 | # loss for K negative pair
28 | P_neg = x.narrow(1, 1, m)
29 | log_D0 = torch.div(P_neg.clone().fill_(m * Pn), P_neg.add(m * Pn + eps)).log_()
30 |
31 | loss = - (log_D1.sum(0) + log_D0.view(-1, 1).sum(0)) / bsz
32 |
33 |
34 | return loss
35 |
36 | class NCEAvgCriterion(nn.Module):
37 | def __init__(self, n_data):
38 | super(NCEAvgCriterion, self).__init__()
39 | self.n_data = n_data
40 |
41 | def forward(self, x):
42 |
43 | bsz = x.shape[0]
44 | m = x.size(1) - 1
45 |
46 | # noise distribution
47 | Pn = 1 / float(self.n_data)
48 |
49 | # loss for positive pair
50 | P_pos = x.select(1, 0)
51 | log_D1 = torch.div(P_pos, P_pos.add(m * Pn + eps)).log_()
52 |
53 | # loss for K negative pair
54 | P_neg = x.narrow(1, 1, m)
55 | log_D0 = torch.div(P_neg.clone().fill_(m * Pn), P_neg.add(m * Pn + eps)).log_()
56 |
57 | loss = - (log_D1.sum(0) + log_D0.view(-1, 1).mean(0)) / bsz
58 |
59 |
60 | return loss
61 |
62 | class NCEMaxCriterion(nn.Module):
63 | def __init__(self, n_data):
64 | super(NCEMaxCriterion, self).__init__()
65 | self.n_data = n_data
66 |
67 | def forward(self, x):
68 |
69 | bsz = x.shape[0]
70 | m = x.size(1) - 1
71 |
72 | # noise distribution
73 | Pn = 1 / float(self.n_data)
74 |
75 | # loss for positive pair
76 | P_pos = x.select(1, 0)
77 | log_D1 = torch.div(P_pos, P_pos.add(m * Pn + eps)).log_()
78 |
79 | # loss for K negative pair
80 | P_neg = x.narrow(1, 1, m)
81 | log_D0 = torch.div(P_neg.clone().fill_(m * Pn), P_neg.add(m * Pn + eps)).log_()
82 | log_D0_max, log_D0_indexes = log_D0.view(-1, 1).max(0)
83 |
84 | loss = - (log_D1.sum(0) + log_D0_max) / bsz
85 |
86 |
87 | return loss
88 |
89 | class ContrastCriterion(nn.Module):
90 | def __init__(self, n_data):
91 | super(ContrastCriterion, self).__init__()
92 | self.n_data = n_data
93 |
94 | def forward(self, x):
95 |
96 | bsz = x.shape[0]
97 | m = x.size(1) - 1
98 |
99 | # noise distribution
100 | Pn = 1 / float(self.n_data)
101 |
102 | # loss for positive pair
103 | P_pos = x.select(1, 0)
104 | log_D1 = torch.div(P_pos, P_pos.add(m * Pn + eps)).log_()
105 |
106 | # loss for K negative pair
107 | P_neg = x.narrow(1, 1, m)
108 | log_D0 = torch.avg(torch.div(P_neg.clone().fill_(m * Pn), P_neg.add(m * Pn + eps))).log_()
109 |
110 | loss = - (log_D1.sum(0) + log_D0.view(-1, 1).sum(0)) / bsz
111 |
112 |
113 | return loss
--------------------------------------------------------------------------------
/loss/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from loss.NCECriterion import NCECriterion
4 | from loss.NCEAverage import *
5 |
6 |
7 | def get_loss(cfg, n_data):
8 | if cfg.LOSS.PRETRAINING.MINING == False:
9 | average = NCEAverageKinetics(
10 | feat_dim = cfg.MODEL.PRETRAINING.FC_DIM,
11 | n_data = n_data,
12 | K = cfg.LOSS.PRETRAINING.NCE.NEGATIVES,
13 | T = cfg.LOSS.PRETRAINING.NCE.TEMPERATURE,
14 | momentum = cfg.LOSS.PRETRAINING.NCE.MOMENTUM
15 | )
16 | else:
17 | average = NCEAverageKineticsMining(
18 | feat_dim = cfg.MODEL.PRETRAINING.FC_DIM,
19 | n_data = n_data,
20 | K = cfg.LOSS.PRETRAINING.NCE.NEGATIVES,
21 | T = cfg.LOSS.PRETRAINING.NCE.TEMPERATURE,
22 | max_hard_negatives_percentage = cfg.LOSS.PRETRAINING.MINING.MAX_HARD_NEGATIVES_PERCENTAGE
23 | )
24 | criterion = NCECriterion(n_data)
25 |
26 | TCELoss = NCELoss(average, criterion)
27 | RotationLoss = nn.CrossEntropyLoss()
28 |
29 | return TCELoss, RotationLoss
30 |
31 | class NCELoss(nn.Module):
32 | def __init__(self, Average, Criterion):
33 | super(NCELoss, self).__init__()
34 | self.Average = Average
35 | self.Criterion = Criterion
36 |
37 | def forward(self, inputs, threshold):
38 | NCE_Average = self.Average(inputs, threshold)
39 | NCE_loss = self.Criterion(NCE_Average)
40 | return NCE_loss
--------------------------------------------------------------------------------
/network/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import PreTrainNet, FineTuneNet
--------------------------------------------------------------------------------
/network/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from utils import get_backbone
4 |
5 | class Normalise(nn.Module):
6 |
7 | def __init__(self, power=2):
8 | super(Normalise, self).__init__()
9 | self.power = power
10 |
11 | def forward(self, x):
12 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power)
13 | out = x.div(norm)
14 | return out
15 |
16 | class PreTrainNet(nn.Module):
17 | def __init__(self, cfg):
18 | super(PreTrainNet, self).__init__()
19 | trunk = cfg.MODEL.TRUNK
20 | fc_dim = cfg.MODEL.PRETRAINING.FC_DIM
21 |
22 | self.backbone, backbone_channels = get_backbone(trunk)
23 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
24 | self.fc_contrastive = nn.Linear(backbone_channels, fc_dim)
25 | self.l2norm = Normalise(2)
26 |
27 | self.fc_rotation_1 = nn.Linear(backbone_channels, 200)
28 | self.bn_rotation_1 = nn.BatchNorm1d(num_features = 200)
29 | self.fc_rotation_2 = nn.Linear(200, 200)
30 | self.bn_rotation_2 = nn.BatchNorm1d(num_features = 200)
31 | self.fc_rotation_3 = nn.Linear(200,4)
32 |
33 |
34 |
35 | def forward(self, x, rotation = False):
36 | x = self.backbone(x)
37 | x = self.avgpool(x)
38 | x = torch.flatten(x, 1)
39 | if rotation == False:
40 | x = self.fc_contrastive(x)
41 | x = self.l2norm(x)
42 | return x
43 | elif rotation == True:
44 | x = self.fc_rotation_1(x)
45 | x = self.backbone.relu(x)
46 | x = self.bn_rotation_1(x)
47 |
48 | x = self.fc_rotation_2(x)
49 | x = self.backbone.relu(x)
50 | x = self.bn_rotation_2(x)
51 |
52 | x = self.fc_rotation_3(x)
53 | return x
54 | else:
55 | raise ValueError('Rotation is either a Boolean True or False')
56 |
57 | class FineTuneNet(nn.Module):
58 | def __init__(self, cfg):
59 | super(FineTuneNet, self).__init__()
60 | trunk = cfg.MODEL.TRUNK
61 | num_classes = cfg.MODEL.FINETUNING.NUM_CLASSES
62 | assert num_classes > 0 and isinstance(num_classes, int), 'Please give a positive integer for the number of classes in the finetuning stage'
63 | p_dropout = cfg.MODEL.FINETUNING.DROPOUT
64 |
65 | self.backbone, backbone_channels = get_backbone(trunk)
66 | self.avgpool = nn.AdaptiveAvgPool2d((1,1))
67 | self.dropout = nn.Dropout(p = p_dropout)
68 | self.fc = nn.Linear(backbone_channels, num_classes)
69 |
70 | def forward(self, x):
71 | x = self.backbone(x)
72 | x = self.avgpool(x)
73 | x = torch.flatten(x,1)
74 | x = self.dropout(x)
75 | x = self.fc(x)
76 |
77 | return x
78 |
--------------------------------------------------------------------------------
/network/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 | import numpy as np
5 | import torch.utils.model_zoo as model_zoo
6 |
7 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
8 | 'resnet152']
9 |
10 | model_urls = {
11 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
12 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
13 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
14 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
15 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
16 | }
17 |
18 |
19 | def conv3x3(in_planes, out_planes, stride=1):
20 | """3x3 convolution with padding"""
21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
22 | padding=1, bias=False)
23 |
24 | class BasicBlock(nn.Module):
25 | expansion = 1
26 |
27 | def __init__(self, inplanes, planes, stride=1, downsample=None):
28 | super(BasicBlock, self).__init__()
29 | self.conv1 = conv3x3(inplanes, planes, stride)
30 | self.bn1 = nn.BatchNorm2d(planes)
31 | self.relu = nn.ReLU(inplace=True)
32 | self.conv2 = conv3x3(planes, planes)
33 | self.bn2 = nn.BatchNorm2d(planes)
34 | self.downsample = downsample
35 | self.stride = stride
36 |
37 | def forward(self, x):
38 | residual = x
39 |
40 | out = self.conv1(x)
41 | out = self.bn1(out)
42 | out = self.relu(out)
43 |
44 | out = self.conv2(out)
45 | out = self.bn2(out)
46 |
47 | if self.downsample is not None:
48 | residual = self.downsample(x)
49 |
50 | out += residual
51 | out = self.relu(out)
52 |
53 | return out
54 |
55 |
56 | class Bottleneck(nn.Module):
57 | expansion = 4
58 |
59 | def __init__(self, inplanes, planes, stride=1, downsample=None):
60 | super(Bottleneck, self).__init__()
61 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
62 | self.bn1 = nn.BatchNorm2d(planes)
63 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
64 | padding=1, bias=False)
65 | self.bn2 = nn.BatchNorm2d(planes)
66 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
67 | self.bn3 = nn.BatchNorm2d(planes * 4)
68 | self.relu = nn.ReLU(inplace=True)
69 | self.downsample = downsample
70 | self.stride = stride
71 |
72 | def forward(self, x):
73 | residual = x
74 | out = self.conv1(x)
75 | out = self.bn1(out)
76 | out = self.relu(out)
77 |
78 | out = self.conv2(out)
79 | out = self.bn2(out)
80 | out = self.relu(out)
81 |
82 | out = self.conv3(out)
83 | out = self.bn3(out)
84 |
85 | if self.downsample is not None:
86 | residual = self.downsample(x)
87 |
88 | out += residual
89 | out = self.relu(out)
90 |
91 | return out
92 |
93 |
94 | class ResNet(nn.Module):
95 |
96 | def __init__(self, block, layers, in_channel=3, width=1):
97 | self.inplanes = 64
98 | super(ResNet, self).__init__()
99 | self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=7, stride=2, padding=3,
100 | bias=False)
101 | self.bn1 = nn.BatchNorm2d(64)
102 | self.relu = nn.ReLU(inplace=True)
103 |
104 | self.base = int(64 * width)
105 |
106 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
107 | self.layer1 = self._make_layer(block, self.base, layers[0])
108 | self.layer2 = self._make_layer(block, self.base * 2, layers[1], stride=2)
109 | self.layer3 = self._make_layer(block, self.base * 4, layers[2], stride=2)
110 | self.layer4 = self._make_layer(block, self.base * 8, layers[3], stride=2)
111 | self.avgpool = nn.AvgPool2d(7, stride=1)
112 |
113 | for m in self.modules():
114 | if isinstance(m, nn.Conv2d):
115 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
116 | m.weight.data.normal_(0, math.sqrt(2. / n))
117 | elif isinstance(m, nn.BatchNorm2d):
118 | m.weight.data.fill_(1)
119 | m.bias.data.zero_()
120 |
121 | def _make_layer(self, block, planes, blocks, stride=1):
122 | downsample = None
123 | if stride != 1 or self.inplanes != planes * block.expansion:
124 | downsample = nn.Sequential(
125 | nn.Conv2d(self.inplanes, planes * block.expansion,
126 | kernel_size=1, stride=stride, bias=False),
127 | nn.BatchNorm2d(planes * block.expansion),
128 | )
129 |
130 | layers = []
131 | layers.append(block(self.inplanes, planes, stride, downsample))
132 | self.inplanes = planes * block.expansion
133 | for i in range(1, blocks):
134 | layers.append(block(self.inplanes, planes))
135 |
136 | return nn.Sequential(*layers)
137 |
138 | def forward(self, x, layer=7):
139 | if layer <= 0:
140 | return x
141 | x = self.conv1(x)
142 | x = self.bn1(x)
143 | x = self.relu(x)
144 |
145 | x = self.maxpool(x)
146 | if layer == 1:
147 | return x
148 | x = self.layer1(x)
149 | if layer == 2:
150 | return x
151 | x = self.layer2(x)
152 | if layer == 3:
153 | return x
154 | x = self.layer3(x)
155 | if layer == 4:
156 | return x
157 | x = self.layer4(x)
158 | if layer == 5:
159 | return x
160 | return x
161 |
162 |
163 | def resnet18(pretrained=False, **kwargs):
164 | """Constructs a ResNet-18 model.
165 | Args:
166 | pretrained (bool): If True, returns a model pre-trained on ImageNet
167 | """
168 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
169 | if pretrained:
170 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
171 | return model
172 |
173 |
174 | def resnet34(pretrained=False, **kwargs):
175 | """Constructs a ResNet-34 model.
176 | Args:
177 | pretrained (bool): If True, returns a model pre-trained on ImageNet
178 | """
179 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
180 | if pretrained:
181 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
182 | return model
183 |
184 |
185 | def resnet50(pretrained=False, **kwargs):
186 | """Constructs a ResNet-50 model.
187 | Args:
188 | pretrained (bool): If True, returns a model pre-trained on ImageNet
189 | """
190 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
191 | if pretrained:
192 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
193 | return model
194 |
195 |
196 | def resnet101(pretrained=False, **kwargs):
197 | """Constructs a ResNet-101 model.
198 | Args:
199 | pretrained (bool): If True, returns a model pre-trained on ImageNet
200 | """
201 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
202 | if pretrained:
203 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
204 | return model
205 |
206 |
207 | def resnet152(pretrained=False, **kwargs):
208 | """Constructs a ResNet-152 model.
209 | Args:
210 | pretrained (bool): If True, returns a model pre-trained on ImageNet
211 | """
212 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
213 | if pretrained:
214 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
215 | return model
216 |
217 |
218 | class InsResNet50(nn.Module):
219 | """Encoder for instance discrimination and MoCo"""
220 | def __init__(self, width=1):
221 | super(InsResNet50, self).__init__()
222 | self.encoder = resnet50(width=width)
223 | self.encoder = nn.DataParallel(self.encoder)
224 |
225 | def forward(self, x, layer=7):
226 | return self.encoder(x, layer)
227 |
228 |
229 | class ResNetV1(nn.Module):
230 | def __init__(self, name='resnet50'):
231 | super(ResNetV1, self).__init__()
232 | if name == 'resnet50':
233 | self.l_to_ab = resnet50(in_channel=1, width=0.5)
234 | self.ab_to_l = resnet50(in_channel=2, width=0.5)
235 | elif name == 'resnet18':
236 | self.l_to_ab = resnet18(in_channel=1, width=0.5)
237 | self.ab_to_l = resnet18(in_channel=2, width=0.5)
238 | elif name == 'resnet101':
239 | self.l_to_ab = resnet101(in_channel=1, width=0.5)
240 | self.ab_to_l = resnet101(in_channel=2, width=0.5)
241 | else:
242 | raise NotImplementedError('model {} is not implemented'.format(name))
243 |
244 | def forward(self, x, layer=7):
245 | l, ab = torch.split(x, [1, 2], dim=1)
246 | feat_l = self.l_to_ab(l, layer)
247 | feat_ab = self.ab_to_l(ab, layer)
248 | return feat_l, feat_ab
249 |
250 |
251 | class ResNetV2(nn.Module):
252 | def __init__(self, name='resnet50'):
253 | super(ResNetV2, self).__init__()
254 | if name == 'resnet50':
255 | self.l_to_ab = resnet50(in_channel=1, width=1)
256 | self.ab_to_l = resnet50(in_channel=2, width=1)
257 | elif name == 'resnet18':
258 | self.l_to_ab = resnet18(in_channel=1, width=1)
259 | self.ab_to_l = resnet18(in_channel=2, width=1)
260 | elif name == 'resnet101':
261 | self.l_to_ab = resnet101(in_channel=1, width=1)
262 | self.ab_to_l = resnet101(in_channel=2, width=1)
263 | else:
264 | raise NotImplementedError('model {} is not implemented'.format(name))
265 |
266 | def forward(self, x, layer=7):
267 | l, ab = torch.split(x, [1, 2], dim=1)
268 | feat_l = self.l_to_ab(l, layer)
269 | feat_ab = self.ab_to_l(ab, layer)
270 | return feat_l, feat_ab
271 |
272 |
273 | class ResNetV3(nn.Module):
274 | def __init__(self, name='resnet50'):
275 | super(ResNetV3, self).__init__()
276 | if name == 'resnet50':
277 | self.l_to_ab = resnet50(in_channel=1, width=2)
278 | self.ab_to_l = resnet50(in_channel=2, width=2)
279 | elif name == 'resnet18':
280 | self.l_to_ab = resnet18(in_channel=1, width=2)
281 | self.ab_to_l = resnet18(in_channel=2, width=2)
282 | elif name == 'resnet101':
283 | self.l_to_ab = resnet101(in_channel=1, width=2)
284 | self.ab_to_l = resnet101(in_channel=2, width=2)
285 | else:
286 | raise NotImplementedError('model {} is not implemented'.format(name))
287 |
288 | def forward(self, x, layer=7):
289 | l, ab = torch.split(x, [1, 2], dim=1)
290 | feat_l = self.l_to_ab(l, layer)
291 | feat_ab = self.ab_to_l(ab, layer)
292 | return feat_l, feat_ab
293 |
294 |
295 | class MyResNetsCMC(nn.Module):
296 | def __init__(self, name='resnet50v1'):
297 | super(MyResNetsCMC, self).__init__()
298 | if name.endswith('v1'):
299 | self.encoder = ResNetV1(name[:-2])
300 | elif name.endswith('v2'):
301 | self.encoder = ResNetV2(name[:-2])
302 | elif name.endswith('v3'):
303 | self.encoder = ResNetV3(name[:-2])
304 | else:
305 | raise NotImplementedError('model not support: {}'.format(name))
306 |
307 | self.encoder = nn.DataParallel(self.encoder)
308 |
309 | def forward(self, x, layer=7):
310 | return self.encoder(x, layer)
311 |
--------------------------------------------------------------------------------
/pretrain.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import time
4 | import argparse
5 | import logging
6 | import torch
7 | import numpy as np
8 | from tensorboardX import SummaryWriter
9 |
10 |
11 | from config import config, update_config
12 | from network import *
13 | from datasets.pretraining import get_pretraining_dataset
14 | from loss import get_loss
15 | from utils import *
16 |
17 |
18 | torch.manual_seed(0)
19 | np.random.seed(0)
20 | torch.backends.cudnn.deterministic = True
21 |
22 | # TODO
23 | # Test Learning Rate Decay
24 | # Test end-of-epoch
25 | # Test Loading Checkpoints
26 | # Train end-to-end
27 |
28 | def parse_args():
29 | parser = argparse.ArgumentParser(description='Train TCE Self-Supervised')
30 | parser.add_argument('--cfg', help = 'Path to config file', type = str, default = None)
31 | parser.add_argument('opts', help = 'Modify config using the command line',
32 | default = None, nargs=argparse.REMAINDER )
33 | args = parser.parse_args()
34 | update_config(config, args)
35 |
36 | return args
37 |
38 | def train(epoch, train_loader, model, NCELoss, RotationLoss, optimizer, config, tboard, logger):
39 |
40 | # ===== Get total steps, set up meters =====
41 | total_steps = config.TRAIN.PRETRAINING.EPOCHS * len(train_loader)
42 | model.train()
43 | NCELoss.train()
44 |
45 | batch_time = AverageMeter()
46 |
47 | MainLossMeter = AverageMeter()
48 | RotationLossMeter = AverageMeter()
49 | TotalLossMeter = AverageMeter()
50 | RotationAccMeter = AverageMeter()
51 |
52 | tl = config.LOSS.PRETRAINING.MINING.THRESH_LOW
53 | th = config.LOSS.PRETRAINING.MINING.THRESH_HIGH
54 | tr = config.LOSS.PRETRAINING.MINING.THRESH_RATE
55 |
56 | for idx, inputs in enumerate(train_loader):
57 | end = time.time()
58 | # ===== Prepare data and get threshold ======
59 | anchor_tensor = inputs['anchor_tensor'].cuda()
60 | pair_tensor = inputs['pair_tensor'].cuda()
61 | rotation_tensor = inputs['rotation_tensor'].cuda()
62 | rotation_gt = inputs['rotation_gt'].cuda()
63 | inputs['negatives'] = inputs['negatives'].cuda()
64 | inputs['membank_idx'] = inputs['membank_idx'].cuda()
65 |
66 | step = (epoch - 1) * len(train_loader) + idx + 1
67 | threshold = tl + (th - tl) * (1 - math.exp(tr * step / total_steps))
68 |
69 | # ===== Forward =====
70 | bsz = anchor_tensor.size(0)
71 | inputs['anchor_feature'] = model(anchor_tensor, rotation = False)
72 | inputs['pair_feature'] = model(pair_tensor, rotation = False)
73 | rotation_feature = model(rotation_tensor, rotation = True)
74 |
75 | main_loss = NCELoss(inputs, threshold)
76 | with torch.no_grad():
77 | rotation_indexes = rotation_feature.max(1)[1]
78 | rotation_accuracy = int(torch.sum(torch.eq(rotation_indexes, rotation_gt).long())) / bsz
79 | rotation_loss = RotationLoss(rotation_feature, rotation_gt).mul(config.LOSS.PRETRAINING.ROTATION_WEIGHT)
80 | total_loss = main_loss + rotation_loss
81 |
82 | # ===== Backward =====
83 | optimizer.zero_grad()
84 | total_loss.backward()
85 | optimizer.step()
86 |
87 | # ===== Update Meters =====
88 | batch_time.update(time.time() - end, bsz)
89 | MainLossMeter.update(main_loss.item(), bsz)
90 | RotationLossMeter.update(rotation_loss.item(), bsz)
91 | TotalLossMeter.update(total_loss.item(), bsz)
92 | RotationAccMeter.update(rotation_accuracy, bsz)
93 |
94 | # ===== Print and Update Tensorboards
95 | if idx % config.PRINT_FREQ == 0:
96 | logger.info('Train: [{0}][{1}/{2}]\t'
97 | 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
98 | 'Main Loss: {mainloss.val:.3f} ({mainloss.avg:.3f})\t'
99 | 'Rotation Loss {rotloss.val:.3f} ({rotloss.avg:.3f})\t'
100 | 'Total Loss {totloss.val:.3f} ({totloss.avg:.3f})\t'.format(
101 | epoch, idx + 1, len(train_loader), batch_time=batch_time,
102 | mainloss = MainLossMeter,
103 | rotloss = RotationLossMeter,
104 | totloss = TotalLossMeter,
105 | ))
106 | tboard.add_scalars("Main Loss",
107 | {"Absolute":MainLossMeter.val, "Average":MainLossMeter.avg}, step)
108 | tboard.add_scalars("Rotation Loss",
109 | {"Absolute":RotationLossMeter.val, "Average":RotationLossMeter.avg}, step)
110 | tboard.add_scalars("Total Loss",
111 | {"Absolute":TotalLossMeter.val, "Average":TotalLossMeter.avg}, step)
112 | tboard.add_scalars("Rotation Accuracy",
113 | {"Absolute":RotationAccMeter.val, "Average":RotationAccMeter.avg}, step)
114 |
115 | return_dic = {
116 | 'Main Loss' : {'Average' : MainLossMeter.avg},
117 | 'Rotation Loss' : {'Average' : RotationLossMeter.avg},
118 | 'Total Loss' : {'Average' : TotalLossMeter.avg},
119 | 'Rotation Accuracy' : {'Average' : RotationAccMeter.avg},
120 | }
121 |
122 | return return_dic
123 |
124 |
125 | def main():
126 | args = parse_args()
127 | logger = setup_logger()
128 | logger.info(config)
129 | if not os.path.exists(config.ASSETS_PATH):
130 | os.makedirs(config.ASSETS_PATH)
131 |
132 | # ===== Create the dataloader =====
133 | train_loader, n_data = get_pretraining_dataset(config)
134 | logger.info('Training with {} Train Samples'.format(n_data))
135 |
136 | # ===== Create the model =====
137 | model = PreTrainNet(config)
138 | logger.info('Built Model, using {} backbone'.format(config.MODEL.TRUNK))
139 | if torch.cuda.device_count() > 1:
140 | model = torch.nn.DataParallel(model).cuda()
141 | else:
142 | model = model.cuda()
143 | logger.info('Training on {} GPUs'.format(torch.cuda.device_count()))
144 |
145 | # ===== Set the optimizer =====
146 | optimizer = get_optimizer(model, config, pretraining = True )
147 |
148 | # ===== Get the loss =====
149 | NCELoss, RotationLoss = get_loss(config, n_data)
150 | NCELoss = NCELoss.cuda()
151 | RotationLoss = RotationLoss.cuda()
152 |
153 | # ===== Resume from am earlier checkpoint =====
154 | start_epoch = 1
155 | if config.TRAIN.PRETRAINING.RESUME:
156 | try:
157 | checkpoint = torch.load(config.TRAIN.PRETRAINING.RESUME)
158 | except FileNotFoundError:
159 | raise FileNotFoundError('No Checkpoint found at path {}'.format(config.TRAIN.PRETRAINING.RESUME))
160 |
161 | start_epoch = checkpoint['epoch'] + 1
162 | # ===== Align checkpoint keys with model =====
163 | if 'module' in list(checkpoint['state_dict'].keys())[0] and 'module' not in list(model.state_dict().keys())[0]:
164 | checkpoint['state_dict'] = {k.replace('module.',''):v for k,v in checkpoint['state_dict'].items() }
165 | elif 'module' not in list(checkpoint['state_dict'].keys())[0] and 'module' in list(model.state_dict().keys())[0]:
166 | checkpoint['state_dict'] = {'module.' + k:v for k,v in checkpoint['state_dict'].items() }
167 |
168 | model.load_state_dict(checkpoint['state_dict'])
169 | optimizer.load_state_dict(checkpoint['optimizer'])
170 | NCELoss.load_state_dict(checkpoint['NCELoss'])
171 | logger.info('Loaded Checkpoint from "{}"'.format(config.TRAIN.PRETRAINING.RESUME))
172 | else:
173 | logger.info('Training from Random Initialisation')
174 |
175 | # ===== Set up Save Directory and TensorBoard =====
176 | assert config.TRAIN.PRETRAINING.SAVEDIR, 'Please specify save directory for model'
177 | if not os.path.exists(config.TRAIN.PRETRAINING.SAVEDIR):
178 | os.makedirs(config.TRAIN.PRETRAINING.SAVEDIR)
179 | os.makedirs(os.path.join(config.TRAIN.PRETRAINING.SAVEDIR, 'checkpoints'))
180 | os.makedirs(os.path.join(config.TRAIN.PRETRAINING.SAVEDIR, 'tboard'))
181 |
182 | tboard = SummaryWriter(logdir = os.path.join(config.TRAIN.PRETRAINING.SAVEDIR, 'tboard'))
183 |
184 | # ===== Train Loop =====
185 | logger.info('Begin training')
186 | for epoch in range(start_epoch, config.TRAIN.PRETRAINING.EPOCHS + 1):
187 | adjust_learning_rate(epoch, config, optimizer, pretraining = True, logger = logger)
188 | logger.info('Training Epoch {}'.format(epoch))
189 |
190 | return_dic = train(
191 | epoch = epoch,
192 | train_loader = train_loader,
193 | model = model,
194 | NCELoss = NCELoss,
195 | RotationLoss = RotationLoss,
196 | optimizer = optimizer,
197 | config = config,
198 | tboard = tboard,
199 | logger = logger,
200 | )
201 |
202 | for key, value in return_dic.items():
203 | tboard.add_scalars(key, value, epoch)
204 |
205 | if epoch % config.TRAIN.PRETRAINING.SAVE_FREQ == 0:
206 | state = {
207 | 'config': config,
208 | 'state_dict': model.state_dict(),
209 | 'optimizer': optimizer.state_dict(),
210 | 'NCELoss': NCELoss.state_dict(),
211 | 'epoch': epoch,
212 | }
213 | save_file = os.path.join(config.TRAIN.PRETRAINING.SAVEDIR, 'checkpoints', 'ckpt_epoch_{}.pth'.format(epoch))
214 | logger.info('Saved Checkpoint to {}'.format(save_file))
215 | torch.save(state, save_file)
216 |
217 | if __name__ == '__main__':
218 | main()
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .optimizer import *
2 | from .log_utils import *
3 | from .tsne_utils import *
4 | from .network_utils import *
5 | from .dataset_utils import *
6 | from .eval_utils import *
--------------------------------------------------------------------------------
/utils/dataset_utils.py:
--------------------------------------------------------------------------------
1 | import os, pickle
2 |
3 | class UCF101_splitter():
4 | def __init__(self, root, split):
5 | self.root = root
6 | self.split = split
7 |
8 | def get_action_index(self):
9 | '''
10 | Create a dictionary relating integer labels (e.g. 1, 2, ... 101) to their respective action classes
11 | '''
12 | self.action_label = {}
13 |
14 | # ===== Open class index text file and retrieve lines =====
15 | with open(os.path.join(self.root,'classInd.txt')) as f:
16 | content = f.readlines()
17 | content = [x.strip('\r\n') for x in content]
18 | f.close()
19 |
20 | # ===== Process the text file to associate labels with classes =====
21 | for line in content:
22 | label,action = line.split(' ')
23 | if action not in self.action_label.keys():
24 | self.action_label[action]=label
25 |
26 | return None
27 |
28 |
29 | def split_video(self):
30 | '''
31 | Create dictionaries for the train and testing splits, relating the video name to the paths to the frames contained within that video
32 | '''
33 | # ===== Create dictionary relating labels to action classes =====
34 | self.get_action_index()
35 | train_video = None
36 | for path,subdir,files in os.walk(self.root):
37 | for filename in files:
38 | # ===== Create dictionary for train split ======
39 | if filename.split('.')[0] == 'trainlist{:02d}'.format(self.split):
40 | train_video = self.file2_dic(os.path.join(self.root,filename))
41 |
42 | # ===== Create dictionary for test split =====
43 | if filename.split('.')[0] == 'testlist{:02d}'.format(self.split):
44 | test_video = self.file2_dic(os.path.join(self.root,filename))
45 |
46 | # ===== Correct issue with handstandpushups class : Inconsistency between text file and folder name capitalisation =====
47 | self.train_video = self.name_HandstandPushups(train_video)
48 | self.test_video = self.name_HandstandPushups(test_video)
49 | return self.train_video, self.test_video
50 |
51 | def file2_dic(self,fname):
52 | '''
53 | Get a text file containing the list of videos, and return a dictionary relating the video name to the paths to the frames in the video
54 |
55 | Arguments:
56 | fname : Path to the text file containing a list of videos in the split
57 | '''
58 |
59 | # ===== Create a list of videonames from the file =====
60 | with open(fname) as f:
61 | content = f.readlines()
62 | content = [x.strip('\r\n') for x in content]
63 | f.close()
64 |
65 | # ===== Create a dictionary relating every video name to a list of paths to the frames within that video =====
66 | dic={}
67 | for line in content:
68 | video = line.split('/',1)[1].split(' ',1)[0]
69 | key = 'v_' + video.split('_',1)[1].split('.',1)[0]
70 | label = self.action_label[line.split('/')[0]]
71 | dic[key] = int(label)
72 | return dic
73 |
74 | def name_HandstandPushups(self,dic):
75 | '''
76 | Account for a discrepancy between the capitalisation of HandstandPushups in the split file and in the files in the dataset
77 | '''
78 | dic = {k.replace('HandStandPushups', 'HandstandPushups'):v for k,v in dic.items()}
79 | return dic
--------------------------------------------------------------------------------
/utils/eval_utils.py:
--------------------------------------------------------------------------------
1 | def accuracy(output, target, topk=(1,)):
2 | """Computes the precision@k for the specified values of k"""
3 | maxk = max(topk)
4 | batch_size = target.size(0)
5 |
6 | _, pred = output.topk(maxk, 1, True, True)
7 | pred = pred.t()
8 | correct = pred.eq(target.view(1, -1).expand_as(pred))
9 |
10 | res = []
11 | for k in topk:
12 | correct_k = correct[:k].contiguous().view(-1).float().sum(0)
13 | res.append(correct_k.mul_(100.0 / batch_size))
14 | return res
--------------------------------------------------------------------------------
/utils/log_utils.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import logging
3 | from termcolor import colored
4 |
5 | class AverageMeter(object):
6 | """Computes and stores the average and current value"""
7 | def __init__(self):
8 | self.val = 0
9 | self.avg = 0
10 | self.sum = 0
11 | self.count = 0
12 | self.reset()
13 |
14 | def reset(self):
15 | self.val = 0
16 | self.avg = 0
17 | self.sum = 0
18 | self.count = 0
19 |
20 | def update(self, val, n=1):
21 | self.val = val
22 | self.sum += val * n
23 | self.count += n
24 | self.avg = self.sum / self.count
25 |
26 | class _ColorfulFormatter(logging.Formatter):
27 | def __init__(self, *args, **kwargs):
28 | self._root_name = kwargs.pop("root_name") + "."
29 | self._abbrev_name = kwargs.pop("abbrev_name", "")
30 | if len(self._abbrev_name):
31 | self._abbrev_name = self._abbrev_name + "."
32 | super(_ColorfulFormatter, self).__init__(*args, **kwargs)
33 |
34 | def formatMessage(self, record):
35 | record.name = record.name.replace(self._root_name, self._abbrev_name)
36 | log = super(_ColorfulFormatter, self).formatMessage(record)
37 | if record.levelno == logging.WARNING:
38 | prefix = colored("WARNING", "red", attrs=["blink"])
39 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
40 | prefix = colored("ERROR", "red", attrs=["blink", "underline"])
41 | else:
42 | return log
43 | return prefix + " " + log
44 |
45 |
46 | def setup_logger(color=True, name="TCE", abbrev_name=None):
47 | logger = logging.getLogger(name)
48 | logger.setLevel(logging.DEBUG)
49 | logger.propagate = False
50 |
51 | if abbrev_name is None:
52 | abbrev_name = name
53 |
54 | plain_formatter = logging.Formatter(
55 | "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S"
56 | )
57 | # stdout logging: master only
58 | ch = logging.StreamHandler(stream=sys.stdout)
59 | ch.setLevel(logging.DEBUG)
60 | if color:
61 | formatter = _ColorfulFormatter(
62 | colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
63 | datefmt="%m/%d %H:%M:%S",
64 | root_name=name,
65 | abbrev_name=str(abbrev_name),
66 | )
67 | else:
68 | formatter = plain_formatter
69 | ch.setFormatter(formatter)
70 | logger.addHandler(ch)
71 |
72 | return logger
73 |
--------------------------------------------------------------------------------
/utils/network_utils.py:
--------------------------------------------------------------------------------
1 | from network.resnet import *
2 |
3 | def get_backbone(trunk):
4 | if trunk == 'resnet18':
5 | backbone = resnet18()
6 | backbone_channels = 512
7 | elif trunk == 'resnet34':
8 | backbone = resnet34()
9 | backbone_channels = 512
10 | elif trunk == 'resnet50':
11 | backbone = resnet50()
12 | backbone_channels = 2048
13 | elif trunk == 'resnet101':
14 | backbone = resnet50()
15 | backbone_channels = 2048
16 | else:
17 | raise NotImplementedError('Backbone "{}" not currently supported'.format(trunk))
18 | return backbone, backbone_channels
19 |
20 | def forgiving_load_state_dict(state_dict, model, logger):
21 | keys_dict = state_dict.keys()
22 | keys_model = model.state_dict().keys()
23 |
24 | keys_shared = [k for k in keys_dict if k in keys_model]
25 | keys_missing = [k for k in keys_model if k not in keys_dict]
26 | keys_unexpected = [k for k in keys_dict if k not in keys_model]
27 |
28 | load_dict = {k:v for k,v in state_dict.items() if k in keys_shared}
29 | model.load_state_dict(load_dict, strict = False)
30 |
31 | logger.info('Missing Keys in checkpoint : ')
32 | for k in keys_missing:
33 | logger.info('\t\t{}'.format(k))
34 | logger.info('Unexpected Keys in checkpoint : ')
35 | for k in keys_unexpected:
36 | logger.info('\t\t{}'.format(k))
37 |
--------------------------------------------------------------------------------
/utils/optimizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | def get_optimizer(model, config, pretraining): # TODO Move to Utils
5 | if pretraining == True:
6 | lr = config.TRAIN.PRETRAINING.LEARNING_RATE
7 | momentum = config.TRAIN.PRETRAINING.MOMENTUM
8 | weight_decay = config.TRAIN.PRETRAINING.WEIGHT_DECAY
9 | elif pretraining == False:
10 | lr = config.TRAIN.FINETUNING.LEARNING_RATE
11 | momentum = config.TRAIN.FINETUNING.MOMENTUM
12 | weight_decay = config.TRAIN.FINETUNING.WEIGHT_DECAY
13 |
14 | optimizer = torch.optim.SGD(
15 | model.parameters(),
16 | lr = lr,
17 | momentum = momentum,
18 | weight_decay = weight_decay
19 | )
20 |
21 | return optimizer
22 |
23 | def adjust_learning_rate(epoch, cfg, optimizer, pretraining, logger):
24 | """Sets the learning rate to the initial LR decayed by the given rate every steep step"""
25 | if pretraining == True:
26 | decay_epochs = cfg.TRAIN.PRETRAINING.DECAY_EPOCHS
27 | lr = cfg.TRAIN.PRETRAINING.LEARNING_RATE
28 | decay_factor = cfg.TRAIN.PRETRAINING.DECAY_FACTOR
29 | elif pretraining == False:
30 | decay_epochs = cfg.TRAIN.FINETUNING.DECAY_EPOCHS
31 | lr = cfg.TRAIN.FINETUNING.LEARNING_RATE
32 | decay_factor = cfg.TRAIN.FINETUNING.DECAY_FACTOR
33 |
34 |
35 | steps = np.sum(epoch > np.asarray(decay_epochs))
36 | if steps > 0:
37 | new_lr = lr * (decay_factor ** steps)
38 | logger.info('Learning rate for epoch set to {}'.format(new_lr))
39 | for param_group in optimizer.param_groups:
40 | param_group['lr'] = new_lr
41 | else:
42 | logger.info('Learning rate for epoch set to {}'.format(lr))
43 |
--------------------------------------------------------------------------------
/utils/tsne_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import time
4 | import numpy as np
5 | from tqdm import tqdm
6 | from glob import glob
7 |
8 | from sklearn.manifold import TSNE
9 | from sklearn.decomposition import PCA
10 |
11 | import PIL.Image as Image
12 | from torchvision import transforms
13 | import matplotlib
14 | matplotlib.use('Agg')
15 | import matplotlib.pyplot as plt
16 |
17 | def create_gif_frames(reduced_embeddings, target, logger):
18 |
19 | # ===== Set up axis and points =====
20 | fig = plt.figure()
21 | ax = fig.add_subplot(1,1,1)
22 |
23 | ax.axes.xaxis.set_visible(False)
24 | ax.axes.yaxis.set_visible(False)
25 |
26 | embeddings_y = reduced_embeddings[:,0]
27 | embeddings_x = reduced_embeddings[:,1]
28 |
29 | # ===== Set up frame loader =====
30 | is_folder = os.path.isdir(target)
31 | if is_folder:
32 | imlist = sorted(glob(os.path.join(target, '*')))
33 | n_images = len(imlist)
34 | else:
35 | reader = cv2.VideoCapture(target)
36 | n_images = int(reader.get(cv2.CAP_PROP_FRAME_COUNT))
37 | assert reader.isOpened(), 'Target was neither directory or a valid video file'
38 | # ===== Iterate over frames =====
39 | frames = []
40 | logger.info('Beginning frame generation')
41 |
42 | for idx in tqdm(range(n_images)):
43 | ax.cla()
44 | # ===== Get frame =====
45 | if is_folder:
46 | frame = cv2.imread(imlist[idx])
47 | else:
48 | tf, frame = reader.read()
49 | if not tf:
50 | logger.info('Reached end of video at frame {} rather than {}'.format(x, n_images))
51 | break
52 | h, w = frame.shape[:2]
53 | frame = frame[:,:,::-1]
54 | height, width = frame.shape[:2]
55 | # ===== Plot on graph, highlighting the current frame's embedding =====
56 | x = list(embeddings_x[:idx]) + list(embeddings_x[idx+1:])
57 | y = list(embeddings_y[:idx]) + list(embeddings_y[idx+1:])
58 | ax.plot(x,y,',', color='b')
59 | ax.plot(embeddings_x[idx], embeddings_y[idx], 'o', color='r')
60 | fig.canvas.draw()
61 |
62 | tsne = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='').reshape(fig.canvas.get_width_height()[::-1] + (3,))
63 | tsne = cv2.resize(tsne, (width,height))
64 | frames.append(np.concatenate((frame, tsne), axis = 1))
65 |
66 | return frames
67 |
68 | class get_blob:
69 | def __init__(self, cfg):
70 | self.resize = transforms.Resize(cfg.VISUALISATION.TSNE.RESIZE)
71 | self.crop = transforms.CenterCrop(cfg.VISUALISATION.TSNE.CROP_SIZE)
72 | self.totensor = transforms.ToTensor()
73 | self.normalize = transforms.Normalize(
74 | mean = cfg.DATASET.PRETRAINING.MEAN,
75 | std = cfg.DATASET.PRETRAINING.STD
76 | )
77 |
78 | def __call__(self, frame):
79 | frame = Image.fromarray(frame)
80 | frame = self.resize(frame)
81 | frame = self.crop(frame)
82 | frame = self.totensor(frame)
83 | frame = self.normalize(frame)
84 | frame = frame.unsqueeze(0)
85 | frame = frame.cuda()
86 |
87 | return frame
88 |
89 | def fit_pca(input_data, logger, num_dims=0, threshold=0.85, max_error=0.05, debug=False):
90 |
91 | optimised = False
92 | # ===== Fit PCA to num_dims output dimensions of num_dims is not 0 =====
93 | if num_dims > 0:
94 | pca = PCA(n_components=num_dims)
95 | # ===== Fit PCA =====
96 | reduced = pca.fit_transform(input_data)
97 | explained_variance = np.sum(pca.explained_variance_ratio_)
98 | logger.info("PCA transform fitted. Explained variance in {} dims is {}.".format(num_dims, explained_variance))
99 | optimised = True
100 |
101 | """
102 | Get data input shape, set variables for the PCA loop
103 | lower, upper is the lower and upper bounds we are trying to converge for PCA dimensionality to achieve
104 | the desired explained_variance, as a fraction of the original input dimensionality
105 | previous is the previous attempt to fit the PCA's dimensionality
106 | """
107 | dims = min(input_data.shape[0], input_data.shape[1])
108 | lower, upper = (0., 1.)
109 | previous = -1
110 |
111 | while not optimised:
112 | # ===== Iterate the PCA until desired explained_variance achieved =====
113 |
114 | # ===== Fit a PCA to dimensionality halfway between lower and upper bounds already found =====
115 | num_dims = int(dims * (0.5*(upper - lower) + lower))
116 | if num_dims == previous:
117 | # ===== Settle on PCA dimensionality if this iteration has the same number of dimensions as the last =====
118 | num_dims = int(upper * dims)
119 | optimised = True
120 |
121 | # ===== Fit PCA to num_dims for this iteration =====
122 | t1 = time.time()
123 | logger.info('Fitting PCA')
124 | pca = PCA(n_components=num_dims)
125 | reduced = pca.fit_transform(input_data)
126 | logger.info('Time Taken = {} seconds'.format(time.time() - t1))
127 | explained_variance = np.sum(pca.explained_variance_ratio_)
128 | previous = num_dims
129 |
130 | if debug:
131 | logger.info('Lower&Upper: ({}, {})\tNumber of dimensions: {}\tExplained Variance: {}'
132 | .format(lower, upper, num_dims, explained_variance))
133 |
134 | if explained_variance < threshold:
135 | # ===== Raise lower bound if explained variance is high enough =====
136 | lower = num_dims / dims
137 | else:
138 | # ===== Lower upper bound if explained variance is not high enough =====
139 | upper = num_dims / dims
140 |
141 | if upper - lower < max_error:
142 | # ===== Settle on PCA dimensionality if upper and lower bounds within acceptable range of each other =====
143 | optimised = True
144 |
145 | return reduced
146 |
147 | def fit_tsne(input_data, logger, pca=True, pca_threshold=0.85, pca_error=0.05, pca_num_dims=0, num_dims=2, num_iterations=500, debug=False):
148 | '''
149 | Performs TSNE on the input data to reduce it to num_dims dimensions.
150 | Will first perform PCA by default to reduce the number of dimensions and make
151 | fitting tsne faster
152 |
153 | Arguments:
154 | input_data : A [SHAPE] [TENSOR, ARRAY?] containing the data for the PCA, where: #TODO
155 | DIM INFO
156 | pca : If True, erform some dimensionality reduction with a PCA before the TSNE to reduce computation time
157 | pca_threshold : Explained variance threshold for PCA
158 | pca_error : Acceptable distance between lower and upper bounds for the pca to be considered converged, as a value between
159 | 0 and 1
160 | pca_num_dims : If set to a non-zero value, PCA will reduce input to num_dims dimensions
161 | num_dims : Number of dimensions to reduce to using TSNE
162 | num_iterations : Number of iterations to run the TSNE for
163 | debug : If True, print debug information
164 | '''
165 | # ===== Reduce data with PCA first if pca=True =====
166 | if pca:
167 | input_data = fit_pca(input_data, logger, num_dims=pca_num_dims, threshold=pca_threshold, max_error=pca_error, debug=debug)
168 | t1 = time.time()
169 | logger.info('Fitting TSNE: ')
170 | tsne=TSNE(n_iter=num_iterations, n_components=num_dims)
171 |
172 | reduced = tsne.fit_transform(input_data)
173 | logger.info('Time Taken = {} seconds'.format(time.time() - t1))
174 |
175 | return reduced
--------------------------------------------------------------------------------
/visualise_tsne.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import torch
4 | import imageio
5 | import argparse
6 | import numpy as np
7 | from glob import glob
8 | from tqdm import tqdm
9 | from glob import glob
10 |
11 | import matplotlib
12 | matplotlib.use('Agg')
13 | import matplotlib.pyplot as plt
14 |
15 | from network import PreTrainNet
16 | from config import config, update_config
17 | from utils import *
18 |
19 |
20 | def parse_args():
21 | parser = argparse.ArgumentParser(description='Train TCE Self-Supervised')
22 | parser.add_argument('--cfg', help = 'Path to config file', type = str, default = None)
23 | parser.add_argument('--target', help = 'Path to visualisation target. Can be a video or folder of images', default = None)
24 | parser.add_argument('--ckpt', help = 'Checkpoint to visualise', type=str, required = True)
25 | parser.add_argument('--gif', action = 'store_true', default = False, help = 'Save output as a gif with the corresponding video alongside')
26 | parser.add_argument('--fps', type = float, help = 'Frames per second for gif', default = 30)
27 | parser.add_argument('--save', help = 'Save path', default = None)
28 | parser.add_argument('opts', help = 'Modify config using the command line',
29 | default = None, nargs=argparse.REMAINDER )
30 | args = parser.parse_args()
31 | update_config(config, args)
32 |
33 | return args
34 |
35 |
36 |
37 | if __name__ == '__main__':
38 | args = parse_args()
39 | logger = setup_logger()
40 |
41 | # ===== Get model =====
42 | model = PreTrainNet(config).eval()
43 | model = torch.nn.DataParallel(model).cuda()
44 |
45 | # ===== Load Model Checkpoint =====
46 | checkpoint = torch.load(args.ckpt)
47 | model.load_state_dict(checkpoint['state_dict'])
48 | logger.info('Loaded Checkpoint from {}'.format(args.ckpt))
49 |
50 |
51 | # ===== Get input transformation class =====
52 | create_blob = get_blob(config)
53 |
54 | # ===== Create either list of frames or videoreader object =====
55 | is_folder = os.path.isdir(args.target)
56 | if is_folder:
57 | imlist = sorted(glob(os.path.join(args.target, '*')))
58 | n_images = len(imlist)
59 | else:
60 | reader = cv2.VideoCapture(args.target)
61 | n_images = int(reader.get(cv2.CAP_PROP_FRAME_COUNT))
62 | assert reader.isOpened(), 'Target was neither directory or a valid video file'
63 |
64 | # ===== Process video frames with network =====
65 | embeddings = []
66 | logger.info('Processing {} Frames'.format(n_images))
67 | for idx in tqdm(range(n_images)):
68 | if is_folder:
69 | frame = cv2.imread(imlist[idx])
70 | else:
71 | tf, frame = reader.read()
72 | if not tf:
73 | logger.info('Reached end of video at frame {} rather than {}'.format(x, n_images))
74 | break
75 |
76 | # ===== Get image dimensions, convert from BGR to RGB, create input blob =====
77 | height, width = np.shape(frame)[:2]
78 | frame = frame[:, :, ::-1]
79 | blob = create_blob(frame)
80 |
81 | # ===== Get embedding from network =====
82 | with torch.no_grad():
83 | embeddings.append(model(blob).cpu().numpy())
84 |
85 | if not is_folder:
86 | reader.release()
87 |
88 | # ===== Use TSNE to reduce embeddings to 2D for plotting =====
89 | embeddings = np.array(embeddings)[:,0,:]
90 | reduced_embeddings = fit_tsne(embeddings, logger)
91 |
92 |
93 | # ===== Create Plot =====
94 | if args.gif:
95 | # ===== Create gif if flag set =====
96 | frames = create_gif_frames(reduced_embeddings, args.target, logger)
97 | imageio.mimsave(args.save, frames, duration = 1 / args.fps)
98 | else:
99 | # ===== Create TSNE single image =====
100 | fig, ax = plt.subplots()
101 | ax.axes.xaxis.set_visible(False)
102 | ax.axes.yaxis.set_visible(False)
103 |
104 | colormap = plt.cm.plasma(np.linspace(0,1, len(reduced_embeddings)))
105 | for idx, point in enumerate(reduced_embeddings):
106 | ax.plot(point[0], point[1], '.', color = colormap[idx])
107 |
108 | # ===== Write plot to image file =====
109 | fig.canvas.draw()
110 | tsne = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='').reshape(fig.canvas.get_width_height()[::-1] + (3,))
111 | Image.fromarray(tsne).save(args.save)
112 |
113 |
114 |
115 |
116 |
--------------------------------------------------------------------------------