├── .idea ├── TBE.iml ├── deployment.xml ├── encodings.xml ├── misc.xml ├── modules.xml ├── remote-mappings.xml ├── vcs.xml └── workspace.xml ├── dataset.md ├── figures ├── PPL.jpg ├── be_visualization.png ├── example1.gif ├── example2.gif ├── mix_videos.png ├── motivation.png └── motivation_example.png ├── readme.md └── src └── Contrastive ├── __init__.py ├── augment ├── basic_augmentation │ ├── eda.py │ ├── mixup_methods.py │ ├── net_mixup.py │ ├── noise.py │ ├── rotation.py │ ├── temporal_augment.py │ ├── temporal_dropout.py │ ├── temporal_shuffle.py │ ├── temporal_sub.py │ ├── triplet.py │ └── video_color_jitter.py ├── config.py ├── gen_negative.py ├── gen_positive.py └── video_transformations │ ├── functional.py │ ├── video_transform_PIL_or_np.py │ ├── videotransforms.py │ └── volume_transforms.py ├── data ├── __init__.py ├── base.py ├── config.py ├── dataloader.py ├── dataset.py ├── decode_on_the_fly.py ├── on_the_fly_test.py └── video_dataset.py ├── feature_extractor.py ├── ft.py ├── loss ├── NCE │ ├── Link.py │ ├── NCEAverage.py │ ├── NCECriterion.py │ ├── __init__.py │ └── alias_multinomial.py ├── __init__.py ├── config.py └── tcr.py ├── main.py ├── model ├── __init__.py ├── c3d.py ├── config.py ├── i3d.py ├── model.py ├── mutual_net.py ├── r2p1d.py ├── r3d.py ├── s3d.py └── s3d_g.py ├── option.py ├── pt.py ├── reterival.py ├── scripts ├── Diving48 │ ├── ft.sh │ ├── pt_and_ft.sh │ ├── pt_and_ft_moco.sh │ └── pt_and_ft_moco_ucf_to_diving48.sh ├── evaluation │ ├── hmdb51_i3d_eval.sh │ └── ucf101_i3d_eval.sh ├── feature_extract │ ├── hmdb51_extract.sh │ └── ucf101_extract.sh ├── hmdb51 │ ├── ft.sh │ ├── pt_and_ft.sh │ └── pt_and_ft_hmdb51.sh ├── kinetics │ └── pt_and_ft.sh ├── something-something-v1 │ ├── i3d_ft.sh │ ├── i3d_pt_and_ft.sh │ ├── i3d_pt_and_ft_multi_gpus.sh │ ├── r3d_ft.sh │ ├── r3d_ft_multi_gpus.sh │ ├── r3d_pt_and_ft.sh │ └── r3d_pt_and_ft_multi_gpus.sh ├── ucf101 │ └── pt_and_ft.sh └── visualization │ ├── ucf101_data_augmentation_visualize.sh │ └── ucf101_triplet_visualize.sh ├── test.py └── utils ├── data_process ├── gen_diving48_frames.py ├── gen_diving48_lists.py ├── gen_hmdb51_dir.py ├── gen_hmdb51_frames.py ├── gen_hmdb51_sta_list.py ├── gen_sub_dataset.py ├── semi_data_split.py └── unrar_hmdb51_sta.sh ├── gradient_check.py ├── learning_rate_adjust.py ├── load_weights.py ├── moment_update.py ├── recoder.py ├── utils.py └── visualization ├── augmentation_visualization.py ├── color_palettes.py ├── confusion_matrix.py ├── mixup_visualization.py ├── multi100.txt ├── pearson_calculate.txt ├── pearson_correlation.py ├── plot_class_distribution.py ├── plot_flow_v_rgb_distribution.py ├── single_video_visualization.py ├── t_SNE_Visualization.py ├── triplet_visualization.py └── video_write_test.py /.idea/TBE.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 14 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 15 | -------------------------------------------------------------------------------- /.idea/encodings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/remote-mappings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /dataset.md: -------------------------------------------------------------------------------- 1 | # Dataset Prepare 2 | For Kinetics, we decode on the fly, each row in the txt file include: 3 | > video_path class 4 | 5 | And we load the videos directly, please place the training set in SSD for fast IO. 6 | 7 | Prepare dataset (UCF101/diving48/sth/hmdb51/actor-hmdb51), and each row of txt is as below: 8 | 9 | > video_path class frames_num 10 | 11 | These datasets saved in frames. We offer list for all these datasets in [Google Driver](https://drive.google.com/drive/folders/1ndq0rdxEvubBrbXny8RuGCTETXU2hr1N?usp=sharing). 12 | 13 | ## Kinetics 14 | As some Youtube Link is lost, we use a copy of kinetics-400 from [Non-local](https://github.com/facebookresearch/video-nonlocal-net), the training set is 234643 videos now and the val set is 19761 now. 15 | All the videos in mpeg/avi format. 16 | 17 | ## UCF101/HMDB51 18 | These two video datasets contain three train/test splits. 19 | Down UCF101 from [crcv](https://www.crcv.ucf.edu/data/UCF101.php) and HMDB51 from [serre-la](https://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/). 20 | 21 | ## Diving48 22 | (1). Download divin48 from [ucsd](http://www.svcl.ucsd.edu/projects/resound/dataset.html) 23 | 24 | (2). Generated frames using script __src/Contrastive/utils/data_process/gen_diving48_frames.py__ 25 | 26 | (3). Generated lists using script __src/Contrastive/utils/data_process/gen_diving48_lists.py__ 27 | 28 | ## Sth-v1 29 | 30 | ## HMDB51-STA/Actor-HMDB51 31 | The HMDB51-STA is removing huge cmaera motion from [HMDB-STA](https://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/). 32 | We provide our generated Actor-HMDB51 in [google_driver](). 33 | 34 | ### Semi-supervised Subdataset 35 | We also provide manuscript to get the sub-set of kinetics-400. Please refer to Contrastive/utils/data_provess/semi_data_split.py for details. 36 | -------------------------------------------------------------------------------- /figures/PPL.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/BE/632c3aa0eaa3acc24a545ec05a9a36f96592cb2c/figures/PPL.jpg -------------------------------------------------------------------------------- /figures/be_visualization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/BE/632c3aa0eaa3acc24a545ec05a9a36f96592cb2c/figures/be_visualization.png -------------------------------------------------------------------------------- /figures/example1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/BE/632c3aa0eaa3acc24a545ec05a9a36f96592cb2c/figures/example1.gif -------------------------------------------------------------------------------- /figures/example2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/BE/632c3aa0eaa3acc24a545ec05a9a36f96592cb2c/figures/example2.gif -------------------------------------------------------------------------------- /figures/mix_videos.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/BE/632c3aa0eaa3acc24a545ec05a9a36f96592cb2c/figures/mix_videos.png -------------------------------------------------------------------------------- /figures/motivation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/BE/632c3aa0eaa3acc24a545ec05a9a36f96592cb2c/figures/motivation.png -------------------------------------------------------------------------------- /figures/motivation_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/BE/632c3aa0eaa3acc24a545ec05a9a36f96592cb2c/figures/motivation_example.png -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | ![](https://img.shields.io/badge/-self--supervised--learning-brightgreen) 2 | ![](https://img.shields.io/badge/-action%20recognition-yellowgreen) 3 | ![](https://img.shields.io/badge/-implicit%20bias-lightgrey) 4 | ![](https://img.shields.io/badge/-pytorch-red) 5 | 6 | # TBE 7 | 8 | The source code for our paper "Removing the Background by Adding the Background: Towards Background Robust Self-supervised Video Representation Learning" [[arxiv](https://arxiv.org/abs/2009.05769)] 9 | [[code](https://github.com/FingerRec/BE)][[Project Website](https://fingerrec.github.io/index_files/jinpeng/papers/CVPR2021/project_website.html)] 10 | 11 |
12 | image 13 |
14 | 15 | 16 | ## Citation 17 | 18 | ```bash 19 | @inproceedings{wang2021removing, 20 | title={Removing the Background by Adding the Background: Towards Background Robust Self-supervised Video Representation Learning}, 21 | author={Wang, Jinpeng and Gao, Yuting and Li, Ke and Lin, Yiqi and Ma, Andy J and Cheng, Hao and Peng, Pai and Ji, Rongrong and Sun, Xing}, 22 | booktitle={CVPR}, 23 | year={2021} 24 | } 25 | ``` 26 | 27 | ## News 28 | [2020.3.7] The first version of TBE are released! 29 | 30 | ## 0. Motivation 31 | 32 | - In camera-fixed situation, the static background in most frames remain similar in pixel-distribution. 33 | 34 | ![](figures/motivation.png) 35 | 36 | 37 | - We ask the model to be **temporal sensitive** rather than **static sensitive**. 38 | 39 | - We ask model to filter the additive **Background Noise**, which means to erasing background in each frame of the video. 40 | 41 | 42 | ### Activation Map Visualization of BE 43 | 44 | ![](figures/motivation_example.png) 45 | 46 | ### GIF 47 | ![](figures/example1.gif) 48 | 49 | #### More hard example 50 | ![](figures/example2.gif) 51 | 52 | ## 2. Plug BE into any self-supervised learning method in two steps 53 | 54 | The impementaion of BE is very simple, you can implement it in two lines by python: 55 | ```python 56 | rand_index = random.randint(t) 57 | mixed_x[j] = (1-prob) * x + prob * x[rand_index] 58 | ``` 59 | 60 | Then, just need define a loss function like MSE: 61 | 62 | ```python 63 | loss = MSE(F(mixed_x),F(x)) 64 | ``` 65 | 66 | 67 | ## 2. Installation 68 | 69 | ### Dataset Prepare 70 | Please refer to [dataset.md] for details. 71 | 72 | 73 | ### Requirements 74 | - Python3 75 | - pytorch1.1+ 76 | - PIL 77 | - Intel (on the fly decode) 78 | - Skvideo.io 79 | - Matplotlib (gradient_check) 80 | 81 | **As Kinetics dataset is time-consuming for IO, we decode the avi/mpeg on the fly. Please refer to data/video_dataset.py for details.** 82 | 83 | ## 3. Structure 84 | - datasets 85 | - list 86 | - hmdb51: the train/val lists of HMDB51/Actor-HMDB51 87 | - hmdb51_sta: the train/val lists of HMDB51_STA 88 | - ucf101: the train/val lists of UCF101 89 | - kinetics-400: the train/val lists of kinetics-400 90 | - diving48: the train/val lists of diving48 91 | - experiments 92 | - logs: experiments record in detials, include logs and trained models 93 | - gradientes: 94 | - visualization: 95 | - pretrained_model: 96 | - src 97 | - Contrastive 98 | - data: load data 99 | - loss: the loss evaluate in this paper 100 | - model: network architectures 101 | - scripts: train/eval scripts 102 | - augmentation: detail implementation of BE augmentation 103 | - utils 104 | - feature_extract.py: feature extractor given pretrained model 105 | - main.py: the main function of pretrain / finetune 106 | - trainer.py 107 | - option.py 108 | - pt.py: BE pretrain 109 | - ft.py: BE finetune 110 | - Pretext 111 | - main.py the main function of pretrain / finetune 112 | - loss: the loss include classification loss 113 | ## 4. Run 114 | ### (1). Download dataset lists and pretrained model 115 | A copy of both dataset lists is provided in [anonymous](). 116 | The Kinetics-pretrained models are provided in [anonymous](). 117 | 118 | ```bash 119 | cd .. && mkdir datasets 120 | mv [path_to_lists] to datasets 121 | mkdir experiments && cd experiments 122 | mkdir pretrained_models && logs 123 | mv [path_to_pretrained_model] to ../experiments/pretrained_model 124 | ``` 125 | 126 | Download and extract frames of Actor-HMDB51. 127 | ```bash 128 | wget -c anonymous 129 | unzip 130 | python utils/data_process/gen_hmdb51_dir.py 131 | python utils/data_process/gen_hmdb51_frames.py 132 | ``` 133 | ### (2). Network Architecture 134 | The network is in the folder **src/model/[].py** 135 | 136 | | Method | #logits_channel | 137 | | ---- | ---- | 138 | | C3D | 512 | 139 | | R2P1D | 2048 | 140 | | I3D | 1024 | 141 | | R3D | 2048 | 142 | 143 | All the logits_channel are feed into a fc layer with 128-D output. 144 | 145 | 146 | 147 | > For simply, we divide the source into Contrastive and Pretext, "--method pt_and_ft" means pretrain and finetune in once. 148 | 149 | ### Action Recognition 150 | #### Random Initialization 151 | For random initialization baseline. Just comment --weights in line 11 of ft.sh. 152 | Like below: 153 | 154 | ```bash 155 | #!/usr/bin/env bash 156 | python main.py \ 157 | --method ft --arch i3d \ 158 | --ft_train_list ../datasets/lists/diving48/diving48_v2_train_no_front.txt \ 159 | --ft_val_list ../datasets/lists/diving48/diving48_v2_test_no_front.txt \ 160 | --ft_root /data1/DataSet/Diving48/rgb_frames/ \ 161 | --ft_dataset diving48 --ft_mode rgb \ 162 | --ft_lr 0.001 --ft_lr_steps 10 20 25 30 35 40 --ft_epochs 45 --ft_batch_size 4 \ 163 | --ft_data_length 64 --ft_spatial_size 224 --ft_workers 4 --ft_stride 1 --ft_dropout 0.5 \ 164 | --ft_print-freq 100 --ft_fixed 0 # \ 165 | # --ft_weights ../experiments/kinetics_contrastive.pth 166 | ``` 167 | 168 | 169 | #### BE(Contrastive) 170 | ##### Kinetics 171 | ```bash 172 | bash scripts/kinetics/pt_and_ft.sh 173 | ``` 174 | ##### UCF101 175 | ```bash 176 | bash scripts/ucf101/ucf101.sh 177 | ``` 178 | ##### Diving48 179 | ```bash 180 | bash scripts/Diving48/diving48.sh 181 | ``` 182 | 183 | 184 | **For Triplet loss optimization and moco baseline, just modify --pt_method** 185 | 186 | #### BE (Triplet) 187 | 188 | ```bash 189 | --pt_method be_triplet 190 | ``` 191 | 192 | #### BE(Pretext) 193 | ```bash 194 | bash scripts/hmdb51/i3d_pt_and_ft_flip_cls.sh 195 | ``` 196 | 197 | or 198 | 199 | ```bash 200 | bash scripts/hmdb51/c3d_pt_and_ft_flip.sh 201 | ``` 202 | 203 | **Notice: More Training Options and ablation study can be find in scripts** 204 | 205 | 206 | ### Video Retrieve and other visualization 207 | 208 | #### (1). Feature Extractor 209 | As STCR can be easily extend to other video representation task, we offer the scripts to perform feature extract. 210 | ```bash 211 | python feature_extractor.py 212 | ``` 213 | 214 | The feature will be saved as a single numpy file in the format [video_nums,features_dim] for further visualization. 215 | 216 | #### (2). Reterival Evaluation 217 | modify line60-line62 in reterival.py. 218 | ```bash 219 | python reterival.py 220 | ``` 221 | 222 | ## Results 223 | ### Action Recognition 224 | #### Kinetics Pretrained (I3D) 225 | | Method | UCF101 | HMDB51 | Diving48 | 226 | | ---- | ---- | ---- | ---- | 227 | | Random Initialization | 57.9 | 29.6| 17.4| 228 | | MoCo Baseline | 70.4 | 36.3| 47.9 | 229 | | BE | 86.5 | 56.2| 62.6| 230 | 231 | 232 | 233 | ### Video Retrieve (HMDB51-C3D) 234 | | Method | @1 | @5 | @10| @20|@50 | 235 | | ---- | ---- | ---- | ---- | ---- | ---- | 236 | | BE | 10.2 | 27.6| 40.5 |56.2|76.6| 237 | 238 | ## More Visualization 239 | ### T-SNE 240 | please refer to __utils/visualization/t_SNE_Visualization.py__ for details. 241 | 242 | ### Confusion_Matrix 243 | please refer to __utils/visualization/confusion_matrix.py__ for details. 244 | 245 | 246 | 247 | ## Acknowledgement 248 | This work is partly based on [UEL](https://github.com/mangye16/Unsupervised_Embedding_Learning) and [MoCo](https://github.com/facebookresearch/moco). 249 | 250 | ## License 251 | The code are released under the CC-BY-NC 4.0 LICENSE. -------------------------------------------------------------------------------- /src/Contrastive/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/BE/632c3aa0eaa3acc24a545ec05a9a36f96592cb2c/src/Contrastive/__init__.py -------------------------------------------------------------------------------- /src/Contrastive/augment/basic_augmentation/eda.py: -------------------------------------------------------------------------------- 1 | # Easy data augmentation techniques for video 2 | # Jason Wei and Kai Zou 3 | 4 | import torch.nn as nn 5 | import random 6 | from random import shuffle 7 | #random.seed(1) 8 | # cleaning up text 9 | import re 10 | 11 | 12 | 13 | ######################################################################## 14 | # Synonym replacement 15 | # Replace n words in the sentence with synonyms from wordnet 16 | ######################################################################## 17 | 18 | 19 | def synonym_replacement(words, n): 20 | new_words = words.copy() 21 | random_word_list = list(set([word for word in words if word not in stop_words])) 22 | random.shuffle(random_word_list) 23 | num_replaced = 0 24 | for random_word in random_word_list: 25 | synonyms = get_synonyms(random_word) 26 | if len(synonyms) >= 1: 27 | synonym = random.choice(list(synonyms)) 28 | new_words = [synonym if word == random_word else word for word in new_words] 29 | # print("replaced", random_word, "with", synonym) 30 | num_replaced += 1 31 | if num_replaced >= n: # only replace up to n words 32 | break 33 | 34 | # this is stupid but we need it, trust me 35 | sentence = ' '.join(new_words) 36 | new_words = sentence.split(' ') 37 | 38 | return new_words 39 | 40 | 41 | def get_synonyms(word): 42 | synonyms = set() 43 | for syn in wordnet.synsets(word): 44 | for l in syn.lemmas(): 45 | synonym = l.name().replace("_", " ").replace("-", " ").lower() 46 | synonym = "".join([char for char in synonym if char in ' qwertyuiopasdfghjklzxcvbnm']) 47 | synonyms.add(synonym) 48 | if word in synonyms: 49 | synonyms.remove(word) 50 | return list(synonyms) 51 | 52 | 53 | ######################################################################## 54 | # Random deletion 55 | # Randomly delete words from the sentence with probability p 56 | ######################################################################## 57 | 58 | def random_deletion(videos, p): 59 | # obviously, if there's only one word, don't delete it 60 | b, c, t, h, w = videos.size() 61 | if t == 1: 62 | return videos 63 | 64 | # randomly delete words with probability p, padding or loop? 65 | ''' 66 | # method 1 67 | new_videos = videos.copy() 68 | for i in range(t): 69 | r = random.uniform(0, 1) 70 | if r <= p: 71 | new_videos[:, :, i, :, :] = 0 72 | ''' 73 | # method 2 loop 74 | new_videos = videos 75 | count = 0 76 | for i in range(t): 77 | r = random.uniform(0, 1) 78 | if r <= p: 79 | continue 80 | else: 81 | new_videos[:, :, count, :, :] = videos[:, :, i, :, :,] 82 | count += 1 83 | for i in range(t - count): 84 | new_videos[:, :, i, :, :] = videos[:, :, i, :, :] 85 | # if you end up deleting all words, just return a random word 86 | if new_videos.size()[2] == 0: 87 | rand_int = random.randint(0, t - 1) 88 | return [videos[:, :, rand_int, :, :]] 89 | 90 | return new_videos 91 | 92 | 93 | ######################################################################## 94 | # Random swap 95 | # Randomly swap two words in the sentence n times 96 | ######################################################################## 97 | 98 | def random_swap(videos, n): 99 | new_videos = videos 100 | for _ in range(n): 101 | new_videos = swap_word(new_videos) 102 | return new_videos 103 | 104 | 105 | def swap_word(new_videos): 106 | b, c, t, h, w = new_videos.size() 107 | random_idx_1 = random.randint(0, t - 1) 108 | random_idx_2 = random_idx_1 109 | counter = 0 110 | while random_idx_2 == random_idx_1: 111 | random_idx_2 = random.randint(0, t - 1) 112 | counter += 1 113 | if counter > 3: 114 | return new_videos 115 | new_videos[:, :, random_idx_1, :, :], new_videos[:, :, random_idx_2, :, :] \ 116 | = new_videos[:, :, random_idx_2, :, :], new_videos[:, :, random_idx_1, :, :] 117 | return new_videos 118 | 119 | 120 | ######################################################################## 121 | # Random insertion 122 | # Randomly insert n words into the sentence 123 | ######################################################################## 124 | 125 | def random_insertion(videos, n): 126 | new_videos = videos 127 | for _ in range(n): 128 | new_videos = add_picture(new_videos) 129 | return new_videos 130 | 131 | 132 | def add_picture(videos): 133 | b, c, t, h, w = videos.size() 134 | new_videos = videos 135 | random_idx = random.randint(0, t - 1) 136 | random_idx2 = random.randint(0, t - 1) 137 | # this is from the same sample, may be need modify 138 | new_videos[:, :, random_idx+1:, :, :] = videos[:, :, random_idx:t-1, :, :] 139 | new_videos[:, :, random_idx, :, :] = videos[:, :, random_idx2, :, :] 140 | return new_videos 141 | 142 | ######################################################################## 143 | # main data augmentation function 144 | ######################################################################## 145 | 146 | 147 | class VideoEda(nn.Module): 148 | def __init__(self, alpha_sr=0.1, alpha_ri=0.1, alpha_rs=0.1, p_rd=0.1, num_aug=4): 149 | super(VideoEda, self).__init__() 150 | self.alpha_sr = alpha_sr 151 | self.alpha_ri = alpha_ri 152 | self.alpha_rs = alpha_rs 153 | self.p_rd = p_rd 154 | self.num_aug = num_aug 155 | 156 | def eda(self, inp): 157 | b, c, t, h, w = inp.size() 158 | 159 | augmented_sentences = [] 160 | num_new_per_technique = int(self.num_aug / 4) + 1 161 | n_sr = max(1, int(self.alpha_sr * t)) 162 | n_ri = max(1, int(self.alpha_ri * t)) 163 | n_rs = max(1, int(self.alpha_rs * t)) 164 | 165 | # # sr synonym replacement 166 | # for _ in range(num_new_per_technique): 167 | # inp = synonym_replacement(inp, n_sr) 168 | 169 | # ri random insertion 170 | for _ in range(num_new_per_technique): 171 | inp = random_insertion(inp, n_ri) 172 | 173 | # rs random swap 174 | for _ in range(num_new_per_technique): 175 | inp = random_swap(inp, n_rs) 176 | 177 | # # rd random delte 178 | # for _ in range(num_new_per_technique): 179 | # inp = random_deletion(inp, self.p_rd) 180 | 181 | return inp 182 | 183 | def forward(self, x): 184 | return self.eda(x) -------------------------------------------------------------------------------- /src/Contrastive/augment/basic_augmentation/net_mixup.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class NETMIXUP(object): 5 | def __init__(self, alpha): 6 | self.alpha = alpha 7 | 8 | def gen_prob(self): 9 | lam = np.random.beta(self.alpha, self.alpha) 10 | return lam 11 | 12 | def construct(self, a, b, mixup_radio=0.5): 13 | c = mixup_radio * a + (1 - mixup_radio) * b 14 | return c 15 | -------------------------------------------------------------------------------- /src/Contrastive/augment/basic_augmentation/noise.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | """ 6 | usage 7 | z_rand = generate_noise([1,nzx,nzy], device=opt.device) 8 | z_rand = z_rand.expand(1,3,Z_opt.shape[2],Z_opt.shape[3]) 9 | z_prev1 = 0.95*Z_opt +0.05*z_rand 10 | """ 11 | 12 | 13 | def upsampling(im, sx, sy): 14 | m = nn.Upsample(size=[round(sx), round(sy)], mode='bilinear', align_corners=True) 15 | return m(im) 16 | 17 | 18 | def generate_noise(size, num_samp=1, device='cuda', type='gaussian', scale=1): 19 | if type == 'gaussian': 20 | noise = torch.randn(num_samp, size[0], round(size[1]/scale), round(size[2]/scale)) 21 | noise = upsampling(noise, size[1], size[2]) 22 | if type == 'gaussian_mixture': 23 | noise1 = torch.randn(num_samp, size[0], size[1], size[2]) + 5 24 | noise2 = torch.randn(num_samp, size[0], size[1], size[2]) 25 | noise = noise1 + noise2 26 | if type == 'uniform': 27 | noise = torch.randn(num_samp, size[0], size[1], size[2]) 28 | return noise 29 | -------------------------------------------------------------------------------- /src/Contrastive/augment/basic_augmentation/rotation.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | 5 | def batch_rotation(l_new_data): 6 | B, C, T, H, W = l_new_data.size() 7 | rotation_type = random.randint(0, 2) 8 | # print(rotation_type) 9 | if rotation_type == 0: 10 | index = list(range(T - 1, -1, -1)) 11 | rotation_data = l_new_data[:, :, index, :, :] 12 | elif rotation_type == 1: 13 | index = list(range(H - 1, -1, -1)) 14 | rotation_data = l_new_data[:, :, :, index, :] 15 | else: 16 | index = list(range(W - 1, -1, -1)) 17 | rotation_data = l_new_data[:, :, :, :, index] 18 | return rotation_data, rotation_type 19 | 20 | 21 | def sample_rotation(l_new_data, rotation_type, trace=False): 22 | """ 23 | her/vec flip (0, 1) 24 | rotation 90/180/270 degree(2, 3, 4) 25 | her flip + rotate90 / ver flip + rotate 90 (5, 6) 26 | :param l_new_data: 27 | :param rotation_type 28 | :return: 29 | """ 30 | B, C, T, H, W = l_new_data.size() 31 | if not trace: 32 | rotated_data = torch.zeros_like(l_new_data).cuda() 33 | else: 34 | rotated_data = l_new_data 35 | for i in range(B): 36 | if rotation_type[i] == 0: 37 | rotated_data[i] = l_new_data[i].flip(2) 38 | elif rotation_type[i] == 1: 39 | rotated_data[i] = l_new_data[i].flip(3) 40 | elif rotation_type[i] == 2: 41 | rotated_data[i] = l_new_data[i].transpose(2, 3).flip(2) 42 | elif rotation_type[i] == 3: 43 | rotated_data[i] = l_new_data[i].flip(2).flip(3) 44 | elif rotation_type[i] == 4: 45 | rotated_data[i] = l_new_data[i].transpose(2, 3).flip(3) 46 | elif rotation_type[i] == 5: 47 | rotated_data[i] = l_new_data[i].flip(2).transpose(2, 3).flip(2) 48 | elif rotation_type[i] == 6: 49 | rotated_data[i] = l_new_data[i].flip(3).transpose(2, 3).flip(2) 50 | else: 51 | rotated_data[i] = l_new_data[i] 52 | return rotated_data 53 | 54 | 55 | def sample_rotation_cls(l_new_data, rotation_type): 56 | """ 57 | her/vec flip (0, 1) 58 | rotation 90/180/270 degree(2, 3, 4) 59 | her flip + rotate90 / ver flip + rotate 90 (5, 6) 60 | :param l_new_data: 61 | :param rotation_type: 62 | :return: 63 | """ 64 | B, C, T, H, W = l_new_data.size() 65 | rotated_data = torch.zeros_like(l_new_data).cuda() 66 | for i in range(B): 67 | if rotation_type[i] == 0: 68 | rotated_data[i] = l_new_data[i].flip(2) 69 | elif rotation_type[i] == 1: 70 | rotated_data[i] = l_new_data[i].flip(3) 71 | elif rotation_type[i] == 2: 72 | rotated_data[i] = l_new_data[i].transpose(2, 3).flip(2) 73 | elif rotation_type[i] == 3: 74 | rotated_data[i] = l_new_data[i].flip(2).flip(3) 75 | elif rotation_type[i] == 4: 76 | rotated_data[i] = l_new_data[i].transpose(2, 3).flip(3) 77 | elif rotation_type[i] == 5: 78 | rotated_data[i] = l_new_data[i].flip(2).transpose(2, 3).flip(2) 79 | elif rotation_type[i] == 6: 80 | rotated_data[i] = l_new_data[i].flip(3).transpose(2, 3).flip(2) 81 | else: 82 | rotated_data[i] = l_new_data[i] 83 | return rotated_data 84 | 85 | 86 | def four_rotation_cls(l_new_data, rotation_type): 87 | """ 88 | rotation 0/90/180/270 degree(0,1,2,3) 89 | :return: 90 | """ 91 | B, C, T, H, W = l_new_data.size() 92 | rotated_data = torch.zeros_like(l_new_data).cuda() 93 | for i in range(B): 94 | if rotation_type[i] == 1: 95 | rotated_data[i] = l_new_data[i].transpose(2, 3).flip(2) 96 | elif rotation_type[i] == 2: 97 | rotated_data[i] = l_new_data[i].flip(2).flip(3) 98 | elif rotation_type[i] == 3: 99 | rotated_data[i] = l_new_data[i].transpose(2, 3).flip(3) 100 | else: 101 | rotated_data[i] = l_new_data[i] 102 | return rotated_data 103 | 104 | 105 | def all_flips(l_new_data, flip_type, trace=False): 106 | """ 107 | her/vec flip (0, 1) 108 | rotation 90/180/270 degree(2, 3, 4) 109 | her flip + rotate90 / ver flip + rotate 90 (5, 6) 110 | :param l_new_data: 111 | :param rotation_type 112 | :return: 113 | """ 114 | """ 115 | :param l_new_data: 116 | :param rotation_type 117 | :return: 118 | """ 119 | B, C, T, H, W = l_new_data.size() 120 | if not trace: 121 | rotated_data = torch.zeros_like(l_new_data).cuda() 122 | else: 123 | rotated_data = l_new_data 124 | flip_type = flip_type // 4 125 | rot_type = flip_type % 4 126 | # flip at first 127 | for i in range(B): 128 | if flip_type[i] == 0: 129 | rotated_data[i] = l_new_data[i] 130 | elif flip_type[i] == 1: # left-right flip 131 | rotated_data[i] = l_new_data[i].flip(3) 132 | elif flip_type[i] == 2: # temporal flip 133 | rotated_data[i] = l_new_data[i].flip(1) 134 | else: # left-right + temporal flip 135 | rotated_data[i] = l_new_data[i].flip(3).flip(1) 136 | # then rotation 137 | for i in range(B): 138 | if rot_type[i] == 0: 139 | rotated_data[i] = l_new_data[i] 140 | elif rot_type[i] == 1: # 90 degree 141 | rotated_data[i] = l_new_data[i].transpose(2, 3).flip(2) 142 | elif rot_type[i] == 2: # 180 degree 143 | rotated_data[i] = l_new_data[i].flip(2).flip(3) 144 | else: # 270 degree 145 | rotated_data[i] = l_new_data[i].transpose(2, 3).flip(3) 146 | 147 | return rotated_data 148 | 149 | 150 | if __name__ == '__main__': 151 | a = torch.tensor([[1,2],[3,4]]).view(1, 1, 1, 2, 2) 152 | print(a.size()) 153 | for i in range(8): 154 | print(torch.tensor(i)) 155 | print(sample_rotation_cls(a, torch.tensor([i]))) -------------------------------------------------------------------------------- /src/Contrastive/augment/basic_augmentation/temporal_augment.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | 5 | def temporal_augment(l_new_data, augment_type, trace=False): 6 | """ 7 | :param l_new_data: 8 | :param rotation_type 9 | :return: 10 | """ 11 | B, C, T, H, W = l_new_data.size() 12 | if not trace: 13 | rotated_data = torch.zeros_like(l_new_data).cuda() 14 | else: 15 | rotated_data = l_new_data 16 | flip_type = augment_type // 4 17 | rot_type = augment_type % 4 18 | # flip at first 19 | for i in range(B): 20 | if flip_type[i] == 0: 21 | rotated_data[i] = l_new_data[i] 22 | elif flip_type[i] == 1: # left-right flip 23 | rotated_data[i] = l_new_data[i].flip(3) 24 | elif flip_type[i] == 2: # temporal flip 25 | rotated_data[i] = l_new_data[i].flip(1) 26 | else: # left-right + temporal flip 27 | rotated_data[i] = l_new_data[i].flip(3).flip(1) 28 | # then rotation 29 | for i in range(B): 30 | if rot_type[i] == 0: 31 | rotated_data[i] = l_new_data[i] 32 | elif rot_type[i] == 1: # 90 degree 33 | rotated_data[i] = l_new_data[i].transpose(2, 3).flip(2) 34 | elif rot_type[i] == 2: # 180 degree 35 | rotated_data[i] = l_new_data[i].flip(2).flip(3) 36 | else: # 270 degree 37 | rotated_data[i] = l_new_data[i].transpose(2, 3).flip(3) 38 | 39 | return rotated_data 40 | 41 | 42 | def inverse_temporal_augment(l_new_data, augment_type, trace=False): 43 | """ 44 | :param l_new_data: 45 | :param rotation_type 46 | :return: 47 | """ 48 | B, C, T, H, W = l_new_data.size() 49 | if not trace: 50 | rotated_data = torch.zeros_like(l_new_data).cuda() 51 | else: 52 | rotated_data = l_new_data 53 | flip_type = augment_type // 4 54 | rot_type = augment_type % 4 55 | # flip at first 56 | for i in range(B): 57 | if flip_type[i] == 0: 58 | rotated_data[i] = l_new_data[i] 59 | elif flip_type[i] == 1: # left-right flip 60 | rotated_data[i] = l_new_data[i].flip(3) 61 | elif flip_type[i] == 2: # temporal flip 62 | rotated_data[i] = l_new_data[i].flip(1) 63 | else: # left-right + temporal flip 64 | rotated_data[i] = l_new_data[i].flip(3).flip(1) 65 | # then rotation 66 | for i in range(B): 67 | if rot_type[i] == 0: 68 | rotated_data[i] = l_new_data[i] 69 | elif rot_type[i] == 1: # -90 degree 70 | rotated_data[i] = l_new_data[i].transpose(2, 3).flip(3) 71 | elif rot_type[i] == 2: # -180 degree 72 | rotated_data[i] = l_new_data[i].flip(3).flip(2) 73 | else: # -270 degree 74 | rotated_data[i] = l_new_data[i].transpose(2, 3).flip(2) 75 | 76 | return rotated_data -------------------------------------------------------------------------------- /src/Contrastive/augment/basic_augmentation/temporal_dropout.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | # @Time : 2019-05-13 19:34 5 | # @Author : Awiny 6 | # @Site : 7 | # @Project : pytorch_i3d 8 | # @File : TemporalDropoutBlock.py 9 | # @Software: PyCharm 10 | # @Github : https://github.com/FingerRec 11 | # @Blog : http://fingerrec.github.io 12 | """ 13 | import scipy.io 14 | import os 15 | import torch 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | import random 20 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #close the warning 21 | 22 | 23 | class DropBlock2D(nn.Module): 24 | r"""Randomly zeroes 2D spatial blocks of the input tensor. 25 | As described in the paper 26 | `DropBlock: A regularization method for convolutional networks`_ , 27 | dropping whole blocks of feature map allows to remove semantic 28 | information as compared to regular dropout. 29 | Args: 30 | drop_prob (float): probability of an element to be dropped. 31 | block_size (int): size of the block to drop 32 | Shape: 33 | - Input: `(N, C, H, W)` 34 | - Output: `(N, C, H, W)` 35 | .. _DropBlock: A regularization method for convolutional networks: 36 | https://arxiv.org/abs/1810.12890 37 | """ 38 | 39 | def __init__(self, drop_prob, block_size): 40 | super(DropBlock2D, self).__init__() 41 | 42 | self.drop_prob = drop_prob 43 | self.block_size = block_size 44 | 45 | def forward(self, x): 46 | # shape: (bsize, channels, height, width) 47 | 48 | assert x.dim() == 4, \ 49 | "Expected input with 4 dimensions (bsize, channels, height, width)" 50 | 51 | if not self.training or self.drop_prob == 0.: 52 | return x 53 | else: 54 | # get gamma value 55 | gamma = self._compute_gamma(x) 56 | 57 | # sample mask 58 | mask = (torch.rand(x.shape[0], *x.shape[2:]) < gamma).float() 59 | 60 | # place mask on input device 61 | mask = mask.to(x.device) 62 | 63 | # compute block mask 64 | block_mask = self._compute_block_mask(mask) 65 | 66 | # apply block mask 67 | out = x * block_mask[:, None, :, :] 68 | 69 | # scale output 70 | out = out * block_mask.numel() / block_mask.sum() 71 | 72 | return out 73 | 74 | def _compute_block_mask(self, mask): 75 | block_mask = F.max_pool2d(input=mask[:, None, :, :], 76 | kernel_size=(self.block_size, self.block_size), 77 | stride=(1, 1), 78 | padding=self.block_size // 2) 79 | 80 | if self.block_size % 2 == 0: 81 | block_mask = block_mask[:, :, :-1, :-1] 82 | 83 | block_mask = 1 - block_mask.squeeze(1) 84 | 85 | return block_mask 86 | 87 | def _compute_gamma(self, x): 88 | return self.drop_prob / (self.block_size ** 2) 89 | 90 | 91 | class DropBlock3D(DropBlock2D): 92 | r"""Randomly zeroes 3D spatial blocks of the input tensor. 93 | An extension to the concept described in the paper 94 | `DropBlock: A regularization method for convolutional networks`_ , 95 | dropping whole blocks of feature map allows to remove semantic 96 | information as compared to regular dropout. 97 | Args: 98 | drop_prob (float): probability of an element to be dropped. 99 | block_size (int): size of the block to drop 100 | Shape: 101 | - Input: `(N, C, D, H, W)` 102 | - Output: `(N, C, D, H, W)` 103 | .. _DropBlock: A regularization method for convolutional networks: 104 | https://arxiv.org/abs/1810.12890 105 | """ 106 | 107 | def __init__(self, drop_prob, block_size): 108 | super(DropBlock3D, self).__init__(drop_prob, block_size) 109 | 110 | def forward(self, x): 111 | # shape: (bsize, channels, depth, height, width) 112 | 113 | assert x.dim() == 5, \ 114 | "Expected input with 5 dimensions (bsize, channels, depth, height, width)" 115 | 116 | if not self.training or self.drop_prob == 0.: 117 | return x 118 | else: 119 | # get gamma value 120 | gamma = self._compute_gamma(x) 121 | 122 | # sample mask 123 | mask = (torch.rand(x.shape[0], *x.shape[2:]) < gamma).float() 124 | 125 | # place mask on input device 126 | mask = mask.to(x.device) 127 | 128 | # compute block mask 129 | block_mask = self._compute_block_mask(mask) 130 | 131 | # apply block mask 132 | out = x * block_mask[:, None, :, :, :] 133 | 134 | # scale output 135 | out = out * block_mask.numel() / block_mask.sum() 136 | 137 | return out 138 | 139 | def _compute_block_mask(self, mask): 140 | block_mask = F.max_pool3d(input=mask[:, None, :, :, :], 141 | kernel_size=(self.block_size, self.block_size, self.block_size), 142 | stride=(1, 1, 1), 143 | padding=self.block_size // 2) 144 | 145 | if self.block_size % 2 == 0: 146 | block_mask = block_mask[:, :, :-1, :-1, :-1] 147 | 148 | block_mask = 1 - block_mask.squeeze(1) 149 | 150 | return block_mask 151 | 152 | def _compute_gamma(self, x): 153 | return self.drop_prob / (self.block_size ** 3) 154 | 155 | 156 | class TemporalDropoutBlock(nn.Module): 157 | """ 158 | method1, for 3d feature map BxCxTxHxW reshape as Bx[CxT]xHxW 159 | """ 160 | def __init__(self, dropout_radio): 161 | super(TemporalDropoutBlock, self).__init__() 162 | self.dropout = nn.Dropout(dropout_radio) 163 | 164 | def forward(self, x): 165 | b, c, t, h, w = x.size() 166 | x = x.view(b, c*t, h, w) 167 | x = self.dropout(x) 168 | x = x.view(b, c, t, h, w) 169 | return x 170 | 171 | 172 | class TemporalDropoutBlock3D(nn.Module): 173 | r""" 174 | method2, for 3d feature map BxCxTxHxW, random dropout in T 175 | """ 176 | 177 | def __init__(self, drop_prob): 178 | super(TemporalDropoutBlock3D, self).__init__() 179 | self.dropout = nn.Dropout3d(drop_prob) 180 | 181 | def forward(self, x): 182 | x = x.permute(0, 2, 1, 3, 4) 183 | x = self.dropout(x) 184 | x = x.permute(0, 2, 1, 3, 4) 185 | return x 186 | 187 | 188 | class TemporalBranchDropout(nn.Module): 189 | """ 190 | Branch dropout 191 | """ 192 | def __init__(self, drop_prob): 193 | super(TemporalBranchDropout, self).__init__() 194 | self.dropout = nn.Dropout(1) 195 | self.drop_prob = drop_prob 196 | 197 | def forward(self, x): 198 | prob = random.random() 199 | if prob < self.drop_prob: 200 | x = self.dropout(x) 201 | #print(x) 202 | else: 203 | x = x 204 | return x 205 | 206 | -------------------------------------------------------------------------------- /src/Contrastive/augment/basic_augmentation/temporal_shuffle.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | # @Time : 2019-05-21 10:37 5 | # @Author : Awiny 6 | # @Site : 7 | # @Project : amax-pytorch-i3d 8 | # @File : TemporalShuffle.py 9 | # @Software: PyCharm 10 | # @Github : https://github.com/FingerRec 11 | # @Blog : http://fingerrec.github.io 12 | """ 13 | 14 | import torch 15 | import torch.nn as nn 16 | import os 17 | import difflib 18 | 19 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #close the warning 20 | 21 | 22 | class TemporalShuffle(nn.Module): 23 | """ 24 | for this module, random shuffle temporal dim, we want to find if the temporal information is important 25 | """ 26 | def __init__(self, s=1): 27 | super(TemporalShuffle, self).__init__() 28 | self.s = s 29 | 30 | def forward(self, x): 31 | """ 32 | random shuffle temporal dim 33 | :param x: b x c x t x h x w 34 | :return: out: b x c x t' x h x w 35 | """ 36 | t = x.size(2) 37 | origin_idx = list(range(t)) 38 | idxs = [] 39 | K = 4 40 | similarity = 1 41 | # ==================================method1======================== 42 | while similarity >= 1: 43 | if self.s == 1: 44 | idxs = torch.randperm(t) 45 | elif self.s == 2: 46 | idx = torch.randperm(K) 47 | for i in range(K): 48 | for j in range(t // K): 49 | idxs.append(idx[i].item() * t // K + j) 50 | else: 51 | for i in range(K): 52 | idx = torch.randperm(t//K) 53 | for j in range(len(idx)): 54 | idxs.append(t//K*i + idx[j].item()) 55 | similarity = difflib.SequenceMatcher(None, idxs, origin_idx).ratio() 56 | # print(idxs) 57 | # print(similarity) 58 | out = x[:, :, idxs, :, :] 59 | return out -------------------------------------------------------------------------------- /src/Contrastive/augment/basic_augmentation/temporal_sub.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | # @Time : 2019-05-12 21:54 5 | # @Author : Awiny 6 | # @Site : 7 | # @Project : amax-pytorch-i3d 8 | # @File : TemporalSubBlock.py 9 | # @Software: PyCharm 10 | # @Github : https://github.com/FingerRec 11 | # @Blog : http://fingerrec.github.io 12 | """ 13 | import scipy.io 14 | import os 15 | import torch.nn as nn 16 | import torch 17 | 18 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #close the warning 19 | 20 | class TemporalSubBlock(nn.Module): 21 | def __init__(self): 22 | super(TemporalSubBlock, self).__init__() 23 | def forward(self, x): 24 | b, c, t, h, w = x.size() 25 | y = torch.zeros((b, c, t-1, h, w)).cuda() 26 | for i in range(t-1): 27 | y[:,:,i,:,:] = x[:,:,i+1,:,:] - x[:,:,i,:,:] 28 | return y 29 | 30 | class TemporalSubMeanBlock(nn.Module): 31 | def __init__(self): 32 | super(TemporalSubMeanBlock, self).__init__() 33 | 34 | def forward(self, x): 35 | b, c, t, h, w = x.size() 36 | mean = x.mean(2) 37 | for i in range(t): 38 | x[:,:,i,:,:] -= mean 39 | return x 40 | -------------------------------------------------------------------------------- /src/Contrastive/augment/basic_augmentation/triplet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import difflib 3 | import random 4 | from augment.basic_augmentation.mixup_methods import SpatialMixup 5 | 6 | 7 | def swap_one_time(seq): 8 | """ 9 | swap a seq random place one time 10 | :param seq: 11 | :return: 12 | """ 13 | length = len(seq) 14 | index = random.randint(0, length-1) 15 | index2 = random.randint(0, length-1) 16 | new_seq = seq.clone() 17 | new_seq[index] = seq[index2] 18 | new_seq[index2] = seq[index] 19 | return new_seq 20 | 21 | 22 | def gen_sim_seq(K, radio=0.95, segments=4): 23 | """ 24 | generate shuffled video sequences as negative, while random shuffle is always zero, control the segments 25 | (eg. divide into 4 segments and shuffle these segments) 26 | :param K: 27 | :param radio: 28 | :return: 29 | """ 30 | similarity = 1 31 | idx = torch.arange(K) 32 | assert K % segments == 0 33 | seg_len = K // segments 34 | origin_idx = torch.arange(K).long() 35 | # revise_idx = torch.tensor(list(range(K - 1, -1, -1))) 36 | # print(revise_idx) 37 | while similarity > radio: 38 | # seg_idx = torch.randperm(segments) 39 | # for i in range(segments): 40 | # idx[i*seg_len:(i+1)*seg_len] = origin_idx[seg_idx[i]*seg_len:(seg_idx[i]+1)*seg_len] 41 | idx = swap_one_time(idx) 42 | similarity = difflib.SequenceMatcher(None, idx, origin_idx).ratio() 43 | # print(idx) 44 | # print(origin_idx) 45 | # print(similarity) 46 | return idx 47 | 48 | 49 | def batch_lst(lst, k): 50 | return lst[k:] + lst[:k] 51 | 52 | 53 | class TRIPLET(object): 54 | def __init__(self, t_radio=0.95, s_radio=0.7): 55 | self.t_radio = t_radio 56 | self.s_radio = s_radio 57 | self.spaital_mixup = SpatialMixup(0.8) 58 | 59 | def construct(self, input): 60 | b, c, t, h, w = input.size() 61 | #spatial_noise = generate_noise((c, h, w), b) 62 | # print(sum(sum(sum(sum(spatial_noise))))) 63 | # print(sum(sum(sum(sum(sum(input)))))) 64 | # postive = torch.zeros_like(input) 65 | # drop_radio = random.random() * self.s_radio 66 | postive = self.spaital_mixup.mixup_data(input, trace=False) 67 | # for i in range(t): 68 | # postive[:, :, i, :, :] = (1 - drop_radio) * input[:, :, i, :, :] + drop_radio * spatial_noise[:, :, :, :] 69 | # if the match low than radio, as it negative 70 | # negative_seq = gen_sim_seq(t, self.radio) 71 | # negative = torch.zeros_like(input) 72 | # negative = negative[:, :, negative_seq, :, :] 73 | index = random.randint(1, b-1) 74 | indexs = batch_lst(list(range(b)), index) 75 | negative = input[indexs] 76 | # print(negative - input) 77 | return input, postive, negative -------------------------------------------------------------------------------- /src/Contrastive/augment/basic_augmentation/video_color_jitter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | 4 | class ColorJitter(object): 5 | 6 | def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4): 7 | self.brightness = brightness 8 | self.contrast = contrast 9 | self.saturation = saturation 10 | 11 | def __call__(self, imgs): 12 | t, h, w, c = imgs.shape 13 | #print(t, h, w, c) 14 | self.transforms = [] 15 | if self.brightness != 0: 16 | self.transforms.append(Brightness(self.brightness)) 17 | if self.contrast != 0: 18 | self.transforms.append(Contrast(self.contrast)) 19 | if self.saturation != 0: 20 | self.transforms.append(Saturation(self.saturation)) 21 | 22 | random.shuffle(self.transforms) 23 | transform = Compose(self.transforms) 24 | # print(transform) 25 | for i in range(t): 26 | imgs[i, :, :, :] = transform(imgs[i, :, :, :]) 27 | return imgs 28 | 29 | class Saturation(object): 30 | 31 | def __init__(self, var): 32 | self.var = var 33 | 34 | def __call__(self, img): 35 | gs = Grayscale()(img) 36 | alpha = random.uniform(-self.var, self.var) 37 | # return img.lerp(gs, alpha) 38 | cover_img = img 39 | for i in range(3): 40 | cover_img[:,:,i] = (1-alpha) * img[:,:,i] + alpha * gs 41 | return cover_img 42 | 43 | 44 | class Brightness(object): 45 | 46 | def __init__(self, var): 47 | self.var = var 48 | 49 | def __call__(self, img): 50 | # gs = img.new().resize_as_(img).zero_() 51 | alpha = random.uniform(-self.var, self.var) 52 | return alpha * img 53 | # return img.lerp(gs, alpha) 54 | 55 | 56 | class Contrast(object): 57 | 58 | def __init__(self, var): 59 | self.var = var 60 | 61 | def __call__(self, img): 62 | # gs = Grayscale()(img) 63 | # gs.fill_(gs.mean()) 64 | # alpha = random.uniform(-self.var, self.var) 65 | # return img.lerp(gs, alpha) 66 | return np.mean(img) + self.var * (img- np.mean(img)) 67 | 68 | class Grayscale(object): 69 | 70 | def __call__(self, img): 71 | # gs = img.clone() 72 | # gs[0].mul_(0.299).add_(0.587, gs[1]).add_(0.114, gs[2]) 73 | # gs[1].copy_(gs[0]) 74 | # gs[2].copy_(gs[0]) 75 | gs = np.dot(img[...,:3], [0.2989, 0.5870, 0.1140]) 76 | return gs 77 | 78 | 79 | class Compose(object): 80 | """Composes several transforms together. 81 | Args: 82 | transforms (list of ``Transform`` objects): list of transforms to compose. 83 | Example: 84 | >>> transforms.Compose([ 85 | >>> transforms.CenterCrop(10), 86 | >>> transforms.ToTensor(), 87 | >>> ]) 88 | """ 89 | 90 | def __init__(self, transforms): 91 | self.transforms = transforms 92 | 93 | def __call__(self, img): 94 | for t in self.transforms: 95 | img = t(img) 96 | return img 97 | 98 | def __repr__(self): 99 | format_string = self.__class__.__name__ + '(' 100 | for t in self.transforms: 101 | format_string += '\n' 102 | format_string += ' {0}'.format(t) 103 | format_string += '\n)' 104 | return format_string 105 | -------------------------------------------------------------------------------- /src/Contrastive/augment/config.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | 4 | 5 | class TC(nn.Module): 6 | def __init__(self, args): 7 | super(TC, self).__init__() 8 | self.args = args 9 | 10 | def forward(self, input): 11 | output = input.cuda() 12 | # output = self.mixup.mixup_data(output) 13 | output = torch.autograd.Variable(output) 14 | return output 15 | -------------------------------------------------------------------------------- /src/Contrastive/augment/gen_negative.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import numpy as np 3 | import torch 4 | from .basic_augmentation.temporal_shuffle import TemporalShuffle 5 | from .basic_augmentation.temporal_dropout import TemporalDropoutBlock3D 6 | from .basic_augmentation.eda import VideoEda 7 | 8 | 9 | class GenNegative(nn.Module): 10 | def __init__(self, prob=0.3): 11 | super(GenNegative, self).__init__() 12 | self.prob = prob 13 | self.t_shuffle = TemporalShuffle() 14 | self.t_drop = TemporalDropoutBlock3D(0.1) 15 | self.t_eda = VideoEda() 16 | 17 | def temporal_dropout(self, x): 18 | return self.t_drop(x) 19 | 20 | def temporal_shuffle(self, x): 21 | return self.t_shuffle(x) 22 | 23 | def temporal_eda(self, x): 24 | return self.t_eda(x) 25 | 26 | def forward(self, x): 27 | # x = self.temporal_shuffle(x) 28 | # x = self.temporal_dropout(x) 29 | x = self.temporal_eda(x) 30 | return x -------------------------------------------------------------------------------- /src/Contrastive/augment/gen_positive.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from augment.basic_augmentation.mixup_methods import * 3 | from augment.video_transformations.videotransforms import ColorJitter 4 | 5 | 6 | class GenPositive(nn.Module): 7 | def __init__(self, prob=0.3): 8 | super(GenPositive, self).__init__() 9 | self.iv_mixup = SpatialMixup(0.3, trace=False, version=2) 10 | self.im_mixup = SpatialMixup(0.3, trace=False, version=3) 11 | self.cut = Cut(1, 0.05) 12 | self.prob = prob 13 | 14 | def intra_video_mixup(self, x): 15 | return self.iv_mixup.mixup_data(x) 16 | 17 | def inter_video_mixup(self, x): 18 | return self.im_mixup.mixup_data(x) 19 | 20 | def video_cut(self, x): 21 | return self.cut.cut_data(x) 22 | 23 | def forward(self, x): 24 | # x = self.inter_video_mixup(x) 25 | x = self.intra_video_mixup(x) 26 | return x 27 | -------------------------------------------------------------------------------- /src/Contrastive/augment/video_transformations/functional.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import torch 3 | # import cv2 4 | import numpy as np 5 | import PIL 6 | 7 | 8 | def _is_tensor_clip(clip): 9 | return torch.is_tensor(clip) and clip.ndimension() == 4 10 | 11 | 12 | def crop_clip(clip, min_h, min_w, h, w): 13 | if isinstance(clip[0], np.ndarray): 14 | cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip] 15 | 16 | elif isinstance(clip[0], PIL.Image.Image): 17 | cropped = [ 18 | img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip 19 | ] 20 | else: 21 | raise TypeError('Expected numpy.ndarray or PIL.Image' + 22 | 'but got list of {0}'.format(type(clip[0]))) 23 | return cropped 24 | 25 | 26 | def resize_clip(clip, size, interpolation='bilinear'): 27 | if isinstance(clip[0], np.ndarray): 28 | if isinstance(size, numbers.Number): 29 | im_h, im_w, im_c = clip[0].shape 30 | # Min spatial dim already matches minimal size 31 | if (im_w <= im_h and im_w == size) or (im_h <= im_w 32 | and im_h == size): 33 | return clip 34 | new_h, new_w = get_resize_sizes(im_h, im_w, size) 35 | size = (new_w, new_h) 36 | else: 37 | size = size[1], size[0] 38 | if interpolation == 'bilinear': 39 | np_inter = cv2.INTER_LINEAR 40 | else: 41 | np_inter = cv2.INTER_NEAREST 42 | scaled = [ 43 | cv2.resize(img, size, interpolation=np_inter) for img in clip 44 | ] 45 | elif isinstance(clip[0], PIL.Image.Image): 46 | if isinstance(size, numbers.Number): 47 | im_w, im_h = clip[0].size 48 | # Min spatial dim already matches minimal size 49 | if (im_w <= im_h and im_w == size) or (im_h <= im_w 50 | and im_h == size): 51 | return clip 52 | new_h, new_w = get_resize_sizes(im_h, im_w, size) 53 | size = (new_w, new_h) 54 | else: 55 | size = size[1], size[0] 56 | if interpolation == 'bilinear': 57 | pil_inter = PIL.Image.NEAREST 58 | else: 59 | pil_inter = PIL.Image.BILINEAR 60 | scaled = [img.resize(size, pil_inter) for img in clip] 61 | else: 62 | raise TypeError('Expected numpy.ndarray or PIL.Image' + 63 | 'but got list of {0}'.format(type(clip[0]))) 64 | return scaled 65 | 66 | 67 | def get_resize_sizes(im_h, im_w, size): 68 | if im_w < im_h: 69 | ow = size 70 | oh = int(size * im_h / im_w) 71 | else: 72 | oh = size 73 | ow = int(size * im_w / im_h) 74 | return oh, ow 75 | 76 | 77 | def normalize(clip, mean, std, inplace=False): 78 | if not _is_tensor_clip(clip): 79 | raise TypeError('tensor is not a torch clip.') 80 | 81 | if not inplace: 82 | clip = clip.clone() 83 | 84 | dtype = clip.dtype 85 | dim = len(mean) 86 | mean = torch.as_tensor(mean, dtype=dtype, device=clip.device) 87 | std = torch.as_tensor(std, dtype=dtype, device=clip.device) 88 | # print(clip.size()) 89 | # if dim == 3: 90 | clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) 91 | # else: 92 | # clip.sub_(mean[:, None, None]).div_(std[:, None, None]) 93 | return clip -------------------------------------------------------------------------------- /src/Contrastive/augment/video_transformations/volume_transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import torch 4 | 5 | def convert_img(img): 6 | """Converts (H, W, C) numpy.ndarray to (C, W, H) format 7 | """ 8 | if len(img.shape) == 3: 9 | img = img.transpose(2, 0, 1) 10 | if len(img.shape) == 2: 11 | img = np.expand_dims(img, 0) 12 | return img 13 | 14 | 15 | class ClipToTensor(object): 16 | """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255] 17 | to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0] 18 | """ 19 | 20 | def __init__(self, channel_nb=3, div_255=True, numpy=False): 21 | self.channel_nb = channel_nb 22 | self.div_255 = div_255 23 | self.numpy = numpy 24 | 25 | def __call__(self, clip): 26 | """ 27 | Args: clip (list of numpy.ndarray): clip (list of images) 28 | to be converted to tensor. 29 | """ 30 | # Retrieve shape 31 | if isinstance(clip[0], np.ndarray): 32 | h, w, ch = clip[0].shape 33 | assert ch == self.channel_nb, 'Got {0} instead of 3 channels'.format( 34 | ch) 35 | elif isinstance(clip[0], Image.Image): 36 | w, h = clip[0].size 37 | else: 38 | raise TypeError('Expected numpy.ndarray or PIL.Image\ 39 | but got list of {0}'.format(type(clip[0]))) 40 | 41 | np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)]) 42 | 43 | # Convert 44 | for img_idx, img in enumerate(clip): 45 | if isinstance(img, np.ndarray): 46 | pass 47 | elif isinstance(img, Image.Image): 48 | img = np.array(img, copy=False) 49 | else: 50 | raise TypeError('Expected numpy.ndarray or PIL.Image\ 51 | but got list of {0}'.format(type(clip[0]))) 52 | img = convert_img(img) 53 | np_clip[:, img_idx, :, :] = img 54 | if self.numpy: 55 | if self.div_255: 56 | np_clip = np_clip / 255 57 | return np_clip 58 | 59 | else: 60 | tensor_clip = torch.from_numpy(np_clip) 61 | 62 | if not isinstance(tensor_clip, torch.FloatTensor): 63 | tensor_clip = tensor_clip.float() 64 | if self.div_255: 65 | tensor_clip = tensor_clip.div(255) 66 | return tensor_clip 67 | 68 | 69 | class ToTensor(object): 70 | """Converts numpy array to tensor 71 | """ 72 | 73 | def __call__(self, array): 74 | tensor = torch.from_numpy(array) 75 | return tensor -------------------------------------------------------------------------------- /src/Contrastive/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/BE/632c3aa0eaa3acc24a545ec05a9a36f96592cb2c/src/Contrastive/data/__init__.py -------------------------------------------------------------------------------- /src/Contrastive/data/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | 4 | 5 | # from config import opt 6 | def video_to_tensor(pic): 7 | """Convert a ``numpy.ndarray`` to tensor. 8 | Converts a numpy.ndarray (T x H x W x C) 9 | to a torch.FloatTensor of shape (C x T x H x W) 10 | 11 | Args: 12 | pic (numpy.ndarray): Video to be converted to tensor. 13 | Returns: 14 | Tensor: Converted video. 15 | """ 16 | # return torch.from_numpy(pic) 17 | return torch.from_numpy(pic.transpose([3, 0, 1, 2])).type(torch.FloatTensor) 18 | 19 | 20 | class VideoRecord(object): 21 | def __init__(self, row): 22 | self._data = row 23 | 24 | @property 25 | def path(self): 26 | return self._data[0] 27 | 28 | @property 29 | def num_frames(self): 30 | return int(self._data[1]) - 1 31 | 32 | @property 33 | def label(self): 34 | return int(self._data[2]) 35 | 36 | 37 | def video_frame_count(video_path): 38 | cap = cv2.VideoCapture(video_path) 39 | if not cap.isOpened(): 40 | #print("could not open: ", video_path) 41 | return -1 42 | length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 43 | width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) ) 44 | height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) ) 45 | return length, width, height 46 | 47 | # 48 | # if __name__ == '__main__': 49 | # print("") -------------------------------------------------------------------------------- /src/Contrastive/data/config.py: -------------------------------------------------------------------------------- 1 | import augment.video_transformations.videotransforms as videotransforms 2 | import augment.video_transformations.video_transform_PIL_or_np as video_transform 3 | from augment.video_transformations.volume_transforms import ClipToTensor 4 | 5 | from torchvision import transforms 6 | 7 | 8 | def pt_data_config(args): 9 | if args.pt_dataset == 'ucf101': 10 | num_class = 101 11 | image_tmpl = "frame{:06d}.jpg" 12 | elif args.pt_dataset == 'hmdb51': 13 | num_class = 51 14 | # image_tmpl = "frame{:06d}.jpg" 15 | image_tmpl = "img_{:05d}.jpg" 16 | # image_tmpl = "image_{:05d}.jpg" 17 | elif args.pt_dataset == 'kinetics': 18 | num_class = 400 19 | image_tmpl = "img_{:05d}.jpg" 20 | # args.root = "/data1/DataSet/Kinetics/compress/" 21 | elif args.pt_dataset == 'sth_v1': 22 | num_class = 174 23 | image_tmpl = "{:05d}.jpg" 24 | elif args.pt_dataset == 'diving48': 25 | num_class = 48 26 | image_tmpl = "image_{:05d}.jpg" 27 | else: 28 | raise ValueError('Unknown dataset ' + args.dataset) 29 | return num_class, int(args.pt_data_length), image_tmpl 30 | 31 | 32 | def ft_data_config(args): 33 | if args.ft_dataset == 'ucf101': 34 | num_class = 101 35 | image_tmpl = "frame{:06d}.jpg" 36 | elif args.ft_dataset == 'hmdb51': 37 | num_class = 51 38 | # image_tmpl = "frame{:06d}.jpg" 39 | image_tmpl = "img_{:05d}.jpg" 40 | # image_tmpl = "image_{:05d}.jpg" 41 | elif args.ft_dataset == 'kinetics': 42 | num_class = 400 43 | image_tmpl = "img_{:05d}.jpg" 44 | # args.root = "/data1/DataSet/Kinetics/compress/" 45 | elif args.ft_dataset == 'sth_v1': 46 | num_class = 174 47 | image_tmpl = "{:05d}.jpg" 48 | elif args.ft_dataset == 'diving48': 49 | num_class = 48 50 | image_tmpl = "image_{:05d}.jpg" 51 | else: 52 | raise ValueError('Unknown dataset ' + args.dataset) 53 | return num_class, int(args.ft_data_length), image_tmpl 54 | 55 | 56 | def pt_augmentation_config(args): 57 | if int(args.pt_spatial_size) == 112: 58 | # print("??????????????????????????") 59 | resize_size = 128 60 | else: 61 | resize_size = 256 62 | if args.pt_mode == 'rgb': 63 | normalize = video_transform.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 64 | else: 65 | normalize = video_transform.Normalize(mean=[0.485, 0.456], std=[0.229, 0.224]) 66 | train_transforms = transforms.Compose([ 67 | # videotransforms.RandomCrop(int(args.spatial_size)), 68 | video_transform.RandomRotation(10), 69 | # video_transform.ColorDistortion(1), 70 | # video_transform.STA_RandomRotation(10), 71 | # video_transform.Each_RandomRotation(10), 72 | video_transform.Resize(resize_size), 73 | video_transform.RandomCrop(int(args.pt_spatial_size)), 74 | video_transform.ColorJitter(0.5, 0.5, 0.25, 0.5), 75 | ClipToTensor(channel_nb=3 if args.pt_mode == 'rgb' else 2), 76 | normalize 77 | # videotransforms.ColorJitter(), 78 | # videotransforms.RandomHorizontalFlip() 79 | ]) 80 | test_transforms = transforms.Compose([ 81 | video_transform.Resize(resize_size), 82 | video_transform.CenterCrop(int(args.pt_spatial_size)), 83 | ClipToTensor(channel_nb=3 if args.pt_mode == 'rgb' else 2), 84 | normalize 85 | ] 86 | ) 87 | eval_transfroms = transforms.Compose([ 88 | video_transform.Resize(resize_size), 89 | video_transform.CenterCrop(int(args.pt_spatial_size)), 90 | ClipToTensor(channel_nb=3 if args.pt_mode == 'rgb' else 2), 91 | normalize 92 | ] 93 | ) 94 | return train_transforms, test_transforms, eval_transfroms 95 | 96 | 97 | def ft_augmentation_config(args): 98 | if int(args.ft_spatial_size) == 112: 99 | resize_size = 128 100 | else: 101 | resize_size = 256 102 | if args.ft_mode == 'rgb': 103 | normalize = video_transform.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 104 | else: 105 | normalize = video_transform.Normalize(mean=[0.485, 0.456], std=[0.229, 0.224]) 106 | train_transforms = transforms.Compose([ 107 | video_transform.RandomRotation(10), 108 | video_transform.Resize(resize_size), 109 | video_transform.RandomCrop(int(args.ft_spatial_size)), 110 | video_transform.ColorJitter(0.5, 0.5, 0.25, 0.5), 111 | ClipToTensor(channel_nb=3 if args.ft_mode == 'rgb' else 2), 112 | normalize 113 | ]) 114 | test_transforms = transforms.Compose([ 115 | video_transform.Resize(resize_size), 116 | video_transform.CenterCrop(int(args.ft_spatial_size)), 117 | ClipToTensor(channel_nb=3 if args.ft_mode == 'rgb' else 2), 118 | normalize 119 | ] 120 | ) 121 | eval_transfroms = transforms.Compose([ 122 | video_transform.Resize(resize_size), 123 | video_transform.CenterCrop(int(args.ft_spatial_size)), 124 | ClipToTensor(channel_nb=3 if args.ft_mode == 'rgb' else 2), 125 | normalize 126 | ] 127 | ) 128 | return train_transforms, test_transforms, eval_transfroms 129 | -------------------------------------------------------------------------------- /src/Contrastive/data/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def pt_data_loader_init(args, data_length, image_tmpl, train_transforms, test_transforms, eval_transforms): 5 | if args.pt_dataset in ['ucf101', 'hmdb51', 'diving48', 'sth_v1']: 6 | from data.dataset import DataSet as DataSet 7 | elif args.pt_dataset == 'kinetics': 8 | from data.video_dataset import VideoDataSet as DataSet 9 | else: 10 | Exception("unsupported dataset") 11 | train_dataset = DataSet(args, args.pt_root, args.pt_train_list, num_segments=1, new_length=data_length, 12 | stride=args.pt_stride, modality=args.pt_mode, dataset=args.pt_dataset, test_mode=False, 13 | image_tmpl=image_tmpl if args.pt_mode in ["rgb", "RGBDiff"] 14 | else args.pt_flow_prefix + "{}_{:05d}.jpg", transform=train_transforms) 15 | print("training samples:{}".format(train_dataset.__len__())) 16 | train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.pt_batch_size, shuffle=True, 17 | num_workers=args.pt_workers, pin_memory=True) 18 | val_dataset = DataSet(args, args.pt_root, args.pt_val_list, num_segments=1, new_length=data_length, 19 | stride=args.pt_stride, modality=args.pt_mode, test_mode=True, dataset=args.pt_dataset, 20 | image_tmpl=image_tmpl if args.pt_mode in ["rgb", "RGBDiff"] else args.pt_flow_prefix + "{}_{:05d}.jpg", 21 | random_shift=False, transform=test_transforms) 22 | print("val samples:{}".format(val_dataset.__len__())) 23 | val_data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.pt_batch_size, shuffle=False, 24 | num_workers=args.pt_workers, pin_memory=True) 25 | eval_dataset = DataSet(args, args.pt_root, args.pt_val_list, num_segments=1, new_length=data_length, 26 | stride=args.pt_stride, modality=args.pt_mode, test_mode=True, dataset=args.pt_dataset, 27 | image_tmpl=image_tmpl if args.pt_mode in ["rgb", "RGBDiff"] else args.pt_flow_prefix + "{}_{:05d}.jpg", 28 | random_shift=False, transform=eval_transforms, full_video=True) 29 | print("eval samples:{}".format(eval_dataset.__len__())) 30 | eval_data_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=args.pt_batch_size, shuffle=False, 31 | num_workers=args.pt_workers, pin_memory=True) 32 | return train_data_loader, val_data_loader, eval_data_loader, train_dataset.__len__(), val_dataset.__len__(), eval_dataset.__len__() 33 | 34 | 35 | def ft_data_loader_init(args, data_length, image_tmpl, train_transforms, test_transforms, eval_transforms): 36 | if args.ft_dataset in ['ucf101', 'hmdb51', 'diving48', 'sth_v1']: 37 | from data.dataset import DataSet as DataSet 38 | elif args.ft_dataset == 'kinetics': 39 | from data.video_dataset import VideoDataSet as DataSet 40 | else: 41 | Exception("unsupported dataset") 42 | train_dataset = DataSet(args, args.ft_root, args.ft_train_list, num_segments=1, new_length=data_length, 43 | stride=args.ft_stride, modality=args.ft_mode, dataset=args.ft_dataset, test_mode=False, 44 | image_tmpl=image_tmpl if args.ft_mode in ["rgb", "RGBDiff"] 45 | else args.flow_prefix + "{}_{:05d}.jpg", transform=train_transforms) 46 | print("training samples:{}".format(train_dataset.__len__())) 47 | train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.ft_batch_size, shuffle=True, 48 | num_workers=args.ft_workers, pin_memory=True) 49 | val_dataset = DataSet(args, args.ft_root, args.ft_val_list, num_segments=1, new_length=data_length, 50 | stride=args.ft_stride, modality=args.ft_mode, test_mode=True, dataset=args.ft_dataset, 51 | image_tmpl=image_tmpl if args.ft_mode in ["rgb", "RGBDiff"] else args.flow_prefix + "{}_{:05d}.jpg", 52 | random_shift=False, transform=test_transforms) 53 | print("val samples:{}".format(val_dataset.__len__())) 54 | val_data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.ft_batch_size, shuffle=False, 55 | num_workers=args.ft_workers, pin_memory=True) 56 | eval_dataset = DataSet(args, args.ft_root, args.ft_val_list, num_segments=1, new_length=data_length, 57 | stride=args.ft_stride, modality=args.ft_mode, test_mode=True, dataset=args.ft_dataset, 58 | image_tmpl=image_tmpl if args.ft_mode in ["rgb", "RGBDiff"] else args.ft_flow_prefix + "{}_{:05d}.jpg", 59 | random_shift=False, transform=eval_transforms, full_video=True) 60 | print("eval samples:{}".format(eval_dataset.__len__())) 61 | eval_data_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=args.ft_batch_size, shuffle=False, 62 | num_workers=args.ft_workers, pin_memory=True) 63 | return train_data_loader, val_data_loader, eval_data_loader, train_dataset.__len__(), val_dataset.__len__(), eval_dataset.__len__() 64 | -------------------------------------------------------------------------------- /src/Contrastive/data/on_the_fly_test.py: -------------------------------------------------------------------------------- 1 | import lintel 2 | import numpy as np 3 | import cv2 4 | import skvideo.io 5 | 6 | def video_frame_count(video_path): 7 | cap = cv2.VideoCapture(video_path) 8 | if not cap.isOpened(): 9 | #print("could not open: ", video_path) 10 | return -1 11 | length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 12 | width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) ) 13 | height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) ) 14 | return length, width, height 15 | 16 | v_path = "/data1/DataSet/Kinetics/compress/val_256/bee_keeping/p3_wp0Cq6Lo.mp4" 17 | video_frames_num, width, height = video_frame_count(v_path) 18 | 19 | # video, width, height, seek_index = lintel.loadvid(open(v_path, 'rb').read(), should_random_seek=False) 20 | # video = np.reshape(np.frombuffer(video, dtype=np.uint8), (-1, height, width, 3)) 21 | # num_frames = video.shape 22 | # print(video_frames_num, num_frames) 23 | # 24 | # videodata = skvideo.io.vread(v_path, inputdict={'-r': '4'}) 25 | # print(videodata.shape) 26 | print(video_frames_num) 27 | f = open(v_path, 'rb') 28 | video = f.read() 29 | f.close() 30 | 31 | # ffmpeg count 比 cv2少1帧 32 | frame_nums = [0,201,280, 295, 296] 33 | decoded_frames = lintel.loadvid_frame_nums(video, 34 | frame_nums=frame_nums, 35 | width=width, 36 | height=height) 37 | decoded_frames = np.frombuffer(decoded_frames, dtype=np.uint8) 38 | decoded_frames = np.reshape( 39 | decoded_frames, 40 | newshape=(len(frame_nums), height, width, 3)) 41 | 42 | print(np.shape(decoded_frames)[0]) 43 | 44 | # pytorch vision里的Kinetics dataset可以简单改写,可以得到音频,但是不能限定开始的index 45 | # self.video_clips = VideoClips( 46 | # video_list, 47 | # frames_per_clip, 48 | # step_between_clips, 49 | # frame_rate, 50 | # _precomputed_metadata, 51 | # num_workers=num_workers, 52 | # _video_width=_video_width, 53 | # _video_height=_video_height, 54 | # _video_min_dimension=_video_min_dimension, 55 | # _audio_samples=_audio_samples, 56 | # _audio_channels=_audio_channels, 57 | # ) -------------------------------------------------------------------------------- /src/Contrastive/feature_extractor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import torch.backends.cudnn as cudnn 5 | from data.config import data_config, augmentation_pretext_config 6 | from data.dataloader import data_loader_init 7 | from model.config import pretext_model_config 8 | from augment.config import TC 9 | from bk.option_old import args 10 | import torch.nn as nn 11 | 12 | 13 | def single_extract(tc, val_loader, model): 14 | model.eval() 15 | features = {'data':[], 'target':[]} 16 | with torch.no_grad(): 17 | for i, (input, target, index) in enumerate(val_loader): 18 | inputs = input 19 | # inputs = tc(input) 20 | output = model(inputs) 21 | # print(output.size()) 22 | output = nn.AdaptiveAvgPool3d(1)(output).view(output.size(0), output.size(1)) 23 | # print(output.size()) 24 | # print(target) 25 | for j in range(output.size(0)): 26 | features['data'].append(output[j].cpu().numpy()) 27 | features['target'].append(target[j].cpu().numpy()) 28 | if i % 10 == 0: 29 | print("{}/{} finished".format(i, len(val_loader))) 30 | return features 31 | 32 | 33 | def feature_extract(tc, data_loader, model): 34 | features = single_extract(tc, data_loader, model) 35 | return features 36 | 37 | 38 | def main(): 39 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # close the warning 40 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus 41 | cudnn.benchmark = True 42 | # == dataset config== 43 | num_class, data_length, image_tmpl = data_config(args) 44 | train_transforms, test_transforms, _ = augmentation_pretext_config(args) 45 | train_data_loader, val_data_loader, _, _, _, _ = data_loader_init(args, data_length, image_tmpl, train_transforms, 46 | test_transforms, _) 47 | # == model config== 48 | model = pretext_model_config(args, num_class) 49 | tc = TC(args) 50 | # front = "contrastive_kinetics_warpping_{}".format(args.dataset) 51 | # front = "triplet_ucf101_warpping_{}".format(args.dataset) 52 | front = "{}_{}_{}".format(args.arch, args.front, args.dataset) 53 | dir = '../experiments/features/{}'.format(front) 54 | if not os.path.exists(dir): 55 | os.makedirs(dir) 56 | features = feature_extract(tc, val_data_loader, model) 57 | np.save('{}/val_features.npy'.format(dir), features) 58 | features = feature_extract(tc, train_data_loader, model) 59 | np.save('{}/train_features.npy'.format(dir), features) 60 | 61 | 62 | if __name__ == '__main__': 63 | main() 64 | -------------------------------------------------------------------------------- /src/Contrastive/ft.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import os 4 | import time 5 | from utils.utils import Timer 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | from utils.utils import AverageMeter 9 | from data.config import ft_data_config, ft_augmentation_config 10 | from data.dataloader import ft_data_loader_init 11 | from model.config import ft_model_config 12 | from loss.config import ft_optim_init 13 | from utils.learning_rate_adjust import ft_adjust_learning_rate 14 | from augment.config import TC 15 | from utils.utils import accuracy 16 | import random 17 | from datetime import datetime 18 | 19 | lowest_val_loss = float('inf') 20 | best_prec1 = 0 21 | torch.manual_seed(1) 22 | 23 | 24 | def fine_tune_train_and_val(args, recorder): 25 | # = 26 | global lowest_val_loss, best_prec1 27 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # close the warning 28 | torch.manual_seed(1) 29 | cudnn.benchmark = True 30 | timer = Timer() 31 | # == dataset config== 32 | num_class, data_length, image_tmpl = ft_data_config(args) 33 | train_transforms, test_transforms, eval_transforms = ft_augmentation_config(args) 34 | train_data_loader, val_data_loader, _, _, _, _ = ft_data_loader_init(args, data_length, image_tmpl, train_transforms, test_transforms, eval_transforms) 35 | # == model config== 36 | model = ft_model_config(args, num_class) 37 | recorder.record_message('a', '='*100) 38 | recorder.record_message('a', '-'*40+'finetune'+'-'*40) 39 | recorder.record_message('a', '='*100) 40 | # == optim config== 41 | train_criterion, val_criterion, optimizer = ft_optim_init(args, model) 42 | # == data augmentation(self-supervised) config== 43 | tc = TC(args) 44 | # == train and eval== 45 | print('*'*70+'Step2: fine tune'+'*'*50) 46 | for epoch in range(args.ft_start_epoch, args.ft_epochs): 47 | timer.tic() 48 | ft_adjust_learning_rate(optimizer, args.ft_lr, epoch, args.ft_lr_steps) 49 | train_prec1, train_loss = train(args, tc, train_data_loader, model, train_criterion, optimizer, epoch, recorder) 50 | # train_prec1, train_loss = random.random() * 100, random.random() 51 | recorder.record_ft_train(train_loss / 5.0, train_prec1 / 100.0) 52 | if (epoch + 1) % args.ft_eval_freq == 0: 53 | val_prec1, val_loss = validate(args, tc, val_data_loader, model, val_criterion, recorder) 54 | # val_prec1, val_loss = random.random() * 100, random.random() 55 | recorder.record_ft_val(val_loss / 5.0, val_prec1 / 100.0) 56 | is_best = val_prec1 > best_prec1 57 | best_prec1 = max(val_prec1, best_prec1) 58 | checkpoint = {'epoch': epoch + 1, 'arch': "i3d", 'state_dict': model.state_dict(), 59 | 'best_prec1': best_prec1} 60 | recorder.save_ft_model(checkpoint, is_best) 61 | timer.toc() 62 | left_time = timer.average_time * (args.ft_epochs - epoch) 63 | message = "Step2: fine tune best_prec1 is: {} left time is : {} now is : {}".format(best_prec1, timer.format(left_time), datetime.now()) 64 | print(message) 65 | recorder.record_message('a', message) 66 | return recorder.filename 67 | 68 | 69 | def train(args, tc, train_loader, model, criterion, optimizer, epoch, recorder, MoCo_init=False): 70 | batch_time = AverageMeter() 71 | data_time = AverageMeter() 72 | losses = AverageMeter() 73 | top1 = AverageMeter() 74 | top3 = AverageMeter() 75 | if MoCo_init: 76 | model.eval() 77 | else: 78 | model.train() 79 | end = time.time() 80 | for i, (input, target, index) in enumerate(train_loader): 81 | data_time.update(time.time() - end) 82 | target = target.cuda() 83 | index = index.cuda() 84 | inputs = tc(input) 85 | target = torch.autograd.Variable(target) 86 | output = model(inputs) 87 | loss = criterion(output, target) # + mse_loss 88 | prec1, prec3 = accuracy(output.data, target, topk=(1, 3)) 89 | losses.update(loss.data.item(), input.size(0)) 90 | top1.update(prec1.item(), input.size(0)) 91 | top3.update(prec3.item(), input.size(0)) 92 | 93 | optimizer.zero_grad() 94 | loss.backward() 95 | # # gradient check 96 | # plot_grad_flow(model.module.base_model.named_parameters()) 97 | optimizer.step() 98 | batch_time.update(time.time() - end) 99 | end = time.time() 100 | 101 | if i % args.ft_print_freq == 0: 102 | message = ('Epoch: [{0}][{1}/{2}], lr: {lr:.5f}\t' 103 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 104 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 105 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( 106 | epoch, i, len(train_loader), batch_time=batch_time, 107 | data_time=data_time, loss=losses, lr=optimizer.param_groups[-1]['lr'])) 108 | print(message) 109 | recorder.record_message('a', message) 110 | message = "Finetune Training: Top1:{} Top3:{}".format(top1.avg, top3.avg) 111 | print(message) 112 | recorder.record_message('a', message) 113 | return top1.avg, losses.avg 114 | 115 | 116 | def validate(args, tc, val_loader, model, criterion, recorder, MoCo_init=False): 117 | batch_time = AverageMeter() 118 | losses = AverageMeter() 119 | top1 = AverageMeter() 120 | top3 = AverageMeter() 121 | # switch to evaluate mode 122 | model.eval() 123 | end = time.time() 124 | with torch.no_grad(): 125 | for i, (input, target, index) in enumerate(val_loader): 126 | target = target.cuda() 127 | inputs = tc(input) 128 | target = torch.autograd.Variable(target) 129 | output = model(inputs) 130 | loss = criterion(output, target) 131 | prec1, prec3 = accuracy(output.data, target, topk=(1, 3)) 132 | losses.update(loss.data.item(), input.size(0)) 133 | top1.update(prec1.item(), input.size(0)) 134 | top3.update(prec3.item(), input.size(0)) 135 | batch_time.update(time.time() - end) 136 | end = time.time() 137 | 138 | if i % args.ft_print_freq == 0: 139 | message = ('Test: [{0}/{1}]\t' 140 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 141 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( 142 | i, len(val_loader), batch_time=batch_time, loss=losses 143 | )) 144 | print(message) 145 | recorder.record_message('a', message) 146 | message = "Finetune Eval: Top1:{} Top3:{}".format(top1.avg, top3.avg) 147 | print(message) 148 | recorder.record_message('a', message) 149 | return top1.avg, losses.avg 150 | 151 | -------------------------------------------------------------------------------- /src/Contrastive/loss/NCE/Link.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class Node: 4 | def __init__(self, data, _pre=None, _next=None): 5 | self.data = data # in format [loss, feature] 6 | # self.data = data[0] 7 | # self.loss = data[1] 8 | # self.index = data[2] 9 | self._pre = _pre 10 | self._next = _next 11 | def __str__(self): 12 | return str(self.data[2]) 13 | 14 | class DoublyLink: 15 | def __init__(self): 16 | self.tail = None 17 | self.head = None 18 | self.size = 0 19 | 20 | # at the end 21 | def append(self, new_node): 22 | tmp_node = self.tail 23 | tmp_node._pre = new_node 24 | new_node._next = tmp_node 25 | new_node._pre = None 26 | self.tail = new_node 27 | return new_node 28 | 29 | # at the head 30 | def add_first(self, new_node): 31 | tmp_node = self.head 32 | tmp_node._next = new_node 33 | new_node._pre = tmp_node 34 | new_node._next = None 35 | self.head = new_node 36 | return new_node 37 | 38 | def insert_before(self, node, new_node): 39 | node._next._pre = new_node 40 | new_node._next = node._next 41 | new_node._pre = node 42 | node._next = new_node 43 | return new_node 44 | 45 | def insert_after(self, node, new_node): 46 | if node._pre is None: 47 | return self.append(new_node) 48 | else: 49 | return self.insert_before(node._pre, new_node) 50 | # node._next = new_node 51 | # new_node._next = None 52 | # new_node._pre = node 53 | # self.head = new_node 54 | # return new_node 55 | 56 | def insert(self, data): 57 | if isinstance(data, Node): 58 | tmp_node = data 59 | else: 60 | tmp_node = Node(data) 61 | if self.size == 0: 62 | self.tail = tmp_node 63 | self.head = self.tail 64 | else: 65 | # pre_node = self.head 66 | tmp_node = self.add_first(tmp_node) 67 | # while pre_node.data[0] > tmp_node.data[0] and pre_node._pre != None: 68 | # pre_node = pre_node._pre 69 | # #insert before 70 | # # print(pre_node._pre, pre_node._next) 71 | # if pre_node._pre is None and pre_node.data[0] >= tmp_node.data[0]: 72 | # tmp_node = self.append(tmp_node) 73 | # elif pre_node._next is None and pre_node.data[0] < tmp_node.data[0]: 74 | # tmp_node = self.add_first(tmp_node) 75 | # elif pre_node._next is None and pre_node.data[0] >= tmp_node.data[0]: 76 | # tmp_node = self.insert_after(pre_node, tmp_node) 77 | # else: 78 | # tmp_node = self.insert_before(pre_node, tmp_node) 79 | self.size += 1 80 | return tmp_node 81 | 82 | def remove(self, node): 83 | if node == self.head: 84 | self.head._pre._next = None 85 | self.head = self.head._pre 86 | elif node == self.tail: 87 | self.tail._next._pre = None 88 | self.tail = self.tail._next 89 | else: 90 | node._next._pre = node._pre 91 | node._pre._next = node._next 92 | self.size -= 1 93 | 94 | def __str__(self): 95 | str_text = "" 96 | cur_node = self.head 97 | count = 0 98 | while cur_node != None: 99 | str_text += str(cur_node.data[2]) + " " 100 | cur_node = cur_node._pre 101 | count += 1 102 | if count > 20: 103 | break 104 | return str_text 105 | 106 | 107 | class LRUCache: 108 | def __init__(self, size): 109 | self.size = size 110 | self.hash_map = dict() 111 | self.link = DoublyLink() 112 | self.LRU_init(size) 113 | 114 | def LRU_init(self, size): 115 | for i in range(size): 116 | self.set(i, [1e-8, torch.rand(128), i]) 117 | 118 | def set(self, key, value): 119 | if self.size == self.link.size: 120 | self.link.remove(self.link.tail) 121 | if key in self.hash_map: 122 | self.link.remove(self.hash_map.get(key)) 123 | tmp_node = self.link.insert(value) 124 | self.hash_map.__setitem__(key, tmp_node) 125 | 126 | def get(self, key): 127 | tmp_node = self.hash_map.get(key) 128 | self.link.remove(tmp_node) 129 | self.link.insert(tmp_node) 130 | return tmp_node.data 131 | 132 | def get_queue(self, num, keys): 133 | queue = torch.rand(num, 128).cuda() 134 | num_queue = 0 135 | cur_node = self.link.head 136 | while num_queue < num: 137 | if cur_node.data[2] not in keys: 138 | queue[num_queue] = cur_node.data[1] 139 | num_queue += 1 140 | cur_node = cur_node._pre 141 | # while num_queue < num: 142 | # # if cur_node.data[2] not in keys: 143 | # queue[num_queue] = cur_node.data 144 | # num_queue += 1 145 | # cur_node = cur_node._next 146 | return queue 147 | 148 | def update_queue(self, queue): 149 | num = queue.size(0) 150 | # print(num) 151 | cur_node = self.link.head 152 | for i in range(num): 153 | queue[i] = cur_node.data[1] 154 | cur_node = cur_node._pre 155 | return queue 156 | 157 | def batch_set(self, keys, values, losses): 158 | # print(self.link) 159 | num = len(values) 160 | for i in range(num): 161 | self.set(keys[i], [losses[i].item(), values[i], keys[i]]) # add loss,data,key 162 | 163 | # r = LRUCache(3) 164 | # r.set("1", ["1","1"]) 165 | # r.set("2", ["2","2"]) 166 | # r.set("3", ["3","3"]) 167 | # print(r.link.size) 168 | # r.get("1") 169 | # print(r.link) 170 | # r.set("4", ["4","4"]) 171 | # print(r.link) 172 | 173 | -------------------------------------------------------------------------------- /src/Contrastive/loss/NCE/NCECriterion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | eps = 1e-7 5 | 6 | 7 | class NCECriterion(nn.Module): 8 | """ 9 | Eq. (12): L_{NCE} 10 | """ 11 | def __init__(self, n_data): 12 | super(NCECriterion, self).__init__() 13 | self.n_data = n_data 14 | 15 | def forward(self, x): 16 | bsz = x.shape[0] 17 | m = x.size(1) - 1 18 | 19 | # noise distribution 20 | Pn = 1 / float(self.n_data) 21 | 22 | # loss for positive pair 23 | P_pos = x.select(1, 0) 24 | log_D1 = torch.div(P_pos, P_pos.add(m * Pn + eps)).log_() 25 | 26 | # loss for K negative pair 27 | P_neg = x.narrow(1, 1, m) 28 | log_D0 = torch.div(P_neg.clone().fill_(m * Pn), P_neg.add(m * Pn + eps)).log_() 29 | 30 | loss = - (log_D1.sum(0) + log_D0.view(-1, 1).sum(0)) / bsz 31 | 32 | return loss 33 | 34 | 35 | class NCESoftmaxLoss(nn.Module): 36 | """Softmax cross-entropy loss (a.k.a., info-NCE loss in CPC paper)""" 37 | def __init__(self): 38 | super(NCESoftmaxLoss, self).__init__() 39 | self.criterion = nn.CrossEntropyLoss() 40 | 41 | def forward(self, x): 42 | bsz = x.shape[0] 43 | x = x.squeeze() 44 | label = torch.zeros([bsz]).cuda().long() 45 | loss = self.criterion(x, label) 46 | return loss 47 | -------------------------------------------------------------------------------- /src/Contrastive/loss/NCE/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/BE/632c3aa0eaa3acc24a545ec05a9a36f96592cb2c/src/Contrastive/loss/NCE/__init__.py -------------------------------------------------------------------------------- /src/Contrastive/loss/NCE/alias_multinomial.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class AliasMethod(object): 5 | """ 6 | From: https://hips.seas.harvard.edu/blog/2013/03/03/the-alias-method-efficient-sampling-with-many-discrete-outcomes/ 7 | """ 8 | def __init__(self, probs): 9 | 10 | if probs.sum() > 1: 11 | probs.div_(probs.sum()) 12 | K = len(probs) 13 | self.prob = torch.zeros(K) 14 | self.alias = torch.LongTensor([0]*K) 15 | 16 | # Sort the data into the outcomes with probabilities 17 | # that are larger and smaller than 1/K. 18 | smaller = [] 19 | larger = [] 20 | for kk, prob in enumerate(probs): 21 | self.prob[kk] = K*prob 22 | if self.prob[kk] < 1.0: 23 | smaller.append(kk) 24 | else: 25 | larger.append(kk) 26 | 27 | # Loop though and create little binary mixtures that 28 | # appropriately allocate the larger outcomes over the 29 | # overall uniform mixture. 30 | while len(smaller) > 0 and len(larger) > 0: 31 | small = smaller.pop() 32 | large = larger.pop() 33 | 34 | self.alias[small] = large 35 | self.prob[large] = (self.prob[large] - 1.0) + self.prob[small] 36 | 37 | if self.prob[large] < 1.0: 38 | smaller.append(large) 39 | else: 40 | larger.append(large) 41 | 42 | for last_one in smaller+larger: 43 | self.prob[last_one] = 1 44 | 45 | def cuda(self): 46 | self.prob = self.prob.cuda() 47 | self.alias = self.alias.cuda() 48 | 49 | def draw(self, N): 50 | """ 51 | Draw N samples from multinomial 52 | :param N: number of samples 53 | :return: samples 54 | """ 55 | K = self.alias.size(0) 56 | 57 | kk = torch.zeros(N, dtype=torch.long, device=self.prob.device).random_(0, K) 58 | prob = self.prob.index_select(0, kk) 59 | alias = self.alias.index_select(0, kk) 60 | # b is whether a random number is greater than q 61 | b = torch.bernoulli(prob) 62 | oq = kk.mul(b.long()) 63 | oj = alias.mul((1-b).long()) 64 | 65 | return oq + oj 66 | -------------------------------------------------------------------------------- /src/Contrastive/loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/BE/632c3aa0eaa3acc24a545ec05a9a36f96592cb2c/src/Contrastive/loss/__init__.py -------------------------------------------------------------------------------- /src/Contrastive/loss/config.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | 4 | from loss.NCE.NCEAverage import MemoryMoCo 5 | from loss.NCE.NCECriterion import NCECriterion 6 | from loss.NCE.NCECriterion import NCESoftmaxLoss 7 | 8 | 9 | def get_fine_tuning_parameters(model, ft_begin_module='custom'): 10 | if not ft_begin_module: 11 | return model.parameters() 12 | parameters = [] 13 | add_flag = False 14 | for k, v in model.named_parameters(): 15 | parameters.append({'params': v}) 16 | # if ft_begin_module in k: 17 | # add_flag = True 18 | # print(k) 19 | # parameters.append({'params': v}) 20 | # if add_flag: 21 | # parameters.append({'params': v}) 22 | return parameters 23 | 24 | 25 | def pt_optim_init(args, model, n_data): 26 | contrast = MemoryMoCo(128, n_data, args.pt_nce_k, args.pt_nce_t, args.pt_softmax).cuda() 27 | criterion = NCESoftmaxLoss() if args.pt_softmax else NCECriterion(n_data) 28 | criterion = criterion.cuda() 29 | 30 | optimizer = torch.optim.SGD(model.parameters(), 31 | lr=args.pt_learning_rate, 32 | momentum=args.pt_momentum, 33 | weight_decay=args.pt_weight_decay) 34 | return contrast, criterion, optimizer 35 | 36 | 37 | def ft_optim_init(args, model): 38 | train_criterion = torch.nn.NLLLoss().cuda() 39 | val_criterion = torch.nn.NLLLoss().cuda() 40 | 41 | if args.ft_fixed == 1: 42 | parameters = get_fine_tuning_parameters(model, ft_begin_module='custom') 43 | else: 44 | parameters = model.parameters() 45 | if args.ft_optim == 'sgd': 46 | optimizer = optim.SGD(parameters, 47 | lr=args.ft_lr, 48 | momentum=args.ft_momentum, 49 | weight_decay=args.ft_weight_decay) 50 | elif args.ft_optim == 'adam': 51 | optimizer = optim.Adam(parameters, lr=args.lr) 52 | else: 53 | Exception("not supported optim") 54 | if args.ft_fixed == 1: 55 | count = 0 56 | for param_group in optimizer.param_groups: 57 | count += 1 58 | print("param group is: {}".format(count)) 59 | count2 = 0 60 | for param_group in optimizer.param_groups: 61 | count2 += 1 62 | param_group['lr'] = param_group['lr'] * count2 / count 63 | return train_criterion, val_criterion, optimizer 64 | 65 | -------------------------------------------------------------------------------- /src/Contrastive/loss/tcr.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | def flip(x, dim): 6 | indices = [slice(None)] * x.dim() 7 | indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, 8 | dtype=torch.long, device=x.device) 9 | return x[tuple(indices)] 10 | 11 | 12 | def tcr(feats_o, feats_r): 13 | loss = nn.MSELoss() 14 | feats_r = flip(feats_r, 2) 15 | b, c, t, h, w = feats_o.size() 16 | o_t = nn.AdaptiveAvgPool3d((t, 1, 1))(feats_o) 17 | o_r = nn.AdaptiveAvgPool3d((t, 1, 1))(feats_r) 18 | output = loss(o_t, o_r) 19 | return output -------------------------------------------------------------------------------- /src/Contrastive/main.py: -------------------------------------------------------------------------------- 1 | from option import args 2 | import datetime 3 | from pt import pretext_train 4 | from ft import fine_tune_train_and_val 5 | import os 6 | from utils.recoder import Record 7 | 8 | 9 | def main(): 10 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus 11 | args.date = datetime.datetime.today().strftime('%m-%d-%H%M') 12 | recorder = Record(args) 13 | message = "" 14 | if args.method == 'pt': 15 | args.status = 'pt' 16 | pretext_train(args, recorder) 17 | print("finished pretrain with weight from: {}".format(args.ft_weights)) 18 | elif args.method == 'ft': 19 | args.status = 'ft' 20 | fine_tune_train_and_val(args, recorder) 21 | print("finished finetune with weight from: {}".format(args.ft_weights)) 22 | elif args.method == 'pt_and_ft': 23 | args.status = 'pt' 24 | checkpoints_path = pretext_train(args, recorder) 25 | print("finished pretrain, the weight is in: {}".format(args.ft_weights)) 26 | args.status = 'ft' 27 | args.ft_weights = checkpoints_path 28 | fine_tune_train_and_val(args, recorder) 29 | print("finished finetune with weight from: {}".format(checkpoints_path)) 30 | else: 31 | Exception("wrong method!") 32 | 33 | 34 | if __name__ == '__main__': 35 | main() 36 | -------------------------------------------------------------------------------- /src/Contrastive/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/BE/632c3aa0eaa3acc24a545ec05a9a36f96592cb2c/src/Contrastive/model/__init__.py -------------------------------------------------------------------------------- /src/Contrastive/model/c3d.py: -------------------------------------------------------------------------------- 1 | """C3D""" 2 | import math 3 | from collections import OrderedDict 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn.modules.utils import _triple 8 | from model.i3d import Flatten, Normalize 9 | import torch.nn.functional as F 10 | 11 | 12 | class C3D(nn.Module): 13 | """C3D with BN and pool5 to be AdaptiveAvgPool3d(1).""" 14 | 15 | def __init__(self, with_classifier=False, num_classes=101): 16 | super(C3D, self).__init__() 17 | self.with_classifier = with_classifier 18 | self.num_classes = num_classes 19 | 20 | self.conv1 = nn.Conv3d(3, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 21 | self.bn1 = nn.BatchNorm3d(64) 22 | self.relu1 = nn.ReLU() 23 | self.pool1 = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)) 24 | 25 | self.conv2 = nn.Conv3d(64, 128, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 26 | self.bn2 = nn.BatchNorm3d(128) 27 | self.relu2 = nn.ReLU() 28 | self.pool2 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)) 29 | 30 | self.conv3a = nn.Conv3d(128, 256, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 31 | self.bn3a = nn.BatchNorm3d(256) 32 | self.relu3a = nn.ReLU() 33 | self.conv3b = nn.Conv3d(256, 256, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 34 | self.bn3b = nn.BatchNorm3d(256) 35 | self.relu3b = nn.ReLU() 36 | self.pool3 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)) 37 | 38 | self.conv4a = nn.Conv3d(256, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 39 | self.bn4a = nn.BatchNorm3d(512) 40 | self.relu4a = nn.ReLU() 41 | self.conv4b = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 42 | self.bn4b = nn.BatchNorm3d(512) 43 | self.relu4b = nn.ReLU() 44 | self.pool4 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)) 45 | 46 | self.conv5a = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 47 | self.bn5a = nn.BatchNorm3d(512) 48 | self.relu5a = nn.ReLU() 49 | self.conv5b = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 50 | self.bn5b = nn.BatchNorm3d(512) 51 | self.relu5b = nn.ReLU() 52 | self.pool5 = nn.AdaptiveAvgPool3d(1) 53 | 54 | if self.with_classifier: 55 | self.linear = nn.Linear(512, self.num_classes) 56 | else: 57 | self.id_head = nn.Sequential( 58 | torch.nn.AdaptiveAvgPool3d((1, 1, 1)), 59 | Flatten(), 60 | torch.nn.Linear(512, 128), 61 | Normalize(2) 62 | ) 63 | 64 | def forward(self, x, return_conv=False): 65 | x = self.conv1(x) 66 | x = self.bn1(x) 67 | x = self.relu1(x) 68 | x = self.pool1(x) 69 | 70 | x = self.conv2(x) 71 | x = self.bn2(x) 72 | x = self.relu2(x) 73 | x = self.pool2(x) 74 | 75 | x = self.conv3a(x) 76 | x = self.bn3a(x) 77 | x = self.relu3a(x) 78 | x = self.conv3b(x) 79 | x = self.bn3b(x) 80 | x = self.relu3b(x) 81 | x = self.pool3(x) 82 | 83 | x = self.conv4a(x) 84 | x = self.bn4a(x) 85 | x = self.relu4a(x) 86 | x = self.conv4b(x) 87 | x = self.bn4b(x) 88 | x = self.relu4b(x) 89 | x = self.pool4(x) 90 | 91 | x = self.conv5a(x) 92 | x = self.bn5a(x) 93 | x = self.relu5a(x) 94 | x = self.conv5b(x) 95 | x = self.bn5b(x) 96 | x = self.relu5b(x) 97 | 98 | if return_conv: 99 | return x 100 | if not self.with_classifier: 101 | id_out = self.id_head(x) 102 | return id_out, 0, 0 103 | 104 | x = self.pool5(x) 105 | x = x.view(-1, 512) 106 | 107 | if self.with_classifier: 108 | x = self.linear(x) 109 | x = F.log_softmax(x, dim=1) 110 | return x 111 | 112 | # 113 | # if __name__ == '__main__': 114 | # c3d = C3D() -------------------------------------------------------------------------------- /src/Contrastive/model/config.py: -------------------------------------------------------------------------------- 1 | from model.i3d import I3D 2 | from model.r2p1d import R2Plus1DNet 3 | from model.r3d import resnet18, resnet34, resnet50 4 | from model.c3d import C3D 5 | from model.s3d_g import S3D_G 6 | from model.s3d import S3DG 7 | import torch.nn as nn 8 | from model.model import TCN 9 | import torch 10 | from utils.load_weights import ft_load_weight 11 | 12 | 13 | def pt_model_config(args, num_class): 14 | if args.arch == 'i3d': 15 | model = I3D(num_classes=101, modality=args.pt_mode, with_classifier=False) 16 | model_ema = I3D(num_classes=101, modality=args.pt_mode, with_classifier=False) 17 | elif args.arch == 'r2p1d': 18 | model = R2Plus1DNet((1, 1, 1, 1), num_classes=num_class, with_classifier=False) 19 | model_ema = R2Plus1DNet((1, 1, 1, 1), num_classes=num_class, with_classifier=False) 20 | elif args.arch == 'r3d18': 21 | model = resnet18(num_classes=num_class, with_classifier=False) 22 | model_ema = resnet18(num_classes=num_class, with_classifier=False) 23 | elif args.arch == 'r3d34': 24 | model = resnet34(num_classes=num_class, with_classifier=False) 25 | model_ema = resnet34(num_classes=num_class, with_classifier=False) 26 | elif args.arch == 'r3d50': 27 | model = resnet50(num_classes=num_class, with_classifier=False) 28 | model_ema = resnet50(num_classes=num_class, with_classifier=False) 29 | elif args.arch == 'c3d': 30 | model = C3D(with_classifier=False, num_classes=num_class) 31 | model_ema = C3D(with_classifier=False, num_classes=num_class) 32 | elif args.arch == 's3d': 33 | model = S3D_G(num_class=num_class, in_channel=3, gate=True, with_classifier=False) 34 | model_ema = S3D_G(num_class=num_class, in_channel=3, gate=True, with_classifier=False) 35 | else: 36 | Exception("Not implemene error!") 37 | model = torch.nn.DataParallel(model) 38 | model_ema = torch.nn.DataParallel(model_ema) 39 | return model, model_ema 40 | 41 | 42 | def ft_model_config(args, num_class): 43 | with_classifier = True 44 | if args.arch == 'i3d': 45 | base_model = I3D(num_classes=num_class, modality=args.ft_mode, dropout_prob=args.ft_dropout, with_classifier=with_classifier) 46 | # args.logits_channel = 1024 47 | if args.ft_spatial_size == '112': 48 | out_size = (int(args.ft_data_length) // 8, 4, 4) 49 | else: 50 | out_size = (int(args.ft_data_length) // 8, 7, 7) 51 | elif args.arch == 'r2p1d': 52 | base_model = R2Plus1DNet((1, 1, 1, 1), num_classes=num_class, with_classifier=with_classifier) 53 | # args.logits_channel = 512 54 | out_size = (4, 4, 4) 55 | elif args.arch == 'c3d': 56 | base_model = C3D(num_classes=num_class, with_classifier=with_classifier) 57 | # args.logits_channel = 512 58 | out_size = (4, 4, 4) 59 | elif args.arch == 'r3d18': 60 | base_model = resnet18(num_classes=num_class, sample_size=int(args.ft_spatial_size), with_classifier=with_classifier) 61 | # args.logits_channel = 512 62 | out_size = (4, 4, 4) 63 | elif args.arch == 'r3d34': 64 | base_model = resnet34(num_classes=num_class, sample_size=int(args.ft_spatial_size), with_classifier=with_classifier) 65 | # args.logits_channel = 512 66 | out_size = (4, 4, 4) 67 | elif args.arch == 'r3d50': 68 | base_model = resnet50(num_classes=num_class, sample_size=int(args.ft_spatial_size), with_classifier=with_classifier) 69 | # args.logits_channel = 512 70 | out_size = (4, 4, 4) 71 | elif args.arch == 's3d': 72 | # base_model = S3D_G(num_class=num_class, drop_prob=args.dropout, in_channel=3) 73 | base_model = S3DG(num_classes=num_class, dropout_keep_prob=args.ft_dropout, input_channel=3, spatial_squeeze=True, with_classifier=True) 74 | # args.logits_channel = 1024 75 | out_size = (2, 7, 7) 76 | else: 77 | Exception("unsuporrted arch!") 78 | base_model = ft_load_weight(args, base_model) 79 | model = TCN(base_model, out_size, args) 80 | model = nn.DataParallel(model).cuda() 81 | # cudnn.benchmark = True 82 | return model 83 | -------------------------------------------------------------------------------- /src/Contrastive/model/model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | import torch 4 | import numpy as np 5 | 6 | class Flatten(nn.Module): 7 | def __init__(self): 8 | super(Flatten, self).__init__() 9 | 10 | def forward(self, input): 11 | return input.view(input.size(0), -1) 12 | 13 | 14 | class Normalize(nn.Module): 15 | def __init__(self, power=2): 16 | super(Normalize, self).__init__() 17 | self.power = power 18 | 19 | def forward(self, x): 20 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1./self.power) 21 | out = x.div(norm) 22 | return out 23 | 24 | 25 | class Sharpen(nn.Module): 26 | def __init__(self, tempeature=0.5): 27 | super(Sharpen, self).__init__() 28 | self.T = tempeature 29 | 30 | def forward(self, probabilities): 31 | tempered = torch.pow(probabilities, 1 / self.T) 32 | tempered = tempered / tempered.sum(dim=-1, keepdim=True) 33 | return tempered 34 | 35 | class MotionEnhance(nn.Module): 36 | def __init__(self, beta=1, maxium_radio=0.3): 37 | super(MotionEnhance, self).__init__() 38 | self.beta = beta 39 | self.maxium_radio = maxium_radio 40 | 41 | def forward(self, x): 42 | b, c, t, h, w = x.size() 43 | mean = nn.AdaptiveAvgPool3d((1, h, w))(x) 44 | lam = np.random.beta(self.beta, self.beta) * self.maxium_radio 45 | out = (x - mean * lam) * (1 / (1 - lam)) 46 | return out 47 | 48 | 49 | class TCN(nn.Module): 50 | """ 51 | encode a video clip into 128 dimension features and classify 52 | two implement ways, reshape and encode adjcent samples into batch dimension 53 | """ 54 | def __init__(self, base_model, out_size, args): 55 | super(TCN, self).__init__() 56 | self.base_model = base_model 57 | self.args = args 58 | self.l2norm = Normalize(2) 59 | print("fine tune ...") 60 | 61 | def forward(self, input): 62 | output = self.base_model(input, return_conv=False) 63 | # print(output.size()) 64 | # output = F.log_softmax(output, dim=1) 65 | return output 66 | -------------------------------------------------------------------------------- /src/Contrastive/model/mutual_net.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class MutualNet(nn.Module): 5 | def __init__(self, embeddingnet): 6 | super(MutualNet, self).__init__() 7 | self.embeddingnet = embeddingnet 8 | 9 | def forward(self, x, y, z): 10 | feature_x = self.embeddingnet(x) 11 | feature_y = self.embeddingnet(y) 12 | return feature_x, feature_y -------------------------------------------------------------------------------- /src/Contrastive/model/s3d_g.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from collections import OrderedDict 4 | from model.i3d import Flatten, Normalize 5 | # import ipdb 6 | import torch.nn.functional as F 7 | 8 | 9 | class BasicConv3d(nn.Module): 10 | def __init__(self, 11 | in_channel, 12 | out_channel, 13 | kernel_size=1, 14 | stride=1, 15 | padding=0, 16 | use_bias=False, 17 | use_bn=True, 18 | activation='rule'): 19 | super(BasicConv3d, self).__init__() 20 | 21 | self.use_bn = use_bn 22 | self.activation = activation 23 | self.conv3d = nn.Conv3d(in_channel, out_channel, kernel_size=kernel_size, 24 | stride=stride, padding=padding, bias=use_bias) 25 | if use_bn: 26 | self.bn = nn.BatchNorm3d(out_channel, eps=1e-3, momentum=0.001, affine=True) 27 | if activation == 'rule': 28 | self.activation = nn.ReLU() 29 | 30 | def forward(self, x): 31 | x = self.conv3d(x) 32 | if self.use_bn: 33 | x = self.bn(x) 34 | if self.activation is not None: 35 | x = self.activation(x) 36 | # ipdb.set_trace() 37 | return x 38 | 39 | 40 | class sep_conv(nn.Module): 41 | def __init__(self, 42 | in_channel, 43 | out_channel, 44 | kernel_size, 45 | stride=1, 46 | padding=0, 47 | use_bias=True, 48 | use_bn=True, 49 | activation='rule', 50 | gate=True): 51 | super(sep_conv, self).__init__() 52 | down = BasicConv3d(in_channel, out_channel, (1,kernel_size,kernel_size), stride=stride, 53 | padding=(0,padding,padding), use_bias=False, use_bn=True) 54 | up = BasicConv3d(out_channel, out_channel, (kernel_size,1,1), stride=1, 55 | padding=(padding,0,0), use_bias=False, use_bn=True) 56 | self.sep_conv = nn.Sequential(down, up) 57 | 58 | # gating 59 | if gate: 60 | self.gate = gate 61 | self.squeeze = nn.AdaptiveAvgPool3d(1) 62 | self.excitation = nn.Conv3d(out_channel, out_channel, 1) 63 | self.sigmoid = nn.Sigmoid() 64 | else: 65 | self.gate = False 66 | 67 | def forward(self, x): 68 | x = self.sep_conv(x) 69 | # ipdb.set_trace() 70 | if self.gate: 71 | temp = x 72 | weight = self.squeeze(x) 73 | weight = self.excitation(weight) 74 | weight = self.sigmoid(weight) 75 | x = weight * x 76 | return x 77 | 78 | 79 | class sep_inc(nn.Module): 80 | def __init__(self, in_channel, out_channel, gate=True): 81 | super(sep_inc, self).__init__() 82 | # branch 0 83 | self.branch0 = BasicConv3d(in_channel, out_channel[0], kernel_size=(1,1,1), stride=1, padding=0) 84 | # branch 1 85 | branch1_conv1 = BasicConv3d(in_channel, out_channel[1],kernel_size=(1,1,1), stride=1, padding=0) 86 | branch1_sep_conv = sep_conv(out_channel[1], out_channel[2], kernel_size=3, stride=1, padding=1, gate=gate) 87 | self.branch1 = nn.Sequential(branch1_conv1, branch1_sep_conv) 88 | # branch 2 89 | branch2_conv1 = BasicConv3d(in_channel, out_channel[3],kernel_size=(1,1,1), stride=1, padding=0) 90 | branch2_sep_conv = sep_conv(out_channel[3], out_channel[4], kernel_size=3, stride=1, padding=1, gate=gate) 91 | self.branch2 = nn.Sequential(branch2_conv1, branch2_sep_conv) 92 | # branch 3 93 | branch3_pool = nn.MaxPool3d(kernel_size=3, stride=1, padding=1) 94 | branch3_conv = BasicConv3d(in_channel, out_channel[5], kernel_size=(1,1,1)) 95 | self.branch3 = nn.Sequential(branch3_pool, branch3_conv) 96 | 97 | def forward(self, x): 98 | # ipdb.set_trace() 99 | out_0 = self.branch0(x) 100 | out_1 = self.branch1(x) 101 | out_2 = self.branch2(x) 102 | out_3 = self.branch3(x) 103 | out = torch.cat((out_0, out_1, out_2, out_3), 1) 104 | return out 105 | 106 | 107 | class S3D_G(nn.Module): 108 | def __init__(self, num_class=400, drop_prob=0.5, in_channel=3, gate=True, with_classifier=True): 109 | super(S3D_G, self).__init__() 110 | self.feature = nn.Sequential(OrderedDict([ 111 | ('sepConv1', sep_conv(in_channel, 64, kernel_size=7, stride=2, padding=3, gate=gate)), # (64,32,112,112) 112 | ('maxPool1', nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1))), # (64,32,56,56) 113 | ('basicConv3d', BasicConv3d(64, 64, kernel_size=1, stride=1)), # (64,32,56,56) 114 | ('sep_conv2', sep_conv(64, 192, kernel_size=3, stride=1, padding=1, gate=gate)), # (192,32,56,56) 115 | ('maxPool2', nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1))), # (192,32,28,28) 116 | ('sepInc_3b', sep_inc(192, [64,96,128,16,32,32], gate=gate)), # (256,32,28,28) 117 | ('sepInc_3c', sep_inc(256, [128,128,192,32,96,64], gate=gate)), # (480,32,28,28) 118 | ('maxPool3', nn.MaxPool3d(kernel_size=(3,3,3), stride=(2,2,2), padding=(1,1,1))), # (480,16,14,14) 119 | ('sepInc_4b', sep_inc(480, [192, 96, 208, 16, 48, 64], gate=gate)), # (512,16,14,14) 120 | ('sepInc_4c', sep_inc(512, [160, 112, 224, 24, 64, 64], gate=gate)), # (512,16,14,14) 121 | ('sepInc_4d', sep_inc(512, [128, 128, 256, 24, 64, 64], gate=gate)), # (512,16,14,14) 122 | ('sepInc_4e', sep_inc(512, [112, 144, 288, 32, 64, 64], gate=gate)), # (528,16,14,14) 123 | ('sepInc_4f', sep_inc(528, [256, 160, 320, 32, 128, 128], gate=gate)), # (832,16,14,14) 124 | ('maxpool4', nn.MaxPool3d(kernel_size=(2,2,2),stride=(2,2,2),padding=(0,0,0))), # (832,8,7,7) 125 | ('sepInc_5b', sep_inc(832, [256, 160, 320, 32, 128, 128], gate=gate)), # (832,8,7,7) 126 | ('sepInc_5c', sep_inc(832, [384, 192, 384, 48, 128, 128], gate=gate)), # (1024,8,7,7) 127 | ])) 128 | self.avgPool = nn.AvgPool3d(kernel_size=(2, 7, 7), stride=1) # (1024,7,1,1) 129 | self.drop = nn.Dropout3d(drop_prob) 130 | self.fc = nn.Conv3d(1024, num_class, kernel_size=1, stride=1, bias=True) # (num_class,7,1,1) 131 | self.softmax = nn.Softmax(1) 132 | self.with_classifier = with_classifier 133 | if not with_classifier: 134 | self.id_head = nn.Sequential( 135 | torch.nn.AdaptiveAvgPool3d((1, 1, 1)), 136 | Flatten(), 137 | torch.nn.Linear(1024, 128), 138 | Normalize(2) 139 | ) 140 | 141 | def forward(self, x, return_conv=False): 142 | # ipdb.set_trace() 143 | out = self.feature(x) # (batch_size,num_class,7,1,1) 144 | if return_conv: 145 | return out 146 | if not self.with_classifier: 147 | out = self.id_head(out) 148 | return out, 0, 0 149 | out = self.drop(self.avgPool(out)) 150 | out = self.fc(out) 151 | # squeeze the spatial dimension 152 | out = out.squeeze(3) # (batch_size,num_class,7,1) 153 | out = out.squeeze(3) # (batch_size,num_class,7) 154 | out = out.mean(2) # (batch_size,num_class) 155 | return F.log_softmax(out, dim=1) 156 | # prediction = self.softmax(out) 157 | # return prediction 158 | 159 | 160 | if __name__ == "__main__": 161 | model = S3D_G() 162 | x = torch.rand((1,3,64,224,224)) 163 | p, out = model(x) -------------------------------------------------------------------------------- /src/Contrastive/reterival.py: -------------------------------------------------------------------------------- 1 | """Video retrieval experiment, top-k.""" 2 | import os 3 | import json 4 | import numpy as np 5 | from sklearn.metrics.pairwise import cosine_distances, euclidean_distances 6 | 7 | 8 | def topk_retrieval(feature_dir): 9 | """Extract features from test split and search on train split features.""" 10 | print('Load local .npy files. from ...', feature_dir) 11 | train_features = np.load(os.path.join(feature_dir, 'train_features.npy'), allow_pickle=True).item() 12 | X_train = train_features['data'] 13 | y_train = train_features['target'] 14 | # X_train = np.mean(X_train, 1) 15 | # y_train = y_train[:, 0] 16 | # X_train = X_train.reshape((-1, X_train.shape[-1])) 17 | # y_train = y_train.reshape(-1) 18 | 19 | val_features = np.load(os.path.join(feature_dir, 'val_features.npy'), allow_pickle=True).item() 20 | X_test = val_features['data'] 21 | y_test = val_features['target'] 22 | # X_test = np.mean(X_test, 1) 23 | # y_test = y_test[:, 0] 24 | # X_test = X_test.reshape((-1, X_test.shape[-1])) 25 | # y_test = y_test.reshape(-1) 26 | 27 | ks = [1, 5, 10, 20, 50] 28 | topk_correct = {k: 0 for k in ks} 29 | 30 | distances = cosine_distances(X_test, X_train) 31 | indices = np.argsort(distances) # 1530 x 3570 32 | 33 | for k in ks: 34 | top_k_indices = indices[:, :k] 35 | for ind, test_label in zip(top_k_indices, y_test): 36 | # print(ind) 37 | for j in range(len(ind)): 38 | labels = y_train[ind[j]] 39 | if test_label in labels: 40 | topk_correct[k] += 1 41 | break 42 | 43 | for k in ks: 44 | correct = topk_correct[k] 45 | total = len(X_test) 46 | print('Top-{}, correct = {:.2f}, total = {}, acc = {:.3f}%'.format(k, correct, total, correct / total * 100)) 47 | 48 | with open(os.path.join(feature_dir, 'topk_correct.json'), 'w') as fp: 49 | json.dump(topk_correct, fp) 50 | 51 | 52 | if __name__ == '__main__': 53 | # front = "contrastive_kinetics_warpping" 54 | # front = "triplet_ucf101_warpping_hmdb51" 55 | # front = "contrastive_ucf101_warpping_hmdb51" 56 | # front = "contrastive_ucf101_warpping_ucf101" 57 | # front = "triplet_ucf101_warpping_hmdb51" 58 | # front = "i3d_fully_supervised_kinetics_warpping_hmdb51_finetune" 59 | # front = "c3d_fully_supervised_ucf101_warpping_hmdb51" 60 | front = "c3d_c3d_contrastive_ucf101_warpping_ucf101_ucf101" 61 | feature_dirs = "../experiments/features/{}".format(front) 62 | topk_retrieval(feature_dirs) 63 | 64 | 65 | # ==============================kinetics BE contrastive pretrain ======================= 66 | # hmdb51: 11.9 / 31.3 / 44.452 / 60.432 / 81.23 67 | # ucf101: 13.0/35.16/44.0/64.78/83.76 68 | 69 | # ===============================ucf101 BE triplet pretrain=================================== 70 | # hmdb51: 3.922 / 17.386 / 29.15 / 45.359 / 69.020 (may not best) 71 | 72 | # =============================ucf101 BE contrastive pretrain=========================== 73 | # hmdb51: 11.4 / 31.2 / 46.5/ 60.4 / 80.876 74 | # ucf101: 17.394 / 35.184 / 45.308 / 57.811 / 73.962 75 | # c3d 76 | # hmdb51: 8.23/25.88/38.10/51.96/75.0 77 | -------------------------------------------------------------------------------- /src/Contrastive/scripts/Diving48/ft.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python main.py \ 3 | --method ft --arch i3d \ 4 | --ft_train_list ../datasets/lists/diving48/diving48_v2_train_no_front.txt \ 5 | --ft_val_list ../datasets/lists/diving48/diving48_v2_test_no_front.txt \ 6 | --ft_root /data1/DataSet/Diving48/rgb_frames/ \ 7 | --ft_dataset diving48 --ft_mode rgb \ 8 | --ft_lr 0.001 --ft_lr_steps 10 20 25 30 35 40 --ft_epochs 45 --ft_batch_size 4 \ 9 | --ft_data_length 64 --ft_spatial_size 224 --ft_workers 4 --ft_stride 1 --ft_dropout 0.5 \ 10 | --ft_print-freq 100 --ft_fixed 0 \ 11 | --ft_weights ../experiments/kinetics_contrastive.pth -------------------------------------------------------------------------------- /src/Contrastive/scripts/Diving48/pt_and_ft.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python main.py --gpus 1 --method pt_and_ft --pt_method be \ 3 | --pt_batch_size 8 --pt_workers 4 --arch i3d --pt_spatial_size 224 --pt_stride 4 --pt_data_length 16 \ 4 | --pt_nce_k 3569 --pt_softmax \ 5 | --pt_moco --pt_epochs 10 --pt_save_freq 4 --pt_print_freq 100 --pt_dataset diving48 \ 6 | --pt_train_list ../datasets/lists/diving48/diving48_v2_train_no_front.txt \ 7 | --pt_val_list ../datasets/lists/diving48/diving48_v2_test_no_front.txt \ 8 | --pt_root /data1/yutinggao/data/videos/Diving48_rgb_frames \ 9 | --ft_train_list ../datasets/lists/diving48/diving48_v2_train_no_front.txt \ 10 | --ft_val_list ../datasets/lists/diving48/diving48_v2_test_no_front.txt \ 11 | --ft_root /data1/yutinggao/data/videos/Diving48_rgb_frames \ 12 | --ft_dataset diving48 --ft_mode rgb \ 13 | --ft_lr 0.001 --ft_lr_steps 10 20 25 30 35 40 --ft_epochs 45 --ft_batch_size 4 \ 14 | --ft_data_length 64 --ft_spatial_size 224 --ft_workers 4 --ft_stride 1 --ft_dropout 0.5 \ 15 | --ft_print-freq 100 --ft_fixed 0 -------------------------------------------------------------------------------- /src/Contrastive/scripts/Diving48/pt_and_ft_moco.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python main.py --gpus 1 --method pt_and_ft --pt_method moco \ 3 | --pt_batch_size 4 --pt_workers 4 --arch i3d --pt_spatial_size 224 --pt_stride 4 --pt_data_length 16 \ 4 | --pt_nce_k 3569 --pt_softmax \ 5 | --pt_moco --pt_epochs 10 --pt_save_freq 4 --pt_print_freq 100 --pt_dataset diving48 \ 6 | --pt_train_list ../datasets/lists/diving48/diving48_v2_train_no_front.txt \ 7 | --pt_val_list ../datasets/lists/diving48/diving48_v2_test_no_front.txt \ 8 | --pt_root /data1/DataSet/Diving48/rgb_frames/ \ 9 | --ft_train_list ../datasets/lists/diving48/diving48_v2_train_no_front.txt \ 10 | --ft_val_list ../datasets/lists/diving48/diving48_v2_test_no_front.txt \ 11 | --ft_root /data1/DataSet/Diving48/rgb_frames/ \ 12 | --ft_dataset diving48 --ft_mode rgb \ 13 | --ft_lr 0.001 --ft_lr_steps 10 20 25 30 35 40 --ft_epochs 45 --ft_batch_size 4 \ 14 | --ft_data_length 64 --ft_spatial_size 224 --ft_workers 4 --ft_stride 1 --ft_dropout 0.5 \ 15 | --ft_print-freq 100 --ft_fixed 0 -------------------------------------------------------------------------------- /src/Contrastive/scripts/Diving48/pt_and_ft_moco_ucf_to_diving48.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python main.py --gpus 1 --method pt_and_ft --pt_method moco \ 3 | --pt_batch_size 4 --pt_workers 4 --arch i3d --pt_spatial_size 224 --pt_stride 4 --pt_data_length 16 \ 4 | --pt_nce_k 3569 --pt_softmax \ 5 | --pt_moco --pt_epochs 10 --pt_save_freq 4 --pt_print_freq 100 --pt_dataset ucf101 \ 6 | --pt_train_list ../datasets/lists/ucf101/ucf101_rgb_train_split_1.txt \ 7 | --pt_val_list ../datasets/lists/ucf101/ucf101_rgb_val_split_1.txt \ 8 | --ft_train_list ../datasets/lists/diving48/diving48_v2_train_no_front.txt \ 9 | --ft_val_list ../datasets/lists/diving48/diving48_v2_test_no_front.txt \ 10 | --ft_root /data1/DataSet/Diving48/rgb_frames/ \ 11 | --ft_dataset diving48 --ft_mode rgb \ 12 | --ft_lr 0.001 --ft_lr_steps 10 20 25 30 35 40 --ft_epochs 45 --ft_batch_size 4 \ 13 | --ft_data_length 64 --ft_spatial_size 224 --ft_workers 4 --ft_stride 1 --ft_dropout 0.5 \ 14 | --ft_print-freq 100 --ft_fixed 0 -------------------------------------------------------------------------------- /src/Contrastive/scripts/evaluation/hmdb51_i3d_eval.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python test.py \ 3 | --method ft \ 4 | --train_list ../datasets/lists/hmdb51/hmdb51_rgb_train_split_1.txt \ 5 | --val_list ../datasets/lists/hmdb51/hmdb51_rgb_val_split_1.txt \ 6 | --dataset hmdb51 \ 7 | --arch i3d \ 8 | --mode rgb \ 9 | --batch_size 1 \ 10 | --stride 1 \ 11 | --data_length 64 \ 12 | --clip_size 64 \ 13 | --spatial_size 224 \ 14 | --workers 1 \ 15 | --dropout 0.5 \ 16 | --gpus 2 \ 17 | --weights ../experiments/logs/hmdb51_i3d_ft/ft_04-11-0112/fine_tune_rgb_model_latest.pth.tar #26? 18 | #--weights ../experiments/logs/hmdb51_i3d_ft/ft_02-18-1134/fine_tune_rgb_model_latest.pth.tar #25.94 19 | #--weights ../experiments/logs/hmdb51_i3d_ft/ft_03-23-2206/fine_tune_rgb_model_latest.pth.tar #49.86 -------------------------------------------------------------------------------- /src/Contrastive/scripts/evaluation/ucf101_i3d_eval.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python test.py \ 3 | --method ft \ 4 | --train_list ../datasets/lists/ucf101/ucf101_rgb_train_split_1.txt \ 5 | --val_list ../datasets/lists/ucf101/ucf101_rgb_val_split_1.txt \ 6 | --dataset ucf101 \ 7 | --arch i3d \ 8 | --mode rgb \ 9 | --batch_size 1 \ 10 | --stride 1 \ 11 | --data_length 64 \ 12 | --clip_size 64 \ 13 | --spatial_size 224 \ 14 | --workers 1 \ 15 | --dropout 0.5 \ 16 | --gpus 2 \ 17 | --weights ../experiments/fine_tune_rgb_model_latest.pth.tar #Scratch: ? 18 | #--weights ../experiments/logs/ucf101_i3d_ft/ft_04-02-1128/fine_tune_rgb_model_latest.pth.tar #SSL: 78.83 19 | #--weights ../experiments/logs/hmdb51_i3d_ft/ft_02-18-1134/fine_tune_rgb_model_latest.pth.tar #25.94 20 | #--weights ../experiments/logs/hmdb51_i3d_ft/ft_03-23-2206/fine_tune_rgb_model_latest.pth.tar #49.86 -------------------------------------------------------------------------------- /src/Contrastive/scripts/feature_extract/hmdb51_extract.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python feature_extractor.py \ 3 | --eval_indict feature_extract \ 4 | --method ft \ 5 | --train_list ../datasets/lists/hmdb51/hmdb51_rgb_train_split_1.txt \ 6 | --val_list ../datasets/lists/hmdb51/hmdb51_rgb_val_split_1.txt \ 7 | --dataset hmdb51 \ 8 | --arch c3d \ 9 | --mode rgb \ 10 | --lr 0.001 \ 11 | --lr_steps 10 20 25 30 35 40 \ 12 | --epochs 45 \ 13 | --batch_size 1 \ 14 | --data_length 64 \ 15 | --spatial_size 224 \ 16 | --workers 8 \ 17 | --dropout 0.5 \ 18 | --gpus 1 \ 19 | --logs_path ../experiments/logs/hmdb51_i3d_ft \ 20 | --print-freq 100 \ 21 | --front c3d_fully_supervised_ucf101_warpping_hmdb51 \ 22 | --weights ../experiments/MoCo/ucf101/models/08-19-1644_aug_CJ/ckpt_epoch_20.pth 23 | # --weights ../experiments/Pretrained/i3d_rgb_imagenet.pt # kinetics fully supervised 24 | # --weights ../experiments/Pretrained/i3d_model_rgb.pth # kinetics fully supervised finetune 25 | # --weights ../experiments/triplet/kinetics/models/08-18-1957_aug_CJ/ckpt_epoch_30.pth #ucf101_triplet 26 | #--weights ../experiments/MoCo/ucf101/models/08-18-1956_aug_CJ/ckpt_epoch_40.pth # ucf101_contrastive_wrapping 27 | #--weights ../experiments/MoCo/ucf101/models/08-12-1150_aug_CJ/ckpt_epoch_42.pth #kinetics 28 | #--weights ../experiments/logs/hmdb51_i3d_ft/ft_03-23-2206/fine_tune_rgb_model_latest.pth.tar 29 | #--weights ../experiments/logs/hmdb51_i3d_pt_and_ft/pt_and_ft_02-18-1837/fine_tune_rgb_model_best.pth.tar 30 | #--weights ../experiments/logs/hmdb51_i3d_pt_and_ft/pt_and_ft_02-19-1046/fine_tune_rgb_model_best.pth.tar 31 | #--weights ../experiments/logs/hmdb51_i3d_pt_and_ft/pt_and_ft_02-19-1058/fine_tune_rgb_model_best.pth.tar 32 | #--weights ../experiments/logs/hmdb51_i3d_pt_and_ft/pt_and_ft_02-19-1058/net_mixup_rgb_model_best.pth.tar 33 | #--weights ../experiments/logs/hmdb51_i3d_pt_and_ft/pt_and_ft_02-19-1046/flip_cls_rgb_model_best.pth.tar 34 | #--weights ../experiments/pretrained_model/model_rgb.pth 35 | #--weights ../experiments/logs/hmdb51_i3d_pt_and_ft/pt_and_ft_02-15-1229/mutual_loss_rgb_model_latest.pth.tar -------------------------------------------------------------------------------- /src/Contrastive/scripts/feature_extract/ucf101_extract.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python feature_extractor.py \ 3 | --eval_indict feature_extract \ 4 | --method ft \ 5 | --train_list ../datasets/lists/ucf101/ucf101_rgb_train_split_1.txt \ 6 | --val_list ../datasets/lists/ucf101/ucf101_rgb_val_split_1.txt \ 7 | --dataset ucf101 \ 8 | --arch c3d \ 9 | --mode rgb \ 10 | --lr 0.001 \ 11 | --lr_steps 10 20 25 30 35 40 \ 12 | --epochs 45 \ 13 | --batch_size 1 \ 14 | --data_length 64 \ 15 | --spatial_size 224 \ 16 | --workers 8 \ 17 | --dropout 0.5 \ 18 | --gpus 2 \ 19 | --logs_path ../experiments/logs/hmdb51_c3d_ft \ 20 | --print-freq 100 \ 21 | --front c3d_contrastive_ucf101_warpping_ucf101 \ 22 | --weights ../experiments/MoCo/ucf101/models/08-19-1644_aug_CJ/ckpt_epoch_20.pth #ucf101_triplet 23 | # --weights ../experiments/MoCo/ucf101/models/08-18-1956_aug_CJ/ckpt_epoch_40.pth # ucf101_contrastive_wrapping 24 | # --weights ../experiments/triplet/ucf101/models/08-04-2112_aug_CJ/ckpt_epoch_10.pth # ucf101_triplet 25 | # --weights ../experiments/MoCo/ucf101/models/08-14-1615_aug_CJ/ckpt_epoch_150.pth # ucf101_contrastive 26 | # --weights ../experiments/MoCo/ucf101/models/08-12-1150_aug_CJ/ckpt_epoch_42.pth # kinetics -------------------------------------------------------------------------------- /src/Contrastive/scripts/hmdb51/ft.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python main.py --gpus 0 --method ft --arch i3d --pt_method be \ 3 | --ft_train_list ../datasets/lists/hmdb51/hmdb51_rgb_train_split_1.txt \ 4 | --ft_val_list ../datasets/lists/hmdb51/hmdb51_rgb_val_split_1.txt \ 5 | --ft_dataset hmdb51 --ft_mode rgb \ 6 | --ft_lr 0.001 --ft_lr_steps 10 20 25 30 35 40 --ft_epochs 45 --ft_batch_size 4 \ 7 | --ft_data_length 64 --ft_spatial_size 224 --ft_workers 4 --ft_stride 1 --ft_dropout 0.5 \ 8 | --ft_print-freq 100 --ft_fixed 0 \ 9 | --ft_weights ../experiments/ucf101_contrastive.pth -------------------------------------------------------------------------------- /src/Contrastive/scripts/hmdb51/pt_and_ft.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python main.py --gpus 0 --method pt_and_ft --pt_method be \ 3 | --pt_batch_size 4 --pt_workers 4 --arch i3d --pt_spatial_size 224 --pt_stride 4 --pt_data_length 16 \ 4 | --pt_nce_k 3569 --pt_softmax \ 5 | --pt_moco --pt_epochs 200 --pt_save_freq 4 --pt_print_freq 100 --pt_dataset ucf101 \ 6 | --pt_train_list ../datasets/lists/ucf101/ucf101_rgb_train_split_1.txt \ 7 | --pt_val_list ../datasets/lists/ucf101/ucf101_rgb_val_split_1.txt \ 8 | --ft_train_list ../datasets/lists/hmdb51/hmdb51_rgb_train_split_1.txt \ 9 | --ft_val_list ../datasets/lists/hmdb51/hmdb51_rgb_val_split_1.txt \ 10 | --ft_dataset hmdb51 --ft_mode rgb \ 11 | --ft_lr 0.001 --ft_lr_steps 10 20 25 30 35 40 --ft_epochs 45 --ft_batch_size 4 \ 12 | --ft_data_length 64 --ft_spatial_size 224 --ft_workers 4 --ft_stride 1 --ft_dropout 0.5 \ 13 | --ft_print-freq 100 --ft_fixed 0 -------------------------------------------------------------------------------- /src/Contrastive/scripts/hmdb51/pt_and_ft_hmdb51.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python main.py --gpus 3 --method pt_and_ft --pt_method be \ 3 | --pt_batch_size 4 --pt_workers 4 --arch i3d --pt_spatial_size 224 --pt_stride 4 --pt_data_length 16 \ 4 | --pt_nce_k 3569 --pt_softmax \ 5 | --pt_moco --pt_epochs 10 --pt_save_freq 4 --pt_print_freq 100 --pt_dataset hmdb51 \ 6 | --pt_train_list ../datasets/lists/hmdb51/hmdb51_rgb_train_split_1.txt \ 7 | --pt_val_list ../datasets/lists/hmdb51/hmdb51_rgb_val_split_1.txt \ 8 | --ft_train_list ../datasets/lists/hmdb51/hmdb51_rgb_train_split_1.txt \ 9 | --ft_val_list ../datasets/lists/hmdb51/hmdb51_rgb_val_split_1.txt \ 10 | --ft_dataset hmdb51 --ft_mode rgb \ 11 | --ft_lr 0.001 --ft_lr_steps 10 20 25 30 35 40 --ft_epochs 45 --ft_batch_size 4 \ 12 | --ft_data_length 64 --ft_spatial_size 224 --ft_workers 4 --ft_stride 1 --ft_dropout 0.5 \ 13 | --ft_print-freq 100 --ft_fixed 0 -------------------------------------------------------------------------------- /src/Contrastive/scripts/kinetics/pt_and_ft.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python main.py --gpus 0,1,2,3,4,5,6,7 --method pt_and_ft --pt_method be \ 3 | --pt_batch_size 64 --pt_workers 16 --arch i3d --pt_spatial_size 224 --pt_stride 4 --pt_data_length 16 \ 4 | --pt_nce_k 65536 --pt_softmax \ 5 | --pt_moco --pt_epochs 50 --pt_save_freq 4 --pt_print_freq 100 --pt_dataset kinetics \ 6 | --pt_train_list ../datasets/lists/kinetics-400/ssd_kinetics_video_trainlist.txt \ 7 | --pt_val_list ../datasets/lists/kinetics-400/ssd_kinetics_video_vallist.txt \ 8 | --ft_train_list ../datasets/lists/hmdb51/hmdb51_rgb_train_split_1.txt \ 9 | --ft_val_list ../datasets/lists/hmdb51/hmdb51_rgb_val_split_1.txt \ 10 | --ft_dataset hmdb51 --ft_mode rgb \ 11 | --ft_lr 0.001 --ft_lr_steps 10 20 25 30 35 40 --ft_epochs 45 --ft_batch_size 4 \ 12 | --ft_data_length 64 --ft_spatial_size 224 --ft_workers 4 --ft_stride 1 --ft_dropout 0.5 \ 13 | --ft_print-freq 100 --ft_fixed 0 -------------------------------------------------------------------------------- /src/Contrastive/scripts/something-something-v1/i3d_ft.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python main.py --gpus 2 --method ft \ 3 | --arch i3d \ 4 | --ft_train_list ../datasets/lists/something_something_v1/train_videofolder.txt \ 5 | --ft_val_list ../datasets/lists/something_something_v1/val_videofolder.txt \ 6 | --ft_root /data1/DataSet/something-something/20bn-something-something-v1/ \ 7 | --ft_dataset sth_v1 --ft_mode rgb \ 8 | --ft_lr 0.001 --ft_lr_steps 10 20 25 30 35 40 --ft_epochs 45 --ft_batch_size 6 \ 9 | --ft_data_length 64 --ft_spatial_size 224 --ft_workers 6 --ft_stride 1 --ft_dropout 0.5 \ 10 | --ft_print-freq 100 --ft_fixed 0 -------------------------------------------------------------------------------- /src/Contrastive/scripts/something-something-v1/i3d_pt_and_ft.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python main.py --gpus 3 --method pt_and_ft --pt_method be \ 3 | --pt_batch_size 48 --pt_workers 8 --arch i3d --pt_spatial_size 224 --pt_stride 4 --pt_data_length 16 \ 4 | --pt_nce_k 3569 --pt_softmax \ 5 | --pt_moco --pt_epochs 10 --pt_save_freq 4 --pt_print_freq 100 --pt_dataset sth_v1 \ 6 | --pt_train_list ../datasets/lists/something_something_v1/train_videofolder.txt \ 7 | --pt_val_list ../datasets/lists/something_something_v1/val_videofolder.txt \ 8 | --pt_root /data1/DataSet/something-something/20bn-something-something-v1/ \ 9 | --ft_train_list ../datasets/lists/something_something_v1/train_videofolder.txt \ 10 | --ft_val_list ../datasets/lists/something_something_v1/val_videofolder.txt \ 11 | --ft_root /data1/DataSet/something-something/20bn-something-something-v1/ \ 12 | --ft_dataset sth_v1 --ft_mode rgb \ 13 | --ft_lr 0.001 --ft_lr_steps 10 20 25 30 35 40 --ft_epochs 45 --ft_batch_size 48 \ 14 | --ft_data_length 16 --ft_spatial_size 224 --ft_workers 8 --ft_stride 1 --ft_dropout 0.5 \ 15 | --ft_print-freq 100 --ft_fixed 0 -------------------------------------------------------------------------------- /src/Contrastive/scripts/something-something-v1/i3d_pt_and_ft_multi_gpus.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python main.py --gpus 0,1,2,3 --method pt_and_ft --pt_method be --arch i3d \ 3 | --pt_batch_size 128 --pt_workers 16 --pt_spatial_size 224 --pt_stride 4 --pt_data_length 16 \ 4 | --pt_nce_k 3569 --pt_softmax \ 5 | --pt_moco --pt_epochs 10 --pt_save_freq 4 --pt_print_freq 100 --pt_dataset sth_v1 \ 6 | --pt_train_list ../datasets/lists/something_something_v1/train_videofolder.txt \ 7 | --pt_val_list ../datasets/lists/something_something_v1/val_videofolder.txt \ 8 | --pt_root /data1/DataSet/something-something/20bn-something-something-v1/ \ 9 | --ft_train_list ../datasets/lists/something_something_v1/train_videofolder.txt \ 10 | --ft_val_list ../datasets/lists/something_something_v1/val_videofolder.txt \ 11 | --ft_root /data1/DataSet/something-something/20bn-something-something-v1/ \ 12 | --ft_dataset sth_v1 --ft_mode rgb \ 13 | --ft_lr 0.001 --ft_lr_steps 10 20 25 30 35 40 --ft_epochs 45 --ft_batch_size 32 \ 14 | --ft_data_length 64 --ft_spatial_size 224 --ft_workers 8 --ft_stride 1 --ft_dropout 0.5 \ 15 | --ft_print-freq 100 --ft_fixed 0 -------------------------------------------------------------------------------- /src/Contrastive/scripts/something-something-v1/r3d_ft.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python main.py --gpus 2 --method ft \ 3 | --arch r3d18 \ 4 | --ft_weights ../experiments/sth_v1_to_sth_v1/11-10-0006/pt/models/current.pth \ 5 | --ft_train_list ../datasets/lists/something_something_v1/train_videofolder.txt \ 6 | --ft_val_list ../datasets/lists/something_something_v1/val_videofolder.txt \ 7 | --ft_root /data1/DataSet/something-something/20bn-something-something-v1/ \ 8 | --ft_dataset sth_v1 --ft_mode rgb \ 9 | --ft_lr 0.001 --ft_lr_steps 10 20 25 30 35 40 --ft_epochs 45 --ft_batch_size 6 \ 10 | --ft_data_length 64 --ft_spatial_size 224 --ft_workers 6 --ft_stride 1 --ft_dropout 0.5 \ 11 | --ft_print-freq 100 --ft_fixed 0 -------------------------------------------------------------------------------- /src/Contrastive/scripts/something-something-v1/r3d_ft_multi_gpus.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python main.py --gpus 0,1,2,3 --method ft \ 3 | --arch r3d18 \ 4 | --ft_weights ../experiments/sth_v1_to_sth_v1/11-10-0006/pt/models/current.pth \ 5 | --ft_train_list ../datasets/lists/something_something_v1/train_videofolder.txt \ 6 | --ft_val_list ../datasets/lists/something_something_v1/val_videofolder.txt \ 7 | --ft_root /data1/DataSet/something-something/20bn-something-something-v1/ \ 8 | --ft_dataset sth_v1 --ft_mode rgb \ 9 | --ft_lr 0.001 --ft_lr_steps 10 20 25 30 35 40 --ft_epochs 45 --ft_batch_size 32 \ 10 | --ft_data_length 64 --ft_spatial_size 224 --ft_workers 8 --ft_stride 1 --ft_dropout 0.5 \ 11 | --ft_print-freq 100 --ft_fixed 0 -------------------------------------------------------------------------------- /src/Contrastive/scripts/something-something-v1/r3d_pt_and_ft.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python main.py --gpus 3 --method pt_and_ft --pt_method be \ 3 | --pt_batch_size 48 --pt_workers 8 --arch r3d18 --pt_spatial_size 224 --pt_stride 4 --pt_data_length 16 \ 4 | --pt_nce_k 3569 --pt_softmax \ 5 | --pt_moco --pt_epochs 10 --pt_save_freq 4 --pt_print_freq 100 --pt_dataset sth_v1 \ 6 | --pt_train_list ../datasets/lists/something_something_v1/train_videofolder.txt \ 7 | --pt_val_list ../datasets/lists/something_something_v1/val_videofolder.txt \ 8 | --pt_root /data1/DataSet/something-something/20bn-something-something-v1/ \ 9 | --ft_train_list ../datasets/lists/something_something_v1/train_videofolder.txt \ 10 | --ft_val_list ../datasets/lists/something_something_v1/val_videofolder.txt \ 11 | --ft_root /data1/DataSet/something-something/20bn-something-something-v1/ \ 12 | --ft_dataset sth_v1 --ft_mode rgb \ 13 | --ft_lr 0.001 --ft_lr_steps 10 20 25 30 35 40 --ft_epochs 45 --ft_batch_size 48 \ 14 | --ft_data_length 16 --ft_spatial_size 224 --ft_workers 8 --ft_stride 1 --ft_dropout 0.5 \ 15 | --ft_print-freq 100 --ft_fixed 0 -------------------------------------------------------------------------------- /src/Contrastive/scripts/something-something-v1/r3d_pt_and_ft_multi_gpus.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python main.py --gpus 0,1,2,3 --method pt_and_ft --pt_method be --arch r3d18 \ 3 | --pt_batch_size 128 --pt_workers 16 --pt_spatial_size 224 --pt_stride 4 --pt_data_length 16 \ 4 | --pt_nce_k 3569 --pt_softmax \ 5 | --pt_moco --pt_epochs 10 --pt_save_freq 4 --pt_print_freq 100 --pt_dataset sth_v1 \ 6 | --pt_train_list ../datasets/lists/something_something_v1/train_videofolder.txt \ 7 | --pt_val_list ../datasets/lists/something_something_v1/val_videofolder.txt \ 8 | --pt_root /data1/DataSet/something-something/20bn-something-something-v1/ \ 9 | --ft_train_list ../datasets/lists/something_something_v1/train_videofolder.txt \ 10 | --ft_val_list ../datasets/lists/something_something_v1/val_videofolder.txt \ 11 | --ft_root /data1/DataSet/something-something/20bn-something-something-v1/ \ 12 | --ft_dataset sth_v1 --ft_mode rgb \ 13 | --ft_lr 0.001 --ft_lr_steps 10 20 25 30 35 40 --ft_epochs 45 --ft_batch_size 32 \ 14 | --ft_data_length 64 --ft_spatial_size 224 --ft_workers 8 --ft_stride 1 --ft_dropout 0.5 \ 15 | --ft_print-freq 100 --ft_fixed 0 -------------------------------------------------------------------------------- /src/Contrastive/scripts/ucf101/pt_and_ft.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python main.py --gpus 0 --method pt_and_ft --pt_method be \ 3 | --pt_batch_size 4 --pt_workers 4 --arch i3d --pt_spatial_size 224 --pt_stride 4 --pt_data_length 16 \ 4 | --pt_nce_k 3569 --pt_softmax \ 5 | --pt_moco --pt_epochs 200 --pt_save_freq 4 --pt_print_freq 100 --pt_dataset ucf101 \ 6 | --pt_train_list ../datasets/lists/ucf101/ucf101_rgb_train_split_1.txt \ 7 | --pt_val_list ../datasets/lists/ucf101/ucf101_rgb_val_split_1.txt \ 8 | --ft_train_list ../datasets/lists/ucf101/ucf101_rgb_train_split_1.txt \ 9 | --ft_val_list ../datasets/lists/ucf101/ucf101_rgb_val_split_1.txt \ 10 | --ft_dataset ucf101 --ft_mode rgb \ 11 | --ft_lr 0.001 --ft_lr_steps 10 20 25 30 35 40 --ft_epochs 45 --ft_batch_size 4 \ 12 | --ft_data_length 64 --ft_spatial_size 224 --ft_workers 4 --ft_stride 1 --ft_dropout 0.5 \ 13 | --ft_print-freq 100 --ft_fixed 0 -------------------------------------------------------------------------------- /src/Contrastive/scripts/visualization/ucf101_data_augmentation_visualize.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python augmentation_visualization.py \ 3 | --eval_indict loss --pt_loss distrubt \ 4 | --train_list ../datasets/lists/ucf101/ucf101_rgb_train_split_1.txt \ 5 | --val_list ../datasets/lists/ucf101/ucf101_rgb_val_split_1.txt \ 6 | --dataset ucf101 \ 7 | --arch i3d \ 8 | --mode rgb \ 9 | --lr 0.001 \ 10 | --lr_steps 10 20 25 30 35 40 \ 11 | --epochs 45 \ 12 | --batch_size 1 \ 13 | --data_length 16 \ 14 | --spatial_size 224 \ 15 | --workers 8 \ 16 | --stride 4 \ 17 | --dropout 0.5 \ 18 | --gpus 3 \ 19 | --logs_path ../experiments/logs/ucf101_i3d_ft \ 20 | --print-freq 100 -------------------------------------------------------------------------------- /src/Contrastive/scripts/visualization/ucf101_triplet_visualize.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python triplet_visualization.py \ 3 | --eval_indict loss --pt_loss MoCo \ 4 | --train_list ../datasets/lists/ucf101/ucf101_rgb_train_split_1.txt \ 5 | --val_list ../datasets/lists/ucf101/ucf101_rgb_val_split_1.txt \ 6 | --dataset ucf101 \ 7 | --arch i3d \ 8 | --mode rgb \ 9 | --lr 0.001 \ 10 | --lr_steps 10 20 25 30 35 40 \ 11 | --epochs 45 \ 12 | --batch_size 4 \ 13 | --data_length 16 \ 14 | --spatial_size 224 \ 15 | --workers 8 \ 16 | --stride 4 \ 17 | --dropout 0.5 \ 18 | --gpus 3 \ 19 | --logs_path ../experiments/logs/ucf101_i3d_ft \ 20 | --print-freq 100 -------------------------------------------------------------------------------- /src/Contrastive/test.py: -------------------------------------------------------------------------------- 1 | """ 2 | input: 16 x 112 x 112 x 3, with no overlapped 3 | """ 4 | import torch.nn.parallel 5 | import torch.optim 6 | from utils.utils import * 7 | import os 8 | from data.config import data_config, augmentation_config 9 | from data.dataloader import data_loader_init 10 | from model.config import model_config 11 | from bk.option_old import args 12 | import datetime 13 | 14 | 15 | def get_action_index(list_txt='data/classInd.txt'): 16 | action_label = [] 17 | with open(list_txt) as f: 18 | content = f.readlines() 19 | content = [x.strip('\r\n') for x in content] 20 | f.close() 21 | for line in content: 22 | label, action = line.split(' ') 23 | action_label.append(action) 24 | return action_label 25 | 26 | 27 | def plot_matrix_test(list_txt, cfu_mat="../experiments/evaluation/ucf101/_confusion.npy", date="",prefix="flow"): 28 | classes = get_action_index(list_txt) 29 | confuse_matrix = np.load(cfu_mat) 30 | plot_confuse_matrix(confuse_matrix, classes, date=date, prefix=prefix) 31 | plt.show() 32 | 33 | 34 | def eval_video(net, video_data): 35 | ''' 36 | average 10 clips, do it later 37 | ''' 38 | i, datas, label = video_data 39 | output = None 40 | net.eval() 41 | with torch.no_grad(): 42 | for data in datas: 43 | if len(data.size()) == 4: 44 | data = data.unsqueeze(0) 45 | # print(data.size()) 46 | overlapped_clips = 1 + int((data.size(2) - args.clip_size) / 10) 47 | # print(overlapped_clips) 48 | for i in range(overlapped_clips): 49 | if i > 1: 50 | break 51 | # print(data.size()) # 3 x 47 x 112 x 112 52 | clip_data = data[:, :, 10 * i:10 * i + args.clip_size, :, :] 53 | input_var = torch.autograd.Variable(clip_data) 54 | res = net(input_var) 55 | # print(torch.exp(res), label) 56 | res = torch.exp(res).data.cpu().numpy().copy() 57 | if output is None: 58 | output = res / overlapped_clips 59 | else: 60 | output += res / overlapped_clips 61 | return output, label 62 | 63 | 64 | def main(prefix='flow_scratch'): 65 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus 66 | # == dataset config== 67 | num_class, data_length, image_tmpl = data_config(args) 68 | train_transforms, test_transforms, eval_transforms = augmentation_config(args) 69 | _, eval_data_loader, _, _, _, _ = data_loader_init(args, data_length, image_tmpl, train_transforms, 70 | test_transforms, eval_transforms) 71 | model = model_config(args, num_class) 72 | output = [] 73 | total_num = len(eval_data_loader) 74 | for i, (data, label, index) in enumerate(eval_data_loader): 75 | proc_start_time = time.time() 76 | rst = eval_video(model, (i, data, label)) 77 | output.append(rst) 78 | cnt_time = time.time() - proc_start_time 79 | if i % 10 == 0: 80 | print('video {} done, total {}/{}, average {} sec/video'.format(i, i + 1, 81 | total_num, 82 | float(cnt_time) / (i + 1))) 83 | if i > 300: 84 | video_pred = [np.argmax(x[0]) for x in output] 85 | video_labels = [x[1] for x in output] 86 | cf = confusion_matrix(video_labels, video_pred).astype(float) 87 | cls_cnt = cf.sum(axis=1) 88 | cls_hit = np.diag(cf) 89 | cls_acc = cls_hit / cls_cnt 90 | print('Accuracy {:.02f}%'.format(np.mean(cls_acc) * 100)) 91 | 92 | date = datetime.datetime.today().strftime('%m-%d-%H%M') 93 | # =====output: every video's num and every video's label 94 | # =====x[0]:softmax value x[1]:label 95 | if not os.path.isdir("../experiments/evaluation/{}/{}".format(args.dataset, date)): 96 | os.mkdir("../experiments/evaluation/{}/{}".format(args.dataset, date)) 97 | video_pred = [np.argmax(x[0]) for x in output] 98 | np.save("../experiments/evaluation/{}/{}/{}_video_pred.npy".format(args.dataset, date, prefix), video_pred) 99 | video_labels = [x[1] for x in output] 100 | np.save("../experiments/evaluation/{}/{}/{}_video_labels.npy".format(args.dataset, date, prefix), video_labels) 101 | cf = confusion_matrix(video_labels, video_pred).astype(float) 102 | np.save("../experiments/evaluation/{}/{}/{}_confusion.npy".format(args.dataset, date, prefix), cf) 103 | cf_name = "../experiments/evaluation/{}/{}/{}_confusion.npy".format(args.dataset, date, prefix) 104 | cls_cnt = cf.sum(axis=1) 105 | cls_hit = np.diag(cf) 106 | cls_acc = cls_hit / cls_cnt 107 | print(cls_acc) 108 | print('Accuracy {:.02f}%'.format(np.mean(cls_acc) * 100)) 109 | 110 | name_list = [x.strip().split()[0] for x in open(args.val_list)] 111 | order_dict = {e: i for i, e in enumerate(sorted(name_list))} 112 | reorder_output = [None] * len(output) 113 | reorder_label = [None] * len(output) 114 | for i in range(len(output)): 115 | idx = order_dict[name_list[i]] 116 | reorder_output[idx] = output[i] 117 | reorder_label[idx] = video_labels[i] 118 | np.savez('../experiments/evaluation/' + args.dataset + '/' + date + "/" + prefix + args.mode + 'res', 119 | scores=reorder_output, labels=reorder_label) 120 | return cf_name, date 121 | 122 | 123 | def plot_confuse_matrix(matrix, classes, 124 | date="", 125 | prefix="flow", 126 | normalize=True, 127 | title=None, 128 | cmap=plt.cm.Blues 129 | ): 130 | """ 131 | :param matrix: 132 | :param classes: 133 | :param normalize: 134 | :param title: 135 | :param cmap: 136 | :return: 137 | """ 138 | if not title: 139 | if normalize: 140 | title = 'Normalized confusion matrix' 141 | else: 142 | title = 'Confusion matrix, without normalization' 143 | 144 | # Compute confusion matrix 145 | cm = matrix 146 | # Only use the labels that appear in the data 147 | if normalize: 148 | cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] 149 | print("Normalized confusion matrix") 150 | else: 151 | print('Confusion matrix, without normalization') 152 | 153 | print(cm) 154 | 155 | fig, ax = plt.subplots() 156 | 157 | # We change the fontsize of minor ticks label 158 | ax.tick_params(axis='both', which='major', labelsize=6) 159 | ax.tick_params(axis='both', which='minor', labelsize=4) 160 | 161 | im = ax.imshow(cm, interpolation='nearest', cmap=cmap) 162 | ax.figure.colorbar(im, ax=ax) 163 | # We want to show all ticks... 164 | ax.set(xticks=np.arange(cm.shape[1]), 165 | yticks=np.arange(cm.shape[0]), 166 | # ... and label them with the respective list entries 167 | xticklabels=classes, yticklabels=classes, 168 | title=title, 169 | ylabel='True label', 170 | xlabel='Predicted label') 171 | 172 | # Rotate the tick labels and set their alignment. 173 | plt.setp(ax.get_xticklabels(), rotation=60, ha="right", 174 | rotation_mode="anchor") 175 | 176 | # Loop over data dimensions and create text annotations. 177 | # fmt = '.2f' 178 | # thresh = cm.max() / 2. 179 | # for i in range(cm.shape[0]): 180 | # for j in range(cm.shape[1]): 181 | # ax.text(j, i, format(cm[i, j], fmt), 182 | # ha="center", va="center", 183 | # color="white" if cm[i, j] > thresh else "black") 184 | # fig.tight_layout() 185 | print("date is: {}".format(date)) 186 | plt.savefig("../experiments/evaluation/hmdb51/{}/{}confuse.png".format(date, prefix)) 187 | return ax 188 | 189 | 190 | if __name__ == '__main__': 191 | # prefix = 'TCA' 192 | cf_name, date = main(args.prefix) 193 | # cf_name = "../experiments/evaluation/hmdb51/03-24-1659/confusion.npy" 194 | classList = "../datasets/lists/hmdb51/hmdb51_classInd.txt" 195 | plot_matrix_test(classList, cf_name, date, prefix=args.prefix) -------------------------------------------------------------------------------- /src/Contrastive/utils/data_process/gen_diving48_frames.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | import sys 4 | import subprocess 5 | 6 | 7 | def process(dir_path, dst_dir_path): 8 | class_path = dir_path 9 | if not os.path.isdir(class_path): 10 | return 11 | 12 | dst_class_path = dst_dir_path 13 | if not os.path.exists(dst_class_path): 14 | os.mkdir(dst_class_path) 15 | 16 | for file_name in os.listdir(class_path): 17 | if '.mp4' not in file_name: 18 | continue 19 | name, ext = os.path.splitext(file_name) 20 | dst_directory_path = os.path.join(dst_class_path, name) 21 | 22 | video_file_path = os.path.join(class_path, file_name) 23 | try: 24 | if os.path.exists(dst_directory_path): 25 | if not os.path.exists(os.path.join(dst_directory_path, 'image_00001.jpg')): 26 | subprocess.call('rm -r \"{}\"'.format(dst_directory_path), shell=True) 27 | print('remove {}'.format(dst_directory_path)) 28 | os.mkdir(dst_directory_path) 29 | else: 30 | continue 31 | else: 32 | os.mkdir(dst_directory_path) 33 | except: 34 | print(dst_directory_path) 35 | continue 36 | cmd = 'ffmpeg -i \"{}\" -vf scale=-1:240 \"{}/image_%05d.jpg\"'.format(video_file_path, dst_directory_path) 37 | print(cmd) 38 | subprocess.call(cmd, shell=True) 39 | print('\n') 40 | 41 | 42 | if __name__=="__main__": 43 | # /data1/DataSet/Diving48/rgb 44 | # /data1/DataSet/Diving48/rgb_frames 45 | dir_path = sys.argv[1] 46 | dst_dir_path = sys.argv[2] 47 | 48 | process(dir_path, dst_dir_path) 49 | -------------------------------------------------------------------------------- /src/Contrastive/utils/data_process/gen_diving48_lists.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | out = [] 5 | count = 0 6 | front = '' 7 | with open('../datasets/lists/diving48/lists/Diving48_V2_train.json', 'r') as f: 8 | lines = json.load(f) 9 | for line in lines: 10 | # line = line.strip() 11 | count += 1 12 | new_line = front + str(line['vid_name']) + ' ' + str(line['end_frame']) + ' ' + str(line['label']) + '\n' 13 | out.append(new_line) 14 | if count % 100 == 0 and count != 0: 15 | print(count) 16 | 17 | with open('../datasets/lists/diving48/diving48_v2_train_no_front.txt', 'a') as f: 18 | f.writelines(out) -------------------------------------------------------------------------------- /src/Contrastive/utils/data_process/gen_hmdb51_dir.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | import sys 4 | import subprocess 5 | 6 | 7 | def process(dir_path, dst_dir_path): 8 | count0 = 0 9 | for class_name in os.listdir(dir_path): 10 | count1 = 0 11 | class_path = os.path.join(dir_path, class_name) 12 | cmd = 'mv {}/* {}'.format(class_path, dst_dir_path) 13 | subprocess.call(cmd, shell=True) 14 | count1 += 1 15 | count0 += 1 16 | 17 | 18 | if __name__=="__main__": 19 | # /data1/DataSet/hmdb51_sta_frames 20 | # /data1/DataSet/hmdb51_sta_frames2 21 | dir_path = sys.argv[1] 22 | dst_dir_path = sys.argv[2] 23 | 24 | process(dir_path, dst_dir_path) 25 | -------------------------------------------------------------------------------- /src/Contrastive/utils/data_process/gen_hmdb51_frames.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | import sys 4 | import subprocess 5 | 6 | 7 | def process(dir_path, dst_dir_path): 8 | count0 = 0 9 | for class_name in os.listdir(dir_path): 10 | count1 = 0 11 | class_path = os.path.join(dir_path, class_name) 12 | if not os.path.isdir(class_path): 13 | return 14 | 15 | dst_class_path = os.path.join(dst_dir_path, class_name) 16 | if not os.path.exists(dst_class_path): 17 | os.mkdir(dst_class_path) 18 | for file_name in os.listdir(class_path): 19 | if file_name[-4:] != '.avi': 20 | continue 21 | name, ext = os.path.splitext(file_name) 22 | dst_directory_path = os.path.join(dst_class_path, name) 23 | 24 | video_file_path = os.path.join(class_path, file_name) 25 | try: 26 | if os.path.exists(dst_directory_path): 27 | if not os.path.exists(os.path.join(dst_directory_path, 'image_00001.jpg')): 28 | subprocess.call('rm -r \"{}\"'.format(dst_directory_path), shell=True) 29 | print('remove {}'.format(dst_directory_path)) 30 | os.mkdir(dst_directory_path) 31 | else: 32 | continue 33 | else: 34 | os.mkdir(dst_directory_path) 35 | except: 36 | print(dst_directory_path) 37 | continue 38 | cmd = 'ffmpeg -i \"{}\" -vf scale=-1:240 \"{}/image_%05d.jpg\"'.format(video_file_path, dst_directory_path) 39 | # print(cmd) 40 | subprocess.call(cmd, shell=True) 41 | count1 += 1 42 | print("{}/{} classes: {}/{} videos finished".format(count0, len(os.listdir(dir_path)), count1, len(os.listdir(class_path)))) 43 | # print('\n') 44 | count0 += 1 45 | 46 | 47 | if __name__=="__main__": 48 | # /data1/DataSet/hmdb51_sta_new 49 | # /data1/DataSet/hmdb51_sta_frames 50 | dir_path = sys.argv[1] 51 | dst_dir_path = sys.argv[2] 52 | 53 | process(dir_path, dst_dir_path) 54 | -------------------------------------------------------------------------------- /src/Contrastive/utils/data_process/gen_hmdb51_sta_list.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | out = [] 5 | count = 0 6 | front = '' 7 | with open('../datasets/lists/hmdb51/lists/Diving48_V2_train.json', 'r') as f: 8 | lines = json.load(f) 9 | for line in lines: 10 | # line = line.strip() 11 | count += 1 12 | new_line = front + str(line['vid_name']) + ' ' + str(line['end_frame']) + ' ' + str(line['label']) + '\n' 13 | out.append(new_line) 14 | if count % 100 == 0 and count != 0: 15 | print(count) 16 | 17 | with open('../datasets/lists/diving48/diving48_v2_train_no_front.txt', 'a') as f: 18 | f.writelines(out) -------------------------------------------------------------------------------- /src/Contrastive/utils/data_process/gen_sub_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def data_split(path, new_file_path, radio): 4 | video_list = [x for x in open(path)] 5 | # if not os.path.exists(new_file_path): 6 | count = 0 7 | new_list = [] 8 | for line in video_list: 9 | count += 1 10 | if count % radio == 0: 11 | new_list.append(line) 12 | with open(new_file_path, 'w') as f: 13 | for item in new_list: 14 | f.write("%s" % item) 15 | 16 | if __name__ == '__main__': 17 | radio = 10 18 | path = "../datasets/lists/kinetics-400/ssd_kinetics_video_trainlist.txt" 19 | new_file_path = "../datasets/lists/kinetics-400/ssd_kinetics_video_trainlist_{}of{}.txt".format(1, radio) 20 | data_split(path, new_file_path, radio) -------------------------------------------------------------------------------- /src/Contrastive/utils/data_process/semi_data_split.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | 5 | def split(file, new_file, radio=0.1, cls_num=51): 6 | """ 7 | ucf101/hmdb51 8 | :param file: 9 | :param new_file: 10 | :param radio: 11 | :param cls_num: 12 | :return: 13 | """ 14 | cls_nums = np.zeros(cls_num) 15 | f = open(file, 'r') 16 | for data in f.readlines(): 17 | name, frame, cls = data.split(" ") 18 | cls = int(cls) 19 | cls_nums[cls] += 1 20 | new_lists = list() 21 | new_cls_nums = np.zeros(cls_num) 22 | f = open(file, 'r') 23 | for data in f.readlines(): 24 | name, frame, cls = data.split(" ") 25 | cls = int(cls) 26 | if new_cls_nums[cls] < radio * cls_nums[cls]: 27 | new_lists.append(data) 28 | new_cls_nums[cls] += 1 29 | with open(new_file, 'w') as f: 30 | f.writelines(new_lists) 31 | return 32 | 33 | 34 | def split_kinetics(file, new_file, radio=0.1, cls_num=400): 35 | cls_nums = np.zeros(cls_num) 36 | f = open(file, 'r') 37 | for data in f.readlines(): 38 | name, cls = data.split(" ") 39 | cls = int(cls) 40 | cls_nums[cls] += 1 41 | new_lists = list() 42 | new_cls_nums = np.zeros(cls_num) 43 | f = open(file, 'r') 44 | for data in f.readlines(): 45 | name, cls = data.split(" ") 46 | cls = int(cls) 47 | if new_cls_nums[cls] < radio * cls_nums[cls]: 48 | new_lists.append(data) 49 | new_cls_nums[cls] += 1 50 | with open(new_file, 'w') as f: 51 | f.writelines(new_lists) 52 | return 53 | 54 | 55 | if __name__ == '__main__': 56 | file = '../datasets/lists/kinetics-400/ssd_kinetics_video_trainlist.txt' 57 | radio = 0.5 58 | new_file = '../datasets/lists/kinetics-400/{}_ssd_kinetics_video_trainlist.txt'.format(radio) 59 | split_kinetics(file, new_file, radio=radio, cls_num=400) -------------------------------------------------------------------------------- /src/Contrastive/utils/data_process/unrar_hmdb51_sta.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # !/bin/bash 3 | 4 | src_path=`readlink -f $1` 5 | dst_path=`readlink -f $2` 6 | 7 | rar_files=`find $src_path -name '*.rar'` 8 | IFS=$'\n'; array=$rar_files; unset IFS 9 | for rar_file in $array; do 10 | file_path=`echo $rar_file | sed -e "s;$src_path;$dst_path;"` 11 | ext_path=${file_path%/*} 12 | if [ ! -d $ext_path ]; then 13 | mkdir -p $ext_path 14 | fi 15 | unrar x $rar_file $ext_path 16 | done 17 | 18 | # bash utils/data_process/unrar_hmdb51_sta.sh /data1/DataSet/hmdb51_sta /data1/DataSet/hmdb51_sta_new -------------------------------------------------------------------------------- /src/Contrastive/utils/gradient_check.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | from matplotlib.lines import Line2D 4 | 5 | # 6 | # def plot_grad_flow(named_parameters): 7 | # '''Plots the gradients flowing through different layers in the net during training. 8 | # Can be used for checking for possible gradient vanishing / exploding problems. 9 | # 10 | # Usage: Plug this function in Trainer class after loss.backwards() as 11 | # "plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow''' 12 | # # for k, v in named_parameters: 13 | # # print(k) 14 | # ave_grads = [] 15 | # max_grads = [] 16 | # layers = [] 17 | # for n, p in named_parameters: 18 | # if (p.requires_grad) and ("bias" not in n): 19 | # layers.append(n) 20 | # ave_grads.append(p.grad.abs().mean()) 21 | # max_grads.append(p.grad.abs().max()) 22 | # plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c") 23 | # plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b") 24 | # plt.hlines(0, 0, len(ave_grads) + 1, lw=2, color="k") 25 | # plt.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical") 26 | # plt.xlim(left=0, right=len(ave_grads)) 27 | # plt.ylim(bottom=-0.001, top=0.02) # zoom in on the lower gradient regions 28 | # plt.xlabel("Layers") 29 | # plt.ylabel("average gradient") 30 | # plt.title("Gradient flow") 31 | # plt.grid(True) 32 | # plt.legend([Line2D([0], [0], color="c", lw=4), 33 | # Line2D([0], [0], color="b", lw=4), 34 | # Line2D([0], [0], color="k", lw=4)], ['max-gradient', 'mean-gradient', 'zero-gradient']) 35 | # plt.savefig('../experiments/gradients/r2p1d_gradient_benchmark.png', dpi=720, bbox_inches='tight') 36 | 37 | 38 | def plot_grad_flow(named_parameters): 39 | '''Plots the gradients flowing through different layers in the net during training. 40 | Can be used for checking for possible gradient vanishing / exploding problems. 41 | 42 | Usage: Plug this function in Trainer class after loss.backwards() as 43 | "plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow''' 44 | ave_grads = [] 45 | max_grads = [] 46 | layers = [] 47 | for n, p in named_parameters: 48 | if (p.requires_grad) and ("bias" not in n): 49 | layers.append(n) 50 | ave_grads.append(p.grad.abs().mean()) 51 | max_grads.append(p.grad.abs().max()) 52 | plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c") 53 | plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b") 54 | plt.hlines(0, 0, len(ave_grads) + 1, lw=2, color="k") 55 | plt.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical") 56 | plt.xlim(left=0, right=len(ave_grads)) 57 | plt.ylim(bottom=-0.001, top=0.02) # zoom in on the lower gradient regions 58 | plt.xlabel("Layers") 59 | plt.ylabel("average gradient") 60 | plt.title("Gradient flow") 61 | plt.grid(True) 62 | plt.legend([Line2D([0], [0], color="c", lw=4), 63 | Line2D([0], [0], color="b", lw=4), 64 | Line2D([0], [0], color="k", lw=4)], ['max-gradient', 'mean-gradient', 'zero-gradient']) -------------------------------------------------------------------------------- /src/Contrastive/utils/learning_rate_adjust.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def pt_adjust_learning_rate(epoch, opt, optimizer): 5 | """Sets the learning rate to the initial LR decayed by 0.2 every steep step""" 6 | # if epoch < 2: 7 | # for param_group in optimizer.param_groups: 8 | # param_group['lr'] = 1e-7 9 | # return 0 10 | # print(epoch) 11 | # print(np.asarray(opt.pt_lr_decay_epochs)) 12 | steps = np.sum(epoch > np.asarray(opt.pt_lr_decay_epochs)) 13 | if steps > 0: 14 | new_lr = opt.pt_learning_rate * (opt.pt_lr_decay_rate ** steps) 15 | for param_group in optimizer.param_groups: 16 | param_group['lr'] = new_lr 17 | 18 | 19 | def ft_adjust_learning_rate(optimizer, intial_lr, epoch, lr_steps): 20 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 21 | decay = 0.3 ** (sum(epoch >= np.array(lr_steps))) 22 | lr = intial_lr * decay 23 | for param_group in optimizer.param_groups: 24 | param_group['lr'] = lr 25 | 26 | -------------------------------------------------------------------------------- /src/Contrastive/utils/load_weights.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.nn.init import xavier_uniform_, constant_, zeros_, normal_, kaiming_uniform_, kaiming_normal_ 4 | import torch.nn as nn 5 | import math 6 | 7 | 8 | def weight_transform(model_dict, pretrain_dict): 9 | ''' 10 | 11 | :return: 12 | ''' 13 | count = 0 14 | # for k, v in pretrain_dict.items(): 15 | # count += 1 16 | # print(k) 17 | # if count > 100: 18 | # break 19 | # for k, v in model_dict.items(): 20 | # print(k) 21 | # count += 1 22 | # if count > 20: 23 | # break 24 | 25 | weight_dict = {k:v for k, v in pretrain_dict.items() if k in model_dict and 'custom' not in k} # and 'custom' not in k 26 | for k, v in weight_dict.items(): 27 | print("load: {}".format(k)) 28 | # print(weight_dict) 29 | model_dict.update(weight_dict) 30 | return model_dict 31 | 32 | 33 | def weights_init(model): 34 | """ Initializes the weights of the CNN model using the Xavier 35 | initialization. 36 | """ 37 | # for m in model.modules(): 38 | # if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv3d) or isinstance(m, nn.Conv1d): 39 | # xavier_uniform_(m.weight, gain=math.sqrt(2.0)) 40 | # if m.bias: 41 | # constant_(m.bias, 0.1) 42 | # elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.BatchNorm3d): 43 | # normal_(m.weight, 1.0, 0.02) 44 | # zeros_(m.bias) 45 | if isinstance(model, nn.Conv2d) or isinstance(model, nn.Conv3d) or isinstance(model, nn.Conv1d): 46 | xavier_uniform_(model.weight.data, gain=math.sqrt(2.0)) 47 | # if model.bias: 48 | # constant_(model.bias.data, 0.1) 49 | elif isinstance(model, nn.BatchNorm2d) or isinstance(model, nn.BatchNorm1d) or isinstance(model, nn.BatchNorm3d): 50 | normal_(model.weight.data, 1.0, 0.02) 51 | 52 | 53 | def pt_load_weight(args, model, model_ema, optimizer, contrast): 54 | # random initialization 55 | model.apply(weights_init) 56 | model_ema.apply(weights_init) 57 | if args.pt_resume: 58 | if os.path.isfile(args.pt_resume): 59 | print("=> loading checkpoint '{}'".format(args.resume)) 60 | checkpoint = torch.load(args.resume, map_location='cpu') 61 | # checkpoint = torch.load(args.resume) 62 | args.start_epoch = checkpoint['epoch'] + 1 63 | model.load_state_dict(checkpoint['model']) 64 | optimizer.load_state_dict(checkpoint['optimizer']) 65 | contrast.load_state_dict(checkpoint['contrast']) 66 | model_ema.load_state_dict(checkpoint['model_ema']) 67 | print("=> loaded successfully '{}' (epoch {})" 68 | .format(args.resume, checkpoint['epoch'])) 69 | del checkpoint 70 | torch.cuda.empty_cache() 71 | else: 72 | print("=> no checkpoint found at '{}'".format(args.resume)) 73 | 74 | return model, model_ema 75 | 76 | 77 | def ft_load_weight(args, model): 78 | if args.ft_weights == "": 79 | print("no pretrained model available. Train from Scratch.....................") 80 | model.apply(weights_init) 81 | # weights_init(model) 82 | else: 83 | # weights_init(model) 84 | model.apply(weights_init) 85 | checkpoint = torch.load(args.ft_weights) 86 | try: 87 | print("model epoch {} lowese val: {}".format(checkpoint['epoch'], checkpoint['lowest_val'])) 88 | except KeyError as e: 89 | try: 90 | print("model epoch {} best prec@1: {}".format(checkpoint['epoch'], checkpoint['best_prec1'])) 91 | except KeyError: 92 | print("not train from this code!") 93 | try: 94 | # print("????") 95 | print(args.ft_weights.split('/')[-1][:9]) 96 | if args.ft_weights.split('/')[-1][:11] == 'mutual_loss': 97 | pretrain_dict = {('.'.join(k.split('.')[2:]))[2:]: v for k, v in list(checkpoint['state_dict'].items())} 98 | elif args.ft_weights.split('/')[-1][:9] == 'byol_ckpt': 99 | # print("???") 100 | pretrain_dict = {k[26:]: v for k, v in list(checkpoint['model'].items())} 101 | # pretrain_dict = {k[7:]: v for k, v in list(checkpoint['model_ema'].items())} 102 | elif args.ft_weights.split('/')[-1][:4] in ['ckpt', 'curr', 'ucf1']: 103 | # print("???") 104 | pretrain_dict = {k[7:]: v for k, v in list(checkpoint['model'].items())} 105 | # pretrain_dict = {k: v for k, v in list(checkpoint['model'].items())} 106 | elif args.ft_weights.split('/')[-1][:3] == 'i3d': 107 | pretrain_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint['state_dict'].items())} 108 | else: 109 | pretrain_dict = {'.'.join(k.split('.')[2:]): v for k, v in list(checkpoint['state_dict'].items())} 110 | except KeyError: 111 | # pretrain_dict = checkpoint['model'] 112 | pretrain_dict = checkpoint 113 | # pretrain_dict = {k[26:]: v for k, v in list(checkpoint['state_dict'].items())} 114 | model_dict = model.state_dict() 115 | model_dict = weight_transform(model_dict, pretrain_dict) 116 | model.load_state_dict(model_dict) 117 | if args.ft_resume: 118 | if os.path.isfile(args.ft_resume): 119 | print(("=> loading checkpoints '{}'".format(args.ft_resume))) 120 | checkpoint = torch.load(args.ft_resume) 121 | args.start_epoch = checkpoint['epoch'] 122 | best_prec1 = checkpoint['best_prec1'] 123 | model.load_state_dict(checkpoint['state_dict']) 124 | print(("=> loaded checkpoints '{}' (epoch {}) best_prec1 {}" 125 | .format(args.evaluate, checkpoint['epoch'], best_prec1))) 126 | else: 127 | print(("=> no checkpoints found at '{}'".format(args.ft_resume))) 128 | return model -------------------------------------------------------------------------------- /src/Contrastive/utils/moment_update.py: -------------------------------------------------------------------------------- 1 | def update_ema_variables(ema_model, model, alpha=0.9): 2 | # Use the true average until the exponential average is more correct 3 | for ema_param, param in zip(ema_model.parameters(), model.parameters()): 4 | ema_param.data.mul_(alpha).add_(1 - alpha, param.data) 5 | 6 | 7 | def moment_update(model, model_ema, m): 8 | """ model_ema = m * model_ema + (1 - m) model """ 9 | for p1, p2 in zip(model.parameters(), model_ema.parameters()): 10 | p2.data.mul_(m).add_(1 - m, p1.detach().data) 11 | # p2.data.mul_(m).add_(1 - m, p1.data) -------------------------------------------------------------------------------- /src/Contrastive/utils/recoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import os 5 | import shutil 6 | import datetime 7 | import shutil 8 | 9 | 10 | class Record: 11 | def __init__(self, args): 12 | super(Record, self).__init__() 13 | if args.method == 'pt_and_ft': 14 | self.origin_path = '../experiments/' + args.pt_dataset + '_to_' + args.ft_dataset 15 | elif args.method == 'pt': 16 | self.origin_path = '../experiments/pt_' + args.pt_dataset 17 | else: 18 | self.origin_path = '../experiments/ft_' + args.ft_dataset 19 | if not os.path.exists(self.origin_path): 20 | os.mkdir(self.origin_path) 21 | self.path = self.origin_path + '/' + args.date 22 | if not os.path.exists(self.path): 23 | os.mkdir(self.path) 24 | if args.method in ['pt', 'pt_and_ft']: 25 | self.pt_path = self.path + '/pt' 26 | self.pt_model_path = '{}/models'.format(self.pt_path) 27 | if not os.path.exists(self.pt_path): 28 | os.mkdir(self.pt_path) 29 | if not os.path.exists(self.pt_model_path): 30 | os.mkdir(self.pt_model_path) 31 | if args.method in ['ft', 'pt_and_ft']: 32 | self.ft_path = self.path + '/ft' 33 | self.ft_model_path = '{}/models'.format(self.ft_path) 34 | if not os.path.exists(self.ft_path): 35 | os.mkdir(self.ft_path) 36 | if not os.path.exists(self.ft_model_path): 37 | os.mkdir(self.ft_model_path) 38 | self.args = args 39 | # pretrain init 40 | self.pt_init() 41 | self.pt_train_loss_list = list() 42 | self.pt_checkpoint = '' 43 | # finetune init 44 | self.ft_init() 45 | self.ft_train_acc_list = list() 46 | self.ft_val_acc_list = list() 47 | self.ft_train_loss_list = list() 48 | self.ft_val_loss_list = list() 49 | self.front = self.args.method 50 | print(self.args) 51 | self.record_txt = os.path.join(self.path, self.front + '_logs.txt') 52 | self.record_init(args, 'w') 53 | self.src_init() 54 | self.filename = '' 55 | self.best_name = '' 56 | 57 | def pt_init(self): 58 | return 59 | 60 | def ft_init(self): 61 | return 62 | 63 | def src_init(self): 64 | if not os.path.exists(self.path + '/src_record'): 65 | shutil.copytree('../src', self.path + '/src_record') 66 | 67 | def record_init(self, args, open_type): 68 | with open(self.record_txt, open_type) as f: 69 | for arg in vars(args): 70 | f.write('{}: {}\n'.format(arg, getattr(args, arg))) 71 | f.write('\n') 72 | 73 | def record_message(self, open_type, message): 74 | with open(self.record_txt, open_type) as f: 75 | f.write(message + '\n\n') 76 | 77 | def record_ft_train(self, loss, acc=0): 78 | self.ft_train_acc_list.append(acc) 79 | self.ft_train_loss_list.append(loss) 80 | 81 | def record_ft_val(self, loss, acc=0): 82 | self.ft_val_acc_list.append(acc) 83 | self.ft_val_loss_list.append(loss) 84 | 85 | def record_pt_train(self, loss): 86 | self.pt_train_loss_list.append(loss) 87 | 88 | def plot_figure(self, plot_list, name='_performance'): 89 | epoch = len(plot_list[0][0]) 90 | axis = np.linspace(1, epoch, epoch) 91 | fig = plt.figure() 92 | plt.title(self.args.arch + '_' + self.args.status + name) 93 | for i in range(len(plot_list)): 94 | plt.plot(axis, plot_list[i][0], label=plot_list[i][1]) 95 | plt.legend() 96 | plt.xlabel('Epochs') 97 | plt.ylabel('%') 98 | plt.grid(True) 99 | plt.savefig(os.path.join(self.path, '{}.pdf'.format(self.args.status + name))) 100 | plt.close(fig) 101 | 102 | def save_ft_model(self, model, is_best=False): 103 | self.save_ft_checkpoint(self.args, model, is_best) 104 | plot_list = list() 105 | plot_list.append([self.ft_train_acc_list, 'train_acc']) 106 | plot_list.append([self.ft_val_acc_list, 'val_acc']) 107 | plot_list.append([self.ft_train_loss_list, 'train_loss']) 108 | plot_list.append([self.ft_val_loss_list, 'val_loss']) 109 | self.plot_figure(plot_list) 110 | 111 | def save_pt_model(self, args, state, epoch): 112 | self.save_pt_checkpoint(args, state, epoch) 113 | plot_list = list() 114 | plot_list.append([self.pt_train_loss_list, 'train_loss']) 115 | self.plot_figure(plot_list) 116 | print('==> Saving...') 117 | 118 | def save_pt_checkpoint(self, args, state, epoch): 119 | save_file = os.path.join(self.pt_model_path, 'current.pth') 120 | self.pt_checkpoint = save_file 121 | torch.save(state, save_file) 122 | if epoch % args.pt_save_freq == 0: 123 | save_file = os.path.join(self.pt_model_path, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) 124 | torch.save(state, save_file) 125 | # help release GPU memory 126 | del state 127 | torch.cuda.empty_cache() 128 | 129 | def save_ft_checkpoint(self, args, state, is_best): 130 | self.filename = self.ft_path + '/' + args.ft_mode + '_model_latest.pth.tar' 131 | torch.save(state, self.filename) 132 | if is_best: 133 | self.best_name = self.ft_path + '/' + args.ft_mode + '_model_best.pth.tar' 134 | shutil.copyfile(self.filename, self.best_name) 135 | -------------------------------------------------------------------------------- /src/Contrastive/utils/utils.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import numpy as np 3 | import torch 4 | import time 5 | import shutil 6 | import matplotlib.pyplot as plt 7 | # from sklearn.metrics import confusion_matrix 8 | 9 | class Timer(object): 10 | """ 11 | docstring for Timer 12 | """ 13 | def __init__(self): 14 | super(Timer, self).__init__() 15 | self.total_time = 0.0 16 | self.calls = 0 17 | self.start_time = 0.0 18 | self.diff = 0.0 19 | self.average_time = 0.0 20 | 21 | def tic(self): 22 | self.start_time = time.time() 23 | 24 | def toc(self, average = False): 25 | self.diff = time.time() - self.start_time 26 | self.calls += 1 27 | self.total_time += self.diff 28 | self.average_time = self.total_time / self.calls 29 | if average: 30 | return self.average_time 31 | else: 32 | return self.diff 33 | 34 | def format(self, time): 35 | m,s = divmod(time, 60) 36 | h,m = divmod(m, 60) 37 | d,h = divmod(h, 24) 38 | return ("{}d:{}h:{}m:{}s".format(int(d), int(h), int(m), int(s))) 39 | 40 | def end_time(self, extra_time): 41 | """ 42 | calculate the end time for training, show local time 43 | """ 44 | localtime= time.asctime(time.localtime(time.time() + extra_time)) 45 | return localtime 46 | 47 | 48 | class AverageMeter(object): 49 | """Computes and stores the average and current value""" 50 | 51 | def __init__(self): 52 | self.reset() 53 | 54 | def reset(self): 55 | self.val = 0 56 | self.avg = 0 57 | self.sum = 0 58 | self.count = 0 59 | 60 | def update(self, val, n=1): 61 | self.val = val 62 | self.sum += val * n 63 | self.count += n 64 | self.avg = self.sum / self.count 65 | 66 | 67 | class Logger(object): 68 | 69 | def __init__(self, path, header): 70 | self.log_file = open(path, 'w') 71 | self.logger = csv.writer(self.log_file, delimiter='\t') 72 | 73 | self.logger.writerow(header) 74 | self.header = header 75 | 76 | def __del(self): 77 | self.log_file.close() 78 | 79 | def log(self, values): 80 | write_values = [] 81 | for col in self.header: 82 | assert col in values 83 | write_values.append(values[col]) 84 | 85 | self.logger.writerow(write_values) 86 | self.log_file.flush() 87 | 88 | 89 | def load_value_file(file_path): 90 | with open(file_path, 'r') as input_file: 91 | value = float(input_file.read().rstrip('\n\r')) 92 | 93 | return value 94 | 95 | 96 | def calculate_accuracy(outputs, targets): 97 | batch_size = targets.size(0) 98 | 99 | _, pred = outputs.topk(1, 1, True) 100 | pred = pred.t() 101 | correct = pred.eq(targets.view(1, -1)) 102 | n_correct_elems = correct.float().sum().data[0] 103 | 104 | return n_correct_elems / batch_size 105 | 106 | 107 | class TrainingHelper(object): 108 | def __init__(self, image): 109 | self.image = image 110 | def congratulation(self): 111 | """ 112 | if finish training success, print congratulation information 113 | """ 114 | for i in range(40): 115 | print('*')*i 116 | print('finish training') 117 | 118 | 119 | def submission_file(ids, outputs, filename): 120 | """ write list of ids and outputs to filename""" 121 | with open(filename, 'w') as f: 122 | for vid, output in zip(ids, outputs): 123 | scores = ['{:g}'.format(x) 124 | for x in output] 125 | f.write('{} {}\n'.format(vid, ' '.join(scores))) 126 | 127 | 128 | def accuracy(output, target, topk=(1,)): 129 | """ 130 | Computes the precision@k for the specified values of k 131 | output: 16(batch_size) x 101 132 | target: 16 x 1 133 | """ 134 | maxk = max(topk) 135 | batch_size = target.size(0) 136 | 137 | _, pred = output.topk(maxk, 1, True, True) 138 | pred = pred.t() 139 | correct = pred.eq(target.view(1, -1).expand_as(pred)) # 5 x 16 140 | #print(correct) 141 | 142 | res = [] 143 | for k in topk: 144 | correct_k = correct[:k].view(-1).float().sum(0) 145 | res.append(correct_k.mul_(100.0 / batch_size)) 146 | return res 147 | 148 | 149 | def accuracy_mixup(output, targets, target_a, target_b, lam, topk=(1,)): 150 | """Computes the precision@k for the specified values of k""" 151 | maxk = max(topk) 152 | batch_size = targets.size(0) 153 | 154 | _, pred = output.topk(maxk, 1, True, True) 155 | pred = pred.t() #5 x 20 156 | correct_1 = pred.eq(target_a.data.view(1, -1).expand_as(pred)) 157 | correct_2 = pred.eq(target_b.data.view(1, -1).expand_as(pred)) 158 | 159 | res = [] 160 | for k in topk: 161 | correct_k = lam * correct_1[:k].view(-1).float().sum(0) + (1-lam) * correct_2[:k].view(-1).float().sum(0) 162 | res.append(correct_k.mul_(100.0 / batch_size)) 163 | return res 164 | -------------------------------------------------------------------------------- /src/Contrastive/utils/visualization/augmentation_visualization.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import os 4 | import torch 5 | import torch.backends.cudnn as cudnn 6 | from data.config import data_config, augmentation_config 7 | from data.dataloader import data_loader_init 8 | from model.config import model_config 9 | from augment.gen_positive import GenPositive 10 | from bk.option_old import args 11 | import numpy as np 12 | 13 | lowest_val_loss = float('inf') 14 | best_prec1 = 0 15 | torch.manual_seed(1) 16 | 17 | 18 | def test(train_loader, model, suffix='each_rotation_10'): 19 | model.eval() 20 | for i, (inputs, _, index) in enumerate(train_loader): 21 | dir_path = "../experiments/visualization/augmentation/{}_{}".format(suffix, index[0]) 22 | # if not os.path.exists(dir_path): 23 | # os.makedirs(dir_path) 24 | augmentation_video = inputs[0].permute(1, 2, 3, 0).cpu().numpy() 25 | np.save("{}.npy".format(dir_path), augmentation_video) 26 | # print(augmentation_video.shape) 27 | # path = "{}/{}.avi".format(dir_path, str(index[0].cpu().numpy())) 28 | # print(path) 29 | # save_video(augmentation_video, path) 30 | print("{}/{} finished".format(i, len(train_loader))) 31 | # output = model(anchor) 32 | return True 33 | 34 | 35 | def main(): 36 | # = 37 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # close the warning 38 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus 39 | torch.manual_seed(1) 40 | cudnn.benchmark = True 41 | # == dataset config== 42 | num_class, data_length, image_tmpl = data_config(args) 43 | train_transforms, test_transforms, eval_transforms = augmentation_config(args) 44 | train_data_loader, val_data_loader, _, _, _, _ = data_loader_init(args, data_length, image_tmpl, train_transforms, test_transforms, eval_transforms) 45 | # == model config== 46 | model = model_config(args, num_class) 47 | pos_aug = GenPositive() 48 | # == train and eval== 49 | test(val_data_loader, model) 50 | return 1 51 | 52 | 53 | if __name__ == '__main__': 54 | main() -------------------------------------------------------------------------------- /src/Contrastive/utils/visualization/color_palettes.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import seaborn as sns 3 | import matplotlib.patches as mpatches 4 | import matplotlib.pyplot as plt 5 | # sns.set() 6 | 7 | # palette = np.array(sns.color_palette("hls", 10)) 8 | patch = mpatches.Patch(color='red', label='red') 9 | plt.legend(handles=[patch]) 10 | plt.savefig("../../../experiments/visualization/color_paletters.png") -------------------------------------------------------------------------------- /src/Contrastive/utils/visualization/confusion_matrix.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | def get_action_index(list_txt='data/classInd.txt'): 5 | action_label = [] 6 | with open(list_txt) as f: 7 | content = f.readlines() 8 | content = [x.strip('\r\n') for x in content] 9 | f.close() 10 | for line in content: 11 | label, action = line.split(' ') 12 | action_label.append(action) 13 | return action_label 14 | 15 | def plot_matrix_test(list_txt, cfu_mat="../experiments/evaluation/ucf101/_confusion.npy", s_path="../experiments/evaluation/hmdb51/03-24-1659/confuse.png"): 16 | classes = get_action_index(list_txt) 17 | confuse_matrix = np.load(cfu_mat) 18 | plot_confuse_matrix(confuse_matrix, classes, s_path) 19 | plt.show() 20 | 21 | def plot_confuse_matrix(matrix, classes, 22 | s_path = "../experiments/evaluation/hmdb51/03-24-1659/confuse.png", 23 | normalize=True, 24 | title=None, 25 | cmap=plt.cm.Blues 26 | ): 27 | """ 28 | :param matrix: 29 | :param classes: 30 | :param s_path: 31 | :param normalize: 32 | :param title: 33 | :param cmap: 34 | :return: 35 | """ 36 | if not title: 37 | if normalize: 38 | title = 'Normalized confusion matrix' 39 | else: 40 | title = 'Confusion matrix, without normalization' 41 | 42 | # Compute confusion matrix 43 | cm = matrix 44 | # Only use the labels that appear in the data 45 | if normalize: 46 | cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] 47 | print("Normalized confusion matrix") 48 | else: 49 | print('Confusion matrix, without normalization') 50 | 51 | print(cm) 52 | 53 | fig, ax = plt.subplots() 54 | 55 | # We change the fontsize of minor ticks label 56 | ax.tick_params(axis='both', which='major', labelsize=3) 57 | ax.tick_params(axis='both', which='minor', labelsize=1) 58 | 59 | im = ax.imshow(cm, interpolation='nearest', cmap=cmap) 60 | ax.figure.colorbar(im, ax=ax) 61 | # We want to show all ticks... 62 | ax.set(xticks=np.arange(cm.shape[1]), 63 | yticks=np.arange(cm.shape[0]), 64 | # ... and label them with the respective list entries 65 | xticklabels=classes, yticklabels=classes, 66 | title=title, 67 | ylabel='True label', 68 | xlabel='Predicted label') 69 | 70 | # Rotate the tick labels and set their alignment. 71 | plt.setp(ax.get_xticklabels(), rotation=60, ha="right", 72 | rotation_mode="anchor") 73 | 74 | # Loop over data dimensions and create text annotations. 75 | fmt = '.2f' 76 | # thresh = cm.max() / 2. 77 | thresh = 0.2 78 | for i in range(cm.shape[0]): 79 | for j in range(cm.shape[1]): 80 | if cm[i, j] > thresh and i != j: 81 | ax.text(j, i, "({},{})".format(classes[i], classes[j]) + format(cm[i, j], fmt), 82 | ha="center", va="center",fontsize='smaller', 83 | color="black") 84 | fig.tight_layout() 85 | plt.savefig(s_path, dpi=1024) 86 | return ax 87 | 88 | if __name__ == '__main__': 89 | # cf_name = "../experiments/evaluation/ucf101/04-12-1132/confusion.npy" 90 | # classList = "../datasets/lists/ucf101/classInd.txt" 91 | # save_path = "../experiments/evaluation/ucf101/04-12-1132/78.8_confuse.png" 92 | # plot_matrix_test(classList, cf_name, save_path) 93 | cf_name = "../experiments/evaluation/ucf101/04-14-1141/confusion.npy" 94 | classList = "../datasets/lists/ucf101/classInd.txt" 95 | save_path = "../experiments/evaluation/ucf101/04-14-1141/63.3_confuse.png" 96 | plot_matrix_test(classList, cf_name, save_path) -------------------------------------------------------------------------------- /src/Contrastive/utils/visualization/mixup_visualization.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | 4 | 5 | def mixup(im1, im2, prob): 6 | img1 = cv2.imread(im1) 7 | img2 = cv2.imread(im2) 8 | img = img1 * (1-prob) + img2 * prob 9 | return img 10 | 11 | 12 | name = '6arrowswithin30seconds_shoot_bow_f_nm_np1_fr_med_1/' 13 | video1 = "/data/jinhuama/DataSet/hmdb51/" + name 14 | index = 15 15 | new_dir = "../experiments/visualizations/mixup/" 16 | prob = 0.3 17 | 18 | if not os.path.exists(new_dir + name): 19 | os.mkdir(new_dir + name) 20 | for image in os.listdir(video1): 21 | image1 = video1 + image 22 | image2 = video1 + "img_{:05d}.jpg".format(index) 23 | new_img = mixup(image1, image2, prob) 24 | cv2.imwrite(new_dir + name + image, new_img) -------------------------------------------------------------------------------- /src/Contrastive/utils/visualization/multi100.txt: -------------------------------------------------------------------------------- 1 | 33.3 2 | 43.0 3 | 6.0 4 | -16.0 5 | 12.0 6 | 15.0 7 | 25.0 8 | -3.0 9 | 6.0 10 | 23.0 11 | 12.0 12 | -10.6 13 | 35.0 14 | 22.0 15 | 3.0 16 | -16.0 17 | 23.0 18 | 12.0 19 | 20.0 20 | 14.0 21 | 13.0 22 | 33.0 23 | -6.0 24 | 13.3 25 | 10.0 26 | -10.0 27 | 44.0 28 | -24.0 29 | 6.0 30 | -13.3 31 | 15.0 32 | 20.0 33 | 26.6 34 | 20.0 35 | 36.6 36 | 26.6 37 | -13.3 38 | -10.0 39 | -10.0 40 | 19.0 41 | 10.0 42 | 16.6 43 | 10.0 44 | 6.0 45 | 46.3 46 | 36.6 47 | 52.4 48 | 6.6 49 | 20.0 50 | 46.6 51 | 6.3 52 | -------------------------------------------------------------------------------- /src/Contrastive/utils/visualization/pearson_calculate.txt: -------------------------------------------------------------------------------- 1 | 0.22551442576666667 0.333 2 | 0.10818058883333334 0.43 3 | 0.5162171408333334 0.06 4 | 0.4882573183 -0.16 5 | 0.0708494733 0.12 6 | 0.13339997128533332 0.15 7 | 0.0516073793 0.25 8 | 0.16686041497466667 -0.03 9 | 0.2689299939 0.06 10 | 0.16904905083333333 0.23 11 | 0.19995855823333333 0.12 12 | 0.2766080123 -0.106 13 | 0.10459493276666666 0.35 14 | 0.13576242733333332 0.22 15 | 0.2628492088 0.03 16 | 0.6472266601333333 -0.16 17 | 0.0820547952 0.23 18 | 0.0504232759 0.12 19 | 0.29175092593333335 0.2 20 | 0.0324064948 0.14 21 | 0.12532078933333332 0.13 22 | 0.0307350102 0.33 23 | 0.5459736722666667 -0.06 24 | 0.11691634886666666 0.133 25 | 0.0928835198 0.1 26 | 0.5027140373666666 -0.1 27 | 0.400281556 0.44 28 | 0.22275721296666667 -0.24 29 | 0.24755843986666665 0.06 30 | 0.30486327076666664 -0.133 31 | 0.12616917856666665 0.15 32 | 0.34176297733333333 0.2 33 | 0.091833597 0.266 34 | 0.13851137033333333 0.2 35 | 0.09020716253333333 0.366 36 | 0.09902224873333335 0.266 37 | 0.17993909336666666 -0.133 38 | 0.0476045226 -0.1 39 | 0.8139881392666667 -0.1 40 | 0.18950731996666664 0.19 41 | 0.1327474982 0.1 42 | 0.1923174741 0.166 43 | 0.06810675673333333 0.1 44 | 0.1636814759 0.06 45 | 0.07896679256666667 0.463 46 | 0.16391625416666666 0.366 47 | 0.11104895753333333 0.524 48 | 0.0198147214 0.066 49 | 0.0264955944 0.2 50 | 0.0226321177 0.466 51 | 0.0656435398 0.063 -------------------------------------------------------------------------------- /src/Contrastive/utils/visualization/pearson_correlation.py: -------------------------------------------------------------------------------- 1 | # from scipy import stats 2 | # import os 3 | # 4 | # x = [] 5 | # y = [] 6 | # with open('pearson_calculate.txt', 'r') as f: 7 | # lines = f.readlines() 8 | # for line in lines: 9 | # if len(line.strip().split(' '))==2: 10 | # x.append(float(line.strip().split(' ')[0])) 11 | # y.append(float(line.strip().split(' ')[1])) 12 | # else: 13 | # x.append(float(line.strip().split('\t')[0])) 14 | # y.append(float(line.strip().split('\t')[1])) 15 | # print(len(x), len(y)) 16 | # print(stats.pearsonr(x, y)) 17 | # 18 | 19 | from scipy import stats 20 | import os 21 | 22 | x = [] 23 | y = [] 24 | g = open('multi100.txt', 'w') 25 | with open('pearson_calculate.txt', 'r') as f: 26 | lines = f.readlines() 27 | for line in lines: 28 | if len(line.strip().split(' ')) is 2: 29 | g.write(str(float(line.strip().split(' ')[1])*100) + '\n') 30 | else: 31 | g.write( 32 | str(float(line.strip().split('\t')[1]) * 100) + '\n') -------------------------------------------------------------------------------- /src/Contrastive/utils/visualization/plot_class_distribution.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import pickle 4 | 5 | 6 | def plot_histgorm(l, num=51): 7 | x = np.arange(1, num+1) 8 | fig, axs = plt.subplots(1, 1, sharex=True) 9 | axs[0].hist(l, bins=num) 10 | plt.savefig('visualization/histgorm.png') 11 | 12 | 13 | def get_action_index(cls_list='../../../datasets/lists/hmdb51/hmdb51_classInd.txt'): 14 | action_label = [] 15 | with open(cls_list) as f: 16 | content = f.readlines() 17 | content = [x.strip('\r\n') for x in content] 18 | f.close() 19 | for line in content: 20 | label, action = line.split(' ') 21 | action_label.append(action) 22 | return action_label 23 | 24 | 25 | def analyse_record(scratch_label, scratch_predict, ssl_label, ssl_predict, dataset='hmdb51'): 26 | if dataset == 'hmdb51': 27 | class_num = 51 28 | elif dataset == 'ucf101': 29 | class_num = 101 30 | else: 31 | Exception("not implement dataset!") 32 | scratch_wrong = 0 33 | scratch_clsses = np.zeros(class_num) 34 | scratch_real_classes = np.zeros(class_num) 35 | for i in range(len(scratch_label)): 36 | max_index = scratch_predict[i] 37 | label = scratch_label[i] 38 | scratch_real_classes[label] += 1 39 | if max_index != label: 40 | scratch_wrong += 1 41 | else: 42 | scratch_clsses[label] += 1 43 | 44 | #=============================our self-supervised learning ======================= 45 | self_supervised_wrong = 0 46 | self_supervised_classes = np.zeros(class_num) 47 | self_supervised_real_classes = np.zeros(class_num) 48 | for i in range(len(ssl_label)): 49 | max_index = ssl_predict[i] 50 | label = ssl_label[i] 51 | self_supervised_real_classes[label] += 1 52 | if max_index != label: 53 | self_supervised_wrong += 1 54 | else: 55 | self_supervised_classes[label] += 1 56 | print("scratch Top-1 is: {}".format(np.mean(scratch_clsses/(scratch_real_classes+1)))) 57 | print("SSL Top-1 is: {}".format(np.mean(self_supervised_classes/self_supervised_real_classes))) 58 | arr = (self_supervised_classes/self_supervised_real_classes - scratch_clsses/(scratch_real_classes+1)) 59 | topk = arr.argsort()[-5:][::-1] 60 | mink = arr.argsort()[:5][::1] 61 | if dataset == 'hmdb51': 62 | classes = get_action_index() 63 | elif dataset == 'ucf101': 64 | classes = get_action_index(cls_list='../../../datasets/lists/ucf101/classInd.txt') 65 | else: 66 | Exception("not implement dataset!") 67 | print("five largest ======>>") 68 | for i in range(5): 69 | print(classes[topk[i]], arr[topk[i]]) 70 | print("five minium ======>>") 71 | for i in range(5): 72 | print(classes[mink[i]], arr[mink[i]]) 73 | # print(self_supervised_real_classes/30, self_supervised_real_classes/30) 74 | # print((scratch_best_clsses - scratch_clsses) / 30) 75 | # print(scratch__wrong) 76 | # print(scratch__clsses/30, real_classes/30) 77 | # plot_histgorm(l_clsses, m_clsses, s_clsses) 78 | # rows = [] 79 | # classes = get_action_index() 80 | # for i in range(len(l_clsses)): 81 | # rows.append((l_clsses[i]/30, classes[i], 82 | # l_clsses[i] / 30 )) 83 | # header = ['l', 'm', 's', 'class', 'avg'] 84 | # with open('visualization/store.csv', 'w') as f: 85 | # f_csv = csv.writer(f) 86 | # f_csv.writerow(header) 87 | # f_csv.writerows(rows) 88 | return True 89 | 90 | 91 | if __name__ == '__main__': 92 | # #====================================HMDB51============================================ 93 | # #51.3 94 | # ssl_label = np.load("../../../experiments/evaluation/hmdb51/03-24-1659/video_labels.npy") 95 | # ssl_predict = np.load("../../../experiments/evaluation/hmdb51/03-24-1659/video_pred.npy") 96 | # #31.9 97 | # scratch_label = np.load("../../../experiments/evaluation/hmdb51/04-11-0926/video_labels.npy") 98 | # scratch_predict = np.load("../../../experiments/evaluation/hmdb51/04-11-0926/video_pred.npy") 99 | # analyse_record(scratch_label, scratch_predict, ssl_label, ssl_predict) 100 | # ====================================UCF101============================================ 101 | # 78.83 102 | ssl_label = np.load("../../../experiments/evaluation/ucf101/04-12-1132/video_labels.npy") 103 | ssl_predict = np.load("../../../experiments/evaluation/ucf101/04-12-1132/video_pred.npy") 104 | # 63.3? 105 | scratch_label = np.load("../../../experiments/evaluation/ucf101/04-14-1141/video_labels.npy") 106 | scratch_predict = np.load("../../../experiments/evaluation/ucf101/04-14-1141/video_pred.npy") 107 | analyse_record(scratch_label, scratch_predict, ssl_label, ssl_predict, dataset='ucf101') -------------------------------------------------------------------------------- /src/Contrastive/utils/visualization/single_video_visualization.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import os 4 | from utils.utils import Timer 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | from data.config import data_config, augmentation_config 8 | from data.dataloader import data_loader_init 9 | from model.config import model_config 10 | from loss.config import optim_init 11 | from augment.gen_positive import GenPositive 12 | from bk.option_old import args 13 | import numpy as np 14 | import cv2 15 | import skvideo.io 16 | 17 | 18 | def save_one_video(video, idx=1, title='origin'): 19 | video = video.squeeze(0) 20 | video_tensor = torch.tensor(video.detach().cpu().numpy().transpose(1, 2, 3, 0)) # 16 x 31 x 31 x 3 21 | path = "../experiments/gen_videos/{}_{}.jpg".format(title, idx) 22 | img = np.zeros((video_tensor.size(1), video_tensor.size(2), video_tensor.size(3)), np.uint8) 23 | img.fill(90) 24 | cv2.putText(img, title, (10, 50), cv2.FONT_HERSHEY_SCRIPT_COMPLEX, 1, (0, 255, 255), 1, cv2.LINE_AA) 25 | output = img 26 | for i in range(video_tensor.shape[0]): 27 | if i % 3 == 0: 28 | output = np.concatenate((output, np.uint8(video_tensor[i] * 255)), axis=1) 29 | cv2.imwrite(path, output) 30 | print("index: {} finished".format(idx)) 31 | return output 32 | 33 | 34 | def rgb_flow(prvs, next): 35 | hsv = np.zeros_like(prvs) 36 | hsv[..., 1] = 255 37 | prvs = cv2.cvtColor(prvs, cv2.COLOR_RGB2GRAY) 38 | next = cv2.cvtColor(next, cv2.COLOR_RGB2GRAY) 39 | flow = cv2.calcOpticalFlowFarneback(prvs, next, None, 0.5, 3, 15, 3, 5, 1.2, 0) 40 | mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1]) 41 | hsv[..., 0] = ang * 180 / np.pi / 2 42 | hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX) 43 | flow_rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB) 44 | return flow_rgb 45 | 46 | 47 | def write_video(path, output): 48 | fps = 30 49 | writer = skvideo.io.FFmpegWriter(path, 50 | outputdict={'-b': '300000000', '-r': str(fps)}) 51 | # print(len(output)) 52 | for frame in output: 53 | frame = np.array(frame) 54 | writer.writeFrame(frame) 55 | writer.close() 56 | 57 | 58 | def save_as_video(a, p, n, tn, idx=1, title='triples'): 59 | a = a.squeeze(0) 60 | p = p.squeeze(0) 61 | n = n.squeeze(0) 62 | tn = tn.squeeze(0) 63 | path = "../experiments/gen_videos/{}_{}.mp4".format(title, idx) 64 | a_p_path = "../experiments/gen_videos/{}_{}.mp4".format('a&p_', idx) 65 | a_n_path = "../experiments/gen_videos/{}_{}.mp4".format('a&n_', idx) 66 | flows_path = "../experiments/gen_videos/{}_{}.mp4".format('flows_', idx) 67 | a_tensor = torch.tensor(a.detach().cpu().numpy().transpose(1, 2, 3, 0)) # 16 x 31 x 31 x 3 68 | p_tensor = torch.tensor(p.detach().cpu().numpy().transpose(1, 2, 3, 0)) # 16 x 31 x 31 x 3 69 | n_tensor = torch.tensor(n.detach().cpu().numpy().transpose(1, 2, 3, 0)) # 16 x 31 x 31 x 3 70 | tn_tensor = torch.tensor(tn.detach().cpu().numpy().transpose(1, 2, 3, 0)) # 16 x 31 x 31 x 3 71 | output = [] 72 | a_p_output = [] 73 | a_n_output = [] 74 | flows_output = [] 75 | # print(a_tensor.size(0)) 76 | for i in range(a_tensor.size(0) - 1): 77 | a_img = np.uint8(a_tensor[i] * 255) 78 | a_img = cv2.cvtColor(a_img, cv2.COLOR_BGR2RGB) 79 | p_img = np.uint8(p_tensor[i] * 255) 80 | n_img = np.uint8(n_tensor[i] * 255) 81 | tn_img = cv2.cvtColor(np.uint8(tn_tensor[i] * 255), cv2.COLOR_BGR2RGB) 82 | a_img_next = np.uint8(a_tensor[i+1] * 255) 83 | p_img_next = np.uint8(p_tensor[i+1] * 255) 84 | n_img_next = np.uint8(n_tensor[i + 1] * 255) 85 | tn_img_next = cv2.cvtColor(np.uint8(tn_tensor[i + 1] * 255), cv2.COLOR_BGR2RGB) 86 | flow_a = rgb_flow(a_img, a_img_next) 87 | flow_p = rgb_flow(p_img, p_img_next) 88 | flow_n = rgb_flow(n_img, n_img_next) 89 | flow_tn = rgb_flow(tn_img, tn_img_next) 90 | 91 | rgb_cat = np.concatenate((a_img, p_img, n_img, tn_img), 1) 92 | flow_cat = np.concatenate((flow_a, flow_p, flow_n, flow_tn), 1) 93 | img = np.concatenate((rgb_cat, flow_cat), 0) 94 | output.append(img) 95 | # a_p_output.append(np.concatenate((a_img, p_img), axis=1)) 96 | # a_n_output.append(np.concatenate((a_img, n_img), axis=1)) 97 | # flows_output.append(np.concatenate((flow_a, flow_p, flow_n), axis=1)) 98 | # write_video(a_p_path, a_p_output) 99 | # write_video(a_n_path, a_n_output) 100 | # write_video(flows_path, flows_output) 101 | write_video(path, output) 102 | print("video: {} finished".format(idx)) 103 | return 104 | 105 | 106 | def validate(tc, val_loader, model): 107 | # switch to evaluate mode 108 | model.eval() 109 | with torch.no_grad(): 110 | for i, (inputs, target, index) in enumerate(val_loader): 111 | if i > 100: 112 | break 113 | # target = target.cuda(async=True) 114 | target = target.cuda() 115 | for j in range(len(inputs)): 116 | inputs[j] = inputs[j].float() 117 | inputs[j] = inputs[j].cuda() 118 | # ===================forward===================== 119 | anchor, positive, negative, t_wrap, s_wrap = inputs 120 | positive = tc(positive) 121 | save_as_video(anchor, positive, negative, t_wrap, idx=index.cpu().data) 122 | # anchor_cat = save_one_video(anchor, idx=index, title='anchor') 123 | # positive_cat = save_one_video(positive, idx=index, title='positive') 124 | # negative_cat = save_one_video(negative, idx=index, title='negative') 125 | # imgs = np.concatenate((anchor_cat, positive_cat, negative_cat)) 126 | # path = "../experiments/gen_videos/{}_{}.jpg".format('triplet', index) 127 | # cv2.imwrite(path, imgs) 128 | return None 129 | 130 | 131 | if __name__ == '__main__': 132 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # close the warning 133 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus 134 | args.batch_size = 1 135 | args.data_length = 128 136 | args.stride = 1 137 | args.spatial_size = 224 138 | args.mode = 'rgb' 139 | args.eval_indict = 'loss' 140 | args.pt_loss = 'flow' 141 | args.workers = 1 142 | args.print_freq = 100 143 | args.dataset = 'ucf101' 144 | args.train_list = '../datasets/lists/ucf101/ucf101_rgb_train_split_1.txt' 145 | args.val_list = '../datasets/lists/ucf101/ucf101_rgb_val_split_1.txt' 146 | # args.root = "" 147 | args.root = "/data1/awinywang/Data/ft_local/ucf101/jpegs_256/" # 144 148 | pos_aug = GenPositive() 149 | torch.manual_seed(1) 150 | cudnn.benchmark = True 151 | timer = Timer() 152 | # == dataset config== 153 | num_class, data_length, image_tmpl = data_config(args) 154 | train_transforms, test_transforms, eval_transforms = augmentation_config(args) 155 | train_data_loader, val_data_loader, eval_data_loader, _, _, _ = data_loader_init(args, data_length, image_tmpl, train_transforms, test_transforms, eval_transforms) 156 | # == model config== 157 | model = model_config(args, num_class) 158 | # == optim config== 159 | train_criterion, val_criterion, optimizer = optim_init(args, model) 160 | # == data augmentation(self-supervised) config== 161 | validate(pos_aug, val_data_loader, model) -------------------------------------------------------------------------------- /src/Contrastive/utils/visualization/t_SNE_Visualization.py: -------------------------------------------------------------------------------- 1 | # That's an impressive list of imports. 2 | import numpy as np 3 | import torch 4 | from numpy import linalg 5 | from numpy.linalg import norm 6 | from scipy.spatial.distance import squareform, pdist 7 | 8 | # We import sklearn. 9 | import sklearn 10 | from sklearn.manifold import TSNE 11 | from sklearn.datasets import load_digits 12 | from sklearn.preprocessing import scale 13 | 14 | # We'll hack a bit with the t-SNE code in sklearn 0.15.2. 15 | from sklearn.metrics.pairwise import pairwise_distances 16 | from sklearn.manifold.t_sne import (_joint_probabilities, 17 | _kl_divergence) 18 | # Random state. 19 | RS = 20150101 20 | 21 | # We'll use matplotlib for graphics. 22 | import matplotlib.pyplot as plt 23 | import matplotlib.patheffects as PathEffects 24 | import matplotlib 25 | # %matplotlib inline 26 | 27 | # We import seaborn to make nice plots. 28 | import seaborn as sns 29 | import os 30 | sns.set_style('darkgrid') 31 | sns.set_palette('muted') 32 | sns.set_context("notebook", font_scale=1.5, 33 | rc={"lines.linewidth": 2.5}) 34 | 35 | # We'll generate an animation with matplotlib and moviepy. 36 | # from moviepy.video.io.bindings import mplfig_to_npimage 37 | # import moviepy.editor as mpy 38 | 39 | 40 | def load_data(file_name): 41 | # digits = load_digits() 42 | # digits.data.shape # 1797 x 64 43 | # print(digits.data.shape) 44 | # print(digits['DESCR']) 45 | # return digits 46 | features = np.load(file_name, allow_pickle='TRUE').item() 47 | return features 48 | 49 | 50 | def scatter(x, colors, num_class=10): 51 | # We choose a color palette with seaborn. 52 | palette = np.array(sns.color_palette("hls", num_class)) 53 | # sns.palplot(sns.color_palette("hls", 10)) 54 | # We create a scatter plot. 55 | labels=['brush_hair', 'cartwheel', 'catch', 'chew', 56 | 'clap', 'climb', 'climb_stairs', 'dive', 'draw_sword', 57 | 'dribble'] 58 | f = plt.figure(figsize=(8, 8)) 59 | # print(colors.astype(np.int)) 60 | ax = plt.subplot(aspect='equal') 61 | # for i in range(10): 62 | # sc = ax.scatter(x[:, 0][30*i:30*(i+1)], x[:, 1][30*i:30*(i+1)], c=palette[colors.astype(np.int)][30*i:30*(i+1)], 63 | # s=40, 64 | # label=labels[i], 65 | # ) 66 | sc = ax.scatter(x[:,0], x[:,1], c=palette[colors.astype(np.int)], 67 | s=150, 68 | #label=colors.astype(np.int)[30], 69 | ) 70 | # ax.legend(loc="best", title="Classes", bbox_to_anchor=(0.2, 0.4)) 71 | plt.xlim(-25, 25) 72 | plt.ylim(-25, 25) 73 | ax.axis('off') 74 | ax.axis('tight') 75 | 76 | # We add the labels for each digit. 77 | txts = [] 78 | for i in range(num_class): 79 | # Position of each label. 80 | xtext, ytext = np.median(x[colors == i, :], axis=0) 81 | txt = ax.text(xtext, ytext, str(i), fontsize=24) 82 | # ax.legend(ytext, "a") 83 | txt.set_path_effects([ 84 | PathEffects.Stroke(linewidth=5, foreground="w"), 85 | PathEffects.Normal()]) 86 | txts.append(txt) 87 | # ax.legend(('a','b','c','d','e')) 88 | return f, ax, sc, txts 89 | 90 | 91 | def tsne_visualize(data, file_name, num_class=101): 92 | # nrows, ncols = 2, 5 93 | # plt.figure(figsize=(6,3)) 94 | # plt.gray() 95 | # for i in range(ncols * nrows): 96 | # ax = plt.subplot(nrows, ncols, i + 1) 97 | # ax.matshow(digits.images[i,...]) 98 | # plt.xticks([]); plt.yticks([]) 99 | # plt.title(digits.target[i]) 100 | # plt.savefig('../../../experiments/visualization/digits-generated.png', dpi=150) 101 | 102 | # We first reorder the data points according to the handwritten numbers. 103 | datas = [] 104 | labels = [] 105 | nums = len(data['target']) 106 | for i in range(num_class, num_class + 10): 107 | for j in range(nums): 108 | if data['target'][j] == i: 109 | datas.append(data['data'][j]) 110 | X = np.vstack(datas) 111 | for i in range(num_class, num_class + 10): 112 | for j in range(nums): 113 | if data['target'][j] == i: 114 | labels.append(data['target'][j] - num_class) 115 | y = np.hstack(labels) 116 | # X = np.vstack([data['data'][data['target']==i].cpu() 117 | # for i in range(10)]) 118 | # y = np.hstack([data['target'][data['target']==i].cpu() 119 | # for i in range(10)]) 120 | # print(y) 121 | digits_proj = TSNE(random_state=RS).fit_transform(X) 122 | scatter(digits_proj, y) 123 | plt.savefig(file_name, dpi=120) 124 | 125 | 126 | # for begin_index in range(1,41,5): 127 | # front = 'self_supervised_kineticsMoCo' 128 | # s_path = "../../../experiments/visualization/" + front + str(begin_index) + 'to' + str(begin_index + 10) 129 | # if not os.path.exists(s_path): 130 | # os.mkdir(s_path) 131 | # features_file = "../../../experiments/visualization/{}_hmdb51_features.npy".format(front) 132 | # file_name = "{}/{}_hmdb51_tsne-generated.png".format(s_path, front) 133 | # data = load_data(features_file) 134 | # tsne_visualize(data, file_name, begin_index) 135 | 136 | # for front in ['scratch', 'fully_supervised_37', 'self_supervised_37','fully_supervised_scratch']: 137 | # # front = 'scratch' 138 | # features_file = "../../../experiments/visualization/{}_hmdb51_features.npy".format(front) 139 | # file_name = "{}/{}hmdb51_tsne-generated.png".format(s_path, front) 140 | # data = load_data(features_file) 141 | # tsne_visualize(data, file_name, begin_index) 142 | 143 | 144 | if __name__ == '__main__': 145 | # front = 'contrastive_kinetics_warpping_hmdb51' 146 | # front = 'contrastive_ucf101_warpping_hmdb51' 147 | front = 'i3d_fully_supervised_kinetics_warpping_hmdb51' 148 | for begin_index in range(1, 41, 5): 149 | s_path = "../../../experiments/visualization/TSNE/" + front + str(begin_index) + 'to' + str(begin_index + 10) 150 | if not os.path.exists(s_path): 151 | os.mkdir(s_path) 152 | features_file = "../../../experiments/features/{}/val_features.npy".format(front) 153 | file_name = "{}/{}_hmdb51_tsne-generated.png".format(s_path, front) 154 | data = load_data(features_file) 155 | tsne_visualize(data, file_name, begin_index) -------------------------------------------------------------------------------- /src/Contrastive/utils/visualization/triplet_visualization.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import os 4 | import torch 5 | import torch.backends.cudnn as cudnn 6 | from data.config import data_config, augmentation_config 7 | from data.dataloader import data_loader_init 8 | from model.config import model_config 9 | from augment.gen_positive import GenPositive 10 | from utils.visualization.triplet_visualization import triplet_visualize, save_img 11 | from bk.option_old import args 12 | 13 | lowest_val_loss = float('inf') 14 | best_prec1 = 0 15 | torch.manual_seed(1) 16 | 17 | 18 | def test(train_loader, model, pos_aug): 19 | model.eval() 20 | for i, (inputs, target, index) in enumerate(train_loader): 21 | anchor, positive, negative = inputs 22 | anchor = pos_aug(anchor) 23 | dir_path = "../experiments/visualization/triplet2/{}".format(index[0]) 24 | if not os.path.exists(dir_path): 25 | os.makedirs(dir_path) 26 | mask_img = triplet_visualize(anchor.cpu().numpy(), positive.cpu().numpy(), negative.cpu().numpy(), dir_path) 27 | path = "{}/{}.png".format(dir_path, str(index[0].cpu().numpy())) 28 | save_img(mask_img, path) 29 | print("{}/{} finished".format(i, len(train_loader))) 30 | # output = model(anchor) 31 | return True 32 | 33 | 34 | def main(): 35 | # = 36 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # close the warning 37 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus 38 | torch.manual_seed(1) 39 | cudnn.benchmark = True 40 | # == dataset config== 41 | num_class, data_length, image_tmpl = data_config(args) 42 | train_transforms, test_transforms, eval_transforms = augmentation_config(args) 43 | train_data_loader, val_data_loader, _, _, _, _ = data_loader_init(args, data_length, image_tmpl, train_transforms, test_transforms, eval_transforms) 44 | # == model config== 45 | model = model_config(args, num_class) 46 | pos_aug = GenPositive() 47 | # == train and eval== 48 | test(train_data_loader, model, pos_aug) 49 | return 1 50 | 51 | 52 | if __name__ == '__main__': 53 | main() -------------------------------------------------------------------------------- /src/Contrastive/utils/visualization/video_write_test.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import skvideo.io 3 | import numpy as np 4 | import os 5 | import augment.video_transformations.video_transform_PIL_or_np as video_transform 6 | from torchvision import transforms 7 | import skimage.transform 8 | import random 9 | from PIL import Image 10 | import torchvision 11 | 12 | 13 | def read_video(video): 14 | cap = cv2.VideoCapture(video) 15 | frames = list() 16 | while True: 17 | ret, frame = cap.read() 18 | if type(frame) is type(None): 19 | break 20 | else: 21 | frames.append(frame) 22 | return frames 23 | 24 | 25 | def write_video(name, frames): 26 | # fshape = frames[0].shape 27 | # fheight = fshape[0] 28 | # fwidth = fshape[1] 29 | # writer = cv2.VideoWriter(name, 30 | # cv2.VideoWriter_fourcc(*"MJPG"), 30, (fheight, fwidth)) 31 | # for i in range(len(frames)): 32 | # writer.write(frames[i]) 33 | # writer.release() 34 | writer = skvideo.io.FFmpegWriter(name, 35 | outputdict={'-b': '300000000'}) 36 | for frame in frames: 37 | frame = np.array(frame) 38 | writer.writeFrame(frame) 39 | writer.close() 40 | return 1 41 | 42 | 43 | if __name__ == '__main__': 44 | # video = "../experiments/test.mp4" 45 | # frames = read_video(video) 46 | 47 | # # aug = video_transform.RandomRotation(10), 48 | # # video_transform.STA_RandomRotation(10), 49 | # # video_transform.Each_RandomRotation(10), 50 | # # train_transforms = transforms.Compose([video_transform.Resize(128), 51 | # # video_transform.RandomCrop(112), aug]) 52 | # frames = np.array(frames) 53 | # prefix = "random_rotation" 54 | # angle = random.uniform(-45, 45) 55 | # rotated = [skimage.transform.rotate(img, angle) for img in frames] 56 | # name = '../experiments/gen_videos/test_{}.avi'.format(prefix) 57 | # write_video(name, rotated) 58 | # prefix = "STA_rotation" 59 | # bsz = len(frames) 60 | # angles = [(i + 1) / (bsz + 1) * angle for i in range(bsz)] 61 | # rotated = [skimage.transform.rotate(img, angles[i]) for i, img in enumerate(frames)] 62 | # name = '../experiments/gen_videos/test_{}.avi'.format(prefix) 63 | # write_video(name, rotated) 64 | # prefix = "each_random_rotation" 65 | # angles = [random.uniform(-45, 45) for i in range(bsz)] 66 | # rotated = [skimage.transform.rotate(img, angles[i]) for i, img in enumerate(frames)] 67 | # name = '../experiments/gen_videos/test_{}.avi'.format(prefix) 68 | # write_video(name, rotated) 69 | # # for i in range(10): 70 | # # seqs = np.load("{}_{}.npy".format("../experiments/augmentation/{}".format(prefix), i)) 71 | # # seqs = (seqs+1)/2*255 72 | # # out_dir = "../experiments/gen_videos/{}".format(i) 73 | # # if not os.path.exists(out_dir): 74 | # # os.makedirs(out_dir) 75 | # # for j in range(16): 76 | # # cv2.imwrite("{}/{}_{}.jpg".format(out_dir, prefix, j), seqs[j]) 77 | # # name = '../experiments/gen_videos/test_{}.mp4'.format(i) 78 | # # write_video(name, seqs) 79 | # video = "../experiments/test.mp4" 80 | # frames = read_video(video) 81 | # images = [] 82 | # frames = np.array(frames) 83 | # for i in range(len(frames)): 84 | # img = Image.fromarray(np.uint8(frames[i])) 85 | # images.append(img) 86 | # # Create img transform function sequence 87 | # img_transforms = [] 88 | # brightness = 0.5 89 | # contrast = 0.5 90 | # saturation = 0.5 91 | # hue = 0.2 92 | # if brightness is not None: 93 | # img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness)) 94 | # if saturation is not None: 95 | # img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation)) 96 | # if hue is not None: 97 | # img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue)) 98 | # if contrast is not None: 99 | # img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast)) 100 | # random.shuffle(img_transforms) 101 | # 102 | # # Apply to all images 103 | # jittered_clip = [] 104 | # for img in images: 105 | # for func in img_transforms: 106 | # jittered_img = func(img) 107 | # jittered_clip.append(jittered_img) 108 | # name = '../experiments/gen_videos/test_{}.avi'.format('jitter') 109 | # write_video(name, jittered_clip) 110 | video = "../experiments/test.mp4" 111 | frames = read_video(video) 112 | images = [] 113 | frames = np.array(frames) 114 | for i in range(len(frames)): 115 | img = Image.fromarray(np.uint8(frames[i])) 116 | images.append(img) 117 | # Create img transform function sequence 118 | brightness = 0.5 119 | contrast = 0.5 120 | saturation = 0.5 121 | hue = 0.2 122 | # img_transforms = [] 123 | # img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness)) 124 | # img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation)) 125 | # img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue)) 126 | # img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast)) 127 | # random.shuffle(img_transforms) 128 | # 129 | # # Apply to all images 130 | # jittered_clip = [] 131 | # for img in images: 132 | # for func in img_transforms: 133 | # jittered_img = func(img) 134 | # jittered_clip.append(jittered_img) 135 | # name = '../experiments/gen_videos/test_{}.avi'.format('sta_jitter') 136 | # write_video(name, jittered_clip) 137 | # Apply to all images 138 | jittered_clip = [] 139 | for i, img in enumerate(images): 140 | t_brightness = (i+1)/(len(images)+1) * brightness 141 | t_contrast = (i + 1) / (len(images) + 1) * contrast 142 | t_saturation = (i + 1) / (len(images) + 1) * saturation 143 | t_hue = (i + 1) / (len(images) + 1) * hue 144 | img_transforms = [] 145 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, t_brightness)) 146 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, t_saturation)) 147 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, t_hue)) 148 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, t_contrast)) 149 | random.shuffle(img_transforms) 150 | for func in img_transforms: 151 | jittered_img = func(img) 152 | jittered_clip.append(jittered_img) 153 | name = '../experiments/gen_videos/test_{}.avi'.format('sta_jitter') 154 | write_video(name, jittered_clip) --------------------------------------------------------------------------------