├── .idea
├── TBE.iml
├── deployment.xml
├── encodings.xml
├── misc.xml
├── modules.xml
├── remote-mappings.xml
├── vcs.xml
└── workspace.xml
├── dataset.md
├── figures
├── PPL.jpg
├── be_visualization.png
├── example1.gif
├── example2.gif
├── mix_videos.png
├── motivation.png
└── motivation_example.png
├── readme.md
└── src
└── Contrastive
├── __init__.py
├── augment
├── basic_augmentation
│ ├── eda.py
│ ├── mixup_methods.py
│ ├── net_mixup.py
│ ├── noise.py
│ ├── rotation.py
│ ├── temporal_augment.py
│ ├── temporal_dropout.py
│ ├── temporal_shuffle.py
│ ├── temporal_sub.py
│ ├── triplet.py
│ └── video_color_jitter.py
├── config.py
├── gen_negative.py
├── gen_positive.py
└── video_transformations
│ ├── functional.py
│ ├── video_transform_PIL_or_np.py
│ ├── videotransforms.py
│ └── volume_transforms.py
├── data
├── __init__.py
├── base.py
├── config.py
├── dataloader.py
├── dataset.py
├── decode_on_the_fly.py
├── on_the_fly_test.py
└── video_dataset.py
├── feature_extractor.py
├── ft.py
├── loss
├── NCE
│ ├── Link.py
│ ├── NCEAverage.py
│ ├── NCECriterion.py
│ ├── __init__.py
│ └── alias_multinomial.py
├── __init__.py
├── config.py
└── tcr.py
├── main.py
├── model
├── __init__.py
├── c3d.py
├── config.py
├── i3d.py
├── model.py
├── mutual_net.py
├── r2p1d.py
├── r3d.py
├── s3d.py
└── s3d_g.py
├── option.py
├── pt.py
├── reterival.py
├── scripts
├── Diving48
│ ├── ft.sh
│ ├── pt_and_ft.sh
│ ├── pt_and_ft_moco.sh
│ └── pt_and_ft_moco_ucf_to_diving48.sh
├── evaluation
│ ├── hmdb51_i3d_eval.sh
│ └── ucf101_i3d_eval.sh
├── feature_extract
│ ├── hmdb51_extract.sh
│ └── ucf101_extract.sh
├── hmdb51
│ ├── ft.sh
│ ├── pt_and_ft.sh
│ └── pt_and_ft_hmdb51.sh
├── kinetics
│ └── pt_and_ft.sh
├── something-something-v1
│ ├── i3d_ft.sh
│ ├── i3d_pt_and_ft.sh
│ ├── i3d_pt_and_ft_multi_gpus.sh
│ ├── r3d_ft.sh
│ ├── r3d_ft_multi_gpus.sh
│ ├── r3d_pt_and_ft.sh
│ └── r3d_pt_and_ft_multi_gpus.sh
├── ucf101
│ └── pt_and_ft.sh
└── visualization
│ ├── ucf101_data_augmentation_visualize.sh
│ └── ucf101_triplet_visualize.sh
├── test.py
└── utils
├── data_process
├── gen_diving48_frames.py
├── gen_diving48_lists.py
├── gen_hmdb51_dir.py
├── gen_hmdb51_frames.py
├── gen_hmdb51_sta_list.py
├── gen_sub_dataset.py
├── semi_data_split.py
└── unrar_hmdb51_sta.sh
├── gradient_check.py
├── learning_rate_adjust.py
├── load_weights.py
├── moment_update.py
├── recoder.py
├── utils.py
└── visualization
├── augmentation_visualization.py
├── color_palettes.py
├── confusion_matrix.py
├── mixup_visualization.py
├── multi100.txt
├── pearson_calculate.txt
├── pearson_correlation.py
├── plot_class_distribution.py
├── plot_flow_v_rgb_distribution.py
├── single_video_visualization.py
├── t_SNE_Visualization.py
├── triplet_visualization.py
└── video_write_test.py
/.idea/TBE.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/.idea/encodings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/remote-mappings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/dataset.md:
--------------------------------------------------------------------------------
1 | # Dataset Prepare
2 | For Kinetics, we decode on the fly, each row in the txt file include:
3 | > video_path class
4 |
5 | And we load the videos directly, please place the training set in SSD for fast IO.
6 |
7 | Prepare dataset (UCF101/diving48/sth/hmdb51/actor-hmdb51), and each row of txt is as below:
8 |
9 | > video_path class frames_num
10 |
11 | These datasets saved in frames. We offer list for all these datasets in [Google Driver](https://drive.google.com/drive/folders/1ndq0rdxEvubBrbXny8RuGCTETXU2hr1N?usp=sharing).
12 |
13 | ## Kinetics
14 | As some Youtube Link is lost, we use a copy of kinetics-400 from [Non-local](https://github.com/facebookresearch/video-nonlocal-net), the training set is 234643 videos now and the val set is 19761 now.
15 | All the videos in mpeg/avi format.
16 |
17 | ## UCF101/HMDB51
18 | These two video datasets contain three train/test splits.
19 | Down UCF101 from [crcv](https://www.crcv.ucf.edu/data/UCF101.php) and HMDB51 from [serre-la](https://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/).
20 |
21 | ## Diving48
22 | (1). Download divin48 from [ucsd](http://www.svcl.ucsd.edu/projects/resound/dataset.html)
23 |
24 | (2). Generated frames using script __src/Contrastive/utils/data_process/gen_diving48_frames.py__
25 |
26 | (3). Generated lists using script __src/Contrastive/utils/data_process/gen_diving48_lists.py__
27 |
28 | ## Sth-v1
29 |
30 | ## HMDB51-STA/Actor-HMDB51
31 | The HMDB51-STA is removing huge cmaera motion from [HMDB-STA](https://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/).
32 | We provide our generated Actor-HMDB51 in [google_driver]().
33 |
34 | ### Semi-supervised Subdataset
35 | We also provide manuscript to get the sub-set of kinetics-400. Please refer to Contrastive/utils/data_provess/semi_data_split.py for details.
36 |
--------------------------------------------------------------------------------
/figures/PPL.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/BE/632c3aa0eaa3acc24a545ec05a9a36f96592cb2c/figures/PPL.jpg
--------------------------------------------------------------------------------
/figures/be_visualization.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/BE/632c3aa0eaa3acc24a545ec05a9a36f96592cb2c/figures/be_visualization.png
--------------------------------------------------------------------------------
/figures/example1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/BE/632c3aa0eaa3acc24a545ec05a9a36f96592cb2c/figures/example1.gif
--------------------------------------------------------------------------------
/figures/example2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/BE/632c3aa0eaa3acc24a545ec05a9a36f96592cb2c/figures/example2.gif
--------------------------------------------------------------------------------
/figures/mix_videos.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/BE/632c3aa0eaa3acc24a545ec05a9a36f96592cb2c/figures/mix_videos.png
--------------------------------------------------------------------------------
/figures/motivation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/BE/632c3aa0eaa3acc24a545ec05a9a36f96592cb2c/figures/motivation.png
--------------------------------------------------------------------------------
/figures/motivation_example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/BE/632c3aa0eaa3acc24a545ec05a9a36f96592cb2c/figures/motivation_example.png
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 | 
2 | 
3 | 
4 | 
5 |
6 | # TBE
7 |
8 | The source code for our paper "Removing the Background by Adding the Background: Towards Background Robust Self-supervised Video Representation Learning" [[arxiv](https://arxiv.org/abs/2009.05769)]
9 | [[code](https://github.com/FingerRec/BE)][[Project Website](https://fingerrec.github.io/index_files/jinpeng/papers/CVPR2021/project_website.html)]
10 |
11 |
12 |

13 |
14 |
15 |
16 | ## Citation
17 |
18 | ```bash
19 | @inproceedings{wang2021removing,
20 | title={Removing the Background by Adding the Background: Towards Background Robust Self-supervised Video Representation Learning},
21 | author={Wang, Jinpeng and Gao, Yuting and Li, Ke and Lin, Yiqi and Ma, Andy J and Cheng, Hao and Peng, Pai and Ji, Rongrong and Sun, Xing},
22 | booktitle={CVPR},
23 | year={2021}
24 | }
25 | ```
26 |
27 | ## News
28 | [2020.3.7] The first version of TBE are released!
29 |
30 | ## 0. Motivation
31 |
32 | - In camera-fixed situation, the static background in most frames remain similar in pixel-distribution.
33 |
34 | 
35 |
36 |
37 | - We ask the model to be **temporal sensitive** rather than **static sensitive**.
38 |
39 | - We ask model to filter the additive **Background Noise**, which means to erasing background in each frame of the video.
40 |
41 |
42 | ### Activation Map Visualization of BE
43 |
44 | 
45 |
46 | ### GIF
47 | 
48 |
49 | #### More hard example
50 | 
51 |
52 | ## 2. Plug BE into any self-supervised learning method in two steps
53 |
54 | The impementaion of BE is very simple, you can implement it in two lines by python:
55 | ```python
56 | rand_index = random.randint(t)
57 | mixed_x[j] = (1-prob) * x + prob * x[rand_index]
58 | ```
59 |
60 | Then, just need define a loss function like MSE:
61 |
62 | ```python
63 | loss = MSE(F(mixed_x),F(x))
64 | ```
65 |
66 |
67 | ## 2. Installation
68 |
69 | ### Dataset Prepare
70 | Please refer to [dataset.md] for details.
71 |
72 |
73 | ### Requirements
74 | - Python3
75 | - pytorch1.1+
76 | - PIL
77 | - Intel (on the fly decode)
78 | - Skvideo.io
79 | - Matplotlib (gradient_check)
80 |
81 | **As Kinetics dataset is time-consuming for IO, we decode the avi/mpeg on the fly. Please refer to data/video_dataset.py for details.**
82 |
83 | ## 3. Structure
84 | - datasets
85 | - list
86 | - hmdb51: the train/val lists of HMDB51/Actor-HMDB51
87 | - hmdb51_sta: the train/val lists of HMDB51_STA
88 | - ucf101: the train/val lists of UCF101
89 | - kinetics-400: the train/val lists of kinetics-400
90 | - diving48: the train/val lists of diving48
91 | - experiments
92 | - logs: experiments record in detials, include logs and trained models
93 | - gradientes:
94 | - visualization:
95 | - pretrained_model:
96 | - src
97 | - Contrastive
98 | - data: load data
99 | - loss: the loss evaluate in this paper
100 | - model: network architectures
101 | - scripts: train/eval scripts
102 | - augmentation: detail implementation of BE augmentation
103 | - utils
104 | - feature_extract.py: feature extractor given pretrained model
105 | - main.py: the main function of pretrain / finetune
106 | - trainer.py
107 | - option.py
108 | - pt.py: BE pretrain
109 | - ft.py: BE finetune
110 | - Pretext
111 | - main.py the main function of pretrain / finetune
112 | - loss: the loss include classification loss
113 | ## 4. Run
114 | ### (1). Download dataset lists and pretrained model
115 | A copy of both dataset lists is provided in [anonymous]().
116 | The Kinetics-pretrained models are provided in [anonymous]().
117 |
118 | ```bash
119 | cd .. && mkdir datasets
120 | mv [path_to_lists] to datasets
121 | mkdir experiments && cd experiments
122 | mkdir pretrained_models && logs
123 | mv [path_to_pretrained_model] to ../experiments/pretrained_model
124 | ```
125 |
126 | Download and extract frames of Actor-HMDB51.
127 | ```bash
128 | wget -c anonymous
129 | unzip
130 | python utils/data_process/gen_hmdb51_dir.py
131 | python utils/data_process/gen_hmdb51_frames.py
132 | ```
133 | ### (2). Network Architecture
134 | The network is in the folder **src/model/[].py**
135 |
136 | | Method | #logits_channel |
137 | | ---- | ---- |
138 | | C3D | 512 |
139 | | R2P1D | 2048 |
140 | | I3D | 1024 |
141 | | R3D | 2048 |
142 |
143 | All the logits_channel are feed into a fc layer with 128-D output.
144 |
145 |
146 |
147 | > For simply, we divide the source into Contrastive and Pretext, "--method pt_and_ft" means pretrain and finetune in once.
148 |
149 | ### Action Recognition
150 | #### Random Initialization
151 | For random initialization baseline. Just comment --weights in line 11 of ft.sh.
152 | Like below:
153 |
154 | ```bash
155 | #!/usr/bin/env bash
156 | python main.py \
157 | --method ft --arch i3d \
158 | --ft_train_list ../datasets/lists/diving48/diving48_v2_train_no_front.txt \
159 | --ft_val_list ../datasets/lists/diving48/diving48_v2_test_no_front.txt \
160 | --ft_root /data1/DataSet/Diving48/rgb_frames/ \
161 | --ft_dataset diving48 --ft_mode rgb \
162 | --ft_lr 0.001 --ft_lr_steps 10 20 25 30 35 40 --ft_epochs 45 --ft_batch_size 4 \
163 | --ft_data_length 64 --ft_spatial_size 224 --ft_workers 4 --ft_stride 1 --ft_dropout 0.5 \
164 | --ft_print-freq 100 --ft_fixed 0 # \
165 | # --ft_weights ../experiments/kinetics_contrastive.pth
166 | ```
167 |
168 |
169 | #### BE(Contrastive)
170 | ##### Kinetics
171 | ```bash
172 | bash scripts/kinetics/pt_and_ft.sh
173 | ```
174 | ##### UCF101
175 | ```bash
176 | bash scripts/ucf101/ucf101.sh
177 | ```
178 | ##### Diving48
179 | ```bash
180 | bash scripts/Diving48/diving48.sh
181 | ```
182 |
183 |
184 | **For Triplet loss optimization and moco baseline, just modify --pt_method**
185 |
186 | #### BE (Triplet)
187 |
188 | ```bash
189 | --pt_method be_triplet
190 | ```
191 |
192 | #### BE(Pretext)
193 | ```bash
194 | bash scripts/hmdb51/i3d_pt_and_ft_flip_cls.sh
195 | ```
196 |
197 | or
198 |
199 | ```bash
200 | bash scripts/hmdb51/c3d_pt_and_ft_flip.sh
201 | ```
202 |
203 | **Notice: More Training Options and ablation study can be find in scripts**
204 |
205 |
206 | ### Video Retrieve and other visualization
207 |
208 | #### (1). Feature Extractor
209 | As STCR can be easily extend to other video representation task, we offer the scripts to perform feature extract.
210 | ```bash
211 | python feature_extractor.py
212 | ```
213 |
214 | The feature will be saved as a single numpy file in the format [video_nums,features_dim] for further visualization.
215 |
216 | #### (2). Reterival Evaluation
217 | modify line60-line62 in reterival.py.
218 | ```bash
219 | python reterival.py
220 | ```
221 |
222 | ## Results
223 | ### Action Recognition
224 | #### Kinetics Pretrained (I3D)
225 | | Method | UCF101 | HMDB51 | Diving48 |
226 | | ---- | ---- | ---- | ---- |
227 | | Random Initialization | 57.9 | 29.6| 17.4|
228 | | MoCo Baseline | 70.4 | 36.3| 47.9 |
229 | | BE | 86.5 | 56.2| 62.6|
230 |
231 |
232 |
233 | ### Video Retrieve (HMDB51-C3D)
234 | | Method | @1 | @5 | @10| @20|@50 |
235 | | ---- | ---- | ---- | ---- | ---- | ---- |
236 | | BE | 10.2 | 27.6| 40.5 |56.2|76.6|
237 |
238 | ## More Visualization
239 | ### T-SNE
240 | please refer to __utils/visualization/t_SNE_Visualization.py__ for details.
241 |
242 | ### Confusion_Matrix
243 | please refer to __utils/visualization/confusion_matrix.py__ for details.
244 |
245 |
246 |
247 | ## Acknowledgement
248 | This work is partly based on [UEL](https://github.com/mangye16/Unsupervised_Embedding_Learning) and [MoCo](https://github.com/facebookresearch/moco).
249 |
250 | ## License
251 | The code are released under the CC-BY-NC 4.0 LICENSE.
--------------------------------------------------------------------------------
/src/Contrastive/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/BE/632c3aa0eaa3acc24a545ec05a9a36f96592cb2c/src/Contrastive/__init__.py
--------------------------------------------------------------------------------
/src/Contrastive/augment/basic_augmentation/eda.py:
--------------------------------------------------------------------------------
1 | # Easy data augmentation techniques for video
2 | # Jason Wei and Kai Zou
3 |
4 | import torch.nn as nn
5 | import random
6 | from random import shuffle
7 | #random.seed(1)
8 | # cleaning up text
9 | import re
10 |
11 |
12 |
13 | ########################################################################
14 | # Synonym replacement
15 | # Replace n words in the sentence with synonyms from wordnet
16 | ########################################################################
17 |
18 |
19 | def synonym_replacement(words, n):
20 | new_words = words.copy()
21 | random_word_list = list(set([word for word in words if word not in stop_words]))
22 | random.shuffle(random_word_list)
23 | num_replaced = 0
24 | for random_word in random_word_list:
25 | synonyms = get_synonyms(random_word)
26 | if len(synonyms) >= 1:
27 | synonym = random.choice(list(synonyms))
28 | new_words = [synonym if word == random_word else word for word in new_words]
29 | # print("replaced", random_word, "with", synonym)
30 | num_replaced += 1
31 | if num_replaced >= n: # only replace up to n words
32 | break
33 |
34 | # this is stupid but we need it, trust me
35 | sentence = ' '.join(new_words)
36 | new_words = sentence.split(' ')
37 |
38 | return new_words
39 |
40 |
41 | def get_synonyms(word):
42 | synonyms = set()
43 | for syn in wordnet.synsets(word):
44 | for l in syn.lemmas():
45 | synonym = l.name().replace("_", " ").replace("-", " ").lower()
46 | synonym = "".join([char for char in synonym if char in ' qwertyuiopasdfghjklzxcvbnm'])
47 | synonyms.add(synonym)
48 | if word in synonyms:
49 | synonyms.remove(word)
50 | return list(synonyms)
51 |
52 |
53 | ########################################################################
54 | # Random deletion
55 | # Randomly delete words from the sentence with probability p
56 | ########################################################################
57 |
58 | def random_deletion(videos, p):
59 | # obviously, if there's only one word, don't delete it
60 | b, c, t, h, w = videos.size()
61 | if t == 1:
62 | return videos
63 |
64 | # randomly delete words with probability p, padding or loop?
65 | '''
66 | # method 1
67 | new_videos = videos.copy()
68 | for i in range(t):
69 | r = random.uniform(0, 1)
70 | if r <= p:
71 | new_videos[:, :, i, :, :] = 0
72 | '''
73 | # method 2 loop
74 | new_videos = videos
75 | count = 0
76 | for i in range(t):
77 | r = random.uniform(0, 1)
78 | if r <= p:
79 | continue
80 | else:
81 | new_videos[:, :, count, :, :] = videos[:, :, i, :, :,]
82 | count += 1
83 | for i in range(t - count):
84 | new_videos[:, :, i, :, :] = videos[:, :, i, :, :]
85 | # if you end up deleting all words, just return a random word
86 | if new_videos.size()[2] == 0:
87 | rand_int = random.randint(0, t - 1)
88 | return [videos[:, :, rand_int, :, :]]
89 |
90 | return new_videos
91 |
92 |
93 | ########################################################################
94 | # Random swap
95 | # Randomly swap two words in the sentence n times
96 | ########################################################################
97 |
98 | def random_swap(videos, n):
99 | new_videos = videos
100 | for _ in range(n):
101 | new_videos = swap_word(new_videos)
102 | return new_videos
103 |
104 |
105 | def swap_word(new_videos):
106 | b, c, t, h, w = new_videos.size()
107 | random_idx_1 = random.randint(0, t - 1)
108 | random_idx_2 = random_idx_1
109 | counter = 0
110 | while random_idx_2 == random_idx_1:
111 | random_idx_2 = random.randint(0, t - 1)
112 | counter += 1
113 | if counter > 3:
114 | return new_videos
115 | new_videos[:, :, random_idx_1, :, :], new_videos[:, :, random_idx_2, :, :] \
116 | = new_videos[:, :, random_idx_2, :, :], new_videos[:, :, random_idx_1, :, :]
117 | return new_videos
118 |
119 |
120 | ########################################################################
121 | # Random insertion
122 | # Randomly insert n words into the sentence
123 | ########################################################################
124 |
125 | def random_insertion(videos, n):
126 | new_videos = videos
127 | for _ in range(n):
128 | new_videos = add_picture(new_videos)
129 | return new_videos
130 |
131 |
132 | def add_picture(videos):
133 | b, c, t, h, w = videos.size()
134 | new_videos = videos
135 | random_idx = random.randint(0, t - 1)
136 | random_idx2 = random.randint(0, t - 1)
137 | # this is from the same sample, may be need modify
138 | new_videos[:, :, random_idx+1:, :, :] = videos[:, :, random_idx:t-1, :, :]
139 | new_videos[:, :, random_idx, :, :] = videos[:, :, random_idx2, :, :]
140 | return new_videos
141 |
142 | ########################################################################
143 | # main data augmentation function
144 | ########################################################################
145 |
146 |
147 | class VideoEda(nn.Module):
148 | def __init__(self, alpha_sr=0.1, alpha_ri=0.1, alpha_rs=0.1, p_rd=0.1, num_aug=4):
149 | super(VideoEda, self).__init__()
150 | self.alpha_sr = alpha_sr
151 | self.alpha_ri = alpha_ri
152 | self.alpha_rs = alpha_rs
153 | self.p_rd = p_rd
154 | self.num_aug = num_aug
155 |
156 | def eda(self, inp):
157 | b, c, t, h, w = inp.size()
158 |
159 | augmented_sentences = []
160 | num_new_per_technique = int(self.num_aug / 4) + 1
161 | n_sr = max(1, int(self.alpha_sr * t))
162 | n_ri = max(1, int(self.alpha_ri * t))
163 | n_rs = max(1, int(self.alpha_rs * t))
164 |
165 | # # sr synonym replacement
166 | # for _ in range(num_new_per_technique):
167 | # inp = synonym_replacement(inp, n_sr)
168 |
169 | # ri random insertion
170 | for _ in range(num_new_per_technique):
171 | inp = random_insertion(inp, n_ri)
172 |
173 | # rs random swap
174 | for _ in range(num_new_per_technique):
175 | inp = random_swap(inp, n_rs)
176 |
177 | # # rd random delte
178 | # for _ in range(num_new_per_technique):
179 | # inp = random_deletion(inp, self.p_rd)
180 |
181 | return inp
182 |
183 | def forward(self, x):
184 | return self.eda(x)
--------------------------------------------------------------------------------
/src/Contrastive/augment/basic_augmentation/net_mixup.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class NETMIXUP(object):
5 | def __init__(self, alpha):
6 | self.alpha = alpha
7 |
8 | def gen_prob(self):
9 | lam = np.random.beta(self.alpha, self.alpha)
10 | return lam
11 |
12 | def construct(self, a, b, mixup_radio=0.5):
13 | c = mixup_radio * a + (1 - mixup_radio) * b
14 | return c
15 |
--------------------------------------------------------------------------------
/src/Contrastive/augment/basic_augmentation/noise.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | """
6 | usage
7 | z_rand = generate_noise([1,nzx,nzy], device=opt.device)
8 | z_rand = z_rand.expand(1,3,Z_opt.shape[2],Z_opt.shape[3])
9 | z_prev1 = 0.95*Z_opt +0.05*z_rand
10 | """
11 |
12 |
13 | def upsampling(im, sx, sy):
14 | m = nn.Upsample(size=[round(sx), round(sy)], mode='bilinear', align_corners=True)
15 | return m(im)
16 |
17 |
18 | def generate_noise(size, num_samp=1, device='cuda', type='gaussian', scale=1):
19 | if type == 'gaussian':
20 | noise = torch.randn(num_samp, size[0], round(size[1]/scale), round(size[2]/scale))
21 | noise = upsampling(noise, size[1], size[2])
22 | if type == 'gaussian_mixture':
23 | noise1 = torch.randn(num_samp, size[0], size[1], size[2]) + 5
24 | noise2 = torch.randn(num_samp, size[0], size[1], size[2])
25 | noise = noise1 + noise2
26 | if type == 'uniform':
27 | noise = torch.randn(num_samp, size[0], size[1], size[2])
28 | return noise
29 |
--------------------------------------------------------------------------------
/src/Contrastive/augment/basic_augmentation/rotation.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 |
4 |
5 | def batch_rotation(l_new_data):
6 | B, C, T, H, W = l_new_data.size()
7 | rotation_type = random.randint(0, 2)
8 | # print(rotation_type)
9 | if rotation_type == 0:
10 | index = list(range(T - 1, -1, -1))
11 | rotation_data = l_new_data[:, :, index, :, :]
12 | elif rotation_type == 1:
13 | index = list(range(H - 1, -1, -1))
14 | rotation_data = l_new_data[:, :, :, index, :]
15 | else:
16 | index = list(range(W - 1, -1, -1))
17 | rotation_data = l_new_data[:, :, :, :, index]
18 | return rotation_data, rotation_type
19 |
20 |
21 | def sample_rotation(l_new_data, rotation_type, trace=False):
22 | """
23 | her/vec flip (0, 1)
24 | rotation 90/180/270 degree(2, 3, 4)
25 | her flip + rotate90 / ver flip + rotate 90 (5, 6)
26 | :param l_new_data:
27 | :param rotation_type
28 | :return:
29 | """
30 | B, C, T, H, W = l_new_data.size()
31 | if not trace:
32 | rotated_data = torch.zeros_like(l_new_data).cuda()
33 | else:
34 | rotated_data = l_new_data
35 | for i in range(B):
36 | if rotation_type[i] == 0:
37 | rotated_data[i] = l_new_data[i].flip(2)
38 | elif rotation_type[i] == 1:
39 | rotated_data[i] = l_new_data[i].flip(3)
40 | elif rotation_type[i] == 2:
41 | rotated_data[i] = l_new_data[i].transpose(2, 3).flip(2)
42 | elif rotation_type[i] == 3:
43 | rotated_data[i] = l_new_data[i].flip(2).flip(3)
44 | elif rotation_type[i] == 4:
45 | rotated_data[i] = l_new_data[i].transpose(2, 3).flip(3)
46 | elif rotation_type[i] == 5:
47 | rotated_data[i] = l_new_data[i].flip(2).transpose(2, 3).flip(2)
48 | elif rotation_type[i] == 6:
49 | rotated_data[i] = l_new_data[i].flip(3).transpose(2, 3).flip(2)
50 | else:
51 | rotated_data[i] = l_new_data[i]
52 | return rotated_data
53 |
54 |
55 | def sample_rotation_cls(l_new_data, rotation_type):
56 | """
57 | her/vec flip (0, 1)
58 | rotation 90/180/270 degree(2, 3, 4)
59 | her flip + rotate90 / ver flip + rotate 90 (5, 6)
60 | :param l_new_data:
61 | :param rotation_type:
62 | :return:
63 | """
64 | B, C, T, H, W = l_new_data.size()
65 | rotated_data = torch.zeros_like(l_new_data).cuda()
66 | for i in range(B):
67 | if rotation_type[i] == 0:
68 | rotated_data[i] = l_new_data[i].flip(2)
69 | elif rotation_type[i] == 1:
70 | rotated_data[i] = l_new_data[i].flip(3)
71 | elif rotation_type[i] == 2:
72 | rotated_data[i] = l_new_data[i].transpose(2, 3).flip(2)
73 | elif rotation_type[i] == 3:
74 | rotated_data[i] = l_new_data[i].flip(2).flip(3)
75 | elif rotation_type[i] == 4:
76 | rotated_data[i] = l_new_data[i].transpose(2, 3).flip(3)
77 | elif rotation_type[i] == 5:
78 | rotated_data[i] = l_new_data[i].flip(2).transpose(2, 3).flip(2)
79 | elif rotation_type[i] == 6:
80 | rotated_data[i] = l_new_data[i].flip(3).transpose(2, 3).flip(2)
81 | else:
82 | rotated_data[i] = l_new_data[i]
83 | return rotated_data
84 |
85 |
86 | def four_rotation_cls(l_new_data, rotation_type):
87 | """
88 | rotation 0/90/180/270 degree(0,1,2,3)
89 | :return:
90 | """
91 | B, C, T, H, W = l_new_data.size()
92 | rotated_data = torch.zeros_like(l_new_data).cuda()
93 | for i in range(B):
94 | if rotation_type[i] == 1:
95 | rotated_data[i] = l_new_data[i].transpose(2, 3).flip(2)
96 | elif rotation_type[i] == 2:
97 | rotated_data[i] = l_new_data[i].flip(2).flip(3)
98 | elif rotation_type[i] == 3:
99 | rotated_data[i] = l_new_data[i].transpose(2, 3).flip(3)
100 | else:
101 | rotated_data[i] = l_new_data[i]
102 | return rotated_data
103 |
104 |
105 | def all_flips(l_new_data, flip_type, trace=False):
106 | """
107 | her/vec flip (0, 1)
108 | rotation 90/180/270 degree(2, 3, 4)
109 | her flip + rotate90 / ver flip + rotate 90 (5, 6)
110 | :param l_new_data:
111 | :param rotation_type
112 | :return:
113 | """
114 | """
115 | :param l_new_data:
116 | :param rotation_type
117 | :return:
118 | """
119 | B, C, T, H, W = l_new_data.size()
120 | if not trace:
121 | rotated_data = torch.zeros_like(l_new_data).cuda()
122 | else:
123 | rotated_data = l_new_data
124 | flip_type = flip_type // 4
125 | rot_type = flip_type % 4
126 | # flip at first
127 | for i in range(B):
128 | if flip_type[i] == 0:
129 | rotated_data[i] = l_new_data[i]
130 | elif flip_type[i] == 1: # left-right flip
131 | rotated_data[i] = l_new_data[i].flip(3)
132 | elif flip_type[i] == 2: # temporal flip
133 | rotated_data[i] = l_new_data[i].flip(1)
134 | else: # left-right + temporal flip
135 | rotated_data[i] = l_new_data[i].flip(3).flip(1)
136 | # then rotation
137 | for i in range(B):
138 | if rot_type[i] == 0:
139 | rotated_data[i] = l_new_data[i]
140 | elif rot_type[i] == 1: # 90 degree
141 | rotated_data[i] = l_new_data[i].transpose(2, 3).flip(2)
142 | elif rot_type[i] == 2: # 180 degree
143 | rotated_data[i] = l_new_data[i].flip(2).flip(3)
144 | else: # 270 degree
145 | rotated_data[i] = l_new_data[i].transpose(2, 3).flip(3)
146 |
147 | return rotated_data
148 |
149 |
150 | if __name__ == '__main__':
151 | a = torch.tensor([[1,2],[3,4]]).view(1, 1, 1, 2, 2)
152 | print(a.size())
153 | for i in range(8):
154 | print(torch.tensor(i))
155 | print(sample_rotation_cls(a, torch.tensor([i])))
--------------------------------------------------------------------------------
/src/Contrastive/augment/basic_augmentation/temporal_augment.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 |
4 |
5 | def temporal_augment(l_new_data, augment_type, trace=False):
6 | """
7 | :param l_new_data:
8 | :param rotation_type
9 | :return:
10 | """
11 | B, C, T, H, W = l_new_data.size()
12 | if not trace:
13 | rotated_data = torch.zeros_like(l_new_data).cuda()
14 | else:
15 | rotated_data = l_new_data
16 | flip_type = augment_type // 4
17 | rot_type = augment_type % 4
18 | # flip at first
19 | for i in range(B):
20 | if flip_type[i] == 0:
21 | rotated_data[i] = l_new_data[i]
22 | elif flip_type[i] == 1: # left-right flip
23 | rotated_data[i] = l_new_data[i].flip(3)
24 | elif flip_type[i] == 2: # temporal flip
25 | rotated_data[i] = l_new_data[i].flip(1)
26 | else: # left-right + temporal flip
27 | rotated_data[i] = l_new_data[i].flip(3).flip(1)
28 | # then rotation
29 | for i in range(B):
30 | if rot_type[i] == 0:
31 | rotated_data[i] = l_new_data[i]
32 | elif rot_type[i] == 1: # 90 degree
33 | rotated_data[i] = l_new_data[i].transpose(2, 3).flip(2)
34 | elif rot_type[i] == 2: # 180 degree
35 | rotated_data[i] = l_new_data[i].flip(2).flip(3)
36 | else: # 270 degree
37 | rotated_data[i] = l_new_data[i].transpose(2, 3).flip(3)
38 |
39 | return rotated_data
40 |
41 |
42 | def inverse_temporal_augment(l_new_data, augment_type, trace=False):
43 | """
44 | :param l_new_data:
45 | :param rotation_type
46 | :return:
47 | """
48 | B, C, T, H, W = l_new_data.size()
49 | if not trace:
50 | rotated_data = torch.zeros_like(l_new_data).cuda()
51 | else:
52 | rotated_data = l_new_data
53 | flip_type = augment_type // 4
54 | rot_type = augment_type % 4
55 | # flip at first
56 | for i in range(B):
57 | if flip_type[i] == 0:
58 | rotated_data[i] = l_new_data[i]
59 | elif flip_type[i] == 1: # left-right flip
60 | rotated_data[i] = l_new_data[i].flip(3)
61 | elif flip_type[i] == 2: # temporal flip
62 | rotated_data[i] = l_new_data[i].flip(1)
63 | else: # left-right + temporal flip
64 | rotated_data[i] = l_new_data[i].flip(3).flip(1)
65 | # then rotation
66 | for i in range(B):
67 | if rot_type[i] == 0:
68 | rotated_data[i] = l_new_data[i]
69 | elif rot_type[i] == 1: # -90 degree
70 | rotated_data[i] = l_new_data[i].transpose(2, 3).flip(3)
71 | elif rot_type[i] == 2: # -180 degree
72 | rotated_data[i] = l_new_data[i].flip(3).flip(2)
73 | else: # -270 degree
74 | rotated_data[i] = l_new_data[i].transpose(2, 3).flip(2)
75 |
76 | return rotated_data
--------------------------------------------------------------------------------
/src/Contrastive/augment/basic_augmentation/temporal_dropout.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | """
4 | # @Time : 2019-05-13 19:34
5 | # @Author : Awiny
6 | # @Site :
7 | # @Project : pytorch_i3d
8 | # @File : TemporalDropoutBlock.py
9 | # @Software: PyCharm
10 | # @Github : https://github.com/FingerRec
11 | # @Blog : http://fingerrec.github.io
12 | """
13 | import scipy.io
14 | import os
15 | import torch
16 | import torch
17 | import torch.nn as nn
18 | import torch.nn.functional as F
19 | import random
20 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #close the warning
21 |
22 |
23 | class DropBlock2D(nn.Module):
24 | r"""Randomly zeroes 2D spatial blocks of the input tensor.
25 | As described in the paper
26 | `DropBlock: A regularization method for convolutional networks`_ ,
27 | dropping whole blocks of feature map allows to remove semantic
28 | information as compared to regular dropout.
29 | Args:
30 | drop_prob (float): probability of an element to be dropped.
31 | block_size (int): size of the block to drop
32 | Shape:
33 | - Input: `(N, C, H, W)`
34 | - Output: `(N, C, H, W)`
35 | .. _DropBlock: A regularization method for convolutional networks:
36 | https://arxiv.org/abs/1810.12890
37 | """
38 |
39 | def __init__(self, drop_prob, block_size):
40 | super(DropBlock2D, self).__init__()
41 |
42 | self.drop_prob = drop_prob
43 | self.block_size = block_size
44 |
45 | def forward(self, x):
46 | # shape: (bsize, channels, height, width)
47 |
48 | assert x.dim() == 4, \
49 | "Expected input with 4 dimensions (bsize, channels, height, width)"
50 |
51 | if not self.training or self.drop_prob == 0.:
52 | return x
53 | else:
54 | # get gamma value
55 | gamma = self._compute_gamma(x)
56 |
57 | # sample mask
58 | mask = (torch.rand(x.shape[0], *x.shape[2:]) < gamma).float()
59 |
60 | # place mask on input device
61 | mask = mask.to(x.device)
62 |
63 | # compute block mask
64 | block_mask = self._compute_block_mask(mask)
65 |
66 | # apply block mask
67 | out = x * block_mask[:, None, :, :]
68 |
69 | # scale output
70 | out = out * block_mask.numel() / block_mask.sum()
71 |
72 | return out
73 |
74 | def _compute_block_mask(self, mask):
75 | block_mask = F.max_pool2d(input=mask[:, None, :, :],
76 | kernel_size=(self.block_size, self.block_size),
77 | stride=(1, 1),
78 | padding=self.block_size // 2)
79 |
80 | if self.block_size % 2 == 0:
81 | block_mask = block_mask[:, :, :-1, :-1]
82 |
83 | block_mask = 1 - block_mask.squeeze(1)
84 |
85 | return block_mask
86 |
87 | def _compute_gamma(self, x):
88 | return self.drop_prob / (self.block_size ** 2)
89 |
90 |
91 | class DropBlock3D(DropBlock2D):
92 | r"""Randomly zeroes 3D spatial blocks of the input tensor.
93 | An extension to the concept described in the paper
94 | `DropBlock: A regularization method for convolutional networks`_ ,
95 | dropping whole blocks of feature map allows to remove semantic
96 | information as compared to regular dropout.
97 | Args:
98 | drop_prob (float): probability of an element to be dropped.
99 | block_size (int): size of the block to drop
100 | Shape:
101 | - Input: `(N, C, D, H, W)`
102 | - Output: `(N, C, D, H, W)`
103 | .. _DropBlock: A regularization method for convolutional networks:
104 | https://arxiv.org/abs/1810.12890
105 | """
106 |
107 | def __init__(self, drop_prob, block_size):
108 | super(DropBlock3D, self).__init__(drop_prob, block_size)
109 |
110 | def forward(self, x):
111 | # shape: (bsize, channels, depth, height, width)
112 |
113 | assert x.dim() == 5, \
114 | "Expected input with 5 dimensions (bsize, channels, depth, height, width)"
115 |
116 | if not self.training or self.drop_prob == 0.:
117 | return x
118 | else:
119 | # get gamma value
120 | gamma = self._compute_gamma(x)
121 |
122 | # sample mask
123 | mask = (torch.rand(x.shape[0], *x.shape[2:]) < gamma).float()
124 |
125 | # place mask on input device
126 | mask = mask.to(x.device)
127 |
128 | # compute block mask
129 | block_mask = self._compute_block_mask(mask)
130 |
131 | # apply block mask
132 | out = x * block_mask[:, None, :, :, :]
133 |
134 | # scale output
135 | out = out * block_mask.numel() / block_mask.sum()
136 |
137 | return out
138 |
139 | def _compute_block_mask(self, mask):
140 | block_mask = F.max_pool3d(input=mask[:, None, :, :, :],
141 | kernel_size=(self.block_size, self.block_size, self.block_size),
142 | stride=(1, 1, 1),
143 | padding=self.block_size // 2)
144 |
145 | if self.block_size % 2 == 0:
146 | block_mask = block_mask[:, :, :-1, :-1, :-1]
147 |
148 | block_mask = 1 - block_mask.squeeze(1)
149 |
150 | return block_mask
151 |
152 | def _compute_gamma(self, x):
153 | return self.drop_prob / (self.block_size ** 3)
154 |
155 |
156 | class TemporalDropoutBlock(nn.Module):
157 | """
158 | method1, for 3d feature map BxCxTxHxW reshape as Bx[CxT]xHxW
159 | """
160 | def __init__(self, dropout_radio):
161 | super(TemporalDropoutBlock, self).__init__()
162 | self.dropout = nn.Dropout(dropout_radio)
163 |
164 | def forward(self, x):
165 | b, c, t, h, w = x.size()
166 | x = x.view(b, c*t, h, w)
167 | x = self.dropout(x)
168 | x = x.view(b, c, t, h, w)
169 | return x
170 |
171 |
172 | class TemporalDropoutBlock3D(nn.Module):
173 | r"""
174 | method2, for 3d feature map BxCxTxHxW, random dropout in T
175 | """
176 |
177 | def __init__(self, drop_prob):
178 | super(TemporalDropoutBlock3D, self).__init__()
179 | self.dropout = nn.Dropout3d(drop_prob)
180 |
181 | def forward(self, x):
182 | x = x.permute(0, 2, 1, 3, 4)
183 | x = self.dropout(x)
184 | x = x.permute(0, 2, 1, 3, 4)
185 | return x
186 |
187 |
188 | class TemporalBranchDropout(nn.Module):
189 | """
190 | Branch dropout
191 | """
192 | def __init__(self, drop_prob):
193 | super(TemporalBranchDropout, self).__init__()
194 | self.dropout = nn.Dropout(1)
195 | self.drop_prob = drop_prob
196 |
197 | def forward(self, x):
198 | prob = random.random()
199 | if prob < self.drop_prob:
200 | x = self.dropout(x)
201 | #print(x)
202 | else:
203 | x = x
204 | return x
205 |
206 |
--------------------------------------------------------------------------------
/src/Contrastive/augment/basic_augmentation/temporal_shuffle.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | """
4 | # @Time : 2019-05-21 10:37
5 | # @Author : Awiny
6 | # @Site :
7 | # @Project : amax-pytorch-i3d
8 | # @File : TemporalShuffle.py
9 | # @Software: PyCharm
10 | # @Github : https://github.com/FingerRec
11 | # @Blog : http://fingerrec.github.io
12 | """
13 |
14 | import torch
15 | import torch.nn as nn
16 | import os
17 | import difflib
18 |
19 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #close the warning
20 |
21 |
22 | class TemporalShuffle(nn.Module):
23 | """
24 | for this module, random shuffle temporal dim, we want to find if the temporal information is important
25 | """
26 | def __init__(self, s=1):
27 | super(TemporalShuffle, self).__init__()
28 | self.s = s
29 |
30 | def forward(self, x):
31 | """
32 | random shuffle temporal dim
33 | :param x: b x c x t x h x w
34 | :return: out: b x c x t' x h x w
35 | """
36 | t = x.size(2)
37 | origin_idx = list(range(t))
38 | idxs = []
39 | K = 4
40 | similarity = 1
41 | # ==================================method1========================
42 | while similarity >= 1:
43 | if self.s == 1:
44 | idxs = torch.randperm(t)
45 | elif self.s == 2:
46 | idx = torch.randperm(K)
47 | for i in range(K):
48 | for j in range(t // K):
49 | idxs.append(idx[i].item() * t // K + j)
50 | else:
51 | for i in range(K):
52 | idx = torch.randperm(t//K)
53 | for j in range(len(idx)):
54 | idxs.append(t//K*i + idx[j].item())
55 | similarity = difflib.SequenceMatcher(None, idxs, origin_idx).ratio()
56 | # print(idxs)
57 | # print(similarity)
58 | out = x[:, :, idxs, :, :]
59 | return out
--------------------------------------------------------------------------------
/src/Contrastive/augment/basic_augmentation/temporal_sub.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | """
4 | # @Time : 2019-05-12 21:54
5 | # @Author : Awiny
6 | # @Site :
7 | # @Project : amax-pytorch-i3d
8 | # @File : TemporalSubBlock.py
9 | # @Software: PyCharm
10 | # @Github : https://github.com/FingerRec
11 | # @Blog : http://fingerrec.github.io
12 | """
13 | import scipy.io
14 | import os
15 | import torch.nn as nn
16 | import torch
17 |
18 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #close the warning
19 |
20 | class TemporalSubBlock(nn.Module):
21 | def __init__(self):
22 | super(TemporalSubBlock, self).__init__()
23 | def forward(self, x):
24 | b, c, t, h, w = x.size()
25 | y = torch.zeros((b, c, t-1, h, w)).cuda()
26 | for i in range(t-1):
27 | y[:,:,i,:,:] = x[:,:,i+1,:,:] - x[:,:,i,:,:]
28 | return y
29 |
30 | class TemporalSubMeanBlock(nn.Module):
31 | def __init__(self):
32 | super(TemporalSubMeanBlock, self).__init__()
33 |
34 | def forward(self, x):
35 | b, c, t, h, w = x.size()
36 | mean = x.mean(2)
37 | for i in range(t):
38 | x[:,:,i,:,:] -= mean
39 | return x
40 |
--------------------------------------------------------------------------------
/src/Contrastive/augment/basic_augmentation/triplet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import difflib
3 | import random
4 | from augment.basic_augmentation.mixup_methods import SpatialMixup
5 |
6 |
7 | def swap_one_time(seq):
8 | """
9 | swap a seq random place one time
10 | :param seq:
11 | :return:
12 | """
13 | length = len(seq)
14 | index = random.randint(0, length-1)
15 | index2 = random.randint(0, length-1)
16 | new_seq = seq.clone()
17 | new_seq[index] = seq[index2]
18 | new_seq[index2] = seq[index]
19 | return new_seq
20 |
21 |
22 | def gen_sim_seq(K, radio=0.95, segments=4):
23 | """
24 | generate shuffled video sequences as negative, while random shuffle is always zero, control the segments
25 | (eg. divide into 4 segments and shuffle these segments)
26 | :param K:
27 | :param radio:
28 | :return:
29 | """
30 | similarity = 1
31 | idx = torch.arange(K)
32 | assert K % segments == 0
33 | seg_len = K // segments
34 | origin_idx = torch.arange(K).long()
35 | # revise_idx = torch.tensor(list(range(K - 1, -1, -1)))
36 | # print(revise_idx)
37 | while similarity > radio:
38 | # seg_idx = torch.randperm(segments)
39 | # for i in range(segments):
40 | # idx[i*seg_len:(i+1)*seg_len] = origin_idx[seg_idx[i]*seg_len:(seg_idx[i]+1)*seg_len]
41 | idx = swap_one_time(idx)
42 | similarity = difflib.SequenceMatcher(None, idx, origin_idx).ratio()
43 | # print(idx)
44 | # print(origin_idx)
45 | # print(similarity)
46 | return idx
47 |
48 |
49 | def batch_lst(lst, k):
50 | return lst[k:] + lst[:k]
51 |
52 |
53 | class TRIPLET(object):
54 | def __init__(self, t_radio=0.95, s_radio=0.7):
55 | self.t_radio = t_radio
56 | self.s_radio = s_radio
57 | self.spaital_mixup = SpatialMixup(0.8)
58 |
59 | def construct(self, input):
60 | b, c, t, h, w = input.size()
61 | #spatial_noise = generate_noise((c, h, w), b)
62 | # print(sum(sum(sum(sum(spatial_noise)))))
63 | # print(sum(sum(sum(sum(sum(input))))))
64 | # postive = torch.zeros_like(input)
65 | # drop_radio = random.random() * self.s_radio
66 | postive = self.spaital_mixup.mixup_data(input, trace=False)
67 | # for i in range(t):
68 | # postive[:, :, i, :, :] = (1 - drop_radio) * input[:, :, i, :, :] + drop_radio * spatial_noise[:, :, :, :]
69 | # if the match low than radio, as it negative
70 | # negative_seq = gen_sim_seq(t, self.radio)
71 | # negative = torch.zeros_like(input)
72 | # negative = negative[:, :, negative_seq, :, :]
73 | index = random.randint(1, b-1)
74 | indexs = batch_lst(list(range(b)), index)
75 | negative = input[indexs]
76 | # print(negative - input)
77 | return input, postive, negative
--------------------------------------------------------------------------------
/src/Contrastive/augment/basic_augmentation/video_color_jitter.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 |
4 | class ColorJitter(object):
5 |
6 | def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4):
7 | self.brightness = brightness
8 | self.contrast = contrast
9 | self.saturation = saturation
10 |
11 | def __call__(self, imgs):
12 | t, h, w, c = imgs.shape
13 | #print(t, h, w, c)
14 | self.transforms = []
15 | if self.brightness != 0:
16 | self.transforms.append(Brightness(self.brightness))
17 | if self.contrast != 0:
18 | self.transforms.append(Contrast(self.contrast))
19 | if self.saturation != 0:
20 | self.transforms.append(Saturation(self.saturation))
21 |
22 | random.shuffle(self.transforms)
23 | transform = Compose(self.transforms)
24 | # print(transform)
25 | for i in range(t):
26 | imgs[i, :, :, :] = transform(imgs[i, :, :, :])
27 | return imgs
28 |
29 | class Saturation(object):
30 |
31 | def __init__(self, var):
32 | self.var = var
33 |
34 | def __call__(self, img):
35 | gs = Grayscale()(img)
36 | alpha = random.uniform(-self.var, self.var)
37 | # return img.lerp(gs, alpha)
38 | cover_img = img
39 | for i in range(3):
40 | cover_img[:,:,i] = (1-alpha) * img[:,:,i] + alpha * gs
41 | return cover_img
42 |
43 |
44 | class Brightness(object):
45 |
46 | def __init__(self, var):
47 | self.var = var
48 |
49 | def __call__(self, img):
50 | # gs = img.new().resize_as_(img).zero_()
51 | alpha = random.uniform(-self.var, self.var)
52 | return alpha * img
53 | # return img.lerp(gs, alpha)
54 |
55 |
56 | class Contrast(object):
57 |
58 | def __init__(self, var):
59 | self.var = var
60 |
61 | def __call__(self, img):
62 | # gs = Grayscale()(img)
63 | # gs.fill_(gs.mean())
64 | # alpha = random.uniform(-self.var, self.var)
65 | # return img.lerp(gs, alpha)
66 | return np.mean(img) + self.var * (img- np.mean(img))
67 |
68 | class Grayscale(object):
69 |
70 | def __call__(self, img):
71 | # gs = img.clone()
72 | # gs[0].mul_(0.299).add_(0.587, gs[1]).add_(0.114, gs[2])
73 | # gs[1].copy_(gs[0])
74 | # gs[2].copy_(gs[0])
75 | gs = np.dot(img[...,:3], [0.2989, 0.5870, 0.1140])
76 | return gs
77 |
78 |
79 | class Compose(object):
80 | """Composes several transforms together.
81 | Args:
82 | transforms (list of ``Transform`` objects): list of transforms to compose.
83 | Example:
84 | >>> transforms.Compose([
85 | >>> transforms.CenterCrop(10),
86 | >>> transforms.ToTensor(),
87 | >>> ])
88 | """
89 |
90 | def __init__(self, transforms):
91 | self.transforms = transforms
92 |
93 | def __call__(self, img):
94 | for t in self.transforms:
95 | img = t(img)
96 | return img
97 |
98 | def __repr__(self):
99 | format_string = self.__class__.__name__ + '('
100 | for t in self.transforms:
101 | format_string += '\n'
102 | format_string += ' {0}'.format(t)
103 | format_string += '\n)'
104 | return format_string
105 |
--------------------------------------------------------------------------------
/src/Contrastive/augment/config.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch
3 |
4 |
5 | class TC(nn.Module):
6 | def __init__(self, args):
7 | super(TC, self).__init__()
8 | self.args = args
9 |
10 | def forward(self, input):
11 | output = input.cuda()
12 | # output = self.mixup.mixup_data(output)
13 | output = torch.autograd.Variable(output)
14 | return output
15 |
--------------------------------------------------------------------------------
/src/Contrastive/augment/gen_negative.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import numpy as np
3 | import torch
4 | from .basic_augmentation.temporal_shuffle import TemporalShuffle
5 | from .basic_augmentation.temporal_dropout import TemporalDropoutBlock3D
6 | from .basic_augmentation.eda import VideoEda
7 |
8 |
9 | class GenNegative(nn.Module):
10 | def __init__(self, prob=0.3):
11 | super(GenNegative, self).__init__()
12 | self.prob = prob
13 | self.t_shuffle = TemporalShuffle()
14 | self.t_drop = TemporalDropoutBlock3D(0.1)
15 | self.t_eda = VideoEda()
16 |
17 | def temporal_dropout(self, x):
18 | return self.t_drop(x)
19 |
20 | def temporal_shuffle(self, x):
21 | return self.t_shuffle(x)
22 |
23 | def temporal_eda(self, x):
24 | return self.t_eda(x)
25 |
26 | def forward(self, x):
27 | # x = self.temporal_shuffle(x)
28 | # x = self.temporal_dropout(x)
29 | x = self.temporal_eda(x)
30 | return x
--------------------------------------------------------------------------------
/src/Contrastive/augment/gen_positive.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from augment.basic_augmentation.mixup_methods import *
3 | from augment.video_transformations.videotransforms import ColorJitter
4 |
5 |
6 | class GenPositive(nn.Module):
7 | def __init__(self, prob=0.3):
8 | super(GenPositive, self).__init__()
9 | self.iv_mixup = SpatialMixup(0.3, trace=False, version=2)
10 | self.im_mixup = SpatialMixup(0.3, trace=False, version=3)
11 | self.cut = Cut(1, 0.05)
12 | self.prob = prob
13 |
14 | def intra_video_mixup(self, x):
15 | return self.iv_mixup.mixup_data(x)
16 |
17 | def inter_video_mixup(self, x):
18 | return self.im_mixup.mixup_data(x)
19 |
20 | def video_cut(self, x):
21 | return self.cut.cut_data(x)
22 |
23 | def forward(self, x):
24 | # x = self.inter_video_mixup(x)
25 | x = self.intra_video_mixup(x)
26 | return x
27 |
--------------------------------------------------------------------------------
/src/Contrastive/augment/video_transformations/functional.py:
--------------------------------------------------------------------------------
1 | import numbers
2 | import torch
3 | # import cv2
4 | import numpy as np
5 | import PIL
6 |
7 |
8 | def _is_tensor_clip(clip):
9 | return torch.is_tensor(clip) and clip.ndimension() == 4
10 |
11 |
12 | def crop_clip(clip, min_h, min_w, h, w):
13 | if isinstance(clip[0], np.ndarray):
14 | cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip]
15 |
16 | elif isinstance(clip[0], PIL.Image.Image):
17 | cropped = [
18 | img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip
19 | ]
20 | else:
21 | raise TypeError('Expected numpy.ndarray or PIL.Image' +
22 | 'but got list of {0}'.format(type(clip[0])))
23 | return cropped
24 |
25 |
26 | def resize_clip(clip, size, interpolation='bilinear'):
27 | if isinstance(clip[0], np.ndarray):
28 | if isinstance(size, numbers.Number):
29 | im_h, im_w, im_c = clip[0].shape
30 | # Min spatial dim already matches minimal size
31 | if (im_w <= im_h and im_w == size) or (im_h <= im_w
32 | and im_h == size):
33 | return clip
34 | new_h, new_w = get_resize_sizes(im_h, im_w, size)
35 | size = (new_w, new_h)
36 | else:
37 | size = size[1], size[0]
38 | if interpolation == 'bilinear':
39 | np_inter = cv2.INTER_LINEAR
40 | else:
41 | np_inter = cv2.INTER_NEAREST
42 | scaled = [
43 | cv2.resize(img, size, interpolation=np_inter) for img in clip
44 | ]
45 | elif isinstance(clip[0], PIL.Image.Image):
46 | if isinstance(size, numbers.Number):
47 | im_w, im_h = clip[0].size
48 | # Min spatial dim already matches minimal size
49 | if (im_w <= im_h and im_w == size) or (im_h <= im_w
50 | and im_h == size):
51 | return clip
52 | new_h, new_w = get_resize_sizes(im_h, im_w, size)
53 | size = (new_w, new_h)
54 | else:
55 | size = size[1], size[0]
56 | if interpolation == 'bilinear':
57 | pil_inter = PIL.Image.NEAREST
58 | else:
59 | pil_inter = PIL.Image.BILINEAR
60 | scaled = [img.resize(size, pil_inter) for img in clip]
61 | else:
62 | raise TypeError('Expected numpy.ndarray or PIL.Image' +
63 | 'but got list of {0}'.format(type(clip[0])))
64 | return scaled
65 |
66 |
67 | def get_resize_sizes(im_h, im_w, size):
68 | if im_w < im_h:
69 | ow = size
70 | oh = int(size * im_h / im_w)
71 | else:
72 | oh = size
73 | ow = int(size * im_w / im_h)
74 | return oh, ow
75 |
76 |
77 | def normalize(clip, mean, std, inplace=False):
78 | if not _is_tensor_clip(clip):
79 | raise TypeError('tensor is not a torch clip.')
80 |
81 | if not inplace:
82 | clip = clip.clone()
83 |
84 | dtype = clip.dtype
85 | dim = len(mean)
86 | mean = torch.as_tensor(mean, dtype=dtype, device=clip.device)
87 | std = torch.as_tensor(std, dtype=dtype, device=clip.device)
88 | # print(clip.size())
89 | # if dim == 3:
90 | clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
91 | # else:
92 | # clip.sub_(mean[:, None, None]).div_(std[:, None, None])
93 | return clip
--------------------------------------------------------------------------------
/src/Contrastive/augment/video_transformations/volume_transforms.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | import torch
4 |
5 | def convert_img(img):
6 | """Converts (H, W, C) numpy.ndarray to (C, W, H) format
7 | """
8 | if len(img.shape) == 3:
9 | img = img.transpose(2, 0, 1)
10 | if len(img.shape) == 2:
11 | img = np.expand_dims(img, 0)
12 | return img
13 |
14 |
15 | class ClipToTensor(object):
16 | """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255]
17 | to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0]
18 | """
19 |
20 | def __init__(self, channel_nb=3, div_255=True, numpy=False):
21 | self.channel_nb = channel_nb
22 | self.div_255 = div_255
23 | self.numpy = numpy
24 |
25 | def __call__(self, clip):
26 | """
27 | Args: clip (list of numpy.ndarray): clip (list of images)
28 | to be converted to tensor.
29 | """
30 | # Retrieve shape
31 | if isinstance(clip[0], np.ndarray):
32 | h, w, ch = clip[0].shape
33 | assert ch == self.channel_nb, 'Got {0} instead of 3 channels'.format(
34 | ch)
35 | elif isinstance(clip[0], Image.Image):
36 | w, h = clip[0].size
37 | else:
38 | raise TypeError('Expected numpy.ndarray or PIL.Image\
39 | but got list of {0}'.format(type(clip[0])))
40 |
41 | np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)])
42 |
43 | # Convert
44 | for img_idx, img in enumerate(clip):
45 | if isinstance(img, np.ndarray):
46 | pass
47 | elif isinstance(img, Image.Image):
48 | img = np.array(img, copy=False)
49 | else:
50 | raise TypeError('Expected numpy.ndarray or PIL.Image\
51 | but got list of {0}'.format(type(clip[0])))
52 | img = convert_img(img)
53 | np_clip[:, img_idx, :, :] = img
54 | if self.numpy:
55 | if self.div_255:
56 | np_clip = np_clip / 255
57 | return np_clip
58 |
59 | else:
60 | tensor_clip = torch.from_numpy(np_clip)
61 |
62 | if not isinstance(tensor_clip, torch.FloatTensor):
63 | tensor_clip = tensor_clip.float()
64 | if self.div_255:
65 | tensor_clip = tensor_clip.div(255)
66 | return tensor_clip
67 |
68 |
69 | class ToTensor(object):
70 | """Converts numpy array to tensor
71 | """
72 |
73 | def __call__(self, array):
74 | tensor = torch.from_numpy(array)
75 | return tensor
--------------------------------------------------------------------------------
/src/Contrastive/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/BE/632c3aa0eaa3acc24a545ec05a9a36f96592cb2c/src/Contrastive/data/__init__.py
--------------------------------------------------------------------------------
/src/Contrastive/data/base.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import cv2
3 |
4 |
5 | # from config import opt
6 | def video_to_tensor(pic):
7 | """Convert a ``numpy.ndarray`` to tensor.
8 | Converts a numpy.ndarray (T x H x W x C)
9 | to a torch.FloatTensor of shape (C x T x H x W)
10 |
11 | Args:
12 | pic (numpy.ndarray): Video to be converted to tensor.
13 | Returns:
14 | Tensor: Converted video.
15 | """
16 | # return torch.from_numpy(pic)
17 | return torch.from_numpy(pic.transpose([3, 0, 1, 2])).type(torch.FloatTensor)
18 |
19 |
20 | class VideoRecord(object):
21 | def __init__(self, row):
22 | self._data = row
23 |
24 | @property
25 | def path(self):
26 | return self._data[0]
27 |
28 | @property
29 | def num_frames(self):
30 | return int(self._data[1]) - 1
31 |
32 | @property
33 | def label(self):
34 | return int(self._data[2])
35 |
36 |
37 | def video_frame_count(video_path):
38 | cap = cv2.VideoCapture(video_path)
39 | if not cap.isOpened():
40 | #print("could not open: ", video_path)
41 | return -1
42 | length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
43 | width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) )
44 | height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) )
45 | return length, width, height
46 |
47 | #
48 | # if __name__ == '__main__':
49 | # print("")
--------------------------------------------------------------------------------
/src/Contrastive/data/config.py:
--------------------------------------------------------------------------------
1 | import augment.video_transformations.videotransforms as videotransforms
2 | import augment.video_transformations.video_transform_PIL_or_np as video_transform
3 | from augment.video_transformations.volume_transforms import ClipToTensor
4 |
5 | from torchvision import transforms
6 |
7 |
8 | def pt_data_config(args):
9 | if args.pt_dataset == 'ucf101':
10 | num_class = 101
11 | image_tmpl = "frame{:06d}.jpg"
12 | elif args.pt_dataset == 'hmdb51':
13 | num_class = 51
14 | # image_tmpl = "frame{:06d}.jpg"
15 | image_tmpl = "img_{:05d}.jpg"
16 | # image_tmpl = "image_{:05d}.jpg"
17 | elif args.pt_dataset == 'kinetics':
18 | num_class = 400
19 | image_tmpl = "img_{:05d}.jpg"
20 | # args.root = "/data1/DataSet/Kinetics/compress/"
21 | elif args.pt_dataset == 'sth_v1':
22 | num_class = 174
23 | image_tmpl = "{:05d}.jpg"
24 | elif args.pt_dataset == 'diving48':
25 | num_class = 48
26 | image_tmpl = "image_{:05d}.jpg"
27 | else:
28 | raise ValueError('Unknown dataset ' + args.dataset)
29 | return num_class, int(args.pt_data_length), image_tmpl
30 |
31 |
32 | def ft_data_config(args):
33 | if args.ft_dataset == 'ucf101':
34 | num_class = 101
35 | image_tmpl = "frame{:06d}.jpg"
36 | elif args.ft_dataset == 'hmdb51':
37 | num_class = 51
38 | # image_tmpl = "frame{:06d}.jpg"
39 | image_tmpl = "img_{:05d}.jpg"
40 | # image_tmpl = "image_{:05d}.jpg"
41 | elif args.ft_dataset == 'kinetics':
42 | num_class = 400
43 | image_tmpl = "img_{:05d}.jpg"
44 | # args.root = "/data1/DataSet/Kinetics/compress/"
45 | elif args.ft_dataset == 'sth_v1':
46 | num_class = 174
47 | image_tmpl = "{:05d}.jpg"
48 | elif args.ft_dataset == 'diving48':
49 | num_class = 48
50 | image_tmpl = "image_{:05d}.jpg"
51 | else:
52 | raise ValueError('Unknown dataset ' + args.dataset)
53 | return num_class, int(args.ft_data_length), image_tmpl
54 |
55 |
56 | def pt_augmentation_config(args):
57 | if int(args.pt_spatial_size) == 112:
58 | # print("??????????????????????????")
59 | resize_size = 128
60 | else:
61 | resize_size = 256
62 | if args.pt_mode == 'rgb':
63 | normalize = video_transform.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
64 | else:
65 | normalize = video_transform.Normalize(mean=[0.485, 0.456], std=[0.229, 0.224])
66 | train_transforms = transforms.Compose([
67 | # videotransforms.RandomCrop(int(args.spatial_size)),
68 | video_transform.RandomRotation(10),
69 | # video_transform.ColorDistortion(1),
70 | # video_transform.STA_RandomRotation(10),
71 | # video_transform.Each_RandomRotation(10),
72 | video_transform.Resize(resize_size),
73 | video_transform.RandomCrop(int(args.pt_spatial_size)),
74 | video_transform.ColorJitter(0.5, 0.5, 0.25, 0.5),
75 | ClipToTensor(channel_nb=3 if args.pt_mode == 'rgb' else 2),
76 | normalize
77 | # videotransforms.ColorJitter(),
78 | # videotransforms.RandomHorizontalFlip()
79 | ])
80 | test_transforms = transforms.Compose([
81 | video_transform.Resize(resize_size),
82 | video_transform.CenterCrop(int(args.pt_spatial_size)),
83 | ClipToTensor(channel_nb=3 if args.pt_mode == 'rgb' else 2),
84 | normalize
85 | ]
86 | )
87 | eval_transfroms = transforms.Compose([
88 | video_transform.Resize(resize_size),
89 | video_transform.CenterCrop(int(args.pt_spatial_size)),
90 | ClipToTensor(channel_nb=3 if args.pt_mode == 'rgb' else 2),
91 | normalize
92 | ]
93 | )
94 | return train_transforms, test_transforms, eval_transfroms
95 |
96 |
97 | def ft_augmentation_config(args):
98 | if int(args.ft_spatial_size) == 112:
99 | resize_size = 128
100 | else:
101 | resize_size = 256
102 | if args.ft_mode == 'rgb':
103 | normalize = video_transform.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
104 | else:
105 | normalize = video_transform.Normalize(mean=[0.485, 0.456], std=[0.229, 0.224])
106 | train_transforms = transforms.Compose([
107 | video_transform.RandomRotation(10),
108 | video_transform.Resize(resize_size),
109 | video_transform.RandomCrop(int(args.ft_spatial_size)),
110 | video_transform.ColorJitter(0.5, 0.5, 0.25, 0.5),
111 | ClipToTensor(channel_nb=3 if args.ft_mode == 'rgb' else 2),
112 | normalize
113 | ])
114 | test_transforms = transforms.Compose([
115 | video_transform.Resize(resize_size),
116 | video_transform.CenterCrop(int(args.ft_spatial_size)),
117 | ClipToTensor(channel_nb=3 if args.ft_mode == 'rgb' else 2),
118 | normalize
119 | ]
120 | )
121 | eval_transfroms = transforms.Compose([
122 | video_transform.Resize(resize_size),
123 | video_transform.CenterCrop(int(args.ft_spatial_size)),
124 | ClipToTensor(channel_nb=3 if args.ft_mode == 'rgb' else 2),
125 | normalize
126 | ]
127 | )
128 | return train_transforms, test_transforms, eval_transfroms
129 |
--------------------------------------------------------------------------------
/src/Contrastive/data/dataloader.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def pt_data_loader_init(args, data_length, image_tmpl, train_transforms, test_transforms, eval_transforms):
5 | if args.pt_dataset in ['ucf101', 'hmdb51', 'diving48', 'sth_v1']:
6 | from data.dataset import DataSet as DataSet
7 | elif args.pt_dataset == 'kinetics':
8 | from data.video_dataset import VideoDataSet as DataSet
9 | else:
10 | Exception("unsupported dataset")
11 | train_dataset = DataSet(args, args.pt_root, args.pt_train_list, num_segments=1, new_length=data_length,
12 | stride=args.pt_stride, modality=args.pt_mode, dataset=args.pt_dataset, test_mode=False,
13 | image_tmpl=image_tmpl if args.pt_mode in ["rgb", "RGBDiff"]
14 | else args.pt_flow_prefix + "{}_{:05d}.jpg", transform=train_transforms)
15 | print("training samples:{}".format(train_dataset.__len__()))
16 | train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.pt_batch_size, shuffle=True,
17 | num_workers=args.pt_workers, pin_memory=True)
18 | val_dataset = DataSet(args, args.pt_root, args.pt_val_list, num_segments=1, new_length=data_length,
19 | stride=args.pt_stride, modality=args.pt_mode, test_mode=True, dataset=args.pt_dataset,
20 | image_tmpl=image_tmpl if args.pt_mode in ["rgb", "RGBDiff"] else args.pt_flow_prefix + "{}_{:05d}.jpg",
21 | random_shift=False, transform=test_transforms)
22 | print("val samples:{}".format(val_dataset.__len__()))
23 | val_data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.pt_batch_size, shuffle=False,
24 | num_workers=args.pt_workers, pin_memory=True)
25 | eval_dataset = DataSet(args, args.pt_root, args.pt_val_list, num_segments=1, new_length=data_length,
26 | stride=args.pt_stride, modality=args.pt_mode, test_mode=True, dataset=args.pt_dataset,
27 | image_tmpl=image_tmpl if args.pt_mode in ["rgb", "RGBDiff"] else args.pt_flow_prefix + "{}_{:05d}.jpg",
28 | random_shift=False, transform=eval_transforms, full_video=True)
29 | print("eval samples:{}".format(eval_dataset.__len__()))
30 | eval_data_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=args.pt_batch_size, shuffle=False,
31 | num_workers=args.pt_workers, pin_memory=True)
32 | return train_data_loader, val_data_loader, eval_data_loader, train_dataset.__len__(), val_dataset.__len__(), eval_dataset.__len__()
33 |
34 |
35 | def ft_data_loader_init(args, data_length, image_tmpl, train_transforms, test_transforms, eval_transforms):
36 | if args.ft_dataset in ['ucf101', 'hmdb51', 'diving48', 'sth_v1']:
37 | from data.dataset import DataSet as DataSet
38 | elif args.ft_dataset == 'kinetics':
39 | from data.video_dataset import VideoDataSet as DataSet
40 | else:
41 | Exception("unsupported dataset")
42 | train_dataset = DataSet(args, args.ft_root, args.ft_train_list, num_segments=1, new_length=data_length,
43 | stride=args.ft_stride, modality=args.ft_mode, dataset=args.ft_dataset, test_mode=False,
44 | image_tmpl=image_tmpl if args.ft_mode in ["rgb", "RGBDiff"]
45 | else args.flow_prefix + "{}_{:05d}.jpg", transform=train_transforms)
46 | print("training samples:{}".format(train_dataset.__len__()))
47 | train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.ft_batch_size, shuffle=True,
48 | num_workers=args.ft_workers, pin_memory=True)
49 | val_dataset = DataSet(args, args.ft_root, args.ft_val_list, num_segments=1, new_length=data_length,
50 | stride=args.ft_stride, modality=args.ft_mode, test_mode=True, dataset=args.ft_dataset,
51 | image_tmpl=image_tmpl if args.ft_mode in ["rgb", "RGBDiff"] else args.flow_prefix + "{}_{:05d}.jpg",
52 | random_shift=False, transform=test_transforms)
53 | print("val samples:{}".format(val_dataset.__len__()))
54 | val_data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.ft_batch_size, shuffle=False,
55 | num_workers=args.ft_workers, pin_memory=True)
56 | eval_dataset = DataSet(args, args.ft_root, args.ft_val_list, num_segments=1, new_length=data_length,
57 | stride=args.ft_stride, modality=args.ft_mode, test_mode=True, dataset=args.ft_dataset,
58 | image_tmpl=image_tmpl if args.ft_mode in ["rgb", "RGBDiff"] else args.ft_flow_prefix + "{}_{:05d}.jpg",
59 | random_shift=False, transform=eval_transforms, full_video=True)
60 | print("eval samples:{}".format(eval_dataset.__len__()))
61 | eval_data_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=args.ft_batch_size, shuffle=False,
62 | num_workers=args.ft_workers, pin_memory=True)
63 | return train_data_loader, val_data_loader, eval_data_loader, train_dataset.__len__(), val_dataset.__len__(), eval_dataset.__len__()
64 |
--------------------------------------------------------------------------------
/src/Contrastive/data/on_the_fly_test.py:
--------------------------------------------------------------------------------
1 | import lintel
2 | import numpy as np
3 | import cv2
4 | import skvideo.io
5 |
6 | def video_frame_count(video_path):
7 | cap = cv2.VideoCapture(video_path)
8 | if not cap.isOpened():
9 | #print("could not open: ", video_path)
10 | return -1
11 | length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
12 | width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) )
13 | height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) )
14 | return length, width, height
15 |
16 | v_path = "/data1/DataSet/Kinetics/compress/val_256/bee_keeping/p3_wp0Cq6Lo.mp4"
17 | video_frames_num, width, height = video_frame_count(v_path)
18 |
19 | # video, width, height, seek_index = lintel.loadvid(open(v_path, 'rb').read(), should_random_seek=False)
20 | # video = np.reshape(np.frombuffer(video, dtype=np.uint8), (-1, height, width, 3))
21 | # num_frames = video.shape
22 | # print(video_frames_num, num_frames)
23 | #
24 | # videodata = skvideo.io.vread(v_path, inputdict={'-r': '4'})
25 | # print(videodata.shape)
26 | print(video_frames_num)
27 | f = open(v_path, 'rb')
28 | video = f.read()
29 | f.close()
30 |
31 | # ffmpeg count 比 cv2少1帧
32 | frame_nums = [0,201,280, 295, 296]
33 | decoded_frames = lintel.loadvid_frame_nums(video,
34 | frame_nums=frame_nums,
35 | width=width,
36 | height=height)
37 | decoded_frames = np.frombuffer(decoded_frames, dtype=np.uint8)
38 | decoded_frames = np.reshape(
39 | decoded_frames,
40 | newshape=(len(frame_nums), height, width, 3))
41 |
42 | print(np.shape(decoded_frames)[0])
43 |
44 | # pytorch vision里的Kinetics dataset可以简单改写,可以得到音频,但是不能限定开始的index
45 | # self.video_clips = VideoClips(
46 | # video_list,
47 | # frames_per_clip,
48 | # step_between_clips,
49 | # frame_rate,
50 | # _precomputed_metadata,
51 | # num_workers=num_workers,
52 | # _video_width=_video_width,
53 | # _video_height=_video_height,
54 | # _video_min_dimension=_video_min_dimension,
55 | # _audio_samples=_audio_samples,
56 | # _audio_channels=_audio_channels,
57 | # )
--------------------------------------------------------------------------------
/src/Contrastive/feature_extractor.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | import torch.backends.cudnn as cudnn
5 | from data.config import data_config, augmentation_pretext_config
6 | from data.dataloader import data_loader_init
7 | from model.config import pretext_model_config
8 | from augment.config import TC
9 | from bk.option_old import args
10 | import torch.nn as nn
11 |
12 |
13 | def single_extract(tc, val_loader, model):
14 | model.eval()
15 | features = {'data':[], 'target':[]}
16 | with torch.no_grad():
17 | for i, (input, target, index) in enumerate(val_loader):
18 | inputs = input
19 | # inputs = tc(input)
20 | output = model(inputs)
21 | # print(output.size())
22 | output = nn.AdaptiveAvgPool3d(1)(output).view(output.size(0), output.size(1))
23 | # print(output.size())
24 | # print(target)
25 | for j in range(output.size(0)):
26 | features['data'].append(output[j].cpu().numpy())
27 | features['target'].append(target[j].cpu().numpy())
28 | if i % 10 == 0:
29 | print("{}/{} finished".format(i, len(val_loader)))
30 | return features
31 |
32 |
33 | def feature_extract(tc, data_loader, model):
34 | features = single_extract(tc, data_loader, model)
35 | return features
36 |
37 |
38 | def main():
39 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # close the warning
40 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
41 | cudnn.benchmark = True
42 | # == dataset config==
43 | num_class, data_length, image_tmpl = data_config(args)
44 | train_transforms, test_transforms, _ = augmentation_pretext_config(args)
45 | train_data_loader, val_data_loader, _, _, _, _ = data_loader_init(args, data_length, image_tmpl, train_transforms,
46 | test_transforms, _)
47 | # == model config==
48 | model = pretext_model_config(args, num_class)
49 | tc = TC(args)
50 | # front = "contrastive_kinetics_warpping_{}".format(args.dataset)
51 | # front = "triplet_ucf101_warpping_{}".format(args.dataset)
52 | front = "{}_{}_{}".format(args.arch, args.front, args.dataset)
53 | dir = '../experiments/features/{}'.format(front)
54 | if not os.path.exists(dir):
55 | os.makedirs(dir)
56 | features = feature_extract(tc, val_data_loader, model)
57 | np.save('{}/val_features.npy'.format(dir), features)
58 | features = feature_extract(tc, train_data_loader, model)
59 | np.save('{}/train_features.npy'.format(dir), features)
60 |
61 |
62 | if __name__ == '__main__':
63 | main()
64 |
--------------------------------------------------------------------------------
/src/Contrastive/ft.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | import os
4 | import time
5 | from utils.utils import Timer
6 | import torch
7 | import torch.backends.cudnn as cudnn
8 | from utils.utils import AverageMeter
9 | from data.config import ft_data_config, ft_augmentation_config
10 | from data.dataloader import ft_data_loader_init
11 | from model.config import ft_model_config
12 | from loss.config import ft_optim_init
13 | from utils.learning_rate_adjust import ft_adjust_learning_rate
14 | from augment.config import TC
15 | from utils.utils import accuracy
16 | import random
17 | from datetime import datetime
18 |
19 | lowest_val_loss = float('inf')
20 | best_prec1 = 0
21 | torch.manual_seed(1)
22 |
23 |
24 | def fine_tune_train_and_val(args, recorder):
25 | # =
26 | global lowest_val_loss, best_prec1
27 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # close the warning
28 | torch.manual_seed(1)
29 | cudnn.benchmark = True
30 | timer = Timer()
31 | # == dataset config==
32 | num_class, data_length, image_tmpl = ft_data_config(args)
33 | train_transforms, test_transforms, eval_transforms = ft_augmentation_config(args)
34 | train_data_loader, val_data_loader, _, _, _, _ = ft_data_loader_init(args, data_length, image_tmpl, train_transforms, test_transforms, eval_transforms)
35 | # == model config==
36 | model = ft_model_config(args, num_class)
37 | recorder.record_message('a', '='*100)
38 | recorder.record_message('a', '-'*40+'finetune'+'-'*40)
39 | recorder.record_message('a', '='*100)
40 | # == optim config==
41 | train_criterion, val_criterion, optimizer = ft_optim_init(args, model)
42 | # == data augmentation(self-supervised) config==
43 | tc = TC(args)
44 | # == train and eval==
45 | print('*'*70+'Step2: fine tune'+'*'*50)
46 | for epoch in range(args.ft_start_epoch, args.ft_epochs):
47 | timer.tic()
48 | ft_adjust_learning_rate(optimizer, args.ft_lr, epoch, args.ft_lr_steps)
49 | train_prec1, train_loss = train(args, tc, train_data_loader, model, train_criterion, optimizer, epoch, recorder)
50 | # train_prec1, train_loss = random.random() * 100, random.random()
51 | recorder.record_ft_train(train_loss / 5.0, train_prec1 / 100.0)
52 | if (epoch + 1) % args.ft_eval_freq == 0:
53 | val_prec1, val_loss = validate(args, tc, val_data_loader, model, val_criterion, recorder)
54 | # val_prec1, val_loss = random.random() * 100, random.random()
55 | recorder.record_ft_val(val_loss / 5.0, val_prec1 / 100.0)
56 | is_best = val_prec1 > best_prec1
57 | best_prec1 = max(val_prec1, best_prec1)
58 | checkpoint = {'epoch': epoch + 1, 'arch': "i3d", 'state_dict': model.state_dict(),
59 | 'best_prec1': best_prec1}
60 | recorder.save_ft_model(checkpoint, is_best)
61 | timer.toc()
62 | left_time = timer.average_time * (args.ft_epochs - epoch)
63 | message = "Step2: fine tune best_prec1 is: {} left time is : {} now is : {}".format(best_prec1, timer.format(left_time), datetime.now())
64 | print(message)
65 | recorder.record_message('a', message)
66 | return recorder.filename
67 |
68 |
69 | def train(args, tc, train_loader, model, criterion, optimizer, epoch, recorder, MoCo_init=False):
70 | batch_time = AverageMeter()
71 | data_time = AverageMeter()
72 | losses = AverageMeter()
73 | top1 = AverageMeter()
74 | top3 = AverageMeter()
75 | if MoCo_init:
76 | model.eval()
77 | else:
78 | model.train()
79 | end = time.time()
80 | for i, (input, target, index) in enumerate(train_loader):
81 | data_time.update(time.time() - end)
82 | target = target.cuda()
83 | index = index.cuda()
84 | inputs = tc(input)
85 | target = torch.autograd.Variable(target)
86 | output = model(inputs)
87 | loss = criterion(output, target) # + mse_loss
88 | prec1, prec3 = accuracy(output.data, target, topk=(1, 3))
89 | losses.update(loss.data.item(), input.size(0))
90 | top1.update(prec1.item(), input.size(0))
91 | top3.update(prec3.item(), input.size(0))
92 |
93 | optimizer.zero_grad()
94 | loss.backward()
95 | # # gradient check
96 | # plot_grad_flow(model.module.base_model.named_parameters())
97 | optimizer.step()
98 | batch_time.update(time.time() - end)
99 | end = time.time()
100 |
101 | if i % args.ft_print_freq == 0:
102 | message = ('Epoch: [{0}][{1}/{2}], lr: {lr:.5f}\t'
103 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
104 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
105 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
106 | epoch, i, len(train_loader), batch_time=batch_time,
107 | data_time=data_time, loss=losses, lr=optimizer.param_groups[-1]['lr']))
108 | print(message)
109 | recorder.record_message('a', message)
110 | message = "Finetune Training: Top1:{} Top3:{}".format(top1.avg, top3.avg)
111 | print(message)
112 | recorder.record_message('a', message)
113 | return top1.avg, losses.avg
114 |
115 |
116 | def validate(args, tc, val_loader, model, criterion, recorder, MoCo_init=False):
117 | batch_time = AverageMeter()
118 | losses = AverageMeter()
119 | top1 = AverageMeter()
120 | top3 = AverageMeter()
121 | # switch to evaluate mode
122 | model.eval()
123 | end = time.time()
124 | with torch.no_grad():
125 | for i, (input, target, index) in enumerate(val_loader):
126 | target = target.cuda()
127 | inputs = tc(input)
128 | target = torch.autograd.Variable(target)
129 | output = model(inputs)
130 | loss = criterion(output, target)
131 | prec1, prec3 = accuracy(output.data, target, topk=(1, 3))
132 | losses.update(loss.data.item(), input.size(0))
133 | top1.update(prec1.item(), input.size(0))
134 | top3.update(prec3.item(), input.size(0))
135 | batch_time.update(time.time() - end)
136 | end = time.time()
137 |
138 | if i % args.ft_print_freq == 0:
139 | message = ('Test: [{0}/{1}]\t'
140 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
141 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
142 | i, len(val_loader), batch_time=batch_time, loss=losses
143 | ))
144 | print(message)
145 | recorder.record_message('a', message)
146 | message = "Finetune Eval: Top1:{} Top3:{}".format(top1.avg, top3.avg)
147 | print(message)
148 | recorder.record_message('a', message)
149 | return top1.avg, losses.avg
150 |
151 |
--------------------------------------------------------------------------------
/src/Contrastive/loss/NCE/Link.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class Node:
4 | def __init__(self, data, _pre=None, _next=None):
5 | self.data = data # in format [loss, feature]
6 | # self.data = data[0]
7 | # self.loss = data[1]
8 | # self.index = data[2]
9 | self._pre = _pre
10 | self._next = _next
11 | def __str__(self):
12 | return str(self.data[2])
13 |
14 | class DoublyLink:
15 | def __init__(self):
16 | self.tail = None
17 | self.head = None
18 | self.size = 0
19 |
20 | # at the end
21 | def append(self, new_node):
22 | tmp_node = self.tail
23 | tmp_node._pre = new_node
24 | new_node._next = tmp_node
25 | new_node._pre = None
26 | self.tail = new_node
27 | return new_node
28 |
29 | # at the head
30 | def add_first(self, new_node):
31 | tmp_node = self.head
32 | tmp_node._next = new_node
33 | new_node._pre = tmp_node
34 | new_node._next = None
35 | self.head = new_node
36 | return new_node
37 |
38 | def insert_before(self, node, new_node):
39 | node._next._pre = new_node
40 | new_node._next = node._next
41 | new_node._pre = node
42 | node._next = new_node
43 | return new_node
44 |
45 | def insert_after(self, node, new_node):
46 | if node._pre is None:
47 | return self.append(new_node)
48 | else:
49 | return self.insert_before(node._pre, new_node)
50 | # node._next = new_node
51 | # new_node._next = None
52 | # new_node._pre = node
53 | # self.head = new_node
54 | # return new_node
55 |
56 | def insert(self, data):
57 | if isinstance(data, Node):
58 | tmp_node = data
59 | else:
60 | tmp_node = Node(data)
61 | if self.size == 0:
62 | self.tail = tmp_node
63 | self.head = self.tail
64 | else:
65 | # pre_node = self.head
66 | tmp_node = self.add_first(tmp_node)
67 | # while pre_node.data[0] > tmp_node.data[0] and pre_node._pre != None:
68 | # pre_node = pre_node._pre
69 | # #insert before
70 | # # print(pre_node._pre, pre_node._next)
71 | # if pre_node._pre is None and pre_node.data[0] >= tmp_node.data[0]:
72 | # tmp_node = self.append(tmp_node)
73 | # elif pre_node._next is None and pre_node.data[0] < tmp_node.data[0]:
74 | # tmp_node = self.add_first(tmp_node)
75 | # elif pre_node._next is None and pre_node.data[0] >= tmp_node.data[0]:
76 | # tmp_node = self.insert_after(pre_node, tmp_node)
77 | # else:
78 | # tmp_node = self.insert_before(pre_node, tmp_node)
79 | self.size += 1
80 | return tmp_node
81 |
82 | def remove(self, node):
83 | if node == self.head:
84 | self.head._pre._next = None
85 | self.head = self.head._pre
86 | elif node == self.tail:
87 | self.tail._next._pre = None
88 | self.tail = self.tail._next
89 | else:
90 | node._next._pre = node._pre
91 | node._pre._next = node._next
92 | self.size -= 1
93 |
94 | def __str__(self):
95 | str_text = ""
96 | cur_node = self.head
97 | count = 0
98 | while cur_node != None:
99 | str_text += str(cur_node.data[2]) + " "
100 | cur_node = cur_node._pre
101 | count += 1
102 | if count > 20:
103 | break
104 | return str_text
105 |
106 |
107 | class LRUCache:
108 | def __init__(self, size):
109 | self.size = size
110 | self.hash_map = dict()
111 | self.link = DoublyLink()
112 | self.LRU_init(size)
113 |
114 | def LRU_init(self, size):
115 | for i in range(size):
116 | self.set(i, [1e-8, torch.rand(128), i])
117 |
118 | def set(self, key, value):
119 | if self.size == self.link.size:
120 | self.link.remove(self.link.tail)
121 | if key in self.hash_map:
122 | self.link.remove(self.hash_map.get(key))
123 | tmp_node = self.link.insert(value)
124 | self.hash_map.__setitem__(key, tmp_node)
125 |
126 | def get(self, key):
127 | tmp_node = self.hash_map.get(key)
128 | self.link.remove(tmp_node)
129 | self.link.insert(tmp_node)
130 | return tmp_node.data
131 |
132 | def get_queue(self, num, keys):
133 | queue = torch.rand(num, 128).cuda()
134 | num_queue = 0
135 | cur_node = self.link.head
136 | while num_queue < num:
137 | if cur_node.data[2] not in keys:
138 | queue[num_queue] = cur_node.data[1]
139 | num_queue += 1
140 | cur_node = cur_node._pre
141 | # while num_queue < num:
142 | # # if cur_node.data[2] not in keys:
143 | # queue[num_queue] = cur_node.data
144 | # num_queue += 1
145 | # cur_node = cur_node._next
146 | return queue
147 |
148 | def update_queue(self, queue):
149 | num = queue.size(0)
150 | # print(num)
151 | cur_node = self.link.head
152 | for i in range(num):
153 | queue[i] = cur_node.data[1]
154 | cur_node = cur_node._pre
155 | return queue
156 |
157 | def batch_set(self, keys, values, losses):
158 | # print(self.link)
159 | num = len(values)
160 | for i in range(num):
161 | self.set(keys[i], [losses[i].item(), values[i], keys[i]]) # add loss,data,key
162 |
163 | # r = LRUCache(3)
164 | # r.set("1", ["1","1"])
165 | # r.set("2", ["2","2"])
166 | # r.set("3", ["3","3"])
167 | # print(r.link.size)
168 | # r.get("1")
169 | # print(r.link)
170 | # r.set("4", ["4","4"])
171 | # print(r.link)
172 |
173 |
--------------------------------------------------------------------------------
/src/Contrastive/loss/NCE/NCECriterion.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | eps = 1e-7
5 |
6 |
7 | class NCECriterion(nn.Module):
8 | """
9 | Eq. (12): L_{NCE}
10 | """
11 | def __init__(self, n_data):
12 | super(NCECriterion, self).__init__()
13 | self.n_data = n_data
14 |
15 | def forward(self, x):
16 | bsz = x.shape[0]
17 | m = x.size(1) - 1
18 |
19 | # noise distribution
20 | Pn = 1 / float(self.n_data)
21 |
22 | # loss for positive pair
23 | P_pos = x.select(1, 0)
24 | log_D1 = torch.div(P_pos, P_pos.add(m * Pn + eps)).log_()
25 |
26 | # loss for K negative pair
27 | P_neg = x.narrow(1, 1, m)
28 | log_D0 = torch.div(P_neg.clone().fill_(m * Pn), P_neg.add(m * Pn + eps)).log_()
29 |
30 | loss = - (log_D1.sum(0) + log_D0.view(-1, 1).sum(0)) / bsz
31 |
32 | return loss
33 |
34 |
35 | class NCESoftmaxLoss(nn.Module):
36 | """Softmax cross-entropy loss (a.k.a., info-NCE loss in CPC paper)"""
37 | def __init__(self):
38 | super(NCESoftmaxLoss, self).__init__()
39 | self.criterion = nn.CrossEntropyLoss()
40 |
41 | def forward(self, x):
42 | bsz = x.shape[0]
43 | x = x.squeeze()
44 | label = torch.zeros([bsz]).cuda().long()
45 | loss = self.criterion(x, label)
46 | return loss
47 |
--------------------------------------------------------------------------------
/src/Contrastive/loss/NCE/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/BE/632c3aa0eaa3acc24a545ec05a9a36f96592cb2c/src/Contrastive/loss/NCE/__init__.py
--------------------------------------------------------------------------------
/src/Contrastive/loss/NCE/alias_multinomial.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class AliasMethod(object):
5 | """
6 | From: https://hips.seas.harvard.edu/blog/2013/03/03/the-alias-method-efficient-sampling-with-many-discrete-outcomes/
7 | """
8 | def __init__(self, probs):
9 |
10 | if probs.sum() > 1:
11 | probs.div_(probs.sum())
12 | K = len(probs)
13 | self.prob = torch.zeros(K)
14 | self.alias = torch.LongTensor([0]*K)
15 |
16 | # Sort the data into the outcomes with probabilities
17 | # that are larger and smaller than 1/K.
18 | smaller = []
19 | larger = []
20 | for kk, prob in enumerate(probs):
21 | self.prob[kk] = K*prob
22 | if self.prob[kk] < 1.0:
23 | smaller.append(kk)
24 | else:
25 | larger.append(kk)
26 |
27 | # Loop though and create little binary mixtures that
28 | # appropriately allocate the larger outcomes over the
29 | # overall uniform mixture.
30 | while len(smaller) > 0 and len(larger) > 0:
31 | small = smaller.pop()
32 | large = larger.pop()
33 |
34 | self.alias[small] = large
35 | self.prob[large] = (self.prob[large] - 1.0) + self.prob[small]
36 |
37 | if self.prob[large] < 1.0:
38 | smaller.append(large)
39 | else:
40 | larger.append(large)
41 |
42 | for last_one in smaller+larger:
43 | self.prob[last_one] = 1
44 |
45 | def cuda(self):
46 | self.prob = self.prob.cuda()
47 | self.alias = self.alias.cuda()
48 |
49 | def draw(self, N):
50 | """
51 | Draw N samples from multinomial
52 | :param N: number of samples
53 | :return: samples
54 | """
55 | K = self.alias.size(0)
56 |
57 | kk = torch.zeros(N, dtype=torch.long, device=self.prob.device).random_(0, K)
58 | prob = self.prob.index_select(0, kk)
59 | alias = self.alias.index_select(0, kk)
60 | # b is whether a random number is greater than q
61 | b = torch.bernoulli(prob)
62 | oq = kk.mul(b.long())
63 | oj = alias.mul((1-b).long())
64 |
65 | return oq + oj
66 |
--------------------------------------------------------------------------------
/src/Contrastive/loss/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/BE/632c3aa0eaa3acc24a545ec05a9a36f96592cb2c/src/Contrastive/loss/__init__.py
--------------------------------------------------------------------------------
/src/Contrastive/loss/config.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.optim as optim
3 |
4 | from loss.NCE.NCEAverage import MemoryMoCo
5 | from loss.NCE.NCECriterion import NCECriterion
6 | from loss.NCE.NCECriterion import NCESoftmaxLoss
7 |
8 |
9 | def get_fine_tuning_parameters(model, ft_begin_module='custom'):
10 | if not ft_begin_module:
11 | return model.parameters()
12 | parameters = []
13 | add_flag = False
14 | for k, v in model.named_parameters():
15 | parameters.append({'params': v})
16 | # if ft_begin_module in k:
17 | # add_flag = True
18 | # print(k)
19 | # parameters.append({'params': v})
20 | # if add_flag:
21 | # parameters.append({'params': v})
22 | return parameters
23 |
24 |
25 | def pt_optim_init(args, model, n_data):
26 | contrast = MemoryMoCo(128, n_data, args.pt_nce_k, args.pt_nce_t, args.pt_softmax).cuda()
27 | criterion = NCESoftmaxLoss() if args.pt_softmax else NCECriterion(n_data)
28 | criterion = criterion.cuda()
29 |
30 | optimizer = torch.optim.SGD(model.parameters(),
31 | lr=args.pt_learning_rate,
32 | momentum=args.pt_momentum,
33 | weight_decay=args.pt_weight_decay)
34 | return contrast, criterion, optimizer
35 |
36 |
37 | def ft_optim_init(args, model):
38 | train_criterion = torch.nn.NLLLoss().cuda()
39 | val_criterion = torch.nn.NLLLoss().cuda()
40 |
41 | if args.ft_fixed == 1:
42 | parameters = get_fine_tuning_parameters(model, ft_begin_module='custom')
43 | else:
44 | parameters = model.parameters()
45 | if args.ft_optim == 'sgd':
46 | optimizer = optim.SGD(parameters,
47 | lr=args.ft_lr,
48 | momentum=args.ft_momentum,
49 | weight_decay=args.ft_weight_decay)
50 | elif args.ft_optim == 'adam':
51 | optimizer = optim.Adam(parameters, lr=args.lr)
52 | else:
53 | Exception("not supported optim")
54 | if args.ft_fixed == 1:
55 | count = 0
56 | for param_group in optimizer.param_groups:
57 | count += 1
58 | print("param group is: {}".format(count))
59 | count2 = 0
60 | for param_group in optimizer.param_groups:
61 | count2 += 1
62 | param_group['lr'] = param_group['lr'] * count2 / count
63 | return train_criterion, val_criterion, optimizer
64 |
65 |
--------------------------------------------------------------------------------
/src/Contrastive/loss/tcr.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 |
4 |
5 | def flip(x, dim):
6 | indices = [slice(None)] * x.dim()
7 | indices[dim] = torch.arange(x.size(dim) - 1, -1, -1,
8 | dtype=torch.long, device=x.device)
9 | return x[tuple(indices)]
10 |
11 |
12 | def tcr(feats_o, feats_r):
13 | loss = nn.MSELoss()
14 | feats_r = flip(feats_r, 2)
15 | b, c, t, h, w = feats_o.size()
16 | o_t = nn.AdaptiveAvgPool3d((t, 1, 1))(feats_o)
17 | o_r = nn.AdaptiveAvgPool3d((t, 1, 1))(feats_r)
18 | output = loss(o_t, o_r)
19 | return output
--------------------------------------------------------------------------------
/src/Contrastive/main.py:
--------------------------------------------------------------------------------
1 | from option import args
2 | import datetime
3 | from pt import pretext_train
4 | from ft import fine_tune_train_and_val
5 | import os
6 | from utils.recoder import Record
7 |
8 |
9 | def main():
10 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
11 | args.date = datetime.datetime.today().strftime('%m-%d-%H%M')
12 | recorder = Record(args)
13 | message = ""
14 | if args.method == 'pt':
15 | args.status = 'pt'
16 | pretext_train(args, recorder)
17 | print("finished pretrain with weight from: {}".format(args.ft_weights))
18 | elif args.method == 'ft':
19 | args.status = 'ft'
20 | fine_tune_train_and_val(args, recorder)
21 | print("finished finetune with weight from: {}".format(args.ft_weights))
22 | elif args.method == 'pt_and_ft':
23 | args.status = 'pt'
24 | checkpoints_path = pretext_train(args, recorder)
25 | print("finished pretrain, the weight is in: {}".format(args.ft_weights))
26 | args.status = 'ft'
27 | args.ft_weights = checkpoints_path
28 | fine_tune_train_and_val(args, recorder)
29 | print("finished finetune with weight from: {}".format(checkpoints_path))
30 | else:
31 | Exception("wrong method!")
32 |
33 |
34 | if __name__ == '__main__':
35 | main()
36 |
--------------------------------------------------------------------------------
/src/Contrastive/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/BE/632c3aa0eaa3acc24a545ec05a9a36f96592cb2c/src/Contrastive/model/__init__.py
--------------------------------------------------------------------------------
/src/Contrastive/model/c3d.py:
--------------------------------------------------------------------------------
1 | """C3D"""
2 | import math
3 | from collections import OrderedDict
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torch.nn.modules.utils import _triple
8 | from model.i3d import Flatten, Normalize
9 | import torch.nn.functional as F
10 |
11 |
12 | class C3D(nn.Module):
13 | """C3D with BN and pool5 to be AdaptiveAvgPool3d(1)."""
14 |
15 | def __init__(self, with_classifier=False, num_classes=101):
16 | super(C3D, self).__init__()
17 | self.with_classifier = with_classifier
18 | self.num_classes = num_classes
19 |
20 | self.conv1 = nn.Conv3d(3, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1))
21 | self.bn1 = nn.BatchNorm3d(64)
22 | self.relu1 = nn.ReLU()
23 | self.pool1 = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))
24 |
25 | self.conv2 = nn.Conv3d(64, 128, kernel_size=(3, 3, 3), padding=(1, 1, 1))
26 | self.bn2 = nn.BatchNorm3d(128)
27 | self.relu2 = nn.ReLU()
28 | self.pool2 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))
29 |
30 | self.conv3a = nn.Conv3d(128, 256, kernel_size=(3, 3, 3), padding=(1, 1, 1))
31 | self.bn3a = nn.BatchNorm3d(256)
32 | self.relu3a = nn.ReLU()
33 | self.conv3b = nn.Conv3d(256, 256, kernel_size=(3, 3, 3), padding=(1, 1, 1))
34 | self.bn3b = nn.BatchNorm3d(256)
35 | self.relu3b = nn.ReLU()
36 | self.pool3 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))
37 |
38 | self.conv4a = nn.Conv3d(256, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1))
39 | self.bn4a = nn.BatchNorm3d(512)
40 | self.relu4a = nn.ReLU()
41 | self.conv4b = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1))
42 | self.bn4b = nn.BatchNorm3d(512)
43 | self.relu4b = nn.ReLU()
44 | self.pool4 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))
45 |
46 | self.conv5a = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1))
47 | self.bn5a = nn.BatchNorm3d(512)
48 | self.relu5a = nn.ReLU()
49 | self.conv5b = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1))
50 | self.bn5b = nn.BatchNorm3d(512)
51 | self.relu5b = nn.ReLU()
52 | self.pool5 = nn.AdaptiveAvgPool3d(1)
53 |
54 | if self.with_classifier:
55 | self.linear = nn.Linear(512, self.num_classes)
56 | else:
57 | self.id_head = nn.Sequential(
58 | torch.nn.AdaptiveAvgPool3d((1, 1, 1)),
59 | Flatten(),
60 | torch.nn.Linear(512, 128),
61 | Normalize(2)
62 | )
63 |
64 | def forward(self, x, return_conv=False):
65 | x = self.conv1(x)
66 | x = self.bn1(x)
67 | x = self.relu1(x)
68 | x = self.pool1(x)
69 |
70 | x = self.conv2(x)
71 | x = self.bn2(x)
72 | x = self.relu2(x)
73 | x = self.pool2(x)
74 |
75 | x = self.conv3a(x)
76 | x = self.bn3a(x)
77 | x = self.relu3a(x)
78 | x = self.conv3b(x)
79 | x = self.bn3b(x)
80 | x = self.relu3b(x)
81 | x = self.pool3(x)
82 |
83 | x = self.conv4a(x)
84 | x = self.bn4a(x)
85 | x = self.relu4a(x)
86 | x = self.conv4b(x)
87 | x = self.bn4b(x)
88 | x = self.relu4b(x)
89 | x = self.pool4(x)
90 |
91 | x = self.conv5a(x)
92 | x = self.bn5a(x)
93 | x = self.relu5a(x)
94 | x = self.conv5b(x)
95 | x = self.bn5b(x)
96 | x = self.relu5b(x)
97 |
98 | if return_conv:
99 | return x
100 | if not self.with_classifier:
101 | id_out = self.id_head(x)
102 | return id_out, 0, 0
103 |
104 | x = self.pool5(x)
105 | x = x.view(-1, 512)
106 |
107 | if self.with_classifier:
108 | x = self.linear(x)
109 | x = F.log_softmax(x, dim=1)
110 | return x
111 |
112 | #
113 | # if __name__ == '__main__':
114 | # c3d = C3D()
--------------------------------------------------------------------------------
/src/Contrastive/model/config.py:
--------------------------------------------------------------------------------
1 | from model.i3d import I3D
2 | from model.r2p1d import R2Plus1DNet
3 | from model.r3d import resnet18, resnet34, resnet50
4 | from model.c3d import C3D
5 | from model.s3d_g import S3D_G
6 | from model.s3d import S3DG
7 | import torch.nn as nn
8 | from model.model import TCN
9 | import torch
10 | from utils.load_weights import ft_load_weight
11 |
12 |
13 | def pt_model_config(args, num_class):
14 | if args.arch == 'i3d':
15 | model = I3D(num_classes=101, modality=args.pt_mode, with_classifier=False)
16 | model_ema = I3D(num_classes=101, modality=args.pt_mode, with_classifier=False)
17 | elif args.arch == 'r2p1d':
18 | model = R2Plus1DNet((1, 1, 1, 1), num_classes=num_class, with_classifier=False)
19 | model_ema = R2Plus1DNet((1, 1, 1, 1), num_classes=num_class, with_classifier=False)
20 | elif args.arch == 'r3d18':
21 | model = resnet18(num_classes=num_class, with_classifier=False)
22 | model_ema = resnet18(num_classes=num_class, with_classifier=False)
23 | elif args.arch == 'r3d34':
24 | model = resnet34(num_classes=num_class, with_classifier=False)
25 | model_ema = resnet34(num_classes=num_class, with_classifier=False)
26 | elif args.arch == 'r3d50':
27 | model = resnet50(num_classes=num_class, with_classifier=False)
28 | model_ema = resnet50(num_classes=num_class, with_classifier=False)
29 | elif args.arch == 'c3d':
30 | model = C3D(with_classifier=False, num_classes=num_class)
31 | model_ema = C3D(with_classifier=False, num_classes=num_class)
32 | elif args.arch == 's3d':
33 | model = S3D_G(num_class=num_class, in_channel=3, gate=True, with_classifier=False)
34 | model_ema = S3D_G(num_class=num_class, in_channel=3, gate=True, with_classifier=False)
35 | else:
36 | Exception("Not implemene error!")
37 | model = torch.nn.DataParallel(model)
38 | model_ema = torch.nn.DataParallel(model_ema)
39 | return model, model_ema
40 |
41 |
42 | def ft_model_config(args, num_class):
43 | with_classifier = True
44 | if args.arch == 'i3d':
45 | base_model = I3D(num_classes=num_class, modality=args.ft_mode, dropout_prob=args.ft_dropout, with_classifier=with_classifier)
46 | # args.logits_channel = 1024
47 | if args.ft_spatial_size == '112':
48 | out_size = (int(args.ft_data_length) // 8, 4, 4)
49 | else:
50 | out_size = (int(args.ft_data_length) // 8, 7, 7)
51 | elif args.arch == 'r2p1d':
52 | base_model = R2Plus1DNet((1, 1, 1, 1), num_classes=num_class, with_classifier=with_classifier)
53 | # args.logits_channel = 512
54 | out_size = (4, 4, 4)
55 | elif args.arch == 'c3d':
56 | base_model = C3D(num_classes=num_class, with_classifier=with_classifier)
57 | # args.logits_channel = 512
58 | out_size = (4, 4, 4)
59 | elif args.arch == 'r3d18':
60 | base_model = resnet18(num_classes=num_class, sample_size=int(args.ft_spatial_size), with_classifier=with_classifier)
61 | # args.logits_channel = 512
62 | out_size = (4, 4, 4)
63 | elif args.arch == 'r3d34':
64 | base_model = resnet34(num_classes=num_class, sample_size=int(args.ft_spatial_size), with_classifier=with_classifier)
65 | # args.logits_channel = 512
66 | out_size = (4, 4, 4)
67 | elif args.arch == 'r3d50':
68 | base_model = resnet50(num_classes=num_class, sample_size=int(args.ft_spatial_size), with_classifier=with_classifier)
69 | # args.logits_channel = 512
70 | out_size = (4, 4, 4)
71 | elif args.arch == 's3d':
72 | # base_model = S3D_G(num_class=num_class, drop_prob=args.dropout, in_channel=3)
73 | base_model = S3DG(num_classes=num_class, dropout_keep_prob=args.ft_dropout, input_channel=3, spatial_squeeze=True, with_classifier=True)
74 | # args.logits_channel = 1024
75 | out_size = (2, 7, 7)
76 | else:
77 | Exception("unsuporrted arch!")
78 | base_model = ft_load_weight(args, base_model)
79 | model = TCN(base_model, out_size, args)
80 | model = nn.DataParallel(model).cuda()
81 | # cudnn.benchmark = True
82 | return model
83 |
--------------------------------------------------------------------------------
/src/Contrastive/model/model.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch.nn.functional as F
3 | import torch
4 | import numpy as np
5 |
6 | class Flatten(nn.Module):
7 | def __init__(self):
8 | super(Flatten, self).__init__()
9 |
10 | def forward(self, input):
11 | return input.view(input.size(0), -1)
12 |
13 |
14 | class Normalize(nn.Module):
15 | def __init__(self, power=2):
16 | super(Normalize, self).__init__()
17 | self.power = power
18 |
19 | def forward(self, x):
20 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1./self.power)
21 | out = x.div(norm)
22 | return out
23 |
24 |
25 | class Sharpen(nn.Module):
26 | def __init__(self, tempeature=0.5):
27 | super(Sharpen, self).__init__()
28 | self.T = tempeature
29 |
30 | def forward(self, probabilities):
31 | tempered = torch.pow(probabilities, 1 / self.T)
32 | tempered = tempered / tempered.sum(dim=-1, keepdim=True)
33 | return tempered
34 |
35 | class MotionEnhance(nn.Module):
36 | def __init__(self, beta=1, maxium_radio=0.3):
37 | super(MotionEnhance, self).__init__()
38 | self.beta = beta
39 | self.maxium_radio = maxium_radio
40 |
41 | def forward(self, x):
42 | b, c, t, h, w = x.size()
43 | mean = nn.AdaptiveAvgPool3d((1, h, w))(x)
44 | lam = np.random.beta(self.beta, self.beta) * self.maxium_radio
45 | out = (x - mean * lam) * (1 / (1 - lam))
46 | return out
47 |
48 |
49 | class TCN(nn.Module):
50 | """
51 | encode a video clip into 128 dimension features and classify
52 | two implement ways, reshape and encode adjcent samples into batch dimension
53 | """
54 | def __init__(self, base_model, out_size, args):
55 | super(TCN, self).__init__()
56 | self.base_model = base_model
57 | self.args = args
58 | self.l2norm = Normalize(2)
59 | print("fine tune ...")
60 |
61 | def forward(self, input):
62 | output = self.base_model(input, return_conv=False)
63 | # print(output.size())
64 | # output = F.log_softmax(output, dim=1)
65 | return output
66 |
--------------------------------------------------------------------------------
/src/Contrastive/model/mutual_net.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class MutualNet(nn.Module):
5 | def __init__(self, embeddingnet):
6 | super(MutualNet, self).__init__()
7 | self.embeddingnet = embeddingnet
8 |
9 | def forward(self, x, y, z):
10 | feature_x = self.embeddingnet(x)
11 | feature_y = self.embeddingnet(y)
12 | return feature_x, feature_y
--------------------------------------------------------------------------------
/src/Contrastive/model/s3d_g.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from collections import OrderedDict
4 | from model.i3d import Flatten, Normalize
5 | # import ipdb
6 | import torch.nn.functional as F
7 |
8 |
9 | class BasicConv3d(nn.Module):
10 | def __init__(self,
11 | in_channel,
12 | out_channel,
13 | kernel_size=1,
14 | stride=1,
15 | padding=0,
16 | use_bias=False,
17 | use_bn=True,
18 | activation='rule'):
19 | super(BasicConv3d, self).__init__()
20 |
21 | self.use_bn = use_bn
22 | self.activation = activation
23 | self.conv3d = nn.Conv3d(in_channel, out_channel, kernel_size=kernel_size,
24 | stride=stride, padding=padding, bias=use_bias)
25 | if use_bn:
26 | self.bn = nn.BatchNorm3d(out_channel, eps=1e-3, momentum=0.001, affine=True)
27 | if activation == 'rule':
28 | self.activation = nn.ReLU()
29 |
30 | def forward(self, x):
31 | x = self.conv3d(x)
32 | if self.use_bn:
33 | x = self.bn(x)
34 | if self.activation is not None:
35 | x = self.activation(x)
36 | # ipdb.set_trace()
37 | return x
38 |
39 |
40 | class sep_conv(nn.Module):
41 | def __init__(self,
42 | in_channel,
43 | out_channel,
44 | kernel_size,
45 | stride=1,
46 | padding=0,
47 | use_bias=True,
48 | use_bn=True,
49 | activation='rule',
50 | gate=True):
51 | super(sep_conv, self).__init__()
52 | down = BasicConv3d(in_channel, out_channel, (1,kernel_size,kernel_size), stride=stride,
53 | padding=(0,padding,padding), use_bias=False, use_bn=True)
54 | up = BasicConv3d(out_channel, out_channel, (kernel_size,1,1), stride=1,
55 | padding=(padding,0,0), use_bias=False, use_bn=True)
56 | self.sep_conv = nn.Sequential(down, up)
57 |
58 | # gating
59 | if gate:
60 | self.gate = gate
61 | self.squeeze = nn.AdaptiveAvgPool3d(1)
62 | self.excitation = nn.Conv3d(out_channel, out_channel, 1)
63 | self.sigmoid = nn.Sigmoid()
64 | else:
65 | self.gate = False
66 |
67 | def forward(self, x):
68 | x = self.sep_conv(x)
69 | # ipdb.set_trace()
70 | if self.gate:
71 | temp = x
72 | weight = self.squeeze(x)
73 | weight = self.excitation(weight)
74 | weight = self.sigmoid(weight)
75 | x = weight * x
76 | return x
77 |
78 |
79 | class sep_inc(nn.Module):
80 | def __init__(self, in_channel, out_channel, gate=True):
81 | super(sep_inc, self).__init__()
82 | # branch 0
83 | self.branch0 = BasicConv3d(in_channel, out_channel[0], kernel_size=(1,1,1), stride=1, padding=0)
84 | # branch 1
85 | branch1_conv1 = BasicConv3d(in_channel, out_channel[1],kernel_size=(1,1,1), stride=1, padding=0)
86 | branch1_sep_conv = sep_conv(out_channel[1], out_channel[2], kernel_size=3, stride=1, padding=1, gate=gate)
87 | self.branch1 = nn.Sequential(branch1_conv1, branch1_sep_conv)
88 | # branch 2
89 | branch2_conv1 = BasicConv3d(in_channel, out_channel[3],kernel_size=(1,1,1), stride=1, padding=0)
90 | branch2_sep_conv = sep_conv(out_channel[3], out_channel[4], kernel_size=3, stride=1, padding=1, gate=gate)
91 | self.branch2 = nn.Sequential(branch2_conv1, branch2_sep_conv)
92 | # branch 3
93 | branch3_pool = nn.MaxPool3d(kernel_size=3, stride=1, padding=1)
94 | branch3_conv = BasicConv3d(in_channel, out_channel[5], kernel_size=(1,1,1))
95 | self.branch3 = nn.Sequential(branch3_pool, branch3_conv)
96 |
97 | def forward(self, x):
98 | # ipdb.set_trace()
99 | out_0 = self.branch0(x)
100 | out_1 = self.branch1(x)
101 | out_2 = self.branch2(x)
102 | out_3 = self.branch3(x)
103 | out = torch.cat((out_0, out_1, out_2, out_3), 1)
104 | return out
105 |
106 |
107 | class S3D_G(nn.Module):
108 | def __init__(self, num_class=400, drop_prob=0.5, in_channel=3, gate=True, with_classifier=True):
109 | super(S3D_G, self).__init__()
110 | self.feature = nn.Sequential(OrderedDict([
111 | ('sepConv1', sep_conv(in_channel, 64, kernel_size=7, stride=2, padding=3, gate=gate)), # (64,32,112,112)
112 | ('maxPool1', nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1))), # (64,32,56,56)
113 | ('basicConv3d', BasicConv3d(64, 64, kernel_size=1, stride=1)), # (64,32,56,56)
114 | ('sep_conv2', sep_conv(64, 192, kernel_size=3, stride=1, padding=1, gate=gate)), # (192,32,56,56)
115 | ('maxPool2', nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1))), # (192,32,28,28)
116 | ('sepInc_3b', sep_inc(192, [64,96,128,16,32,32], gate=gate)), # (256,32,28,28)
117 | ('sepInc_3c', sep_inc(256, [128,128,192,32,96,64], gate=gate)), # (480,32,28,28)
118 | ('maxPool3', nn.MaxPool3d(kernel_size=(3,3,3), stride=(2,2,2), padding=(1,1,1))), # (480,16,14,14)
119 | ('sepInc_4b', sep_inc(480, [192, 96, 208, 16, 48, 64], gate=gate)), # (512,16,14,14)
120 | ('sepInc_4c', sep_inc(512, [160, 112, 224, 24, 64, 64], gate=gate)), # (512,16,14,14)
121 | ('sepInc_4d', sep_inc(512, [128, 128, 256, 24, 64, 64], gate=gate)), # (512,16,14,14)
122 | ('sepInc_4e', sep_inc(512, [112, 144, 288, 32, 64, 64], gate=gate)), # (528,16,14,14)
123 | ('sepInc_4f', sep_inc(528, [256, 160, 320, 32, 128, 128], gate=gate)), # (832,16,14,14)
124 | ('maxpool4', nn.MaxPool3d(kernel_size=(2,2,2),stride=(2,2,2),padding=(0,0,0))), # (832,8,7,7)
125 | ('sepInc_5b', sep_inc(832, [256, 160, 320, 32, 128, 128], gate=gate)), # (832,8,7,7)
126 | ('sepInc_5c', sep_inc(832, [384, 192, 384, 48, 128, 128], gate=gate)), # (1024,8,7,7)
127 | ]))
128 | self.avgPool = nn.AvgPool3d(kernel_size=(2, 7, 7), stride=1) # (1024,7,1,1)
129 | self.drop = nn.Dropout3d(drop_prob)
130 | self.fc = nn.Conv3d(1024, num_class, kernel_size=1, stride=1, bias=True) # (num_class,7,1,1)
131 | self.softmax = nn.Softmax(1)
132 | self.with_classifier = with_classifier
133 | if not with_classifier:
134 | self.id_head = nn.Sequential(
135 | torch.nn.AdaptiveAvgPool3d((1, 1, 1)),
136 | Flatten(),
137 | torch.nn.Linear(1024, 128),
138 | Normalize(2)
139 | )
140 |
141 | def forward(self, x, return_conv=False):
142 | # ipdb.set_trace()
143 | out = self.feature(x) # (batch_size,num_class,7,1,1)
144 | if return_conv:
145 | return out
146 | if not self.with_classifier:
147 | out = self.id_head(out)
148 | return out, 0, 0
149 | out = self.drop(self.avgPool(out))
150 | out = self.fc(out)
151 | # squeeze the spatial dimension
152 | out = out.squeeze(3) # (batch_size,num_class,7,1)
153 | out = out.squeeze(3) # (batch_size,num_class,7)
154 | out = out.mean(2) # (batch_size,num_class)
155 | return F.log_softmax(out, dim=1)
156 | # prediction = self.softmax(out)
157 | # return prediction
158 |
159 |
160 | if __name__ == "__main__":
161 | model = S3D_G()
162 | x = torch.rand((1,3,64,224,224))
163 | p, out = model(x)
--------------------------------------------------------------------------------
/src/Contrastive/reterival.py:
--------------------------------------------------------------------------------
1 | """Video retrieval experiment, top-k."""
2 | import os
3 | import json
4 | import numpy as np
5 | from sklearn.metrics.pairwise import cosine_distances, euclidean_distances
6 |
7 |
8 | def topk_retrieval(feature_dir):
9 | """Extract features from test split and search on train split features."""
10 | print('Load local .npy files. from ...', feature_dir)
11 | train_features = np.load(os.path.join(feature_dir, 'train_features.npy'), allow_pickle=True).item()
12 | X_train = train_features['data']
13 | y_train = train_features['target']
14 | # X_train = np.mean(X_train, 1)
15 | # y_train = y_train[:, 0]
16 | # X_train = X_train.reshape((-1, X_train.shape[-1]))
17 | # y_train = y_train.reshape(-1)
18 |
19 | val_features = np.load(os.path.join(feature_dir, 'val_features.npy'), allow_pickle=True).item()
20 | X_test = val_features['data']
21 | y_test = val_features['target']
22 | # X_test = np.mean(X_test, 1)
23 | # y_test = y_test[:, 0]
24 | # X_test = X_test.reshape((-1, X_test.shape[-1]))
25 | # y_test = y_test.reshape(-1)
26 |
27 | ks = [1, 5, 10, 20, 50]
28 | topk_correct = {k: 0 for k in ks}
29 |
30 | distances = cosine_distances(X_test, X_train)
31 | indices = np.argsort(distances) # 1530 x 3570
32 |
33 | for k in ks:
34 | top_k_indices = indices[:, :k]
35 | for ind, test_label in zip(top_k_indices, y_test):
36 | # print(ind)
37 | for j in range(len(ind)):
38 | labels = y_train[ind[j]]
39 | if test_label in labels:
40 | topk_correct[k] += 1
41 | break
42 |
43 | for k in ks:
44 | correct = topk_correct[k]
45 | total = len(X_test)
46 | print('Top-{}, correct = {:.2f}, total = {}, acc = {:.3f}%'.format(k, correct, total, correct / total * 100))
47 |
48 | with open(os.path.join(feature_dir, 'topk_correct.json'), 'w') as fp:
49 | json.dump(topk_correct, fp)
50 |
51 |
52 | if __name__ == '__main__':
53 | # front = "contrastive_kinetics_warpping"
54 | # front = "triplet_ucf101_warpping_hmdb51"
55 | # front = "contrastive_ucf101_warpping_hmdb51"
56 | # front = "contrastive_ucf101_warpping_ucf101"
57 | # front = "triplet_ucf101_warpping_hmdb51"
58 | # front = "i3d_fully_supervised_kinetics_warpping_hmdb51_finetune"
59 | # front = "c3d_fully_supervised_ucf101_warpping_hmdb51"
60 | front = "c3d_c3d_contrastive_ucf101_warpping_ucf101_ucf101"
61 | feature_dirs = "../experiments/features/{}".format(front)
62 | topk_retrieval(feature_dirs)
63 |
64 |
65 | # ==============================kinetics BE contrastive pretrain =======================
66 | # hmdb51: 11.9 / 31.3 / 44.452 / 60.432 / 81.23
67 | # ucf101: 13.0/35.16/44.0/64.78/83.76
68 |
69 | # ===============================ucf101 BE triplet pretrain===================================
70 | # hmdb51: 3.922 / 17.386 / 29.15 / 45.359 / 69.020 (may not best)
71 |
72 | # =============================ucf101 BE contrastive pretrain===========================
73 | # hmdb51: 11.4 / 31.2 / 46.5/ 60.4 / 80.876
74 | # ucf101: 17.394 / 35.184 / 45.308 / 57.811 / 73.962
75 | # c3d
76 | # hmdb51: 8.23/25.88/38.10/51.96/75.0
77 |
--------------------------------------------------------------------------------
/src/Contrastive/scripts/Diving48/ft.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python main.py \
3 | --method ft --arch i3d \
4 | --ft_train_list ../datasets/lists/diving48/diving48_v2_train_no_front.txt \
5 | --ft_val_list ../datasets/lists/diving48/diving48_v2_test_no_front.txt \
6 | --ft_root /data1/DataSet/Diving48/rgb_frames/ \
7 | --ft_dataset diving48 --ft_mode rgb \
8 | --ft_lr 0.001 --ft_lr_steps 10 20 25 30 35 40 --ft_epochs 45 --ft_batch_size 4 \
9 | --ft_data_length 64 --ft_spatial_size 224 --ft_workers 4 --ft_stride 1 --ft_dropout 0.5 \
10 | --ft_print-freq 100 --ft_fixed 0 \
11 | --ft_weights ../experiments/kinetics_contrastive.pth
--------------------------------------------------------------------------------
/src/Contrastive/scripts/Diving48/pt_and_ft.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python main.py --gpus 1 --method pt_and_ft --pt_method be \
3 | --pt_batch_size 8 --pt_workers 4 --arch i3d --pt_spatial_size 224 --pt_stride 4 --pt_data_length 16 \
4 | --pt_nce_k 3569 --pt_softmax \
5 | --pt_moco --pt_epochs 10 --pt_save_freq 4 --pt_print_freq 100 --pt_dataset diving48 \
6 | --pt_train_list ../datasets/lists/diving48/diving48_v2_train_no_front.txt \
7 | --pt_val_list ../datasets/lists/diving48/diving48_v2_test_no_front.txt \
8 | --pt_root /data1/yutinggao/data/videos/Diving48_rgb_frames \
9 | --ft_train_list ../datasets/lists/diving48/diving48_v2_train_no_front.txt \
10 | --ft_val_list ../datasets/lists/diving48/diving48_v2_test_no_front.txt \
11 | --ft_root /data1/yutinggao/data/videos/Diving48_rgb_frames \
12 | --ft_dataset diving48 --ft_mode rgb \
13 | --ft_lr 0.001 --ft_lr_steps 10 20 25 30 35 40 --ft_epochs 45 --ft_batch_size 4 \
14 | --ft_data_length 64 --ft_spatial_size 224 --ft_workers 4 --ft_stride 1 --ft_dropout 0.5 \
15 | --ft_print-freq 100 --ft_fixed 0
--------------------------------------------------------------------------------
/src/Contrastive/scripts/Diving48/pt_and_ft_moco.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python main.py --gpus 1 --method pt_and_ft --pt_method moco \
3 | --pt_batch_size 4 --pt_workers 4 --arch i3d --pt_spatial_size 224 --pt_stride 4 --pt_data_length 16 \
4 | --pt_nce_k 3569 --pt_softmax \
5 | --pt_moco --pt_epochs 10 --pt_save_freq 4 --pt_print_freq 100 --pt_dataset diving48 \
6 | --pt_train_list ../datasets/lists/diving48/diving48_v2_train_no_front.txt \
7 | --pt_val_list ../datasets/lists/diving48/diving48_v2_test_no_front.txt \
8 | --pt_root /data1/DataSet/Diving48/rgb_frames/ \
9 | --ft_train_list ../datasets/lists/diving48/diving48_v2_train_no_front.txt \
10 | --ft_val_list ../datasets/lists/diving48/diving48_v2_test_no_front.txt \
11 | --ft_root /data1/DataSet/Diving48/rgb_frames/ \
12 | --ft_dataset diving48 --ft_mode rgb \
13 | --ft_lr 0.001 --ft_lr_steps 10 20 25 30 35 40 --ft_epochs 45 --ft_batch_size 4 \
14 | --ft_data_length 64 --ft_spatial_size 224 --ft_workers 4 --ft_stride 1 --ft_dropout 0.5 \
15 | --ft_print-freq 100 --ft_fixed 0
--------------------------------------------------------------------------------
/src/Contrastive/scripts/Diving48/pt_and_ft_moco_ucf_to_diving48.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python main.py --gpus 1 --method pt_and_ft --pt_method moco \
3 | --pt_batch_size 4 --pt_workers 4 --arch i3d --pt_spatial_size 224 --pt_stride 4 --pt_data_length 16 \
4 | --pt_nce_k 3569 --pt_softmax \
5 | --pt_moco --pt_epochs 10 --pt_save_freq 4 --pt_print_freq 100 --pt_dataset ucf101 \
6 | --pt_train_list ../datasets/lists/ucf101/ucf101_rgb_train_split_1.txt \
7 | --pt_val_list ../datasets/lists/ucf101/ucf101_rgb_val_split_1.txt \
8 | --ft_train_list ../datasets/lists/diving48/diving48_v2_train_no_front.txt \
9 | --ft_val_list ../datasets/lists/diving48/diving48_v2_test_no_front.txt \
10 | --ft_root /data1/DataSet/Diving48/rgb_frames/ \
11 | --ft_dataset diving48 --ft_mode rgb \
12 | --ft_lr 0.001 --ft_lr_steps 10 20 25 30 35 40 --ft_epochs 45 --ft_batch_size 4 \
13 | --ft_data_length 64 --ft_spatial_size 224 --ft_workers 4 --ft_stride 1 --ft_dropout 0.5 \
14 | --ft_print-freq 100 --ft_fixed 0
--------------------------------------------------------------------------------
/src/Contrastive/scripts/evaluation/hmdb51_i3d_eval.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python test.py \
3 | --method ft \
4 | --train_list ../datasets/lists/hmdb51/hmdb51_rgb_train_split_1.txt \
5 | --val_list ../datasets/lists/hmdb51/hmdb51_rgb_val_split_1.txt \
6 | --dataset hmdb51 \
7 | --arch i3d \
8 | --mode rgb \
9 | --batch_size 1 \
10 | --stride 1 \
11 | --data_length 64 \
12 | --clip_size 64 \
13 | --spatial_size 224 \
14 | --workers 1 \
15 | --dropout 0.5 \
16 | --gpus 2 \
17 | --weights ../experiments/logs/hmdb51_i3d_ft/ft_04-11-0112/fine_tune_rgb_model_latest.pth.tar #26?
18 | #--weights ../experiments/logs/hmdb51_i3d_ft/ft_02-18-1134/fine_tune_rgb_model_latest.pth.tar #25.94
19 | #--weights ../experiments/logs/hmdb51_i3d_ft/ft_03-23-2206/fine_tune_rgb_model_latest.pth.tar #49.86
--------------------------------------------------------------------------------
/src/Contrastive/scripts/evaluation/ucf101_i3d_eval.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python test.py \
3 | --method ft \
4 | --train_list ../datasets/lists/ucf101/ucf101_rgb_train_split_1.txt \
5 | --val_list ../datasets/lists/ucf101/ucf101_rgb_val_split_1.txt \
6 | --dataset ucf101 \
7 | --arch i3d \
8 | --mode rgb \
9 | --batch_size 1 \
10 | --stride 1 \
11 | --data_length 64 \
12 | --clip_size 64 \
13 | --spatial_size 224 \
14 | --workers 1 \
15 | --dropout 0.5 \
16 | --gpus 2 \
17 | --weights ../experiments/fine_tune_rgb_model_latest.pth.tar #Scratch: ?
18 | #--weights ../experiments/logs/ucf101_i3d_ft/ft_04-02-1128/fine_tune_rgb_model_latest.pth.tar #SSL: 78.83
19 | #--weights ../experiments/logs/hmdb51_i3d_ft/ft_02-18-1134/fine_tune_rgb_model_latest.pth.tar #25.94
20 | #--weights ../experiments/logs/hmdb51_i3d_ft/ft_03-23-2206/fine_tune_rgb_model_latest.pth.tar #49.86
--------------------------------------------------------------------------------
/src/Contrastive/scripts/feature_extract/hmdb51_extract.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python feature_extractor.py \
3 | --eval_indict feature_extract \
4 | --method ft \
5 | --train_list ../datasets/lists/hmdb51/hmdb51_rgb_train_split_1.txt \
6 | --val_list ../datasets/lists/hmdb51/hmdb51_rgb_val_split_1.txt \
7 | --dataset hmdb51 \
8 | --arch c3d \
9 | --mode rgb \
10 | --lr 0.001 \
11 | --lr_steps 10 20 25 30 35 40 \
12 | --epochs 45 \
13 | --batch_size 1 \
14 | --data_length 64 \
15 | --spatial_size 224 \
16 | --workers 8 \
17 | --dropout 0.5 \
18 | --gpus 1 \
19 | --logs_path ../experiments/logs/hmdb51_i3d_ft \
20 | --print-freq 100 \
21 | --front c3d_fully_supervised_ucf101_warpping_hmdb51 \
22 | --weights ../experiments/MoCo/ucf101/models/08-19-1644_aug_CJ/ckpt_epoch_20.pth
23 | # --weights ../experiments/Pretrained/i3d_rgb_imagenet.pt # kinetics fully supervised
24 | # --weights ../experiments/Pretrained/i3d_model_rgb.pth # kinetics fully supervised finetune
25 | # --weights ../experiments/triplet/kinetics/models/08-18-1957_aug_CJ/ckpt_epoch_30.pth #ucf101_triplet
26 | #--weights ../experiments/MoCo/ucf101/models/08-18-1956_aug_CJ/ckpt_epoch_40.pth # ucf101_contrastive_wrapping
27 | #--weights ../experiments/MoCo/ucf101/models/08-12-1150_aug_CJ/ckpt_epoch_42.pth #kinetics
28 | #--weights ../experiments/logs/hmdb51_i3d_ft/ft_03-23-2206/fine_tune_rgb_model_latest.pth.tar
29 | #--weights ../experiments/logs/hmdb51_i3d_pt_and_ft/pt_and_ft_02-18-1837/fine_tune_rgb_model_best.pth.tar
30 | #--weights ../experiments/logs/hmdb51_i3d_pt_and_ft/pt_and_ft_02-19-1046/fine_tune_rgb_model_best.pth.tar
31 | #--weights ../experiments/logs/hmdb51_i3d_pt_and_ft/pt_and_ft_02-19-1058/fine_tune_rgb_model_best.pth.tar
32 | #--weights ../experiments/logs/hmdb51_i3d_pt_and_ft/pt_and_ft_02-19-1058/net_mixup_rgb_model_best.pth.tar
33 | #--weights ../experiments/logs/hmdb51_i3d_pt_and_ft/pt_and_ft_02-19-1046/flip_cls_rgb_model_best.pth.tar
34 | #--weights ../experiments/pretrained_model/model_rgb.pth
35 | #--weights ../experiments/logs/hmdb51_i3d_pt_and_ft/pt_and_ft_02-15-1229/mutual_loss_rgb_model_latest.pth.tar
--------------------------------------------------------------------------------
/src/Contrastive/scripts/feature_extract/ucf101_extract.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python feature_extractor.py \
3 | --eval_indict feature_extract \
4 | --method ft \
5 | --train_list ../datasets/lists/ucf101/ucf101_rgb_train_split_1.txt \
6 | --val_list ../datasets/lists/ucf101/ucf101_rgb_val_split_1.txt \
7 | --dataset ucf101 \
8 | --arch c3d \
9 | --mode rgb \
10 | --lr 0.001 \
11 | --lr_steps 10 20 25 30 35 40 \
12 | --epochs 45 \
13 | --batch_size 1 \
14 | --data_length 64 \
15 | --spatial_size 224 \
16 | --workers 8 \
17 | --dropout 0.5 \
18 | --gpus 2 \
19 | --logs_path ../experiments/logs/hmdb51_c3d_ft \
20 | --print-freq 100 \
21 | --front c3d_contrastive_ucf101_warpping_ucf101 \
22 | --weights ../experiments/MoCo/ucf101/models/08-19-1644_aug_CJ/ckpt_epoch_20.pth #ucf101_triplet
23 | # --weights ../experiments/MoCo/ucf101/models/08-18-1956_aug_CJ/ckpt_epoch_40.pth # ucf101_contrastive_wrapping
24 | # --weights ../experiments/triplet/ucf101/models/08-04-2112_aug_CJ/ckpt_epoch_10.pth # ucf101_triplet
25 | # --weights ../experiments/MoCo/ucf101/models/08-14-1615_aug_CJ/ckpt_epoch_150.pth # ucf101_contrastive
26 | # --weights ../experiments/MoCo/ucf101/models/08-12-1150_aug_CJ/ckpt_epoch_42.pth # kinetics
--------------------------------------------------------------------------------
/src/Contrastive/scripts/hmdb51/ft.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python main.py --gpus 0 --method ft --arch i3d --pt_method be \
3 | --ft_train_list ../datasets/lists/hmdb51/hmdb51_rgb_train_split_1.txt \
4 | --ft_val_list ../datasets/lists/hmdb51/hmdb51_rgb_val_split_1.txt \
5 | --ft_dataset hmdb51 --ft_mode rgb \
6 | --ft_lr 0.001 --ft_lr_steps 10 20 25 30 35 40 --ft_epochs 45 --ft_batch_size 4 \
7 | --ft_data_length 64 --ft_spatial_size 224 --ft_workers 4 --ft_stride 1 --ft_dropout 0.5 \
8 | --ft_print-freq 100 --ft_fixed 0 \
9 | --ft_weights ../experiments/ucf101_contrastive.pth
--------------------------------------------------------------------------------
/src/Contrastive/scripts/hmdb51/pt_and_ft.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python main.py --gpus 0 --method pt_and_ft --pt_method be \
3 | --pt_batch_size 4 --pt_workers 4 --arch i3d --pt_spatial_size 224 --pt_stride 4 --pt_data_length 16 \
4 | --pt_nce_k 3569 --pt_softmax \
5 | --pt_moco --pt_epochs 200 --pt_save_freq 4 --pt_print_freq 100 --pt_dataset ucf101 \
6 | --pt_train_list ../datasets/lists/ucf101/ucf101_rgb_train_split_1.txt \
7 | --pt_val_list ../datasets/lists/ucf101/ucf101_rgb_val_split_1.txt \
8 | --ft_train_list ../datasets/lists/hmdb51/hmdb51_rgb_train_split_1.txt \
9 | --ft_val_list ../datasets/lists/hmdb51/hmdb51_rgb_val_split_1.txt \
10 | --ft_dataset hmdb51 --ft_mode rgb \
11 | --ft_lr 0.001 --ft_lr_steps 10 20 25 30 35 40 --ft_epochs 45 --ft_batch_size 4 \
12 | --ft_data_length 64 --ft_spatial_size 224 --ft_workers 4 --ft_stride 1 --ft_dropout 0.5 \
13 | --ft_print-freq 100 --ft_fixed 0
--------------------------------------------------------------------------------
/src/Contrastive/scripts/hmdb51/pt_and_ft_hmdb51.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python main.py --gpus 3 --method pt_and_ft --pt_method be \
3 | --pt_batch_size 4 --pt_workers 4 --arch i3d --pt_spatial_size 224 --pt_stride 4 --pt_data_length 16 \
4 | --pt_nce_k 3569 --pt_softmax \
5 | --pt_moco --pt_epochs 10 --pt_save_freq 4 --pt_print_freq 100 --pt_dataset hmdb51 \
6 | --pt_train_list ../datasets/lists/hmdb51/hmdb51_rgb_train_split_1.txt \
7 | --pt_val_list ../datasets/lists/hmdb51/hmdb51_rgb_val_split_1.txt \
8 | --ft_train_list ../datasets/lists/hmdb51/hmdb51_rgb_train_split_1.txt \
9 | --ft_val_list ../datasets/lists/hmdb51/hmdb51_rgb_val_split_1.txt \
10 | --ft_dataset hmdb51 --ft_mode rgb \
11 | --ft_lr 0.001 --ft_lr_steps 10 20 25 30 35 40 --ft_epochs 45 --ft_batch_size 4 \
12 | --ft_data_length 64 --ft_spatial_size 224 --ft_workers 4 --ft_stride 1 --ft_dropout 0.5 \
13 | --ft_print-freq 100 --ft_fixed 0
--------------------------------------------------------------------------------
/src/Contrastive/scripts/kinetics/pt_and_ft.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python main.py --gpus 0,1,2,3,4,5,6,7 --method pt_and_ft --pt_method be \
3 | --pt_batch_size 64 --pt_workers 16 --arch i3d --pt_spatial_size 224 --pt_stride 4 --pt_data_length 16 \
4 | --pt_nce_k 65536 --pt_softmax \
5 | --pt_moco --pt_epochs 50 --pt_save_freq 4 --pt_print_freq 100 --pt_dataset kinetics \
6 | --pt_train_list ../datasets/lists/kinetics-400/ssd_kinetics_video_trainlist.txt \
7 | --pt_val_list ../datasets/lists/kinetics-400/ssd_kinetics_video_vallist.txt \
8 | --ft_train_list ../datasets/lists/hmdb51/hmdb51_rgb_train_split_1.txt \
9 | --ft_val_list ../datasets/lists/hmdb51/hmdb51_rgb_val_split_1.txt \
10 | --ft_dataset hmdb51 --ft_mode rgb \
11 | --ft_lr 0.001 --ft_lr_steps 10 20 25 30 35 40 --ft_epochs 45 --ft_batch_size 4 \
12 | --ft_data_length 64 --ft_spatial_size 224 --ft_workers 4 --ft_stride 1 --ft_dropout 0.5 \
13 | --ft_print-freq 100 --ft_fixed 0
--------------------------------------------------------------------------------
/src/Contrastive/scripts/something-something-v1/i3d_ft.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python main.py --gpus 2 --method ft \
3 | --arch i3d \
4 | --ft_train_list ../datasets/lists/something_something_v1/train_videofolder.txt \
5 | --ft_val_list ../datasets/lists/something_something_v1/val_videofolder.txt \
6 | --ft_root /data1/DataSet/something-something/20bn-something-something-v1/ \
7 | --ft_dataset sth_v1 --ft_mode rgb \
8 | --ft_lr 0.001 --ft_lr_steps 10 20 25 30 35 40 --ft_epochs 45 --ft_batch_size 6 \
9 | --ft_data_length 64 --ft_spatial_size 224 --ft_workers 6 --ft_stride 1 --ft_dropout 0.5 \
10 | --ft_print-freq 100 --ft_fixed 0
--------------------------------------------------------------------------------
/src/Contrastive/scripts/something-something-v1/i3d_pt_and_ft.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python main.py --gpus 3 --method pt_and_ft --pt_method be \
3 | --pt_batch_size 48 --pt_workers 8 --arch i3d --pt_spatial_size 224 --pt_stride 4 --pt_data_length 16 \
4 | --pt_nce_k 3569 --pt_softmax \
5 | --pt_moco --pt_epochs 10 --pt_save_freq 4 --pt_print_freq 100 --pt_dataset sth_v1 \
6 | --pt_train_list ../datasets/lists/something_something_v1/train_videofolder.txt \
7 | --pt_val_list ../datasets/lists/something_something_v1/val_videofolder.txt \
8 | --pt_root /data1/DataSet/something-something/20bn-something-something-v1/ \
9 | --ft_train_list ../datasets/lists/something_something_v1/train_videofolder.txt \
10 | --ft_val_list ../datasets/lists/something_something_v1/val_videofolder.txt \
11 | --ft_root /data1/DataSet/something-something/20bn-something-something-v1/ \
12 | --ft_dataset sth_v1 --ft_mode rgb \
13 | --ft_lr 0.001 --ft_lr_steps 10 20 25 30 35 40 --ft_epochs 45 --ft_batch_size 48 \
14 | --ft_data_length 16 --ft_spatial_size 224 --ft_workers 8 --ft_stride 1 --ft_dropout 0.5 \
15 | --ft_print-freq 100 --ft_fixed 0
--------------------------------------------------------------------------------
/src/Contrastive/scripts/something-something-v1/i3d_pt_and_ft_multi_gpus.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python main.py --gpus 0,1,2,3 --method pt_and_ft --pt_method be --arch i3d \
3 | --pt_batch_size 128 --pt_workers 16 --pt_spatial_size 224 --pt_stride 4 --pt_data_length 16 \
4 | --pt_nce_k 3569 --pt_softmax \
5 | --pt_moco --pt_epochs 10 --pt_save_freq 4 --pt_print_freq 100 --pt_dataset sth_v1 \
6 | --pt_train_list ../datasets/lists/something_something_v1/train_videofolder.txt \
7 | --pt_val_list ../datasets/lists/something_something_v1/val_videofolder.txt \
8 | --pt_root /data1/DataSet/something-something/20bn-something-something-v1/ \
9 | --ft_train_list ../datasets/lists/something_something_v1/train_videofolder.txt \
10 | --ft_val_list ../datasets/lists/something_something_v1/val_videofolder.txt \
11 | --ft_root /data1/DataSet/something-something/20bn-something-something-v1/ \
12 | --ft_dataset sth_v1 --ft_mode rgb \
13 | --ft_lr 0.001 --ft_lr_steps 10 20 25 30 35 40 --ft_epochs 45 --ft_batch_size 32 \
14 | --ft_data_length 64 --ft_spatial_size 224 --ft_workers 8 --ft_stride 1 --ft_dropout 0.5 \
15 | --ft_print-freq 100 --ft_fixed 0
--------------------------------------------------------------------------------
/src/Contrastive/scripts/something-something-v1/r3d_ft.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python main.py --gpus 2 --method ft \
3 | --arch r3d18 \
4 | --ft_weights ../experiments/sth_v1_to_sth_v1/11-10-0006/pt/models/current.pth \
5 | --ft_train_list ../datasets/lists/something_something_v1/train_videofolder.txt \
6 | --ft_val_list ../datasets/lists/something_something_v1/val_videofolder.txt \
7 | --ft_root /data1/DataSet/something-something/20bn-something-something-v1/ \
8 | --ft_dataset sth_v1 --ft_mode rgb \
9 | --ft_lr 0.001 --ft_lr_steps 10 20 25 30 35 40 --ft_epochs 45 --ft_batch_size 6 \
10 | --ft_data_length 64 --ft_spatial_size 224 --ft_workers 6 --ft_stride 1 --ft_dropout 0.5 \
11 | --ft_print-freq 100 --ft_fixed 0
--------------------------------------------------------------------------------
/src/Contrastive/scripts/something-something-v1/r3d_ft_multi_gpus.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python main.py --gpus 0,1,2,3 --method ft \
3 | --arch r3d18 \
4 | --ft_weights ../experiments/sth_v1_to_sth_v1/11-10-0006/pt/models/current.pth \
5 | --ft_train_list ../datasets/lists/something_something_v1/train_videofolder.txt \
6 | --ft_val_list ../datasets/lists/something_something_v1/val_videofolder.txt \
7 | --ft_root /data1/DataSet/something-something/20bn-something-something-v1/ \
8 | --ft_dataset sth_v1 --ft_mode rgb \
9 | --ft_lr 0.001 --ft_lr_steps 10 20 25 30 35 40 --ft_epochs 45 --ft_batch_size 32 \
10 | --ft_data_length 64 --ft_spatial_size 224 --ft_workers 8 --ft_stride 1 --ft_dropout 0.5 \
11 | --ft_print-freq 100 --ft_fixed 0
--------------------------------------------------------------------------------
/src/Contrastive/scripts/something-something-v1/r3d_pt_and_ft.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python main.py --gpus 3 --method pt_and_ft --pt_method be \
3 | --pt_batch_size 48 --pt_workers 8 --arch r3d18 --pt_spatial_size 224 --pt_stride 4 --pt_data_length 16 \
4 | --pt_nce_k 3569 --pt_softmax \
5 | --pt_moco --pt_epochs 10 --pt_save_freq 4 --pt_print_freq 100 --pt_dataset sth_v1 \
6 | --pt_train_list ../datasets/lists/something_something_v1/train_videofolder.txt \
7 | --pt_val_list ../datasets/lists/something_something_v1/val_videofolder.txt \
8 | --pt_root /data1/DataSet/something-something/20bn-something-something-v1/ \
9 | --ft_train_list ../datasets/lists/something_something_v1/train_videofolder.txt \
10 | --ft_val_list ../datasets/lists/something_something_v1/val_videofolder.txt \
11 | --ft_root /data1/DataSet/something-something/20bn-something-something-v1/ \
12 | --ft_dataset sth_v1 --ft_mode rgb \
13 | --ft_lr 0.001 --ft_lr_steps 10 20 25 30 35 40 --ft_epochs 45 --ft_batch_size 48 \
14 | --ft_data_length 16 --ft_spatial_size 224 --ft_workers 8 --ft_stride 1 --ft_dropout 0.5 \
15 | --ft_print-freq 100 --ft_fixed 0
--------------------------------------------------------------------------------
/src/Contrastive/scripts/something-something-v1/r3d_pt_and_ft_multi_gpus.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python main.py --gpus 0,1,2,3 --method pt_and_ft --pt_method be --arch r3d18 \
3 | --pt_batch_size 128 --pt_workers 16 --pt_spatial_size 224 --pt_stride 4 --pt_data_length 16 \
4 | --pt_nce_k 3569 --pt_softmax \
5 | --pt_moco --pt_epochs 10 --pt_save_freq 4 --pt_print_freq 100 --pt_dataset sth_v1 \
6 | --pt_train_list ../datasets/lists/something_something_v1/train_videofolder.txt \
7 | --pt_val_list ../datasets/lists/something_something_v1/val_videofolder.txt \
8 | --pt_root /data1/DataSet/something-something/20bn-something-something-v1/ \
9 | --ft_train_list ../datasets/lists/something_something_v1/train_videofolder.txt \
10 | --ft_val_list ../datasets/lists/something_something_v1/val_videofolder.txt \
11 | --ft_root /data1/DataSet/something-something/20bn-something-something-v1/ \
12 | --ft_dataset sth_v1 --ft_mode rgb \
13 | --ft_lr 0.001 --ft_lr_steps 10 20 25 30 35 40 --ft_epochs 45 --ft_batch_size 32 \
14 | --ft_data_length 64 --ft_spatial_size 224 --ft_workers 8 --ft_stride 1 --ft_dropout 0.5 \
15 | --ft_print-freq 100 --ft_fixed 0
--------------------------------------------------------------------------------
/src/Contrastive/scripts/ucf101/pt_and_ft.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python main.py --gpus 0 --method pt_and_ft --pt_method be \
3 | --pt_batch_size 4 --pt_workers 4 --arch i3d --pt_spatial_size 224 --pt_stride 4 --pt_data_length 16 \
4 | --pt_nce_k 3569 --pt_softmax \
5 | --pt_moco --pt_epochs 200 --pt_save_freq 4 --pt_print_freq 100 --pt_dataset ucf101 \
6 | --pt_train_list ../datasets/lists/ucf101/ucf101_rgb_train_split_1.txt \
7 | --pt_val_list ../datasets/lists/ucf101/ucf101_rgb_val_split_1.txt \
8 | --ft_train_list ../datasets/lists/ucf101/ucf101_rgb_train_split_1.txt \
9 | --ft_val_list ../datasets/lists/ucf101/ucf101_rgb_val_split_1.txt \
10 | --ft_dataset ucf101 --ft_mode rgb \
11 | --ft_lr 0.001 --ft_lr_steps 10 20 25 30 35 40 --ft_epochs 45 --ft_batch_size 4 \
12 | --ft_data_length 64 --ft_spatial_size 224 --ft_workers 4 --ft_stride 1 --ft_dropout 0.5 \
13 | --ft_print-freq 100 --ft_fixed 0
--------------------------------------------------------------------------------
/src/Contrastive/scripts/visualization/ucf101_data_augmentation_visualize.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python augmentation_visualization.py \
3 | --eval_indict loss --pt_loss distrubt \
4 | --train_list ../datasets/lists/ucf101/ucf101_rgb_train_split_1.txt \
5 | --val_list ../datasets/lists/ucf101/ucf101_rgb_val_split_1.txt \
6 | --dataset ucf101 \
7 | --arch i3d \
8 | --mode rgb \
9 | --lr 0.001 \
10 | --lr_steps 10 20 25 30 35 40 \
11 | --epochs 45 \
12 | --batch_size 1 \
13 | --data_length 16 \
14 | --spatial_size 224 \
15 | --workers 8 \
16 | --stride 4 \
17 | --dropout 0.5 \
18 | --gpus 3 \
19 | --logs_path ../experiments/logs/ucf101_i3d_ft \
20 | --print-freq 100
--------------------------------------------------------------------------------
/src/Contrastive/scripts/visualization/ucf101_triplet_visualize.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python triplet_visualization.py \
3 | --eval_indict loss --pt_loss MoCo \
4 | --train_list ../datasets/lists/ucf101/ucf101_rgb_train_split_1.txt \
5 | --val_list ../datasets/lists/ucf101/ucf101_rgb_val_split_1.txt \
6 | --dataset ucf101 \
7 | --arch i3d \
8 | --mode rgb \
9 | --lr 0.001 \
10 | --lr_steps 10 20 25 30 35 40 \
11 | --epochs 45 \
12 | --batch_size 4 \
13 | --data_length 16 \
14 | --spatial_size 224 \
15 | --workers 8 \
16 | --stride 4 \
17 | --dropout 0.5 \
18 | --gpus 3 \
19 | --logs_path ../experiments/logs/ucf101_i3d_ft \
20 | --print-freq 100
--------------------------------------------------------------------------------
/src/Contrastive/test.py:
--------------------------------------------------------------------------------
1 | """
2 | input: 16 x 112 x 112 x 3, with no overlapped
3 | """
4 | import torch.nn.parallel
5 | import torch.optim
6 | from utils.utils import *
7 | import os
8 | from data.config import data_config, augmentation_config
9 | from data.dataloader import data_loader_init
10 | from model.config import model_config
11 | from bk.option_old import args
12 | import datetime
13 |
14 |
15 | def get_action_index(list_txt='data/classInd.txt'):
16 | action_label = []
17 | with open(list_txt) as f:
18 | content = f.readlines()
19 | content = [x.strip('\r\n') for x in content]
20 | f.close()
21 | for line in content:
22 | label, action = line.split(' ')
23 | action_label.append(action)
24 | return action_label
25 |
26 |
27 | def plot_matrix_test(list_txt, cfu_mat="../experiments/evaluation/ucf101/_confusion.npy", date="",prefix="flow"):
28 | classes = get_action_index(list_txt)
29 | confuse_matrix = np.load(cfu_mat)
30 | plot_confuse_matrix(confuse_matrix, classes, date=date, prefix=prefix)
31 | plt.show()
32 |
33 |
34 | def eval_video(net, video_data):
35 | '''
36 | average 10 clips, do it later
37 | '''
38 | i, datas, label = video_data
39 | output = None
40 | net.eval()
41 | with torch.no_grad():
42 | for data in datas:
43 | if len(data.size()) == 4:
44 | data = data.unsqueeze(0)
45 | # print(data.size())
46 | overlapped_clips = 1 + int((data.size(2) - args.clip_size) / 10)
47 | # print(overlapped_clips)
48 | for i in range(overlapped_clips):
49 | if i > 1:
50 | break
51 | # print(data.size()) # 3 x 47 x 112 x 112
52 | clip_data = data[:, :, 10 * i:10 * i + args.clip_size, :, :]
53 | input_var = torch.autograd.Variable(clip_data)
54 | res = net(input_var)
55 | # print(torch.exp(res), label)
56 | res = torch.exp(res).data.cpu().numpy().copy()
57 | if output is None:
58 | output = res / overlapped_clips
59 | else:
60 | output += res / overlapped_clips
61 | return output, label
62 |
63 |
64 | def main(prefix='flow_scratch'):
65 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
66 | # == dataset config==
67 | num_class, data_length, image_tmpl = data_config(args)
68 | train_transforms, test_transforms, eval_transforms = augmentation_config(args)
69 | _, eval_data_loader, _, _, _, _ = data_loader_init(args, data_length, image_tmpl, train_transforms,
70 | test_transforms, eval_transforms)
71 | model = model_config(args, num_class)
72 | output = []
73 | total_num = len(eval_data_loader)
74 | for i, (data, label, index) in enumerate(eval_data_loader):
75 | proc_start_time = time.time()
76 | rst = eval_video(model, (i, data, label))
77 | output.append(rst)
78 | cnt_time = time.time() - proc_start_time
79 | if i % 10 == 0:
80 | print('video {} done, total {}/{}, average {} sec/video'.format(i, i + 1,
81 | total_num,
82 | float(cnt_time) / (i + 1)))
83 | if i > 300:
84 | video_pred = [np.argmax(x[0]) for x in output]
85 | video_labels = [x[1] for x in output]
86 | cf = confusion_matrix(video_labels, video_pred).astype(float)
87 | cls_cnt = cf.sum(axis=1)
88 | cls_hit = np.diag(cf)
89 | cls_acc = cls_hit / cls_cnt
90 | print('Accuracy {:.02f}%'.format(np.mean(cls_acc) * 100))
91 |
92 | date = datetime.datetime.today().strftime('%m-%d-%H%M')
93 | # =====output: every video's num and every video's label
94 | # =====x[0]:softmax value x[1]:label
95 | if not os.path.isdir("../experiments/evaluation/{}/{}".format(args.dataset, date)):
96 | os.mkdir("../experiments/evaluation/{}/{}".format(args.dataset, date))
97 | video_pred = [np.argmax(x[0]) for x in output]
98 | np.save("../experiments/evaluation/{}/{}/{}_video_pred.npy".format(args.dataset, date, prefix), video_pred)
99 | video_labels = [x[1] for x in output]
100 | np.save("../experiments/evaluation/{}/{}/{}_video_labels.npy".format(args.dataset, date, prefix), video_labels)
101 | cf = confusion_matrix(video_labels, video_pred).astype(float)
102 | np.save("../experiments/evaluation/{}/{}/{}_confusion.npy".format(args.dataset, date, prefix), cf)
103 | cf_name = "../experiments/evaluation/{}/{}/{}_confusion.npy".format(args.dataset, date, prefix)
104 | cls_cnt = cf.sum(axis=1)
105 | cls_hit = np.diag(cf)
106 | cls_acc = cls_hit / cls_cnt
107 | print(cls_acc)
108 | print('Accuracy {:.02f}%'.format(np.mean(cls_acc) * 100))
109 |
110 | name_list = [x.strip().split()[0] for x in open(args.val_list)]
111 | order_dict = {e: i for i, e in enumerate(sorted(name_list))}
112 | reorder_output = [None] * len(output)
113 | reorder_label = [None] * len(output)
114 | for i in range(len(output)):
115 | idx = order_dict[name_list[i]]
116 | reorder_output[idx] = output[i]
117 | reorder_label[idx] = video_labels[i]
118 | np.savez('../experiments/evaluation/' + args.dataset + '/' + date + "/" + prefix + args.mode + 'res',
119 | scores=reorder_output, labels=reorder_label)
120 | return cf_name, date
121 |
122 |
123 | def plot_confuse_matrix(matrix, classes,
124 | date="",
125 | prefix="flow",
126 | normalize=True,
127 | title=None,
128 | cmap=plt.cm.Blues
129 | ):
130 | """
131 | :param matrix:
132 | :param classes:
133 | :param normalize:
134 | :param title:
135 | :param cmap:
136 | :return:
137 | """
138 | if not title:
139 | if normalize:
140 | title = 'Normalized confusion matrix'
141 | else:
142 | title = 'Confusion matrix, without normalization'
143 |
144 | # Compute confusion matrix
145 | cm = matrix
146 | # Only use the labels that appear in the data
147 | if normalize:
148 | cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
149 | print("Normalized confusion matrix")
150 | else:
151 | print('Confusion matrix, without normalization')
152 |
153 | print(cm)
154 |
155 | fig, ax = plt.subplots()
156 |
157 | # We change the fontsize of minor ticks label
158 | ax.tick_params(axis='both', which='major', labelsize=6)
159 | ax.tick_params(axis='both', which='minor', labelsize=4)
160 |
161 | im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
162 | ax.figure.colorbar(im, ax=ax)
163 | # We want to show all ticks...
164 | ax.set(xticks=np.arange(cm.shape[1]),
165 | yticks=np.arange(cm.shape[0]),
166 | # ... and label them with the respective list entries
167 | xticklabels=classes, yticklabels=classes,
168 | title=title,
169 | ylabel='True label',
170 | xlabel='Predicted label')
171 |
172 | # Rotate the tick labels and set their alignment.
173 | plt.setp(ax.get_xticklabels(), rotation=60, ha="right",
174 | rotation_mode="anchor")
175 |
176 | # Loop over data dimensions and create text annotations.
177 | # fmt = '.2f'
178 | # thresh = cm.max() / 2.
179 | # for i in range(cm.shape[0]):
180 | # for j in range(cm.shape[1]):
181 | # ax.text(j, i, format(cm[i, j], fmt),
182 | # ha="center", va="center",
183 | # color="white" if cm[i, j] > thresh else "black")
184 | # fig.tight_layout()
185 | print("date is: {}".format(date))
186 | plt.savefig("../experiments/evaluation/hmdb51/{}/{}confuse.png".format(date, prefix))
187 | return ax
188 |
189 |
190 | if __name__ == '__main__':
191 | # prefix = 'TCA'
192 | cf_name, date = main(args.prefix)
193 | # cf_name = "../experiments/evaluation/hmdb51/03-24-1659/confusion.npy"
194 | classList = "../datasets/lists/hmdb51/hmdb51_classInd.txt"
195 | plot_matrix_test(classList, cf_name, date, prefix=args.prefix)
--------------------------------------------------------------------------------
/src/Contrastive/utils/data_process/gen_diving48_frames.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 | import os
3 | import sys
4 | import subprocess
5 |
6 |
7 | def process(dir_path, dst_dir_path):
8 | class_path = dir_path
9 | if not os.path.isdir(class_path):
10 | return
11 |
12 | dst_class_path = dst_dir_path
13 | if not os.path.exists(dst_class_path):
14 | os.mkdir(dst_class_path)
15 |
16 | for file_name in os.listdir(class_path):
17 | if '.mp4' not in file_name:
18 | continue
19 | name, ext = os.path.splitext(file_name)
20 | dst_directory_path = os.path.join(dst_class_path, name)
21 |
22 | video_file_path = os.path.join(class_path, file_name)
23 | try:
24 | if os.path.exists(dst_directory_path):
25 | if not os.path.exists(os.path.join(dst_directory_path, 'image_00001.jpg')):
26 | subprocess.call('rm -r \"{}\"'.format(dst_directory_path), shell=True)
27 | print('remove {}'.format(dst_directory_path))
28 | os.mkdir(dst_directory_path)
29 | else:
30 | continue
31 | else:
32 | os.mkdir(dst_directory_path)
33 | except:
34 | print(dst_directory_path)
35 | continue
36 | cmd = 'ffmpeg -i \"{}\" -vf scale=-1:240 \"{}/image_%05d.jpg\"'.format(video_file_path, dst_directory_path)
37 | print(cmd)
38 | subprocess.call(cmd, shell=True)
39 | print('\n')
40 |
41 |
42 | if __name__=="__main__":
43 | # /data1/DataSet/Diving48/rgb
44 | # /data1/DataSet/Diving48/rgb_frames
45 | dir_path = sys.argv[1]
46 | dst_dir_path = sys.argv[2]
47 |
48 | process(dir_path, dst_dir_path)
49 |
--------------------------------------------------------------------------------
/src/Contrastive/utils/data_process/gen_diving48_lists.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 |
4 | out = []
5 | count = 0
6 | front = ''
7 | with open('../datasets/lists/diving48/lists/Diving48_V2_train.json', 'r') as f:
8 | lines = json.load(f)
9 | for line in lines:
10 | # line = line.strip()
11 | count += 1
12 | new_line = front + str(line['vid_name']) + ' ' + str(line['end_frame']) + ' ' + str(line['label']) + '\n'
13 | out.append(new_line)
14 | if count % 100 == 0 and count != 0:
15 | print(count)
16 |
17 | with open('../datasets/lists/diving48/diving48_v2_train_no_front.txt', 'a') as f:
18 | f.writelines(out)
--------------------------------------------------------------------------------
/src/Contrastive/utils/data_process/gen_hmdb51_dir.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 | import os
3 | import sys
4 | import subprocess
5 |
6 |
7 | def process(dir_path, dst_dir_path):
8 | count0 = 0
9 | for class_name in os.listdir(dir_path):
10 | count1 = 0
11 | class_path = os.path.join(dir_path, class_name)
12 | cmd = 'mv {}/* {}'.format(class_path, dst_dir_path)
13 | subprocess.call(cmd, shell=True)
14 | count1 += 1
15 | count0 += 1
16 |
17 |
18 | if __name__=="__main__":
19 | # /data1/DataSet/hmdb51_sta_frames
20 | # /data1/DataSet/hmdb51_sta_frames2
21 | dir_path = sys.argv[1]
22 | dst_dir_path = sys.argv[2]
23 |
24 | process(dir_path, dst_dir_path)
25 |
--------------------------------------------------------------------------------
/src/Contrastive/utils/data_process/gen_hmdb51_frames.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 | import os
3 | import sys
4 | import subprocess
5 |
6 |
7 | def process(dir_path, dst_dir_path):
8 | count0 = 0
9 | for class_name in os.listdir(dir_path):
10 | count1 = 0
11 | class_path = os.path.join(dir_path, class_name)
12 | if not os.path.isdir(class_path):
13 | return
14 |
15 | dst_class_path = os.path.join(dst_dir_path, class_name)
16 | if not os.path.exists(dst_class_path):
17 | os.mkdir(dst_class_path)
18 | for file_name in os.listdir(class_path):
19 | if file_name[-4:] != '.avi':
20 | continue
21 | name, ext = os.path.splitext(file_name)
22 | dst_directory_path = os.path.join(dst_class_path, name)
23 |
24 | video_file_path = os.path.join(class_path, file_name)
25 | try:
26 | if os.path.exists(dst_directory_path):
27 | if not os.path.exists(os.path.join(dst_directory_path, 'image_00001.jpg')):
28 | subprocess.call('rm -r \"{}\"'.format(dst_directory_path), shell=True)
29 | print('remove {}'.format(dst_directory_path))
30 | os.mkdir(dst_directory_path)
31 | else:
32 | continue
33 | else:
34 | os.mkdir(dst_directory_path)
35 | except:
36 | print(dst_directory_path)
37 | continue
38 | cmd = 'ffmpeg -i \"{}\" -vf scale=-1:240 \"{}/image_%05d.jpg\"'.format(video_file_path, dst_directory_path)
39 | # print(cmd)
40 | subprocess.call(cmd, shell=True)
41 | count1 += 1
42 | print("{}/{} classes: {}/{} videos finished".format(count0, len(os.listdir(dir_path)), count1, len(os.listdir(class_path))))
43 | # print('\n')
44 | count0 += 1
45 |
46 |
47 | if __name__=="__main__":
48 | # /data1/DataSet/hmdb51_sta_new
49 | # /data1/DataSet/hmdb51_sta_frames
50 | dir_path = sys.argv[1]
51 | dst_dir_path = sys.argv[2]
52 |
53 | process(dir_path, dst_dir_path)
54 |
--------------------------------------------------------------------------------
/src/Contrastive/utils/data_process/gen_hmdb51_sta_list.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 |
4 | out = []
5 | count = 0
6 | front = ''
7 | with open('../datasets/lists/hmdb51/lists/Diving48_V2_train.json', 'r') as f:
8 | lines = json.load(f)
9 | for line in lines:
10 | # line = line.strip()
11 | count += 1
12 | new_line = front + str(line['vid_name']) + ' ' + str(line['end_frame']) + ' ' + str(line['label']) + '\n'
13 | out.append(new_line)
14 | if count % 100 == 0 and count != 0:
15 | print(count)
16 |
17 | with open('../datasets/lists/diving48/diving48_v2_train_no_front.txt', 'a') as f:
18 | f.writelines(out)
--------------------------------------------------------------------------------
/src/Contrastive/utils/data_process/gen_sub_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | def data_split(path, new_file_path, radio):
4 | video_list = [x for x in open(path)]
5 | # if not os.path.exists(new_file_path):
6 | count = 0
7 | new_list = []
8 | for line in video_list:
9 | count += 1
10 | if count % radio == 0:
11 | new_list.append(line)
12 | with open(new_file_path, 'w') as f:
13 | for item in new_list:
14 | f.write("%s" % item)
15 |
16 | if __name__ == '__main__':
17 | radio = 10
18 | path = "../datasets/lists/kinetics-400/ssd_kinetics_video_trainlist.txt"
19 | new_file_path = "../datasets/lists/kinetics-400/ssd_kinetics_video_trainlist_{}of{}.txt".format(1, radio)
20 | data_split(path, new_file_path, radio)
--------------------------------------------------------------------------------
/src/Contrastive/utils/data_process/semi_data_split.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 |
4 |
5 | def split(file, new_file, radio=0.1, cls_num=51):
6 | """
7 | ucf101/hmdb51
8 | :param file:
9 | :param new_file:
10 | :param radio:
11 | :param cls_num:
12 | :return:
13 | """
14 | cls_nums = np.zeros(cls_num)
15 | f = open(file, 'r')
16 | for data in f.readlines():
17 | name, frame, cls = data.split(" ")
18 | cls = int(cls)
19 | cls_nums[cls] += 1
20 | new_lists = list()
21 | new_cls_nums = np.zeros(cls_num)
22 | f = open(file, 'r')
23 | for data in f.readlines():
24 | name, frame, cls = data.split(" ")
25 | cls = int(cls)
26 | if new_cls_nums[cls] < radio * cls_nums[cls]:
27 | new_lists.append(data)
28 | new_cls_nums[cls] += 1
29 | with open(new_file, 'w') as f:
30 | f.writelines(new_lists)
31 | return
32 |
33 |
34 | def split_kinetics(file, new_file, radio=0.1, cls_num=400):
35 | cls_nums = np.zeros(cls_num)
36 | f = open(file, 'r')
37 | for data in f.readlines():
38 | name, cls = data.split(" ")
39 | cls = int(cls)
40 | cls_nums[cls] += 1
41 | new_lists = list()
42 | new_cls_nums = np.zeros(cls_num)
43 | f = open(file, 'r')
44 | for data in f.readlines():
45 | name, cls = data.split(" ")
46 | cls = int(cls)
47 | if new_cls_nums[cls] < radio * cls_nums[cls]:
48 | new_lists.append(data)
49 | new_cls_nums[cls] += 1
50 | with open(new_file, 'w') as f:
51 | f.writelines(new_lists)
52 | return
53 |
54 |
55 | if __name__ == '__main__':
56 | file = '../datasets/lists/kinetics-400/ssd_kinetics_video_trainlist.txt'
57 | radio = 0.5
58 | new_file = '../datasets/lists/kinetics-400/{}_ssd_kinetics_video_trainlist.txt'.format(radio)
59 | split_kinetics(file, new_file, radio=radio, cls_num=400)
--------------------------------------------------------------------------------
/src/Contrastive/utils/data_process/unrar_hmdb51_sta.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # !/bin/bash
3 |
4 | src_path=`readlink -f $1`
5 | dst_path=`readlink -f $2`
6 |
7 | rar_files=`find $src_path -name '*.rar'`
8 | IFS=$'\n'; array=$rar_files; unset IFS
9 | for rar_file in $array; do
10 | file_path=`echo $rar_file | sed -e "s;$src_path;$dst_path;"`
11 | ext_path=${file_path%/*}
12 | if [ ! -d $ext_path ]; then
13 | mkdir -p $ext_path
14 | fi
15 | unrar x $rar_file $ext_path
16 | done
17 |
18 | # bash utils/data_process/unrar_hmdb51_sta.sh /data1/DataSet/hmdb51_sta /data1/DataSet/hmdb51_sta_new
--------------------------------------------------------------------------------
/src/Contrastive/utils/gradient_check.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 | from matplotlib.lines import Line2D
4 |
5 | #
6 | # def plot_grad_flow(named_parameters):
7 | # '''Plots the gradients flowing through different layers in the net during training.
8 | # Can be used for checking for possible gradient vanishing / exploding problems.
9 | #
10 | # Usage: Plug this function in Trainer class after loss.backwards() as
11 | # "plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow'''
12 | # # for k, v in named_parameters:
13 | # # print(k)
14 | # ave_grads = []
15 | # max_grads = []
16 | # layers = []
17 | # for n, p in named_parameters:
18 | # if (p.requires_grad) and ("bias" not in n):
19 | # layers.append(n)
20 | # ave_grads.append(p.grad.abs().mean())
21 | # max_grads.append(p.grad.abs().max())
22 | # plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c")
23 | # plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b")
24 | # plt.hlines(0, 0, len(ave_grads) + 1, lw=2, color="k")
25 | # plt.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical")
26 | # plt.xlim(left=0, right=len(ave_grads))
27 | # plt.ylim(bottom=-0.001, top=0.02) # zoom in on the lower gradient regions
28 | # plt.xlabel("Layers")
29 | # plt.ylabel("average gradient")
30 | # plt.title("Gradient flow")
31 | # plt.grid(True)
32 | # plt.legend([Line2D([0], [0], color="c", lw=4),
33 | # Line2D([0], [0], color="b", lw=4),
34 | # Line2D([0], [0], color="k", lw=4)], ['max-gradient', 'mean-gradient', 'zero-gradient'])
35 | # plt.savefig('../experiments/gradients/r2p1d_gradient_benchmark.png', dpi=720, bbox_inches='tight')
36 |
37 |
38 | def plot_grad_flow(named_parameters):
39 | '''Plots the gradients flowing through different layers in the net during training.
40 | Can be used for checking for possible gradient vanishing / exploding problems.
41 |
42 | Usage: Plug this function in Trainer class after loss.backwards() as
43 | "plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow'''
44 | ave_grads = []
45 | max_grads = []
46 | layers = []
47 | for n, p in named_parameters:
48 | if (p.requires_grad) and ("bias" not in n):
49 | layers.append(n)
50 | ave_grads.append(p.grad.abs().mean())
51 | max_grads.append(p.grad.abs().max())
52 | plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c")
53 | plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b")
54 | plt.hlines(0, 0, len(ave_grads) + 1, lw=2, color="k")
55 | plt.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical")
56 | plt.xlim(left=0, right=len(ave_grads))
57 | plt.ylim(bottom=-0.001, top=0.02) # zoom in on the lower gradient regions
58 | plt.xlabel("Layers")
59 | plt.ylabel("average gradient")
60 | plt.title("Gradient flow")
61 | plt.grid(True)
62 | plt.legend([Line2D([0], [0], color="c", lw=4),
63 | Line2D([0], [0], color="b", lw=4),
64 | Line2D([0], [0], color="k", lw=4)], ['max-gradient', 'mean-gradient', 'zero-gradient'])
--------------------------------------------------------------------------------
/src/Contrastive/utils/learning_rate_adjust.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def pt_adjust_learning_rate(epoch, opt, optimizer):
5 | """Sets the learning rate to the initial LR decayed by 0.2 every steep step"""
6 | # if epoch < 2:
7 | # for param_group in optimizer.param_groups:
8 | # param_group['lr'] = 1e-7
9 | # return 0
10 | # print(epoch)
11 | # print(np.asarray(opt.pt_lr_decay_epochs))
12 | steps = np.sum(epoch > np.asarray(opt.pt_lr_decay_epochs))
13 | if steps > 0:
14 | new_lr = opt.pt_learning_rate * (opt.pt_lr_decay_rate ** steps)
15 | for param_group in optimizer.param_groups:
16 | param_group['lr'] = new_lr
17 |
18 |
19 | def ft_adjust_learning_rate(optimizer, intial_lr, epoch, lr_steps):
20 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
21 | decay = 0.3 ** (sum(epoch >= np.array(lr_steps)))
22 | lr = intial_lr * decay
23 | for param_group in optimizer.param_groups:
24 | param_group['lr'] = lr
25 |
26 |
--------------------------------------------------------------------------------
/src/Contrastive/utils/load_weights.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch.nn.init import xavier_uniform_, constant_, zeros_, normal_, kaiming_uniform_, kaiming_normal_
4 | import torch.nn as nn
5 | import math
6 |
7 |
8 | def weight_transform(model_dict, pretrain_dict):
9 | '''
10 |
11 | :return:
12 | '''
13 | count = 0
14 | # for k, v in pretrain_dict.items():
15 | # count += 1
16 | # print(k)
17 | # if count > 100:
18 | # break
19 | # for k, v in model_dict.items():
20 | # print(k)
21 | # count += 1
22 | # if count > 20:
23 | # break
24 |
25 | weight_dict = {k:v for k, v in pretrain_dict.items() if k in model_dict and 'custom' not in k} # and 'custom' not in k
26 | for k, v in weight_dict.items():
27 | print("load: {}".format(k))
28 | # print(weight_dict)
29 | model_dict.update(weight_dict)
30 | return model_dict
31 |
32 |
33 | def weights_init(model):
34 | """ Initializes the weights of the CNN model using the Xavier
35 | initialization.
36 | """
37 | # for m in model.modules():
38 | # if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv3d) or isinstance(m, nn.Conv1d):
39 | # xavier_uniform_(m.weight, gain=math.sqrt(2.0))
40 | # if m.bias:
41 | # constant_(m.bias, 0.1)
42 | # elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.BatchNorm3d):
43 | # normal_(m.weight, 1.0, 0.02)
44 | # zeros_(m.bias)
45 | if isinstance(model, nn.Conv2d) or isinstance(model, nn.Conv3d) or isinstance(model, nn.Conv1d):
46 | xavier_uniform_(model.weight.data, gain=math.sqrt(2.0))
47 | # if model.bias:
48 | # constant_(model.bias.data, 0.1)
49 | elif isinstance(model, nn.BatchNorm2d) or isinstance(model, nn.BatchNorm1d) or isinstance(model, nn.BatchNorm3d):
50 | normal_(model.weight.data, 1.0, 0.02)
51 |
52 |
53 | def pt_load_weight(args, model, model_ema, optimizer, contrast):
54 | # random initialization
55 | model.apply(weights_init)
56 | model_ema.apply(weights_init)
57 | if args.pt_resume:
58 | if os.path.isfile(args.pt_resume):
59 | print("=> loading checkpoint '{}'".format(args.resume))
60 | checkpoint = torch.load(args.resume, map_location='cpu')
61 | # checkpoint = torch.load(args.resume)
62 | args.start_epoch = checkpoint['epoch'] + 1
63 | model.load_state_dict(checkpoint['model'])
64 | optimizer.load_state_dict(checkpoint['optimizer'])
65 | contrast.load_state_dict(checkpoint['contrast'])
66 | model_ema.load_state_dict(checkpoint['model_ema'])
67 | print("=> loaded successfully '{}' (epoch {})"
68 | .format(args.resume, checkpoint['epoch']))
69 | del checkpoint
70 | torch.cuda.empty_cache()
71 | else:
72 | print("=> no checkpoint found at '{}'".format(args.resume))
73 |
74 | return model, model_ema
75 |
76 |
77 | def ft_load_weight(args, model):
78 | if args.ft_weights == "":
79 | print("no pretrained model available. Train from Scratch.....................")
80 | model.apply(weights_init)
81 | # weights_init(model)
82 | else:
83 | # weights_init(model)
84 | model.apply(weights_init)
85 | checkpoint = torch.load(args.ft_weights)
86 | try:
87 | print("model epoch {} lowese val: {}".format(checkpoint['epoch'], checkpoint['lowest_val']))
88 | except KeyError as e:
89 | try:
90 | print("model epoch {} best prec@1: {}".format(checkpoint['epoch'], checkpoint['best_prec1']))
91 | except KeyError:
92 | print("not train from this code!")
93 | try:
94 | # print("????")
95 | print(args.ft_weights.split('/')[-1][:9])
96 | if args.ft_weights.split('/')[-1][:11] == 'mutual_loss':
97 | pretrain_dict = {('.'.join(k.split('.')[2:]))[2:]: v for k, v in list(checkpoint['state_dict'].items())}
98 | elif args.ft_weights.split('/')[-1][:9] == 'byol_ckpt':
99 | # print("???")
100 | pretrain_dict = {k[26:]: v for k, v in list(checkpoint['model'].items())}
101 | # pretrain_dict = {k[7:]: v for k, v in list(checkpoint['model_ema'].items())}
102 | elif args.ft_weights.split('/')[-1][:4] in ['ckpt', 'curr', 'ucf1']:
103 | # print("???")
104 | pretrain_dict = {k[7:]: v for k, v in list(checkpoint['model'].items())}
105 | # pretrain_dict = {k: v for k, v in list(checkpoint['model'].items())}
106 | elif args.ft_weights.split('/')[-1][:3] == 'i3d':
107 | pretrain_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint['state_dict'].items())}
108 | else:
109 | pretrain_dict = {'.'.join(k.split('.')[2:]): v for k, v in list(checkpoint['state_dict'].items())}
110 | except KeyError:
111 | # pretrain_dict = checkpoint['model']
112 | pretrain_dict = checkpoint
113 | # pretrain_dict = {k[26:]: v for k, v in list(checkpoint['state_dict'].items())}
114 | model_dict = model.state_dict()
115 | model_dict = weight_transform(model_dict, pretrain_dict)
116 | model.load_state_dict(model_dict)
117 | if args.ft_resume:
118 | if os.path.isfile(args.ft_resume):
119 | print(("=> loading checkpoints '{}'".format(args.ft_resume)))
120 | checkpoint = torch.load(args.ft_resume)
121 | args.start_epoch = checkpoint['epoch']
122 | best_prec1 = checkpoint['best_prec1']
123 | model.load_state_dict(checkpoint['state_dict'])
124 | print(("=> loaded checkpoints '{}' (epoch {}) best_prec1 {}"
125 | .format(args.evaluate, checkpoint['epoch'], best_prec1)))
126 | else:
127 | print(("=> no checkpoints found at '{}'".format(args.ft_resume)))
128 | return model
--------------------------------------------------------------------------------
/src/Contrastive/utils/moment_update.py:
--------------------------------------------------------------------------------
1 | def update_ema_variables(ema_model, model, alpha=0.9):
2 | # Use the true average until the exponential average is more correct
3 | for ema_param, param in zip(ema_model.parameters(), model.parameters()):
4 | ema_param.data.mul_(alpha).add_(1 - alpha, param.data)
5 |
6 |
7 | def moment_update(model, model_ema, m):
8 | """ model_ema = m * model_ema + (1 - m) model """
9 | for p1, p2 in zip(model.parameters(), model_ema.parameters()):
10 | p2.data.mul_(m).add_(1 - m, p1.detach().data)
11 | # p2.data.mul_(m).add_(1 - m, p1.data)
--------------------------------------------------------------------------------
/src/Contrastive/utils/recoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import matplotlib.pyplot as plt
3 | import numpy as np
4 | import os
5 | import shutil
6 | import datetime
7 | import shutil
8 |
9 |
10 | class Record:
11 | def __init__(self, args):
12 | super(Record, self).__init__()
13 | if args.method == 'pt_and_ft':
14 | self.origin_path = '../experiments/' + args.pt_dataset + '_to_' + args.ft_dataset
15 | elif args.method == 'pt':
16 | self.origin_path = '../experiments/pt_' + args.pt_dataset
17 | else:
18 | self.origin_path = '../experiments/ft_' + args.ft_dataset
19 | if not os.path.exists(self.origin_path):
20 | os.mkdir(self.origin_path)
21 | self.path = self.origin_path + '/' + args.date
22 | if not os.path.exists(self.path):
23 | os.mkdir(self.path)
24 | if args.method in ['pt', 'pt_and_ft']:
25 | self.pt_path = self.path + '/pt'
26 | self.pt_model_path = '{}/models'.format(self.pt_path)
27 | if not os.path.exists(self.pt_path):
28 | os.mkdir(self.pt_path)
29 | if not os.path.exists(self.pt_model_path):
30 | os.mkdir(self.pt_model_path)
31 | if args.method in ['ft', 'pt_and_ft']:
32 | self.ft_path = self.path + '/ft'
33 | self.ft_model_path = '{}/models'.format(self.ft_path)
34 | if not os.path.exists(self.ft_path):
35 | os.mkdir(self.ft_path)
36 | if not os.path.exists(self.ft_model_path):
37 | os.mkdir(self.ft_model_path)
38 | self.args = args
39 | # pretrain init
40 | self.pt_init()
41 | self.pt_train_loss_list = list()
42 | self.pt_checkpoint = ''
43 | # finetune init
44 | self.ft_init()
45 | self.ft_train_acc_list = list()
46 | self.ft_val_acc_list = list()
47 | self.ft_train_loss_list = list()
48 | self.ft_val_loss_list = list()
49 | self.front = self.args.method
50 | print(self.args)
51 | self.record_txt = os.path.join(self.path, self.front + '_logs.txt')
52 | self.record_init(args, 'w')
53 | self.src_init()
54 | self.filename = ''
55 | self.best_name = ''
56 |
57 | def pt_init(self):
58 | return
59 |
60 | def ft_init(self):
61 | return
62 |
63 | def src_init(self):
64 | if not os.path.exists(self.path + '/src_record'):
65 | shutil.copytree('../src', self.path + '/src_record')
66 |
67 | def record_init(self, args, open_type):
68 | with open(self.record_txt, open_type) as f:
69 | for arg in vars(args):
70 | f.write('{}: {}\n'.format(arg, getattr(args, arg)))
71 | f.write('\n')
72 |
73 | def record_message(self, open_type, message):
74 | with open(self.record_txt, open_type) as f:
75 | f.write(message + '\n\n')
76 |
77 | def record_ft_train(self, loss, acc=0):
78 | self.ft_train_acc_list.append(acc)
79 | self.ft_train_loss_list.append(loss)
80 |
81 | def record_ft_val(self, loss, acc=0):
82 | self.ft_val_acc_list.append(acc)
83 | self.ft_val_loss_list.append(loss)
84 |
85 | def record_pt_train(self, loss):
86 | self.pt_train_loss_list.append(loss)
87 |
88 | def plot_figure(self, plot_list, name='_performance'):
89 | epoch = len(plot_list[0][0])
90 | axis = np.linspace(1, epoch, epoch)
91 | fig = plt.figure()
92 | plt.title(self.args.arch + '_' + self.args.status + name)
93 | for i in range(len(plot_list)):
94 | plt.plot(axis, plot_list[i][0], label=plot_list[i][1])
95 | plt.legend()
96 | plt.xlabel('Epochs')
97 | plt.ylabel('%')
98 | plt.grid(True)
99 | plt.savefig(os.path.join(self.path, '{}.pdf'.format(self.args.status + name)))
100 | plt.close(fig)
101 |
102 | def save_ft_model(self, model, is_best=False):
103 | self.save_ft_checkpoint(self.args, model, is_best)
104 | plot_list = list()
105 | plot_list.append([self.ft_train_acc_list, 'train_acc'])
106 | plot_list.append([self.ft_val_acc_list, 'val_acc'])
107 | plot_list.append([self.ft_train_loss_list, 'train_loss'])
108 | plot_list.append([self.ft_val_loss_list, 'val_loss'])
109 | self.plot_figure(plot_list)
110 |
111 | def save_pt_model(self, args, state, epoch):
112 | self.save_pt_checkpoint(args, state, epoch)
113 | plot_list = list()
114 | plot_list.append([self.pt_train_loss_list, 'train_loss'])
115 | self.plot_figure(plot_list)
116 | print('==> Saving...')
117 |
118 | def save_pt_checkpoint(self, args, state, epoch):
119 | save_file = os.path.join(self.pt_model_path, 'current.pth')
120 | self.pt_checkpoint = save_file
121 | torch.save(state, save_file)
122 | if epoch % args.pt_save_freq == 0:
123 | save_file = os.path.join(self.pt_model_path, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
124 | torch.save(state, save_file)
125 | # help release GPU memory
126 | del state
127 | torch.cuda.empty_cache()
128 |
129 | def save_ft_checkpoint(self, args, state, is_best):
130 | self.filename = self.ft_path + '/' + args.ft_mode + '_model_latest.pth.tar'
131 | torch.save(state, self.filename)
132 | if is_best:
133 | self.best_name = self.ft_path + '/' + args.ft_mode + '_model_best.pth.tar'
134 | shutil.copyfile(self.filename, self.best_name)
135 |
--------------------------------------------------------------------------------
/src/Contrastive/utils/utils.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import numpy as np
3 | import torch
4 | import time
5 | import shutil
6 | import matplotlib.pyplot as plt
7 | # from sklearn.metrics import confusion_matrix
8 |
9 | class Timer(object):
10 | """
11 | docstring for Timer
12 | """
13 | def __init__(self):
14 | super(Timer, self).__init__()
15 | self.total_time = 0.0
16 | self.calls = 0
17 | self.start_time = 0.0
18 | self.diff = 0.0
19 | self.average_time = 0.0
20 |
21 | def tic(self):
22 | self.start_time = time.time()
23 |
24 | def toc(self, average = False):
25 | self.diff = time.time() - self.start_time
26 | self.calls += 1
27 | self.total_time += self.diff
28 | self.average_time = self.total_time / self.calls
29 | if average:
30 | return self.average_time
31 | else:
32 | return self.diff
33 |
34 | def format(self, time):
35 | m,s = divmod(time, 60)
36 | h,m = divmod(m, 60)
37 | d,h = divmod(h, 24)
38 | return ("{}d:{}h:{}m:{}s".format(int(d), int(h), int(m), int(s)))
39 |
40 | def end_time(self, extra_time):
41 | """
42 | calculate the end time for training, show local time
43 | """
44 | localtime= time.asctime(time.localtime(time.time() + extra_time))
45 | return localtime
46 |
47 |
48 | class AverageMeter(object):
49 | """Computes and stores the average and current value"""
50 |
51 | def __init__(self):
52 | self.reset()
53 |
54 | def reset(self):
55 | self.val = 0
56 | self.avg = 0
57 | self.sum = 0
58 | self.count = 0
59 |
60 | def update(self, val, n=1):
61 | self.val = val
62 | self.sum += val * n
63 | self.count += n
64 | self.avg = self.sum / self.count
65 |
66 |
67 | class Logger(object):
68 |
69 | def __init__(self, path, header):
70 | self.log_file = open(path, 'w')
71 | self.logger = csv.writer(self.log_file, delimiter='\t')
72 |
73 | self.logger.writerow(header)
74 | self.header = header
75 |
76 | def __del(self):
77 | self.log_file.close()
78 |
79 | def log(self, values):
80 | write_values = []
81 | for col in self.header:
82 | assert col in values
83 | write_values.append(values[col])
84 |
85 | self.logger.writerow(write_values)
86 | self.log_file.flush()
87 |
88 |
89 | def load_value_file(file_path):
90 | with open(file_path, 'r') as input_file:
91 | value = float(input_file.read().rstrip('\n\r'))
92 |
93 | return value
94 |
95 |
96 | def calculate_accuracy(outputs, targets):
97 | batch_size = targets.size(0)
98 |
99 | _, pred = outputs.topk(1, 1, True)
100 | pred = pred.t()
101 | correct = pred.eq(targets.view(1, -1))
102 | n_correct_elems = correct.float().sum().data[0]
103 |
104 | return n_correct_elems / batch_size
105 |
106 |
107 | class TrainingHelper(object):
108 | def __init__(self, image):
109 | self.image = image
110 | def congratulation(self):
111 | """
112 | if finish training success, print congratulation information
113 | """
114 | for i in range(40):
115 | print('*')*i
116 | print('finish training')
117 |
118 |
119 | def submission_file(ids, outputs, filename):
120 | """ write list of ids and outputs to filename"""
121 | with open(filename, 'w') as f:
122 | for vid, output in zip(ids, outputs):
123 | scores = ['{:g}'.format(x)
124 | for x in output]
125 | f.write('{} {}\n'.format(vid, ' '.join(scores)))
126 |
127 |
128 | def accuracy(output, target, topk=(1,)):
129 | """
130 | Computes the precision@k for the specified values of k
131 | output: 16(batch_size) x 101
132 | target: 16 x 1
133 | """
134 | maxk = max(topk)
135 | batch_size = target.size(0)
136 |
137 | _, pred = output.topk(maxk, 1, True, True)
138 | pred = pred.t()
139 | correct = pred.eq(target.view(1, -1).expand_as(pred)) # 5 x 16
140 | #print(correct)
141 |
142 | res = []
143 | for k in topk:
144 | correct_k = correct[:k].view(-1).float().sum(0)
145 | res.append(correct_k.mul_(100.0 / batch_size))
146 | return res
147 |
148 |
149 | def accuracy_mixup(output, targets, target_a, target_b, lam, topk=(1,)):
150 | """Computes the precision@k for the specified values of k"""
151 | maxk = max(topk)
152 | batch_size = targets.size(0)
153 |
154 | _, pred = output.topk(maxk, 1, True, True)
155 | pred = pred.t() #5 x 20
156 | correct_1 = pred.eq(target_a.data.view(1, -1).expand_as(pred))
157 | correct_2 = pred.eq(target_b.data.view(1, -1).expand_as(pred))
158 |
159 | res = []
160 | for k in topk:
161 | correct_k = lam * correct_1[:k].view(-1).float().sum(0) + (1-lam) * correct_2[:k].view(-1).float().sum(0)
162 | res.append(correct_k.mul_(100.0 / batch_size))
163 | return res
164 |
--------------------------------------------------------------------------------
/src/Contrastive/utils/visualization/augmentation_visualization.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | import os
4 | import torch
5 | import torch.backends.cudnn as cudnn
6 | from data.config import data_config, augmentation_config
7 | from data.dataloader import data_loader_init
8 | from model.config import model_config
9 | from augment.gen_positive import GenPositive
10 | from bk.option_old import args
11 | import numpy as np
12 |
13 | lowest_val_loss = float('inf')
14 | best_prec1 = 0
15 | torch.manual_seed(1)
16 |
17 |
18 | def test(train_loader, model, suffix='each_rotation_10'):
19 | model.eval()
20 | for i, (inputs, _, index) in enumerate(train_loader):
21 | dir_path = "../experiments/visualization/augmentation/{}_{}".format(suffix, index[0])
22 | # if not os.path.exists(dir_path):
23 | # os.makedirs(dir_path)
24 | augmentation_video = inputs[0].permute(1, 2, 3, 0).cpu().numpy()
25 | np.save("{}.npy".format(dir_path), augmentation_video)
26 | # print(augmentation_video.shape)
27 | # path = "{}/{}.avi".format(dir_path, str(index[0].cpu().numpy()))
28 | # print(path)
29 | # save_video(augmentation_video, path)
30 | print("{}/{} finished".format(i, len(train_loader)))
31 | # output = model(anchor)
32 | return True
33 |
34 |
35 | def main():
36 | # =
37 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # close the warning
38 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
39 | torch.manual_seed(1)
40 | cudnn.benchmark = True
41 | # == dataset config==
42 | num_class, data_length, image_tmpl = data_config(args)
43 | train_transforms, test_transforms, eval_transforms = augmentation_config(args)
44 | train_data_loader, val_data_loader, _, _, _, _ = data_loader_init(args, data_length, image_tmpl, train_transforms, test_transforms, eval_transforms)
45 | # == model config==
46 | model = model_config(args, num_class)
47 | pos_aug = GenPositive()
48 | # == train and eval==
49 | test(val_data_loader, model)
50 | return 1
51 |
52 |
53 | if __name__ == '__main__':
54 | main()
--------------------------------------------------------------------------------
/src/Contrastive/utils/visualization/color_palettes.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import seaborn as sns
3 | import matplotlib.patches as mpatches
4 | import matplotlib.pyplot as plt
5 | # sns.set()
6 |
7 | # palette = np.array(sns.color_palette("hls", 10))
8 | patch = mpatches.Patch(color='red', label='red')
9 | plt.legend(handles=[patch])
10 | plt.savefig("../../../experiments/visualization/color_paletters.png")
--------------------------------------------------------------------------------
/src/Contrastive/utils/visualization/confusion_matrix.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 |
4 | def get_action_index(list_txt='data/classInd.txt'):
5 | action_label = []
6 | with open(list_txt) as f:
7 | content = f.readlines()
8 | content = [x.strip('\r\n') for x in content]
9 | f.close()
10 | for line in content:
11 | label, action = line.split(' ')
12 | action_label.append(action)
13 | return action_label
14 |
15 | def plot_matrix_test(list_txt, cfu_mat="../experiments/evaluation/ucf101/_confusion.npy", s_path="../experiments/evaluation/hmdb51/03-24-1659/confuse.png"):
16 | classes = get_action_index(list_txt)
17 | confuse_matrix = np.load(cfu_mat)
18 | plot_confuse_matrix(confuse_matrix, classes, s_path)
19 | plt.show()
20 |
21 | def plot_confuse_matrix(matrix, classes,
22 | s_path = "../experiments/evaluation/hmdb51/03-24-1659/confuse.png",
23 | normalize=True,
24 | title=None,
25 | cmap=plt.cm.Blues
26 | ):
27 | """
28 | :param matrix:
29 | :param classes:
30 | :param s_path:
31 | :param normalize:
32 | :param title:
33 | :param cmap:
34 | :return:
35 | """
36 | if not title:
37 | if normalize:
38 | title = 'Normalized confusion matrix'
39 | else:
40 | title = 'Confusion matrix, without normalization'
41 |
42 | # Compute confusion matrix
43 | cm = matrix
44 | # Only use the labels that appear in the data
45 | if normalize:
46 | cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
47 | print("Normalized confusion matrix")
48 | else:
49 | print('Confusion matrix, without normalization')
50 |
51 | print(cm)
52 |
53 | fig, ax = plt.subplots()
54 |
55 | # We change the fontsize of minor ticks label
56 | ax.tick_params(axis='both', which='major', labelsize=3)
57 | ax.tick_params(axis='both', which='minor', labelsize=1)
58 |
59 | im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
60 | ax.figure.colorbar(im, ax=ax)
61 | # We want to show all ticks...
62 | ax.set(xticks=np.arange(cm.shape[1]),
63 | yticks=np.arange(cm.shape[0]),
64 | # ... and label them with the respective list entries
65 | xticklabels=classes, yticklabels=classes,
66 | title=title,
67 | ylabel='True label',
68 | xlabel='Predicted label')
69 |
70 | # Rotate the tick labels and set their alignment.
71 | plt.setp(ax.get_xticklabels(), rotation=60, ha="right",
72 | rotation_mode="anchor")
73 |
74 | # Loop over data dimensions and create text annotations.
75 | fmt = '.2f'
76 | # thresh = cm.max() / 2.
77 | thresh = 0.2
78 | for i in range(cm.shape[0]):
79 | for j in range(cm.shape[1]):
80 | if cm[i, j] > thresh and i != j:
81 | ax.text(j, i, "({},{})".format(classes[i], classes[j]) + format(cm[i, j], fmt),
82 | ha="center", va="center",fontsize='smaller',
83 | color="black")
84 | fig.tight_layout()
85 | plt.savefig(s_path, dpi=1024)
86 | return ax
87 |
88 | if __name__ == '__main__':
89 | # cf_name = "../experiments/evaluation/ucf101/04-12-1132/confusion.npy"
90 | # classList = "../datasets/lists/ucf101/classInd.txt"
91 | # save_path = "../experiments/evaluation/ucf101/04-12-1132/78.8_confuse.png"
92 | # plot_matrix_test(classList, cf_name, save_path)
93 | cf_name = "../experiments/evaluation/ucf101/04-14-1141/confusion.npy"
94 | classList = "../datasets/lists/ucf101/classInd.txt"
95 | save_path = "../experiments/evaluation/ucf101/04-14-1141/63.3_confuse.png"
96 | plot_matrix_test(classList, cf_name, save_path)
--------------------------------------------------------------------------------
/src/Contrastive/utils/visualization/mixup_visualization.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import os
3 |
4 |
5 | def mixup(im1, im2, prob):
6 | img1 = cv2.imread(im1)
7 | img2 = cv2.imread(im2)
8 | img = img1 * (1-prob) + img2 * prob
9 | return img
10 |
11 |
12 | name = '6arrowswithin30seconds_shoot_bow_f_nm_np1_fr_med_1/'
13 | video1 = "/data/jinhuama/DataSet/hmdb51/" + name
14 | index = 15
15 | new_dir = "../experiments/visualizations/mixup/"
16 | prob = 0.3
17 |
18 | if not os.path.exists(new_dir + name):
19 | os.mkdir(new_dir + name)
20 | for image in os.listdir(video1):
21 | image1 = video1 + image
22 | image2 = video1 + "img_{:05d}.jpg".format(index)
23 | new_img = mixup(image1, image2, prob)
24 | cv2.imwrite(new_dir + name + image, new_img)
--------------------------------------------------------------------------------
/src/Contrastive/utils/visualization/multi100.txt:
--------------------------------------------------------------------------------
1 | 33.3
2 | 43.0
3 | 6.0
4 | -16.0
5 | 12.0
6 | 15.0
7 | 25.0
8 | -3.0
9 | 6.0
10 | 23.0
11 | 12.0
12 | -10.6
13 | 35.0
14 | 22.0
15 | 3.0
16 | -16.0
17 | 23.0
18 | 12.0
19 | 20.0
20 | 14.0
21 | 13.0
22 | 33.0
23 | -6.0
24 | 13.3
25 | 10.0
26 | -10.0
27 | 44.0
28 | -24.0
29 | 6.0
30 | -13.3
31 | 15.0
32 | 20.0
33 | 26.6
34 | 20.0
35 | 36.6
36 | 26.6
37 | -13.3
38 | -10.0
39 | -10.0
40 | 19.0
41 | 10.0
42 | 16.6
43 | 10.0
44 | 6.0
45 | 46.3
46 | 36.6
47 | 52.4
48 | 6.6
49 | 20.0
50 | 46.6
51 | 6.3
52 |
--------------------------------------------------------------------------------
/src/Contrastive/utils/visualization/pearson_calculate.txt:
--------------------------------------------------------------------------------
1 | 0.22551442576666667 0.333
2 | 0.10818058883333334 0.43
3 | 0.5162171408333334 0.06
4 | 0.4882573183 -0.16
5 | 0.0708494733 0.12
6 | 0.13339997128533332 0.15
7 | 0.0516073793 0.25
8 | 0.16686041497466667 -0.03
9 | 0.2689299939 0.06
10 | 0.16904905083333333 0.23
11 | 0.19995855823333333 0.12
12 | 0.2766080123 -0.106
13 | 0.10459493276666666 0.35
14 | 0.13576242733333332 0.22
15 | 0.2628492088 0.03
16 | 0.6472266601333333 -0.16
17 | 0.0820547952 0.23
18 | 0.0504232759 0.12
19 | 0.29175092593333335 0.2
20 | 0.0324064948 0.14
21 | 0.12532078933333332 0.13
22 | 0.0307350102 0.33
23 | 0.5459736722666667 -0.06
24 | 0.11691634886666666 0.133
25 | 0.0928835198 0.1
26 | 0.5027140373666666 -0.1
27 | 0.400281556 0.44
28 | 0.22275721296666667 -0.24
29 | 0.24755843986666665 0.06
30 | 0.30486327076666664 -0.133
31 | 0.12616917856666665 0.15
32 | 0.34176297733333333 0.2
33 | 0.091833597 0.266
34 | 0.13851137033333333 0.2
35 | 0.09020716253333333 0.366
36 | 0.09902224873333335 0.266
37 | 0.17993909336666666 -0.133
38 | 0.0476045226 -0.1
39 | 0.8139881392666667 -0.1
40 | 0.18950731996666664 0.19
41 | 0.1327474982 0.1
42 | 0.1923174741 0.166
43 | 0.06810675673333333 0.1
44 | 0.1636814759 0.06
45 | 0.07896679256666667 0.463
46 | 0.16391625416666666 0.366
47 | 0.11104895753333333 0.524
48 | 0.0198147214 0.066
49 | 0.0264955944 0.2
50 | 0.0226321177 0.466
51 | 0.0656435398 0.063
--------------------------------------------------------------------------------
/src/Contrastive/utils/visualization/pearson_correlation.py:
--------------------------------------------------------------------------------
1 | # from scipy import stats
2 | # import os
3 | #
4 | # x = []
5 | # y = []
6 | # with open('pearson_calculate.txt', 'r') as f:
7 | # lines = f.readlines()
8 | # for line in lines:
9 | # if len(line.strip().split(' '))==2:
10 | # x.append(float(line.strip().split(' ')[0]))
11 | # y.append(float(line.strip().split(' ')[1]))
12 | # else:
13 | # x.append(float(line.strip().split('\t')[0]))
14 | # y.append(float(line.strip().split('\t')[1]))
15 | # print(len(x), len(y))
16 | # print(stats.pearsonr(x, y))
17 | #
18 |
19 | from scipy import stats
20 | import os
21 |
22 | x = []
23 | y = []
24 | g = open('multi100.txt', 'w')
25 | with open('pearson_calculate.txt', 'r') as f:
26 | lines = f.readlines()
27 | for line in lines:
28 | if len(line.strip().split(' ')) is 2:
29 | g.write(str(float(line.strip().split(' ')[1])*100) + '\n')
30 | else:
31 | g.write(
32 | str(float(line.strip().split('\t')[1]) * 100) + '\n')
--------------------------------------------------------------------------------
/src/Contrastive/utils/visualization/plot_class_distribution.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import pickle
4 |
5 |
6 | def plot_histgorm(l, num=51):
7 | x = np.arange(1, num+1)
8 | fig, axs = plt.subplots(1, 1, sharex=True)
9 | axs[0].hist(l, bins=num)
10 | plt.savefig('visualization/histgorm.png')
11 |
12 |
13 | def get_action_index(cls_list='../../../datasets/lists/hmdb51/hmdb51_classInd.txt'):
14 | action_label = []
15 | with open(cls_list) as f:
16 | content = f.readlines()
17 | content = [x.strip('\r\n') for x in content]
18 | f.close()
19 | for line in content:
20 | label, action = line.split(' ')
21 | action_label.append(action)
22 | return action_label
23 |
24 |
25 | def analyse_record(scratch_label, scratch_predict, ssl_label, ssl_predict, dataset='hmdb51'):
26 | if dataset == 'hmdb51':
27 | class_num = 51
28 | elif dataset == 'ucf101':
29 | class_num = 101
30 | else:
31 | Exception("not implement dataset!")
32 | scratch_wrong = 0
33 | scratch_clsses = np.zeros(class_num)
34 | scratch_real_classes = np.zeros(class_num)
35 | for i in range(len(scratch_label)):
36 | max_index = scratch_predict[i]
37 | label = scratch_label[i]
38 | scratch_real_classes[label] += 1
39 | if max_index != label:
40 | scratch_wrong += 1
41 | else:
42 | scratch_clsses[label] += 1
43 |
44 | #=============================our self-supervised learning =======================
45 | self_supervised_wrong = 0
46 | self_supervised_classes = np.zeros(class_num)
47 | self_supervised_real_classes = np.zeros(class_num)
48 | for i in range(len(ssl_label)):
49 | max_index = ssl_predict[i]
50 | label = ssl_label[i]
51 | self_supervised_real_classes[label] += 1
52 | if max_index != label:
53 | self_supervised_wrong += 1
54 | else:
55 | self_supervised_classes[label] += 1
56 | print("scratch Top-1 is: {}".format(np.mean(scratch_clsses/(scratch_real_classes+1))))
57 | print("SSL Top-1 is: {}".format(np.mean(self_supervised_classes/self_supervised_real_classes)))
58 | arr = (self_supervised_classes/self_supervised_real_classes - scratch_clsses/(scratch_real_classes+1))
59 | topk = arr.argsort()[-5:][::-1]
60 | mink = arr.argsort()[:5][::1]
61 | if dataset == 'hmdb51':
62 | classes = get_action_index()
63 | elif dataset == 'ucf101':
64 | classes = get_action_index(cls_list='../../../datasets/lists/ucf101/classInd.txt')
65 | else:
66 | Exception("not implement dataset!")
67 | print("five largest ======>>")
68 | for i in range(5):
69 | print(classes[topk[i]], arr[topk[i]])
70 | print("five minium ======>>")
71 | for i in range(5):
72 | print(classes[mink[i]], arr[mink[i]])
73 | # print(self_supervised_real_classes/30, self_supervised_real_classes/30)
74 | # print((scratch_best_clsses - scratch_clsses) / 30)
75 | # print(scratch__wrong)
76 | # print(scratch__clsses/30, real_classes/30)
77 | # plot_histgorm(l_clsses, m_clsses, s_clsses)
78 | # rows = []
79 | # classes = get_action_index()
80 | # for i in range(len(l_clsses)):
81 | # rows.append((l_clsses[i]/30, classes[i],
82 | # l_clsses[i] / 30 ))
83 | # header = ['l', 'm', 's', 'class', 'avg']
84 | # with open('visualization/store.csv', 'w') as f:
85 | # f_csv = csv.writer(f)
86 | # f_csv.writerow(header)
87 | # f_csv.writerows(rows)
88 | return True
89 |
90 |
91 | if __name__ == '__main__':
92 | # #====================================HMDB51============================================
93 | # #51.3
94 | # ssl_label = np.load("../../../experiments/evaluation/hmdb51/03-24-1659/video_labels.npy")
95 | # ssl_predict = np.load("../../../experiments/evaluation/hmdb51/03-24-1659/video_pred.npy")
96 | # #31.9
97 | # scratch_label = np.load("../../../experiments/evaluation/hmdb51/04-11-0926/video_labels.npy")
98 | # scratch_predict = np.load("../../../experiments/evaluation/hmdb51/04-11-0926/video_pred.npy")
99 | # analyse_record(scratch_label, scratch_predict, ssl_label, ssl_predict)
100 | # ====================================UCF101============================================
101 | # 78.83
102 | ssl_label = np.load("../../../experiments/evaluation/ucf101/04-12-1132/video_labels.npy")
103 | ssl_predict = np.load("../../../experiments/evaluation/ucf101/04-12-1132/video_pred.npy")
104 | # 63.3?
105 | scratch_label = np.load("../../../experiments/evaluation/ucf101/04-14-1141/video_labels.npy")
106 | scratch_predict = np.load("../../../experiments/evaluation/ucf101/04-14-1141/video_pred.npy")
107 | analyse_record(scratch_label, scratch_predict, ssl_label, ssl_predict, dataset='ucf101')
--------------------------------------------------------------------------------
/src/Contrastive/utils/visualization/single_video_visualization.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | import os
4 | from utils.utils import Timer
5 | import torch
6 | import torch.backends.cudnn as cudnn
7 | from data.config import data_config, augmentation_config
8 | from data.dataloader import data_loader_init
9 | from model.config import model_config
10 | from loss.config import optim_init
11 | from augment.gen_positive import GenPositive
12 | from bk.option_old import args
13 | import numpy as np
14 | import cv2
15 | import skvideo.io
16 |
17 |
18 | def save_one_video(video, idx=1, title='origin'):
19 | video = video.squeeze(0)
20 | video_tensor = torch.tensor(video.detach().cpu().numpy().transpose(1, 2, 3, 0)) # 16 x 31 x 31 x 3
21 | path = "../experiments/gen_videos/{}_{}.jpg".format(title, idx)
22 | img = np.zeros((video_tensor.size(1), video_tensor.size(2), video_tensor.size(3)), np.uint8)
23 | img.fill(90)
24 | cv2.putText(img, title, (10, 50), cv2.FONT_HERSHEY_SCRIPT_COMPLEX, 1, (0, 255, 255), 1, cv2.LINE_AA)
25 | output = img
26 | for i in range(video_tensor.shape[0]):
27 | if i % 3 == 0:
28 | output = np.concatenate((output, np.uint8(video_tensor[i] * 255)), axis=1)
29 | cv2.imwrite(path, output)
30 | print("index: {} finished".format(idx))
31 | return output
32 |
33 |
34 | def rgb_flow(prvs, next):
35 | hsv = np.zeros_like(prvs)
36 | hsv[..., 1] = 255
37 | prvs = cv2.cvtColor(prvs, cv2.COLOR_RGB2GRAY)
38 | next = cv2.cvtColor(next, cv2.COLOR_RGB2GRAY)
39 | flow = cv2.calcOpticalFlowFarneback(prvs, next, None, 0.5, 3, 15, 3, 5, 1.2, 0)
40 | mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1])
41 | hsv[..., 0] = ang * 180 / np.pi / 2
42 | hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX)
43 | flow_rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
44 | return flow_rgb
45 |
46 |
47 | def write_video(path, output):
48 | fps = 30
49 | writer = skvideo.io.FFmpegWriter(path,
50 | outputdict={'-b': '300000000', '-r': str(fps)})
51 | # print(len(output))
52 | for frame in output:
53 | frame = np.array(frame)
54 | writer.writeFrame(frame)
55 | writer.close()
56 |
57 |
58 | def save_as_video(a, p, n, tn, idx=1, title='triples'):
59 | a = a.squeeze(0)
60 | p = p.squeeze(0)
61 | n = n.squeeze(0)
62 | tn = tn.squeeze(0)
63 | path = "../experiments/gen_videos/{}_{}.mp4".format(title, idx)
64 | a_p_path = "../experiments/gen_videos/{}_{}.mp4".format('a&p_', idx)
65 | a_n_path = "../experiments/gen_videos/{}_{}.mp4".format('a&n_', idx)
66 | flows_path = "../experiments/gen_videos/{}_{}.mp4".format('flows_', idx)
67 | a_tensor = torch.tensor(a.detach().cpu().numpy().transpose(1, 2, 3, 0)) # 16 x 31 x 31 x 3
68 | p_tensor = torch.tensor(p.detach().cpu().numpy().transpose(1, 2, 3, 0)) # 16 x 31 x 31 x 3
69 | n_tensor = torch.tensor(n.detach().cpu().numpy().transpose(1, 2, 3, 0)) # 16 x 31 x 31 x 3
70 | tn_tensor = torch.tensor(tn.detach().cpu().numpy().transpose(1, 2, 3, 0)) # 16 x 31 x 31 x 3
71 | output = []
72 | a_p_output = []
73 | a_n_output = []
74 | flows_output = []
75 | # print(a_tensor.size(0))
76 | for i in range(a_tensor.size(0) - 1):
77 | a_img = np.uint8(a_tensor[i] * 255)
78 | a_img = cv2.cvtColor(a_img, cv2.COLOR_BGR2RGB)
79 | p_img = np.uint8(p_tensor[i] * 255)
80 | n_img = np.uint8(n_tensor[i] * 255)
81 | tn_img = cv2.cvtColor(np.uint8(tn_tensor[i] * 255), cv2.COLOR_BGR2RGB)
82 | a_img_next = np.uint8(a_tensor[i+1] * 255)
83 | p_img_next = np.uint8(p_tensor[i+1] * 255)
84 | n_img_next = np.uint8(n_tensor[i + 1] * 255)
85 | tn_img_next = cv2.cvtColor(np.uint8(tn_tensor[i + 1] * 255), cv2.COLOR_BGR2RGB)
86 | flow_a = rgb_flow(a_img, a_img_next)
87 | flow_p = rgb_flow(p_img, p_img_next)
88 | flow_n = rgb_flow(n_img, n_img_next)
89 | flow_tn = rgb_flow(tn_img, tn_img_next)
90 |
91 | rgb_cat = np.concatenate((a_img, p_img, n_img, tn_img), 1)
92 | flow_cat = np.concatenate((flow_a, flow_p, flow_n, flow_tn), 1)
93 | img = np.concatenate((rgb_cat, flow_cat), 0)
94 | output.append(img)
95 | # a_p_output.append(np.concatenate((a_img, p_img), axis=1))
96 | # a_n_output.append(np.concatenate((a_img, n_img), axis=1))
97 | # flows_output.append(np.concatenate((flow_a, flow_p, flow_n), axis=1))
98 | # write_video(a_p_path, a_p_output)
99 | # write_video(a_n_path, a_n_output)
100 | # write_video(flows_path, flows_output)
101 | write_video(path, output)
102 | print("video: {} finished".format(idx))
103 | return
104 |
105 |
106 | def validate(tc, val_loader, model):
107 | # switch to evaluate mode
108 | model.eval()
109 | with torch.no_grad():
110 | for i, (inputs, target, index) in enumerate(val_loader):
111 | if i > 100:
112 | break
113 | # target = target.cuda(async=True)
114 | target = target.cuda()
115 | for j in range(len(inputs)):
116 | inputs[j] = inputs[j].float()
117 | inputs[j] = inputs[j].cuda()
118 | # ===================forward=====================
119 | anchor, positive, negative, t_wrap, s_wrap = inputs
120 | positive = tc(positive)
121 | save_as_video(anchor, positive, negative, t_wrap, idx=index.cpu().data)
122 | # anchor_cat = save_one_video(anchor, idx=index, title='anchor')
123 | # positive_cat = save_one_video(positive, idx=index, title='positive')
124 | # negative_cat = save_one_video(negative, idx=index, title='negative')
125 | # imgs = np.concatenate((anchor_cat, positive_cat, negative_cat))
126 | # path = "../experiments/gen_videos/{}_{}.jpg".format('triplet', index)
127 | # cv2.imwrite(path, imgs)
128 | return None
129 |
130 |
131 | if __name__ == '__main__':
132 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # close the warning
133 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
134 | args.batch_size = 1
135 | args.data_length = 128
136 | args.stride = 1
137 | args.spatial_size = 224
138 | args.mode = 'rgb'
139 | args.eval_indict = 'loss'
140 | args.pt_loss = 'flow'
141 | args.workers = 1
142 | args.print_freq = 100
143 | args.dataset = 'ucf101'
144 | args.train_list = '../datasets/lists/ucf101/ucf101_rgb_train_split_1.txt'
145 | args.val_list = '../datasets/lists/ucf101/ucf101_rgb_val_split_1.txt'
146 | # args.root = ""
147 | args.root = "/data1/awinywang/Data/ft_local/ucf101/jpegs_256/" # 144
148 | pos_aug = GenPositive()
149 | torch.manual_seed(1)
150 | cudnn.benchmark = True
151 | timer = Timer()
152 | # == dataset config==
153 | num_class, data_length, image_tmpl = data_config(args)
154 | train_transforms, test_transforms, eval_transforms = augmentation_config(args)
155 | train_data_loader, val_data_loader, eval_data_loader, _, _, _ = data_loader_init(args, data_length, image_tmpl, train_transforms, test_transforms, eval_transforms)
156 | # == model config==
157 | model = model_config(args, num_class)
158 | # == optim config==
159 | train_criterion, val_criterion, optimizer = optim_init(args, model)
160 | # == data augmentation(self-supervised) config==
161 | validate(pos_aug, val_data_loader, model)
--------------------------------------------------------------------------------
/src/Contrastive/utils/visualization/t_SNE_Visualization.py:
--------------------------------------------------------------------------------
1 | # That's an impressive list of imports.
2 | import numpy as np
3 | import torch
4 | from numpy import linalg
5 | from numpy.linalg import norm
6 | from scipy.spatial.distance import squareform, pdist
7 |
8 | # We import sklearn.
9 | import sklearn
10 | from sklearn.manifold import TSNE
11 | from sklearn.datasets import load_digits
12 | from sklearn.preprocessing import scale
13 |
14 | # We'll hack a bit with the t-SNE code in sklearn 0.15.2.
15 | from sklearn.metrics.pairwise import pairwise_distances
16 | from sklearn.manifold.t_sne import (_joint_probabilities,
17 | _kl_divergence)
18 | # Random state.
19 | RS = 20150101
20 |
21 | # We'll use matplotlib for graphics.
22 | import matplotlib.pyplot as plt
23 | import matplotlib.patheffects as PathEffects
24 | import matplotlib
25 | # %matplotlib inline
26 |
27 | # We import seaborn to make nice plots.
28 | import seaborn as sns
29 | import os
30 | sns.set_style('darkgrid')
31 | sns.set_palette('muted')
32 | sns.set_context("notebook", font_scale=1.5,
33 | rc={"lines.linewidth": 2.5})
34 |
35 | # We'll generate an animation with matplotlib and moviepy.
36 | # from moviepy.video.io.bindings import mplfig_to_npimage
37 | # import moviepy.editor as mpy
38 |
39 |
40 | def load_data(file_name):
41 | # digits = load_digits()
42 | # digits.data.shape # 1797 x 64
43 | # print(digits.data.shape)
44 | # print(digits['DESCR'])
45 | # return digits
46 | features = np.load(file_name, allow_pickle='TRUE').item()
47 | return features
48 |
49 |
50 | def scatter(x, colors, num_class=10):
51 | # We choose a color palette with seaborn.
52 | palette = np.array(sns.color_palette("hls", num_class))
53 | # sns.palplot(sns.color_palette("hls", 10))
54 | # We create a scatter plot.
55 | labels=['brush_hair', 'cartwheel', 'catch', 'chew',
56 | 'clap', 'climb', 'climb_stairs', 'dive', 'draw_sword',
57 | 'dribble']
58 | f = plt.figure(figsize=(8, 8))
59 | # print(colors.astype(np.int))
60 | ax = plt.subplot(aspect='equal')
61 | # for i in range(10):
62 | # sc = ax.scatter(x[:, 0][30*i:30*(i+1)], x[:, 1][30*i:30*(i+1)], c=palette[colors.astype(np.int)][30*i:30*(i+1)],
63 | # s=40,
64 | # label=labels[i],
65 | # )
66 | sc = ax.scatter(x[:,0], x[:,1], c=palette[colors.astype(np.int)],
67 | s=150,
68 | #label=colors.astype(np.int)[30],
69 | )
70 | # ax.legend(loc="best", title="Classes", bbox_to_anchor=(0.2, 0.4))
71 | plt.xlim(-25, 25)
72 | plt.ylim(-25, 25)
73 | ax.axis('off')
74 | ax.axis('tight')
75 |
76 | # We add the labels for each digit.
77 | txts = []
78 | for i in range(num_class):
79 | # Position of each label.
80 | xtext, ytext = np.median(x[colors == i, :], axis=0)
81 | txt = ax.text(xtext, ytext, str(i), fontsize=24)
82 | # ax.legend(ytext, "a")
83 | txt.set_path_effects([
84 | PathEffects.Stroke(linewidth=5, foreground="w"),
85 | PathEffects.Normal()])
86 | txts.append(txt)
87 | # ax.legend(('a','b','c','d','e'))
88 | return f, ax, sc, txts
89 |
90 |
91 | def tsne_visualize(data, file_name, num_class=101):
92 | # nrows, ncols = 2, 5
93 | # plt.figure(figsize=(6,3))
94 | # plt.gray()
95 | # for i in range(ncols * nrows):
96 | # ax = plt.subplot(nrows, ncols, i + 1)
97 | # ax.matshow(digits.images[i,...])
98 | # plt.xticks([]); plt.yticks([])
99 | # plt.title(digits.target[i])
100 | # plt.savefig('../../../experiments/visualization/digits-generated.png', dpi=150)
101 |
102 | # We first reorder the data points according to the handwritten numbers.
103 | datas = []
104 | labels = []
105 | nums = len(data['target'])
106 | for i in range(num_class, num_class + 10):
107 | for j in range(nums):
108 | if data['target'][j] == i:
109 | datas.append(data['data'][j])
110 | X = np.vstack(datas)
111 | for i in range(num_class, num_class + 10):
112 | for j in range(nums):
113 | if data['target'][j] == i:
114 | labels.append(data['target'][j] - num_class)
115 | y = np.hstack(labels)
116 | # X = np.vstack([data['data'][data['target']==i].cpu()
117 | # for i in range(10)])
118 | # y = np.hstack([data['target'][data['target']==i].cpu()
119 | # for i in range(10)])
120 | # print(y)
121 | digits_proj = TSNE(random_state=RS).fit_transform(X)
122 | scatter(digits_proj, y)
123 | plt.savefig(file_name, dpi=120)
124 |
125 |
126 | # for begin_index in range(1,41,5):
127 | # front = 'self_supervised_kineticsMoCo'
128 | # s_path = "../../../experiments/visualization/" + front + str(begin_index) + 'to' + str(begin_index + 10)
129 | # if not os.path.exists(s_path):
130 | # os.mkdir(s_path)
131 | # features_file = "../../../experiments/visualization/{}_hmdb51_features.npy".format(front)
132 | # file_name = "{}/{}_hmdb51_tsne-generated.png".format(s_path, front)
133 | # data = load_data(features_file)
134 | # tsne_visualize(data, file_name, begin_index)
135 |
136 | # for front in ['scratch', 'fully_supervised_37', 'self_supervised_37','fully_supervised_scratch']:
137 | # # front = 'scratch'
138 | # features_file = "../../../experiments/visualization/{}_hmdb51_features.npy".format(front)
139 | # file_name = "{}/{}hmdb51_tsne-generated.png".format(s_path, front)
140 | # data = load_data(features_file)
141 | # tsne_visualize(data, file_name, begin_index)
142 |
143 |
144 | if __name__ == '__main__':
145 | # front = 'contrastive_kinetics_warpping_hmdb51'
146 | # front = 'contrastive_ucf101_warpping_hmdb51'
147 | front = 'i3d_fully_supervised_kinetics_warpping_hmdb51'
148 | for begin_index in range(1, 41, 5):
149 | s_path = "../../../experiments/visualization/TSNE/" + front + str(begin_index) + 'to' + str(begin_index + 10)
150 | if not os.path.exists(s_path):
151 | os.mkdir(s_path)
152 | features_file = "../../../experiments/features/{}/val_features.npy".format(front)
153 | file_name = "{}/{}_hmdb51_tsne-generated.png".format(s_path, front)
154 | data = load_data(features_file)
155 | tsne_visualize(data, file_name, begin_index)
--------------------------------------------------------------------------------
/src/Contrastive/utils/visualization/triplet_visualization.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | import os
4 | import torch
5 | import torch.backends.cudnn as cudnn
6 | from data.config import data_config, augmentation_config
7 | from data.dataloader import data_loader_init
8 | from model.config import model_config
9 | from augment.gen_positive import GenPositive
10 | from utils.visualization.triplet_visualization import triplet_visualize, save_img
11 | from bk.option_old import args
12 |
13 | lowest_val_loss = float('inf')
14 | best_prec1 = 0
15 | torch.manual_seed(1)
16 |
17 |
18 | def test(train_loader, model, pos_aug):
19 | model.eval()
20 | for i, (inputs, target, index) in enumerate(train_loader):
21 | anchor, positive, negative = inputs
22 | anchor = pos_aug(anchor)
23 | dir_path = "../experiments/visualization/triplet2/{}".format(index[0])
24 | if not os.path.exists(dir_path):
25 | os.makedirs(dir_path)
26 | mask_img = triplet_visualize(anchor.cpu().numpy(), positive.cpu().numpy(), negative.cpu().numpy(), dir_path)
27 | path = "{}/{}.png".format(dir_path, str(index[0].cpu().numpy()))
28 | save_img(mask_img, path)
29 | print("{}/{} finished".format(i, len(train_loader)))
30 | # output = model(anchor)
31 | return True
32 |
33 |
34 | def main():
35 | # =
36 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # close the warning
37 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
38 | torch.manual_seed(1)
39 | cudnn.benchmark = True
40 | # == dataset config==
41 | num_class, data_length, image_tmpl = data_config(args)
42 | train_transforms, test_transforms, eval_transforms = augmentation_config(args)
43 | train_data_loader, val_data_loader, _, _, _, _ = data_loader_init(args, data_length, image_tmpl, train_transforms, test_transforms, eval_transforms)
44 | # == model config==
45 | model = model_config(args, num_class)
46 | pos_aug = GenPositive()
47 | # == train and eval==
48 | test(train_data_loader, model, pos_aug)
49 | return 1
50 |
51 |
52 | if __name__ == '__main__':
53 | main()
--------------------------------------------------------------------------------
/src/Contrastive/utils/visualization/video_write_test.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import skvideo.io
3 | import numpy as np
4 | import os
5 | import augment.video_transformations.video_transform_PIL_or_np as video_transform
6 | from torchvision import transforms
7 | import skimage.transform
8 | import random
9 | from PIL import Image
10 | import torchvision
11 |
12 |
13 | def read_video(video):
14 | cap = cv2.VideoCapture(video)
15 | frames = list()
16 | while True:
17 | ret, frame = cap.read()
18 | if type(frame) is type(None):
19 | break
20 | else:
21 | frames.append(frame)
22 | return frames
23 |
24 |
25 | def write_video(name, frames):
26 | # fshape = frames[0].shape
27 | # fheight = fshape[0]
28 | # fwidth = fshape[1]
29 | # writer = cv2.VideoWriter(name,
30 | # cv2.VideoWriter_fourcc(*"MJPG"), 30, (fheight, fwidth))
31 | # for i in range(len(frames)):
32 | # writer.write(frames[i])
33 | # writer.release()
34 | writer = skvideo.io.FFmpegWriter(name,
35 | outputdict={'-b': '300000000'})
36 | for frame in frames:
37 | frame = np.array(frame)
38 | writer.writeFrame(frame)
39 | writer.close()
40 | return 1
41 |
42 |
43 | if __name__ == '__main__':
44 | # video = "../experiments/test.mp4"
45 | # frames = read_video(video)
46 |
47 | # # aug = video_transform.RandomRotation(10),
48 | # # video_transform.STA_RandomRotation(10),
49 | # # video_transform.Each_RandomRotation(10),
50 | # # train_transforms = transforms.Compose([video_transform.Resize(128),
51 | # # video_transform.RandomCrop(112), aug])
52 | # frames = np.array(frames)
53 | # prefix = "random_rotation"
54 | # angle = random.uniform(-45, 45)
55 | # rotated = [skimage.transform.rotate(img, angle) for img in frames]
56 | # name = '../experiments/gen_videos/test_{}.avi'.format(prefix)
57 | # write_video(name, rotated)
58 | # prefix = "STA_rotation"
59 | # bsz = len(frames)
60 | # angles = [(i + 1) / (bsz + 1) * angle for i in range(bsz)]
61 | # rotated = [skimage.transform.rotate(img, angles[i]) for i, img in enumerate(frames)]
62 | # name = '../experiments/gen_videos/test_{}.avi'.format(prefix)
63 | # write_video(name, rotated)
64 | # prefix = "each_random_rotation"
65 | # angles = [random.uniform(-45, 45) for i in range(bsz)]
66 | # rotated = [skimage.transform.rotate(img, angles[i]) for i, img in enumerate(frames)]
67 | # name = '../experiments/gen_videos/test_{}.avi'.format(prefix)
68 | # write_video(name, rotated)
69 | # # for i in range(10):
70 | # # seqs = np.load("{}_{}.npy".format("../experiments/augmentation/{}".format(prefix), i))
71 | # # seqs = (seqs+1)/2*255
72 | # # out_dir = "../experiments/gen_videos/{}".format(i)
73 | # # if not os.path.exists(out_dir):
74 | # # os.makedirs(out_dir)
75 | # # for j in range(16):
76 | # # cv2.imwrite("{}/{}_{}.jpg".format(out_dir, prefix, j), seqs[j])
77 | # # name = '../experiments/gen_videos/test_{}.mp4'.format(i)
78 | # # write_video(name, seqs)
79 | # video = "../experiments/test.mp4"
80 | # frames = read_video(video)
81 | # images = []
82 | # frames = np.array(frames)
83 | # for i in range(len(frames)):
84 | # img = Image.fromarray(np.uint8(frames[i]))
85 | # images.append(img)
86 | # # Create img transform function sequence
87 | # img_transforms = []
88 | # brightness = 0.5
89 | # contrast = 0.5
90 | # saturation = 0.5
91 | # hue = 0.2
92 | # if brightness is not None:
93 | # img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))
94 | # if saturation is not None:
95 | # img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))
96 | # if hue is not None:
97 | # img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue))
98 | # if contrast is not None:
99 | # img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))
100 | # random.shuffle(img_transforms)
101 | #
102 | # # Apply to all images
103 | # jittered_clip = []
104 | # for img in images:
105 | # for func in img_transforms:
106 | # jittered_img = func(img)
107 | # jittered_clip.append(jittered_img)
108 | # name = '../experiments/gen_videos/test_{}.avi'.format('jitter')
109 | # write_video(name, jittered_clip)
110 | video = "../experiments/test.mp4"
111 | frames = read_video(video)
112 | images = []
113 | frames = np.array(frames)
114 | for i in range(len(frames)):
115 | img = Image.fromarray(np.uint8(frames[i]))
116 | images.append(img)
117 | # Create img transform function sequence
118 | brightness = 0.5
119 | contrast = 0.5
120 | saturation = 0.5
121 | hue = 0.2
122 | # img_transforms = []
123 | # img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))
124 | # img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))
125 | # img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue))
126 | # img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))
127 | # random.shuffle(img_transforms)
128 | #
129 | # # Apply to all images
130 | # jittered_clip = []
131 | # for img in images:
132 | # for func in img_transforms:
133 | # jittered_img = func(img)
134 | # jittered_clip.append(jittered_img)
135 | # name = '../experiments/gen_videos/test_{}.avi'.format('sta_jitter')
136 | # write_video(name, jittered_clip)
137 | # Apply to all images
138 | jittered_clip = []
139 | for i, img in enumerate(images):
140 | t_brightness = (i+1)/(len(images)+1) * brightness
141 | t_contrast = (i + 1) / (len(images) + 1) * contrast
142 | t_saturation = (i + 1) / (len(images) + 1) * saturation
143 | t_hue = (i + 1) / (len(images) + 1) * hue
144 | img_transforms = []
145 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, t_brightness))
146 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, t_saturation))
147 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, t_hue))
148 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, t_contrast))
149 | random.shuffle(img_transforms)
150 | for func in img_transforms:
151 | jittered_img = func(img)
152 | jittered_clip.append(jittered_img)
153 | name = '../experiments/gen_videos/test_{}.avi'.format('sta_jitter')
154 | write_video(name, jittered_clip)
--------------------------------------------------------------------------------