├── .gitignore
├── LICENSE
├── README.md
├── SceneSeg
├── BiLSTM_protocol.py
├── __init__.py
├── main.py
└── movienet_seg_data.py
├── cluster
├── Group.py
└── cluster_test.ipynb
├── config
├── SCRL_pretrain_default.yaml
├── SCRL_pretrain_with_imagenet1k.yaml
└── SCRL_pretrain_without_imagenet1k.yaml
├── data
├── MovieNet_1.0_shotinfo.json
├── MovieNet_shot_num.json
├── data_preparation.py
├── movie1K.scene_seg_318_name_index_shotnum_label.v1.json
├── movie1K.split.v1.json
└── movienet_data.py
├── extract_embeddings.py
├── figures
└── puzzle_example.jpg
├── models
├── __init__.py
├── backbones
│ ├── __init__.py
│ └── visual
│ │ └── resnet.py
├── core
│ ├── SCRL_MoCo.py
│ └── __init__.py
└── factory.py
├── pretrain_main.py
├── pretrain_trainer.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | output/
2 | compressed_shot_images/
3 | embeddings/
4 | checkpoints/
5 | SceneSeg/output/
6 | pretrain/
7 | __pycache__/
8 | *.pkl
9 | *.log
10 | *.txt
11 | *.pth
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Scene Consistency Representation Learning for Video Scene Segmentation (CVPR2022)
2 | This is an official PyTorch implementation of SCRL, the CVPR2022 paper is available at [here](https://openaccess.thecvf.com/content/CVPR2022/html/Wu_Scene_Consistency_Representation_Learning_for_Video_Scene_Segmentation_CVPR_2022_paper.html).
3 |
4 | # Getting Started
5 |
6 | ## Data Preparation
7 | ### MovieNet Dataset
8 | Download MovieNet Dataset from its [Official Website](https://movienet.github.io/).
9 | ### SceneSeg318 Dataset
10 | Download the Annotation of [SceneSeg318](https://drive.google.com/drive/folders/1NFyL_IZvr1mQR3vR63XMYITU7rq9geY_?usp=sharing), you can find the download instructions in [LGSS](https://github.com/AnyiRao/SceneSeg/blob/master/docs/INSTALL.md) repository.
11 |
12 | ### Make Puzzles for pre-training
13 | In order to reduce the number of IO accesses and perform data augmentation (a.k.a *Scene Agnostic Clip-Shuffling* in the paper) at the same time, we suggest to stitch 16 shots into one image (puzzle) during the pre-training stage. You can make the data by yourself:
14 | ```
15 | python ./data/data_preparation.py
16 | ```
17 | And the processed data will be saved in `./compressed_shot_images/`, a puzzle example [figure](./figures/puzzle_example.jpg).
18 |
19 |
20 |
21 | ### Load the Data into Memory [Optional]
22 | We **strongly recommend** loading data into memory to speed up pre-training, which additionally requires your device to have at least 100GB of RAM.
23 | ```
24 | mkdir /tmpdata
25 | mount tmpfs /tmpdata -t tmpfs -o size=100G
26 | cp -r ./compressed_shot_images/ /tmpdata/
27 | ```
28 |
29 |
30 | ## Initialization Weights Preparation
31 | Download the ResNet-50 weights trained on ImageNet-1k ([resnet50-19c8e357.pth](https://download.pytorch.org/models/resnet50-19c8e357.pth)), and save it in `./pretrain/` folder.
32 |
33 | ## Prerequisites
34 | ### Requirements
35 | * python >= 3.6
36 | * pytorch >= 1.6
37 | * cv2
38 | * pickle
39 | * numpy
40 | * yaml
41 | * sklearn
42 |
43 | ### Hardware
44 | * 8 NVIDIA V100 (32GB) GPUs
45 |
46 | # Usage
47 | ### STEP 1: Encoder Pre-training
48 | Using the default configuration to pretrain the model. Make sure the data path is correct and the GPUs are sufficient (e.g. 8 NVIDIA V100 GPUs)
49 | ```
50 | python pretrain_main.py --config ./config/SCRL_pretrain_default.yaml
51 | ```
52 | The checkpoint, copy of config and log will be saved in `./output/`.
53 |
54 | ### STEP 2: Feature Extraction
55 |
56 | ```
57 | python extract_embeddings.py $CKP_PATH --shot_img_path $SHOT_PATH --Type all --gpu-id 0
58 | ```
59 | `$CKP_PATH` is the path of an encoder checkpoint, and `$SHOT_PATH` is the keyframe path of MovieNet.
60 | The extracted embeddings (in pickle format) and log will be saved in `./embeddings/`.
61 |
62 | ### STEP 3: Video Scene Segmentation Evaluation
63 |
64 | ```
65 | cd SceneSeg
66 |
67 | python main.py \
68 | -train $TRAIN_PKL_PATH \
69 | -test $TEST_PKL_PATH \
70 | -val $VAL_PKL_PATH \
71 | --seq-len 40 \
72 | --gpu-id 0
73 | ```
74 |
75 | The checkpoints and log will be saved in `./SceneSeg/output/`.
76 |
77 | ## Models
78 | We provide checkpoints, logs and results under two different pre-training settings, i.e. with and without ImageNet-1K initialization, respectively.
79 |
80 | | Initialization | AP | F1 | Config File | STEP 1
Pre-training | STEP 2
Embeddings| STEP 3
Fine-tuning |
81 | | :-----| :---- | :---- | :---- | :-----| :---- | :---- |
82 | | w/o ImageNet-1k | 55.16 | 51.32 | SCRL_pretrain
_without_imagenet1k.yaml | [ckp and log](https://drive.google.com/drive/folders/1ZYg9PFRU_lt3G5qJrldkguA52T2oxErR?usp=sharing) | [embedings](https://drive.google.com/drive/folders/1uen_HP3BZu8bcrPBikkgV3j9wzUjQ0C1?usp=sharing) | [ckps and log](https://drive.google.com/drive/folders/1rJbOnVbqTdPmnh2grIkePXOmwpNELnrK?usp=sharing) |
83 | | w/ ImageNet-1k | 56.65 | 52.45 | SCRL_pretrain
_with_imagenet1k.yaml | [ckp and log](https://drive.google.com/drive/folders/1BG5ZLqrPKKGTtDIZj8aps_QuWc6K3c3V?usp=sharing) | [embedings](https://drive.google.com/drive/folders/1NFvGhkvRxpmEJYNjRnwp3ybuHQaG25gW?usp=sharing) | [ckps and log](https://drive.google.com/drive/folders/1dE0JFi-MDua70_CgI1CvyLNRnhwLjaUV?usp=sharing) |
84 |
85 |
86 | ## License
87 | Please see [LICENSE](./LICENSE) file for the details.
88 |
89 | ## Acknowledgments
90 | Part of codes are borrowed from the following repositories:
91 | * [MoCo](https://github.com/facebookresearch/moco)
92 | * [LGSS](https://github.com/AnyiRao/SceneSeg)
93 |
94 | ## Citation
95 | Please cite our work if it's useful for your research.
96 | ```
97 | @InProceedings{Wu_2022_CVPR,
98 | author = {Wu, Haoqian and Chen, Keyu and Luo, Yanan and Qiao, Ruizhi and Ren, Bo and Liu, Haozhe and Xie, Weicheng and Shen, Linlin},
99 | title = {Scene Consistency Representation Learning for Video Scene Segmentation},
100 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
101 | month = {June},
102 | year = {2022},
103 | pages = {14021-14030}
104 | }
105 | ```
--------------------------------------------------------------------------------
/SceneSeg/BiLSTM_protocol.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class BiLSTM(nn.Module):
7 | def __init__(self, input_feature_dim=2048, fc_dim=1024, hidden_size=512,
8 | input_drop_rate=0.3, lstm_drop_rate=0.6, fc_drop_rate=0.7, use_bn=True):
9 | super(BiLSTM, self).__init__()
10 |
11 | input_size = input_feature_dim
12 | output_size = fc_dim
13 | self.embed_sizes = input_feature_dim
14 | self.embed_fc = nn.Linear(input_size, output_size)
15 | self.hidden_size = hidden_size
16 | self.lstm = nn.LSTM(
17 | input_size=output_size,
18 | hidden_size=self.hidden_size,
19 | num_layers=2,
20 | batch_first=True,
21 | dropout=lstm_drop_rate,
22 | bidirectional=True
23 | )
24 | # The probability is set to 0 by default
25 | self.input_shotmask = ShotMask(p=0)
26 | self.input_dropout = nn.Dropout(p=input_drop_rate)
27 | self.fc_dropout = nn.Dropout(p=fc_drop_rate)
28 | self.fc1 = nn.Linear(self.hidden_size*2, hidden_size)
29 | self.fc2 = nn.Linear(hidden_size, 2)
30 | self.softmax = nn.Softmax(2)
31 | self.use_bn = use_bn
32 |
33 | if self.use_bn:
34 | self.bn1 = nn.BatchNorm1d(output_size)
35 | self.bn2 = nn.BatchNorm1d(hidden_size)
36 |
37 |
38 | def forward(self, x, y):
39 | if self.training:
40 | x = self.input_shotmask(x, y)
41 | x = self.input_dropout(x)
42 | x = self.embed_fc(x)
43 |
44 | if self.use_bn:
45 | seq_len, C = x.shape[1:3]
46 | x = x.view(-1, C)
47 | x = self.bn1(x)
48 | x = x.view(-1, seq_len, C)
49 |
50 | x = self.fc_dropout(x)
51 | self.lstm.flatten_parameters()
52 | out, (_, _) = self.lstm(x, None)
53 | out = self.fc1(out)
54 | if self.use_bn:
55 | seq_len, C = out.shape[1:3]
56 | out = out.view(-1, C)
57 | out = self.bn2(out)
58 | out = out.view(-1, seq_len, C)
59 | out = self.fc_dropout(out)
60 | out = F.relu(out)
61 | out = self.fc2(out)
62 | if not self.training:
63 | out = self.softmax(out)
64 | return out
65 |
66 |
67 | class ShotMask(nn.Module):
68 | '''
69 | Drop the shot from the middle of a scene
70 | '''
71 | def __init__(self, p=0.2):
72 | super(ShotMask, self).__init__()
73 | self.p = p
74 |
75 | def forward(self, x, y):
76 | # keep the cue
77 | B, L , _ = x.size()
78 | y_shift = torch.cat([torch.zeros(B,1,1).bool().to(y.device), y.bool()],dim=1)[:,:L,:]
79 | self.mask = torch.rand(*y.size()) >= self.p
80 | self.mask = self.mask.bool().to(x.device) | y.bool() | y_shift
81 | out = x.mul(self.mask)
82 | return out
83 |
84 | if __name__ == '__main__':
85 | B, seq_len, C = 10, 20, 2048
86 | input = torch.randn(B, seq_len, C)
87 | model = BiLSTM()
88 | out = model(input)
89 | # torch.Size([10, 20, 2])
90 | print(out.size())
--------------------------------------------------------------------------------
/SceneSeg/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentYoutuResearch/SceneSegmentation-SCRL/7d2daed4c8f1922aa6c85abaf9db36abaf0ae67e/SceneSeg/__init__.py
--------------------------------------------------------------------------------
/SceneSeg/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.parallel
4 | import torch.backends.cudnn as cudnn
5 | import torch.optim
6 | import numpy as np
7 | import os
8 | import argparse
9 | import random
10 | import time
11 | from sklearn.metrics import average_precision_score
12 | import shutil
13 | import os.path as osp
14 | from BiLSTM_protocol import BiLSTM
15 | from movienet_seg_data import MovieNet_SceneSeg_Dataset_Embeddings_Train, MovieNet_SceneSeg_Dataset_Embeddings_Val
16 |
17 | def main(args):
18 | setup_seed(100)
19 | model = BiLSTM(
20 | input_feature_dim=args.dim,
21 | input_drop_rate=args.input_drop_rate
22 | ).cuda()
23 |
24 | label_weights = torch.Tensor([args.loss_weight[0], args.loss_weight[1]]).cuda()
25 | criterion = nn.CrossEntropyLoss(label_weights).cuda()
26 |
27 |
28 | optimizer = torch.optim.SGD(model.parameters(),
29 | args.lr,
30 | momentum=args.momentum,
31 | weight_decay=args.weight_decay
32 | )
33 |
34 | train_dataset = MovieNet_SceneSeg_Dataset_Embeddings_Train(
35 | pkl_path=args.pkl_path_train,
36 | sampled_shot_num=args.seq_len,
37 | shuffle_p=args.sample_shulle_rate
38 | )
39 | val_dataset = MovieNet_SceneSeg_Dataset_Embeddings_Val(
40 | pkl_path=args.pkl_path_val,
41 | sampled_shot_num=args.seq_len
42 | )
43 |
44 | test_dataset = MovieNet_SceneSeg_Dataset_Embeddings_Val(
45 | pkl_path=args.pkl_path_test,
46 | sampled_shot_num=args.seq_len
47 | )
48 |
49 | train_loader = torch.utils.data.DataLoader(train_dataset, args.train_bs, num_workers=args.workers,
50 | shuffle=True, pin_memory=True, drop_last=False)
51 | test_loader = torch.utils.data.DataLoader(test_dataset, args.test_bs, num_workers=args.workers,
52 | shuffle=False, pin_memory=True, drop_last=False)
53 | val_loader = torch.utils.data.DataLoader(val_dataset, args.test_bs, num_workers=args.workers,
54 | shuffle=False, pin_memory=True, drop_last=False)
55 |
56 | train_fun = train
57 | test_fun = inference
58 |
59 | val_max_F1 = 0
60 | is_best = False
61 | test_info = {'mAP': 0, 'F1': 0}
62 | for epoch in range(1, args.epochs + 1):
63 | train_loader.dataset._shuffle_offset()
64 | adjust_learning_rate(args, optimizer, epoch)
65 | train_fun(args, model, train_loader, optimizer, epoch, criterion)
66 | if epoch % args.test_interval == 0 and epoch >= args.test_milestone:
67 | f1, map, acc_all = test_fun(args, model, val_loader)
68 | to_log(args, f'val set: {map, f1, acc_all}', True)
69 | if val_max_F1 < f1:
70 | val_max_F1 = f1
71 | f1_t, map_t, acc_all_t = test_fun(args, model, test_loader)
72 | test_info['mAP'] = map_t
73 | test_info['F1'] = f1_t
74 | is_best = True
75 | to_log(args, f'now best F1 on val is: {val_max_F1}', True)
76 | to_log(args, f'test set: {map_t, f1_t, acc_all_t}', True)
77 | else:
78 | is_best = False
79 | save_checkpoint({
80 | 'state_dict': model.state_dict(), 'epoch': epoch,
81 | }, is_best=is_best, fpath=os.path.join(args.save_dir, 'checkpoint.pth.tar'))
82 |
83 | to_log(args, f'best F1 on val: {val_max_F1}', True)
84 | to_log(args, f"the test set mAP: {test_info['mAP']}, F1: {test_info['F1']}", True)
85 |
86 |
87 | def train(args, model, train_loader, optimizer, epoch, criterion, log_interval=30):
88 | model.train()
89 | for batch_idx, (data, target, _) in enumerate(train_loader):
90 | data = data.cuda(non_blocking=True)
91 | target = target.unsqueeze(-1).cuda(non_blocking=True)
92 | output = model(data, target)
93 | output = output.view(-1, 2)
94 | target = target.view(-1)
95 | loss = criterion(output, target)
96 |
97 | optimizer.zero_grad()
98 | loss.backward()
99 | optimizer.step()
100 | if batch_idx % log_interval == 0:
101 | log = 'Train Epoch: {} [{}/{} ({:.0f}%)]'.format(epoch,
102 | int(batch_idx * len(data)), len(train_loader.dataset),
103 | 100. * batch_idx / len(train_loader)).ljust(40) + \
104 | 'Loss: {:.6f}'.format( loss.item())
105 | to_log(args, log, True)
106 |
107 | @torch.no_grad()
108 | def inference(args, model, loader, threshhold=0.5):
109 | model.eval()
110 | corr = 0
111 | total = 0
112 | stride = args.seq_len // 2
113 | result_all = {}
114 | for batch_idx, (data, target, imdb) in enumerate(loader):
115 | imdb = imdb[0]
116 | result_all[imdb] = None
117 | data = data.view(-1, args.dim).cuda(non_blocking=True)
118 | target = target.view(-1)
119 | data_len = data.size(0)
120 | gt_len = target.size(0)
121 | prob_all = []
122 | for w_id in range(data_len//stride):
123 | start_pos = w_id*stride
124 | _data = data[start_pos:start_pos + args.seq_len].unsqueeze(0)
125 | output = model(_data, None)
126 | output = output.view(-1, 2)
127 | prob = output[:, 1]
128 | prob = prob[stride//2:stride+stride//2].squeeze()
129 | prob_all.append(prob.cpu())
130 |
131 | # metrics
132 | preb_all = torch.cat(prob_all,axis=0)[:gt_len].numpy()
133 | pre = np.nan_to_num(preb_all) > threshhold
134 | gt = target.cpu().numpy().astype(int)
135 | pre = pre.astype(int)
136 | idx1 = np.where(gt == 1)[0]
137 | idx0 = np.where(gt == 0)[0]
138 | idx1_p = np.where(pre == 1)[0]
139 | idx0_p = np.where(pre == 0)[0]
140 | TP = len(np.where(gt[idx1] == pre[idx1])[0])
141 | FP = len(np.where(gt[idx1_p] != pre[idx1_p])[0])
142 | TN = len(np.where(gt[idx0] == pre[idx0])[0])
143 | FN = len(np.where(gt[idx0_p] != pre[idx0_p])[0])
144 | ap = get_ap(gt, preb_all, False)
145 | correct = len(np.where(gt == pre)[0])
146 | corr += correct
147 | total += gt_len
148 | recall = TP / (TP + FN + 1e-5)
149 | precision = TP / (TP + FP + 1e-5)
150 | f1 = 2 * recall * precision / (recall + precision + 1e-5)
151 | result_all[imdb] = (f1, ap, recall, precision)
152 | mAP_all_avg = 0
153 | F1_all_avg = 0
154 | for k, v in result_all.items():
155 | F1_all_avg += v[0]
156 | mAP_all_avg += v[1]
157 | F1_all_avg /= len(result_all.keys())
158 | mAP_all_avg /= len(result_all.keys())
159 | return F1_all_avg, mAP_all_avg, corr / total
160 |
161 |
162 | def setup_seed(seed):
163 | torch.manual_seed(seed)
164 | torch.cuda.manual_seed_all(seed)
165 | np.random.seed(seed)
166 | random.seed(seed)
167 | cudnn.benchmark = True
168 |
169 |
170 | def set_log(args):
171 | time_str = time.strftime("%Y-%m-%d_%H_%M_%S", time.localtime())
172 |
173 | args.log_file = './output/log_' + time_str + '.txt'
174 | args.save_dir = args.save_dir + 'seg_checkpoints/' + time_str + '/'
175 |
176 | if not os.path.exists(args.save_dir):
177 | os.makedirs(args.save_dir)
178 |
179 | if not os.path.exists('./output/'):
180 | os.makedirs('./output/')
181 |
182 | def to_log(args, content, echo=False):
183 | with open(args.log_file, 'a') as f:
184 | f.writelines(content+'\n')
185 | if echo:
186 | print(content)
187 |
188 | def adjust_learning_rate(args, optimizer, epoch):
189 | """Decay the learning rate based on schedule"""
190 | lr = args.lr
191 | for milestone in args.schedule:
192 | lr *= 0.1 if epoch >= milestone else 1.
193 | for param_group in optimizer.param_groups:
194 | param_group['lr'] = lr
195 |
196 | def get_ap(gts_raw,preds_raw,is_list=True):
197 | if is_list:
198 | gts,preds = [],[]
199 | for gt_raw in gts_raw:
200 | gts.extend(gt_raw.tolist())
201 | for pred_raw in preds_raw:
202 | preds.extend(pred_raw.tolist())
203 | else:
204 | gts = np.array(gts_raw)
205 | preds = np.array(preds_raw)
206 | # print ("AP ",average_precision_score(gts, preds))
207 | return average_precision_score(np.nan_to_num(gts), np.nan_to_num(preds))
208 | # return average_precision_score(gts, preds)
209 |
210 | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'):
211 | os.makedirs(osp.dirname(fpath),exist_ok=True)
212 | torch.save(state, fpath)
213 | if is_best:
214 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_best.pth.tar'))
215 |
216 |
217 | def get_config():
218 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
219 | parser.add_argument('--epochs', default=200, type=int, metavar='N',
220 | help='number of total epochs to run')
221 | # data
222 | parser.add_argument('-train', '--pkl-path-train', default='', type=str,
223 | help='the path of pickle train data')
224 |
225 | parser.add_argument('-test', '--pkl-path-test', default='', type=str,
226 | help='the path of pickle test data')
227 |
228 | parser.add_argument('-val', '--pkl-path-val', default='', type=str,
229 | help='the path of pickle val data')
230 |
231 | parser.add_argument('--train-bs', default=12, type=int)
232 | parser.add_argument('--test-bs', default=1, type=int)
233 | parser.add_argument('--shot-num', default=10, type=int)
234 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
235 | metavar='LR', help='initial learning rate', dest='lr')
236 | parser.add_argument('--gpu-id', type=str, default='0', help='gpu id')
237 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
238 | help='momentum of SGD solver')
239 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
240 | metavar='W', help='weight decay', dest='weight_decay')
241 |
242 | parser.add_argument('--save-dir', default='./output/', type=str,
243 | help='the path of checkpoints')
244 | # loss weight
245 | parser.add_argument('--loss-weight', default=[1, 4], nargs='+', type=float,
246 | help='loss weight')
247 | parser.add_argument('--sample-shulle-rate', default=1.0, type=float)
248 | parser.add_argument('--input-drop-rate', default=0.2, type=float)
249 | # lr schedule
250 | parser.add_argument('--schedule', default=[160, 180], nargs='+',
251 | help='learning rate schedule (when to drop lr by a ratio)')
252 |
253 | parser.add_argument('-j', '--workers', default=16, type=int,
254 | help='number of workers')
255 | parser.add_argument('--dim', default=2048, type=int)
256 | parser.add_argument('--seq-len', default=40, type=int)
257 | parser.add_argument('--test-interval', default=1, type=int)
258 | parser.add_argument('--test-milestone', default=100, type=int)
259 |
260 | args = parser.parse_args()
261 |
262 | # assert
263 | assert args.seq_len % 4 == 0
264 |
265 | # select GPU
266 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
267 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
268 |
269 | set_log(args)
270 | for arg in vars(args):
271 | to_log(args,arg.ljust(20)+':'+str(getattr(args, arg)), True)
272 | return args
273 |
274 | if __name__ == '__main__':
275 | args = get_config()
276 | main(args)
--------------------------------------------------------------------------------
/SceneSeg/movienet_seg_data.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import torch
3 | import torch.utils.data as data
4 | import numpy as np
5 | import random
6 |
7 | class MovieNet_SceneSeg_Dataset_Embeddings_Train(data.Dataset):
8 | def __init__(self, pkl_path, frame_size=3, shot_num=1,
9 | sampled_shot_num=10, shuffle_p=0.5,random_cat=False):
10 | self.shot_num = shot_num
11 | self.pkl_path = pkl_path
12 | self.frame_size = frame_size
13 | self.sampled_shot_num = sampled_shot_num
14 | self.shuffle_p = shuffle_p
15 | self.dict_idx_shot = {}
16 | self.data_length = 0
17 | self.random_cat = random_cat
18 | fileObject = open(self.pkl_path, 'rb')
19 | self.pickle_data = pickle.load(fileObject)
20 | fileObject.close()
21 | self.total_video_num = len(self.pickle_data.keys())
22 | idx = 0
23 | self.shuffle_map = {}
24 | self.shuffle_offset = {}
25 | for k, v in self.pickle_data.items():
26 | video_shot_group_num = (len(v) // self.sampled_shot_num) - 1
27 | self.shuffle_map[k] = (len(v) - self.sampled_shot_num * video_shot_group_num)
28 | self.shuffle_offset[k] = 0
29 | for i in range(video_shot_group_num):
30 | self.dict_idx_shot[idx] = (k, i)
31 | idx += 1
32 | self._shuffle_offset()
33 | print(f'Train video num: {self.total_video_num}')
34 | print(f'total shot group: {idx}')
35 | self.data_length = idx
36 |
37 | def _shuffle_offset(self):
38 | for k, offset_upper_bound in self.shuffle_map.items():
39 | offset = random.randint(0, offset_upper_bound-1)
40 | offset = 0 if offset < 0 else offset
41 | self.shuffle_offset[k] = offset
42 |
43 | def _get_randomly_cat_clip(self, idx):
44 | k, i = self.dict_idx_shot[idx]
45 | sampled_len = self.sampled_shot_num // 2
46 | # randomly cat an another clip
47 | data1, label1, _ = self._get_clip_by_idx(idx, sampled_len)
48 | # fix last shot label
49 | label1[-1] = 1
50 | # random the index
51 | length = len(self.pickle_data[k])
52 | start = random.randint(0, length - sampled_len - 1)
53 |
54 | p = self.pickle_data[k][start : start + sampled_len]
55 | data = np.array([p[i][0] for i in range(sampled_len)])
56 | label = np.array([p[i][1] for i in range(sampled_len)])
57 | data2 = torch.from_numpy(data).squeeze(1)
58 | label2 = torch.from_numpy(label).long()
59 |
60 | data = torch.cat([data1, data2],dim=0)
61 | label = torch.cat([label1, label2],dim=0)
62 | return data, label, k
63 |
64 |
65 | def _seg_shuffle(self, data, label):
66 | new_d, new_l = [], []
67 | clips = []
68 | # find positive pos
69 | p_index = torch.where(label>=1)[0]
70 | start, end = 0, len(label)
71 | for i in p_index:
72 | i = i.item()
73 | clips.append((start, i+1))
74 | start = i+1
75 | if start != end:
76 | clips.append((start, end))
77 | # if the last clip is used for shulling
78 | # the label of the last shot might be changed
79 | label[-1] = 1
80 | clips_len = len(clips)
81 | index_list = random.sample(range(0, clips_len), clips_len)
82 | for i in index_list:
83 | s, e = clips[i]
84 | new_d.append(data[s:e])
85 | new_l.append(label[s:e])
86 | d = torch.cat(new_d,dim=0)
87 | l = torch.cat(new_l,dim=0)
88 | # when shuffling is done, fix the last shot label
89 | l[-1] = 0
90 | return d, l
91 |
92 | def _get_clip_by_idx(self, idx, length):
93 | k , i = self.dict_idx_shot[idx]
94 | offset = self.shuffle_offset[k]
95 | s = self.sampled_shot_num
96 | p = self.pickle_data[k][i*s+offset:(i+1)*s+offset][:length]
97 | data = np.array([p[i][0] for i in range(length)])
98 | label = np.array([p[i][1] for i in range(length)])
99 | data = torch.from_numpy(data).squeeze(1)
100 | label = torch.from_numpy(label).long()
101 | # fix last shot label
102 | label[-1] = 0
103 | return data, label, k
104 |
105 |
106 | def __getitem__(self, idx):
107 | if not self.random_cat:
108 | data, label, k = self._get_clip_by_idx(idx, self.sampled_shot_num)
109 | else:
110 | data, label, k = self._get_randomly_cat_clip(idx)
111 | if random.random() < self.shuffle_p:
112 | data, label = self._seg_shuffle(data, label)
113 | return data, label, k
114 |
115 | def __len__(self):
116 | return self.data_length
117 |
118 | class MovieNet_SceneSeg_Dataset_Embeddings_Val(data.Dataset):
119 | def __init__(self, pkl_path, frame_size=3, shot_num=1,
120 | sampled_shot_num=100):
121 | self.shot_num = shot_num
122 | self.pkl_path = pkl_path
123 | self.frame_size = frame_size
124 | self.sampled_shot_num = sampled_shot_num
125 | self.dict_idx_shot = {}
126 | self.data_length = 0
127 | fileObject = open(self.pkl_path, 'rb')
128 | self.pickle_data = pickle.load(fileObject)
129 | fileObject.close()
130 | self.total_video_num = len(self.pickle_data.keys())
131 | idx = 0
132 | for k, v in self.pickle_data.items():
133 | self.dict_idx_shot[idx] = (k, v)
134 | idx += 1
135 | print(f'video num: {self.total_video_num}')
136 | self.data_length = idx
137 |
138 | def _padding(self, data):
139 | stride = self.sampled_shot_num // 2
140 | shot_len = data.size(0)
141 | p_l = data[0].repeat(self.sampled_shot_num // 4, 1)
142 | p_r_len = self.sampled_shot_num // 4
143 | res = shot_len % (stride)
144 | if res != 0:
145 | p_r_len += (stride) - res
146 | p_r = data[-1].repeat(p_r_len, 1)
147 | pad_data = torch.cat((p_l, data, p_r),0)
148 | assert pad_data.size(0) % stride == 0
149 | return pad_data
150 |
151 | def __getitem__(self, idx):
152 | k, v = self.dict_idx_shot[idx]
153 | num_shot = len(v)
154 | data = np.array([v[i][0] for i in range(num_shot)])
155 | label = np.array([v[i][1] for i in range(num_shot)])
156 | data = torch.from_numpy(data).squeeze(1)
157 | data = self._padding(data)
158 | label = torch.from_numpy(label)
159 | return data, label, k
160 |
161 | def __len__(self):
162 | return self.data_length
--------------------------------------------------------------------------------
/cluster/Group.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch
3 | import numpy as np
4 | import time
5 |
6 |
7 | class Cluster_GPU():
8 | '''
9 | A pytorch GPU implementation for K-Means algorithm,
10 | which is used for real-time clutering in SCRL.
11 | '''
12 | def __init__(self,
13 | num_clusters,
14 | shift_threshold=1e-2,
15 | max_iter=20,
16 | device=torch.device('cuda'),
17 | debug=False):
18 | self.cluster_func = KMeans_Mixed(
19 | num_clusters=num_clusters,
20 | shift_threshold=shift_threshold,
21 | max_iter=max_iter,
22 | device=device
23 | )
24 | self.device = device
25 | self.debug = debug
26 |
27 | def __call__(self, x):
28 | dimension = len(x.size())
29 | x = x.to(self.device)
30 | B = x.size(0)
31 | output_vector = x.clone().detach()
32 | # D == 2
33 | if dimension == 2:
34 | _, choice_cluster, choice_points = self.cluster_func(output_vector, debug=self.debug)
35 | # D >= 3
36 | elif dimension == 3:
37 | choice_cluster_list, cluster_points_list = [], []
38 | for batch in range(B):
39 | y = output_vector.narrow(dim=0, start=batch, length=1).squeeze(0)
40 | _, choice_cluster, choice_points = self.cluster_func(y, debug=self.debug)
41 | choice_cluster_list.append(choice_cluster)
42 | cluster_points_list.append(choice_points)
43 | choice_cluster = np.stack(choice_cluster_list)
44 | choice_points = np.stack(cluster_points_list)
45 | else:
46 | raise ValueError('Dimension of input must <= 3, got {dimension} instead')
47 | return choice_cluster, choice_points
48 |
49 |
50 | class KMeans_Mixed():
51 | '''
52 | This version uses GPU for tensor computation and
53 | CPU for indexing to improve the speed of the algorithm.
54 | '''
55 | def __init__(self,
56 | num_clusters,
57 | shift_threshold,
58 | max_iter,
59 | cluster_centers = [],
60 | device=torch.device('cuda')):
61 |
62 | self.num_clusters = num_clusters
63 | self.shift_threshold = shift_threshold
64 | self.max_iter = max_iter
65 | self.cluster_centers = cluster_centers
66 | self.device = device
67 | self.pairwise_distance_func = pairwise_distance
68 |
69 | def initialize(self, X):
70 | num_samples = len(X)
71 | initial_indices = np.random.choice(num_samples, self.num_clusters, replace=False)
72 | initial_state = X[initial_indices]
73 | return initial_state
74 |
75 | def __call__(self, tensor_input, debug=False):
76 | if debug:
77 | time_start=time.time()
78 |
79 | X = tensor_input
80 | X = X.to(self.device)
81 | choice_points = np.ones(self.num_clusters)
82 | # init cluster center
83 | if type(self.cluster_centers) == list:
84 | initial_state = self.initialize(X)
85 | else:
86 | if debug:
87 | print('resuming cluster')
88 | initial_state = self.cluster_centers
89 | dis = self.pairwise_distance_func(X, initial_state, self.device)
90 | choice_points = torch.argmin(dis, dim=0)
91 | initial_state = X[choice_points]
92 | initial_state = initial_state.to(self.device)
93 | iteration = 0
94 | status = 0
95 | while status == 0:
96 | # CPU is better at indexing, so transfer the data to the cpu
97 | dis = self.pairwise_distance_func(X, initial_state, self.device).cpu().numpy()
98 | choice_cluster = np.argmin(dis, axis=1)
99 | initial_state_pre = initial_state.clone()
100 | for index in range(self.num_clusters):
101 | selected = np.where(choice_cluster == index)
102 | selected = X[selected]
103 | initial_state[index] = selected.mean(dim=0)
104 | dis_new = self.pairwise_distance_func(X,
105 | initial_state[index].unsqueeze(0),
106 | self.device).cpu().numpy()
107 | culuster_pos = np.argmin(dis_new, axis=0)
108 | # a cluster has at least one sample
109 | while culuster_pos in choice_points[:index]:
110 | dis_new[culuster_pos] = np.inf
111 | culuster_pos = np.argmin(dis_new, axis=0)
112 |
113 | choice_points[index] = culuster_pos
114 | initial_state = X[choice_points]
115 |
116 | center_shift = torch.sum(torch.sum((initial_state - initial_state_pre) ** 2, dim=1))
117 |
118 | iteration = iteration + 1
119 |
120 | if center_shift **2 < self.shift_threshold:
121 | status = 1
122 | if iteration >= self.max_iter:
123 | status = 2
124 |
125 | if debug:
126 | print("Iter: {} center_shift: {:.5f}".format(iteration, center_shift))
127 |
128 | if debug:
129 | if status == 1:
130 | time_end=time.time()
131 | print('Time cost: {:.3f}'.format(time_end-time_start))
132 | print("Stopped for the center_shift!")
133 | else:
134 | time_end=time.time()
135 | print('Time cost: {:.3f}'.format(time_end-time_start))
136 | print("Stopped for the max_iter!")
137 | return initial_state, choice_cluster, choice_points
138 |
139 | # utils
140 | def pairwise_distance(data1, data2, device=torch.device('cuda')):
141 | data1, data2 = data1.to(device), data2.to(device)
142 | # N*1*M
143 | A = data1.unsqueeze(dim=1)
144 | # 1*N*M
145 | B = data2.unsqueeze(dim=0)
146 | dis = (A - B) ** 2.0
147 | dis = dis.sum(dim=-1)
148 | return dis
149 |
--------------------------------------------------------------------------------
/cluster/cluster_test.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from Group import Cluster_GPU as Cluster\n",
10 | "import matplotlib.pyplot as plt\n",
11 | "import torch\n",
12 | "import random"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": 2,
18 | "metadata": {},
19 | "outputs": [],
20 | "source": [
21 | "num_clusters = 3\n",
22 | "cluster = Cluster(num_clusters=num_clusters, \n",
23 | " max_iter=30, \n",
24 | " shift_threshold=1e-3, \n",
25 | " device='cuda',\n",
26 | " debug=True\n",
27 | ")"
28 | ]
29 | },
30 | {
31 | "cell_type": "code",
32 | "execution_count": 3,
33 | "metadata": {},
34 | "outputs": [
35 | {
36 | "name": "stdout",
37 | "output_type": "stream",
38 | "text": [
39 | "torch.Size([5, 300, 2])\n"
40 | ]
41 | }
42 | ],
43 | "source": [
44 | "num_point = 300\n",
45 | "batch = 5\n",
46 | "test_tensor = torch.cat([\n",
47 | " torch.randn(batch, num_point//3, 2) + 4, \n",
48 | " torch.randn(batch, num_point//3, 2) - 4, \n",
49 | " torch.randn(batch, num_point//3, 2)],\n",
50 | " dim=1\n",
51 | ")\n",
52 | "test_tensor = test_tensor.cuda()\n",
53 | "print(test_tensor.size())\n"
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": 4,
59 | "metadata": {},
60 | "outputs": [
61 | {
62 | "name": "stdout",
63 | "output_type": "stream",
64 | "text": [
65 | "Iter: 1 center_shift: 32.28633\n",
66 | "Iter: 2 center_shift: 6.91954\n",
67 | "Iter: 3 center_shift: 0.12403\n",
68 | "Iter: 4 center_shift: 0.00000\n",
69 | "Time cost: 0.025\n",
70 | "Stopped for the center_shift!\n",
71 | "Iter: 1 center_shift: 13.87646\n",
72 | "Iter: 2 center_shift: 4.93253\n",
73 | "Iter: 3 center_shift: 0.00000\n",
74 | "Time cost: 0.008\n",
75 | "Stopped for the center_shift!\n",
76 | "Iter: 1 center_shift: 3.94572\n",
77 | "Iter: 2 center_shift: 0.00000\n",
78 | "Time cost: 0.006\n",
79 | "Stopped for the center_shift!\n",
80 | "Iter: 1 center_shift: 5.07124\n",
81 | "Iter: 2 center_shift: 1.16594\n",
82 | "Iter: 3 center_shift: 0.00000\n",
83 | "Time cost: 0.008\n",
84 | "Stopped for the center_shift!\n",
85 | "Iter: 1 center_shift: 14.44800\n",
86 | "Iter: 2 center_shift: 3.34260\n",
87 | "Iter: 3 center_shift: 0.21062\n",
88 | "Iter: 4 center_shift: 0.00000\n",
89 | "Time cost: 0.011\n",
90 | "Stopped for the center_shift!\n"
91 | ]
92 | }
93 | ],
94 | "source": [
95 | "cluster_ids_stack, cluster_centers_stack = cluster(test_tensor)"
96 | ]
97 | },
98 | {
99 | "cell_type": "code",
100 | "execution_count": 5,
101 | "metadata": {},
102 | "outputs": [
103 | {
104 | "name": "stdout",
105 | "output_type": "stream",
106 | "text": [
107 | "(300, 2)\n",
108 | "(300,)\n",
109 | "(3,)\n"
110 | ]
111 | },
112 | {
113 | "data": {
114 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAx90lEQVR4nO2de5RU1Z3vv7+qU011R0jCSAZfBHQ6PgLXRxHUPMhDFO6ocWat9Ag35Epkhkwy08vJdW7So8u4WN5k9b1JdLidrFyYkMEbnca0OpqHI4KaEHMVpOIDBJ1WRNSWkUwngxGhu6r2/WP36T7n1D6Pqjr1/n7W6tVdp87ZZ1ct+O7f+e7f/m1RSoEQQkjzkqh3BwghhFQGhZwQQpocCjkhhDQ5FHJCCGlyKOSEENLkWPW46Yknnqjmzp1bj1sTQkjTks1mf6OUmuU9Xhchnzt3Lnbt2lWPWxNCSNMiIq+YjtNaIYSQJodCTgghTQ6FnBBCmhwKOSGENDkUckIIaXLqkrVCCKkP2ZEshkeH0T2zG5mTM/XuDokJCjkhbULftj4M7BxAUpLIqzx6F/Wif0l/vbtFYoDWCiFtQHYki4GdAzg6fhRvjb2Fo+NHMbBzANmRbL27RmKAQk5IGzA8OoykJF3HkpLE8OhwnXpE4oRCTkgb0D2zG3mVdx3Lqzy6Z3bXqUckTijkhLQBmZMz6F3Ui65UF6Z3TEdXqgu9i3o54dkicLKTkDahf0k/es7pYdZKC0IhJ6SNyJycqYmAM82xtlDICSGxEneaIweFcCjkhJDYcKY52gzsHEDPOT1liTBz36PByU5CSGzEmebI3PfoUMgJIbERZ5ojc9+jE4uQi8h7RORuEXleRPaJyMVxtEsIaS7iTHNk7nt04vLI1wF4UCn1GRHpANAVU7uEkCYjrjRHe1DweuSc8CxGlFKVNSDybgBPAzhdRWxs4cKFint2EkKiwKyVKUQkq5Ra6D0eR0Q+D8BhAP8oIucCyAK4Tin1dgxtE0LanFrlvjczcXjkFoALAHxPKXU+gLcB9HlPEpE1IrJLRHYdPnw4htsSQggB4hHy1wC8ppTaMfH6bmhhd6GU2qCUWqiUWjhr1qwYbksIIQSIQciVUocAvCoiZ04cugTA3krbJYTEQ3Yki817NjP/uoWJK2ulF8CdExkr+wF8PqZ2CSEV0G4rI9t1YjQWIVdKPQ2gaCaVEFI/4l4uX0vKEeR2G7ScsNYKIS1K2MrIRo1cyxHkZh604oBCTkiL4rcy8sEXH8TqvavrHrmaou4oguy8DtAD0v7f7vcdtCjkhJCmxbQysuecHgztHap75OqNunvO6cGyP1oWKsjO694ZfwcQoNPqxPH8ceQKOdd11VjO36gePIWckBbGu1x+eHQY9+6713VOrSNXU9R9+zO3Y+i5IRRQQC5vFmTTdVDAW2Nvuc5PJVJIJVOxL+dvZA+eQk5Ii+NdGRlHIapKIlOTdw8AR3NaoJOSRNpKI5VIueqrbN6z2XidFxHB+svXY+W5K0vqVxCN7sFTyAlpI+IoRFVpZGry7p3kVR4fOfkj+OKHvugaKLpndms7JYRpyWmwklPSZg86uXwOVtKKbfBpJA+eQk5Im1FJdcK4ItMl85bgof0PISEJt1Uywc6Rnbh15q3FbQqAkNJ8zicMe9AZz49jvDCOjkQHrKQVy+DTSCV1ubEEIS1CKSs4MydnsHz+8rIi0/H8uOvYeH488mYPfdv6sHjTYjx64FEAwCXzLsHiOYuLzkslUkVtDo8Oo9PqdB3rSnXhgpMuQNpKF9U/dw464wXd57HCWFk7DcVZZ70aMCInpAWo1URcLp+bFEWb8cJ40QSliexIFut2rMOx3LHJYw+//DDWX74eO0d2uo6bol0/S2bDFRsAFOfF+3nxQHm2SFx11qsBI3JCmpxa7m1pJS10JDrcxxIWdo7sNN7P+ZSw9hdrXWINYNK3vu7C63yjXbsNAK6oOG2lccm8SwCYnzC6Z3YXDTo25doi5T7JVBtG5IQ0OdWYiPPLSume2Q0raWGsMDZ5LFfIYeOvN2LjUxtdTwLOp4TxwniRJQNoqyOXz/lGu6Ynje2rtmPtL9Zi6/6t+PmBn+Phlx82PoEM7R0qyi13euS1FONq559TyAlpcuKeiAuyaZxZLwAmJyrt1EF74tP+2zSR6SQpyckME2+apN/E6vxZ8/Hwyw/jWO4YjuGY677eFaJOIe9IduDmj9+MpWcsrevip2rYXrRWCGly4pyIi2LT9C/px/ZV27H6/NXostzb89pPAkH+tJNUMuU74Pg9aewc2RlYQ8bv2mnJaTj9vafXPBKvhe3FiJyQFiCuibioNo3998anNrrOdT4JeJ8SLLGgoCaPW2Lh6g9ePSnA3j77PWksOnlR4H39rh0vjNc8XbBW+eeMyAlpEeKYiCvVprlk3iXG1D/TU0Lm5AxSyZT2qRMWzjrxLAztHcKan6zB4k2L0bfNvUOk35PGynNXhj6B2NdaMhWr5go5DO0divxdxLEhR63yzyXixvexsnDhQrVr166a35cQEk4UT9c7kXnZ6Zfhax//WlGFQjsrJZfP4Qs/+0KgZ96V6sL2VduLBiK/icKgCcTsSBZbXtqCtb9Yi7H81MSs3z3K+Q6iEmdbIpJVShXt/UAhJ4QUESSeW17aglu23+JKJXQKpEm4zpt9Htb8ZE1RgSsn0zumY8OVG7B8/vKK+m7fv1Ao4Fjene4Y5R7ZkSwWb1rsGnSiDgBBbcaRteIn5PTICSFFeDNIgGCBdE42mjJN1l++PrC+ChCP5WCskFjiParha5u+zzihR04ICSQ7ksU3fvkNrNuxDkfHjxaJODAlkH4iaOduO33tC0+5EF2pLpyQOiG2Je9+2TJpK+270MjpgWdHstj/2/1FC4kaqa6KCUbkhBBfJqNwVShalQlogUxIwiWQ7+TcFQrfyb2D7pndWD5/eVFmzX377sPye5bjrs/chavOuqri/pomF9NWGjctvsmVP26yfw79/hAG9wzCSljIFXKwxEJnqrOsCpG1hkJOCDESZlOYBBJAcXVCx2uvxTD43CDG8+MY3DMYi5D7lem94WM3BH6ub/3qW8hDDwD25GjaSqPvo301X0BUDhRyQogRX5simUYikXAJpD2Zt/+3+9GZ6nRNanamOo3+8oujL+LHL/wYBRRw/wv346XRl3DGzDMm3y93gjAsp970uWwRd5KQRM0XEJVLbEIuIkkAuwC8rpS6Iq52CSH1wWRTpBIpXHv+tbj2/GuNNoWpEqKfv3zjIzdO1l8Zz4/jxkduxObPbC5qMyhlz0/sgyYXwza2sMkVck2z52eck53XAdgXY3uEkDriXJDjrHi46ZlNkwtrvEvQj+WOQUEZFwk5eXH0Rdyz955JQc2rPO7eezdeGn0p8rJ2u7a534KiKJ/LrqJoJYpj2hXzV8S+52c5/Y1CLEIuIqcCuBzA9+NojxDSGPQv6cf6y9frnXmgl7k7hdVoU0wso99w5QZsX7W9KJLOjmTxp3f9qXHF45d+9qXA9D9nG5XUMLHrxWy4cgMe+/xjuP7i69GV6kKX1YWOZAeuOfcabPqTTRG/pXCqXXMlLmvl7wF8BcB0vxNEZA2ANQAwZ86cmG5LCKk2VtLCtOQ01wpJW1j9an77bdXWt60Ptz5+q2+d8EcPPIovfehLocva48j1dtovmZMzRb56nDZItWuuVByRi8gVAN5USgUOLUqpDUqphUqphbNmzar0toSQKuHNrw6qF5I5OYNLT7+0qA3TVm3ZkSy++f++6SviAJAv5DG4Z3DS+uhK6QjZuydoNWqYOGvVxG2DVLvmShzWykcAfFpEDgDYDOBTInJHDO0SQmqMScDCyuTe/PGbkbbSrnaO5Y4VTXr+4KkfoKAKgfcvoIC7996Nv7jgL9BzTs9kPvfQ3iGXmFZzD81q2CDV3vMz1lorIvIJAH8blrXCWiuE1IawwlJeKyGoxkhQW33b+vDtx7/t2sjBSli4/uLrJz3y5Xcvx13P3RWp3xedchGeffPZ0Hon1cgC2bxnc1FdmKAaLaX0odL+stYKIW1GUAqfX2Er72Iep48blNI3f9Z8eIPCXCGHdTvWTdois981O3Lfnxx5sijK96uL7tencgW2FBvE7zsuJy2yEmIVcqXUzwH8PM42CSGl47dNmt82bAM7B/CpuZ+a3LLNxitgJoFadd8q3PHsHb652bb47nh9R+T+m0oClOIpl1I61nSuaXWo6YnG9D0e+v0hDO0dqurWbl4YkRPSgoSl8JlWbD60/6GiY85JRlMN8oIq4KfDPw3sS/fMbhwdP4qdIzsj9z+VTGE8P45OqxNWwopU78RZA91vEIsqxttXbQ/dccm48lUBg3sGXRk+fveOEwo5IS1ImD3gLWx1PHdcpxjCvQnDsj9aBsAseD/+1x8H9iEpSVz9wasnBez1//Y61v58LTY9s2kyWr28+3L8bPhnrug7baXxo8/8CB3JDmTfyOL1I6/jlBmnYOkZS33v5RxkjuePB1pEToIGPDuDxc7i8Qq66TvOqRyshGVM1aSQE0JKwq94lC1MRYWtoDNGvOTyOWzesxn7f7s/0mbKACAQJCWJjmQHhvYOYfYJs9G/pB+zT5iN713xPfz5BX/uinRN1savXv0VBnYOIJfPYawwhlQiha//8utGmyKsuBfgb8uEDXir7ls1WRERgOv+pu+455yeou3kalEClzsEEdLCmDxtv6yMT879JLa9vG1SlBa8bwF2v7l7qoZKIefKSvEjiaSrCFWU3XWc/QRQlD1jk7bSeOzzj7naMn2eVCIFEcG05LSyPPL+Jf1Ydd8q3P7M7a5zo2TOxLm1mxdu9UYIARC8lRmASZ/Zu8emJRYSiYTLNrBJW2kUVAEXnXIRnjr0VOTUPRMmYXby6Q98GvevuD/086y/fD2spFVWWmB2JIsP/+DDRZ+1K9WFjZ/eGPpZqlUci+mHhLQRQUISZLvY72/es7nISulMdaLvo3144rUnsHX/VqQSqUk7YdkfLXNF005KtRa6Z3YXefhOHtr/ELIjWVd/7c8DpX3qnnN6sPLclZHv6U0LHB4dhiWWa84AiF4Rsdpbu3mhkBPSYkR5tA+r2e3nHS89Yylu+NgNgQNF76JerNuxTvvwgvJWMAYYBfbyf2eb/Uv6XTv8OL35cuie2T1ZKMxJ3BUR44J7dhLSQpSyvNxZW8T0XtCS8qBrJzEIYRSGR4ch4n/x8fzxouX/2ZEshvYOYSw/hqPjRyteVu/8/NWqiBgnjMgJaSHirLIXFrWbsAcSZzphqXnUuXzOWFjLnkRVSuELP/sC9hzeMxlxV6O6YDmfv15QyAlpIeKusleq1xuHoFpJCx2JDowV3LnYAACla6KPF8ZdA0S1qgvW2usuF1orhLQQ1a6y56WUkrdR6Z7ZDSvpjjETksC05DTXMedK1Vp/7kaDETkhLUatLAG/SdUodUqCsEXZWU2xoApFqYDeAaKZrJC4YR45ISSUOEveRm3/o//4UZfXbiUsWAlrMu3RmY0Tdr8487qrlSMeBeaRE0LKwq/kbZAXXoq37Nd+wuP8dlo6j/30957uEtGwdMs4V1pWc9VmJdAjJ4T44pfOmMvnYplczI5ksW7HuqL273z2TmNJ3aVnLHUVs/rGL79hvN727OPc7afaGyhXAoWcEDKJd/LSLwvFSlqxTC6u/cXaorrjUNFK6i7etBi3/OKWouudk6Bh5XxLIc624obWCiEEgNk26DmnpyjythfkVDq5mB3JYuv+rUXHxwvjSCVTrsnNsJK6TpxPBnGmJVZ7A+VKYEROCPG1DQBMRt4diQ4AmFyQY2/MHLrC04fh0WGkEqmi4xeferHxfFswjRs6QBfuMq1AjSstsZFTHBmRE9LmZEeyulyrz2YM/Uv6MX/WfKz+yWoA5gU55WCKcNNWGrcuvRVDe4d8Uxj9rrtp8U1YesbSov7EmZbYqCmOFHJC2hjbTgEQuF+nlbT0DkIx7nwTVIUxc3LGVzD9rrvhYzcE3isu0W3E1Z4UckLaFD+vucvqKqpaWC1/OCjCDRLMRo2M6wWFnJA2xeQ1d6W6sPr81bjm3GsiRcFxCGi5EW4jRsb1omIhF5HTAPxfAH8I7bJtUEqtq7RdQkh1MUXZAIpE3IZRcOMSR0SeA3C9UurXIjIdQFZEtiql9sbQNiElkc0Cw8NAdzeQCdCZqOe1MuVE2YyCG5OKhVwp9QaANyb+fktE9gE4BQCFnNSUvj5gYABIJoF8HujtBfoNq6ejnlcPaj3AMMpuDWItmiUicwFsBzBfKXXE894aAGsAYM6cOZlXXnkltvsSks0CixcDRx3zdl1dwPbtbkGMel49aOQBplLqWWiqlfArmhXbgiAROQHAPQD+xiviAKCU2qCUWqiUWjhr1qy4bksIAB3FJj1rRJJJfbyc82pNNqtF/OhR4K239O+BAX282bGX06/5yRos3rQYfdv66t2lliMWIReRFLSI36mUujeONgkphe5uHcU6yef18XLOqzWNOsBUSiMXmmolKhZy0bukbgSwTyl1a+VdIqR0MhltRXR1AdOn69+9vcV2SdTzak0tBxhvYaxq0siFplqJOLJWPgLgcwB2i8jTE8duUEo9EEPbhESmvx/o6QmfLIx6Xi2xBxivRx5332pdT7uRC021EtwhiJAGoppZK2G7+lSLRt2MoRnhDkGENAGZTHGWTVzCHscO9+XAFMfqQyEnpEGJOx2xnjYHFxJVF9YjJ01HNgts3twaqXl+VCMdsZHraZPKYEROmgq/KLXVltwHpSNW8vloc7QmFHLSNDijVJuBAeDQIWBoKJoF0SyCX810RNocrQetFdI0mKJUABgcjGZB9PXp5flr1ujffSUsMKy1ndOo+e6kMWFETpoGU5SaywGWBYxNbVzjWhFpR99AcTS/bh0wYwawdKkWSL9ovV41UBox3500JswjJ02FV1R7erSt4i2CZR+3z1uyBHj0UR21O0mngUQCWLAA2L3b7L03apGtZrGJSHwwj5xUlVqJiilKnT07XNwfesjc3rFj+veOHe7jt90GzJ+vo/1qTDpWSitXSiSlw4icVEwjiIpzIBke1j64M/qePh34xCeAhx8GCoUpAQ+iowNYscIc8UeNyKsxwDXyUwKpLlUvY0vakzvu0NFrvcuvZjLA8uX6t1/Gx803a7G76SZtqYQxNqZFvKcneNLRbyJ01Srgwx8GVq8ufXI1iFatlEjKh9YKKZu+Pi3izolGoP7WQ1gBqkwGOHLE/f6CBcBTT5k/y7Jl+npnZG1H2g8+aE59vPJK4Kc/1W3YbQ4M6EGh0u+lUUvxkvpBa4WUhenx3qZRHvPDbA37fTvzZXgY+B//wy3mps9iW0lA8efv6gI+9akpEfe+t3GjfnKolEaws0jt4WQniRW/nO6OjsbJd/YWoLJxCvzTT7sF8fzzi7NXvDaKN43Ry4MPmo/ncvFFzUxNJE4o5KQsuruBd95xH0smdcS5cmVpbVVrQjAsJ3x8XItrLjf1/u7dwPr1OkI39cdvALPJ5XQ6o4kVK+IVXL+BirQfFHISGyLA2WeXdk1Ui6AUsQ+qxxIWTSeTWsT97A+TP+3ktNOAl14qPn7FFcCmTcH9JqRcmLVCymJ4GOjsdB/r7CwtcyJqxkspS+uDqgaGRdOAedLQmZViT6R2dBRfm04Dr75afPyKK4Cf/CT4voRUAoWclEWlmRN9fTotzy/jxabUcq5BqXmmPluWFmC/1EJTCmF/v7aQvGJeKOj2nHR1AZ/9bOjX4Us7lOwllUMhJ2VRSVEnW5y9Ig4UDwal5kwHDTCmPl9/PfDYY8CGDTo7xWnrfOxjwO23634ePeoeRFauBL78ZXdbK1b496kcKinyRdoLeuSkbMrNnCgl46XUyN8W63Xrpo452+zp0YWygKliWfZ1ThYsAPbs8e9/JhOtXEC5GTx+JXtLyUNnLZb2gUJOKqKczAmTOHd0TGW8eAUojt3ls1lg7Vpg61YgldLtHDlibueOO/xF3JtC6P38TnG389Ntb70UKt1YgnnmbYZSquY/mUxGkfbmq19VqqtLqenT9e+vfjX4+K5dSg0O6t82fse6upQCpn4sS6mODvcxQJ/nvNamt7f4XPvnmmsq+3xh2J/phz8s/hx+/TW1Ue61pLEBsEsZNDWWiFxElgFYByAJ4PtKKY79JBCTLRHFTrC98aEhc8RpimSdeeJOCgVgy5biCHfRIvP5H/1otBTCUm0Rv+X+ptK6UaLxam0TRxqXioVcRJIAvgvgUgCvAXhSRH6slNpbaduktfHaEkEC5BRu00IeWyjD8rydHDsG3HKLtlictsPKlcB3vuMubZtIuEU8yH8uRUhtC8RUkTFscZIfrMXSfsSRtbIIwItKqf1KqTEAmwFcFUO7pM3wE6Bczp2CeOxYcZTtFMreXp1SaMr19nLsmDmd8YkngB/+ULd14YX62I036t9h2SS5HHD8ePHnMOWnr1unP5eprK5zcVIpkXQmowe1jo6p7yGOYl2kcYnDWjkFgHMZxGsALvSeJCJrAKwBgDlz5sRwW9JqeCc2x8eBSy4BDh6MtpAnl9M514cO6WN+9eBE3O/5RcsrVwIXXaQtjkIBuP9+4NvfLq746Nwyzn5ysEml9I/JFlm7NrguerlRdF+f7kehoPtpWfr17Nmc8GxZTMZ5KT8APgPti9uvPwfgO0HXcLKTBLFrl1JXXqlUOq0nC9NpPWHpnLxLJpVKpfQkXleXUhdeOPW330QloCc90+noE4F/9mf6XoBSIvrH1G46be5nR4eeuDR9Rm8/vNdFnSD1tuv3HXDCs/lBFSc7XwdwmuP1qRPHCCmbhx/W0aodsdorMFOpqWJdqZSOwi+7DHjkkfAaKtOm6cgYKN4azp5EdUbNL74I3HPPlN0TVPHZL7KeNq14tSeg75dKma9zpmJ6CcsNDypDwAnP1iUOIX8SQLeIzIMW8OUA/ksM7ZI2xSRGnZ1TXvQtt7h98oceMoulk0RCTxza4miLt50pcu+9xfnWX/pS9IlTP/zsEVP1SECL+4oVZhGPkhseNNnLCc/WpeLJTqVUDsBfA9gCYB+AHymlnqu0XdK+mMRofHzq71TK/Z5l+acY2qTTbrG3t4Sz9+P01nG57z4d5ZdKWO2WMJJJ3SfvBGrUmjPOMgT299TRUV5fSPMQSx65UuoBAA/E0RYh3knPd97RQt3fP5V66OWyy8y78tiYolG/NMG1a4EHHigvGs9kgO9+N3xpvF090rlBNDBltZjy56OmNJpWl3KZfmvDJfqkIbHFaMuWKSvFFj2nX25bDOedZ/bJ02ltq5iiUb/If8uW8i2V3buBffvCrZ6wfHevSJdTc4bC3T5QyEnDksmYJwU7O4HPfU6L3aJFU/VZvKTTwE03uYtjedtfsMC98OfEE4E33ii/z2NjuuTttGlmH9s5WWk/dQDFA5BXpOOqOUNaFFMqS7V/mH5IouJXO8VOTYxSp6WUtuP+cab8mfpn11a55ppofTfVlyHtA3zSD0UF5VRViYULF6pdu3bV/L6kOQnbZ9O5030ppVs3b9bRc1DaYqVMn65rnXd361Wgzns5+w2w7CwJR0SySqmF3uO0VkjD45y8279fv3ZOEjr95FK84e7u8GwXPzo69LWFQvB5tkUyPOzOvAH0a6cPTl+blAt3CGpnRrPAgc36d4OTyeiaI0uXxlcQKpPx39XHJpEortliZ490dupJzYThf1E67U75y+XMQl7uQEKIEwp5u/J0H7B1MbBzjf79dHPsI1bJFnMmenuD67gkElrs7ful07pWy9gY8Pbb5qjcnmR1bh1nWcUDQkdHeHYLIVHgP6N2ZDQLvDAA5B2G7QsDwJweYGbjP9uXu8WcieFhLdLefG6bzk5g2TIt+H7WjperrwZuuMF9rLtbi7az2Jad301IpTAib0eODAPiCUMlqY83CbbVUqmnHJbP7dy42c/acdLVpYXf1N84nyQIcUIhb0dmdAPKo0Yqr483ONmszjYx5Y2Xg1dgLUv/+Imt8/yuLnObflF2f7+2WzZscNsuhFQK0w/blaf7tJ0iSS3iZ/YC51WuLNVMoavmhsLOfgPmv52fx297Nm5yTKqJX/ohhbydGc1qO2VGdyzeeKVCGzQIZLPhedhx4OyD376gpfSbkDjxE3Ku7CSxUOnO7WGrMgcH9Xve9nt741vl6OyDaZMIbsxA6g18VnbSIyexEFSdL4woJVpNk5JHjwL/8A/Ahz8MrFpVWf+9fQjaF5SQRoNCTmKhkp3bowwCfpOMx47plL7bb69MzIN21rHhxgykUaGQk1ioJL0u6iBgZ32sXq0X3XgZHCw/m8XUh0o3iSCkVnBBEImNchfqlFKi1T62fn3xe5ZV/p6Ufn2Ia+ERIdWEWSuNTBxZJTFnplSTUrI/Vq3SdoqTOLJYmIFCGhlWP2w24sjzrlKueLUopfrfpk369+DgVL2SOKwPViAkzQgj8kZkNKsLWTlroSS7gEu3R4+q42ijCWAETdoJRuTNRFAtlKgiHEcbTUCjRdAcWEg9oJA3InHUQimljSby0etFFIGuZgkBQoKoKP1QRL4pIs+LyLMi8s8i8p6Y+tXezMxoPzvZBVjT9e8ze0sT2ahtNGld8lrS16fLA6xZo3/3Gb6iKIuaCKkWFXnkInIZgEeUUjkR+Z8AoJT6ath19MgjUu2slSAfHWCUjug1XjZv1kLvrFNu79e5fHnt+hsHtIcal6p45EqphxwvnwDwmUraIx5mZioX0aA2/Hz03WuBQw+Xlu3SovZM0KpTp8iZFhQdP958W7nRHmpO4lzZeS2Af/F7U0TWiMguEdl1+PDhGG9LysbkoxfGgTe26ig995b+/cJA8L6eNbRn4q5HHkbUVafOla2p1NTxL3zBbMU0IrSHmpdQIReRbSKyx/BzleOcGwHkANzp145SaoNSaqFSauGsWbPi6T0JJmxzZZOPftJlQCLlPi9o9yDntnFRhb9MonjVcVNK6YH+fr3iVES/HhtrLjGspPAZqS+h1opSaknQ+yKyCsAVAC5R9UhKJ2aiLgY6r1/v1WnbIgBwaJv7nKCMmRqlOTqjRZuBAb2Evto+bimlBywLmDbNvTenyYppRCopfEbqS0UeuYgsA/AVAB9XSh0NO5/UCNPmys+vA1IzgJOWFgus10c/s7d4EPAT5RptGxfVq64WUfPVm1kMS6l5QxqLSvPIvwNgGoCtop8nn1BK/WXFvSKVYYqSC8eA3bcAe74ePnnpjdJtETdNaNr2TFThL5NmEchmF8NyC5+R+sIl+q2IKa3QSTlL9cOsmhpkrdgZFYDOBlmxYqrmSqPBFD5SDfzSD1mPvBVxTmImDIW7gyYvTUSZ0JyZAeYur2rqoR0t5nLaix4aatyMkExG549TxEktoJC3Kuf166h7/k3FYm7ysIMyXIImNGtINqvF284GaaaMkKjUOr2StAastdLK2JOYuSPBHvbjq4ADg0Bi4p+D1zap0YRmGPWe8Kw2XIxDyoUReaMSlgNeCnZ0vmiD/u0U6cdXAS/fDqgxbZn42SaV1n6JgWaZ8CwHLsYhlcCIvBGpxoYQpqX6o1nglUHz+d48cL9MlhrS7BkhQbT60wapLhTyRsOUA/7CgBbRuMXzyDAgFoAx9/FCzmybxFH7pUJaNT2ulZ82SPWhtdJoVGNi0c+m8fO4566ou2AH0YoZIaWUAiDECyPyRsNYyOo4oMoso2fbNICOtOeuAC7epF87F/MA+h7vd7wfRItWO6wnrfq0QaoPFwTVglJFzxbfQk5PQiKpC1mddV1pXrnfwqB517jFutz+NcmmzoS0ClwQVC/KKfF6Xj+waL0jCs/rJfZ7v11aFoufHfPKYPmLeWpY7ZAQEg0KeTUZzepiVeWI3tsHARQ8B3PAyJbo95/RreuLexGrfM+9QRYHEUKmoEdeTZ5dqyNpJ94Sr362hvi06T0eZIscHAKUdzCYoJTFPM57RF0cRA+dkJpBIY8Dk2iNZoFDW4vPLYxPiV6Q13zSUuDZm92TnGLp4zZB19sWCDyim0iXtpjHdI+waof00AmpKRTySvETrSPDeoLSG5GfdKkWvbB88ZkZYO5ngQP/NLHlTEJPdjoHiqDrTRZIIq1rr8y/YaqNoKjZ7x6XbvdfHFTLPHhCCIBm88jjXLYeB0ETfyZ/OpEGFtys/w7zmp/u09ZIIgUoAO+/2h3Vhl1vskAkAZy8dKr9sEnYsN1/TBOk9NAJqTnNI+Q13OA3MibRso8fHNLpgzZiTUXUo1ng9/uLhd72mp0DRP6oTkE8OOQewMK86qD6KFEzT8opltUgBbYIaSeaw1pp1Md1k2jljwIH7gT+7RHo/agnEEv39+k+nckCaCEXC0h2ur3mA5sNA0TCPUk6M6PbOzCoI20Uir+POT16ezcFHYnb75Wyz+bsJcAbD008GUTY/adGOwYRQqZoDiGv0Qa/JWOL6cu3u4+PPAAkPTXAEyngjS06F9wp8CoJfLAPKOSB44eBl+8A3n22YYD4PTD65JQnbVsvKACFMQCWft05W1swXu8+d2Tqu4oSNTtXhKo88N4PAZnbon3fDVBgi5B2ojmEvJEf16d/wHDQFlcHKg+8/TpcIg4AyAMvfA849trUoT+4UEexk5H7xITpC/8bGP4/WiQPDnlWbOaAfE6L77vnh0+kepfmO6N50xPQ4V/qY1GW7wMNUWCLkHahOTzyBqmHbcQv3/vEDxf3912nmM91ijgA/PsO4HfP6b8LjgFM5bS4+pWeBXQE/u87wycc7ai5kNP2zsGhqXmHqCtCCSENQXNE5EDjPq775XtnbtV/O/s7mi0+FxP+tpeRn0Gb2wbEck+kOlF54A8WAS9tLD7ufIIZzU4sGBoD8hNPD3bUPqPb3L69IrRRvntCCIBmichtarDBb8nMzABnX69TC+2fs6935IIvd09Qes89+Y99Gg4oZqYKuoph0Ths6ch/3srwJ5iweYcZJssIjWFnEUJcxBKRi8j1AL4FYJZS6jdxtNkU2Atq5vTonze2TGWI+GF6sthykbZTotI5WwvzK3e5I+fERGaM332cBM072OmRXuqdJUQIMVKxkIvIaQAuA3Cw8u40Ed6skPcsAH63W79+7uvBy9KdE4GjWeDMvwHeGgYO/0qnLSpDoSsnR18H3nyseOVoIjXlb9sCPne5fx/80gRN6Y/JLuCkZaFfCyGk9sQRkd8G4CsA7o+hrfrhVy8l6jJ0b0QdJc/dOxjM6Zn4O0TIVUGnOJoi6jceBHasnnht2CjC+Zn8onZTtG4fJ4Q0HBUJuYhcBeB1pdQzIn7pG5PnrgGwBgDmzJlTyW3j5/FVOiNDJr6OM3v1b7/CT34rOp2E5bmbBoODQ3op/st3oKjYlYs8cHg7MO+/6nOdA4E3LdHOcb94k39dGG8fuaiHkKYiVMhFZBuA2Ya3bgRwA7StEopSagOADYDeIaiEPlaXx1c5FvRMZG9487cBd4TtF7E6yR8L3p7Nb7LxpGXAkefDPfPCODD2H7qAle3N+42lBwb1Cs1SVsc2apYQIaSIUCFXSi0xHReRBQDmAbCj8VMB/FpEFimlDsXay2oxmtUi50UVJpa9O3BG2M6IVRWKKxwCWsSfuBY4tM28iMZvsjGZBkZ/Hd53lQdevx/oePdUVF4YN9syCUvnlntruxTGg58aqr2ohzXLCYmFsq0VpdRuAO+zX4vIAQALmypr5ciwFrm8ZxUmCijKzPTmYdsR6xtbgN23GMRcaVF1WhtO/OyLVwZhzCs3kR8DXtwItw3jY/lMm1Us8mq8/E2dK4U1ywmJjeZZEFQN/Cbv5n5Wp/iFecR2xDp+xLNZsocD/2S+3mtfzDgbGJoBJDom2gqZ9BQpjuqtLuC95wG/2aEHKUDfe3o3IB3u/knH1LxALWnUImiENCmx/S9WSs2Nq62a4a05UsjphTZ29BzVI7YFee83gYN3Fb+vAiwMr33xJ69pfx3Q6Yy/PwC8dh/w5iPu6xJp4NxvAM/cCBTecdwrr4tbAcWrSr1PH5KYyhuvpcXRqEXQCGlS2jsiB4In9UrxiGdmgHP+O3DwRzCuynwr4sYKnY555RPmapF9xlB7/f1XA2d/GTj+b/5PDt7PcmYvsO/bU3ZKIQc8+VdT+e+1sjgauQgaIU0IhRyIb1JvZgZ47/nAbw2TlccPB1/rFxWbolfn4pw5PYA1Q2esnLQ0+HPM6dEZOZO+eK68/PdKYXojIbFCIY+bs74MPP654uPTZvlfEzTxN6MbyL/jPr8wpo97rxs/EiyGfvuIOqmVxcH0RkJig0IeN/NWAv/6HU+kmwD29gP5t4tti6CJPwAY2QIoj1WjFPAf+0qfMDTtI+rFVCVxZEu0iL9UWLOckFigkPtRyQTg0id0SuKetRNecGFqX0yv0PpN/GW/rDNPABSv8swDz98WXHPc1HfvPqL6IgBKZ7AkLLfF8XSfe0ejZ78GnP23TBMkpMGgkJuII8d5erf2snNvTR1z2hb2QKFyxRN/uaN6R54gfreneNGSs9aKt+925F+0Q5Ej2l+0Xj9RAPr859e5z1d5YN+3mCZISINBIfcSV45zUGZGUOVEe3VmWBGD5DTgDz+pV44G1Vqx+x5WHyY5zZ1TfmRYr1r1ovJ6ERSFnJCGgUJuY0fIv98fT46zX2YGUDxQ/G63jobFAt7eD+z5umc/TgMqDyz4mv6xbZQjw8Cr95r7HlYfxuuNz+g2CzkQPsgQQmoKhRxwR8iF8WIfudwcZ1NmhqnWtyS1iM9drgeUPV8vbuvd86cGmaB8cb+nAO/Akn9HC7LVaU7/m5kB5n3WUVDMxgreOIMQUnMo5CYrRSxA0jpVz7Ys7EnEUi0Fb2ZG2GIY0w73dk3xsAnYsPxs78ACBLdnr3A9cCeAhPbkz7qOtgohDYYob2pbDVi4cKHatWtXze9r5MBmYOca96SkNR34YB/wrtP15OHBoXhWPtpCHKXNSrJm4l5yzyqFhDQEIpJVSi30HmdE7hchnzRhH+xYHU9xJ9NuQCctC46uyxXNuPOzme9NSEOTCD+lxbHtCNOO80HFnUrBad/k3tK/Dw6VH+GOZvWTxGi29GsJIS0HI3KgtL0ry5n4jLPanx3ZA8XVGgkhbQmF3MZkH8RV3CmuAcFepOOsleK3cQUhpG2gkIcRR3GnuAaEZ9eaC169MhjcHicrCWlpKORRiGOyr9IBYTQLHNpqfk8sf5vGr9wAxZ2QloFCXksqGRDCStCabBpTjvzztwG/ew74t0cqT6nkYEBIQ0Ahj5tqiZvfEvtE2t9WMU2yFsaAkZ+6jz2/DkjNKK1MLTdPJqRhYPphnDzdB2xdrBcYbV2sX8eFN00ykQZOuRK47DF/AQ2rr2JTOKbL7kbtsymd8oUBpkMSUicYkcdFLXaGL2VbN2BK/J+/TUfiQdiWTZQ+c/NkQhoKCnlcVFvcSt3Wzea8fl1w64nVgPKIuXQYjkXoMzdPJqShqNhaEZFeEXleRJ4Tkf8VR6eakjhzxb2rNiu1MuatBM7+snv16rxrgAU3a4um1D4HrYYlhNSciiJyEfkkgKsAnKuUOi4i74unW01IHLnifhOIcUT7fumPuSPl9ZmbJxPSMFRU/VBEfgRgg1JqWynXNVT1w7gpN2tlNKsnG50ee7ILuHS7/tvvPVY3JKRt8Kt+WKm18gEAHxORHSLyCxH5UEAH1ojILhHZdfjw4Qpv28DMzOgNIkoVxLCou5pWRrl9JoQ0BKHWiohsAzDb8NaNE9fPBHARgA8B+JGInK4MYb5SagOADYCOyCvpdEsS5rHTyiCE+BAq5EqpJX7vicgXAdw7Idw7RaQA4EQALRxyV4koHjvrghNCDFSafngfgE8CeFREPgCgA8BvKu1U28KomxBSBpUK+Q8A/EBE9gAYA3CNyVYhJcComxBSIhUJuVJqDMDKmPpCCCGkDFhrhRBCmhwKOSGENDkUckIIaXIo5IQQ0uRUtES/7JuKHAbwSs1vPMWJaM40Sfa7trDftYX9Duf9SqlZ3oN1EfJ6IyK7TPUKGh32u7aw37WF/S4fWiuEENLkUMgJIaTJaVch31DvDpQJ+11b2O/awn6XSVt65IQQ0kq0a0ROCCEtA4WcEEKanLYV8mbeNFpErhcRJSIn1rsvURCRb05818+KyD+LyHvq3acgRGSZiLwgIi+KSF+9+xMFETlNRB4Vkb0T/6avq3efSkFEkiLylIj8tN59KQUReY+I3D3x73ufiFxcj360pZB7No3+IIBv1blLkRGR0wBcBuBgvftSAlsBzFdK/ScA/wrg7+rcH19EJAnguwD+M4BzAKwQkXPq26tI5ABcr5Q6B3rHrr9qkn7bXAdgX707UQbrADyolDoLwLmo02doSyEH8EUA/Uqp4wCglHqzzv0phdsAfAVA08xSK6UeUkrlJl4+AeDUevYnhEUAXlRK7Z8o07wZetBvaJRSbyilfj3x91vQgnJKfXsVDRE5FcDlAL5f776Ugoi8G8BiABsBXdZbKfW7evSlXYU88qbRjYSIXAXgdaXUM/XuSwVcC+Bf6t2JAE4B8Krj9WtoEkG0EZG5AM4HsKPOXYnK30MHJ4U696NU5kFva/mPE7bQ90XkXfXoSKU7BDUscW0aXWtC+n0DtK3ScAT1Wyl1/8Q5N0JbAHfWsm/thIicAOAeAH+jlDpS7/6EISJXAHhTKZUVkU/UuTulYgG4AECvUmqHiKwD0Afgpnp0pCVp1k2j/fotIgugI4BnRATQ9sSvRWSRUupQDbtoJOj7BgARWQXgCgCXNMKAGcDrAE5zvD514ljDIyIpaBG/Uyl1b737E5GPAPi0iPwxgDSAGSJyh1KqGXYeew3Aa0op+8nnbmghrzntaq3cB71pNJpl02il1G6l1PuUUnOVUnOh/xFd0AgiHoaILIN+dP60UupovfsTwpMAukVknoh0AFgO4Md17lMookf3jQD2KaVurXd/oqKU+jul1KkT/6aXA3ikSUQcE//3XhWRMycOXQJgbz360rIReQjcNLq2fAfANABbJ54mnlBK/WV9u2RGKZUTkb8GsAVAEsAPlFLP1blbUfgIgM8B2C0iT08cu0Ep9UD9utQW9AK4c2LQ3w/g8/XoBJfoE0JIk9Ou1gohhLQMFHJCCGlyKOSEENLkUMgJIaTJoZATQkiTQyEnhJAmh0JOCCFNzv8HlU6MaDVHsLIAAAAASUVORK5CYII=",
115 | "text/plain": [
116 | ""
117 | ]
118 | },
119 | "metadata": {
120 | "needs_background": "light"
121 | },
122 | "output_type": "display_data"
123 | }
124 | ],
125 | "source": [
126 | "data = test_tensor[-1,:].cpu().numpy()\n",
127 | "index = cluster_ids_stack[-1,:]\n",
128 | "cluster_centers = cluster_centers_stack[-1,:]\n",
129 | "print(data.shape)\n",
130 | "print(index.shape)\n",
131 | "print(cluster_centers.shape)\n",
132 | "color = ['orange', 'b', 'g', 'r', 'm', 'y', 'k','c'] * num_clusters\n",
133 | "for i in range(num_clusters):\n",
134 | " t_c = (random.uniform(0, 1), random.uniform(0, 1), random.uniform(0, 1))\n",
135 | " plt.scatter(data[index==i,0], data[index==i,1],marker='.',s=90,color=color[i])\n",
136 | " plt.scatter(data[int(cluster_centers[i]),0], data[int(cluster_centers[i]),1],marker='^',s=150,color=color[i])\n",
137 | "plt.show()"
138 | ]
139 | },
140 | {
141 | "cell_type": "code",
142 | "execution_count": null,
143 | "metadata": {},
144 | "outputs": [],
145 | "source": []
146 | },
147 | {
148 | "cell_type": "code",
149 | "execution_count": null,
150 | "metadata": {},
151 | "outputs": [],
152 | "source": []
153 | }
154 | ],
155 | "metadata": {
156 | "interpreter": {
157 | "hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1"
158 | },
159 | "kernelspec": {
160 | "display_name": "Python 3.6.8 64-bit",
161 | "name": "python3"
162 | },
163 | "language_info": {
164 | "codemirror_mode": {
165 | "name": "ipython",
166 | "version": 3
167 | },
168 | "file_extension": ".py",
169 | "mimetype": "text/x-python",
170 | "name": "python",
171 | "nbconvert_exporter": "python",
172 | "pygments_lexer": "ipython3",
173 | "version": "3.6.8"
174 | },
175 | "orig_nbformat": 4
176 | },
177 | "nbformat": 4,
178 | "nbformat_minor": 2
179 | }
180 |
--------------------------------------------------------------------------------
/config/SCRL_pretrain_default.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | SSL: SCRL
3 | Positive_Selection: cluster
4 | cluster: True
5 | cluster_num: 24
6 | soft_gamma: 0.5
7 | backbone: resnet50
8 | backbone_pretrain: ./pretrain/resnet50-19c8e357.pth
9 | fix_pred_lr: null
10 | SyncBatchNorm: False
11 | resume:
12 |
13 | MoCo:
14 | dim: 2048
15 | k: 65536
16 | m: 0.999
17 | t: 0.07
18 | mlp: True
19 | neighborhood_size: 8
20 | multi_positive: True
21 |
22 |
23 | data:
24 | name: movienet
25 | data_path: /tmpdata/compressed_shot_images
26 | shot_info: ./data/MovieNet_shot_num.json
27 | _T: train
28 | frame_size: 3
29 | clipshuffle: True
30 | clipshuffle_len: 16
31 | # aug_type: asymmetric # asymmetric or symmetry
32 | workers: 96
33 | fixed_aug_shot: True
34 | color_aug_for_q: False
35 | color_aug_for_k: True
36 |
37 |
38 | optim:
39 | epochs: 100
40 | bs: 1024
41 | momentum: 0.9
42 | optimizer: sgd
43 | lr: 0.03
44 | lr_cos: True
45 | schedule: # works when lr_cos is False
46 | - 50
47 | - 100
48 | - 150
49 | wd: 0.0001
50 | gradient_norm: -1 # off when <= 0
51 |
52 |
53 | log:
54 | dir: ./output/
55 | print_freq: 10
56 |
57 | DDP:
58 | multiprocessing_distributed: True
59 | machine_num: 1
60 | world_size: 8
61 | rank: 0
62 | dist_url: env://
63 | dist_backend: nccl
64 | seed: null
65 | gpu: null
66 | master_ip: localhost
67 | master_port: 10008
68 | node_num: 0
69 |
70 |
71 |
--------------------------------------------------------------------------------
/config/SCRL_pretrain_with_imagenet1k.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | SSL: SCRL
3 | Positive_Selection: cluster
4 | cluster: True
5 | cluster_num: 24
6 | soft_gamma: 0.5
7 | backbone: resnet50
8 | backbone_pretrain: ./pretrain/resnet50-19c8e357.pth
9 | fix_pred_lr: null
10 | SyncBatchNorm: False
11 | resume:
12 |
13 | MoCo:
14 | dim: 2048
15 | k: 65536
16 | m: 0.999
17 | t: 0.07
18 | mlp: True
19 | neighborhood_size: 8
20 | multi_positive: True
21 |
22 |
23 | data:
24 | name: movienet
25 | data_path: /tmpdata/compressed_shot_images
26 | shot_info: ./data/MovieNet_shot_num.json
27 | _T: train
28 | frame_size: 3
29 | clipshuffle: True
30 | clipshuffle_len: 16
31 | # aug_type: asymmetric # asymmetric or symmetry
32 | workers: 96
33 | fixed_aug_shot: True
34 | color_aug_for_q: False
35 | color_aug_for_k: True
36 |
37 |
38 | optim:
39 | epochs: 100
40 | bs: 1024
41 | momentum: 0.9
42 | optimizer: sgd
43 | lr: 0.03
44 | lr_cos: True
45 | schedule: # works when lr_cos is False
46 | - 50
47 | - 100
48 | - 150
49 | wd: 0.0001
50 | gradient_norm: -1 # off when <= 0
51 |
52 |
53 | log:
54 | dir: ./output/
55 | print_freq: 10
56 |
57 | DDP:
58 | multiprocessing_distributed: True
59 | machine_num: 1
60 | world_size: 8
61 | rank: 0
62 | dist_url: env://
63 | dist_backend: nccl
64 | seed: null
65 | gpu: null
66 | master_ip: localhost
67 | master_port: 10008
68 | node_num: 0
69 |
70 |
71 |
--------------------------------------------------------------------------------
/config/SCRL_pretrain_without_imagenet1k.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | SSL: SCRL
3 | Positive_Selection: cluster
4 | cluster: True
5 | cluster_num: 24
6 | soft_gamma: 0.5
7 | backbone: resnet50
8 | backbone_pretrain:
9 | fix_pred_lr: null
10 | SyncBatchNorm: False
11 | resume:
12 |
13 | MoCo:
14 | dim: 2048
15 | k: 65536
16 | m: 0.999
17 | t: 0.07
18 | mlp: True
19 | neighborhood_size: 8
20 | multi_positive: True
21 |
22 |
23 | data:
24 | name: movienet
25 | data_path: /tmpdata/compressed_shot_images
26 | shot_info: ./data/MovieNet_shot_num.json
27 | _T: train
28 | frame_size: 3
29 | clipshuffle: True
30 | clipshuffle_len: 16
31 | # aug_type: asymmetric # asymmetric or symmetry
32 | workers: 96
33 | fixed_aug_shot: True
34 | color_aug_for_q: False
35 | color_aug_for_k: True
36 |
37 |
38 | optim:
39 | epochs: 100
40 | bs: 1024
41 | momentum: 0.9
42 | optimizer: sgd
43 | lr: 0.06
44 | lr_cos: True
45 | schedule: # works when lr_cos is False
46 | - 50
47 | - 100
48 | - 150
49 | wd: 0.0001
50 | gradient_norm: -1 # off when <= 0
51 |
52 |
53 | log:
54 | dir: ./output/
55 | print_freq: 10
56 |
57 | DDP:
58 | multiprocessing_distributed: True
59 | machine_num: 1
60 | world_size: 8
61 | rank: 0
62 | dist_url: env://
63 | dist_backend: nccl
64 | seed: null
65 | gpu: null
66 | master_ip: localhost
67 | master_port: 10008
68 | node_num: 0
69 |
70 |
71 |
--------------------------------------------------------------------------------
/data/MovieNet_1.0_shotinfo.json:
--------------------------------------------------------------------------------
1 | {"train": {"0": 1374, "1": 284, "2": 793, "3": 260, "4": 654, "5": 311, "6": 1197, "7": 1437, "8": 1236, "9": 1208, "10": 978, "11": 706, "12": 333, "13": 871, "14": 442, "15": 1620, "16": 660, "17": 645, "18": 470, "19": 955, "20": 1271, "21": 2469, "22": 837, "23": 2423, "24": 384, "25": 724, "26": 1166, "27": 668, "28": 936, "29": 325, "30": 338, "31": 1144, "32": 1124, "33": 620, "34": 843, "35": 1079, "36": 668, "37": 852, "38": 529, "39": 675, "40": 762, "41": 747, "42": 888, "43": 974, "44": 726, "45": 567, "46": 897, "47": 1045, "48": 2935, "49": 903, "50": 1449, "51": 620, "52": 700, "53": 791, "54": 596, "55": 814, "56": 922, "57": 142, "58": 1445, "59": 774, "60": 1128, "61": 1526, "62": 866, "63": 969, "64": 682, "65": 1646, "66": 756, "67": 1489, "68": 1381, "69": 1849, "70": 378, "71": 1279, "72": 1022, "73": 1017, "74": 793, "75": 1560, "76": 1311, "77": 2268, "78": 825, "79": 2016, "80": 1885, "81": 1115, "82": 70, "83": 1363, "84": 1191, "85": 1258, "86": 1310, "87": 2932, "88": 1365, "89": 1976, "90": 508, "91": 194, "92": 857, "93": 1474, "94": 738, "95": 1079, "96": 811, "97": 970, "98": 1108, "99": 1449, "100": 1178, "101": 832, "102": 1817, "103": 1229, "104": 2130, "105": 630, "106": 1874, "107": 1420, "108": 1724, "109": 827, "110": 467, "111": 1258, "112": 876, "113": 713, "114": 591, "115": 2476, "116": 863, "117": 1465, "118": 1386, "119": 1263, "120": 1164, "121": 1244, "122": 380, "123": 1891, "124": 960, "125": 300, "126": 1604, "127": 1756, "128": 868, "129": 1110, "130": 1133, "131": 1064, "132": 1158, "133": 2778, "134": 1299, "135": 1306, "136": 799, "137": 1488, "138": 869, "139": 1082, "140": 1532, "141": 1616, "142": 1147, "143": 1601, "144": 291, "145": 944, "146": 266, "147": 1760, "148": 946, "149": 1142, "150": 1713, "151": 1123, "152": 1031, "153": 992, "154": 1082, "155": 1170, "156": 1145, "157": 1412, "158": 1875, "159": 2196, "160": 1157, "161": 1227, "162": 675, "163": 1086, "164": 1293, "165": 1184, "166": 1706, "167": 1564, "168": 1610, "169": 1238, "170": 1131, "171": 1159, "172": 765, "173": 1759, "174": 2165, "175": 2540, "176": 2001, "177": 3034, "178": 1923, "179": 370, "180": 1430, "181": 1417, "182": 1217, "183": 1486, "184": 1965, "185": 2166, "186": 1308, "187": 919, "188": 1142, "189": 659, "190": 1704, "191": 2522, "192": 2180, "193": 1473, "194": 1799, "195": 1091, "196": 488, "197": 2244, "198": 1202, "199": 1237, "200": 1101, "201": 1555, "202": 1300, "203": 1383, "204": 1049, "205": 24, "206": 2890, "207": 1621, "208": 2368, "209": 2233, "210": 2023, "211": 3459, "212": 2298, "213": 1048, "214": 2021, "215": 1099, "216": 3902, "217": 1699, "218": 2776, "219": 1348, "220": 2103, "221": 1391, "222": 1257, "223": 1698, "224": 1442, "225": 2259, "226": 2231, "227": 1877, "228": 1462, "229": 1652, "230": 941, "231": 848, "232": 1612, "233": 1093, "234": 2223, "235": 1515, "236": 1141, "237": 1872, "238": 1021, "239": 2520, "240": 963, "241": 1330, "242": 2104, "243": 1546, "244": 1536, "245": 1112, "246": 2198, "247": 1557, "248": 2394, "249": 4043, "250": 973, "251": 1037, "252": 1245, "253": 893, "254": 1320, "255": 1786, "256": 1502, "257": 2445, "258": 901, "259": 979, "260": 2528, "261": 2585, "262": 1101, "263": 1395, "264": 704, "265": 1575, "266": 3117, "267": 444, "268": 1818, "269": 1481, "270": 1847, "271": 236, "272": 793, "273": 1754, "274": 3582, "275": 1189, "276": 2366, "277": 2724, "278": 1182, "279": 2588, "280": 1436, "281": 1694, "282": 2643, "283": 1362, "284": 1278, "285": 1431, "286": 1655, "287": 1815, "288": 2043, "289": 1961, "290": 2099, "291": 781, "292": 1283, "293": 1378, "294": 1267, "295": 1437, "296": 1237, "297": 2424, "298": 821, "299": 639, "300": 2081, "301": 2564, "302": 1405, "303": 2335, "304": 427, "305": 1496, "306": 2712, "307": 1383, "308": 2087, "309": 1251, "310": 3161, "311": 1798, "312": 2147, "313": 3003, "314": 2670, "315": 3370, "316": 1662, "317": 1229, "318": 1833, "319": 902, "320": 3354, "321": 1317, "322": 2782, "323": 1601, "324": 245, "325": 1036, "326": 1935, "327": 1534, "328": 1243, "329": 1521, "330": 977, "331": 116, "332": 2973, "333": 1250, "334": 1768, "335": 954, "336": 1676, "337": 2300, "338": 3001, "339": 974, "340": 1355, "341": 875, "342": 1332, "343": 1907, "344": 924, "345": 1232, "346": 1011, "347": 2778, "348": 1513, "349": 1480, "350": 2600, "351": 1790, "352": 1911, "353": 955, "354": 1040, "355": 3265, "356": 1285, "357": 1716, "358": 1866, "359": 775, "360": 1724, "361": 1481, "362": 2662, "363": 1263, "364": 1177, "365": 649, "366": 1185, "367": 1079, "368": 1886, "369": 749, "370": 1431, "371": 2963, "372": 1531, "373": 1979, "374": 1703, "375": 1162, "376": 1360, "377": 2444, "378": 1468, "379": 1629, "380": 2229, "381": 1770, "382": 2237, "383": 1678, "384": 1536, "385": 1205, "386": 1804, "387": 2050, "388": 2972, "389": 1641, "390": 2144, "391": 1686, "392": 1241, "393": 1078, "394": 1919, "395": 972, "396": 1364, "397": 1098, "398": 1133, "399": 2578, "400": 1753, "401": 1782, "402": 1326, "403": 1550, "404": 1851, "405": 2321, "406": 2278, "407": 2724, "408": 1948, "409": 528, "410": 2545, "411": 913, "412": 824, "413": 1219, "414": 1468, "415": 2270, "416": 1973, "417": 1297, "418": 1514, "419": 1739, "420": 134, "421": 1664, "422": 1495, "423": 1766, "424": 2873, "425": 1155, "426": 1088, "427": 855, "428": 731, "429": 2365, "430": 1568, "431": 1246, "432": 922, "433": 1280, "434": 2593, "435": 1477, "436": 1571, "437": 1200, "438": 1261, "439": 2174, "440": 2058, "441": 424, "442": 1558, "443": 2769, "444": 2360, "445": 2205, "446": 895, "447": 1126, "448": 910, "449": 2115, "450": 1016, "451": 1706, "452": 1242, "453": 1037, "454": 1670, "455": 1037, "456": 1513, "457": 2646, "458": 1795, "459": 2514, "460": 646, "461": 1359, "462": 1544, "463": 2126, "464": 1197, "465": 1878, "466": 874, "467": 938, "468": 2540, "469": 2227, "470": 844, "471": 774, "472": 2591, "473": 1225, "474": 726, "475": 1101, "476": 1420, "477": 278, "478": 2135, "479": 2439, "480": 1608, "481": 2708, "482": 1533, "483": 2620, "484": 1486, "485": 1371, "486": 2348, "487": 2527, "488": 1129, "489": 251, "490": 1151, "491": 1462, "492": 478, "493": 1156, "494": 1762, "495": 1664, "496": 2204, "497": 1791, "498": 2190, "499": 2491, "500": 1650, "501": 943, "502": 1191, "503": 1304, "504": 2010, "505": 1013, "506": 2841, "507": 1142, "508": 3109, "509": 2011, "510": 1332, "511": 2157, "512": 1532, "513": 957, "514": 1458, "515": 737, "516": 2179, "517": 1585, "518": 370, "519": 1841, "520": 905, "521": 973, "522": 1969, "523": 1151, "524": 1395, "525": 3552, "526": 1078, "527": 565, "528": 925, "529": 1929, "530": 2134, "531": 233, "532": 1895, "533": 896, "534": 1568, "535": 1935, "536": 1939, "537": 1726, "538": 1853, "539": 1490, "540": 2992, "541": 1373, "542": 699, "543": 691, "544": 330, "545": 1272, "546": 1492, "547": 1408, "548": 1434, "549": 1170, "550": 1289, "551": 1526, "552": 1775, "553": 3052, "554": 1008, "555": 1022, "556": 1425, "557": 2363, "558": 2122, "559": 1125, "560": 934, "561": 1192, "562": 830, "563": 2369, "564": 792, "565": 1511, "566": 1553, "567": 906, "568": 349, "569": 2097, "570": 1758, "571": 1223, "572": 2680, "573": 105, "574": 2475, "575": 2566, "576": 2932, "577": 454, "578": 892, "579": 976, "580": 2808, "581": 1066, "582": 2825, "583": 1751, "584": 704, "585": 2231, "586": 1139, "587": 502, "588": 1108, "589": 1845, "590": 1516, "591": 1772, "592": 2321, "593": 897, "594": 756, "595": 1112, "596": 1239, "597": 1372, "598": 3007, "599": 1603, "600": 1080, "601": 1507, "602": 294, "603": 1229, "604": 835, "605": 394, "606": 1520, "607": 775, "608": 1383, "609": 39, "610": 906, "611": 1610, "612": 2957, "613": 2102, "614": 1260, "615": 275, "616": 964, "617": 1113, "618": 1600, "619": 616, "620": 779, "621": 751, "622": 1288, "623": 173, "624": 970, "625": 2015, "626": 1932, "627": 1498, "628": 1271, "629": 1918, "630": 1635, "631": 2090, "632": 381, "633": 1223, "634": 1808, "635": 1412, "636": 269, "637": 1165, "638": 1177, "639": 1592, "640": 2262, "641": 1235, "642": 984, "643": 1453, "644": 699, "645": 1567, "646": 1452, "647": 1180, "648": 690, "649": 2349, "650": 1599, "651": 1887, "652": 1586, "653": 659, "654": 1613, "655": 2072, "656": 1750, "657": 2457, "658": 3565, "659": 1872}, "val": {"0": 672, "1": 973, "2": 837, "3": 969, "4": 999, "5": 662, "6": 449, "7": 1011, "8": 995, "9": 1245, "10": 755, "11": 666, "12": 1445, "13": 144, "14": 1512, "15": 1575, "16": 667, "17": 2210, "18": 370, "19": 885, "20": 578, "21": 1770, "22": 1192, "23": 1181, "24": 833, "25": 817, "26": 1233, "27": 384, "28": 832, "29": 896, "30": 455, "31": 1415, "32": 1191, "33": 1215, "34": 1297, "35": 1201, "36": 1328, "37": 1222, "38": 1147, "39": 2386, "40": 1610, "41": 1248, "42": 471, "43": 1778, "44": 1041, "45": 2568, "46": 1159, "47": 1009, "48": 1094, "49": 2073, "50": 645, "51": 609, "52": 1074, "53": 1979, "54": 1495, "55": 707, "56": 1583, "57": 405, "58": 1326, "59": 1647, "60": 1928, "61": 1832, "62": 2114, "63": 862, "64": 2806, "65": 2039, "66": 1533, "67": 2034, "68": 1412, "69": 1746, "70": 931, "71": 3014, "72": 2363, "73": 1819, "74": 1738, "75": 1549, "76": 1346, "77": 1599, "78": 1210, "79": 2029, "80": 1195, "81": 1182, "82": 2160, "83": 1787, "84": 1877, "85": 2031, "86": 752, "87": 1295, "88": 3721, "89": 672, "90": 1354, "91": 1913, "92": 1595, "93": 389, "94": 1000, "95": 2045, "96": 1747, "97": 1738, "98": 833, "99": 1458, "100": 1577, "101": 1708, "102": 862, "103": 1654, "104": 2359, "105": 1005, "106": 1633, "107": 2162, "108": 1395, "109": 2470, "110": 384, "111": 1242, "112": 1208, "113": 2101, "114": 1527, "115": 2232, "116": 1701, "117": 819, "118": 1384, "119": 925, "120": 1678, "121": 515, "122": 1851, "123": 2288, "124": 1641, "125": 1580, "126": 1180, "127": 1109, "128": 361, "129": 1388, "130": 660, "131": 893, "132": 1410, "133": 1663, "134": 1965, "135": 1478, "136": 1650, "137": 2073, "138": 2232, "139": 1277, "140": 1855, "141": 1072, "142": 944, "143": 1437, "144": 1716, "145": 2538, "146": 1144, "147": 1382, "148": 1664, "149": 1895, "150": 2057, "151": 1855, "152": 1714, "153": 1271, "154": 1120, "155": 1433, "156": 1182, "157": 327, "158": 1082, "159": 2962, "160": 1381, "161": 670, "162": 2313, "163": 1881, "164": 1712, "165": 1263, "166": 1363, "167": 1791, "168": 1408, "169": 1854, "170": 1308, "171": 1137, "172": 844, "173": 2051, "174": 1339, "175": 1687, "176": 643, "177": 1159, "178": 507, "179": 1804, "180": 1447, "181": 758, "182": 1423, "183": 1352, "184": 2007, "185": 1089, "186": 2516, "187": 1068, "188": 819, "189": 1408, "190": 199, "191": 1659, "192": 3241, "193": 598, "194": 1286, "195": 2756, "196": 1441, "197": 2038, "198": 1795, "199": 2044, "200": 2497, "201": 812, "202": 1882, "203": 479, "204": 695, "205": 265, "206": 1461, "207": 2226, "208": 516, "209": 915, "210": 1080, "211": 1316, "212": 1155, "213": 1526, "214": 1797, "215": 1464, "216": 903, "217": 1130, "218": 2684, "219": 1415}, "test": {"0": 667, "1": 996, "2": 1089, "3": 423, "4": 1101, "5": 1521, "6": 820, "7": 1144, "8": 1030, "9": 1250, "10": 793, "11": 1354, "12": 1244, "13": 1192, "14": 1080, "15": 889, "16": 397, "17": 1589, "18": 798, "19": 2202, "20": 549, "21": 607, "22": 1416, "23": 1083, "24": 610, "25": 398, "26": 2178, "27": 1425, "28": 698, "29": 1851, "30": 1843, "31": 1555, "32": 973, "33": 549, "34": 981, "35": 1875, "36": 1582, "37": 903, "38": 1746, "39": 1138, "40": 2430, "41": 1132, "42": 1779, "43": 2151, "44": 285, "45": 1277, "46": 1458, "47": 1328, "48": 427, "49": 1941, "50": 1001, "51": 1127, "52": 1784, "53": 1507, "54": 855, "55": 1684, "56": 3096, "57": 1022, "58": 844, "59": 2000, "60": 1302, "61": 817, "62": 2034, "63": 1684, "64": 2346, "65": 760, "66": 388, "67": 1212, "68": 1537, "69": 1463, "70": 2032, "71": 1070, "72": 1981, "73": 968, "74": 881, "75": 1780, "76": 473, "77": 1397, "78": 2398, "79": 1699, "80": 2243, "81": 1661, "82": 1151, "83": 2147, "84": 1063, "85": 1177, "86": 1455, "87": 2466, "88": 448, "89": 2357, "90": 1566, "91": 1471, "92": 2609, "93": 665, "94": 1646, "95": 1728, "96": 1828, "97": 1805, "98": 1238, "99": 1212, "100": 758, "101": 1441, "102": 1189, "103": 1119, "104": 1838, "105": 1978, "106": 798, "107": 1448, "108": 2431, "109": 1345, "110": 1398, "111": 2017, "112": 1115, "113": 1660, "114": 2135, "115": 2357, "116": 1828, "117": 258, "118": 2711, "119": 1087, "120": 2766, "121": 1449, "122": 2761, "123": 919, "124": 1030, "125": 1395, "126": 2532, "127": 1597, "128": 2476, "129": 1792, "130": 530, "131": 1880, "132": 1853, "133": 1828, "134": 1117, "135": 1365, "136": 1027, "137": 2257, "138": 2098, "139": 1318, "140": 1550, "141": 771, "142": 2014, "143": 2967, "144": 1361, "145": 1964, "146": 1443, "147": 1979, "148": 1288, "149": 1195, "150": 1287, "151": 1127, "152": 789, "153": 1443, "154": 2013, "155": 1955, "156": 1626, "157": 2092, "158": 1791, "159": 1764, "160": 1806, "161": 652, "162": 1551, "163": 2489, "164": 1200, "165": 1494, "166": 1449, "167": 1381, "168": 1413, "169": 2335, "170": 397, "171": 1851, "172": 1816, "173": 2610, "174": 2454, "175": 2585, "176": 1256, "177": 2873, "178": 1488, "179": 2359, "180": 1586, "181": 2040, "182": 767, "183": 1975, "184": 1910, "185": 1203, "186": 2218, "187": 1681, "188": 1621, "189": 2666, "190": 2549, "191": 885, "192": 319, "193": 1795, "194": 2884, "195": 405, "196": 728, "197": 542, "198": 1047, "199": 1605, "200": 631, "201": 635, "202": 1096, "203": 2396, "204": 1173, "205": 1874, "206": 1098, "207": 2455, "208": 2569, "209": 1628, "210": 2425, "211": 774, "212": 1237, "213": 1460, "214": 1809, "215": 792, "216": 1400, "217": 1083, "218": 2030, "219": 784}}
--------------------------------------------------------------------------------
/data/MovieNet_shot_num.json:
--------------------------------------------------------------------------------
1 | {
2 | "train": {
3 | "tt0035423": 1374,
4 | "tt0045537": 284,
5 | "tt0047396": 793,
6 | "tt0048605": 260,
7 | "tt0049730": 654,
8 | "tt0050706": 311,
9 | "tt0053125": 1197,
10 | "tt0056869": 1437,
11 | "tt0056923": 1236,
12 | "tt0057115": 1208,
13 | "tt0058461": 978,
14 | "tt0059043": 706,
15 | "tt0059592": 333,
16 | "tt0060522": 871,
17 | "tt0061138": 442,
18 | "tt0061418": 1620,
19 | "tt0061781": 660,
20 | "tt0062622": 645,
21 | "tt0064040": 470,
22 | "tt0064276": 955,
23 | "tt0064665": 1271,
24 | "tt0065214": 2469,
25 | "tt0065724": 837,
26 | "tt0065988": 2423,
27 | "tt0066249": 384,
28 | "tt0066921": 724,
29 | "tt0067116": 1166,
30 | "tt0067185": 668,
31 | "tt0068935": 936,
32 | "tt0069293": 325,
33 | "tt0069467": 338,
34 | "tt0069995": 1144,
35 | "tt0070047": 1124,
36 | "tt0070246": 620,
37 | "tt0070379": 843,
38 | "tt0070735": 1079,
39 | "tt0070849": 668,
40 | "tt0071129": 852,
41 | "tt0071315": 529,
42 | "tt0071360": 675,
43 | "tt0072684": 762,
44 | "tt0073440": 747,
45 | "tt0074119": 888,
46 | "tt0074285": 974,
47 | "tt0074686": 726,
48 | "tt0074749": 567,
49 | "tt0075148": 897,
50 | "tt0076729": 1045,
51 | "tt0077402": 2935,
52 | "tt0077405": 903,
53 | "tt0077416": 1449,
54 | "tt0077651": 620,
55 | "tt0078841": 700,
56 | "tt0078908": 791,
57 | "tt0079095": 596,
58 | "tt0079116": 814,
59 | "tt0079417": 922,
60 | "tt0079944": 142,
61 | "tt0079945": 1445,
62 | "tt0080339": 774,
63 | "tt0080453": 1128,
64 | "tt0080745": 1526,
65 | "tt0080958": 866,
66 | "tt0080979": 969,
67 | "tt0081505": 682,
68 | "tt0082186": 1646,
69 | "tt0082846": 756,
70 | "tt0082971": 1489,
71 | "tt0083658": 1381,
72 | "tt0083987": 1849,
73 | "tt0084549": 378,
74 | "tt0084628": 1279,
75 | "tt0084726": 1022,
76 | "tt0084787": 1017,
77 | "tt0085794": 793,
78 | "tt0086250": 1560,
79 | "tt0086837": 1311,
80 | "tt0086879": 2268,
81 | "tt0086969": 825,
82 | "tt0087182": 2016,
83 | "tt0087469": 1885,
84 | "tt0088170": 1115,
85 | "tt0088222": 70,
86 | "tt0088847": 1363,
87 | "tt0088939": 1191,
88 | "tt0088993": 1258,
89 | "tt0090022": 1310,
90 | "tt0090605": 2932,
91 | "tt0091042": 1365,
92 | "tt0091203": 1976,
93 | "tt0091251": 508,
94 | "tt0091406": 194,
95 | "tt0091738": 857,
96 | "tt0091763": 1474,
97 | "tt0092603": 738,
98 | "tt0093010": 1079,
99 | "tt0093209": 811,
100 | "tt0093565": 970,
101 | "tt0093748": 1108,
102 | "tt0093779": 1449,
103 | "tt0094226": 1178,
104 | "tt0094291": 832,
105 | "tt0095016": 1817,
106 | "tt0095497": 1229,
107 | "tt0095956": 2130,
108 | "tt0096463": 630,
109 | "tt0096754": 1874,
110 | "tt0096874": 1420,
111 | "tt0096895": 1724,
112 | "tt0097216": 827,
113 | "tt0097372": 467,
114 | "tt0097428": 1258,
115 | "tt0098258": 876,
116 | "tt0098635": 713,
117 | "tt0098724": 591,
118 | "tt0099348": 2476,
119 | "tt0099487": 863,
120 | "tt0099653": 1465,
121 | "tt0099674": 1386,
122 | "tt0099810": 1263,
123 | "tt0100150": 1164,
124 | "tt0100157": 1244,
125 | "tt0100234": 380,
126 | "tt0100802": 1891,
127 | "tt0100935": 960,
128 | "tt0100998": 300,
129 | "tt0101272": 1604,
130 | "tt0101393": 1756,
131 | "tt0101700": 868,
132 | "tt0101889": 1110,
133 | "tt0101921": 1133,
134 | "tt0102492": 1064,
135 | "tt0102926": 1158,
136 | "tt0103064": 2778,
137 | "tt0103772": 1299,
138 | "tt0103786": 1306,
139 | "tt0104036": 799,
140 | "tt0104257": 1488,
141 | "tt0104348": 869,
142 | "tt0105226": 1082,
143 | "tt0105665": 1532,
144 | "tt0105695": 1616,
145 | "tt0106332": 1147,
146 | "tt0106582": 1601,
147 | "tt0107507": 291,
148 | "tt0107653": 944,
149 | "tt0107736": 266,
150 | "tt0107808": 1760,
151 | "tt0107822": 946,
152 | "tt0108122": 1142,
153 | "tt0108289": 1713,
154 | "tt0108330": 1123,
155 | "tt0109686": 1031,
156 | "tt0109830": 992,
157 | "tt0109831": 1082,
158 | "tt0110074": 1170,
159 | "tt0110148": 1145,
160 | "tt0110201": 1412,
161 | "tt0110322": 1875,
162 | "tt0110632": 2196,
163 | "tt0110912": 1157,
164 | "tt0111003": 1227,
165 | "tt0112769": 675,
166 | "tt0113101": 1086,
167 | "tt0113243": 1293,
168 | "tt0113253": 1184,
169 | "tt0113497": 1706,
170 | "tt0114367": 1564,
171 | "tt0114369": 1610,
172 | "tt0114388": 1238,
173 | "tt0114814": 1131,
174 | "tt0115798": 1159,
175 | "tt0115964": 765,
176 | "tt0116209": 1759,
177 | "tt0116477": 2165,
178 | "tt0116629": 2540,
179 | "tt0116695": 2001,
180 | "tt0117500": 3034,
181 | "tt0117509": 1923,
182 | "tt0117666": 370,
183 | "tt0117731": 1430,
184 | "tt0117883": 1417,
185 | "tt0117951": 1217,
186 | "tt0118636": 1486,
187 | "tt0118655": 1965,
188 | "tt0118688": 2166,
189 | "tt0118689": 1308,
190 | "tt0118749": 919,
191 | "tt0118799": 1142,
192 | "tt0118845": 659,
193 | "tt0118883": 1704,
194 | "tt0118929": 2522,
195 | "tt0118971": 2180,
196 | "tt0119008": 1473,
197 | "tt0119081": 1799,
198 | "tt0119177": 1091,
199 | "tt0119250": 488,
200 | "tt0119314": 2244,
201 | "tt0119396": 1202,
202 | "tt0119528": 1237,
203 | "tt0119567": 1101,
204 | "tt0119643": 1555,
205 | "tt0119654": 1300,
206 | "tt0119670": 1383,
207 | "tt0119738": 1049,
208 | "tt0120263": 24,
209 | "tt0120338": 2890,
210 | "tt0120586": 1621,
211 | "tt0120591": 2368,
212 | "tt0120616": 2233,
213 | "tt0120655": 2023,
214 | "tt0120660": 3459,
215 | "tt0120667": 2298,
216 | "tt0120669": 1048,
217 | "tt0120696": 2021,
218 | "tt0120735": 1099,
219 | "tt0120737": 3902,
220 | "tt0120744": 1699,
221 | "tt0120755": 2776,
222 | "tt0120787": 1348,
223 | "tt0120804": 2103,
224 | "tt0120815": 1391,
225 | "tt0120885": 1257,
226 | "tt0120902": 1698,
227 | "tt0120912": 1442,
228 | "tt0120915": 2259,
229 | "tt0121766": 2231,
230 | "tt0122690": 1877,
231 | "tt0125439": 1462,
232 | "tt0125664": 1652,
233 | "tt0126886": 941,
234 | "tt0128445": 848,
235 | "tt0134119": 1612,
236 | "tt0134273": 1093,
237 | "tt0134847": 2223,
238 | "tt0137494": 1515,
239 | "tt0139134": 1141,
240 | "tt0139654": 1872,
241 | "tt0142688": 1021,
242 | "tt0143145": 2520,
243 | "tt0144084": 963,
244 | "tt0144117": 1330,
245 | "tt0159365": 2104,
246 | "tt0159784": 1546,
247 | "tt0160127": 1536,
248 | "tt0162346": 1112,
249 | "tt0162661": 2198,
250 | "tt0164052": 1557,
251 | "tt0167190": 2394,
252 | "tt0167260": 4043,
253 | "tt0167331": 973,
254 | "tt0169547": 1037,
255 | "tt0171363": 1245,
256 | "tt0175880": 893,
257 | "tt0180073": 1320,
258 | "tt0180093": 1786,
259 | "tt0181689": 1502,
260 | "tt0181875": 2445,
261 | "tt0183523": 901,
262 | "tt0183649": 979,
263 | "tt0187078": 2528,
264 | "tt0187393": 2585,
265 | "tt0190590": 1101,
266 | "tt0195685": 1395,
267 | "tt0199354": 704,
268 | "tt0199753": 1575,
269 | "tt0203009": 3117,
270 | "tt0206634": 444,
271 | "tt0207201": 1818,
272 | "tt0208092": 1481,
273 | "tt0209144": 1847,
274 | "tt0209463": 236,
275 | "tt0210727": 793,
276 | "tt0212338": 1754,
277 | "tt0213149": 3582,
278 | "tt0227445": 1189,
279 | "tt0232500": 2366,
280 | "tt0234215": 2724,
281 | "tt0240772": 1182,
282 | "tt0242653": 2588,
283 | "tt0243876": 1436,
284 | "tt0244244": 1694,
285 | "tt0244353": 2643,
286 | "tt0245844": 1362,
287 | "tt0246578": 1278,
288 | "tt0250494": 1431,
289 | "tt0250797": 1655,
290 | "tt0251160": 1815,
291 | "tt0253754": 2043,
292 | "tt0258000": 1961,
293 | "tt0264395": 2099,
294 | "tt0264616": 781,
295 | "tt0266697": 1283,
296 | "tt0266915": 1378,
297 | "tt0268695": 1267,
298 | "tt0272152": 1437,
299 | "tt0275719": 1237,
300 | "tt0278504": 2424,
301 | "tt0283509": 821,
302 | "tt0286106": 639,
303 | "tt0288477": 2081,
304 | "tt0290334": 2564,
305 | "tt0294870": 1405,
306 | "tt0299658": 2335,
307 | "tt0308476": 427,
308 | "tt0309698": 1496,
309 | "tt0313542": 2712,
310 | "tt0315327": 1383,
311 | "tt0316654": 2087,
312 | "tt0317198": 1251,
313 | "tt0317919": 3161,
314 | "tt0318627": 1798,
315 | "tt0318974": 2147,
316 | "tt0325710": 3003,
317 | "tt0325980": 2670,
318 | "tt0328107": 3370,
319 | "tt0329101": 1662,
320 | "tt0331811": 1229,
321 | "tt0332452": 1833,
322 | "tt0335266": 902,
323 | "tt0337978": 3354,
324 | "tt0338013": 1317,
325 | "tt0338751": 2782,
326 | "tt0343660": 1601,
327 | "tt0346094": 245,
328 | "tt0349903": 1036,
329 | "tt0350258": 1935,
330 | "tt0351977": 1534,
331 | "tt0357413": 1243,
332 | "tt0359950": 1521,
333 | "tt0362227": 977,
334 | "tt0363589": 116,
335 | "tt0363771": 2973,
336 | "tt0365907": 1250,
337 | "tt0369339": 1768,
338 | "tt0369702": 954,
339 | "tt0371257": 1676,
340 | "tt0372183": 2300,
341 | "tt0372784": 3001,
342 | "tt0372824": 974,
343 | "tt0373074": 1355,
344 | "tt0374546": 875,
345 | "tt0375679": 1332,
346 | "tt0376994": 1907,
347 | "tt0377713": 924,
348 | "tt0378194": 1232,
349 | "tt0379306": 1011,
350 | "tt0382625": 2778,
351 | "tt0383028": 1513,
352 | "tt0383216": 1480,
353 | "tt0383574": 2600,
354 | "tt0385004": 1790,
355 | "tt0387564": 1911,
356 | "tt0387877": 955,
357 | "tt0388795": 1040,
358 | "tt0390022": 3265,
359 | "tt0393109": 1285,
360 | "tt0395699": 1716,
361 | "tt0397078": 1866,
362 | "tt0398027": 775,
363 | "tt0399295": 1724,
364 | "tt0405159": 1481,
365 | "tt0407887": 2662,
366 | "tt0408306": 1263,
367 | "tt0408790": 1177,
368 | "tt0413893": 649,
369 | "tt0414055": 1185,
370 | "tt0414387": 1079,
371 | "tt0414982": 1886,
372 | "tt0415380": 749,
373 | "tt0417741": 1431,
374 | "tt0418279": 2963,
375 | "tt0418819": 1531,
376 | "tt0419887": 1979,
377 | "tt0420223": 1703,
378 | "tt0421715": 1162,
379 | "tt0424345": 1360,
380 | "tt0425061": 2444,
381 | "tt0425210": 1468,
382 | "tt0427309": 1629,
383 | "tt0427954": 2229,
384 | "tt0430357": 1770,
385 | "tt0433035": 2237,
386 | "tt0435705": 1678,
387 | "tt0439815": 1536,
388 | "tt0443453": 1205,
389 | "tt0443706": 1804,
390 | "tt0448157": 2050,
391 | "tt0449088": 2972,
392 | "tt0450259": 1641,
393 | "tt0450385": 2144,
394 | "tt0454841": 1686,
395 | "tt0454876": 1241,
396 | "tt0454921": 1078,
397 | "tt0457939": 1919,
398 | "tt0458413": 972,
399 | "tt0462200": 1364,
400 | "tt0467406": 1098,
401 | "tt0468565": 1133,
402 | "tt0468569": 2578,
403 | "tt0473705": 1753,
404 | "tt0475293": 1782,
405 | "tt0477348": 1326,
406 | "tt0479997": 1550,
407 | "tt0489018": 1851,
408 | "tt0493464": 2321,
409 | "tt0499448": 2278,
410 | "tt0499549": 2724,
411 | "tt0758774": 1948,
412 | "tt0765128": 528,
413 | "tt0765429": 2545,
414 | "tt0765447": 913,
415 | "tt0780504": 824,
416 | "tt0790636": 1219,
417 | "tt0790686": 1468,
418 | "tt0796366": 2270,
419 | "tt0800320": 1973,
420 | "tt0810819": 1297,
421 | "tt0815236": 1514,
422 | "tt0824747": 1739,
423 | "tt0826711": 134,
424 | "tt0829482": 1664,
425 | "tt0844286": 1495,
426 | "tt0846308": 1766,
427 | "tt0848228": 2873,
428 | "tt0862846": 1155,
429 | "tt0887883": 1088,
430 | "tt0913425": 855,
431 | "tt0914798": 731,
432 | "tt0942385": 2365,
433 | "tt0947798": 1568,
434 | "tt0958860": 1246,
435 | "tt0959337": 922,
436 | "tt0963794": 1280,
437 | "tt0963966": 2593,
438 | "tt0970416": 1477,
439 | "tt0974661": 1571,
440 | "tt0975645": 1200,
441 | "tt0977855": 1261,
442 | "tt0985694": 2174,
443 | "tt0985699": 2058,
444 | "tt0986233": 424,
445 | "tt0986263": 1558,
446 | "tt0988045": 2769,
447 | "tt0993846": 2360,
448 | "tt1010048": 2205,
449 | "tt1013753": 895,
450 | "tt1016268": 1126,
451 | "tt1022603": 910,
452 | "tt1024648": 2115,
453 | "tt1027718": 1016,
454 | "tt1029234": 1706,
455 | "tt1029360": 1242,
456 | "tt1037705": 1037,
457 | "tt1045658": 1670,
458 | "tt1045772": 1037,
459 | "tt1054606": 1513,
460 | "tt1055369": 2646,
461 | "tt1057500": 1795,
462 | "tt1059786": 2514,
463 | "tt1068649": 646,
464 | "tt1068680": 1359,
465 | "tt1072748": 1544,
466 | "tt1074638": 2126,
467 | "tt1084950": 1197,
468 | "tt1104001": 1878,
469 | "tt1124035": 874,
470 | "tt1125849": 938,
471 | "tt1131729": 2540,
472 | "tt1133985": 2227,
473 | "tt1135952": 844,
474 | "tt1139797": 774,
475 | "tt1148204": 2591,
476 | "tt1156466": 1225,
477 | "tt1158278": 726,
478 | "tt1174732": 1101,
479 | "tt1179031": 1420,
480 | "tt1179904": 278,
481 | "tt1186367": 2135,
482 | "tt1188729": 2439,
483 | "tt1193138": 1608,
484 | "tt1194173": 2708,
485 | "tt1210166": 1533,
486 | "tt1217613": 2620,
487 | "tt1219289": 1486,
488 | "tt1220719": 1371,
489 | "tt1228705": 2348,
490 | "tt1229340": 2527,
491 | "tt1229822": 1129,
492 | "tt1233381": 251,
493 | "tt1244754": 1151,
494 | "tt1253863": 1462,
495 | "tt1255953": 478,
496 | "tt1274586": 1156,
497 | "tt1276104": 1762,
498 | "tt1282140": 1664,
499 | "tt1285016": 2204,
500 | "tt1291150": 1791,
501 | "tt1291584": 2190,
502 | "tt1298650": 2491,
503 | "tt1300851": 1650,
504 | "tt1305806": 943,
505 | "tt1306980": 1191,
506 | "tt1322269": 1304,
507 | "tt1324999": 2010,
508 | "tt1340800": 1013,
509 | "tt1343092": 2841,
510 | "tt1360860": 1142,
511 | "tt1371111": 3109,
512 | "tt1375670": 2011,
513 | "tt1396218": 1332,
514 | "tt1401152": 2157,
515 | "tt1403865": 1532,
516 | "tt1411238": 957,
517 | "tt1438176": 1458,
518 | "tt1439572": 737,
519 | "tt1446714": 2179,
520 | "tt1454029": 1585,
521 | "tt1454468": 370,
522 | "tt1458175": 1841,
523 | "tt1462758": 905,
524 | "tt1468846": 973,
525 | "tt1478338": 1969,
526 | "tt1486190": 1151,
527 | "tt1502712": 1395,
528 | "tt1533117": 3552,
529 | "tt1535970": 1078,
530 | "tt1560747": 565,
531 | "tt1563738": 925,
532 | "tt1564367": 1929,
533 | "tt1568346": 2134,
534 | "tt1602620": 233,
535 | "tt1606378": 1895,
536 | "tt1615147": 896,
537 | "tt1616195": 1568,
538 | "tt1628841": 1935,
539 | "tt1637725": 1939,
540 | "tt1646987": 1726,
541 | "tt1649443": 1853,
542 | "tt1655420": 1490,
543 | "tt1670345": 2992,
544 | "tt1675434": 1373,
545 | "tt1692486": 699,
546 | "tt1706593": 691,
547 | "tt1723811": 330,
548 | "tt1747958": 1272,
549 | "tt1757746": 1492,
550 | "tt1781769": 1408,
551 | "tt1800241": 1434,
552 | "tt1800246": 1170,
553 | "tt1809398": 1289,
554 | "tt1832382": 1526,
555 | "tt1855325": 1775,
556 | "tt1877832": 3052,
557 | "tt1907668": 1008,
558 | "tt1951266": 1022,
559 | "tt1971325": 1425,
560 | "tt1979320": 2363,
561 | "tt1981115": 2122,
562 | "tt2017561": 1125,
563 | "tt2053463": 934,
564 | "tt2056771": 1192,
565 | "tt2058107": 830,
566 | "tt2058673": 2369,
567 | "tt2059255": 792,
568 | "tt2070649": 1511,
569 | "tt2084970": 1553,
570 | "tt2103281": 906,
571 | "tt2109184": 349,
572 | "tt2118775": 2097,
573 | "tt2140373": 1758,
574 | "tt2167266": 1223,
575 | "tt2238032": 2680,
576 | "tt2258281": 105,
577 | "tt2267998": 2475,
578 | "tt2294449": 2566,
579 | "tt2310332": 2932,
580 | "tt2334873": 454,
581 | "tt2345567": 892,
582 | "tt2366450": 976,
583 | "tt2381249": 2808,
584 | "tt2382298": 1066,
585 | "tt2404435": 2825,
586 | "tt2463288": 1751,
587 | "tt2473794": 704,
588 | "tt2567026": 2231,
589 | "tt2582802": 1139,
590 | "tt2639344": 502,
591 | "tt2675914": 1108,
592 | "tt2713180": 1845,
593 | "tt2717822": 1516,
594 | "tt2800240": 1772,
595 | "tt2823054": 2321,
596 | "tt2884018": 897,
597 | "tt2908856": 756,
598 | "tt2911666": 1112,
599 | "tt2923316": 1239,
600 | "tt2980516": 1372,
601 | "tt3062096": 3007,
602 | "tt3064298": 1603,
603 | "tt3077214": 1080,
604 | "tt3289956": 1507,
605 | "tt3296658": 294,
606 | "tt3312830": 1229,
607 | "tt3319920": 835,
608 | "tt3395184": 394,
609 | "tt3410834": 1520,
610 | "tt3416744": 775,
611 | "tt3439114": 1383,
612 | "tt3465916": 39,
613 | "tt3474602": 906,
614 | "tt3478232": 1610,
615 | "tt3498820": 2957,
616 | "tt3501416": 2102,
617 | "tt3531578": 1260,
618 | "tt3630276": 275,
619 | "tt3659786": 964,
620 | "tt3671542": 1113,
621 | "tt3700392": 1600,
622 | "tt3700804": 616,
623 | "tt3707106": 779,
624 | "tt3714720": 751,
625 | "tt3766394": 1288,
626 | "tt3808342": 173,
627 | "tt3860916": 970,
628 | "tt3960412": 2015,
629 | "tt4046784": 1932,
630 | "tt4052882": 1498,
631 | "tt4136084": 1271,
632 | "tt4151192": 1918,
633 | "tt4176826": 1635,
634 | "tt4242158": 2090,
635 | "tt4273292": 381,
636 | "tt4501454": 1223,
637 | "tt4651520": 1808,
638 | "tt4698684": 1412,
639 | "tt4721400": 269,
640 | "tt4781612": 1165,
641 | "tt4786282": 1177,
642 | "tt4824302": 1592,
643 | "tt4939066": 2262,
644 | "tt5052448": 1235,
645 | "tt5065810": 984,
646 | "tt5294550": 1453,
647 | "tt5564148": 699,
648 | "tt5576318": 1567,
649 | "tt5580036": 1452,
650 | "tt5593416": 1180,
651 | "tt5649144": 690,
652 | "tt5688868": 2349,
653 | "tt5827496": 1599,
654 | "tt5866930": 1887,
655 | "tt6133130": 1586,
656 | "tt6298600": 659,
657 | "tt6466464": 1613,
658 | "tt6513406": 2072,
659 | "tt6788942": 1750,
660 | "tt7055592": 2457,
661 | "tt7131870": 3565,
662 | "tt7180392": 1872
663 | },
664 | "val": {
665 | "tt0032138": 672,
666 | "tt0038650": 973,
667 | "tt0048545": 837,
668 | "tt0053221": 969,
669 | "tt0053579": 999,
670 | "tt0054167": 662,
671 | "tt0061722": 449,
672 | "tt0064115": 1011,
673 | "tt0066026": 995,
674 | "tt0067140": 1245,
675 | "tt0069762": 755,
676 | "tt0070245": 666,
677 | "tt0071562": 1445,
678 | "tt0072443": 144,
679 | "tt0072890": 1512,
680 | "tt0073486": 1575,
681 | "tt0074811": 667,
682 | "tt0076759": 2210,
683 | "tt0079182": 370,
684 | "tt0079470": 885,
685 | "tt0080610": 578,
686 | "tt0080684": 1770,
687 | "tt0083866": 1192,
688 | "tt0083929": 1181,
689 | "tt0084899": 833,
690 | "tt0085991": 817,
691 | "tt0087332": 1233,
692 | "tt0089853": 384,
693 | "tt0089907": 832,
694 | "tt0090756": 896,
695 | "tt0091355": 455,
696 | "tt0091369": 1415,
697 | "tt0092699": 1191,
698 | "tt0092991": 1215,
699 | "tt0094737": 1297,
700 | "tt0094761": 1201,
701 | "tt0095765": 1328,
702 | "tt0095953": 1222,
703 | "tt0096256": 1147,
704 | "tt0096446": 2386,
705 | "tt0097576": 1610,
706 | "tt0099685": 1248,
707 | "tt0100112": 471,
708 | "tt0100403": 1778,
709 | "tt0101410": 1041,
710 | "tt0102138": 2568,
711 | "tt0103074": 1159,
712 | "tt0103241": 1009,
713 | "tt0103292": 1094,
714 | "tt0104797": 2073,
715 | "tt0105236": 645,
716 | "tt0105652": 609,
717 | "tt0106226": 1074,
718 | "tt0106977": 1979,
719 | "tt0107614": 1495,
720 | "tt0108160": 707,
721 | "tt0108656": 1583,
722 | "tt0109020": 405,
723 | "tt0110475": 1326,
724 | "tt0110932": 1647,
725 | "tt0112462": 1928,
726 | "tt0112641": 1832,
727 | "tt0112740": 2114,
728 | "tt0113870": 862,
729 | "tt0114558": 2806,
730 | "tt0116367": 2039,
731 | "tt0116996": 1533,
732 | "tt0117381": 2034,
733 | "tt0118548": 1412,
734 | "tt0118571": 1746,
735 | "tt0118842": 931,
736 | "tt0119094": 3014,
737 | "tt0119116": 2363,
738 | "tt0119174": 1819,
739 | "tt0119822": 1738,
740 | "tt0120483": 1549,
741 | "tt0120601": 1346,
742 | "tt0120780": 1599,
743 | "tt0120863": 1210,
744 | "tt0121765": 2029,
745 | "tt0122933": 1195,
746 | "tt0129387": 1182,
747 | "tt0133093": 2160,
748 | "tt0138097": 1787,
749 | "tt0140352": 1877,
750 | "tt0145487": 2031,
751 | "tt0166896": 752,
752 | "tt0166924": 1295,
753 | "tt0167261": 3721,
754 | "tt0167404": 672,
755 | "tt0182789": 1354,
756 | "tt0186151": 1913,
757 | "tt0209958": 1595,
758 | "tt0217869": 389,
759 | "tt0240890": 1000,
760 | "tt0248667": 2045,
761 | "tt0258463": 1747,
762 | "tt0261392": 1738,
763 | "tt0265666": 833,
764 | "tt0268126": 1458,
765 | "tt0268978": 1577,
766 | "tt0277027": 1708,
767 | "tt0285742": 862,
768 | "tt0289879": 1654,
769 | "tt0290002": 2359,
770 | "tt0298228": 1005,
771 | "tt0311113": 1633,
772 | "tt0317740": 2162,
773 | "tt0319262": 1395,
774 | "tt0322259": 2470,
775 | "tt0324197": 384,
776 | "tt0337921": 1242,
777 | "tt0341495": 1208,
778 | "tt0343818": 2101,
779 | "tt0360486": 1527,
780 | "tt0370263": 2232,
781 | "tt0371724": 1701,
782 | "tt0375063": 819,
783 | "tt0395169": 1384,
784 | "tt0401383": 925,
785 | "tt0408236": 1678,
786 | "tt0416320": 515,
787 | "tt0432021": 1851,
788 | "tt0434409": 2288,
789 | "tt0454848": 1641,
790 | "tt0455760": 1580,
791 | "tt0457297": 1180,
792 | "tt0457430": 1109,
793 | "tt0457513": 361,
794 | "tt0467200": 1388,
795 | "tt0469494": 660,
796 | "tt0470752": 893,
797 | "tt0480025": 1410,
798 | "tt0758730": 1663,
799 | "tt0758758": 1965,
800 | "tt0780653": 1478,
801 | "tt0790628": 1650,
802 | "tt0808151": 2073,
803 | "tt0816692": 2232,
804 | "tt0824758": 1277,
805 | "tt0838232": 1855,
806 | "tt0898367": 1072,
807 | "tt0940709": 944,
808 | "tt0964517": 1437,
809 | "tt0993842": 1716,
810 | "tt1000774": 2538,
811 | "tt1019452": 1144,
812 | "tt1032755": 1382,
813 | "tt1041829": 1664,
814 | "tt1055292": 1895,
815 | "tt1065073": 2057,
816 | "tt1071875": 1855,
817 | "tt1073498": 1714,
818 | "tt1093906": 1271,
819 | "tt1100089": 1120,
820 | "tt1144884": 1433,
821 | "tt1172049": 1182,
822 | "tt1178663": 327,
823 | "tt1182345": 1082,
824 | "tt1190080": 2962,
825 | "tt1211837": 1381,
826 | "tt1216496": 670,
827 | "tt1232829": 2313,
828 | "tt1284575": 1881,
829 | "tt1341167": 1712,
830 | "tt1355683": 1263,
831 | "tt1385826": 1363,
832 | "tt1409024": 1791,
833 | "tt1441953": 1408,
834 | "tt1462900": 1854,
835 | "tt1504320": 1308,
836 | "tt1540133": 1137,
837 | "tt1582248": 844,
838 | "tt1586752": 2051,
839 | "tt1591095": 1339,
840 | "tt1596363": 1687,
841 | "tt1602613": 643,
842 | "tt1611840": 1159,
843 | "tt1619029": 507,
844 | "tt1645170": 1804,
845 | "tt1659337": 1447,
846 | "tt1703957": 758,
847 | "tt1722484": 1423,
848 | "tt1725986": 1352,
849 | "tt1731141": 2007,
850 | "tt1742683": 1089,
851 | "tt1840309": 2516,
852 | "tt1895587": 1068,
853 | "tt1974419": 819,
854 | "tt2032557": 1408,
855 | "tt2076220": 199,
856 | "tt2078768": 1659,
857 | "tt2109248": 3241,
858 | "tt2132285": 598,
859 | "tt2381991": 1286,
860 | "tt2645044": 2756,
861 | "tt2788732": 1441,
862 | "tt2832470": 2038,
863 | "tt2872732": 1795,
864 | "tt2978462": 2044,
865 | "tt3110958": 2497,
866 | "tt3316960": 812,
867 | "tt3421514": 1882,
868 | "tt3464902": 479,
869 | "tt3488710": 695,
870 | "tt3508840": 265,
871 | "tt3553442": 1461,
872 | "tt3672840": 2226,
873 | "tt3726704": 516,
874 | "tt3824458": 915,
875 | "tt3882082": 1080,
876 | "tt3922798": 1316,
877 | "tt4160708": 1155,
878 | "tt4647900": 1526,
879 | "tt4967094": 1797,
880 | "tt5726086": 1464,
881 | "tt6121428": 903,
882 | "tt6190198": 1130,
883 | "tt7160070": 2684,
884 | "tt7672188": 1415
885 | },
886 | "test": {
887 | "tt0048028": 667,
888 | "tt0049470": 996,
889 | "tt0049833": 1089,
890 | "tt0050419": 423,
891 | "tt0052357": 1101,
892 | "tt0058331": 1521,
893 | "tt0061811": 820,
894 | "tt0063442": 1144,
895 | "tt0066206": 1030,
896 | "tt0068646": 1250,
897 | "tt0070291": 793,
898 | "tt0070511": 1354,
899 | "tt0073195": 1244,
900 | "tt0073582": 1192,
901 | "tt0073629": 1080,
902 | "tt0075314": 889,
903 | "tt0075686": 397,
904 | "tt0078788": 1589,
905 | "tt0079672": 798,
906 | "tt0080455": 2202,
907 | "tt0080761": 549,
908 | "tt0082089": 607,
909 | "tt0082198": 1416,
910 | "tt0083907": 1083,
911 | "tt0083946": 610,
912 | "tt0084390": 398,
913 | "tt0086190": 2178,
914 | "tt0086856": 1425,
915 | "tt0087921": 698,
916 | "tt0088247": 1851,
917 | "tt0088944": 1843,
918 | "tt0089218": 1555,
919 | "tt0089881": 973,
920 | "tt0090257": 549,
921 | "tt0091867": 981,
922 | "tt0092099": 1875,
923 | "tt0093773": 1582,
924 | "tt0094964": 903,
925 | "tt0095250": 1746,
926 | "tt0096320": 1138,
927 | "tt0099423": 2430,
928 | "tt0100405": 1132,
929 | "tt0103776": 1779,
930 | "tt0103855": 2151,
931 | "tt0104466": 285,
932 | "tt0104553": 1277,
933 | "tt0104691": 1458,
934 | "tt0107290": 1328,
935 | "tt0107617": 427,
936 | "tt0108399": 1941,
937 | "tt0110116": 1001,
938 | "tt0110167": 1127,
939 | "tt0110604": 1784,
940 | "tt0111280": 1507,
941 | "tt0111797": 855,
942 | "tt0112384": 1684,
943 | "tt0112573": 3096,
944 | "tt0112818": 1022,
945 | "tt0112883": 844,
946 | "tt0113277": 2000,
947 | "tt0114746": 1302,
948 | "tt0115734": 817,
949 | "tt0115759": 2034,
950 | "tt0115956": 1684,
951 | "tt0116213": 2346,
952 | "tt0116282": 760,
953 | "tt0116767": 388,
954 | "tt0116922": 1212,
955 | "tt0117060": 1537,
956 | "tt0117571": 1463,
957 | "tt0118583": 2032,
958 | "tt0118715": 1070,
959 | "tt0119303": 1981,
960 | "tt0119349": 968,
961 | "tt0119375": 881,
962 | "tt0119488": 1780,
963 | "tt0120255": 473,
964 | "tt0120382": 1397,
965 | "tt0120689": 2398,
966 | "tt0120731": 1699,
967 | "tt0120738": 2243,
968 | "tt0120812": 1661,
969 | "tt0120890": 1151,
970 | "tt0120903": 2147,
971 | "tt0123755": 1063,
972 | "tt0124315": 1177,
973 | "tt0127536": 1455,
974 | "tt0133152": 2466,
975 | "tt0137439": 448,
976 | "tt0137523": 2357,
977 | "tt0142342": 1566,
978 | "tt0163025": 1471,
979 | "tt0172495": 2609,
980 | "tt0178868": 665,
981 | "tt0190332": 1646,
982 | "tt0195714": 1728,
983 | "tt0212985": 1828,
984 | "tt0217505": 1805,
985 | "tt0219822": 1238,
986 | "tt0253474": 1212,
987 | "tt0257360": 758,
988 | "tt0280609": 1441,
989 | "tt0281358": 1189,
990 | "tt0281686": 1119,
991 | "tt0319061": 1838,
992 | "tt0330373": 1978,
993 | "tt0335119": 798,
994 | "tt0361748": 1448,
995 | "tt0368891": 2431,
996 | "tt0368933": 1345,
997 | "tt0369441": 1398,
998 | "tt0370032": 2017,
999 | "tt0373051": 1115,
1000 | "tt0373469": 1660,
1001 | "tt0379786": 2135,
1002 | "tt0381061": 2357,
1003 | "tt0386588": 1828,
1004 | "tt0387898": 258,
1005 | "tt0399201": 2711,
1006 | "tt0404978": 1087,
1007 | "tt0409459": 2766,
1008 | "tt0416508": 1449,
1009 | "tt0440963": 2761,
1010 | "tt0443272": 919,
1011 | "tt0443680": 1030,
1012 | "tt0452625": 1395,
1013 | "tt0455824": 2532,
1014 | "tt0458352": 1597,
1015 | "tt0458525": 2476,
1016 | "tt0460791": 1792,
1017 | "tt0460989": 530,
1018 | "tt0462499": 1880,
1019 | "tt0477347": 1853,
1020 | "tt0479884": 1828,
1021 | "tt0481369": 1117,
1022 | "tt0780571": 1365,
1023 | "tt0783233": 1027,
1024 | "tt0800080": 2257,
1025 | "tt0800369": 2098,
1026 | "tt0815245": 1318,
1027 | "tt0822832": 1550,
1028 | "tt0844347": 771,
1029 | "tt0878804": 2014,
1030 | "tt0903624": 2967,
1031 | "tt0905372": 1361,
1032 | "tt0944835": 1964,
1033 | "tt0945513": 1443,
1034 | "tt0970179": 1979,
1035 | "tt0976051": 1288,
1036 | "tt1001508": 1195,
1037 | "tt1007029": 1287,
1038 | "tt1017460": 1127,
1039 | "tt1033575": 789,
1040 | "tt1034314": 1443,
1041 | "tt1038919": 2013,
1042 | "tt1046173": 1955,
1043 | "tt1063669": 1626,
1044 | "tt1086772": 2092,
1045 | "tt1092026": 1791,
1046 | "tt1099212": 1764,
1047 | "tt1119646": 1806,
1048 | "tt1120985": 652,
1049 | "tt1124037": 1551,
1050 | "tt1170358": 2489,
1051 | "tt1181614": 1200,
1052 | "tt1189340": 1494,
1053 | "tt1201607": 1449,
1054 | "tt1205489": 1381,
1055 | "tt1220634": 1413,
1056 | "tt1229238": 2335,
1057 | "tt1287878": 397,
1058 | "tt1292566": 1851,
1059 | "tt1318514": 1816,
1060 | "tt1375666": 2610,
1061 | "tt1386932": 2454,
1062 | "tt1392190": 2585,
1063 | "tt1397514": 1256,
1064 | "tt1399103": 2873,
1065 | "tt1412386": 1488,
1066 | "tt1413492": 2359,
1067 | "tt1424381": 1586,
1068 | "tt1431045": 2040,
1069 | "tt1440728": 767,
1070 | "tt1446147": 1975,
1071 | "tt1483013": 1910,
1072 | "tt1510906": 1203,
1073 | "tt1524137": 2218,
1074 | "tt1570728": 1681,
1075 | "tt1623205": 1621,
1076 | "tt1663662": 2666,
1077 | "tt1707386": 2549,
1078 | "tt1748122": 885,
1079 | "tt1843287": 319,
1080 | "tt1853728": 1795,
1081 | "tt1872181": 2884,
1082 | "tt2011351": 405,
1083 | "tt2024544": 728,
1084 | "tt2099556": 542,
1085 | "tt2115388": 1047,
1086 | "tt2194499": 1605,
1087 | "tt2402927": 631,
1088 | "tt2409818": 635,
1089 | "tt2446980": 1096,
1090 | "tt2488496": 2396,
1091 | "tt2567712": 1173,
1092 | "tt2582846": 1874,
1093 | "tt2614684": 1098,
1094 | "tt2802144": 2455,
1095 | "tt3385516": 2569,
1096 | "tt3480796": 1628,
1097 | "tt3495026": 2425,
1098 | "tt4008652": 774,
1099 | "tt4034354": 1237,
1100 | "tt4520364": 1460,
1101 | "tt4915672": 1809,
1102 | "tt4972062": 792,
1103 | "tt5140878": 1400,
1104 | "tt6157626": 1083,
1105 | "tt6518634": 2030,
1106 | "tt6644200": 784
1107 | }
1108 | }
--------------------------------------------------------------------------------
/data/data_preparation.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 | import json
5 |
6 | # Concate 16 shot images into a single image,
7 | # the concated images are used for speeding up pre-training.
8 | # Matrix size of the concated image: [16x3]
9 | def concate_pic(shot_info, img_path, save_path, row=16):
10 | for imdb, shot_num in shot_info.items():
11 | pic_num = shot_num // row
12 | for item in range(pic_num):
13 | img_list = []
14 | for idx in range(row):
15 | shot_id = item * row + idx
16 | img_name_0 = f"{img_path}/{imdb}/shot_{str(shot_id).zfill(4)}_img_0.jpg"
17 | img_name_1 = f"{img_path}/{imdb}/shot_{str(shot_id).zfill(4)}_img_1.jpg"
18 | img_name_2 = f"{img_path}/{imdb}/shot_{str(shot_id).zfill(4)}_img_2.jpg"
19 | img_0 = cv2.imread(img_name_0)
20 | img_1 = cv2.imread(img_name_1)
21 | img_2 = cv2.imread(img_name_2)
22 | img = np.concatenate([img_0,img_1,img_2],axis=1)
23 | img_list.append(img)
24 | full_img = np.concatenate(img_list,axis=0)
25 | # print(img.shape)
26 | # print(full_img.shape)
27 | new_pic_dir = f"{save_path}/{imdb}/"
28 | if not os.path.isdir(new_pic_dir):
29 | os.makedirs(new_pic_dir)
30 | filename = new_pic_dir + str(item).zfill(4) + '.jpg'
31 | cv2.imwrite(filename, full_img)
32 |
33 | # Number of shot in each movie
34 | def _generate_shot_num(new_shot_info='./MovieNet_shot_num.json'):
35 | shot_info = './MovieNet_1.0_shotinfo.json'
36 | shot_split = './movie1K.split.v1.json'
37 | with open(shot_info, 'rb') as f:
38 | shot_info_data = json.load(f)
39 | with open(shot_split, 'rb') as f:
40 | shot_split_data = json.load(f)
41 | new_shot_info_data = {}
42 | _type = ['train','val','test']
43 | for _t in _type:
44 | new_shot_info_data[_t] = {}
45 | _movie_list = shot_split_data[_t]
46 | for idx, imdb_id in enumerate(_movie_list):
47 | shot_num = shot_info_data[_t][str(idx)]
48 | new_shot_info_data[_t][imdb_id] = shot_num
49 | with open(new_shot_info, 'w') as f:
50 | json.dump(new_shot_info_data, f, indent=4)
51 |
52 |
53 | def process_raw_label(_T = 'train', raw_root_dir = './'):
54 | split = 'movie1K.split.v1.json'
55 | data_dict = json.load(open(os.path.join(raw_root_dir,split)))
56 |
57 | # print(data_dict.keys())
58 | # dict_keys(['train', 'val', 'test', 'full'])
59 | # print(len(data_dict['train'])) # 660
60 | # print(len(data_dict['val'])) # 220
61 | # print(len(data_dict['test'])) # 220
62 | # print(len(data_dict['full'])) # 1100
63 |
64 | data_list = data_dict[_T]
65 |
66 | # annotation
67 | annotation_path = 'annotation'
68 | count = 0
69 | video_list = []
70 | # all annotations
71 | for index,name in enumerate(data_list):
72 | # print(name)
73 | annotation_file = os.path.join(raw_root_dir, annotation_path, name+'.json')
74 | data = json.load(open(annotation_file))
75 | # only need sence seg labels
76 | if data['scene'] is not None:
77 | video_list.append({'name':name,'index':index})
78 | count += 1
79 | print(f'scene annotations num: {count}')
80 | return video_list
81 |
82 |
83 |
84 | # GT generation
85 | def process_scene_seg_lable(scene_seg_path = './CVPR20SceneSeg/data/scene318/label318',
86 | scene_seg_label_json_name = './movie1K.scene_seg_318_name_index_shotnum_label.v1.json',
87 | raw_root_dir = './MovieNet'):
88 | def _process(data):
89 | seg_label = []
90 | for i in data:
91 | name = i['name']
92 | index = i['index']
93 | label = []
94 | with open (os.path.join(scene_seg_path,name+'.txt'), 'r') as f:
95 | shotnum_label = f.readlines()
96 | for i in shotnum_label:
97 | if ' ' in i:
98 | shot_id = i.split(' ')[0].strip()
99 | l = i.split(' ')[1].strip()
100 | label.append((shot_id,l))
101 | shot_count = len(label) + 1
102 | seg_label.append({"name":name, "index":index, "shot_count":shot_count, "label":label })
103 | return seg_label
104 |
105 | train_list = process_raw_label('train',raw_root_dir)
106 | val_list = process_raw_label('val',raw_root_dir)
107 | test_list = process_raw_label('test',raw_root_dir)
108 | data = {'train':train_list, 'val':val_list, 'test':test_list}
109 |
110 | # CVPR20SceneSeg GT
111 | train = _process(data['train'])
112 | test = _process(data['test'])
113 | val = _process(data['val'])
114 | d_all = {'train':train, 'val':val, 'test':test}
115 |
116 | with open(scene_seg_label_json_name,'w') as f:
117 | f.write(json.dumps(d_all))
118 |
119 |
120 |
121 | if __name__ == '__main__':
122 | # Path of movienet images
123 | img_path = '/MovieNet_unzip/240P'
124 |
125 | # Shot number
126 | shot_info = './MovieNet_shot_num.json'
127 | _generate_shot_num(shot_info)
128 |
129 | # GT label
130 | scene_seg_label_json_name = './movie1K.scene_seg_318_name_index_shotnum_label.v1.json'
131 | ## Download LGSS Annotation from: https://github.com/AnyiRao/SceneSeg/blob/master/docs/INSTALL.md
132 | ## 'scene_seg_path' is the path of the downloaded annotations
133 | scene_seg_path = './CVPR20SceneSeg/data/scene318/label318'
134 | ## Path of raw MovieNet
135 | raw_root_dir = './MovieNet/MovieNet_Ori'
136 | process_scene_seg_lable(scene_seg_path ,scene_seg_label_json_name, raw_root_dir)
137 |
138 | # Concate images
139 | save_path = './compressed_shot_images'
140 | with open(shot_info, 'rb') as f:
141 | shot_info_data = json.load(f)
142 | concate_pic(shot_info_data['train'], img_path, save_path)
143 |
144 |
--------------------------------------------------------------------------------
/data/movie1K.split.v1.json:
--------------------------------------------------------------------------------
1 | {
2 | "train": [
3 | "tt0035423",
4 | "tt0045537",
5 | "tt0047396",
6 | "tt0048605",
7 | "tt0049730",
8 | "tt0050706",
9 | "tt0053125",
10 | "tt0056869",
11 | "tt0056923",
12 | "tt0057115",
13 | "tt0058461",
14 | "tt0059043",
15 | "tt0059592",
16 | "tt0060522",
17 | "tt0061138",
18 | "tt0061418",
19 | "tt0061781",
20 | "tt0062622",
21 | "tt0064040",
22 | "tt0064276",
23 | "tt0064665",
24 | "tt0065214",
25 | "tt0065724",
26 | "tt0065988",
27 | "tt0066249",
28 | "tt0066921",
29 | "tt0067116",
30 | "tt0067185",
31 | "tt0068935",
32 | "tt0069293",
33 | "tt0069467",
34 | "tt0069995",
35 | "tt0070047",
36 | "tt0070246",
37 | "tt0070379",
38 | "tt0070735",
39 | "tt0070849",
40 | "tt0071129",
41 | "tt0071315",
42 | "tt0071360",
43 | "tt0072684",
44 | "tt0073440",
45 | "tt0074119",
46 | "tt0074285",
47 | "tt0074686",
48 | "tt0074749",
49 | "tt0075148",
50 | "tt0076729",
51 | "tt0077402",
52 | "tt0077405",
53 | "tt0077416",
54 | "tt0077651",
55 | "tt0078841",
56 | "tt0078908",
57 | "tt0079095",
58 | "tt0079116",
59 | "tt0079417",
60 | "tt0079944",
61 | "tt0079945",
62 | "tt0080339",
63 | "tt0080453",
64 | "tt0080745",
65 | "tt0080958",
66 | "tt0080979",
67 | "tt0081505",
68 | "tt0082186",
69 | "tt0082846",
70 | "tt0082971",
71 | "tt0083658",
72 | "tt0083987",
73 | "tt0084549",
74 | "tt0084628",
75 | "tt0084726",
76 | "tt0084787",
77 | "tt0085794",
78 | "tt0086250",
79 | "tt0086837",
80 | "tt0086879",
81 | "tt0086969",
82 | "tt0087182",
83 | "tt0087469",
84 | "tt0088170",
85 | "tt0088222",
86 | "tt0088847",
87 | "tt0088939",
88 | "tt0088993",
89 | "tt0090022",
90 | "tt0090605",
91 | "tt0091042",
92 | "tt0091203",
93 | "tt0091251",
94 | "tt0091406",
95 | "tt0091738",
96 | "tt0091763",
97 | "tt0092603",
98 | "tt0093010",
99 | "tt0093209",
100 | "tt0093565",
101 | "tt0093748",
102 | "tt0093779",
103 | "tt0094226",
104 | "tt0094291",
105 | "tt0095016",
106 | "tt0095497",
107 | "tt0095956",
108 | "tt0096463",
109 | "tt0096754",
110 | "tt0096874",
111 | "tt0096895",
112 | "tt0097216",
113 | "tt0097372",
114 | "tt0097428",
115 | "tt0098258",
116 | "tt0098635",
117 | "tt0098724",
118 | "tt0099348",
119 | "tt0099487",
120 | "tt0099653",
121 | "tt0099674",
122 | "tt0099810",
123 | "tt0100150",
124 | "tt0100157",
125 | "tt0100234",
126 | "tt0100802",
127 | "tt0100935",
128 | "tt0100998",
129 | "tt0101272",
130 | "tt0101393",
131 | "tt0101700",
132 | "tt0101889",
133 | "tt0101921",
134 | "tt0102492",
135 | "tt0102926",
136 | "tt0103064",
137 | "tt0103772",
138 | "tt0103786",
139 | "tt0104036",
140 | "tt0104257",
141 | "tt0104348",
142 | "tt0105226",
143 | "tt0105665",
144 | "tt0105695",
145 | "tt0106332",
146 | "tt0106582",
147 | "tt0107507",
148 | "tt0107653",
149 | "tt0107736",
150 | "tt0107808",
151 | "tt0107822",
152 | "tt0108122",
153 | "tt0108289",
154 | "tt0108330",
155 | "tt0109686",
156 | "tt0109830",
157 | "tt0109831",
158 | "tt0110074",
159 | "tt0110148",
160 | "tt0110201",
161 | "tt0110322",
162 | "tt0110632",
163 | "tt0110912",
164 | "tt0111003",
165 | "tt0112769",
166 | "tt0113101",
167 | "tt0113243",
168 | "tt0113253",
169 | "tt0113497",
170 | "tt0114367",
171 | "tt0114369",
172 | "tt0114388",
173 | "tt0114814",
174 | "tt0115798",
175 | "tt0115964",
176 | "tt0116209",
177 | "tt0116477",
178 | "tt0116629",
179 | "tt0116695",
180 | "tt0117500",
181 | "tt0117509",
182 | "tt0117666",
183 | "tt0117731",
184 | "tt0117883",
185 | "tt0117951",
186 | "tt0118636",
187 | "tt0118655",
188 | "tt0118688",
189 | "tt0118689",
190 | "tt0118749",
191 | "tt0118799",
192 | "tt0118845",
193 | "tt0118883",
194 | "tt0118929",
195 | "tt0118971",
196 | "tt0119008",
197 | "tt0119081",
198 | "tt0119177",
199 | "tt0119250",
200 | "tt0119314",
201 | "tt0119396",
202 | "tt0119528",
203 | "tt0119567",
204 | "tt0119643",
205 | "tt0119654",
206 | "tt0119670",
207 | "tt0119738",
208 | "tt0120263",
209 | "tt0120338",
210 | "tt0120586",
211 | "tt0120591",
212 | "tt0120616",
213 | "tt0120655",
214 | "tt0120660",
215 | "tt0120667",
216 | "tt0120669",
217 | "tt0120696",
218 | "tt0120735",
219 | "tt0120737",
220 | "tt0120744",
221 | "tt0120755",
222 | "tt0120787",
223 | "tt0120804",
224 | "tt0120815",
225 | "tt0120885",
226 | "tt0120902",
227 | "tt0120912",
228 | "tt0120915",
229 | "tt0121766",
230 | "tt0122690",
231 | "tt0125439",
232 | "tt0125664",
233 | "tt0126886",
234 | "tt0128445",
235 | "tt0134119",
236 | "tt0134273",
237 | "tt0134847",
238 | "tt0137494",
239 | "tt0139134",
240 | "tt0139654",
241 | "tt0142688",
242 | "tt0143145",
243 | "tt0144084",
244 | "tt0144117",
245 | "tt0159365",
246 | "tt0159784",
247 | "tt0160127",
248 | "tt0162346",
249 | "tt0162661",
250 | "tt0164052",
251 | "tt0167190",
252 | "tt0167260",
253 | "tt0167331",
254 | "tt0169547",
255 | "tt0171363",
256 | "tt0175880",
257 | "tt0180073",
258 | "tt0180093",
259 | "tt0181689",
260 | "tt0181875",
261 | "tt0183523",
262 | "tt0183649",
263 | "tt0187078",
264 | "tt0187393",
265 | "tt0190590",
266 | "tt0195685",
267 | "tt0199354",
268 | "tt0199753",
269 | "tt0203009",
270 | "tt0206634",
271 | "tt0207201",
272 | "tt0208092",
273 | "tt0209144",
274 | "tt0209463",
275 | "tt0210727",
276 | "tt0212338",
277 | "tt0213149",
278 | "tt0227445",
279 | "tt0232500",
280 | "tt0234215",
281 | "tt0240772",
282 | "tt0242653",
283 | "tt0243876",
284 | "tt0244244",
285 | "tt0244353",
286 | "tt0245844",
287 | "tt0246578",
288 | "tt0250494",
289 | "tt0250797",
290 | "tt0251160",
291 | "tt0253754",
292 | "tt0258000",
293 | "tt0264395",
294 | "tt0264616",
295 | "tt0266697",
296 | "tt0266915",
297 | "tt0268695",
298 | "tt0272152",
299 | "tt0275719",
300 | "tt0278504",
301 | "tt0283509",
302 | "tt0286106",
303 | "tt0288477",
304 | "tt0290334",
305 | "tt0294870",
306 | "tt0299658",
307 | "tt0308476",
308 | "tt0309698",
309 | "tt0313542",
310 | "tt0315327",
311 | "tt0316654",
312 | "tt0317198",
313 | "tt0317919",
314 | "tt0318627",
315 | "tt0318974",
316 | "tt0325710",
317 | "tt0325980",
318 | "tt0328107",
319 | "tt0329101",
320 | "tt0331811",
321 | "tt0332452",
322 | "tt0335266",
323 | "tt0337978",
324 | "tt0338013",
325 | "tt0338751",
326 | "tt0343660",
327 | "tt0346094",
328 | "tt0349903",
329 | "tt0350258",
330 | "tt0351977",
331 | "tt0357413",
332 | "tt0359950",
333 | "tt0362227",
334 | "tt0363589",
335 | "tt0363771",
336 | "tt0365907",
337 | "tt0369339",
338 | "tt0369702",
339 | "tt0371257",
340 | "tt0372183",
341 | "tt0372784",
342 | "tt0372824",
343 | "tt0373074",
344 | "tt0374546",
345 | "tt0375679",
346 | "tt0376994",
347 | "tt0377713",
348 | "tt0378194",
349 | "tt0379306",
350 | "tt0382625",
351 | "tt0383028",
352 | "tt0383216",
353 | "tt0383574",
354 | "tt0385004",
355 | "tt0387564",
356 | "tt0387877",
357 | "tt0388795",
358 | "tt0390022",
359 | "tt0393109",
360 | "tt0395699",
361 | "tt0397078",
362 | "tt0398027",
363 | "tt0399295",
364 | "tt0405159",
365 | "tt0407887",
366 | "tt0408306",
367 | "tt0408790",
368 | "tt0413893",
369 | "tt0414055",
370 | "tt0414387",
371 | "tt0414982",
372 | "tt0415380",
373 | "tt0417741",
374 | "tt0418279",
375 | "tt0418819",
376 | "tt0419887",
377 | "tt0420223",
378 | "tt0421715",
379 | "tt0424345",
380 | "tt0425061",
381 | "tt0425210",
382 | "tt0427309",
383 | "tt0427954",
384 | "tt0430357",
385 | "tt0433035",
386 | "tt0435705",
387 | "tt0439815",
388 | "tt0443453",
389 | "tt0443706",
390 | "tt0448157",
391 | "tt0449088",
392 | "tt0450259",
393 | "tt0450385",
394 | "tt0454841",
395 | "tt0454876",
396 | "tt0454921",
397 | "tt0457939",
398 | "tt0458413",
399 | "tt0462200",
400 | "tt0467406",
401 | "tt0468565",
402 | "tt0468569",
403 | "tt0473705",
404 | "tt0475293",
405 | "tt0477348",
406 | "tt0479997",
407 | "tt0489018",
408 | "tt0493464",
409 | "tt0499448",
410 | "tt0499549",
411 | "tt0758774",
412 | "tt0765128",
413 | "tt0765429",
414 | "tt0765447",
415 | "tt0780504",
416 | "tt0790636",
417 | "tt0790686",
418 | "tt0796366",
419 | "tt0800320",
420 | "tt0810819",
421 | "tt0815236",
422 | "tt0824747",
423 | "tt0826711",
424 | "tt0829482",
425 | "tt0844286",
426 | "tt0846308",
427 | "tt0848228",
428 | "tt0862846",
429 | "tt0887883",
430 | "tt0913425",
431 | "tt0914798",
432 | "tt0942385",
433 | "tt0947798",
434 | "tt0958860",
435 | "tt0959337",
436 | "tt0963794",
437 | "tt0963966",
438 | "tt0970416",
439 | "tt0974661",
440 | "tt0975645",
441 | "tt0977855",
442 | "tt0985694",
443 | "tt0985699",
444 | "tt0986233",
445 | "tt0986263",
446 | "tt0988045",
447 | "tt0993846",
448 | "tt1010048",
449 | "tt1013753",
450 | "tt1016268",
451 | "tt1022603",
452 | "tt1024648",
453 | "tt1027718",
454 | "tt1029234",
455 | "tt1029360",
456 | "tt1037705",
457 | "tt1045658",
458 | "tt1045772",
459 | "tt1054606",
460 | "tt1055369",
461 | "tt1057500",
462 | "tt1059786",
463 | "tt1068649",
464 | "tt1068680",
465 | "tt1072748",
466 | "tt1074638",
467 | "tt1084950",
468 | "tt1104001",
469 | "tt1124035",
470 | "tt1125849",
471 | "tt1131729",
472 | "tt1133985",
473 | "tt1135952",
474 | "tt1139797",
475 | "tt1148204",
476 | "tt1156466",
477 | "tt1158278",
478 | "tt1174732",
479 | "tt1179031",
480 | "tt1179904",
481 | "tt1186367",
482 | "tt1188729",
483 | "tt1193138",
484 | "tt1194173",
485 | "tt1210166",
486 | "tt1217613",
487 | "tt1219289",
488 | "tt1220719",
489 | "tt1228705",
490 | "tt1229340",
491 | "tt1229822",
492 | "tt1233381",
493 | "tt1244754",
494 | "tt1253863",
495 | "tt1255953",
496 | "tt1274586",
497 | "tt1276104",
498 | "tt1282140",
499 | "tt1285016",
500 | "tt1291150",
501 | "tt1291584",
502 | "tt1298650",
503 | "tt1300851",
504 | "tt1305806",
505 | "tt1306980",
506 | "tt1322269",
507 | "tt1324999",
508 | "tt1340800",
509 | "tt1343092",
510 | "tt1360860",
511 | "tt1371111",
512 | "tt1375670",
513 | "tt1396218",
514 | "tt1401152",
515 | "tt1403865",
516 | "tt1411238",
517 | "tt1438176",
518 | "tt1439572",
519 | "tt1446714",
520 | "tt1454029",
521 | "tt1454468",
522 | "tt1458175",
523 | "tt1462758",
524 | "tt1468846",
525 | "tt1478338",
526 | "tt1486190",
527 | "tt1502712",
528 | "tt1533117",
529 | "tt1535970",
530 | "tt1560747",
531 | "tt1563738",
532 | "tt1564367",
533 | "tt1568346",
534 | "tt1602620",
535 | "tt1606378",
536 | "tt1615147",
537 | "tt1616195",
538 | "tt1628841",
539 | "tt1637725",
540 | "tt1646987",
541 | "tt1649443",
542 | "tt1655420",
543 | "tt1670345",
544 | "tt1675434",
545 | "tt1692486",
546 | "tt1706593",
547 | "tt1723811",
548 | "tt1747958",
549 | "tt1757746",
550 | "tt1781769",
551 | "tt1800241",
552 | "tt1800246",
553 | "tt1809398",
554 | "tt1832382",
555 | "tt1855325",
556 | "tt1877832",
557 | "tt1907668",
558 | "tt1951266",
559 | "tt1971325",
560 | "tt1979320",
561 | "tt1981115",
562 | "tt2017561",
563 | "tt2053463",
564 | "tt2056771",
565 | "tt2058107",
566 | "tt2058673",
567 | "tt2059255",
568 | "tt2070649",
569 | "tt2084970",
570 | "tt2103281",
571 | "tt2109184",
572 | "tt2118775",
573 | "tt2140373",
574 | "tt2167266",
575 | "tt2238032",
576 | "tt2258281",
577 | "tt2267998",
578 | "tt2294449",
579 | "tt2310332",
580 | "tt2334873",
581 | "tt2345567",
582 | "tt2366450",
583 | "tt2381249",
584 | "tt2382298",
585 | "tt2404435",
586 | "tt2463288",
587 | "tt2473794",
588 | "tt2567026",
589 | "tt2582802",
590 | "tt2639344",
591 | "tt2675914",
592 | "tt2713180",
593 | "tt2717822",
594 | "tt2800240",
595 | "tt2823054",
596 | "tt2884018",
597 | "tt2908856",
598 | "tt2911666",
599 | "tt2923316",
600 | "tt2980516",
601 | "tt3062096",
602 | "tt3064298",
603 | "tt3077214",
604 | "tt3289956",
605 | "tt3296658",
606 | "tt3312830",
607 | "tt3319920",
608 | "tt3395184",
609 | "tt3410834",
610 | "tt3416744",
611 | "tt3439114",
612 | "tt3465916",
613 | "tt3474602",
614 | "tt3478232",
615 | "tt3498820",
616 | "tt3501416",
617 | "tt3531578",
618 | "tt3630276",
619 | "tt3659786",
620 | "tt3671542",
621 | "tt3700392",
622 | "tt3700804",
623 | "tt3707106",
624 | "tt3714720",
625 | "tt3766394",
626 | "tt3808342",
627 | "tt3860916",
628 | "tt3960412",
629 | "tt4046784",
630 | "tt4052882",
631 | "tt4136084",
632 | "tt4151192",
633 | "tt4176826",
634 | "tt4242158",
635 | "tt4273292",
636 | "tt4501454",
637 | "tt4651520",
638 | "tt4698684",
639 | "tt4721400",
640 | "tt4781612",
641 | "tt4786282",
642 | "tt4824302",
643 | "tt4939066",
644 | "tt5052448",
645 | "tt5065810",
646 | "tt5294550",
647 | "tt5564148",
648 | "tt5576318",
649 | "tt5580036",
650 | "tt5593416",
651 | "tt5649144",
652 | "tt5688868",
653 | "tt5827496",
654 | "tt5866930",
655 | "tt6133130",
656 | "tt6298600",
657 | "tt6466464",
658 | "tt6513406",
659 | "tt6788942",
660 | "tt7055592",
661 | "tt7131870",
662 | "tt7180392"
663 | ],
664 | "val": [
665 | "tt0032138",
666 | "tt0038650",
667 | "tt0048545",
668 | "tt0053221",
669 | "tt0053579",
670 | "tt0054167",
671 | "tt0061722",
672 | "tt0064115",
673 | "tt0066026",
674 | "tt0067140",
675 | "tt0069762",
676 | "tt0070245",
677 | "tt0071562",
678 | "tt0072443",
679 | "tt0072890",
680 | "tt0073486",
681 | "tt0074811",
682 | "tt0076759",
683 | "tt0079182",
684 | "tt0079470",
685 | "tt0080610",
686 | "tt0080684",
687 | "tt0083866",
688 | "tt0083929",
689 | "tt0084899",
690 | "tt0085991",
691 | "tt0087332",
692 | "tt0089853",
693 | "tt0089907",
694 | "tt0090756",
695 | "tt0091355",
696 | "tt0091369",
697 | "tt0092699",
698 | "tt0092991",
699 | "tt0094737",
700 | "tt0094761",
701 | "tt0095765",
702 | "tt0095953",
703 | "tt0096256",
704 | "tt0096446",
705 | "tt0097576",
706 | "tt0099685",
707 | "tt0100112",
708 | "tt0100403",
709 | "tt0101410",
710 | "tt0102138",
711 | "tt0103074",
712 | "tt0103241",
713 | "tt0103292",
714 | "tt0104797",
715 | "tt0105236",
716 | "tt0105652",
717 | "tt0106226",
718 | "tt0106977",
719 | "tt0107614",
720 | "tt0108160",
721 | "tt0108656",
722 | "tt0109020",
723 | "tt0110475",
724 | "tt0110932",
725 | "tt0112462",
726 | "tt0112641",
727 | "tt0112740",
728 | "tt0113870",
729 | "tt0114558",
730 | "tt0116367",
731 | "tt0116996",
732 | "tt0117381",
733 | "tt0118548",
734 | "tt0118571",
735 | "tt0118842",
736 | "tt0119094",
737 | "tt0119116",
738 | "tt0119174",
739 | "tt0119822",
740 | "tt0120483",
741 | "tt0120601",
742 | "tt0120780",
743 | "tt0120863",
744 | "tt0121765",
745 | "tt0122933",
746 | "tt0129387",
747 | "tt0133093",
748 | "tt0138097",
749 | "tt0140352",
750 | "tt0145487",
751 | "tt0166896",
752 | "tt0166924",
753 | "tt0167261",
754 | "tt0167404",
755 | "tt0182789",
756 | "tt0186151",
757 | "tt0209958",
758 | "tt0217869",
759 | "tt0240890",
760 | "tt0248667",
761 | "tt0258463",
762 | "tt0261392",
763 | "tt0265666",
764 | "tt0268126",
765 | "tt0268978",
766 | "tt0277027",
767 | "tt0285742",
768 | "tt0289879",
769 | "tt0290002",
770 | "tt0298228",
771 | "tt0311113",
772 | "tt0317740",
773 | "tt0319262",
774 | "tt0322259",
775 | "tt0324197",
776 | "tt0337921",
777 | "tt0341495",
778 | "tt0343818",
779 | "tt0360486",
780 | "tt0370263",
781 | "tt0371724",
782 | "tt0375063",
783 | "tt0395169",
784 | "tt0401383",
785 | "tt0408236",
786 | "tt0416320",
787 | "tt0432021",
788 | "tt0434409",
789 | "tt0454848",
790 | "tt0455760",
791 | "tt0457297",
792 | "tt0457430",
793 | "tt0457513",
794 | "tt0467200",
795 | "tt0469494",
796 | "tt0470752",
797 | "tt0480025",
798 | "tt0758730",
799 | "tt0758758",
800 | "tt0780653",
801 | "tt0790628",
802 | "tt0808151",
803 | "tt0816692",
804 | "tt0824758",
805 | "tt0838232",
806 | "tt0898367",
807 | "tt0940709",
808 | "tt0964517",
809 | "tt0993842",
810 | "tt1000774",
811 | "tt1019452",
812 | "tt1032755",
813 | "tt1041829",
814 | "tt1055292",
815 | "tt1065073",
816 | "tt1071875",
817 | "tt1073498",
818 | "tt1093906",
819 | "tt1100089",
820 | "tt1144884",
821 | "tt1172049",
822 | "tt1178663",
823 | "tt1182345",
824 | "tt1190080",
825 | "tt1211837",
826 | "tt1216496",
827 | "tt1232829",
828 | "tt1284575",
829 | "tt1341167",
830 | "tt1355683",
831 | "tt1385826",
832 | "tt1409024",
833 | "tt1441953",
834 | "tt1462900",
835 | "tt1504320",
836 | "tt1540133",
837 | "tt1582248",
838 | "tt1586752",
839 | "tt1591095",
840 | "tt1596363",
841 | "tt1602613",
842 | "tt1611840",
843 | "tt1619029",
844 | "tt1645170",
845 | "tt1659337",
846 | "tt1703957",
847 | "tt1722484",
848 | "tt1725986",
849 | "tt1731141",
850 | "tt1742683",
851 | "tt1840309",
852 | "tt1895587",
853 | "tt1974419",
854 | "tt2032557",
855 | "tt2076220",
856 | "tt2078768",
857 | "tt2109248",
858 | "tt2132285",
859 | "tt2381991",
860 | "tt2645044",
861 | "tt2788732",
862 | "tt2832470",
863 | "tt2872732",
864 | "tt2978462",
865 | "tt3110958",
866 | "tt3316960",
867 | "tt3421514",
868 | "tt3464902",
869 | "tt3488710",
870 | "tt3508840",
871 | "tt3553442",
872 | "tt3672840",
873 | "tt3726704",
874 | "tt3824458",
875 | "tt3882082",
876 | "tt3922798",
877 | "tt4160708",
878 | "tt4647900",
879 | "tt4967094",
880 | "tt5726086",
881 | "tt6121428",
882 | "tt6190198",
883 | "tt7160070",
884 | "tt7672188"
885 | ],
886 | "test": [
887 | "tt0048028",
888 | "tt0049470",
889 | "tt0049833",
890 | "tt0050419",
891 | "tt0052357",
892 | "tt0058331",
893 | "tt0061811",
894 | "tt0063442",
895 | "tt0066206",
896 | "tt0068646",
897 | "tt0070291",
898 | "tt0070511",
899 | "tt0073195",
900 | "tt0073582",
901 | "tt0073629",
902 | "tt0075314",
903 | "tt0075686",
904 | "tt0078788",
905 | "tt0079672",
906 | "tt0080455",
907 | "tt0080761",
908 | "tt0082089",
909 | "tt0082198",
910 | "tt0083907",
911 | "tt0083946",
912 | "tt0084390",
913 | "tt0086190",
914 | "tt0086856",
915 | "tt0087921",
916 | "tt0088247",
917 | "tt0088944",
918 | "tt0089218",
919 | "tt0089881",
920 | "tt0090257",
921 | "tt0091867",
922 | "tt0092099",
923 | "tt0093773",
924 | "tt0094964",
925 | "tt0095250",
926 | "tt0096320",
927 | "tt0099423",
928 | "tt0100405",
929 | "tt0103776",
930 | "tt0103855",
931 | "tt0104466",
932 | "tt0104553",
933 | "tt0104691",
934 | "tt0107290",
935 | "tt0107617",
936 | "tt0108399",
937 | "tt0110116",
938 | "tt0110167",
939 | "tt0110604",
940 | "tt0111280",
941 | "tt0111797",
942 | "tt0112384",
943 | "tt0112573",
944 | "tt0112818",
945 | "tt0112883",
946 | "tt0113277",
947 | "tt0114746",
948 | "tt0115734",
949 | "tt0115759",
950 | "tt0115956",
951 | "tt0116213",
952 | "tt0116282",
953 | "tt0116767",
954 | "tt0116922",
955 | "tt0117060",
956 | "tt0117571",
957 | "tt0118583",
958 | "tt0118715",
959 | "tt0119303",
960 | "tt0119349",
961 | "tt0119375",
962 | "tt0119488",
963 | "tt0120255",
964 | "tt0120382",
965 | "tt0120689",
966 | "tt0120731",
967 | "tt0120738",
968 | "tt0120812",
969 | "tt0120890",
970 | "tt0120903",
971 | "tt0123755",
972 | "tt0124315",
973 | "tt0127536",
974 | "tt0133152",
975 | "tt0137439",
976 | "tt0137523",
977 | "tt0142342",
978 | "tt0163025",
979 | "tt0172495",
980 | "tt0178868",
981 | "tt0190332",
982 | "tt0195714",
983 | "tt0212985",
984 | "tt0217505",
985 | "tt0219822",
986 | "tt0253474",
987 | "tt0257360",
988 | "tt0280609",
989 | "tt0281358",
990 | "tt0281686",
991 | "tt0319061",
992 | "tt0330373",
993 | "tt0335119",
994 | "tt0361748",
995 | "tt0368891",
996 | "tt0368933",
997 | "tt0369441",
998 | "tt0370032",
999 | "tt0373051",
1000 | "tt0373469",
1001 | "tt0379786",
1002 | "tt0381061",
1003 | "tt0386588",
1004 | "tt0387898",
1005 | "tt0399201",
1006 | "tt0404978",
1007 | "tt0409459",
1008 | "tt0416508",
1009 | "tt0440963",
1010 | "tt0443272",
1011 | "tt0443680",
1012 | "tt0452625",
1013 | "tt0455824",
1014 | "tt0458352",
1015 | "tt0458525",
1016 | "tt0460791",
1017 | "tt0460989",
1018 | "tt0462499",
1019 | "tt0477347",
1020 | "tt0479884",
1021 | "tt0481369",
1022 | "tt0780571",
1023 | "tt0783233",
1024 | "tt0800080",
1025 | "tt0800369",
1026 | "tt0815245",
1027 | "tt0822832",
1028 | "tt0844347",
1029 | "tt0878804",
1030 | "tt0903624",
1031 | "tt0905372",
1032 | "tt0944835",
1033 | "tt0945513",
1034 | "tt0970179",
1035 | "tt0976051",
1036 | "tt1001508",
1037 | "tt1007029",
1038 | "tt1017460",
1039 | "tt1033575",
1040 | "tt1034314",
1041 | "tt1038919",
1042 | "tt1046173",
1043 | "tt1063669",
1044 | "tt1086772",
1045 | "tt1092026",
1046 | "tt1099212",
1047 | "tt1119646",
1048 | "tt1120985",
1049 | "tt1124037",
1050 | "tt1170358",
1051 | "tt1181614",
1052 | "tt1189340",
1053 | "tt1201607",
1054 | "tt1205489",
1055 | "tt1220634",
1056 | "tt1229238",
1057 | "tt1287878",
1058 | "tt1292566",
1059 | "tt1318514",
1060 | "tt1375666",
1061 | "tt1386932",
1062 | "tt1392190",
1063 | "tt1397514",
1064 | "tt1399103",
1065 | "tt1412386",
1066 | "tt1413492",
1067 | "tt1424381",
1068 | "tt1431045",
1069 | "tt1440728",
1070 | "tt1446147",
1071 | "tt1483013",
1072 | "tt1510906",
1073 | "tt1524137",
1074 | "tt1570728",
1075 | "tt1623205",
1076 | "tt1663662",
1077 | "tt1707386",
1078 | "tt1748122",
1079 | "tt1843287",
1080 | "tt1853728",
1081 | "tt1872181",
1082 | "tt2011351",
1083 | "tt2024544",
1084 | "tt2099556",
1085 | "tt2115388",
1086 | "tt2194499",
1087 | "tt2402927",
1088 | "tt2409818",
1089 | "tt2446980",
1090 | "tt2488496",
1091 | "tt2567712",
1092 | "tt2582846",
1093 | "tt2614684",
1094 | "tt2802144",
1095 | "tt3385516",
1096 | "tt3480796",
1097 | "tt3495026",
1098 | "tt4008652",
1099 | "tt4034354",
1100 | "tt4520364",
1101 | "tt4915672",
1102 | "tt4972062",
1103 | "tt5140878",
1104 | "tt6157626",
1105 | "tt6518634",
1106 | "tt6644200"
1107 | ],
1108 | "full": [
1109 | "tt0032138",
1110 | "tt0035423",
1111 | "tt0038650",
1112 | "tt0045537",
1113 | "tt0047396",
1114 | "tt0048028",
1115 | "tt0048545",
1116 | "tt0048605",
1117 | "tt0049470",
1118 | "tt0049730",
1119 | "tt0049833",
1120 | "tt0050419",
1121 | "tt0050706",
1122 | "tt0052357",
1123 | "tt0053125",
1124 | "tt0053221",
1125 | "tt0053579",
1126 | "tt0054167",
1127 | "tt0056869",
1128 | "tt0056923",
1129 | "tt0057115",
1130 | "tt0058331",
1131 | "tt0058461",
1132 | "tt0059043",
1133 | "tt0059592",
1134 | "tt0060522",
1135 | "tt0061138",
1136 | "tt0061418",
1137 | "tt0061722",
1138 | "tt0061781",
1139 | "tt0061811",
1140 | "tt0062622",
1141 | "tt0063442",
1142 | "tt0064040",
1143 | "tt0064115",
1144 | "tt0064276",
1145 | "tt0064665",
1146 | "tt0065214",
1147 | "tt0065724",
1148 | "tt0065988",
1149 | "tt0066026",
1150 | "tt0066206",
1151 | "tt0066249",
1152 | "tt0066921",
1153 | "tt0067116",
1154 | "tt0067140",
1155 | "tt0067185",
1156 | "tt0068646",
1157 | "tt0068935",
1158 | "tt0069293",
1159 | "tt0069467",
1160 | "tt0069762",
1161 | "tt0069995",
1162 | "tt0070047",
1163 | "tt0070245",
1164 | "tt0070246",
1165 | "tt0070291",
1166 | "tt0070379",
1167 | "tt0070511",
1168 | "tt0070735",
1169 | "tt0070849",
1170 | "tt0071129",
1171 | "tt0071315",
1172 | "tt0071360",
1173 | "tt0071562",
1174 | "tt0072443",
1175 | "tt0072684",
1176 | "tt0072890",
1177 | "tt0073195",
1178 | "tt0073440",
1179 | "tt0073486",
1180 | "tt0073582",
1181 | "tt0073629",
1182 | "tt0074119",
1183 | "tt0074285",
1184 | "tt0074686",
1185 | "tt0074749",
1186 | "tt0074811",
1187 | "tt0075148",
1188 | "tt0075314",
1189 | "tt0075686",
1190 | "tt0076729",
1191 | "tt0076759",
1192 | "tt0077402",
1193 | "tt0077405",
1194 | "tt0077416",
1195 | "tt0077651",
1196 | "tt0078788",
1197 | "tt0078841",
1198 | "tt0078908",
1199 | "tt0079095",
1200 | "tt0079116",
1201 | "tt0079182",
1202 | "tt0079417",
1203 | "tt0079470",
1204 | "tt0079672",
1205 | "tt0079944",
1206 | "tt0079945",
1207 | "tt0080339",
1208 | "tt0080453",
1209 | "tt0080455",
1210 | "tt0080610",
1211 | "tt0080684",
1212 | "tt0080745",
1213 | "tt0080761",
1214 | "tt0080958",
1215 | "tt0080979",
1216 | "tt0081505",
1217 | "tt0082089",
1218 | "tt0082186",
1219 | "tt0082198",
1220 | "tt0082846",
1221 | "tt0082971",
1222 | "tt0083658",
1223 | "tt0083866",
1224 | "tt0083907",
1225 | "tt0083929",
1226 | "tt0083946",
1227 | "tt0083987",
1228 | "tt0084390",
1229 | "tt0084549",
1230 | "tt0084628",
1231 | "tt0084726",
1232 | "tt0084787",
1233 | "tt0084899",
1234 | "tt0085794",
1235 | "tt0085991",
1236 | "tt0086190",
1237 | "tt0086250",
1238 | "tt0086837",
1239 | "tt0086856",
1240 | "tt0086879",
1241 | "tt0086969",
1242 | "tt0087182",
1243 | "tt0087332",
1244 | "tt0087469",
1245 | "tt0087921",
1246 | "tt0088170",
1247 | "tt0088222",
1248 | "tt0088247",
1249 | "tt0088847",
1250 | "tt0088939",
1251 | "tt0088944",
1252 | "tt0088993",
1253 | "tt0089218",
1254 | "tt0089853",
1255 | "tt0089881",
1256 | "tt0089907",
1257 | "tt0090022",
1258 | "tt0090257",
1259 | "tt0090605",
1260 | "tt0090756",
1261 | "tt0091042",
1262 | "tt0091203",
1263 | "tt0091251",
1264 | "tt0091355",
1265 | "tt0091369",
1266 | "tt0091406",
1267 | "tt0091738",
1268 | "tt0091763",
1269 | "tt0091867",
1270 | "tt0092099",
1271 | "tt0092603",
1272 | "tt0092699",
1273 | "tt0092991",
1274 | "tt0093010",
1275 | "tt0093209",
1276 | "tt0093565",
1277 | "tt0093748",
1278 | "tt0093773",
1279 | "tt0093779",
1280 | "tt0094226",
1281 | "tt0094291",
1282 | "tt0094737",
1283 | "tt0094761",
1284 | "tt0094964",
1285 | "tt0095016",
1286 | "tt0095250",
1287 | "tt0095497",
1288 | "tt0095765",
1289 | "tt0095953",
1290 | "tt0095956",
1291 | "tt0096256",
1292 | "tt0096320",
1293 | "tt0096446",
1294 | "tt0096463",
1295 | "tt0096754",
1296 | "tt0096874",
1297 | "tt0096895",
1298 | "tt0097216",
1299 | "tt0097372",
1300 | "tt0097428",
1301 | "tt0097576",
1302 | "tt0098258",
1303 | "tt0098635",
1304 | "tt0098724",
1305 | "tt0099348",
1306 | "tt0099423",
1307 | "tt0099487",
1308 | "tt0099653",
1309 | "tt0099674",
1310 | "tt0099685",
1311 | "tt0099810",
1312 | "tt0100112",
1313 | "tt0100150",
1314 | "tt0100157",
1315 | "tt0100234",
1316 | "tt0100403",
1317 | "tt0100405",
1318 | "tt0100802",
1319 | "tt0100935",
1320 | "tt0100998",
1321 | "tt0101272",
1322 | "tt0101393",
1323 | "tt0101410",
1324 | "tt0101700",
1325 | "tt0101889",
1326 | "tt0101921",
1327 | "tt0102138",
1328 | "tt0102492",
1329 | "tt0102926",
1330 | "tt0103064",
1331 | "tt0103074",
1332 | "tt0103241",
1333 | "tt0103292",
1334 | "tt0103772",
1335 | "tt0103776",
1336 | "tt0103786",
1337 | "tt0103855",
1338 | "tt0104036",
1339 | "tt0104257",
1340 | "tt0104348",
1341 | "tt0104466",
1342 | "tt0104553",
1343 | "tt0104691",
1344 | "tt0104797",
1345 | "tt0105226",
1346 | "tt0105236",
1347 | "tt0105652",
1348 | "tt0105665",
1349 | "tt0105695",
1350 | "tt0106226",
1351 | "tt0106332",
1352 | "tt0106582",
1353 | "tt0106977",
1354 | "tt0107290",
1355 | "tt0107507",
1356 | "tt0107614",
1357 | "tt0107617",
1358 | "tt0107653",
1359 | "tt0107736",
1360 | "tt0107808",
1361 | "tt0107822",
1362 | "tt0108122",
1363 | "tt0108160",
1364 | "tt0108289",
1365 | "tt0108330",
1366 | "tt0108399",
1367 | "tt0108656",
1368 | "tt0109020",
1369 | "tt0109686",
1370 | "tt0109830",
1371 | "tt0109831",
1372 | "tt0110074",
1373 | "tt0110116",
1374 | "tt0110148",
1375 | "tt0110167",
1376 | "tt0110201",
1377 | "tt0110322",
1378 | "tt0110475",
1379 | "tt0110604",
1380 | "tt0110632",
1381 | "tt0110912",
1382 | "tt0110932",
1383 | "tt0111003",
1384 | "tt0111280",
1385 | "tt0111797",
1386 | "tt0112384",
1387 | "tt0112462",
1388 | "tt0112573",
1389 | "tt0112641",
1390 | "tt0112740",
1391 | "tt0112769",
1392 | "tt0112818",
1393 | "tt0112883",
1394 | "tt0113101",
1395 | "tt0113243",
1396 | "tt0113253",
1397 | "tt0113277",
1398 | "tt0113497",
1399 | "tt0113870",
1400 | "tt0114367",
1401 | "tt0114369",
1402 | "tt0114388",
1403 | "tt0114558",
1404 | "tt0114746",
1405 | "tt0114814",
1406 | "tt0115734",
1407 | "tt0115759",
1408 | "tt0115798",
1409 | "tt0115956",
1410 | "tt0115964",
1411 | "tt0116209",
1412 | "tt0116213",
1413 | "tt0116282",
1414 | "tt0116367",
1415 | "tt0116477",
1416 | "tt0116629",
1417 | "tt0116695",
1418 | "tt0116767",
1419 | "tt0116922",
1420 | "tt0116996",
1421 | "tt0117060",
1422 | "tt0117381",
1423 | "tt0117500",
1424 | "tt0117509",
1425 | "tt0117571",
1426 | "tt0117666",
1427 | "tt0117731",
1428 | "tt0117883",
1429 | "tt0117951",
1430 | "tt0118548",
1431 | "tt0118571",
1432 | "tt0118583",
1433 | "tt0118636",
1434 | "tt0118655",
1435 | "tt0118688",
1436 | "tt0118689",
1437 | "tt0118715",
1438 | "tt0118749",
1439 | "tt0118799",
1440 | "tt0118842",
1441 | "tt0118845",
1442 | "tt0118883",
1443 | "tt0118929",
1444 | "tt0118971",
1445 | "tt0119008",
1446 | "tt0119081",
1447 | "tt0119094",
1448 | "tt0119116",
1449 | "tt0119174",
1450 | "tt0119177",
1451 | "tt0119250",
1452 | "tt0119303",
1453 | "tt0119314",
1454 | "tt0119349",
1455 | "tt0119375",
1456 | "tt0119396",
1457 | "tt0119488",
1458 | "tt0119528",
1459 | "tt0119567",
1460 | "tt0119643",
1461 | "tt0119654",
1462 | "tt0119670",
1463 | "tt0119738",
1464 | "tt0119822",
1465 | "tt0120255",
1466 | "tt0120263",
1467 | "tt0120338",
1468 | "tt0120382",
1469 | "tt0120483",
1470 | "tt0120586",
1471 | "tt0120591",
1472 | "tt0120601",
1473 | "tt0120616",
1474 | "tt0120655",
1475 | "tt0120660",
1476 | "tt0120667",
1477 | "tt0120669",
1478 | "tt0120689",
1479 | "tt0120696",
1480 | "tt0120731",
1481 | "tt0120735",
1482 | "tt0120737",
1483 | "tt0120738",
1484 | "tt0120744",
1485 | "tt0120755",
1486 | "tt0120780",
1487 | "tt0120787",
1488 | "tt0120804",
1489 | "tt0120812",
1490 | "tt0120815",
1491 | "tt0120863",
1492 | "tt0120885",
1493 | "tt0120890",
1494 | "tt0120902",
1495 | "tt0120903",
1496 | "tt0120912",
1497 | "tt0120915",
1498 | "tt0121765",
1499 | "tt0121766",
1500 | "tt0122690",
1501 | "tt0122933",
1502 | "tt0123755",
1503 | "tt0124315",
1504 | "tt0125439",
1505 | "tt0125664",
1506 | "tt0126886",
1507 | "tt0127536",
1508 | "tt0128445",
1509 | "tt0129387",
1510 | "tt0133093",
1511 | "tt0133152",
1512 | "tt0134119",
1513 | "tt0134273",
1514 | "tt0134847",
1515 | "tt0137439",
1516 | "tt0137494",
1517 | "tt0137523",
1518 | "tt0138097",
1519 | "tt0139134",
1520 | "tt0139654",
1521 | "tt0140352",
1522 | "tt0142342",
1523 | "tt0142688",
1524 | "tt0143145",
1525 | "tt0144084",
1526 | "tt0144117",
1527 | "tt0145487",
1528 | "tt0159365",
1529 | "tt0159784",
1530 | "tt0160127",
1531 | "tt0162346",
1532 | "tt0162661",
1533 | "tt0163025",
1534 | "tt0164052",
1535 | "tt0166896",
1536 | "tt0166924",
1537 | "tt0167190",
1538 | "tt0167260",
1539 | "tt0167261",
1540 | "tt0167331",
1541 | "tt0167404",
1542 | "tt0169547",
1543 | "tt0171363",
1544 | "tt0172495",
1545 | "tt0175880",
1546 | "tt0178868",
1547 | "tt0180073",
1548 | "tt0180093",
1549 | "tt0181689",
1550 | "tt0181875",
1551 | "tt0182789",
1552 | "tt0183523",
1553 | "tt0183649",
1554 | "tt0186151",
1555 | "tt0187078",
1556 | "tt0187393",
1557 | "tt0190332",
1558 | "tt0190590",
1559 | "tt0195685",
1560 | "tt0195714",
1561 | "tt0199354",
1562 | "tt0199753",
1563 | "tt0203009",
1564 | "tt0206634",
1565 | "tt0207201",
1566 | "tt0208092",
1567 | "tt0209144",
1568 | "tt0209463",
1569 | "tt0209958",
1570 | "tt0210727",
1571 | "tt0212338",
1572 | "tt0212985",
1573 | "tt0213149",
1574 | "tt0217505",
1575 | "tt0217869",
1576 | "tt0219822",
1577 | "tt0227445",
1578 | "tt0232500",
1579 | "tt0234215",
1580 | "tt0240772",
1581 | "tt0240890",
1582 | "tt0242653",
1583 | "tt0243876",
1584 | "tt0244244",
1585 | "tt0244353",
1586 | "tt0245844",
1587 | "tt0246578",
1588 | "tt0248667",
1589 | "tt0250494",
1590 | "tt0250797",
1591 | "tt0251160",
1592 | "tt0253474",
1593 | "tt0253754",
1594 | "tt0257360",
1595 | "tt0258000",
1596 | "tt0258463",
1597 | "tt0261392",
1598 | "tt0264395",
1599 | "tt0264616",
1600 | "tt0265666",
1601 | "tt0266697",
1602 | "tt0266915",
1603 | "tt0268126",
1604 | "tt0268695",
1605 | "tt0268978",
1606 | "tt0272152",
1607 | "tt0275719",
1608 | "tt0277027",
1609 | "tt0278504",
1610 | "tt0280609",
1611 | "tt0281358",
1612 | "tt0281686",
1613 | "tt0283509",
1614 | "tt0285742",
1615 | "tt0286106",
1616 | "tt0288477",
1617 | "tt0289879",
1618 | "tt0290002",
1619 | "tt0290334",
1620 | "tt0294870",
1621 | "tt0298228",
1622 | "tt0299658",
1623 | "tt0308476",
1624 | "tt0309698",
1625 | "tt0311113",
1626 | "tt0313542",
1627 | "tt0315327",
1628 | "tt0316654",
1629 | "tt0317198",
1630 | "tt0317740",
1631 | "tt0317919",
1632 | "tt0318627",
1633 | "tt0318974",
1634 | "tt0319061",
1635 | "tt0319262",
1636 | "tt0322259",
1637 | "tt0324197",
1638 | "tt0325710",
1639 | "tt0325980",
1640 | "tt0328107",
1641 | "tt0329101",
1642 | "tt0330373",
1643 | "tt0331811",
1644 | "tt0332452",
1645 | "tt0335119",
1646 | "tt0335266",
1647 | "tt0337921",
1648 | "tt0337978",
1649 | "tt0338013",
1650 | "tt0338751",
1651 | "tt0341495",
1652 | "tt0343660",
1653 | "tt0343818",
1654 | "tt0346094",
1655 | "tt0349903",
1656 | "tt0350258",
1657 | "tt0351977",
1658 | "tt0357413",
1659 | "tt0359950",
1660 | "tt0360486",
1661 | "tt0361748",
1662 | "tt0362227",
1663 | "tt0363589",
1664 | "tt0363771",
1665 | "tt0365907",
1666 | "tt0368891",
1667 | "tt0368933",
1668 | "tt0369339",
1669 | "tt0369441",
1670 | "tt0369702",
1671 | "tt0370032",
1672 | "tt0370263",
1673 | "tt0371257",
1674 | "tt0371724",
1675 | "tt0372183",
1676 | "tt0372784",
1677 | "tt0372824",
1678 | "tt0373051",
1679 | "tt0373074",
1680 | "tt0373469",
1681 | "tt0374546",
1682 | "tt0375063",
1683 | "tt0375679",
1684 | "tt0376994",
1685 | "tt0377713",
1686 | "tt0378194",
1687 | "tt0379306",
1688 | "tt0379786",
1689 | "tt0381061",
1690 | "tt0382625",
1691 | "tt0383028",
1692 | "tt0383216",
1693 | "tt0383574",
1694 | "tt0385004",
1695 | "tt0386588",
1696 | "tt0387564",
1697 | "tt0387877",
1698 | "tt0387898",
1699 | "tt0388795",
1700 | "tt0390022",
1701 | "tt0393109",
1702 | "tt0395169",
1703 | "tt0395699",
1704 | "tt0397078",
1705 | "tt0398027",
1706 | "tt0399201",
1707 | "tt0399295",
1708 | "tt0401383",
1709 | "tt0404978",
1710 | "tt0405159",
1711 | "tt0407887",
1712 | "tt0408236",
1713 | "tt0408306",
1714 | "tt0408790",
1715 | "tt0409459",
1716 | "tt0413893",
1717 | "tt0414055",
1718 | "tt0414387",
1719 | "tt0414982",
1720 | "tt0415380",
1721 | "tt0416320",
1722 | "tt0416508",
1723 | "tt0417741",
1724 | "tt0418279",
1725 | "tt0418819",
1726 | "tt0419887",
1727 | "tt0420223",
1728 | "tt0421715",
1729 | "tt0424345",
1730 | "tt0425061",
1731 | "tt0425210",
1732 | "tt0427309",
1733 | "tt0427954",
1734 | "tt0430357",
1735 | "tt0432021",
1736 | "tt0433035",
1737 | "tt0434409",
1738 | "tt0435705",
1739 | "tt0439815",
1740 | "tt0440963",
1741 | "tt0443272",
1742 | "tt0443453",
1743 | "tt0443680",
1744 | "tt0443706",
1745 | "tt0448157",
1746 | "tt0449088",
1747 | "tt0450259",
1748 | "tt0450385",
1749 | "tt0452625",
1750 | "tt0454841",
1751 | "tt0454848",
1752 | "tt0454876",
1753 | "tt0454921",
1754 | "tt0455760",
1755 | "tt0455824",
1756 | "tt0457297",
1757 | "tt0457430",
1758 | "tt0457513",
1759 | "tt0457939",
1760 | "tt0458352",
1761 | "tt0458413",
1762 | "tt0458525",
1763 | "tt0460791",
1764 | "tt0460989",
1765 | "tt0462200",
1766 | "tt0462499",
1767 | "tt0467200",
1768 | "tt0467406",
1769 | "tt0468565",
1770 | "tt0468569",
1771 | "tt0469494",
1772 | "tt0470752",
1773 | "tt0473705",
1774 | "tt0475293",
1775 | "tt0477347",
1776 | "tt0477348",
1777 | "tt0479884",
1778 | "tt0479997",
1779 | "tt0480025",
1780 | "tt0481369",
1781 | "tt0489018",
1782 | "tt0493464",
1783 | "tt0499448",
1784 | "tt0499549",
1785 | "tt0758730",
1786 | "tt0758758",
1787 | "tt0758774",
1788 | "tt0765128",
1789 | "tt0765429",
1790 | "tt0765447",
1791 | "tt0780504",
1792 | "tt0780571",
1793 | "tt0780653",
1794 | "tt0783233",
1795 | "tt0790628",
1796 | "tt0790636",
1797 | "tt0790686",
1798 | "tt0796366",
1799 | "tt0800080",
1800 | "tt0800320",
1801 | "tt0800369",
1802 | "tt0808151",
1803 | "tt0810819",
1804 | "tt0815236",
1805 | "tt0815245",
1806 | "tt0816692",
1807 | "tt0822832",
1808 | "tt0824747",
1809 | "tt0824758",
1810 | "tt0826711",
1811 | "tt0829482",
1812 | "tt0838232",
1813 | "tt0844286",
1814 | "tt0844347",
1815 | "tt0846308",
1816 | "tt0848228",
1817 | "tt0862846",
1818 | "tt0878804",
1819 | "tt0887883",
1820 | "tt0898367",
1821 | "tt0903624",
1822 | "tt0905372",
1823 | "tt0913425",
1824 | "tt0914798",
1825 | "tt0940709",
1826 | "tt0942385",
1827 | "tt0944835",
1828 | "tt0945513",
1829 | "tt0947798",
1830 | "tt0958860",
1831 | "tt0959337",
1832 | "tt0963794",
1833 | "tt0963966",
1834 | "tt0964517",
1835 | "tt0970179",
1836 | "tt0970416",
1837 | "tt0974661",
1838 | "tt0975645",
1839 | "tt0976051",
1840 | "tt0977855",
1841 | "tt0985694",
1842 | "tt0985699",
1843 | "tt0986233",
1844 | "tt0986263",
1845 | "tt0988045",
1846 | "tt0993842",
1847 | "tt0993846",
1848 | "tt1000774",
1849 | "tt1001508",
1850 | "tt1007029",
1851 | "tt1010048",
1852 | "tt1013753",
1853 | "tt1016268",
1854 | "tt1017460",
1855 | "tt1019452",
1856 | "tt1022603",
1857 | "tt1024648",
1858 | "tt1027718",
1859 | "tt1029234",
1860 | "tt1029360",
1861 | "tt1032755",
1862 | "tt1033575",
1863 | "tt1034314",
1864 | "tt1037705",
1865 | "tt1038919",
1866 | "tt1041829",
1867 | "tt1045658",
1868 | "tt1045772",
1869 | "tt1046173",
1870 | "tt1054606",
1871 | "tt1055292",
1872 | "tt1055369",
1873 | "tt1057500",
1874 | "tt1059786",
1875 | "tt1063669",
1876 | "tt1065073",
1877 | "tt1068649",
1878 | "tt1068680",
1879 | "tt1071875",
1880 | "tt1072748",
1881 | "tt1073498",
1882 | "tt1074638",
1883 | "tt1084950",
1884 | "tt1086772",
1885 | "tt1092026",
1886 | "tt1093906",
1887 | "tt1099212",
1888 | "tt1100089",
1889 | "tt1104001",
1890 | "tt1119646",
1891 | "tt1120985",
1892 | "tt1124035",
1893 | "tt1124037",
1894 | "tt1125849",
1895 | "tt1131729",
1896 | "tt1133985",
1897 | "tt1135952",
1898 | "tt1139797",
1899 | "tt1144884",
1900 | "tt1148204",
1901 | "tt1156466",
1902 | "tt1158278",
1903 | "tt1170358",
1904 | "tt1172049",
1905 | "tt1174732",
1906 | "tt1178663",
1907 | "tt1179031",
1908 | "tt1179904",
1909 | "tt1181614",
1910 | "tt1182345",
1911 | "tt1186367",
1912 | "tt1188729",
1913 | "tt1189340",
1914 | "tt1190080",
1915 | "tt1193138",
1916 | "tt1194173",
1917 | "tt1201607",
1918 | "tt1205489",
1919 | "tt1210166",
1920 | "tt1211837",
1921 | "tt1216496",
1922 | "tt1217613",
1923 | "tt1219289",
1924 | "tt1220634",
1925 | "tt1220719",
1926 | "tt1228705",
1927 | "tt1229238",
1928 | "tt1229340",
1929 | "tt1229822",
1930 | "tt1232829",
1931 | "tt1233381",
1932 | "tt1244754",
1933 | "tt1253863",
1934 | "tt1255953",
1935 | "tt1274586",
1936 | "tt1276104",
1937 | "tt1282140",
1938 | "tt1284575",
1939 | "tt1285016",
1940 | "tt1287878",
1941 | "tt1291150",
1942 | "tt1291584",
1943 | "tt1292566",
1944 | "tt1298650",
1945 | "tt1300851",
1946 | "tt1305806",
1947 | "tt1306980",
1948 | "tt1318514",
1949 | "tt1322269",
1950 | "tt1324999",
1951 | "tt1340800",
1952 | "tt1341167",
1953 | "tt1343092",
1954 | "tt1355683",
1955 | "tt1360860",
1956 | "tt1371111",
1957 | "tt1375666",
1958 | "tt1375670",
1959 | "tt1385826",
1960 | "tt1386932",
1961 | "tt1392190",
1962 | "tt1396218",
1963 | "tt1397514",
1964 | "tt1399103",
1965 | "tt1401152",
1966 | "tt1403865",
1967 | "tt1409024",
1968 | "tt1411238",
1969 | "tt1412386",
1970 | "tt1413492",
1971 | "tt1424381",
1972 | "tt1431045",
1973 | "tt1438176",
1974 | "tt1439572",
1975 | "tt1440728",
1976 | "tt1441953",
1977 | "tt1446147",
1978 | "tt1446714",
1979 | "tt1454029",
1980 | "tt1454468",
1981 | "tt1458175",
1982 | "tt1462758",
1983 | "tt1462900",
1984 | "tt1468846",
1985 | "tt1478338",
1986 | "tt1483013",
1987 | "tt1486190",
1988 | "tt1502712",
1989 | "tt1504320",
1990 | "tt1510906",
1991 | "tt1524137",
1992 | "tt1533117",
1993 | "tt1535970",
1994 | "tt1540133",
1995 | "tt1560747",
1996 | "tt1563738",
1997 | "tt1564367",
1998 | "tt1568346",
1999 | "tt1570728",
2000 | "tt1582248",
2001 | "tt1586752",
2002 | "tt1591095",
2003 | "tt1596363",
2004 | "tt1602613",
2005 | "tt1602620",
2006 | "tt1606378",
2007 | "tt1611840",
2008 | "tt1615147",
2009 | "tt1616195",
2010 | "tt1619029",
2011 | "tt1623205",
2012 | "tt1628841",
2013 | "tt1637725",
2014 | "tt1645170",
2015 | "tt1646987",
2016 | "tt1649443",
2017 | "tt1655420",
2018 | "tt1659337",
2019 | "tt1663662",
2020 | "tt1670345",
2021 | "tt1675434",
2022 | "tt1692486",
2023 | "tt1703957",
2024 | "tt1706593",
2025 | "tt1707386",
2026 | "tt1722484",
2027 | "tt1723811",
2028 | "tt1725986",
2029 | "tt1731141",
2030 | "tt1742683",
2031 | "tt1747958",
2032 | "tt1748122",
2033 | "tt1757746",
2034 | "tt1781769",
2035 | "tt1800241",
2036 | "tt1800246",
2037 | "tt1809398",
2038 | "tt1832382",
2039 | "tt1840309",
2040 | "tt1843287",
2041 | "tt1853728",
2042 | "tt1855325",
2043 | "tt1872181",
2044 | "tt1877832",
2045 | "tt1895587",
2046 | "tt1907668",
2047 | "tt1951266",
2048 | "tt1971325",
2049 | "tt1974419",
2050 | "tt1979320",
2051 | "tt1981115",
2052 | "tt2011351",
2053 | "tt2017561",
2054 | "tt2024544",
2055 | "tt2032557",
2056 | "tt2053463",
2057 | "tt2056771",
2058 | "tt2058107",
2059 | "tt2058673",
2060 | "tt2059255",
2061 | "tt2070649",
2062 | "tt2076220",
2063 | "tt2078768",
2064 | "tt2084970",
2065 | "tt2099556",
2066 | "tt2103281",
2067 | "tt2109184",
2068 | "tt2109248",
2069 | "tt2115388",
2070 | "tt2118775",
2071 | "tt2132285",
2072 | "tt2140373",
2073 | "tt2167266",
2074 | "tt2194499",
2075 | "tt2238032",
2076 | "tt2258281",
2077 | "tt2267998",
2078 | "tt2294449",
2079 | "tt2310332",
2080 | "tt2334873",
2081 | "tt2345567",
2082 | "tt2366450",
2083 | "tt2381249",
2084 | "tt2381991",
2085 | "tt2382298",
2086 | "tt2402927",
2087 | "tt2404435",
2088 | "tt2409818",
2089 | "tt2446980",
2090 | "tt2463288",
2091 | "tt2473794",
2092 | "tt2488496",
2093 | "tt2567026",
2094 | "tt2567712",
2095 | "tt2582802",
2096 | "tt2582846",
2097 | "tt2614684",
2098 | "tt2639344",
2099 | "tt2645044",
2100 | "tt2675914",
2101 | "tt2713180",
2102 | "tt2717822",
2103 | "tt2788732",
2104 | "tt2800240",
2105 | "tt2802144",
2106 | "tt2823054",
2107 | "tt2832470",
2108 | "tt2872732",
2109 | "tt2884018",
2110 | "tt2908856",
2111 | "tt2911666",
2112 | "tt2923316",
2113 | "tt2978462",
2114 | "tt2980516",
2115 | "tt3062096",
2116 | "tt3064298",
2117 | "tt3077214",
2118 | "tt3110958",
2119 | "tt3289956",
2120 | "tt3296658",
2121 | "tt3312830",
2122 | "tt3316960",
2123 | "tt3319920",
2124 | "tt3385516",
2125 | "tt3395184",
2126 | "tt3410834",
2127 | "tt3416744",
2128 | "tt3421514",
2129 | "tt3439114",
2130 | "tt3464902",
2131 | "tt3465916",
2132 | "tt3474602",
2133 | "tt3478232",
2134 | "tt3480796",
2135 | "tt3488710",
2136 | "tt3495026",
2137 | "tt3498820",
2138 | "tt3501416",
2139 | "tt3508840",
2140 | "tt3531578",
2141 | "tt3553442",
2142 | "tt3630276",
2143 | "tt3659786",
2144 | "tt3671542",
2145 | "tt3672840",
2146 | "tt3700392",
2147 | "tt3700804",
2148 | "tt3707106",
2149 | "tt3714720",
2150 | "tt3726704",
2151 | "tt3766394",
2152 | "tt3808342",
2153 | "tt3824458",
2154 | "tt3860916",
2155 | "tt3882082",
2156 | "tt3922798",
2157 | "tt3960412",
2158 | "tt4008652",
2159 | "tt4034354",
2160 | "tt4046784",
2161 | "tt4052882",
2162 | "tt4136084",
2163 | "tt4151192",
2164 | "tt4160708",
2165 | "tt4176826",
2166 | "tt4242158",
2167 | "tt4273292",
2168 | "tt4501454",
2169 | "tt4520364",
2170 | "tt4647900",
2171 | "tt4651520",
2172 | "tt4698684",
2173 | "tt4721400",
2174 | "tt4781612",
2175 | "tt4786282",
2176 | "tt4824302",
2177 | "tt4915672",
2178 | "tt4939066",
2179 | "tt4967094",
2180 | "tt4972062",
2181 | "tt5052448",
2182 | "tt5065810",
2183 | "tt5140878",
2184 | "tt5294550",
2185 | "tt5564148",
2186 | "tt5576318",
2187 | "tt5580036",
2188 | "tt5593416",
2189 | "tt5649144",
2190 | "tt5688868",
2191 | "tt5726086",
2192 | "tt5827496",
2193 | "tt5866930",
2194 | "tt6121428",
2195 | "tt6133130",
2196 | "tt6157626",
2197 | "tt6190198",
2198 | "tt6298600",
2199 | "tt6466464",
2200 | "tt6513406",
2201 | "tt6518634",
2202 | "tt6644200",
2203 | "tt6788942",
2204 | "tt7055592",
2205 | "tt7131870",
2206 | "tt7160070",
2207 | "tt7180392",
2208 | "tt7672188"
2209 | ]
2210 | }
--------------------------------------------------------------------------------
/data/movienet_data.py:
--------------------------------------------------------------------------------
1 | from PIL import ImageFilter
2 | import random
3 | import torch
4 | import torchvision.transforms as transforms
5 | import json
6 | import cv2
7 | import numpy as np
8 | from torchvision import utils as vutils
9 |
10 | class TwoWayTransform:
11 | def __init__(self, base_transform_a,
12 | base_transform_b, fixed_aug_shot=True):
13 | self.base_transform_a = base_transform_a
14 | self.base_transform_b = base_transform_b
15 | self.fixed = fixed_aug_shot
16 |
17 | def __call__(self, x):
18 | frame_num = len(x)
19 | if self.fixed:
20 | seed = np.random.randint(2147483647)
21 | q, k = [], []
22 | for i in range(frame_num):
23 | random.seed(seed)
24 | q.append(self.base_transform_a(x[i]))
25 | seed = np.random.randint(2147483647)
26 | for i in range(frame_num):
27 | random.seed(seed)
28 | k.append(self.base_transform_b(x[i]))
29 | else:
30 | q = [self.base_transform_a(x[i]) for i in range(frame_num)]
31 | k = [self.base_transform_b(x[i]) for i in range(frame_num)]
32 | q = torch.cat(q, axis = 0)
33 | k = torch.cat(k, axis = 0)
34 | return [q, k]
35 |
36 |
37 | class MovieNet_Shot_Dataset(torch.utils.data.Dataset):
38 | def __init__(self, img_path, shot_info_path, transform,
39 | shot_len = 16, frame_per_shot = 3, _Type='train'):
40 | self.img_path = img_path
41 | with open(shot_info_path, 'rb') as f:
42 | self.shot_info = json.load(f)
43 | self.img_path = img_path
44 | self.shot_len = shot_len
45 | self.frame_per_shot = frame_per_shot
46 | self.transform = transform
47 | self._Type = _Type.lower()
48 | assert self._Type in ['train','val','test']
49 | self.idx_imdb_map = {}
50 | data_length = 0
51 | for imdb, shot_num in self.shot_info[_Type].items():
52 | for i in range(shot_num // shot_len):
53 | self.idx_imdb_map[data_length] = (imdb, i)
54 | data_length += 1
55 |
56 |
57 | def __len__(self):
58 | return len(self.idx_imdb_map.keys())
59 |
60 |
61 | def _transform(self, img_list):
62 | q, k = [], []
63 | for item in img_list:
64 | out = self.transform(item)
65 | q.append(out[0])
66 | k.append(out[1])
67 | out_q = torch.stack(q, axis=0)
68 | out_k = torch.stack(k, axis=0)
69 | return [out_q, out_k]
70 |
71 |
72 | def _process_puzzle(self, idx):
73 | imdb, puzzle_id = self.idx_imdb_map[idx]
74 | img_path = f'{self.img_path}/{imdb}/{str(puzzle_id).zfill(4)}.jpg'
75 | img = cv2.imread(img_path)
76 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
77 | img = np.vsplit(img, self.shot_len)
78 | img = [np.hsplit(i, self.frame_per_shot) for i in img]
79 | data = self._transform(img)
80 | return data
81 |
82 |
83 | def __getitem__(self, idx):
84 | return self._process_puzzle(idx)
85 |
86 |
87 |
88 | class GaussianBlur:
89 | def __init__(self, sigma=[.1, 2.]):
90 | self.sigma = sigma
91 |
92 | def __call__(self, x):
93 | sigma = random.uniform(self.sigma[0], self.sigma[1])
94 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
95 | return x
96 |
97 | def get_train_loader(cfg):
98 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
99 | std=[0.229, 0.224, 0.225])
100 | augmentation_base = [
101 | transforms.ToPILImage(),
102 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
103 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
104 | transforms.RandomHorizontalFlip(),
105 | transforms.ToTensor(),
106 | normalize
107 | ]
108 | augmentation_color = [
109 | transforms.ToPILImage(),
110 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
111 | transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.5),
112 | transforms.RandomGrayscale(p=0.2),
113 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
114 | transforms.RandomHorizontalFlip(),
115 | transforms.ToTensor(),
116 | normalize
117 | ]
118 | augmentation_q = augmentation_color if cfg['data']['color_aug_for_q'] else augmentation_base
119 | augmentation_k = augmentation_color if cfg['data']['color_aug_for_k'] else augmentation_base
120 |
121 | train_transform = TwoWayTransform(
122 | transforms.Compose(augmentation_q),
123 | transforms.Compose(augmentation_k),
124 | fixed_aug_shot=cfg['data']['fixed_aug_shot'])
125 |
126 | img_path = cfg['data']['data_path']
127 | shot_info_path = cfg['data']['shot_info']
128 | train_dataset = MovieNet_Shot_Dataset(img_path, shot_info_path, train_transform)
129 | train_sampler = None
130 | if cfg['DDP']['multiprocessing_distributed']:
131 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True)
132 | train_loader = torch.utils.data.DataLoader(train_dataset,
133 | batch_size=cfg['optim']['bs'], num_workers=cfg['data']['workers'],
134 | sampler=train_sampler, shuffle=(train_sampler is None), pin_memory=True, drop_last=True)
135 | return train_loader, train_sampler
136 |
137 | if __name__ == '__main__':
138 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
139 | std=[0.229, 0.224, 0.225])
140 | augmentation_base = [
141 | transforms.ToPILImage(),
142 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
143 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
144 | transforms.RandomHorizontalFlip(),
145 | transforms.ToTensor(),
146 | # normalize
147 | ]
148 | augmentation_color = [
149 | transforms.ToPILImage(),
150 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
151 | transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.5),
152 | transforms.RandomGrayscale(p=0.2),
153 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
154 | transforms.RandomHorizontalFlip(),
155 | transforms.ToTensor(),
156 | # normalize
157 | ]
158 | train_transform = TwoWayTransform(
159 | transforms.Compose(augmentation_base),
160 | transforms.Compose(augmentation_color),
161 | fixed_aug_shot=False)
162 | img_path = './compressed_shot_images'
163 | shot_info_path = './MovieNet_shot_num.json'
164 | train_dataset = MovieNet_Shot_Dataset(img_path, shot_info_path, train_transform)
165 | print(f'len: {len(train_dataset)}')
166 | i = train_dataset[0]
167 | print(i[0].size())
168 |
169 |
170 |
171 |
172 |
173 |
174 |
--------------------------------------------------------------------------------
/extract_embeddings.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import os
3 | import torch
4 | import argparse
5 | import time
6 | from models.backbones.visual.resnet import encoder_resnet50
7 | import json
8 | import cv2
9 | from torchvision import transforms
10 | from torch.utils.data import DataLoader
11 |
12 | class MovieNet_SingleShot_Dataset(torch.utils.data.Dataset):
13 | def __init__(self, img_path, shot_info_path, transform,
14 | frame_per_shot = 3, _Type='train'):
15 | self.img_path = img_path
16 | with open(shot_info_path, 'rb') as f:
17 | self.shot_info = json.load(f)
18 | self.img_path = img_path
19 | self.frame_per_shot = frame_per_shot
20 | self.transform = transform
21 | self._Type = _Type.lower()
22 | assert self._Type in ['train','val','test']
23 | self.idx_imdb_map = {}
24 | data_length = 0
25 | for info in self.shot_info[_Type]:
26 | imdb = info['name']
27 | for shot in info['label']:
28 | self.idx_imdb_map[data_length] = (imdb, shot[0], shot[1])
29 | data_length += 1
30 |
31 | def __len__(self):
32 | return len(self.idx_imdb_map.keys())
33 |
34 | def _process(self, idx):
35 | imdb, _id, label = self.idx_imdb_map[idx]
36 | img_path_0 = f'{self.img_path}/{imdb}/shot_{_id}_img_0.jpg'
37 | img_path_1 = f'{self.img_path}/{imdb}/shot_{_id}_img_1.jpg'
38 | img_path_2 = f'{self.img_path}/{imdb}/shot_{_id}_img_2.jpg'
39 | img_0 = cv2.cvtColor(cv2.imread(img_path_0), cv2.COLOR_BGR2RGB)
40 | img_1 = cv2.cvtColor(cv2.imread(img_path_1), cv2.COLOR_BGR2RGB)
41 | img_2 = cv2.cvtColor(cv2.imread(img_path_2), cv2.COLOR_BGR2RGB)
42 | data_0 = self.transform(img_0)
43 | data_1 = self.transform(img_1)
44 | data_2 = self.transform(img_2)
45 | data = torch.cat([data_0, data_1, data_2], axis=0)
46 | label = int(label)
47 | # According to LGSS[1]
48 | # [1] https://arxiv.org/abs/2004.02678
49 | if label == -1:
50 | label = 1
51 | return data, label, (imdb, _id)
52 |
53 |
54 | def __getitem__(self, idx):
55 | return self._process(idx)
56 |
57 | def get_loader(cfg, _Type='train'):
58 | normalize = transforms.Normalize(
59 | mean=[0.485, 0.456, 0.406],
60 | std=[0.229, 0.224, 0.225]
61 | )
62 |
63 | _transform = transforms.Compose([
64 | transforms.ToPILImage(),
65 | transforms.Resize(224),
66 | transforms.CenterCrop(224),
67 | transforms.ToTensor(),
68 | normalize,
69 | ])
70 | dataset = MovieNet_SingleShot_Dataset(
71 | img_path = cfg.shot_img_path,
72 | shot_info_path = cfg.shot_info_path,
73 | transform = _transform,
74 | frame_per_shot = cfg.frame_per_shot,
75 | _Type=_Type,
76 | )
77 | loader = DataLoader(
78 | dataset, batch_size=cfg.bs, drop_last=False,
79 | shuffle=False, num_workers=cfg.worker_num, pin_memory=True
80 | )
81 | return loader
82 |
83 | def get_encoder(model_name='resnet50', weight_path='', input_channel=9):
84 | encoder = None
85 | model_name = model_name.lower()
86 | if model_name == 'resnet50':
87 | encoder = encoder_resnet50(weight_path='',input_channel=input_channel)
88 | model_weight = torch.load(weight_path,map_location=torch.device('cpu'))['state_dict']
89 | pretrained_dict = {}
90 | for k, v in model_weight.items():
91 | # moco loading
92 | if k.startswith('module.encoder_k'):
93 | continue
94 | if k == 'module.queue' or k == 'module.queue_ptr':
95 | continue
96 | if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
97 | k = k[17:]
98 |
99 | pretrained_dict[k] = v
100 | encoder.load_state_dict(pretrained_dict, strict = False)
101 | print(f'loaded from {weight_path}')
102 | return encoder
103 |
104 |
105 | @torch.no_grad()
106 | def get_save_embeddings(model, loader, shot_num, filename, log_interval=100):
107 | # dict
108 | # key: index, value: [(embeddings, label), ...]
109 | embeddings = {}
110 | model.eval()
111 |
112 | print(f'total length of dataset: {len(loader.dataset)}')
113 | print(f'total length of loader: {len(loader)}')
114 |
115 | for batch_idx, (data, target, index) in enumerate(loader):
116 | if batch_idx % log_interval == 0:
117 | print(f'processed: {batch_idx}')
118 |
119 | data = data.cuda(non_blocking=True) # ([bs, shot_num, 9, 224, 224])
120 | data = data.view(-1, 9, 224, 224)
121 |
122 | target = target.view(-1).cuda()
123 | output = model(data, False) # ([bs * shot_num, 2048])
124 | for i, key in enumerate(index[0]):
125 | if key not in embeddings:
126 | embeddings[key] = []
127 | t_emb = output[i*shot_num:(i+1)*shot_num].cpu().numpy()
128 | t_label = target[i].cpu().numpy()
129 | embeddings[key].append((t_emb.copy() ,t_label.copy()))
130 | pickle.dump(embeddings, open(filename, 'wb'))
131 |
132 |
133 | def extract_features(cfg):
134 | time_str = time.strftime("%Y-%m-%d_%H_%M_%S", time.localtime())
135 | save_dir = os.path.join(cfg.save_dir, time_str)
136 | if not os.path.exists(save_dir):
137 | os.makedirs(save_dir)
138 | cfg.log_file = save_dir + '/extraction.log'
139 | encoder = get_encoder(
140 | model_name=cfg.model_name,
141 | weight_path=cfg.model_path,
142 | input_channel=cfg.frame_per_shot * 3
143 | ).cuda()
144 | dataType = [cfg.Type]
145 | if dataType[0] == 'all':
146 | dataType = ['train','test','val']
147 | for _T in dataType:
148 | to_log(cfg, f'processing: {_T} \n')
149 | loader = get_loader(cfg, _Type = _T)
150 | filename = os.path.join(save_dir, _T+'.pkl')
151 | get_save_embeddings(encoder,
152 | loader,
153 | cfg.shot_num,
154 | filename,
155 | log_interval=100
156 | )
157 | to_log(cfg, f'{_T} embeddings are saved in {filename}!\n')
158 |
159 |
160 | def to_log(cfg, content, echo=True):
161 | with open(cfg.log_file, 'a') as f:
162 | f.writelines(content+'\n')
163 | if echo: print(content)
164 |
165 |
166 | def get_config():
167 | parser = argparse.ArgumentParser()
168 | parser.add_argument('model_path', type=str)
169 | parser.add_argument('--shot_info_path', type=str,
170 | default='./data/movie1K.scene_seg_318_name_index_shotnum_label.v1.json')
171 | parser.add_argument('--shot_img_path', type=str, default='./MovieNet_unzip/240P/')
172 | parser.add_argument('--Type', type=str, default='train', choices=['train','test','val','all'])
173 | parser.add_argument('--model_name', type=str, default='resnet50')
174 | parser.add_argument('--frame_per_shot', type=int, default=3)
175 | parser.add_argument('--shot_num', type=int, default=1)
176 | parser.add_argument('--worker_num', type=int, default=16)
177 | parser.add_argument('--bs', type=int, default=64)
178 | parser.add_argument('--save_dir', type=str, default='./embeddings/')
179 | parser.add_argument('--gpu-id', type=str, default='0')
180 | cfg = parser.parse_args()
181 |
182 | # select GPU
183 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
184 | os.environ["CUDA_VISIBLE_DEVICES"] = cfg.gpu_id
185 |
186 | return cfg
187 |
188 |
189 | if __name__ == '__main__':
190 | cfg = get_config()
191 | extract_features(cfg)
--------------------------------------------------------------------------------
/figures/puzzle_example.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentYoutuResearch/SceneSegmentation-SCRL/7d2daed4c8f1922aa6c85abaf9db36abaf0ae67e/figures/puzzle_example.jpg
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentYoutuResearch/SceneSegmentation-SCRL/7d2daed4c8f1922aa6c85abaf9db36abaf0ae67e/models/__init__.py
--------------------------------------------------------------------------------
/models/backbones/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentYoutuResearch/SceneSegmentation-SCRL/7d2daed4c8f1922aa6c85abaf9db36abaf0ae67e/models/backbones/__init__.py
--------------------------------------------------------------------------------
/models/backbones/visual/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | import torch.nn as nn
4 | from typing import Type, Any, Callable, Union, List, Optional
5 |
6 |
7 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
8 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
9 | 'wide_resnet50_2', 'wide_resnet101_2']
10 |
11 |
12 | model_urls = {
13 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
14 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
15 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
16 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
17 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
18 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
19 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
20 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
21 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
22 | }
23 |
24 |
25 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
26 | """3x3 convolution with padding"""
27 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
28 | padding=dilation, groups=groups, bias=False, dilation=dilation)
29 |
30 |
31 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
32 | """1x1 convolution"""
33 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
34 |
35 |
36 | class BasicBlock(nn.Module):
37 | expansion: int = 1
38 |
39 | def __init__(
40 | self,
41 | inplanes: int,
42 | planes: int,
43 | stride: int = 1,
44 | downsample: Optional[nn.Module] = None,
45 | groups: int = 1,
46 | base_width: int = 64,
47 | dilation: int = 1,
48 | norm_layer: Optional[Callable[..., nn.Module]] = None
49 | ) -> None:
50 | super(BasicBlock, self).__init__()
51 | if norm_layer is None:
52 | norm_layer = nn.BatchNorm2d
53 | if groups != 1 or base_width != 64:
54 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
55 | if dilation > 1:
56 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
57 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
58 | self.conv1 = conv3x3(inplanes, planes, stride)
59 | self.bn1 = norm_layer(planes)
60 | self.relu = nn.ReLU(inplace=True)
61 | self.conv2 = conv3x3(planes, planes)
62 | self.bn2 = norm_layer(planes)
63 | self.downsample = downsample
64 | self.stride = stride
65 |
66 | def forward(self, x: Tensor) -> Tensor:
67 | identity = x
68 |
69 | out = self.conv1(x)
70 | out = self.bn1(out)
71 | out = self.relu(out)
72 |
73 | out = self.conv2(out)
74 | out = self.bn2(out)
75 |
76 | if self.downsample is not None:
77 | identity = self.downsample(x)
78 |
79 | out += identity
80 | out = self.relu(out)
81 |
82 | return out
83 |
84 |
85 | class Bottleneck(nn.Module):
86 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
87 | # while original implementation places the stride at the first 1x1 convolution(self.conv1)
88 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
89 | # This variant is also known as ResNet V1.5 and improves accuracy according to
90 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
91 |
92 | expansion: int = 4
93 |
94 | def __init__(
95 | self,
96 | inplanes: int,
97 | planes: int,
98 | stride: int = 1,
99 | downsample: Optional[nn.Module] = None,
100 | groups: int = 1,
101 | base_width: int = 64,
102 | dilation: int = 1,
103 | norm_layer: Optional[Callable[..., nn.Module]] = None
104 | ) -> None:
105 | super(Bottleneck, self).__init__()
106 | if norm_layer is None:
107 | norm_layer = nn.BatchNorm2d
108 | width = int(planes * (base_width / 64.)) * groups
109 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
110 | self.conv1 = conv1x1(inplanes, width)
111 | self.bn1 = norm_layer(width)
112 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
113 | self.bn2 = norm_layer(width)
114 | self.conv3 = conv1x1(width, planes * self.expansion)
115 | self.bn3 = norm_layer(planes * self.expansion)
116 | self.relu = nn.ReLU(inplace=True)
117 | self.downsample = downsample
118 | self.stride = stride
119 |
120 | def forward(self, x: Tensor) -> Tensor:
121 | identity = x
122 |
123 | out = self.conv1(x)
124 | out = self.bn1(out)
125 | out = self.relu(out)
126 |
127 | out = self.conv2(out)
128 | out = self.bn2(out)
129 | out = self.relu(out)
130 |
131 | out = self.conv3(out)
132 | out = self.bn3(out)
133 |
134 | if self.downsample is not None:
135 | identity = self.downsample(x)
136 |
137 | out += identity
138 | out = self.relu(out)
139 |
140 | return out
141 |
142 |
143 | class ResNet(nn.Module):
144 |
145 | def __init__(
146 | self,
147 | block: Type[Union[BasicBlock, Bottleneck]],
148 | layers: List[int],
149 | input_channel:int = 3,
150 | num_classes: int = 1000,
151 | zero_init_residual: bool = True,
152 | groups: int = 1,
153 | width_per_group: int = 64,
154 | replace_stride_with_dilation: Optional[List[bool]] = None,
155 | norm_layer: Optional[Callable[..., nn.Module]] = None
156 | ) -> None:
157 | super(ResNet, self).__init__()
158 | if norm_layer is None:
159 | norm_layer = nn.BatchNorm2d
160 | self._norm_layer = norm_layer
161 |
162 | self.inplanes = 64
163 | self.dilation = 1
164 | if replace_stride_with_dilation is None:
165 | # each element in the tuple indicates if we should replace
166 | # the 2x2 stride with a dilated convolution instead
167 | replace_stride_with_dilation = [False, False, False]
168 | if len(replace_stride_with_dilation) != 3:
169 | raise ValueError("replace_stride_with_dilation should be None "
170 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
171 | self.groups = groups
172 | self.base_width = width_per_group
173 | self.conv1 = nn.Conv2d(input_channel, self.inplanes, kernel_size=7, stride=2, padding=3,
174 | bias=False)
175 | self.bn1 = norm_layer(self.inplanes)
176 | self.relu = nn.ReLU(inplace=True)
177 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
178 | self.layer1 = self._make_layer(block, 64, layers[0])
179 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
180 | dilate=replace_stride_with_dilation[0])
181 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
182 | dilate=replace_stride_with_dilation[1])
183 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
184 | dilate=replace_stride_with_dilation[2])
185 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
186 | self.fc = nn.Linear(512 * block.expansion, num_classes)
187 |
188 | for m in self.modules():
189 | if isinstance(m, nn.Conv2d):
190 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
191 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
192 | nn.init.constant_(m.weight, 1)
193 | nn.init.constant_(m.bias, 0)
194 |
195 | # Zero-initialize the last BN in each residual branch,
196 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
197 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
198 | if zero_init_residual:
199 | for m in self.modules():
200 | if isinstance(m, Bottleneck):
201 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
202 | elif isinstance(m, BasicBlock):
203 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
204 |
205 | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
206 | stride: int = 1, dilate: bool = False) -> nn.Sequential:
207 | norm_layer = self._norm_layer
208 | downsample = None
209 | previous_dilation = self.dilation
210 | if dilate:
211 | self.dilation *= stride
212 | stride = 1
213 | if stride != 1 or self.inplanes != planes * block.expansion:
214 | downsample = nn.Sequential(
215 | conv1x1(self.inplanes, planes * block.expansion, stride),
216 | norm_layer(planes * block.expansion),
217 | )
218 |
219 | layers = []
220 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
221 | self.base_width, previous_dilation, norm_layer))
222 | self.inplanes = planes * block.expansion
223 | for _ in range(1, blocks):
224 | layers.append(block(self.inplanes, planes, groups=self.groups,
225 | base_width=self.base_width, dilation=self.dilation,
226 | norm_layer=norm_layer))
227 |
228 | return nn.Sequential(*layers)
229 |
230 | def _forward_impl(self, x: Tensor, is_fc: bool) -> Tensor:
231 | # See note [TorchScript super()]
232 | x = self.conv1(x)
233 | x = self.bn1(x)
234 | x = self.relu(x)
235 | x = self.maxpool(x)
236 |
237 | x = self.layer1(x)
238 | x = self.layer2(x)
239 | x = self.layer3(x)
240 | x = self.layer4(x)
241 |
242 | x = self.avgpool(x)
243 | x = torch.flatten(x, 1)
244 | if is_fc:
245 | x = self.fc(x)
246 | return x
247 |
248 | def forward(self, x: Tensor, is_fc=True) -> Tensor:
249 | return self._forward_impl(x, is_fc)
250 |
251 |
252 | class Encoder(ResNet):
253 | def __init__(self,
254 | input_channel:int,
255 | block: Type[Union[BasicBlock, Bottleneck]],
256 | layers: List[int],
257 | weight_path: str,
258 | num_classes: int = 2048,
259 | **kwargs: Any
260 | ) -> None:
261 | super(Encoder, self).__init__(block, layers, input_channel, num_classes, **kwargs)
262 | self.input_channel = input_channel
263 | if weight_path is not None and len(weight_path) > 1:
264 | print(f'loading weight from {weight_path}')
265 | self._load_from_weight(weight_path)
266 |
267 | def _load_from_weight(self, weight_path: str):
268 | model_weight = torch.load(weight_path)
269 | pretrained_dict = {}
270 | for k, v in model_weight.items():
271 | if k.startswith('conv1') or k.startswith('fc'):
272 | continue
273 | pretrained_dict[k] = v
274 | self.load_state_dict(pretrained_dict, strict = False)
275 |
276 |
277 | def _forward_fc(self, x: Tensor):
278 | x = self.fc(x)
279 | return x
280 |
281 |
282 |
283 |
284 | def _resnet(
285 | arch: str,
286 | block: Type[Union[BasicBlock, Bottleneck]],
287 | layers: List[int],
288 | pretrained: bool,
289 | progress: bool,
290 | **kwargs: Any
291 | ) -> ResNet:
292 | model = ResNet(block, layers, **kwargs)
293 | if pretrained:
294 | state_dict = load_state_dict_from_url(model_urls[arch],
295 | progress=progress)
296 | model.load_state_dict(state_dict)
297 | return model
298 |
299 |
300 | def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
301 | r"""ResNet-18 model from
302 | `"Deep Residual Learning for Image Recognition" `_.
303 |
304 | Args:
305 | pretrained (bool): If True, returns a model pre-trained on ImageNet
306 | progress (bool): If True, displays a progress bar of the download to stderr
307 | """
308 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
309 | **kwargs)
310 |
311 |
312 |
313 | def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
314 | r"""ResNet-34 model from
315 | `"Deep Residual Learning for Image Recognition" `_.
316 |
317 | Args:
318 | pretrained (bool): If True, returns a model pre-trained on ImageNet
319 | progress (bool): If True, displays a progress bar of the download to stderr
320 | """
321 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
322 | **kwargs)
323 |
324 |
325 |
326 | def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
327 | r"""ResNet-50 model from
328 | `"Deep Residual Learning for Image Recognition" `_.
329 |
330 | Args:
331 | pretrained (bool): If True, returns a model pre-trained on ImageNet
332 | progress (bool): If True, displays a progress bar of the download to stderr
333 | """
334 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
335 | **kwargs)
336 |
337 |
338 |
339 | def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
340 | r"""ResNet-101 model from
341 | `"Deep Residual Learning for Image Recognition" `_.
342 |
343 | Args:
344 | pretrained (bool): If True, returns a model pre-trained on ImageNet
345 | progress (bool): If True, displays a progress bar of the download to stderr
346 | """
347 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
348 | **kwargs)
349 |
350 |
351 |
352 | def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
353 | r"""ResNet-152 model from
354 | `"Deep Residual Learning for Image Recognition" `_.
355 |
356 | Args:
357 | pretrained (bool): If True, returns a model pre-trained on ImageNet
358 | progress (bool): If True, displays a progress bar of the download to stderr
359 | """
360 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
361 | **kwargs)
362 |
363 |
364 |
365 | def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
366 | r"""ResNeXt-50 32x4d model from
367 | `"Aggregated Residual Transformation for Deep Neural Networks" `_.
368 |
369 | Args:
370 | pretrained (bool): If True, returns a model pre-trained on ImageNet
371 | progress (bool): If True, displays a progress bar of the download to stderr
372 | """
373 | kwargs['groups'] = 32
374 | kwargs['width_per_group'] = 4
375 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
376 | pretrained, progress, **kwargs)
377 |
378 |
379 |
380 | def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
381 | r"""ResNeXt-101 32x8d model from
382 | `"Aggregated Residual Transformation for Deep Neural Networks" `_.
383 |
384 | Args:
385 | pretrained (bool): If True, returns a model pre-trained on ImageNet
386 | progress (bool): If True, displays a progress bar of the download to stderr
387 | """
388 | kwargs['groups'] = 32
389 | kwargs['width_per_group'] = 8
390 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
391 | pretrained, progress, **kwargs)
392 |
393 |
394 |
395 | def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
396 | r"""Wide ResNet-50-2 model from
397 | `"Wide Residual Networks" `_.
398 |
399 | The model is the same as ResNet except for the bottleneck number of channels
400 | which is twice larger in every block. The number of channels in outer 1x1
401 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
402 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
403 |
404 | Args:
405 | pretrained (bool): If True, returns a model pre-trained on ImageNet
406 | progress (bool): If True, displays a progress bar of the download to stderr
407 | """
408 | kwargs['width_per_group'] = 64 * 2
409 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
410 | pretrained, progress, **kwargs)
411 |
412 |
413 |
414 | def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
415 | r"""Wide ResNet-101-2 model from
416 | `"Wide Residual Networks" `_.
417 |
418 | The model is the same as ResNet except for the bottleneck number of channels
419 | which is twice larger in every block. The number of channels in outer 1x1
420 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
421 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
422 |
423 | Args:
424 | pretrained (bool): If True, returns a model pre-trained on ImageNet
425 | progress (bool): If True, displays a progress bar of the download to stderr
426 | """
427 | kwargs['width_per_group'] = 64 * 2
428 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
429 | pretrained, progress, **kwargs)
430 |
431 |
432 | def encoder_resnet50(input_channel:int = 9, weight_path: str = '', progress: bool = True, num_classes=2048,
433 | **kwargs: Any) -> Encoder:
434 |
435 | return Encoder(input_channel, Bottleneck, [3, 4, 6, 3], weight_path, num_classes, **kwargs)
436 |
437 |
438 | def get_encoder(model_name='resnet50', weight_path='', modal='v', input_channel=9, ssl_type='moco'):
439 | encoder = None
440 | model_name = model_name.lower()
441 | modal = modal.lower()
442 | ssl_type = ssl_type.lower()
443 |
444 | if modal.startswith('v'):
445 | if model_name == 'resnet50':
446 | if 'resnet50-19c8e357' in weight_path or len(weight_path) < 1:
447 | encoder = encoder_resnet50(weight_path=weight_path,input_channel=input_channel)
448 | else:
449 | encoder = encoder_resnet50(weight_path='',input_channel=input_channel)
450 | # model_weight = torch.load(weight_path)['state_dict']
451 | model_weight = torch.load(weight_path,map_location=torch.device('cpu'))['state_dict']
452 | pretrained_dict = {}
453 | for k, v in model_weight.items():
454 | # print(k)
455 | if ssl_type == 'moco':
456 | # moco loading
457 | if k.startswith('module.encoder_k'):
458 | continue
459 | if k == 'module.queue' or k == 'module.queue_ptr':
460 | continue
461 | if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
462 | k = k[17:]
463 | else:
464 | # simsiam loading
465 | if k.startswith('module.encoder') and not k.startswith('module.encoder.fc'):
466 | k = k[15:]
467 |
468 | pretrained_dict[k] = v
469 | encoder.load_state_dict(pretrained_dict, strict = False)
470 |
471 | return encoder
472 |
473 |
474 |
475 | if __name__ == '__main__':
476 | model = encoder_resnet50(input_channel=9, weight_path='./pretrained/resnet50-19c8e357.pth')
477 | x = torch.randn(2,9,224,224)
478 | out = model(x)
479 | print(out.size())
480 |
--------------------------------------------------------------------------------
/models/core/SCRL_MoCo.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from cluster.Group import Cluster_GPU
4 |
5 | class SCRL(nn.Module):
6 | """
7 | Referenced from MoCo[1] and SCRL[2].
8 | [1] https://arxiv.org/abs/1911.05722
9 | [2] https://arxiv.org/abs/2205.05487
10 | """
11 | def __init__(self, base_encoder, dim=2048, K=65536,
12 | m=0.999, T=0.07, mlp=False,
13 | encoder_pretrained_path: str ='',
14 | multi_positive = False,
15 | positive_selection = 'cluster',
16 | cluster_num = 10,
17 | soft_gamma=0.5):
18 | super(SCRL, self).__init__()
19 |
20 | self.K = K
21 | self.m = m
22 | self.T = T
23 | self.dim = dim
24 | self.multi_positive = multi_positive
25 | self.forward_fn = self.forward_SCRL
26 | self.cluster_num = cluster_num
27 | self.soft_gamma = soft_gamma
28 | assert self.cluster_num > 0
29 |
30 | # positive selection strategy
31 | if 'cluster' in positive_selection:
32 | self.selection_fn = self.get_q_and_k_index_cluster
33 | self.cluster_obj = Cluster_GPU(self.cluster_num)
34 | else:
35 | raise NotImplementedError
36 |
37 | self.encoder_q = base_encoder(weight_path = encoder_pretrained_path)
38 | self.encoder_k = base_encoder(weight_path = encoder_pretrained_path)
39 | self.mlp = mlp
40 |
41 | # hack: brute-force replacement
42 | if mlp:
43 | dim_mlp = self.encoder_q.fc.weight.shape[1]
44 | self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc)
45 | self.encoder_k.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc)
46 |
47 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
48 | param_k.data.copy_(param_q.data)
49 | param_k.requires_grad = False
50 |
51 | # create the queue
52 | self.register_buffer("queue", torch.randn(dim, K))
53 | self.queue = nn.functional.normalize(self.queue, dim=0)
54 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
55 |
56 | @torch.no_grad()
57 | def _momentum_update_key_encoder(self):
58 | """
59 | Momentum update of the key encoder
60 | """
61 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
62 | param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
63 |
64 | @torch.no_grad()
65 | def _dequeue_and_enqueue(self, keys):
66 | # gather keys before updating queue
67 | keys = concat_all_gather(keys)
68 |
69 | batch_size = keys.shape[0]
70 |
71 | ptr = int(self.queue_ptr)
72 |
73 | assert self.K % batch_size == 0 # for simplicity
74 |
75 | # replace the keys at ptr (dequeue and enqueue)
76 | self.queue[:, ptr:ptr + batch_size] = keys.T
77 | ptr = (ptr + batch_size) % self.K # move pointer
78 |
79 | self.queue_ptr[0] = ptr
80 |
81 | @torch.no_grad()
82 | def _batch_shuffle_ddp(self, x):
83 | """
84 | Batch shuffle, for making use of BatchNorm.
85 | *** Only support DistributedDataParallel (DDP) model. ***
86 | """
87 | # gather from all gpus
88 | batch_size_this = x.shape[0]
89 | x_gather = concat_all_gather(x)
90 | batch_size_all = x_gather.shape[0]
91 |
92 | num_gpus = batch_size_all // batch_size_this
93 |
94 | # random shuffle index
95 | idx_shuffle = torch.randperm(batch_size_all).cuda()
96 |
97 | # broadcast to all gpus
98 | torch.distributed.broadcast(idx_shuffle, src=0)
99 |
100 | # index for restoring
101 | idx_unshuffle = torch.argsort(idx_shuffle)
102 |
103 | # shuffled index for this gpu
104 | gpu_idx = torch.distributed.get_rank()
105 | idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]
106 |
107 | return x_gather[idx_this], idx_unshuffle
108 |
109 |
110 | @torch.no_grad()
111 | def _batch_unshuffle_ddp(self, x, idx_unshuffle):
112 | """
113 | Undo batch shuffle.
114 | *** Only support DistributedDataParallel (DDP) model. ***
115 | """
116 | # gather from all gpus
117 | batch_size_this = x.shape[0]
118 | x_gather = concat_all_gather(x)
119 | batch_size_all = x_gather.shape[0]
120 |
121 | num_gpus = batch_size_all // batch_size_this
122 |
123 | # restored index for this gpu
124 | gpu_idx = torch.distributed.get_rank()
125 | idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]
126 |
127 | return x_gather[idx_this]
128 |
129 | @torch.no_grad()
130 | def get_q_and_k_index_cluster(self, embeddings, return_group=False) -> tuple:
131 |
132 | B = embeddings.size(0)
133 | target_index = list(range(0, B))
134 | q_index = target_index
135 |
136 | choice_cluster, choice_points = self.cluster_obj(embeddings)
137 | k_index = []
138 | for c in choice_cluster:
139 | k_index.append(int(choice_points[c]))
140 | if return_group:
141 | return (q_index, k_index, choice_cluster, choice_points)
142 | else:
143 | return (q_index, k_index)
144 |
145 |
146 | def forward(self, img_q, img_k):
147 | """
148 | Input:
149 | query , key (images)
150 | Output:
151 | logits, targets
152 | """
153 | return self.forward_fn(img_q, img_k)
154 |
155 |
156 | def forward_SCRL(self, img_q, img_k):
157 | # compute query features
158 | embeddings = self.encoder_q(img_q, self.mlp)
159 | embeddings = nn.functional.normalize(embeddings, dim=1)
160 |
161 | # get q and k index
162 | index_q, index_k = self.selection_fn(embeddings)
163 |
164 | # features of q
165 | q = embeddings[index_q]
166 |
167 | # compute key features
168 | with torch.no_grad():
169 | # update the key encoder
170 | self._momentum_update_key_encoder()
171 |
172 | # shuffle for making use of BN
173 | img_k, idx_unshuffle = self._batch_shuffle_ddp(img_k)
174 |
175 | k = self.encoder_k(img_k, self.mlp)
176 | k = nn.functional.normalize(k, dim=1)
177 |
178 | # undo shuffle
179 | k = self._batch_unshuffle_ddp(k, idx_unshuffle)
180 |
181 | k_ori = k
182 | k = k[index_k]
183 |
184 | # compute logits
185 | # positive logits: Nx1
186 | if self.multi_positive:
187 | # SCRL Soft-SC
188 | k = (k + k_ori) * self.soft_gamma
189 |
190 |
191 | l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
192 |
193 | # negative logits: NxK
194 | l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
195 |
196 | # logits: Nx(1+K)
197 | logits = torch.cat([l_pos, l_neg], dim=1)
198 |
199 | # apply temperature
200 | logits /= self.T
201 |
202 | # labels: positive key indicators
203 | labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
204 |
205 | # dequeue and enqueue
206 | self._dequeue_and_enqueue(k)
207 |
208 | return logits, labels
209 |
210 | # the old moco forward func
211 | def forward_moco_old(self, im_q, im_k):
212 | """
213 | Input:
214 | im_q: a batch of query images
215 | im_k: a batch of key images
216 | Output:
217 | logits, targets
218 | """
219 |
220 | # compute query features
221 | q = self.encoder_q(im_q) # queries: NxC
222 | q = nn.functional.normalize(q, dim=1)
223 |
224 | # compute key features
225 | with torch.no_grad(): # no gradient to keys
226 | self._momentum_update_key_encoder() # update the key encoder
227 |
228 | # shuffle for making use of BN
229 | im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k)
230 |
231 | k = self.encoder_k(im_k) # keys: NxC
232 | k = nn.functional.normalize(k, dim=1)
233 |
234 | # undo shuffle
235 | k = self._batch_unshuffle_ddp(k, idx_unshuffle)
236 |
237 | # compute logits
238 | # Einstein sum is more intuitive
239 | # positive logits: Nx1
240 | l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
241 | # negative logits: NxK
242 | l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
243 |
244 | # logits: Nx(1+K)
245 | logits = torch.cat([l_pos, l_neg], dim=1)
246 |
247 | # apply temperature
248 | logits /= self.T
249 |
250 | # labels: positive key indicators
251 | labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
252 |
253 | # dequeue and enqueue
254 | self._dequeue_and_enqueue(k)
255 |
256 | return logits, labels
257 |
258 |
259 | # utils
260 | @torch.no_grad()
261 | def concat_all_gather(tensor):
262 | """
263 | Performs all_gather operation on the provided tensors.
264 | *** Warning ***: torch.distributed.all_gather has no gradient.
265 | """
266 | tensors_gather = [torch.ones_like(tensor)
267 | for _ in range(torch.distributed.get_world_size())]
268 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
269 |
270 | output = torch.cat(tensors_gather, dim=0)
271 | return output
272 |
--------------------------------------------------------------------------------
/models/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentYoutuResearch/SceneSegmentation-SCRL/7d2daed4c8f1922aa6c85abaf9db36abaf0ae67e/models/core/__init__.py
--------------------------------------------------------------------------------
/models/factory.py:
--------------------------------------------------------------------------------
1 | import models.backbones.visual.resnet as resnet
2 | from models.core.SCRL_MoCo import SCRL
3 | from data.movienet_data import get_train_loader
4 | import torch, os
5 | from utils import to_log
6 |
7 | def get_model(cfg):
8 | encoder = None
9 | model = None
10 | if 'multimodal' not in cfg or cfg['multimodal']['using_audio'] == False:
11 | encoder = resnet.encoder_resnet50
12 | else:
13 | raise NotImplementedError
14 | assert encoder is not None
15 |
16 | to_log(cfg, 'backbone init: ' + cfg['model']['backbone'], True)
17 |
18 | if cfg['model']['SSL'] == 'SCRL':
19 | model = SCRL(
20 | base_encoder = encoder,
21 | dim = cfg['MoCo']['dim'],
22 | K = cfg['MoCo']['k'],
23 | m = cfg['MoCo']['m'],
24 | T = cfg['MoCo']['t'],
25 | mlp = cfg['MoCo']['mlp'],
26 | encoder_pretrained_path = cfg['model']['backbone_pretrain'],
27 | multi_positive = cfg['MoCo']['multi_positive'],
28 | positive_selection = cfg['model']['Positive_Selection'],
29 | cluster_num = cfg['model']['cluster_num'],
30 | soft_gamma = cfg['model']['soft_gamma'],
31 | )
32 | else:
33 | raise NotImplementedError
34 | to_log(cfg, 'model init: ' + cfg['model']['SSL'], True)
35 |
36 | if cfg['model']['SyncBatchNorm']:
37 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
38 | to_log(cfg, 'SyncBatchNorm: on' if cfg['model']['SyncBatchNorm'] else 'SyncBatchNorm: off', True)
39 | return model
40 |
41 | def get_loader(cfg):
42 | train_loader, train_sampler = get_train_loader(cfg)
43 | return train_loader, train_sampler
44 |
45 |
46 | def get_criterion(cfg):
47 | criterion = None
48 | if cfg['model']['SSL'] == 'simsiam':
49 | criterion = torch.nn.CosineSimilarity(dim=1)
50 | elif cfg['model']['SSL'] == 'SCRL':
51 | criterion = torch.nn.CrossEntropyLoss()
52 | else:
53 | raise NotImplementedError
54 | to_log(cfg, 'criterion init: ' + str(criterion), True)
55 | return criterion
56 |
57 | def get_optimizer(cfg, model):
58 | optimizer = None
59 | if cfg['optim']['optimizer'] == 'sgd':
60 | if cfg['model']['SSL'] == 'simsiam':
61 | if cfg['model']['fix_pred_lr']:
62 | optim_params = [{'params': model.module.encoder.parameters(), 'fix_lr': False},
63 | {'params': model.module.predictor.parameters(), 'fix_lr': True}]
64 | else:
65 | optim_params = model.parameters()
66 | elif cfg['model']['SSL'] == 'SCRL':
67 | optim_params = model.parameters()
68 | else:
69 | raise NotImplementedError
70 |
71 | optimizer = torch.optim.SGD(optim_params, cfg['optim']['lr'],
72 | momentum=cfg['optim']['momentum'],
73 | weight_decay=cfg['optim']['wd'])
74 | else:
75 | raise NotImplementedError
76 | return optimizer
77 |
78 | def get_training_stuff(cfg, gpu, ngpus_per_node):
79 | cfg['optim']['bs'] = int(cfg['optim']['bs'] / ngpus_per_node)
80 | to_log(cfg, 'shot per GPU: ' + str(cfg['optim']['bs']), True)
81 |
82 | if cfg['data']['clipshuffle']:
83 | len_per_data = cfg['data']['clipshuffle_len']
84 | else:
85 | len_per_data = 1
86 | assert cfg['optim']['bs'] % len_per_data == 0
87 | cfg['optim']['bs'] = int(cfg['optim']['bs'] / len_per_data )
88 | cfg['data']['workers'] = int(( cfg['data']['workers'] + ngpus_per_node - 1) / ngpus_per_node)
89 | to_log(cfg, 'batch size per GPU: ' + str(cfg['optim']['bs']), True)
90 | to_log(cfg, 'worker per GPU: ' + str(cfg['data']['workers']) , True)
91 |
92 | train_loader, train_sampler = get_train_loader(cfg)
93 | model = get_model(cfg)
94 | model.cuda(gpu)
95 | model = torch.nn.parallel.DistributedDataParallel(model,
96 | device_ids=[gpu],
97 | output_device=gpu,
98 | find_unused_parameters=True)
99 |
100 | criterion = get_criterion(cfg).cuda(gpu)
101 | optimizer = get_optimizer(cfg, model)
102 | cfg['optim']['start_epoch'] = 0
103 | resume = cfg['model']['resume']
104 | if resume is not None and len(resume) > 1:
105 | if os.path.isfile(resume):
106 | to_log(cfg, "=> loading checkpoint '{}'".format(resume), True)
107 | if gpu is None:
108 | checkpoint = torch.load(resume)
109 | else:
110 | loc = f'cuda:{gpu}'
111 | checkpoint = torch.load(resume, map_location=loc)
112 | start_epoch = checkpoint['epoch']
113 | cfg['optim']['start_epoch'] = start_epoch
114 | model.load_state_dict(checkpoint['state_dict'])
115 | optimizer.load_state_dict(checkpoint['optimizer'])
116 | to_log(cfg, "=> loaded checkpoint '{}' (epoch {})"
117 | .format(resume, checkpoint['epoch']), True)
118 | else:
119 | to_log(cfg, "=> no checkpoint found at '{}'".format(resume), True)
120 | raise FileNotFoundError
121 |
122 |
123 | assert model is not None \
124 | and train_loader is not None \
125 | and criterion is not None \
126 | and optimizer is not None
127 |
128 | return (model, train_loader, train_sampler, criterion, optimizer)
129 |
--------------------------------------------------------------------------------
/pretrain_main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import yaml
3 | import os, shutil
4 | import builtins
5 | import math
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.parallel
10 | import torch.backends.cudnn as cudnn
11 | import torch.distributed as dist
12 | import torch.optim
13 | import torch.multiprocessing as mp
14 | import random
15 | import numpy as np
16 |
17 | from models import factory
18 | from utils import to_log, set_log
19 | from pretrain_trainer import train_SCRL
20 |
21 |
22 | def start_training(cfg):
23 | # only multiprocessing_distributed is supported
24 | if cfg['DDP']['multiprocessing_distributed']:
25 |
26 | ngpus_per_node = torch.cuda.device_count()
27 |
28 | if cfg['DDP']['dist_url'] == "env://":
29 | os.environ['MASTER_ADDR'] = cfg['DDP']['master_ip']
30 | os.environ['MASTER_PORT'] = str(cfg['DDP']['master_port'])
31 | os.environ['WORLD_SIZE'] = str(ngpus_per_node * cfg['DDP']['machine_num'])
32 | os.environ['NODE_RANK'] = str(cfg['DDP']['node_num'])
33 | os.environ['NUM_NODES'] = str(cfg['DDP']['machine_num'])
34 | os.environ['NUM_GPUS_PER_NODE'] = str(ngpus_per_node)
35 | # os.environ['NCCL_IB_DISABLE'] = "1"
36 |
37 | cfg['DDP']['world_size'] = ngpus_per_node * cfg['DDP']['machine_num']
38 | print(cfg['DDP']['world_size'], ngpus_per_node)
39 |
40 | mp.spawn(task_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, cfg))
41 |
42 |
43 | def setup_worker(seed, gpu):
44 | torch.manual_seed(seed)
45 | torch.cuda.manual_seed_all(seed)
46 | np.random.seed(seed)
47 | random.seed(seed)
48 | cudnn.benchmark = True
49 | torch.cuda.set_device(gpu)
50 |
51 | def task_worker(gpu, ngpus_per_node, cfg):
52 | setup_worker(seed = 100, gpu = gpu)
53 | if gpu != 0:
54 | def print_pass(*args):
55 | pass
56 | builtins.print = print_pass
57 |
58 | cfg['DDP']['rank']= cfg['DDP']['node_num'] * ngpus_per_node + gpu
59 | cfg['DDP']['gpu'] = gpu
60 | if cfg['DDP']['dist_url'] == 'env://':
61 | os.environ['RANK'] = str(cfg['DDP']['rank'])
62 |
63 | print(cfg['DDP']['dist_backend'], cfg['DDP']['dist_url'], cfg['DDP']['world_size'],cfg['DDP']['rank'] )
64 | dist.init_process_group(backend=cfg['DDP']['dist_backend'], init_method=cfg['DDP']['dist_url'])
65 |
66 | if gpu == 0:
67 | to_log(cfg, 'DDP init succeed!', True)
68 |
69 | model, train_loader, train_sampler, criterion, optimizer \
70 | = factory.get_training_stuff(cfg, gpu, ngpus_per_node)
71 |
72 | # training function
73 | if cfg['model']['SSL'] == 'SCRL':
74 | train_fun = train_SCRL
75 | else:
76 | raise NotImplementedError
77 |
78 | start_epoch = cfg['optim']['start_epoch']
79 | end_epoch = cfg['optim']['epochs']
80 |
81 | assert train_fun is not None
82 | for epoch in range(start_epoch, end_epoch):
83 | train_sampler.set_epoch(epoch)
84 | adjust_learning_rate(optimizer, cfg['optim']['lr'], epoch, cfg)
85 | train_fun(gpu, train_loader, model, criterion, optimizer, epoch, cfg)
86 | if cfg['DDP']['rank'] == 0 and (epoch + 1) % 4 == 0:
87 | save_checkpoint(cfg,{
88 | 'epoch': epoch + 1,
89 | 'arch': cfg['model']['backbone'],
90 | 'state_dict': model.state_dict(),
91 | 'optimizer' : optimizer.state_dict(),
92 | }, is_best=False, filename='checkpoint_{:04d}.pth.tar'.format(epoch))
93 |
94 |
95 | def adjust_learning_rate(optimizer, init_lr, epoch, cfg):
96 | """Decay the learning rate based on schedule"""
97 | if cfg['optim']['lr_cos'] == True:
98 | cur_lr = init_lr * 0.5 * (1. + math.cos(0.5 * math.pi * epoch / cfg['optim']['epochs']))
99 | else:
100 | cur_lr = init_lr
101 | for milestone in cfg['optim']['schedule']:
102 | cur_lr *= 0.1 if epoch >= milestone else 1.
103 | for param_group in optimizer.param_groups:
104 | if 'fix_lr' in param_group and param_group['fix_lr']:
105 | param_group['lr'] = init_lr
106 | else:
107 | param_group['lr'] = cur_lr
108 |
109 |
110 |
111 | def save_checkpoint(cfg, state, is_best, filename='checkpoint.pth.tar'):
112 | p = os.path.join(cfg['log']['dir'], 'checkpoints')
113 | if not os.path.exists(p):
114 | os.makedirs(p)
115 |
116 | torch.save(state, os.path.join(p, filename))
117 | if is_best:
118 | shutil.copyfile(os.path.join(p, filename), os.path.join(p, 'model_best.pth.tar'))
119 |
120 |
121 | def get_config():
122 | parser = argparse.ArgumentParser()
123 | parser.add_argument('--config', type=str, default='./config/SCRL_pretrain_default.yaml')
124 | args = parser.parse_args()
125 | cfg = yaml.safe_load(open(args.config, encoding='utf8'))
126 | cfg = set_log(cfg)
127 | shutil.copy(args.config, cfg['log']['dir'])
128 | return cfg
129 |
130 |
131 | def main():
132 | cfg = get_config()
133 | start_training(cfg)
134 |
135 | if __name__ == '__main__':
136 | main()
137 |
--------------------------------------------------------------------------------
/pretrain_trainer.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 | import torch.nn.parallel
4 | import torch.optim
5 | from utils import AverageMeter, ProgressMeter, to_log, accuracy
6 |
7 |
8 |
9 | def train_SCRL(gpu, train_loader, model, criterion, optimizer, epoch, cfg):
10 | batch_time = AverageMeter('Time', ':6.3f')
11 | data_time = AverageMeter('Data', ':6.3f')
12 | losses = AverageMeter('Loss', ':.4e')
13 | top1 = AverageMeter('Acc@1', ':6.2f')
14 | top5 = AverageMeter('Acc@5', ':6.2f')
15 |
16 | progress = ProgressMeter(
17 | len(train_loader),
18 | [batch_time, data_time, losses, top1, top5],
19 | prefix="Epoch: [{}]".format(epoch))
20 |
21 | gradient_clip_val = cfg['optim']['gradient_norm']
22 |
23 | model.train()
24 | view_size = (-1, 3 * cfg['data']['frame_size'], 224, 224)
25 | pivot = time.time()
26 | for i, data in enumerate(train_loader):
27 | if gpu is not None:
28 | data_q = data[0].cuda(gpu, non_blocking=True)
29 | data_k = data[1].cuda(gpu, non_blocking=True)
30 | data_time.update(time.time() - pivot)
31 | data_q = data_q.view(view_size)
32 | data_k = data_k.view(view_size)
33 |
34 | output, target = model(data_q, data_k)
35 |
36 | loss = criterion(output, target)
37 |
38 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
39 |
40 | losses.update(loss.item(), target.size(0))
41 | top1.update(acc1[0], target.size(0))
42 | top5.update(acc5[0], target.size(0))
43 |
44 | optimizer.zero_grad()
45 | loss.backward()
46 |
47 | # gradient clipping
48 | if gradient_clip_val > 0:
49 | torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip_val)
50 |
51 | optimizer.step()
52 |
53 | batch_time.update(time.time() - pivot)
54 | pivot = time.time()
55 |
56 | if gpu == 0 and i % cfg['log']['print_freq'] == 0:
57 | _out = progress.display(i)
58 | to_log(cfg, _out, True)
59 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import time
2 | import os
3 | import torch
4 |
5 | class AverageMeter(object):
6 | """Computes and stores the average and current value"""
7 | def __init__(self, name, fmt=':f'):
8 | self.name = name
9 | self.fmt = fmt
10 | self.reset()
11 |
12 | def reset(self):
13 | self.val = 0
14 | self.avg = 0
15 | self.sum = 0
16 | self.count = 0
17 |
18 | def update(self, val, n=1):
19 | self.val = val
20 | self.sum += val * n
21 | self.count += n
22 | self.avg = self.sum / self.count
23 |
24 | def __str__(self):
25 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
26 | return fmtstr.format(**self.__dict__)
27 |
28 |
29 | class ProgressMeter(object):
30 | def __init__(self, num_batches, meters, prefix=""):
31 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
32 | self.meters = meters
33 | self.prefix = prefix
34 |
35 | def display(self, batch):
36 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
37 | entries += [str(meter) for meter in self.meters]
38 | out = '\t'.join(entries)
39 | return out
40 |
41 | def _get_batch_fmtstr(self, num_batches):
42 | num_digits = len(str(num_batches // 1))
43 | fmt = '{:' + str(num_digits) + 'd}'
44 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
45 |
46 | def set_log(cfg):
47 | time_str = time.strftime("%Y-%m-%d_%H_%M_%S", time.localtime())
48 | cfg['log']['dir'] = cfg['log']['dir'] + time_str
49 | if not os.path.exists(cfg['log']['dir']):
50 | os.makedirs(cfg['log']['dir'])
51 | return cfg
52 |
53 | def to_log(cfg, content, echo=False, gpu_print_id=0):
54 | # gpu_print_id < 0 force to print
55 | if cfg['DDP']['gpu'] == gpu_print_id and gpu_print_id >= 0:
56 | log_path = os.path.join(cfg['log']['dir'], 'log.txt')
57 | with open(log_path, 'a') as f:
58 | f.writelines(content+'\n')
59 | if echo:
60 | print(content)
61 |
62 | def accuracy(output, target, topk=(1,)):
63 | """Computes the accuracy over the k top predictions for the specified values of k"""
64 | with torch.no_grad():
65 | maxk = max(topk)
66 | batch_size = target.size(0)
67 |
68 | _, pred = output.topk(maxk, 1, True, True)
69 | pred = pred.t()
70 | correct = pred.eq(target.view(1, -1).expand_as(pred))
71 |
72 | res = []
73 | for k in topk:
74 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
75 | res.append(correct_k.mul_(100.0 / batch_size))
76 | return res
--------------------------------------------------------------------------------