├── .gitignore ├── LICENSE ├── README.md ├── SceneSeg ├── BiLSTM_protocol.py ├── __init__.py ├── main.py └── movienet_seg_data.py ├── cluster ├── Group.py └── cluster_test.ipynb ├── config ├── SCRL_pretrain_default.yaml ├── SCRL_pretrain_with_imagenet1k.yaml └── SCRL_pretrain_without_imagenet1k.yaml ├── data ├── MovieNet_1.0_shotinfo.json ├── MovieNet_shot_num.json ├── data_preparation.py ├── movie1K.scene_seg_318_name_index_shotnum_label.v1.json ├── movie1K.split.v1.json └── movienet_data.py ├── extract_embeddings.py ├── figures └── puzzle_example.jpg ├── models ├── __init__.py ├── backbones │ ├── __init__.py │ └── visual │ │ └── resnet.py ├── core │ ├── SCRL_MoCo.py │ └── __init__.py └── factory.py ├── pretrain_main.py ├── pretrain_trainer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | output/ 2 | compressed_shot_images/ 3 | embeddings/ 4 | checkpoints/ 5 | SceneSeg/output/ 6 | pretrain/ 7 | __pycache__/ 8 | *.pkl 9 | *.log 10 | *.txt 11 | *.pth -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Scene Consistency Representation Learning for Video Scene Segmentation (CVPR2022) 2 | This is an official PyTorch implementation of SCRL, the CVPR2022 paper is available at [here](https://openaccess.thecvf.com/content/CVPR2022/html/Wu_Scene_Consistency_Representation_Learning_for_Video_Scene_Segmentation_CVPR_2022_paper.html). 3 | 4 | # Getting Started 5 | 6 | ## Data Preparation 7 | ### MovieNet Dataset 8 | Download MovieNet Dataset from its [Official Website](https://movienet.github.io/). 9 | ### SceneSeg318 Dataset 10 | Download the Annotation of [SceneSeg318](https://drive.google.com/drive/folders/1NFyL_IZvr1mQR3vR63XMYITU7rq9geY_?usp=sharing), you can find the download instructions in [LGSS](https://github.com/AnyiRao/SceneSeg/blob/master/docs/INSTALL.md) repository. 11 | 12 | ### Make Puzzles for pre-training 13 | In order to reduce the number of IO accesses and perform data augmentation (a.k.a *Scene Agnostic Clip-Shuffling* in the paper) at the same time, we suggest to stitch 16 shots into one image (puzzle) during the pre-training stage. You can make the data by yourself: 14 | ``` 15 | python ./data/data_preparation.py 16 | ``` 17 | And the processed data will be saved in `./compressed_shot_images/`, a puzzle example [figure](./figures/puzzle_example.jpg). 18 | 19 | 20 | 21 | ### Load the Data into Memory [Optional] 22 | We **strongly recommend** loading data into memory to speed up pre-training, which additionally requires your device to have at least 100GB of RAM. 23 | ``` 24 | mkdir /tmpdata 25 | mount tmpfs /tmpdata -t tmpfs -o size=100G 26 | cp -r ./compressed_shot_images/ /tmpdata/ 27 | ``` 28 | 29 | 30 | ## Initialization Weights Preparation 31 | Download the ResNet-50 weights trained on ImageNet-1k ([resnet50-19c8e357.pth](https://download.pytorch.org/models/resnet50-19c8e357.pth)), and save it in `./pretrain/` folder. 32 | 33 | ## Prerequisites 34 | ### Requirements 35 | * python >= 3.6 36 | * pytorch >= 1.6 37 | * cv2 38 | * pickle 39 | * numpy 40 | * yaml 41 | * sklearn 42 | 43 | ### Hardware 44 | * 8 NVIDIA V100 (32GB) GPUs 45 | 46 | # Usage 47 | ### STEP 1: Encoder Pre-training 48 | Using the default configuration to pretrain the model. Make sure the data path is correct and the GPUs are sufficient (e.g. 8 NVIDIA V100 GPUs) 49 | ``` 50 | python pretrain_main.py --config ./config/SCRL_pretrain_default.yaml 51 | ``` 52 | The checkpoint, copy of config and log will be saved in `./output/`. 53 | 54 | ### STEP 2: Feature Extraction 55 | 56 | ``` 57 | python extract_embeddings.py $CKP_PATH --shot_img_path $SHOT_PATH --Type all --gpu-id 0 58 | ``` 59 | `$CKP_PATH` is the path of an encoder checkpoint, and `$SHOT_PATH` is the keyframe path of MovieNet. 60 | The extracted embeddings (in pickle format) and log will be saved in `./embeddings/`. 61 | 62 | ### STEP 3: Video Scene Segmentation Evaluation 63 | 64 | ``` 65 | cd SceneSeg 66 | 67 | python main.py \ 68 | -train $TRAIN_PKL_PATH \ 69 | -test $TEST_PKL_PATH \ 70 | -val $VAL_PKL_PATH \ 71 | --seq-len 40 \ 72 | --gpu-id 0 73 | ``` 74 | 75 | The checkpoints and log will be saved in `./SceneSeg/output/`. 76 | 77 | ## Models 78 | We provide checkpoints, logs and results under two different pre-training settings, i.e. with and without ImageNet-1K initialization, respectively. 79 | 80 | | Initialization | AP | F1 | Config File | STEP 1
Pre-training | STEP 2
Embeddings| STEP 3
Fine-tuning | 81 | | :-----| :---- | :---- | :---- | :-----| :---- | :---- | 82 | | w/o ImageNet-1k | 55.16 | 51.32 | SCRL_pretrain
_without_imagenet1k.yaml | [ckp and log](https://drive.google.com/drive/folders/1ZYg9PFRU_lt3G5qJrldkguA52T2oxErR?usp=sharing) | [embedings](https://drive.google.com/drive/folders/1uen_HP3BZu8bcrPBikkgV3j9wzUjQ0C1?usp=sharing) | [ckps and log](https://drive.google.com/drive/folders/1rJbOnVbqTdPmnh2grIkePXOmwpNELnrK?usp=sharing) | 83 | | w/ ImageNet-1k | 56.65 | 52.45 | SCRL_pretrain
_with_imagenet1k.yaml | [ckp and log](https://drive.google.com/drive/folders/1BG5ZLqrPKKGTtDIZj8aps_QuWc6K3c3V?usp=sharing) | [embedings](https://drive.google.com/drive/folders/1NFvGhkvRxpmEJYNjRnwp3ybuHQaG25gW?usp=sharing) | [ckps and log](https://drive.google.com/drive/folders/1dE0JFi-MDua70_CgI1CvyLNRnhwLjaUV?usp=sharing) | 84 | 85 | 86 | ## License 87 | Please see [LICENSE](./LICENSE) file for the details. 88 | 89 | ## Acknowledgments 90 | Part of codes are borrowed from the following repositories: 91 | * [MoCo](https://github.com/facebookresearch/moco) 92 | * [LGSS](https://github.com/AnyiRao/SceneSeg) 93 | 94 | ## Citation 95 | Please cite our work if it's useful for your research. 96 | ``` 97 | @InProceedings{Wu_2022_CVPR, 98 | author = {Wu, Haoqian and Chen, Keyu and Luo, Yanan and Qiao, Ruizhi and Ren, Bo and Liu, Haozhe and Xie, Weicheng and Shen, Linlin}, 99 | title = {Scene Consistency Representation Learning for Video Scene Segmentation}, 100 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 101 | month = {June}, 102 | year = {2022}, 103 | pages = {14021-14030} 104 | } 105 | ``` -------------------------------------------------------------------------------- /SceneSeg/BiLSTM_protocol.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class BiLSTM(nn.Module): 7 | def __init__(self, input_feature_dim=2048, fc_dim=1024, hidden_size=512, 8 | input_drop_rate=0.3, lstm_drop_rate=0.6, fc_drop_rate=0.7, use_bn=True): 9 | super(BiLSTM, self).__init__() 10 | 11 | input_size = input_feature_dim 12 | output_size = fc_dim 13 | self.embed_sizes = input_feature_dim 14 | self.embed_fc = nn.Linear(input_size, output_size) 15 | self.hidden_size = hidden_size 16 | self.lstm = nn.LSTM( 17 | input_size=output_size, 18 | hidden_size=self.hidden_size, 19 | num_layers=2, 20 | batch_first=True, 21 | dropout=lstm_drop_rate, 22 | bidirectional=True 23 | ) 24 | # The probability is set to 0 by default 25 | self.input_shotmask = ShotMask(p=0) 26 | self.input_dropout = nn.Dropout(p=input_drop_rate) 27 | self.fc_dropout = nn.Dropout(p=fc_drop_rate) 28 | self.fc1 = nn.Linear(self.hidden_size*2, hidden_size) 29 | self.fc2 = nn.Linear(hidden_size, 2) 30 | self.softmax = nn.Softmax(2) 31 | self.use_bn = use_bn 32 | 33 | if self.use_bn: 34 | self.bn1 = nn.BatchNorm1d(output_size) 35 | self.bn2 = nn.BatchNorm1d(hidden_size) 36 | 37 | 38 | def forward(self, x, y): 39 | if self.training: 40 | x = self.input_shotmask(x, y) 41 | x = self.input_dropout(x) 42 | x = self.embed_fc(x) 43 | 44 | if self.use_bn: 45 | seq_len, C = x.shape[1:3] 46 | x = x.view(-1, C) 47 | x = self.bn1(x) 48 | x = x.view(-1, seq_len, C) 49 | 50 | x = self.fc_dropout(x) 51 | self.lstm.flatten_parameters() 52 | out, (_, _) = self.lstm(x, None) 53 | out = self.fc1(out) 54 | if self.use_bn: 55 | seq_len, C = out.shape[1:3] 56 | out = out.view(-1, C) 57 | out = self.bn2(out) 58 | out = out.view(-1, seq_len, C) 59 | out = self.fc_dropout(out) 60 | out = F.relu(out) 61 | out = self.fc2(out) 62 | if not self.training: 63 | out = self.softmax(out) 64 | return out 65 | 66 | 67 | class ShotMask(nn.Module): 68 | ''' 69 | Drop the shot from the middle of a scene 70 | ''' 71 | def __init__(self, p=0.2): 72 | super(ShotMask, self).__init__() 73 | self.p = p 74 | 75 | def forward(self, x, y): 76 | # keep the cue 77 | B, L , _ = x.size() 78 | y_shift = torch.cat([torch.zeros(B,1,1).bool().to(y.device), y.bool()],dim=1)[:,:L,:] 79 | self.mask = torch.rand(*y.size()) >= self.p 80 | self.mask = self.mask.bool().to(x.device) | y.bool() | y_shift 81 | out = x.mul(self.mask) 82 | return out 83 | 84 | if __name__ == '__main__': 85 | B, seq_len, C = 10, 20, 2048 86 | input = torch.randn(B, seq_len, C) 87 | model = BiLSTM() 88 | out = model(input) 89 | # torch.Size([10, 20, 2]) 90 | print(out.size()) -------------------------------------------------------------------------------- /SceneSeg/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentYoutuResearch/SceneSegmentation-SCRL/7d2daed4c8f1922aa6c85abaf9db36abaf0ae67e/SceneSeg/__init__.py -------------------------------------------------------------------------------- /SceneSeg/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.parallel 4 | import torch.backends.cudnn as cudnn 5 | import torch.optim 6 | import numpy as np 7 | import os 8 | import argparse 9 | import random 10 | import time 11 | from sklearn.metrics import average_precision_score 12 | import shutil 13 | import os.path as osp 14 | from BiLSTM_protocol import BiLSTM 15 | from movienet_seg_data import MovieNet_SceneSeg_Dataset_Embeddings_Train, MovieNet_SceneSeg_Dataset_Embeddings_Val 16 | 17 | def main(args): 18 | setup_seed(100) 19 | model = BiLSTM( 20 | input_feature_dim=args.dim, 21 | input_drop_rate=args.input_drop_rate 22 | ).cuda() 23 | 24 | label_weights = torch.Tensor([args.loss_weight[0], args.loss_weight[1]]).cuda() 25 | criterion = nn.CrossEntropyLoss(label_weights).cuda() 26 | 27 | 28 | optimizer = torch.optim.SGD(model.parameters(), 29 | args.lr, 30 | momentum=args.momentum, 31 | weight_decay=args.weight_decay 32 | ) 33 | 34 | train_dataset = MovieNet_SceneSeg_Dataset_Embeddings_Train( 35 | pkl_path=args.pkl_path_train, 36 | sampled_shot_num=args.seq_len, 37 | shuffle_p=args.sample_shulle_rate 38 | ) 39 | val_dataset = MovieNet_SceneSeg_Dataset_Embeddings_Val( 40 | pkl_path=args.pkl_path_val, 41 | sampled_shot_num=args.seq_len 42 | ) 43 | 44 | test_dataset = MovieNet_SceneSeg_Dataset_Embeddings_Val( 45 | pkl_path=args.pkl_path_test, 46 | sampled_shot_num=args.seq_len 47 | ) 48 | 49 | train_loader = torch.utils.data.DataLoader(train_dataset, args.train_bs, num_workers=args.workers, 50 | shuffle=True, pin_memory=True, drop_last=False) 51 | test_loader = torch.utils.data.DataLoader(test_dataset, args.test_bs, num_workers=args.workers, 52 | shuffle=False, pin_memory=True, drop_last=False) 53 | val_loader = torch.utils.data.DataLoader(val_dataset, args.test_bs, num_workers=args.workers, 54 | shuffle=False, pin_memory=True, drop_last=False) 55 | 56 | train_fun = train 57 | test_fun = inference 58 | 59 | val_max_F1 = 0 60 | is_best = False 61 | test_info = {'mAP': 0, 'F1': 0} 62 | for epoch in range(1, args.epochs + 1): 63 | train_loader.dataset._shuffle_offset() 64 | adjust_learning_rate(args, optimizer, epoch) 65 | train_fun(args, model, train_loader, optimizer, epoch, criterion) 66 | if epoch % args.test_interval == 0 and epoch >= args.test_milestone: 67 | f1, map, acc_all = test_fun(args, model, val_loader) 68 | to_log(args, f'val set: {map, f1, acc_all}', True) 69 | if val_max_F1 < f1: 70 | val_max_F1 = f1 71 | f1_t, map_t, acc_all_t = test_fun(args, model, test_loader) 72 | test_info['mAP'] = map_t 73 | test_info['F1'] = f1_t 74 | is_best = True 75 | to_log(args, f'now best F1 on val is: {val_max_F1}', True) 76 | to_log(args, f'test set: {map_t, f1_t, acc_all_t}', True) 77 | else: 78 | is_best = False 79 | save_checkpoint({ 80 | 'state_dict': model.state_dict(), 'epoch': epoch, 81 | }, is_best=is_best, fpath=os.path.join(args.save_dir, 'checkpoint.pth.tar')) 82 | 83 | to_log(args, f'best F1 on val: {val_max_F1}', True) 84 | to_log(args, f"the test set mAP: {test_info['mAP']}, F1: {test_info['F1']}", True) 85 | 86 | 87 | def train(args, model, train_loader, optimizer, epoch, criterion, log_interval=30): 88 | model.train() 89 | for batch_idx, (data, target, _) in enumerate(train_loader): 90 | data = data.cuda(non_blocking=True) 91 | target = target.unsqueeze(-1).cuda(non_blocking=True) 92 | output = model(data, target) 93 | output = output.view(-1, 2) 94 | target = target.view(-1) 95 | loss = criterion(output, target) 96 | 97 | optimizer.zero_grad() 98 | loss.backward() 99 | optimizer.step() 100 | if batch_idx % log_interval == 0: 101 | log = 'Train Epoch: {} [{}/{} ({:.0f}%)]'.format(epoch, 102 | int(batch_idx * len(data)), len(train_loader.dataset), 103 | 100. * batch_idx / len(train_loader)).ljust(40) + \ 104 | 'Loss: {:.6f}'.format( loss.item()) 105 | to_log(args, log, True) 106 | 107 | @torch.no_grad() 108 | def inference(args, model, loader, threshhold=0.5): 109 | model.eval() 110 | corr = 0 111 | total = 0 112 | stride = args.seq_len // 2 113 | result_all = {} 114 | for batch_idx, (data, target, imdb) in enumerate(loader): 115 | imdb = imdb[0] 116 | result_all[imdb] = None 117 | data = data.view(-1, args.dim).cuda(non_blocking=True) 118 | target = target.view(-1) 119 | data_len = data.size(0) 120 | gt_len = target.size(0) 121 | prob_all = [] 122 | for w_id in range(data_len//stride): 123 | start_pos = w_id*stride 124 | _data = data[start_pos:start_pos + args.seq_len].unsqueeze(0) 125 | output = model(_data, None) 126 | output = output.view(-1, 2) 127 | prob = output[:, 1] 128 | prob = prob[stride//2:stride+stride//2].squeeze() 129 | prob_all.append(prob.cpu()) 130 | 131 | # metrics 132 | preb_all = torch.cat(prob_all,axis=0)[:gt_len].numpy() 133 | pre = np.nan_to_num(preb_all) > threshhold 134 | gt = target.cpu().numpy().astype(int) 135 | pre = pre.astype(int) 136 | idx1 = np.where(gt == 1)[0] 137 | idx0 = np.where(gt == 0)[0] 138 | idx1_p = np.where(pre == 1)[0] 139 | idx0_p = np.where(pre == 0)[0] 140 | TP = len(np.where(gt[idx1] == pre[idx1])[0]) 141 | FP = len(np.where(gt[idx1_p] != pre[idx1_p])[0]) 142 | TN = len(np.where(gt[idx0] == pre[idx0])[0]) 143 | FN = len(np.where(gt[idx0_p] != pre[idx0_p])[0]) 144 | ap = get_ap(gt, preb_all, False) 145 | correct = len(np.where(gt == pre)[0]) 146 | corr += correct 147 | total += gt_len 148 | recall = TP / (TP + FN + 1e-5) 149 | precision = TP / (TP + FP + 1e-5) 150 | f1 = 2 * recall * precision / (recall + precision + 1e-5) 151 | result_all[imdb] = (f1, ap, recall, precision) 152 | mAP_all_avg = 0 153 | F1_all_avg = 0 154 | for k, v in result_all.items(): 155 | F1_all_avg += v[0] 156 | mAP_all_avg += v[1] 157 | F1_all_avg /= len(result_all.keys()) 158 | mAP_all_avg /= len(result_all.keys()) 159 | return F1_all_avg, mAP_all_avg, corr / total 160 | 161 | 162 | def setup_seed(seed): 163 | torch.manual_seed(seed) 164 | torch.cuda.manual_seed_all(seed) 165 | np.random.seed(seed) 166 | random.seed(seed) 167 | cudnn.benchmark = True 168 | 169 | 170 | def set_log(args): 171 | time_str = time.strftime("%Y-%m-%d_%H_%M_%S", time.localtime()) 172 | 173 | args.log_file = './output/log_' + time_str + '.txt' 174 | args.save_dir = args.save_dir + 'seg_checkpoints/' + time_str + '/' 175 | 176 | if not os.path.exists(args.save_dir): 177 | os.makedirs(args.save_dir) 178 | 179 | if not os.path.exists('./output/'): 180 | os.makedirs('./output/') 181 | 182 | def to_log(args, content, echo=False): 183 | with open(args.log_file, 'a') as f: 184 | f.writelines(content+'\n') 185 | if echo: 186 | print(content) 187 | 188 | def adjust_learning_rate(args, optimizer, epoch): 189 | """Decay the learning rate based on schedule""" 190 | lr = args.lr 191 | for milestone in args.schedule: 192 | lr *= 0.1 if epoch >= milestone else 1. 193 | for param_group in optimizer.param_groups: 194 | param_group['lr'] = lr 195 | 196 | def get_ap(gts_raw,preds_raw,is_list=True): 197 | if is_list: 198 | gts,preds = [],[] 199 | for gt_raw in gts_raw: 200 | gts.extend(gt_raw.tolist()) 201 | for pred_raw in preds_raw: 202 | preds.extend(pred_raw.tolist()) 203 | else: 204 | gts = np.array(gts_raw) 205 | preds = np.array(preds_raw) 206 | # print ("AP ",average_precision_score(gts, preds)) 207 | return average_precision_score(np.nan_to_num(gts), np.nan_to_num(preds)) 208 | # return average_precision_score(gts, preds) 209 | 210 | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): 211 | os.makedirs(osp.dirname(fpath),exist_ok=True) 212 | torch.save(state, fpath) 213 | if is_best: 214 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_best.pth.tar')) 215 | 216 | 217 | def get_config(): 218 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 219 | parser.add_argument('--epochs', default=200, type=int, metavar='N', 220 | help='number of total epochs to run') 221 | # data 222 | parser.add_argument('-train', '--pkl-path-train', default='', type=str, 223 | help='the path of pickle train data') 224 | 225 | parser.add_argument('-test', '--pkl-path-test', default='', type=str, 226 | help='the path of pickle test data') 227 | 228 | parser.add_argument('-val', '--pkl-path-val', default='', type=str, 229 | help='the path of pickle val data') 230 | 231 | parser.add_argument('--train-bs', default=12, type=int) 232 | parser.add_argument('--test-bs', default=1, type=int) 233 | parser.add_argument('--shot-num', default=10, type=int) 234 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 235 | metavar='LR', help='initial learning rate', dest='lr') 236 | parser.add_argument('--gpu-id', type=str, default='0', help='gpu id') 237 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 238 | help='momentum of SGD solver') 239 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 240 | metavar='W', help='weight decay', dest='weight_decay') 241 | 242 | parser.add_argument('--save-dir', default='./output/', type=str, 243 | help='the path of checkpoints') 244 | # loss weight 245 | parser.add_argument('--loss-weight', default=[1, 4], nargs='+', type=float, 246 | help='loss weight') 247 | parser.add_argument('--sample-shulle-rate', default=1.0, type=float) 248 | parser.add_argument('--input-drop-rate', default=0.2, type=float) 249 | # lr schedule 250 | parser.add_argument('--schedule', default=[160, 180], nargs='+', 251 | help='learning rate schedule (when to drop lr by a ratio)') 252 | 253 | parser.add_argument('-j', '--workers', default=16, type=int, 254 | help='number of workers') 255 | parser.add_argument('--dim', default=2048, type=int) 256 | parser.add_argument('--seq-len', default=40, type=int) 257 | parser.add_argument('--test-interval', default=1, type=int) 258 | parser.add_argument('--test-milestone', default=100, type=int) 259 | 260 | args = parser.parse_args() 261 | 262 | # assert 263 | assert args.seq_len % 4 == 0 264 | 265 | # select GPU 266 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 267 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 268 | 269 | set_log(args) 270 | for arg in vars(args): 271 | to_log(args,arg.ljust(20)+':'+str(getattr(args, arg)), True) 272 | return args 273 | 274 | if __name__ == '__main__': 275 | args = get_config() 276 | main(args) -------------------------------------------------------------------------------- /SceneSeg/movienet_seg_data.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import torch 3 | import torch.utils.data as data 4 | import numpy as np 5 | import random 6 | 7 | class MovieNet_SceneSeg_Dataset_Embeddings_Train(data.Dataset): 8 | def __init__(self, pkl_path, frame_size=3, shot_num=1, 9 | sampled_shot_num=10, shuffle_p=0.5,random_cat=False): 10 | self.shot_num = shot_num 11 | self.pkl_path = pkl_path 12 | self.frame_size = frame_size 13 | self.sampled_shot_num = sampled_shot_num 14 | self.shuffle_p = shuffle_p 15 | self.dict_idx_shot = {} 16 | self.data_length = 0 17 | self.random_cat = random_cat 18 | fileObject = open(self.pkl_path, 'rb') 19 | self.pickle_data = pickle.load(fileObject) 20 | fileObject.close() 21 | self.total_video_num = len(self.pickle_data.keys()) 22 | idx = 0 23 | self.shuffle_map = {} 24 | self.shuffle_offset = {} 25 | for k, v in self.pickle_data.items(): 26 | video_shot_group_num = (len(v) // self.sampled_shot_num) - 1 27 | self.shuffle_map[k] = (len(v) - self.sampled_shot_num * video_shot_group_num) 28 | self.shuffle_offset[k] = 0 29 | for i in range(video_shot_group_num): 30 | self.dict_idx_shot[idx] = (k, i) 31 | idx += 1 32 | self._shuffle_offset() 33 | print(f'Train video num: {self.total_video_num}') 34 | print(f'total shot group: {idx}') 35 | self.data_length = idx 36 | 37 | def _shuffle_offset(self): 38 | for k, offset_upper_bound in self.shuffle_map.items(): 39 | offset = random.randint(0, offset_upper_bound-1) 40 | offset = 0 if offset < 0 else offset 41 | self.shuffle_offset[k] = offset 42 | 43 | def _get_randomly_cat_clip(self, idx): 44 | k, i = self.dict_idx_shot[idx] 45 | sampled_len = self.sampled_shot_num // 2 46 | # randomly cat an another clip 47 | data1, label1, _ = self._get_clip_by_idx(idx, sampled_len) 48 | # fix last shot label 49 | label1[-1] = 1 50 | # random the index 51 | length = len(self.pickle_data[k]) 52 | start = random.randint(0, length - sampled_len - 1) 53 | 54 | p = self.pickle_data[k][start : start + sampled_len] 55 | data = np.array([p[i][0] for i in range(sampled_len)]) 56 | label = np.array([p[i][1] for i in range(sampled_len)]) 57 | data2 = torch.from_numpy(data).squeeze(1) 58 | label2 = torch.from_numpy(label).long() 59 | 60 | data = torch.cat([data1, data2],dim=0) 61 | label = torch.cat([label1, label2],dim=0) 62 | return data, label, k 63 | 64 | 65 | def _seg_shuffle(self, data, label): 66 | new_d, new_l = [], [] 67 | clips = [] 68 | # find positive pos 69 | p_index = torch.where(label>=1)[0] 70 | start, end = 0, len(label) 71 | for i in p_index: 72 | i = i.item() 73 | clips.append((start, i+1)) 74 | start = i+1 75 | if start != end: 76 | clips.append((start, end)) 77 | # if the last clip is used for shulling 78 | # the label of the last shot might be changed 79 | label[-1] = 1 80 | clips_len = len(clips) 81 | index_list = random.sample(range(0, clips_len), clips_len) 82 | for i in index_list: 83 | s, e = clips[i] 84 | new_d.append(data[s:e]) 85 | new_l.append(label[s:e]) 86 | d = torch.cat(new_d,dim=0) 87 | l = torch.cat(new_l,dim=0) 88 | # when shuffling is done, fix the last shot label 89 | l[-1] = 0 90 | return d, l 91 | 92 | def _get_clip_by_idx(self, idx, length): 93 | k , i = self.dict_idx_shot[idx] 94 | offset = self.shuffle_offset[k] 95 | s = self.sampled_shot_num 96 | p = self.pickle_data[k][i*s+offset:(i+1)*s+offset][:length] 97 | data = np.array([p[i][0] for i in range(length)]) 98 | label = np.array([p[i][1] for i in range(length)]) 99 | data = torch.from_numpy(data).squeeze(1) 100 | label = torch.from_numpy(label).long() 101 | # fix last shot label 102 | label[-1] = 0 103 | return data, label, k 104 | 105 | 106 | def __getitem__(self, idx): 107 | if not self.random_cat: 108 | data, label, k = self._get_clip_by_idx(idx, self.sampled_shot_num) 109 | else: 110 | data, label, k = self._get_randomly_cat_clip(idx) 111 | if random.random() < self.shuffle_p: 112 | data, label = self._seg_shuffle(data, label) 113 | return data, label, k 114 | 115 | def __len__(self): 116 | return self.data_length 117 | 118 | class MovieNet_SceneSeg_Dataset_Embeddings_Val(data.Dataset): 119 | def __init__(self, pkl_path, frame_size=3, shot_num=1, 120 | sampled_shot_num=100): 121 | self.shot_num = shot_num 122 | self.pkl_path = pkl_path 123 | self.frame_size = frame_size 124 | self.sampled_shot_num = sampled_shot_num 125 | self.dict_idx_shot = {} 126 | self.data_length = 0 127 | fileObject = open(self.pkl_path, 'rb') 128 | self.pickle_data = pickle.load(fileObject) 129 | fileObject.close() 130 | self.total_video_num = len(self.pickle_data.keys()) 131 | idx = 0 132 | for k, v in self.pickle_data.items(): 133 | self.dict_idx_shot[idx] = (k, v) 134 | idx += 1 135 | print(f'video num: {self.total_video_num}') 136 | self.data_length = idx 137 | 138 | def _padding(self, data): 139 | stride = self.sampled_shot_num // 2 140 | shot_len = data.size(0) 141 | p_l = data[0].repeat(self.sampled_shot_num // 4, 1) 142 | p_r_len = self.sampled_shot_num // 4 143 | res = shot_len % (stride) 144 | if res != 0: 145 | p_r_len += (stride) - res 146 | p_r = data[-1].repeat(p_r_len, 1) 147 | pad_data = torch.cat((p_l, data, p_r),0) 148 | assert pad_data.size(0) % stride == 0 149 | return pad_data 150 | 151 | def __getitem__(self, idx): 152 | k, v = self.dict_idx_shot[idx] 153 | num_shot = len(v) 154 | data = np.array([v[i][0] for i in range(num_shot)]) 155 | label = np.array([v[i][1] for i in range(num_shot)]) 156 | data = torch.from_numpy(data).squeeze(1) 157 | data = self._padding(data) 158 | label = torch.from_numpy(label) 159 | return data, label, k 160 | 161 | def __len__(self): 162 | return self.data_length -------------------------------------------------------------------------------- /cluster/Group.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import numpy as np 4 | import time 5 | 6 | 7 | class Cluster_GPU(): 8 | ''' 9 | A pytorch GPU implementation for K-Means algorithm, 10 | which is used for real-time clutering in SCRL. 11 | ''' 12 | def __init__(self, 13 | num_clusters, 14 | shift_threshold=1e-2, 15 | max_iter=20, 16 | device=torch.device('cuda'), 17 | debug=False): 18 | self.cluster_func = KMeans_Mixed( 19 | num_clusters=num_clusters, 20 | shift_threshold=shift_threshold, 21 | max_iter=max_iter, 22 | device=device 23 | ) 24 | self.device = device 25 | self.debug = debug 26 | 27 | def __call__(self, x): 28 | dimension = len(x.size()) 29 | x = x.to(self.device) 30 | B = x.size(0) 31 | output_vector = x.clone().detach() 32 | # D == 2 33 | if dimension == 2: 34 | _, choice_cluster, choice_points = self.cluster_func(output_vector, debug=self.debug) 35 | # D >= 3 36 | elif dimension == 3: 37 | choice_cluster_list, cluster_points_list = [], [] 38 | for batch in range(B): 39 | y = output_vector.narrow(dim=0, start=batch, length=1).squeeze(0) 40 | _, choice_cluster, choice_points = self.cluster_func(y, debug=self.debug) 41 | choice_cluster_list.append(choice_cluster) 42 | cluster_points_list.append(choice_points) 43 | choice_cluster = np.stack(choice_cluster_list) 44 | choice_points = np.stack(cluster_points_list) 45 | else: 46 | raise ValueError('Dimension of input must <= 3, got {dimension} instead') 47 | return choice_cluster, choice_points 48 | 49 | 50 | class KMeans_Mixed(): 51 | ''' 52 | This version uses GPU for tensor computation and 53 | CPU for indexing to improve the speed of the algorithm. 54 | ''' 55 | def __init__(self, 56 | num_clusters, 57 | shift_threshold, 58 | max_iter, 59 | cluster_centers = [], 60 | device=torch.device('cuda')): 61 | 62 | self.num_clusters = num_clusters 63 | self.shift_threshold = shift_threshold 64 | self.max_iter = max_iter 65 | self.cluster_centers = cluster_centers 66 | self.device = device 67 | self.pairwise_distance_func = pairwise_distance 68 | 69 | def initialize(self, X): 70 | num_samples = len(X) 71 | initial_indices = np.random.choice(num_samples, self.num_clusters, replace=False) 72 | initial_state = X[initial_indices] 73 | return initial_state 74 | 75 | def __call__(self, tensor_input, debug=False): 76 | if debug: 77 | time_start=time.time() 78 | 79 | X = tensor_input 80 | X = X.to(self.device) 81 | choice_points = np.ones(self.num_clusters) 82 | # init cluster center 83 | if type(self.cluster_centers) == list: 84 | initial_state = self.initialize(X) 85 | else: 86 | if debug: 87 | print('resuming cluster') 88 | initial_state = self.cluster_centers 89 | dis = self.pairwise_distance_func(X, initial_state, self.device) 90 | choice_points = torch.argmin(dis, dim=0) 91 | initial_state = X[choice_points] 92 | initial_state = initial_state.to(self.device) 93 | iteration = 0 94 | status = 0 95 | while status == 0: 96 | # CPU is better at indexing, so transfer the data to the cpu 97 | dis = self.pairwise_distance_func(X, initial_state, self.device).cpu().numpy() 98 | choice_cluster = np.argmin(dis, axis=1) 99 | initial_state_pre = initial_state.clone() 100 | for index in range(self.num_clusters): 101 | selected = np.where(choice_cluster == index) 102 | selected = X[selected] 103 | initial_state[index] = selected.mean(dim=0) 104 | dis_new = self.pairwise_distance_func(X, 105 | initial_state[index].unsqueeze(0), 106 | self.device).cpu().numpy() 107 | culuster_pos = np.argmin(dis_new, axis=0) 108 | # a cluster has at least one sample 109 | while culuster_pos in choice_points[:index]: 110 | dis_new[culuster_pos] = np.inf 111 | culuster_pos = np.argmin(dis_new, axis=0) 112 | 113 | choice_points[index] = culuster_pos 114 | initial_state = X[choice_points] 115 | 116 | center_shift = torch.sum(torch.sum((initial_state - initial_state_pre) ** 2, dim=1)) 117 | 118 | iteration = iteration + 1 119 | 120 | if center_shift **2 < self.shift_threshold: 121 | status = 1 122 | if iteration >= self.max_iter: 123 | status = 2 124 | 125 | if debug: 126 | print("Iter: {} center_shift: {:.5f}".format(iteration, center_shift)) 127 | 128 | if debug: 129 | if status == 1: 130 | time_end=time.time() 131 | print('Time cost: {:.3f}'.format(time_end-time_start)) 132 | print("Stopped for the center_shift!") 133 | else: 134 | time_end=time.time() 135 | print('Time cost: {:.3f}'.format(time_end-time_start)) 136 | print("Stopped for the max_iter!") 137 | return initial_state, choice_cluster, choice_points 138 | 139 | # utils 140 | def pairwise_distance(data1, data2, device=torch.device('cuda')): 141 | data1, data2 = data1.to(device), data2.to(device) 142 | # N*1*M 143 | A = data1.unsqueeze(dim=1) 144 | # 1*N*M 145 | B = data2.unsqueeze(dim=0) 146 | dis = (A - B) ** 2.0 147 | dis = dis.sum(dim=-1) 148 | return dis 149 | -------------------------------------------------------------------------------- /cluster/cluster_test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from Group import Cluster_GPU as Cluster\n", 10 | "import matplotlib.pyplot as plt\n", 11 | "import torch\n", 12 | "import random" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 2, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "num_clusters = 3\n", 22 | "cluster = Cluster(num_clusters=num_clusters, \n", 23 | " max_iter=30, \n", 24 | " shift_threshold=1e-3, \n", 25 | " device='cuda',\n", 26 | " debug=True\n", 27 | ")" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 3, 33 | "metadata": {}, 34 | "outputs": [ 35 | { 36 | "name": "stdout", 37 | "output_type": "stream", 38 | "text": [ 39 | "torch.Size([5, 300, 2])\n" 40 | ] 41 | } 42 | ], 43 | "source": [ 44 | "num_point = 300\n", 45 | "batch = 5\n", 46 | "test_tensor = torch.cat([\n", 47 | " torch.randn(batch, num_point//3, 2) + 4, \n", 48 | " torch.randn(batch, num_point//3, 2) - 4, \n", 49 | " torch.randn(batch, num_point//3, 2)],\n", 50 | " dim=1\n", 51 | ")\n", 52 | "test_tensor = test_tensor.cuda()\n", 53 | "print(test_tensor.size())\n" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 4, 59 | "metadata": {}, 60 | "outputs": [ 61 | { 62 | "name": "stdout", 63 | "output_type": "stream", 64 | "text": [ 65 | "Iter: 1 center_shift: 32.28633\n", 66 | "Iter: 2 center_shift: 6.91954\n", 67 | "Iter: 3 center_shift: 0.12403\n", 68 | "Iter: 4 center_shift: 0.00000\n", 69 | "Time cost: 0.025\n", 70 | "Stopped for the center_shift!\n", 71 | "Iter: 1 center_shift: 13.87646\n", 72 | "Iter: 2 center_shift: 4.93253\n", 73 | "Iter: 3 center_shift: 0.00000\n", 74 | "Time cost: 0.008\n", 75 | "Stopped for the center_shift!\n", 76 | "Iter: 1 center_shift: 3.94572\n", 77 | "Iter: 2 center_shift: 0.00000\n", 78 | "Time cost: 0.006\n", 79 | "Stopped for the center_shift!\n", 80 | "Iter: 1 center_shift: 5.07124\n", 81 | "Iter: 2 center_shift: 1.16594\n", 82 | "Iter: 3 center_shift: 0.00000\n", 83 | "Time cost: 0.008\n", 84 | "Stopped for the center_shift!\n", 85 | "Iter: 1 center_shift: 14.44800\n", 86 | "Iter: 2 center_shift: 3.34260\n", 87 | "Iter: 3 center_shift: 0.21062\n", 88 | "Iter: 4 center_shift: 0.00000\n", 89 | "Time cost: 0.011\n", 90 | "Stopped for the center_shift!\n" 91 | ] 92 | } 93 | ], 94 | "source": [ 95 | "cluster_ids_stack, cluster_centers_stack = cluster(test_tensor)" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 5, 101 | "metadata": {}, 102 | "outputs": [ 103 | { 104 | "name": "stdout", 105 | "output_type": "stream", 106 | "text": [ 107 | "(300, 2)\n", 108 | "(300,)\n", 109 | "(3,)\n" 110 | ] 111 | }, 112 | { 113 | "data": { 114 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAx90lEQVR4nO2de5RU1Z3vv7+qU011R0jCSAZfBHQ6PgLXRxHUPMhDFO6ocWat9Ag35Epkhkwy08vJdW7So8u4WN5k9b1JdLidrFyYkMEbnca0OpqHI4KaEHMVpOIDBJ1WRNSWkUwngxGhu6r2/WP36T7n1D6Pqjr1/n7W6tVdp87ZZ1ct+O7f+e7f/m1RSoEQQkjzkqh3BwghhFQGhZwQQpocCjkhhDQ5FHJCCGlyKOSEENLkWPW46Yknnqjmzp1bj1sTQkjTks1mf6OUmuU9Xhchnzt3Lnbt2lWPWxNCSNMiIq+YjtNaIYSQJodCTgghTQ6FnBBCmhwKOSGENDkUckIIaXLqkrVCCKkP2ZEshkeH0T2zG5mTM/XuDokJCjkhbULftj4M7BxAUpLIqzx6F/Wif0l/vbtFYoDWCiFtQHYki4GdAzg6fhRvjb2Fo+NHMbBzANmRbL27RmKAQk5IGzA8OoykJF3HkpLE8OhwnXpE4oRCTkgb0D2zG3mVdx3Lqzy6Z3bXqUckTijkhLQBmZMz6F3Ui65UF6Z3TEdXqgu9i3o54dkicLKTkDahf0k/es7pYdZKC0IhJ6SNyJycqYmAM82xtlDICSGxEneaIweFcCjkhJDYcKY52gzsHEDPOT1liTBz36PByU5CSGzEmebI3PfoUMgJIbERZ5ojc9+jE4uQi8h7RORuEXleRPaJyMVxtEsIaS7iTHNk7nt04vLI1wF4UCn1GRHpANAVU7uEkCYjrjRHe1DweuSc8CxGlFKVNSDybgBPAzhdRWxs4cKFint2EkKiwKyVKUQkq5Ra6D0eR0Q+D8BhAP8oIucCyAK4Tin1dgxtE0LanFrlvjczcXjkFoALAHxPKXU+gLcB9HlPEpE1IrJLRHYdPnw4htsSQggB4hHy1wC8ppTaMfH6bmhhd6GU2qCUWqiUWjhr1qwYbksIIQSIQciVUocAvCoiZ04cugTA3krbJYTEQ3Yki817NjP/uoWJK2ulF8CdExkr+wF8PqZ2CSEV0G4rI9t1YjQWIVdKPQ2gaCaVEFI/4l4uX0vKEeR2G7ScsNYKIS1K2MrIRo1cyxHkZh604oBCTkiL4rcy8sEXH8TqvavrHrmaou4oguy8DtAD0v7f7vcdtCjkhJCmxbQysuecHgztHap75OqNunvO6cGyP1oWKsjO694ZfwcQoNPqxPH8ceQKOdd11VjO36gePIWckBbGu1x+eHQY9+6713VOrSNXU9R9+zO3Y+i5IRRQQC5vFmTTdVDAW2Nvuc5PJVJIJVOxL+dvZA+eQk5Ii+NdGRlHIapKIlOTdw8AR3NaoJOSRNpKI5VIueqrbN6z2XidFxHB+svXY+W5K0vqVxCN7sFTyAlpI+IoRFVpZGry7p3kVR4fOfkj+OKHvugaKLpndms7JYRpyWmwklPSZg86uXwOVtKKbfBpJA+eQk5Im1FJdcK4ItMl85bgof0PISEJt1Uywc6Rnbh15q3FbQqAkNJ8zicMe9AZz49jvDCOjkQHrKQVy+DTSCV1ubEEIS1CKSs4MydnsHz+8rIi0/H8uOvYeH488mYPfdv6sHjTYjx64FEAwCXzLsHiOYuLzkslUkVtDo8Oo9PqdB3rSnXhgpMuQNpKF9U/dw464wXd57HCWFk7DcVZZ70aMCInpAWo1URcLp+bFEWb8cJ40QSliexIFut2rMOx3LHJYw+//DDWX74eO0d2uo6bol0/S2bDFRsAFOfF+3nxQHm2SFx11qsBI3JCmpxa7m1pJS10JDrcxxIWdo7sNN7P+ZSw9hdrXWINYNK3vu7C63yjXbsNAK6oOG2lccm8SwCYnzC6Z3YXDTo25doi5T7JVBtG5IQ0OdWYiPPLSume2Q0raWGsMDZ5LFfIYeOvN2LjUxtdTwLOp4TxwniRJQNoqyOXz/lGu6Ynje2rtmPtL9Zi6/6t+PmBn+Phlx82PoEM7R0qyi13euS1FONq559TyAlpcuKeiAuyaZxZLwAmJyrt1EF74tP+2zSR6SQpyckME2+apN/E6vxZ8/Hwyw/jWO4YjuGY677eFaJOIe9IduDmj9+MpWcsrevip2rYXrRWCGly4pyIi2LT9C/px/ZV27H6/NXostzb89pPAkH+tJNUMuU74Pg9aewc2RlYQ8bv2mnJaTj9vafXPBKvhe3FiJyQFiCuibioNo3998anNrrOdT4JeJ8SLLGgoCaPW2Lh6g9ePSnA3j77PWksOnlR4H39rh0vjNc8XbBW+eeMyAlpEeKYiCvVprlk3iXG1D/TU0Lm5AxSyZT2qRMWzjrxLAztHcKan6zB4k2L0bfNvUOk35PGynNXhj6B2NdaMhWr5go5DO0divxdxLEhR63yzyXixvexsnDhQrVr166a35cQEk4UT9c7kXnZ6Zfhax//WlGFQjsrJZfP4Qs/+0KgZ96V6sL2VduLBiK/icKgCcTsSBZbXtqCtb9Yi7H81MSs3z3K+Q6iEmdbIpJVShXt/UAhJ4QUESSeW17aglu23+JKJXQKpEm4zpt9Htb8ZE1RgSsn0zumY8OVG7B8/vKK+m7fv1Ao4Fjene4Y5R7ZkSwWb1rsGnSiDgBBbcaRteIn5PTICSFFeDNIgGCBdE42mjJN1l++PrC+ChCP5WCskFjiParha5u+zzihR04ICSQ7ksU3fvkNrNuxDkfHjxaJODAlkH4iaOduO33tC0+5EF2pLpyQOiG2Je9+2TJpK+270MjpgWdHstj/2/1FC4kaqa6KCUbkhBBfJqNwVShalQlogUxIwiWQ7+TcFQrfyb2D7pndWD5/eVFmzX377sPye5bjrs/chavOuqri/pomF9NWGjctvsmVP26yfw79/hAG9wzCSljIFXKwxEJnqrOsCpG1hkJOCDESZlOYBBJAcXVCx2uvxTD43CDG8+MY3DMYi5D7lem94WM3BH6ub/3qW8hDDwD25GjaSqPvo301X0BUDhRyQogRX5simUYikXAJpD2Zt/+3+9GZ6nRNanamOo3+8oujL+LHL/wYBRRw/wv346XRl3DGzDMm3y93gjAsp970uWwRd5KQRM0XEJVLbEIuIkkAuwC8rpS6Iq52CSH1wWRTpBIpXHv+tbj2/GuNNoWpEqKfv3zjIzdO1l8Zz4/jxkduxObPbC5qMyhlz0/sgyYXwza2sMkVck2z52eck53XAdgXY3uEkDriXJDjrHi46ZlNkwtrvEvQj+WOQUEZFwk5eXH0Rdyz955JQc2rPO7eezdeGn0p8rJ2u7a534KiKJ/LrqJoJYpj2hXzV8S+52c5/Y1CLEIuIqcCuBzA9+NojxDSGPQv6cf6y9frnXmgl7k7hdVoU0wso99w5QZsX7W9KJLOjmTxp3f9qXHF45d+9qXA9D9nG5XUMLHrxWy4cgMe+/xjuP7i69GV6kKX1YWOZAeuOfcabPqTTRG/pXCqXXMlLmvl7wF8BcB0vxNEZA2ANQAwZ86cmG5LCKk2VtLCtOQ01wpJW1j9an77bdXWt60Ptz5+q2+d8EcPPIovfehLocva48j1dtovmZMzRb56nDZItWuuVByRi8gVAN5USgUOLUqpDUqphUqphbNmzar0toSQKuHNrw6qF5I5OYNLT7+0qA3TVm3ZkSy++f++6SviAJAv5DG4Z3DS+uhK6QjZuydoNWqYOGvVxG2DVLvmShzWykcAfFpEDgDYDOBTInJHDO0SQmqMScDCyuTe/PGbkbbSrnaO5Y4VTXr+4KkfoKAKgfcvoIC7996Nv7jgL9BzTs9kPvfQ3iGXmFZzD81q2CDV3vMz1lorIvIJAH8blrXCWiuE1IawwlJeKyGoxkhQW33b+vDtx7/t2sjBSli4/uLrJz3y5Xcvx13P3RWp3xedchGeffPZ0Hon1cgC2bxnc1FdmKAaLaX0odL+stYKIW1GUAqfX2Er72Iep48blNI3f9Z8eIPCXCGHdTvWTdois981O3Lfnxx5sijK96uL7tencgW2FBvE7zsuJy2yEmIVcqXUzwH8PM42CSGl47dNmt82bAM7B/CpuZ+a3LLNxitgJoFadd8q3PHsHb652bb47nh9R+T+m0oClOIpl1I61nSuaXWo6YnG9D0e+v0hDO0dqurWbl4YkRPSgoSl8JlWbD60/6GiY85JRlMN8oIq4KfDPw3sS/fMbhwdP4qdIzsj9z+VTGE8P45OqxNWwopU78RZA91vEIsqxttXbQ/dccm48lUBg3sGXRk+fveOEwo5IS1ImD3gLWx1PHdcpxjCvQnDsj9aBsAseD/+1x8H9iEpSVz9wasnBez1//Y61v58LTY9s2kyWr28+3L8bPhnrug7baXxo8/8CB3JDmTfyOL1I6/jlBmnYOkZS33v5RxkjuePB1pEToIGPDuDxc7i8Qq66TvOqRyshGVM1aSQE0JKwq94lC1MRYWtoDNGvOTyOWzesxn7f7s/0mbKACAQJCWJjmQHhvYOYfYJs9G/pB+zT5iN713xPfz5BX/uinRN1savXv0VBnYOIJfPYawwhlQiha//8utGmyKsuBfgb8uEDXir7ls1WRERgOv+pu+455yeou3kalEClzsEEdLCmDxtv6yMT879JLa9vG1SlBa8bwF2v7l7qoZKIefKSvEjiaSrCFWU3XWc/QRQlD1jk7bSeOzzj7naMn2eVCIFEcG05LSyPPL+Jf1Ydd8q3P7M7a5zo2TOxLm1mxdu9UYIARC8lRmASZ/Zu8emJRYSiYTLNrBJW2kUVAEXnXIRnjr0VOTUPRMmYXby6Q98GvevuD/086y/fD2spFVWWmB2JIsP/+DDRZ+1K9WFjZ/eGPpZqlUci+mHhLQRQUISZLvY72/es7nISulMdaLvo3144rUnsHX/VqQSqUk7YdkfLXNF005KtRa6Z3YXefhOHtr/ELIjWVd/7c8DpX3qnnN6sPLclZHv6U0LHB4dhiWWa84AiF4Rsdpbu3mhkBPSYkR5tA+r2e3nHS89Yylu+NgNgQNF76JerNuxTvvwgvJWMAYYBfbyf2eb/Uv6XTv8OL35cuie2T1ZKMxJ3BUR44J7dhLSQpSyvNxZW8T0XtCS8qBrJzEIYRSGR4ch4n/x8fzxouX/2ZEshvYOYSw/hqPjRyteVu/8/NWqiBgnjMgJaSHirLIXFrWbsAcSZzphqXnUuXzOWFjLnkRVSuELP/sC9hzeMxlxV6O6YDmfv15QyAlpIeKusleq1xuHoFpJCx2JDowV3LnYAACla6KPF8ZdA0S1qgvW2usuF1orhLQQ1a6y56WUkrdR6Z7ZDSvpjjETksC05DTXMedK1Vp/7kaDETkhLUatLAG/SdUodUqCsEXZWU2xoApFqYDeAaKZrJC4YR45ISSUOEveRm3/o//4UZfXbiUsWAlrMu3RmY0Tdr8487qrlSMeBeaRE0LKwq/kbZAXXoq37Nd+wuP8dlo6j/30957uEtGwdMs4V1pWc9VmJdAjJ4T44pfOmMvnYplczI5ksW7HuqL273z2TmNJ3aVnLHUVs/rGL79hvN727OPc7afaGyhXAoWcEDKJd/LSLwvFSlqxTC6u/cXaorrjUNFK6i7etBi3/OKWouudk6Bh5XxLIc624obWCiEEgNk26DmnpyjythfkVDq5mB3JYuv+rUXHxwvjSCVTrsnNsJK6TpxPBnGmJVZ7A+VKYEROCPG1DQBMRt4diQ4AmFyQY2/MHLrC04fh0WGkEqmi4xeferHxfFswjRs6QBfuMq1AjSstsZFTHBmRE9LmZEeyulyrz2YM/Uv6MX/WfKz+yWoA5gU55WCKcNNWGrcuvRVDe4d8Uxj9rrtp8U1YesbSov7EmZbYqCmOFHJC2hjbTgEQuF+nlbT0DkIx7nwTVIUxc3LGVzD9rrvhYzcE3isu0W3E1Z4UckLaFD+vucvqKqpaWC1/OCjCDRLMRo2M6wWFnJA2xeQ1d6W6sPr81bjm3GsiRcFxCGi5EW4jRsb1omIhF5HTAPxfAH8I7bJtUEqtq7RdQkh1MUXZAIpE3IZRcOMSR0SeA3C9UurXIjIdQFZEtiql9sbQNiElkc0Cw8NAdzeQCdCZqOe1MuVE2YyCG5OKhVwp9QaANyb+fktE9gE4BQCFnNSUvj5gYABIJoF8HujtBfoNq6ejnlcPaj3AMMpuDWItmiUicwFsBzBfKXXE894aAGsAYM6cOZlXXnkltvsSks0CixcDRx3zdl1dwPbtbkGMel49aOQBplLqWWiqlfArmhXbgiAROQHAPQD+xiviAKCU2qCUWqiUWjhr1qy4bksIAB3FJj1rRJJJfbyc82pNNqtF/OhR4K239O+BAX282bGX06/5yRos3rQYfdv66t2lliMWIReRFLSI36mUujeONgkphe5uHcU6yef18XLOqzWNOsBUSiMXmmolKhZy0bukbgSwTyl1a+VdIqR0MhltRXR1AdOn69+9vcV2SdTzak0tBxhvYaxq0siFplqJOLJWPgLgcwB2i8jTE8duUEo9EEPbhESmvx/o6QmfLIx6Xi2xBxivRx5332pdT7uRC021EtwhiJAGoppZK2G7+lSLRt2MoRnhDkGENAGZTHGWTVzCHscO9+XAFMfqQyEnpEGJOx2xnjYHFxJVF9YjJ01HNgts3twaqXl+VCMdsZHraZPKYEROmgq/KLXVltwHpSNW8vloc7QmFHLSNDijVJuBAeDQIWBoKJoF0SyCX810RNocrQetFdI0mKJUABgcjGZB9PXp5flr1ujffSUsMKy1ndOo+e6kMWFETpoGU5SaywGWBYxNbVzjWhFpR99AcTS/bh0wYwawdKkWSL9ovV41UBox3500JswjJ02FV1R7erSt4i2CZR+3z1uyBHj0UR21O0mngUQCWLAA2L3b7L03apGtZrGJSHwwj5xUlVqJiilKnT07XNwfesjc3rFj+veOHe7jt90GzJ+vo/1qTDpWSitXSiSlw4icVEwjiIpzIBke1j64M/qePh34xCeAhx8GCoUpAQ+iowNYscIc8UeNyKsxwDXyUwKpLlUvY0vakzvu0NFrvcuvZjLA8uX6t1/Gx803a7G76SZtqYQxNqZFvKcneNLRbyJ01Srgwx8GVq8ufXI1iFatlEjKh9YKKZu+Pi3izolGoP7WQ1gBqkwGOHLE/f6CBcBTT5k/y7Jl+npnZG1H2g8+aE59vPJK4Kc/1W3YbQ4M6EGh0u+lUUvxkvpBa4WUhenx3qZRHvPDbA37fTvzZXgY+B//wy3mps9iW0lA8efv6gI+9akpEfe+t3GjfnKolEaws0jt4WQniRW/nO6OjsbJd/YWoLJxCvzTT7sF8fzzi7NXvDaKN43Ry4MPmo/ncvFFzUxNJE4o5KQsuruBd95xH0smdcS5cmVpbVVrQjAsJ3x8XItrLjf1/u7dwPr1OkI39cdvALPJ5XQ6o4kVK+IVXL+BirQfFHISGyLA2WeXdk1Ui6AUsQ+qxxIWTSeTWsT97A+TP+3ktNOAl14qPn7FFcCmTcH9JqRcmLVCymJ4GOjsdB/r7CwtcyJqxkspS+uDqgaGRdOAedLQmZViT6R2dBRfm04Dr75afPyKK4Cf/CT4voRUAoWclEWlmRN9fTotzy/jxabUcq5BqXmmPluWFmC/1EJTCmF/v7aQvGJeKOj2nHR1AZ/9bOjX4Us7lOwllUMhJ2VRSVEnW5y9Ig4UDwal5kwHDTCmPl9/PfDYY8CGDTo7xWnrfOxjwO23634ePeoeRFauBL78ZXdbK1b496kcKinyRdoLeuSkbMrNnCgl46XUyN8W63Xrpo452+zp0YWygKliWfZ1ThYsAPbs8e9/JhOtXEC5GTx+JXtLyUNnLZb2gUJOKqKczAmTOHd0TGW8eAUojt3ls1lg7Vpg61YgldLtHDlibueOO/xF3JtC6P38TnG389Ntb70UKt1YgnnmbYZSquY/mUxGkfbmq19VqqtLqenT9e+vfjX4+K5dSg0O6t82fse6upQCpn4sS6mODvcxQJ/nvNamt7f4XPvnmmsq+3xh2J/phz8s/hx+/TW1Ue61pLEBsEsZNDWWiFxElgFYByAJ4PtKKY79JBCTLRHFTrC98aEhc8RpimSdeeJOCgVgy5biCHfRIvP5H/1otBTCUm0Rv+X+ptK6UaLxam0TRxqXioVcRJIAvgvgUgCvAXhSRH6slNpbaduktfHaEkEC5BRu00IeWyjD8rydHDsG3HKLtlictsPKlcB3vuMubZtIuEU8yH8uRUhtC8RUkTFscZIfrMXSfsSRtbIIwItKqf1KqTEAmwFcFUO7pM3wE6Bczp2CeOxYcZTtFMreXp1SaMr19nLsmDmd8YkngB/+ULd14YX62I036t9h2SS5HHD8ePHnMOWnr1unP5eprK5zcVIpkXQmowe1jo6p7yGOYl2kcYnDWjkFgHMZxGsALvSeJCJrAKwBgDlz5sRwW9JqeCc2x8eBSy4BDh6MtpAnl9M514cO6WN+9eBE3O/5RcsrVwIXXaQtjkIBuP9+4NvfLq746Nwyzn5ysEml9I/JFlm7NrguerlRdF+f7kehoPtpWfr17Nmc8GxZTMZ5KT8APgPti9uvPwfgO0HXcLKTBLFrl1JXXqlUOq0nC9NpPWHpnLxLJpVKpfQkXleXUhdeOPW330QloCc90+noE4F/9mf6XoBSIvrH1G46be5nR4eeuDR9Rm8/vNdFnSD1tuv3HXDCs/lBFSc7XwdwmuP1qRPHCCmbhx/W0aodsdorMFOpqWJdqZSOwi+7DHjkkfAaKtOm6cgYKN4azp5EdUbNL74I3HPPlN0TVPHZL7KeNq14tSeg75dKma9zpmJ6CcsNDypDwAnP1iUOIX8SQLeIzIMW8OUA/ksM7ZI2xSRGnZ1TXvQtt7h98oceMoulk0RCTxza4miLt50pcu+9xfnWX/pS9IlTP/zsEVP1SECL+4oVZhGPkhseNNnLCc/WpeLJTqVUDsBfA9gCYB+AHymlnqu0XdK+mMRofHzq71TK/Z5l+acY2qTTbrG3t4Sz9+P01nG57z4d5ZdKWO2WMJJJ3SfvBGrUmjPOMgT299TRUV5fSPMQSx65UuoBAA/E0RYh3knPd97RQt3fP5V66OWyy8y78tiYolG/NMG1a4EHHigvGs9kgO9+N3xpvF090rlBNDBltZjy56OmNJpWl3KZfmvDJfqkIbHFaMuWKSvFFj2nX25bDOedZ/bJ02ltq5iiUb/If8uW8i2V3buBffvCrZ6wfHevSJdTc4bC3T5QyEnDksmYJwU7O4HPfU6L3aJFU/VZvKTTwE03uYtjedtfsMC98OfEE4E33ii/z2NjuuTttGlmH9s5WWk/dQDFA5BXpOOqOUNaFFMqS7V/mH5IouJXO8VOTYxSp6WUtuP+cab8mfpn11a55ppofTfVlyHtA3zSD0UF5VRViYULF6pdu3bV/L6kOQnbZ9O5030ppVs3b9bRc1DaYqVMn65rnXd361Wgzns5+w2w7CwJR0SySqmF3uO0VkjD45y8279fv3ZOEjr95FK84e7u8GwXPzo69LWFQvB5tkUyPOzOvAH0a6cPTl+blAt3CGpnRrPAgc36d4OTyeiaI0uXxlcQKpPx39XHJpEortliZ490dupJzYThf1E67U75y+XMQl7uQEKIEwp5u/J0H7B1MbBzjf79dHPsI1bJFnMmenuD67gkElrs7ful07pWy9gY8Pbb5qjcnmR1bh1nWcUDQkdHeHYLIVHgP6N2ZDQLvDAA5B2G7QsDwJweYGbjP9uXu8WcieFhLdLefG6bzk5g2TIt+H7WjperrwZuuMF9rLtbi7az2Jad301IpTAib0eODAPiCUMlqY83CbbVUqmnHJbP7dy42c/acdLVpYXf1N84nyQIcUIhb0dmdAPKo0Yqr483ONmszjYx5Y2Xg1dgLUv/+Imt8/yuLnObflF2f7+2WzZscNsuhFQK0w/blaf7tJ0iSS3iZ/YC51WuLNVMoavmhsLOfgPmv52fx297Nm5yTKqJX/ohhbydGc1qO2VGdyzeeKVCGzQIZLPhedhx4OyD376gpfSbkDjxE3Ku7CSxUOnO7WGrMgcH9Xve9nt741vl6OyDaZMIbsxA6g18VnbSIyexEFSdL4woJVpNk5JHjwL/8A/Ahz8MrFpVWf+9fQjaF5SQRoNCTmKhkp3bowwCfpOMx47plL7bb69MzIN21rHhxgykUaGQk1ioJL0u6iBgZ32sXq0X3XgZHCw/m8XUh0o3iSCkVnBBEImNchfqlFKi1T62fn3xe5ZV/p6Ufn2Ia+ERIdWEWSuNTBxZJTFnplSTUrI/Vq3SdoqTOLJYmIFCGhlWP2w24sjzrlKueLUopfrfpk369+DgVL2SOKwPViAkzQgj8kZkNKsLWTlroSS7gEu3R4+q42ijCWAETdoJRuTNRFAtlKgiHEcbTUCjRdAcWEg9oJA3InHUQimljSby0etFFIGuZgkBQoKoKP1QRL4pIs+LyLMi8s8i8p6Y+tXezMxoPzvZBVjT9e8ze0sT2ahtNGld8lrS16fLA6xZo3/3Gb6iKIuaCKkWFXnkInIZgEeUUjkR+Z8AoJT6ath19MgjUu2slSAfHWCUjug1XjZv1kLvrFNu79e5fHnt+hsHtIcal6p45EqphxwvnwDwmUraIx5mZioX0aA2/Hz03WuBQw+Xlu3SovZM0KpTp8iZFhQdP958W7nRHmpO4lzZeS2Af/F7U0TWiMguEdl1+PDhGG9LysbkoxfGgTe26ig995b+/cJA8L6eNbRn4q5HHkbUVafOla2p1NTxL3zBbMU0IrSHmpdQIReRbSKyx/BzleOcGwHkANzp145SaoNSaqFSauGsWbPi6T0JJmxzZZOPftJlQCLlPi9o9yDntnFRhb9MonjVcVNK6YH+fr3iVES/HhtrLjGspPAZqS+h1opSaknQ+yKyCsAVAC5R9UhKJ2aiLgY6r1/v1WnbIgBwaJv7nKCMmRqlOTqjRZuBAb2Evto+bimlBywLmDbNvTenyYppRCopfEbqS0UeuYgsA/AVAB9XSh0NO5/UCNPmys+vA1IzgJOWFgus10c/s7d4EPAT5RptGxfVq64WUfPVm1kMS6l5QxqLSvPIvwNgGoCtop8nn1BK/WXFvSKVYYqSC8eA3bcAe74ePnnpjdJtETdNaNr2TFThL5NmEchmF8NyC5+R+sIl+q2IKa3QSTlL9cOsmhpkrdgZFYDOBlmxYqrmSqPBFD5SDfzSD1mPvBVxTmImDIW7gyYvTUSZ0JyZAeYur2rqoR0t5nLaix4aatyMkExG549TxEktoJC3Kuf166h7/k3FYm7ysIMyXIImNGtINqvF284GaaaMkKjUOr2StAastdLK2JOYuSPBHvbjq4ADg0Bi4p+D1zap0YRmGPWe8Kw2XIxDyoUReaMSlgNeCnZ0vmiD/u0U6cdXAS/fDqgxbZn42SaV1n6JgWaZ8CwHLsYhlcCIvBGpxoYQpqX6o1nglUHz+d48cL9MlhrS7BkhQbT60wapLhTyRsOUA/7CgBbRuMXzyDAgFoAx9/FCzmybxFH7pUJaNT2ulZ82SPWhtdJoVGNi0c+m8fO4566ou2AH0YoZIaWUAiDECyPyRsNYyOo4oMoso2fbNICOtOeuAC7epF87F/MA+h7vd7wfRItWO6wnrfq0QaoPFwTVglJFzxbfQk5PQiKpC1mddV1pXrnfwqB517jFutz+NcmmzoS0ClwQVC/KKfF6Xj+waL0jCs/rJfZ7v11aFoufHfPKYPmLeWpY7ZAQEg0KeTUZzepiVeWI3tsHARQ8B3PAyJbo95/RreuLexGrfM+9QRYHEUKmoEdeTZ5dqyNpJ94Sr362hvi06T0eZIscHAKUdzCYoJTFPM57RF0cRA+dkJpBIY8Dk2iNZoFDW4vPLYxPiV6Q13zSUuDZm92TnGLp4zZB19sWCDyim0iXtpjHdI+waof00AmpKRTySvETrSPDeoLSG5GfdKkWvbB88ZkZYO5ngQP/NLHlTEJPdjoHiqDrTRZIIq1rr8y/YaqNoKjZ7x6XbvdfHFTLPHhCCIBm88jjXLYeB0ETfyZ/OpEGFtys/w7zmp/u09ZIIgUoAO+/2h3Vhl1vskAkAZy8dKr9sEnYsN1/TBOk9NAJqTnNI+Q13OA3MibRso8fHNLpgzZiTUXUo1ng9/uLhd72mp0DRP6oTkE8OOQewMK86qD6KFEzT8opltUgBbYIaSeaw1pp1Md1k2jljwIH7gT+7RHo/agnEEv39+k+nckCaCEXC0h2ur3mA5sNA0TCPUk6M6PbOzCoI20Uir+POT16ezcFHYnb75Wyz+bsJcAbD008GUTY/adGOwYRQqZoDiGv0Qa/JWOL6cu3u4+PPAAkPTXAEyngjS06F9wp8CoJfLAPKOSB44eBl+8A3n22YYD4PTD65JQnbVsvKACFMQCWft05W1swXu8+d2Tqu4oSNTtXhKo88N4PAZnbon3fDVBgi5B2ojmEvJEf16d/wHDQFlcHKg+8/TpcIg4AyAMvfA849trUoT+4UEexk5H7xITpC/8bGP4/WiQPDnlWbOaAfE6L77vnh0+kepfmO6N50xPQ4V/qY1GW7wMNUWCLkHahOTzyBqmHbcQv3/vEDxf3912nmM91ijgA/PsO4HfP6b8LjgFM5bS4+pWeBXQE/u87wycc7ai5kNP2zsGhqXmHqCtCCSENQXNE5EDjPq775XtnbtV/O/s7mi0+FxP+tpeRn0Gb2wbEck+kOlF54A8WAS9tLD7ufIIZzU4sGBoD8hNPD3bUPqPb3L69IrRRvntCCIBmichtarDBb8nMzABnX69TC+2fs6935IIvd09Qes89+Y99Gg4oZqYKuoph0Ths6ch/3srwJ5iweYcZJssIjWFnEUJcxBKRi8j1AL4FYJZS6jdxtNkU2Atq5vTonze2TGWI+GF6sthykbZTotI5WwvzK3e5I+fERGaM332cBM072OmRXuqdJUQIMVKxkIvIaQAuA3Cw8u40Ed6skPcsAH63W79+7uvBy9KdE4GjWeDMvwHeGgYO/0qnLSpDoSsnR18H3nyseOVoIjXlb9sCPne5fx/80gRN6Y/JLuCkZaFfCyGk9sQRkd8G4CsA7o+hrfrhVy8l6jJ0b0QdJc/dOxjM6Zn4O0TIVUGnOJoi6jceBHasnnht2CjC+Zn8onZTtG4fJ4Q0HBUJuYhcBeB1pdQzIn7pG5PnrgGwBgDmzJlTyW3j5/FVOiNDJr6OM3v1b7/CT34rOp2E5bmbBoODQ3op/st3oKjYlYs8cHg7MO+/6nOdA4E3LdHOcb94k39dGG8fuaiHkKYiVMhFZBuA2Ya3bgRwA7StEopSagOADYDeIaiEPlaXx1c5FvRMZG9487cBd4TtF7E6yR8L3p7Nb7LxpGXAkefDPfPCODD2H7qAle3N+42lBwb1Cs1SVsc2apYQIaSIUCFXSi0xHReRBQDmAbCj8VMB/FpEFimlDsXay2oxmtUi50UVJpa9O3BG2M6IVRWKKxwCWsSfuBY4tM28iMZvsjGZBkZ/Hd53lQdevx/oePdUVF4YN9syCUvnlntruxTGg58aqr2ohzXLCYmFsq0VpdRuAO+zX4vIAQALmypr5ciwFrm8ZxUmCijKzPTmYdsR6xtbgN23GMRcaVF1WhtO/OyLVwZhzCs3kR8DXtwItw3jY/lMm1Us8mq8/E2dK4U1ywmJjeZZEFQN/Cbv5n5Wp/iFecR2xDp+xLNZsocD/2S+3mtfzDgbGJoBJDom2gqZ9BQpjuqtLuC95wG/2aEHKUDfe3o3IB3u/knH1LxALWnUImiENCmx/S9WSs2Nq62a4a05UsjphTZ29BzVI7YFee83gYN3Fb+vAiwMr33xJ69pfx3Q6Yy/PwC8dh/w5iPu6xJp4NxvAM/cCBTecdwrr4tbAcWrSr1PH5KYyhuvpcXRqEXQCGlS2jsiB4In9UrxiGdmgHP+O3DwRzCuynwr4sYKnY555RPmapF9xlB7/f1XA2d/GTj+b/5PDt7PcmYvsO/bU3ZKIQc8+VdT+e+1sjgauQgaIU0IhRyIb1JvZgZ47/nAbw2TlccPB1/rFxWbolfn4pw5PYA1Q2esnLQ0+HPM6dEZOZO+eK68/PdKYXojIbFCIY+bs74MPP654uPTZvlfEzTxN6MbyL/jPr8wpo97rxs/EiyGfvuIOqmVxcH0RkJig0IeN/NWAv/6HU+kmwD29gP5t4tti6CJPwAY2QIoj1WjFPAf+0qfMDTtI+rFVCVxZEu0iL9UWLOckFigkPtRyQTg0id0SuKetRNecGFqX0yv0PpN/GW/rDNPABSv8swDz98WXHPc1HfvPqL6IgBKZ7AkLLfF8XSfe0ejZ78GnP23TBMkpMGgkJuII8d5erf2snNvTR1z2hb2QKFyxRN/uaN6R54gfreneNGSs9aKt+925F+0Q5Ej2l+0Xj9RAPr859e5z1d5YN+3mCZISINBIfcSV45zUGZGUOVEe3VmWBGD5DTgDz+pV44G1Vqx+x5WHyY5zZ1TfmRYr1r1ovJ6ERSFnJCGgUJuY0fIv98fT46zX2YGUDxQ/G63jobFAt7eD+z5umc/TgMqDyz4mv6xbZQjw8Cr95r7HlYfxuuNz+g2CzkQPsgQQmoKhRxwR8iF8WIfudwcZ1NmhqnWtyS1iM9drgeUPV8vbuvd86cGmaB8cb+nAO/Akn9HC7LVaU7/m5kB5n3WUVDMxgreOIMQUnMo5CYrRSxA0jpVz7Ys7EnEUi0Fb2ZG2GIY0w73dk3xsAnYsPxs78ACBLdnr3A9cCeAhPbkz7qOtgohDYYob2pbDVi4cKHatWtXze9r5MBmYOca96SkNR34YB/wrtP15OHBoXhWPtpCHKXNSrJm4l5yzyqFhDQEIpJVSi30HmdE7hchnzRhH+xYHU9xJ9NuQCctC46uyxXNuPOzme9NSEOTCD+lxbHtCNOO80HFnUrBad/k3tK/Dw6VH+GOZvWTxGi29GsJIS0HI3KgtL0ry5n4jLPanx3ZA8XVGgkhbQmF3MZkH8RV3CmuAcFepOOsleK3cQUhpG2gkIcRR3GnuAaEZ9eaC169MhjcHicrCWlpKORRiGOyr9IBYTQLHNpqfk8sf5vGr9wAxZ2QloFCXksqGRDCStCabBpTjvzztwG/ew74t0cqT6nkYEBIQ0Ahj5tqiZvfEvtE2t9WMU2yFsaAkZ+6jz2/DkjNKK1MLTdPJqRhYPphnDzdB2xdrBcYbV2sX8eFN00ykQZOuRK47DF/AQ2rr2JTOKbL7kbtsymd8oUBpkMSUicYkcdFLXaGL2VbN2BK/J+/TUfiQdiWTZQ+c/NkQhoKCnlcVFvcSt3Wzea8fl1w64nVgPKIuXQYjkXoMzdPJqShqNhaEZFeEXleRJ4Tkf8VR6eakjhzxb2rNiu1MuatBM7+snv16rxrgAU3a4um1D4HrYYlhNSciiJyEfkkgKsAnKuUOi4i74unW01IHLnifhOIcUT7fumPuSPl9ZmbJxPSMFRU/VBEfgRgg1JqWynXNVT1w7gpN2tlNKsnG50ee7ILuHS7/tvvPVY3JKRt8Kt+WKm18gEAHxORHSLyCxH5UEAH1ojILhHZdfjw4Qpv28DMzOgNIkoVxLCou5pWRrl9JoQ0BKHWiohsAzDb8NaNE9fPBHARgA8B+JGInK4MYb5SagOADYCOyCvpdEsS5rHTyiCE+BAq5EqpJX7vicgXAdw7Idw7RaQA4EQALRxyV4koHjvrghNCDFSafngfgE8CeFREPgCgA8BvKu1U28KomxBSBpUK+Q8A/EBE9gAYA3CNyVYhJcComxBSIhUJuVJqDMDKmPpCCCGkDFhrhRBCmhwKOSGENDkUckIIaXIo5IQQ0uRUtES/7JuKHAbwSs1vPMWJaM40Sfa7trDftYX9Duf9SqlZ3oN1EfJ6IyK7TPUKGh32u7aw37WF/S4fWiuEENLkUMgJIaTJaVch31DvDpQJ+11b2O/awn6XSVt65IQQ0kq0a0ROCCEtA4WcEEKanLYV8mbeNFpErhcRJSIn1rsvURCRb05818+KyD+LyHvq3acgRGSZiLwgIi+KSF+9+xMFETlNRB4Vkb0T/6avq3efSkFEkiLylIj8tN59KQUReY+I3D3x73ufiFxcj360pZB7No3+IIBv1blLkRGR0wBcBuBgvftSAlsBzFdK/ScA/wrg7+rcH19EJAnguwD+M4BzAKwQkXPq26tI5ABcr5Q6B3rHrr9qkn7bXAdgX707UQbrADyolDoLwLmo02doSyEH8EUA/Uqp4wCglHqzzv0phdsAfAVA08xSK6UeUkrlJl4+AeDUevYnhEUAXlRK7Z8o07wZetBvaJRSbyilfj3x91vQgnJKfXsVDRE5FcDlAL5f776Ugoi8G8BiABsBXdZbKfW7evSlXYU88qbRjYSIXAXgdaXUM/XuSwVcC+Bf6t2JAE4B8Krj9WtoEkG0EZG5AM4HsKPOXYnK30MHJ4U696NU5kFva/mPE7bQ90XkXfXoSKU7BDUscW0aXWtC+n0DtK3ScAT1Wyl1/8Q5N0JbAHfWsm/thIicAOAeAH+jlDpS7/6EISJXAHhTKZUVkU/UuTulYgG4AECvUmqHiKwD0Afgpnp0pCVp1k2j/fotIgugI4BnRATQ9sSvRWSRUupQDbtoJOj7BgARWQXgCgCXNMKAGcDrAE5zvD514ljDIyIpaBG/Uyl1b737E5GPAPi0iPwxgDSAGSJyh1KqGXYeew3Aa0op+8nnbmghrzntaq3cB71pNJpl02il1G6l1PuUUnOVUnOh/xFd0AgiHoaILIN+dP60UupovfsTwpMAukVknoh0AFgO4Md17lMookf3jQD2KaVurXd/oqKU+jul1KkT/6aXA3ikSUQcE//3XhWRMycOXQJgbz360rIReQjcNLq2fAfANABbJ54mnlBK/WV9u2RGKZUTkb8GsAVAEsAPlFLP1blbUfgIgM8B2C0iT08cu0Ep9UD9utQW9AK4c2LQ3w/g8/XoBJfoE0JIk9Ou1gohhLQMFHJCCGlyKOSEENLkUMgJIaTJoZATQkiTQyEnhJAmh0JOCCFNzv8HlU6MaDVHsLIAAAAASUVORK5CYII=", 115 | "text/plain": [ 116 | "
" 117 | ] 118 | }, 119 | "metadata": { 120 | "needs_background": "light" 121 | }, 122 | "output_type": "display_data" 123 | } 124 | ], 125 | "source": [ 126 | "data = test_tensor[-1,:].cpu().numpy()\n", 127 | "index = cluster_ids_stack[-1,:]\n", 128 | "cluster_centers = cluster_centers_stack[-1,:]\n", 129 | "print(data.shape)\n", 130 | "print(index.shape)\n", 131 | "print(cluster_centers.shape)\n", 132 | "color = ['orange', 'b', 'g', 'r', 'm', 'y', 'k','c'] * num_clusters\n", 133 | "for i in range(num_clusters):\n", 134 | " t_c = (random.uniform(0, 1), random.uniform(0, 1), random.uniform(0, 1))\n", 135 | " plt.scatter(data[index==i,0], data[index==i,1],marker='.',s=90,color=color[i])\n", 136 | " plt.scatter(data[int(cluster_centers[i]),0], data[int(cluster_centers[i]),1],marker='^',s=150,color=color[i])\n", 137 | "plt.show()" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [] 153 | } 154 | ], 155 | "metadata": { 156 | "interpreter": { 157 | "hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1" 158 | }, 159 | "kernelspec": { 160 | "display_name": "Python 3.6.8 64-bit", 161 | "name": "python3" 162 | }, 163 | "language_info": { 164 | "codemirror_mode": { 165 | "name": "ipython", 166 | "version": 3 167 | }, 168 | "file_extension": ".py", 169 | "mimetype": "text/x-python", 170 | "name": "python", 171 | "nbconvert_exporter": "python", 172 | "pygments_lexer": "ipython3", 173 | "version": "3.6.8" 174 | }, 175 | "orig_nbformat": 4 176 | }, 177 | "nbformat": 4, 178 | "nbformat_minor": 2 179 | } 180 | -------------------------------------------------------------------------------- /config/SCRL_pretrain_default.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | SSL: SCRL 3 | Positive_Selection: cluster 4 | cluster: True 5 | cluster_num: 24 6 | soft_gamma: 0.5 7 | backbone: resnet50 8 | backbone_pretrain: ./pretrain/resnet50-19c8e357.pth 9 | fix_pred_lr: null 10 | SyncBatchNorm: False 11 | resume: 12 | 13 | MoCo: 14 | dim: 2048 15 | k: 65536 16 | m: 0.999 17 | t: 0.07 18 | mlp: True 19 | neighborhood_size: 8 20 | multi_positive: True 21 | 22 | 23 | data: 24 | name: movienet 25 | data_path: /tmpdata/compressed_shot_images 26 | shot_info: ./data/MovieNet_shot_num.json 27 | _T: train 28 | frame_size: 3 29 | clipshuffle: True 30 | clipshuffle_len: 16 31 | # aug_type: asymmetric # asymmetric or symmetry 32 | workers: 96 33 | fixed_aug_shot: True 34 | color_aug_for_q: False 35 | color_aug_for_k: True 36 | 37 | 38 | optim: 39 | epochs: 100 40 | bs: 1024 41 | momentum: 0.9 42 | optimizer: sgd 43 | lr: 0.03 44 | lr_cos: True 45 | schedule: # works when lr_cos is False 46 | - 50 47 | - 100 48 | - 150 49 | wd: 0.0001 50 | gradient_norm: -1 # off when <= 0 51 | 52 | 53 | log: 54 | dir: ./output/ 55 | print_freq: 10 56 | 57 | DDP: 58 | multiprocessing_distributed: True 59 | machine_num: 1 60 | world_size: 8 61 | rank: 0 62 | dist_url: env:// 63 | dist_backend: nccl 64 | seed: null 65 | gpu: null 66 | master_ip: localhost 67 | master_port: 10008 68 | node_num: 0 69 | 70 | 71 | -------------------------------------------------------------------------------- /config/SCRL_pretrain_with_imagenet1k.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | SSL: SCRL 3 | Positive_Selection: cluster 4 | cluster: True 5 | cluster_num: 24 6 | soft_gamma: 0.5 7 | backbone: resnet50 8 | backbone_pretrain: ./pretrain/resnet50-19c8e357.pth 9 | fix_pred_lr: null 10 | SyncBatchNorm: False 11 | resume: 12 | 13 | MoCo: 14 | dim: 2048 15 | k: 65536 16 | m: 0.999 17 | t: 0.07 18 | mlp: True 19 | neighborhood_size: 8 20 | multi_positive: True 21 | 22 | 23 | data: 24 | name: movienet 25 | data_path: /tmpdata/compressed_shot_images 26 | shot_info: ./data/MovieNet_shot_num.json 27 | _T: train 28 | frame_size: 3 29 | clipshuffle: True 30 | clipshuffle_len: 16 31 | # aug_type: asymmetric # asymmetric or symmetry 32 | workers: 96 33 | fixed_aug_shot: True 34 | color_aug_for_q: False 35 | color_aug_for_k: True 36 | 37 | 38 | optim: 39 | epochs: 100 40 | bs: 1024 41 | momentum: 0.9 42 | optimizer: sgd 43 | lr: 0.03 44 | lr_cos: True 45 | schedule: # works when lr_cos is False 46 | - 50 47 | - 100 48 | - 150 49 | wd: 0.0001 50 | gradient_norm: -1 # off when <= 0 51 | 52 | 53 | log: 54 | dir: ./output/ 55 | print_freq: 10 56 | 57 | DDP: 58 | multiprocessing_distributed: True 59 | machine_num: 1 60 | world_size: 8 61 | rank: 0 62 | dist_url: env:// 63 | dist_backend: nccl 64 | seed: null 65 | gpu: null 66 | master_ip: localhost 67 | master_port: 10008 68 | node_num: 0 69 | 70 | 71 | -------------------------------------------------------------------------------- /config/SCRL_pretrain_without_imagenet1k.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | SSL: SCRL 3 | Positive_Selection: cluster 4 | cluster: True 5 | cluster_num: 24 6 | soft_gamma: 0.5 7 | backbone: resnet50 8 | backbone_pretrain: 9 | fix_pred_lr: null 10 | SyncBatchNorm: False 11 | resume: 12 | 13 | MoCo: 14 | dim: 2048 15 | k: 65536 16 | m: 0.999 17 | t: 0.07 18 | mlp: True 19 | neighborhood_size: 8 20 | multi_positive: True 21 | 22 | 23 | data: 24 | name: movienet 25 | data_path: /tmpdata/compressed_shot_images 26 | shot_info: ./data/MovieNet_shot_num.json 27 | _T: train 28 | frame_size: 3 29 | clipshuffle: True 30 | clipshuffle_len: 16 31 | # aug_type: asymmetric # asymmetric or symmetry 32 | workers: 96 33 | fixed_aug_shot: True 34 | color_aug_for_q: False 35 | color_aug_for_k: True 36 | 37 | 38 | optim: 39 | epochs: 100 40 | bs: 1024 41 | momentum: 0.9 42 | optimizer: sgd 43 | lr: 0.06 44 | lr_cos: True 45 | schedule: # works when lr_cos is False 46 | - 50 47 | - 100 48 | - 150 49 | wd: 0.0001 50 | gradient_norm: -1 # off when <= 0 51 | 52 | 53 | log: 54 | dir: ./output/ 55 | print_freq: 10 56 | 57 | DDP: 58 | multiprocessing_distributed: True 59 | machine_num: 1 60 | world_size: 8 61 | rank: 0 62 | dist_url: env:// 63 | dist_backend: nccl 64 | seed: null 65 | gpu: null 66 | master_ip: localhost 67 | master_port: 10008 68 | node_num: 0 69 | 70 | 71 | -------------------------------------------------------------------------------- /data/MovieNet_1.0_shotinfo.json: -------------------------------------------------------------------------------- 1 | {"train": {"0": 1374, "1": 284, "2": 793, "3": 260, "4": 654, "5": 311, "6": 1197, "7": 1437, "8": 1236, "9": 1208, "10": 978, "11": 706, "12": 333, "13": 871, "14": 442, "15": 1620, "16": 660, "17": 645, "18": 470, "19": 955, "20": 1271, "21": 2469, "22": 837, "23": 2423, "24": 384, "25": 724, "26": 1166, "27": 668, "28": 936, "29": 325, "30": 338, "31": 1144, "32": 1124, "33": 620, "34": 843, "35": 1079, "36": 668, "37": 852, "38": 529, "39": 675, "40": 762, "41": 747, "42": 888, "43": 974, "44": 726, "45": 567, "46": 897, "47": 1045, "48": 2935, "49": 903, "50": 1449, "51": 620, "52": 700, "53": 791, "54": 596, "55": 814, "56": 922, "57": 142, "58": 1445, "59": 774, "60": 1128, "61": 1526, "62": 866, "63": 969, "64": 682, "65": 1646, "66": 756, "67": 1489, "68": 1381, "69": 1849, "70": 378, "71": 1279, "72": 1022, "73": 1017, "74": 793, "75": 1560, "76": 1311, "77": 2268, "78": 825, "79": 2016, "80": 1885, "81": 1115, "82": 70, "83": 1363, "84": 1191, "85": 1258, "86": 1310, "87": 2932, "88": 1365, "89": 1976, "90": 508, "91": 194, "92": 857, "93": 1474, "94": 738, "95": 1079, "96": 811, "97": 970, "98": 1108, "99": 1449, "100": 1178, "101": 832, "102": 1817, "103": 1229, "104": 2130, "105": 630, "106": 1874, "107": 1420, "108": 1724, "109": 827, "110": 467, "111": 1258, "112": 876, "113": 713, "114": 591, "115": 2476, "116": 863, "117": 1465, "118": 1386, "119": 1263, "120": 1164, "121": 1244, "122": 380, "123": 1891, "124": 960, "125": 300, "126": 1604, "127": 1756, "128": 868, "129": 1110, "130": 1133, "131": 1064, "132": 1158, "133": 2778, "134": 1299, "135": 1306, "136": 799, "137": 1488, "138": 869, "139": 1082, "140": 1532, "141": 1616, "142": 1147, "143": 1601, "144": 291, "145": 944, "146": 266, "147": 1760, "148": 946, "149": 1142, "150": 1713, "151": 1123, "152": 1031, "153": 992, "154": 1082, "155": 1170, "156": 1145, "157": 1412, "158": 1875, "159": 2196, "160": 1157, "161": 1227, "162": 675, "163": 1086, "164": 1293, "165": 1184, "166": 1706, "167": 1564, "168": 1610, "169": 1238, "170": 1131, "171": 1159, "172": 765, "173": 1759, "174": 2165, "175": 2540, "176": 2001, "177": 3034, "178": 1923, "179": 370, "180": 1430, "181": 1417, "182": 1217, "183": 1486, "184": 1965, "185": 2166, "186": 1308, "187": 919, "188": 1142, "189": 659, "190": 1704, "191": 2522, "192": 2180, "193": 1473, "194": 1799, "195": 1091, "196": 488, "197": 2244, "198": 1202, "199": 1237, "200": 1101, "201": 1555, "202": 1300, "203": 1383, "204": 1049, "205": 24, "206": 2890, "207": 1621, "208": 2368, "209": 2233, "210": 2023, "211": 3459, "212": 2298, "213": 1048, "214": 2021, "215": 1099, "216": 3902, "217": 1699, "218": 2776, "219": 1348, "220": 2103, "221": 1391, "222": 1257, "223": 1698, "224": 1442, "225": 2259, "226": 2231, "227": 1877, "228": 1462, "229": 1652, "230": 941, "231": 848, "232": 1612, "233": 1093, "234": 2223, "235": 1515, "236": 1141, "237": 1872, "238": 1021, "239": 2520, "240": 963, "241": 1330, "242": 2104, "243": 1546, "244": 1536, "245": 1112, "246": 2198, "247": 1557, "248": 2394, "249": 4043, "250": 973, "251": 1037, "252": 1245, "253": 893, "254": 1320, "255": 1786, "256": 1502, "257": 2445, "258": 901, "259": 979, "260": 2528, "261": 2585, "262": 1101, "263": 1395, "264": 704, "265": 1575, "266": 3117, "267": 444, "268": 1818, "269": 1481, "270": 1847, "271": 236, "272": 793, "273": 1754, "274": 3582, "275": 1189, "276": 2366, "277": 2724, "278": 1182, "279": 2588, "280": 1436, "281": 1694, "282": 2643, "283": 1362, "284": 1278, "285": 1431, "286": 1655, "287": 1815, "288": 2043, "289": 1961, "290": 2099, "291": 781, "292": 1283, "293": 1378, "294": 1267, "295": 1437, "296": 1237, "297": 2424, "298": 821, "299": 639, "300": 2081, "301": 2564, "302": 1405, "303": 2335, "304": 427, "305": 1496, "306": 2712, "307": 1383, "308": 2087, "309": 1251, "310": 3161, "311": 1798, "312": 2147, "313": 3003, "314": 2670, "315": 3370, "316": 1662, "317": 1229, "318": 1833, "319": 902, "320": 3354, "321": 1317, "322": 2782, "323": 1601, "324": 245, "325": 1036, "326": 1935, "327": 1534, "328": 1243, "329": 1521, "330": 977, "331": 116, "332": 2973, "333": 1250, "334": 1768, "335": 954, "336": 1676, "337": 2300, "338": 3001, "339": 974, "340": 1355, "341": 875, "342": 1332, "343": 1907, "344": 924, "345": 1232, "346": 1011, "347": 2778, "348": 1513, "349": 1480, "350": 2600, "351": 1790, "352": 1911, "353": 955, "354": 1040, "355": 3265, "356": 1285, "357": 1716, "358": 1866, "359": 775, "360": 1724, "361": 1481, "362": 2662, "363": 1263, "364": 1177, "365": 649, "366": 1185, "367": 1079, "368": 1886, "369": 749, "370": 1431, "371": 2963, "372": 1531, "373": 1979, "374": 1703, "375": 1162, "376": 1360, "377": 2444, "378": 1468, "379": 1629, "380": 2229, "381": 1770, "382": 2237, "383": 1678, "384": 1536, "385": 1205, "386": 1804, "387": 2050, "388": 2972, "389": 1641, "390": 2144, "391": 1686, "392": 1241, "393": 1078, "394": 1919, "395": 972, "396": 1364, "397": 1098, "398": 1133, "399": 2578, "400": 1753, "401": 1782, "402": 1326, "403": 1550, "404": 1851, "405": 2321, "406": 2278, "407": 2724, "408": 1948, "409": 528, "410": 2545, "411": 913, "412": 824, "413": 1219, "414": 1468, "415": 2270, "416": 1973, "417": 1297, "418": 1514, "419": 1739, "420": 134, "421": 1664, "422": 1495, "423": 1766, "424": 2873, "425": 1155, "426": 1088, "427": 855, "428": 731, "429": 2365, "430": 1568, "431": 1246, "432": 922, "433": 1280, "434": 2593, "435": 1477, "436": 1571, "437": 1200, "438": 1261, "439": 2174, "440": 2058, "441": 424, "442": 1558, "443": 2769, "444": 2360, "445": 2205, "446": 895, "447": 1126, "448": 910, "449": 2115, "450": 1016, "451": 1706, "452": 1242, "453": 1037, "454": 1670, "455": 1037, "456": 1513, "457": 2646, "458": 1795, "459": 2514, "460": 646, "461": 1359, "462": 1544, "463": 2126, "464": 1197, "465": 1878, "466": 874, "467": 938, "468": 2540, "469": 2227, "470": 844, "471": 774, "472": 2591, "473": 1225, "474": 726, "475": 1101, "476": 1420, "477": 278, "478": 2135, "479": 2439, "480": 1608, "481": 2708, "482": 1533, "483": 2620, "484": 1486, "485": 1371, "486": 2348, "487": 2527, "488": 1129, "489": 251, "490": 1151, "491": 1462, "492": 478, "493": 1156, "494": 1762, "495": 1664, "496": 2204, "497": 1791, "498": 2190, "499": 2491, "500": 1650, "501": 943, "502": 1191, "503": 1304, "504": 2010, "505": 1013, "506": 2841, "507": 1142, "508": 3109, "509": 2011, "510": 1332, "511": 2157, "512": 1532, "513": 957, "514": 1458, "515": 737, "516": 2179, "517": 1585, "518": 370, "519": 1841, "520": 905, "521": 973, "522": 1969, "523": 1151, "524": 1395, "525": 3552, "526": 1078, "527": 565, "528": 925, "529": 1929, "530": 2134, "531": 233, "532": 1895, "533": 896, "534": 1568, "535": 1935, "536": 1939, "537": 1726, "538": 1853, "539": 1490, "540": 2992, "541": 1373, "542": 699, "543": 691, "544": 330, "545": 1272, "546": 1492, "547": 1408, "548": 1434, "549": 1170, "550": 1289, "551": 1526, "552": 1775, "553": 3052, "554": 1008, "555": 1022, "556": 1425, "557": 2363, "558": 2122, "559": 1125, "560": 934, "561": 1192, "562": 830, "563": 2369, "564": 792, "565": 1511, "566": 1553, "567": 906, "568": 349, "569": 2097, "570": 1758, "571": 1223, "572": 2680, "573": 105, "574": 2475, "575": 2566, "576": 2932, "577": 454, "578": 892, "579": 976, "580": 2808, "581": 1066, "582": 2825, "583": 1751, "584": 704, "585": 2231, "586": 1139, "587": 502, "588": 1108, "589": 1845, "590": 1516, "591": 1772, "592": 2321, "593": 897, "594": 756, "595": 1112, "596": 1239, "597": 1372, "598": 3007, "599": 1603, "600": 1080, "601": 1507, "602": 294, "603": 1229, "604": 835, "605": 394, "606": 1520, "607": 775, "608": 1383, "609": 39, "610": 906, "611": 1610, "612": 2957, "613": 2102, "614": 1260, "615": 275, "616": 964, "617": 1113, "618": 1600, "619": 616, "620": 779, "621": 751, "622": 1288, "623": 173, "624": 970, "625": 2015, "626": 1932, "627": 1498, "628": 1271, "629": 1918, "630": 1635, "631": 2090, "632": 381, "633": 1223, "634": 1808, "635": 1412, "636": 269, "637": 1165, "638": 1177, "639": 1592, "640": 2262, "641": 1235, "642": 984, "643": 1453, "644": 699, "645": 1567, "646": 1452, "647": 1180, "648": 690, "649": 2349, "650": 1599, "651": 1887, "652": 1586, "653": 659, "654": 1613, "655": 2072, "656": 1750, "657": 2457, "658": 3565, "659": 1872}, "val": {"0": 672, "1": 973, "2": 837, "3": 969, "4": 999, "5": 662, "6": 449, "7": 1011, "8": 995, "9": 1245, "10": 755, "11": 666, "12": 1445, "13": 144, "14": 1512, "15": 1575, "16": 667, "17": 2210, "18": 370, "19": 885, "20": 578, "21": 1770, "22": 1192, "23": 1181, "24": 833, "25": 817, "26": 1233, "27": 384, "28": 832, "29": 896, "30": 455, "31": 1415, "32": 1191, "33": 1215, "34": 1297, "35": 1201, "36": 1328, "37": 1222, "38": 1147, "39": 2386, "40": 1610, "41": 1248, "42": 471, "43": 1778, "44": 1041, "45": 2568, "46": 1159, "47": 1009, "48": 1094, "49": 2073, "50": 645, "51": 609, "52": 1074, "53": 1979, "54": 1495, "55": 707, "56": 1583, "57": 405, "58": 1326, "59": 1647, "60": 1928, "61": 1832, "62": 2114, "63": 862, "64": 2806, "65": 2039, "66": 1533, "67": 2034, "68": 1412, "69": 1746, "70": 931, "71": 3014, "72": 2363, "73": 1819, "74": 1738, "75": 1549, "76": 1346, "77": 1599, "78": 1210, "79": 2029, "80": 1195, "81": 1182, "82": 2160, "83": 1787, "84": 1877, "85": 2031, "86": 752, "87": 1295, "88": 3721, "89": 672, "90": 1354, "91": 1913, "92": 1595, "93": 389, "94": 1000, "95": 2045, "96": 1747, "97": 1738, "98": 833, "99": 1458, "100": 1577, "101": 1708, "102": 862, "103": 1654, "104": 2359, "105": 1005, "106": 1633, "107": 2162, "108": 1395, "109": 2470, "110": 384, "111": 1242, "112": 1208, "113": 2101, "114": 1527, "115": 2232, "116": 1701, "117": 819, "118": 1384, "119": 925, "120": 1678, "121": 515, "122": 1851, "123": 2288, "124": 1641, "125": 1580, "126": 1180, "127": 1109, "128": 361, "129": 1388, "130": 660, "131": 893, "132": 1410, "133": 1663, "134": 1965, "135": 1478, "136": 1650, "137": 2073, "138": 2232, "139": 1277, "140": 1855, "141": 1072, "142": 944, "143": 1437, "144": 1716, "145": 2538, "146": 1144, "147": 1382, "148": 1664, "149": 1895, "150": 2057, "151": 1855, "152": 1714, "153": 1271, "154": 1120, "155": 1433, "156": 1182, "157": 327, "158": 1082, "159": 2962, "160": 1381, "161": 670, "162": 2313, "163": 1881, "164": 1712, "165": 1263, "166": 1363, "167": 1791, "168": 1408, "169": 1854, "170": 1308, "171": 1137, "172": 844, "173": 2051, "174": 1339, "175": 1687, "176": 643, "177": 1159, "178": 507, "179": 1804, "180": 1447, "181": 758, "182": 1423, "183": 1352, "184": 2007, "185": 1089, "186": 2516, "187": 1068, "188": 819, "189": 1408, "190": 199, "191": 1659, "192": 3241, "193": 598, "194": 1286, "195": 2756, "196": 1441, "197": 2038, "198": 1795, "199": 2044, "200": 2497, "201": 812, "202": 1882, "203": 479, "204": 695, "205": 265, "206": 1461, "207": 2226, "208": 516, "209": 915, "210": 1080, "211": 1316, "212": 1155, "213": 1526, "214": 1797, "215": 1464, "216": 903, "217": 1130, "218": 2684, "219": 1415}, "test": {"0": 667, "1": 996, "2": 1089, "3": 423, "4": 1101, "5": 1521, "6": 820, "7": 1144, "8": 1030, "9": 1250, "10": 793, "11": 1354, "12": 1244, "13": 1192, "14": 1080, "15": 889, "16": 397, "17": 1589, "18": 798, "19": 2202, "20": 549, "21": 607, "22": 1416, "23": 1083, "24": 610, "25": 398, "26": 2178, "27": 1425, "28": 698, "29": 1851, "30": 1843, "31": 1555, "32": 973, "33": 549, "34": 981, "35": 1875, "36": 1582, "37": 903, "38": 1746, "39": 1138, "40": 2430, "41": 1132, "42": 1779, "43": 2151, "44": 285, "45": 1277, "46": 1458, "47": 1328, "48": 427, "49": 1941, "50": 1001, "51": 1127, "52": 1784, "53": 1507, "54": 855, "55": 1684, "56": 3096, "57": 1022, "58": 844, "59": 2000, "60": 1302, "61": 817, "62": 2034, "63": 1684, "64": 2346, "65": 760, "66": 388, "67": 1212, "68": 1537, "69": 1463, "70": 2032, "71": 1070, "72": 1981, "73": 968, "74": 881, "75": 1780, "76": 473, "77": 1397, "78": 2398, "79": 1699, "80": 2243, "81": 1661, "82": 1151, "83": 2147, "84": 1063, "85": 1177, "86": 1455, "87": 2466, "88": 448, "89": 2357, "90": 1566, "91": 1471, "92": 2609, "93": 665, "94": 1646, "95": 1728, "96": 1828, "97": 1805, "98": 1238, "99": 1212, "100": 758, "101": 1441, "102": 1189, "103": 1119, "104": 1838, "105": 1978, "106": 798, "107": 1448, "108": 2431, "109": 1345, "110": 1398, "111": 2017, "112": 1115, "113": 1660, "114": 2135, "115": 2357, "116": 1828, "117": 258, "118": 2711, "119": 1087, "120": 2766, "121": 1449, "122": 2761, "123": 919, "124": 1030, "125": 1395, "126": 2532, "127": 1597, "128": 2476, "129": 1792, "130": 530, "131": 1880, "132": 1853, "133": 1828, "134": 1117, "135": 1365, "136": 1027, "137": 2257, "138": 2098, "139": 1318, "140": 1550, "141": 771, "142": 2014, "143": 2967, "144": 1361, "145": 1964, "146": 1443, "147": 1979, "148": 1288, "149": 1195, "150": 1287, "151": 1127, "152": 789, "153": 1443, "154": 2013, "155": 1955, "156": 1626, "157": 2092, "158": 1791, "159": 1764, "160": 1806, "161": 652, "162": 1551, "163": 2489, "164": 1200, "165": 1494, "166": 1449, "167": 1381, "168": 1413, "169": 2335, "170": 397, "171": 1851, "172": 1816, "173": 2610, "174": 2454, "175": 2585, "176": 1256, "177": 2873, "178": 1488, "179": 2359, "180": 1586, "181": 2040, "182": 767, "183": 1975, "184": 1910, "185": 1203, "186": 2218, "187": 1681, "188": 1621, "189": 2666, "190": 2549, "191": 885, "192": 319, "193": 1795, "194": 2884, "195": 405, "196": 728, "197": 542, "198": 1047, "199": 1605, "200": 631, "201": 635, "202": 1096, "203": 2396, "204": 1173, "205": 1874, "206": 1098, "207": 2455, "208": 2569, "209": 1628, "210": 2425, "211": 774, "212": 1237, "213": 1460, "214": 1809, "215": 792, "216": 1400, "217": 1083, "218": 2030, "219": 784}} -------------------------------------------------------------------------------- /data/MovieNet_shot_num.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "tt0035423": 1374, 4 | "tt0045537": 284, 5 | "tt0047396": 793, 6 | "tt0048605": 260, 7 | "tt0049730": 654, 8 | "tt0050706": 311, 9 | "tt0053125": 1197, 10 | "tt0056869": 1437, 11 | "tt0056923": 1236, 12 | "tt0057115": 1208, 13 | "tt0058461": 978, 14 | "tt0059043": 706, 15 | "tt0059592": 333, 16 | "tt0060522": 871, 17 | "tt0061138": 442, 18 | "tt0061418": 1620, 19 | "tt0061781": 660, 20 | "tt0062622": 645, 21 | "tt0064040": 470, 22 | "tt0064276": 955, 23 | "tt0064665": 1271, 24 | "tt0065214": 2469, 25 | "tt0065724": 837, 26 | "tt0065988": 2423, 27 | "tt0066249": 384, 28 | "tt0066921": 724, 29 | "tt0067116": 1166, 30 | "tt0067185": 668, 31 | "tt0068935": 936, 32 | "tt0069293": 325, 33 | "tt0069467": 338, 34 | "tt0069995": 1144, 35 | "tt0070047": 1124, 36 | "tt0070246": 620, 37 | "tt0070379": 843, 38 | "tt0070735": 1079, 39 | "tt0070849": 668, 40 | "tt0071129": 852, 41 | "tt0071315": 529, 42 | "tt0071360": 675, 43 | "tt0072684": 762, 44 | "tt0073440": 747, 45 | "tt0074119": 888, 46 | "tt0074285": 974, 47 | "tt0074686": 726, 48 | "tt0074749": 567, 49 | "tt0075148": 897, 50 | "tt0076729": 1045, 51 | "tt0077402": 2935, 52 | "tt0077405": 903, 53 | "tt0077416": 1449, 54 | "tt0077651": 620, 55 | "tt0078841": 700, 56 | "tt0078908": 791, 57 | "tt0079095": 596, 58 | "tt0079116": 814, 59 | "tt0079417": 922, 60 | "tt0079944": 142, 61 | "tt0079945": 1445, 62 | "tt0080339": 774, 63 | "tt0080453": 1128, 64 | "tt0080745": 1526, 65 | "tt0080958": 866, 66 | "tt0080979": 969, 67 | "tt0081505": 682, 68 | "tt0082186": 1646, 69 | "tt0082846": 756, 70 | "tt0082971": 1489, 71 | "tt0083658": 1381, 72 | "tt0083987": 1849, 73 | "tt0084549": 378, 74 | "tt0084628": 1279, 75 | "tt0084726": 1022, 76 | "tt0084787": 1017, 77 | "tt0085794": 793, 78 | "tt0086250": 1560, 79 | "tt0086837": 1311, 80 | "tt0086879": 2268, 81 | "tt0086969": 825, 82 | "tt0087182": 2016, 83 | "tt0087469": 1885, 84 | "tt0088170": 1115, 85 | "tt0088222": 70, 86 | "tt0088847": 1363, 87 | "tt0088939": 1191, 88 | "tt0088993": 1258, 89 | "tt0090022": 1310, 90 | "tt0090605": 2932, 91 | "tt0091042": 1365, 92 | "tt0091203": 1976, 93 | "tt0091251": 508, 94 | "tt0091406": 194, 95 | "tt0091738": 857, 96 | "tt0091763": 1474, 97 | "tt0092603": 738, 98 | "tt0093010": 1079, 99 | "tt0093209": 811, 100 | "tt0093565": 970, 101 | "tt0093748": 1108, 102 | "tt0093779": 1449, 103 | "tt0094226": 1178, 104 | "tt0094291": 832, 105 | "tt0095016": 1817, 106 | "tt0095497": 1229, 107 | "tt0095956": 2130, 108 | "tt0096463": 630, 109 | "tt0096754": 1874, 110 | "tt0096874": 1420, 111 | "tt0096895": 1724, 112 | "tt0097216": 827, 113 | "tt0097372": 467, 114 | "tt0097428": 1258, 115 | "tt0098258": 876, 116 | "tt0098635": 713, 117 | "tt0098724": 591, 118 | "tt0099348": 2476, 119 | "tt0099487": 863, 120 | "tt0099653": 1465, 121 | "tt0099674": 1386, 122 | "tt0099810": 1263, 123 | "tt0100150": 1164, 124 | "tt0100157": 1244, 125 | "tt0100234": 380, 126 | "tt0100802": 1891, 127 | "tt0100935": 960, 128 | "tt0100998": 300, 129 | "tt0101272": 1604, 130 | "tt0101393": 1756, 131 | "tt0101700": 868, 132 | "tt0101889": 1110, 133 | "tt0101921": 1133, 134 | "tt0102492": 1064, 135 | "tt0102926": 1158, 136 | "tt0103064": 2778, 137 | "tt0103772": 1299, 138 | "tt0103786": 1306, 139 | "tt0104036": 799, 140 | "tt0104257": 1488, 141 | "tt0104348": 869, 142 | "tt0105226": 1082, 143 | "tt0105665": 1532, 144 | "tt0105695": 1616, 145 | "tt0106332": 1147, 146 | "tt0106582": 1601, 147 | "tt0107507": 291, 148 | "tt0107653": 944, 149 | "tt0107736": 266, 150 | "tt0107808": 1760, 151 | "tt0107822": 946, 152 | "tt0108122": 1142, 153 | "tt0108289": 1713, 154 | "tt0108330": 1123, 155 | "tt0109686": 1031, 156 | "tt0109830": 992, 157 | "tt0109831": 1082, 158 | "tt0110074": 1170, 159 | "tt0110148": 1145, 160 | "tt0110201": 1412, 161 | "tt0110322": 1875, 162 | "tt0110632": 2196, 163 | "tt0110912": 1157, 164 | "tt0111003": 1227, 165 | "tt0112769": 675, 166 | "tt0113101": 1086, 167 | "tt0113243": 1293, 168 | "tt0113253": 1184, 169 | "tt0113497": 1706, 170 | "tt0114367": 1564, 171 | "tt0114369": 1610, 172 | "tt0114388": 1238, 173 | "tt0114814": 1131, 174 | "tt0115798": 1159, 175 | "tt0115964": 765, 176 | "tt0116209": 1759, 177 | "tt0116477": 2165, 178 | "tt0116629": 2540, 179 | "tt0116695": 2001, 180 | "tt0117500": 3034, 181 | "tt0117509": 1923, 182 | "tt0117666": 370, 183 | "tt0117731": 1430, 184 | "tt0117883": 1417, 185 | "tt0117951": 1217, 186 | "tt0118636": 1486, 187 | "tt0118655": 1965, 188 | "tt0118688": 2166, 189 | "tt0118689": 1308, 190 | "tt0118749": 919, 191 | "tt0118799": 1142, 192 | "tt0118845": 659, 193 | "tt0118883": 1704, 194 | "tt0118929": 2522, 195 | "tt0118971": 2180, 196 | "tt0119008": 1473, 197 | "tt0119081": 1799, 198 | "tt0119177": 1091, 199 | "tt0119250": 488, 200 | "tt0119314": 2244, 201 | "tt0119396": 1202, 202 | "tt0119528": 1237, 203 | "tt0119567": 1101, 204 | "tt0119643": 1555, 205 | "tt0119654": 1300, 206 | "tt0119670": 1383, 207 | "tt0119738": 1049, 208 | "tt0120263": 24, 209 | "tt0120338": 2890, 210 | "tt0120586": 1621, 211 | "tt0120591": 2368, 212 | "tt0120616": 2233, 213 | "tt0120655": 2023, 214 | "tt0120660": 3459, 215 | "tt0120667": 2298, 216 | "tt0120669": 1048, 217 | "tt0120696": 2021, 218 | "tt0120735": 1099, 219 | "tt0120737": 3902, 220 | "tt0120744": 1699, 221 | "tt0120755": 2776, 222 | "tt0120787": 1348, 223 | "tt0120804": 2103, 224 | "tt0120815": 1391, 225 | "tt0120885": 1257, 226 | "tt0120902": 1698, 227 | "tt0120912": 1442, 228 | "tt0120915": 2259, 229 | "tt0121766": 2231, 230 | "tt0122690": 1877, 231 | "tt0125439": 1462, 232 | "tt0125664": 1652, 233 | "tt0126886": 941, 234 | "tt0128445": 848, 235 | "tt0134119": 1612, 236 | "tt0134273": 1093, 237 | "tt0134847": 2223, 238 | "tt0137494": 1515, 239 | "tt0139134": 1141, 240 | "tt0139654": 1872, 241 | "tt0142688": 1021, 242 | "tt0143145": 2520, 243 | "tt0144084": 963, 244 | "tt0144117": 1330, 245 | "tt0159365": 2104, 246 | "tt0159784": 1546, 247 | "tt0160127": 1536, 248 | "tt0162346": 1112, 249 | "tt0162661": 2198, 250 | "tt0164052": 1557, 251 | "tt0167190": 2394, 252 | "tt0167260": 4043, 253 | "tt0167331": 973, 254 | "tt0169547": 1037, 255 | "tt0171363": 1245, 256 | "tt0175880": 893, 257 | "tt0180073": 1320, 258 | "tt0180093": 1786, 259 | "tt0181689": 1502, 260 | "tt0181875": 2445, 261 | "tt0183523": 901, 262 | "tt0183649": 979, 263 | "tt0187078": 2528, 264 | "tt0187393": 2585, 265 | "tt0190590": 1101, 266 | "tt0195685": 1395, 267 | "tt0199354": 704, 268 | "tt0199753": 1575, 269 | "tt0203009": 3117, 270 | "tt0206634": 444, 271 | "tt0207201": 1818, 272 | "tt0208092": 1481, 273 | "tt0209144": 1847, 274 | "tt0209463": 236, 275 | "tt0210727": 793, 276 | "tt0212338": 1754, 277 | "tt0213149": 3582, 278 | "tt0227445": 1189, 279 | "tt0232500": 2366, 280 | "tt0234215": 2724, 281 | "tt0240772": 1182, 282 | "tt0242653": 2588, 283 | "tt0243876": 1436, 284 | "tt0244244": 1694, 285 | "tt0244353": 2643, 286 | "tt0245844": 1362, 287 | "tt0246578": 1278, 288 | "tt0250494": 1431, 289 | "tt0250797": 1655, 290 | "tt0251160": 1815, 291 | "tt0253754": 2043, 292 | "tt0258000": 1961, 293 | "tt0264395": 2099, 294 | "tt0264616": 781, 295 | "tt0266697": 1283, 296 | "tt0266915": 1378, 297 | "tt0268695": 1267, 298 | "tt0272152": 1437, 299 | "tt0275719": 1237, 300 | "tt0278504": 2424, 301 | "tt0283509": 821, 302 | "tt0286106": 639, 303 | "tt0288477": 2081, 304 | "tt0290334": 2564, 305 | "tt0294870": 1405, 306 | "tt0299658": 2335, 307 | "tt0308476": 427, 308 | "tt0309698": 1496, 309 | "tt0313542": 2712, 310 | "tt0315327": 1383, 311 | "tt0316654": 2087, 312 | "tt0317198": 1251, 313 | "tt0317919": 3161, 314 | "tt0318627": 1798, 315 | "tt0318974": 2147, 316 | "tt0325710": 3003, 317 | "tt0325980": 2670, 318 | "tt0328107": 3370, 319 | "tt0329101": 1662, 320 | "tt0331811": 1229, 321 | "tt0332452": 1833, 322 | "tt0335266": 902, 323 | "tt0337978": 3354, 324 | "tt0338013": 1317, 325 | "tt0338751": 2782, 326 | "tt0343660": 1601, 327 | "tt0346094": 245, 328 | "tt0349903": 1036, 329 | "tt0350258": 1935, 330 | "tt0351977": 1534, 331 | "tt0357413": 1243, 332 | "tt0359950": 1521, 333 | "tt0362227": 977, 334 | "tt0363589": 116, 335 | "tt0363771": 2973, 336 | "tt0365907": 1250, 337 | "tt0369339": 1768, 338 | "tt0369702": 954, 339 | "tt0371257": 1676, 340 | "tt0372183": 2300, 341 | "tt0372784": 3001, 342 | "tt0372824": 974, 343 | "tt0373074": 1355, 344 | "tt0374546": 875, 345 | "tt0375679": 1332, 346 | "tt0376994": 1907, 347 | "tt0377713": 924, 348 | "tt0378194": 1232, 349 | "tt0379306": 1011, 350 | "tt0382625": 2778, 351 | "tt0383028": 1513, 352 | "tt0383216": 1480, 353 | "tt0383574": 2600, 354 | "tt0385004": 1790, 355 | "tt0387564": 1911, 356 | "tt0387877": 955, 357 | "tt0388795": 1040, 358 | "tt0390022": 3265, 359 | "tt0393109": 1285, 360 | "tt0395699": 1716, 361 | "tt0397078": 1866, 362 | "tt0398027": 775, 363 | "tt0399295": 1724, 364 | "tt0405159": 1481, 365 | "tt0407887": 2662, 366 | "tt0408306": 1263, 367 | "tt0408790": 1177, 368 | "tt0413893": 649, 369 | "tt0414055": 1185, 370 | "tt0414387": 1079, 371 | "tt0414982": 1886, 372 | "tt0415380": 749, 373 | "tt0417741": 1431, 374 | "tt0418279": 2963, 375 | "tt0418819": 1531, 376 | "tt0419887": 1979, 377 | "tt0420223": 1703, 378 | "tt0421715": 1162, 379 | "tt0424345": 1360, 380 | "tt0425061": 2444, 381 | "tt0425210": 1468, 382 | "tt0427309": 1629, 383 | "tt0427954": 2229, 384 | "tt0430357": 1770, 385 | "tt0433035": 2237, 386 | "tt0435705": 1678, 387 | "tt0439815": 1536, 388 | "tt0443453": 1205, 389 | "tt0443706": 1804, 390 | "tt0448157": 2050, 391 | "tt0449088": 2972, 392 | "tt0450259": 1641, 393 | "tt0450385": 2144, 394 | "tt0454841": 1686, 395 | "tt0454876": 1241, 396 | "tt0454921": 1078, 397 | "tt0457939": 1919, 398 | "tt0458413": 972, 399 | "tt0462200": 1364, 400 | "tt0467406": 1098, 401 | "tt0468565": 1133, 402 | "tt0468569": 2578, 403 | "tt0473705": 1753, 404 | "tt0475293": 1782, 405 | "tt0477348": 1326, 406 | "tt0479997": 1550, 407 | "tt0489018": 1851, 408 | "tt0493464": 2321, 409 | "tt0499448": 2278, 410 | "tt0499549": 2724, 411 | "tt0758774": 1948, 412 | "tt0765128": 528, 413 | "tt0765429": 2545, 414 | "tt0765447": 913, 415 | "tt0780504": 824, 416 | "tt0790636": 1219, 417 | "tt0790686": 1468, 418 | "tt0796366": 2270, 419 | "tt0800320": 1973, 420 | "tt0810819": 1297, 421 | "tt0815236": 1514, 422 | "tt0824747": 1739, 423 | "tt0826711": 134, 424 | "tt0829482": 1664, 425 | "tt0844286": 1495, 426 | "tt0846308": 1766, 427 | "tt0848228": 2873, 428 | "tt0862846": 1155, 429 | "tt0887883": 1088, 430 | "tt0913425": 855, 431 | "tt0914798": 731, 432 | "tt0942385": 2365, 433 | "tt0947798": 1568, 434 | "tt0958860": 1246, 435 | "tt0959337": 922, 436 | "tt0963794": 1280, 437 | "tt0963966": 2593, 438 | "tt0970416": 1477, 439 | "tt0974661": 1571, 440 | "tt0975645": 1200, 441 | "tt0977855": 1261, 442 | "tt0985694": 2174, 443 | "tt0985699": 2058, 444 | "tt0986233": 424, 445 | "tt0986263": 1558, 446 | "tt0988045": 2769, 447 | "tt0993846": 2360, 448 | "tt1010048": 2205, 449 | "tt1013753": 895, 450 | "tt1016268": 1126, 451 | "tt1022603": 910, 452 | "tt1024648": 2115, 453 | "tt1027718": 1016, 454 | "tt1029234": 1706, 455 | "tt1029360": 1242, 456 | "tt1037705": 1037, 457 | "tt1045658": 1670, 458 | "tt1045772": 1037, 459 | "tt1054606": 1513, 460 | "tt1055369": 2646, 461 | "tt1057500": 1795, 462 | "tt1059786": 2514, 463 | "tt1068649": 646, 464 | "tt1068680": 1359, 465 | "tt1072748": 1544, 466 | "tt1074638": 2126, 467 | "tt1084950": 1197, 468 | "tt1104001": 1878, 469 | "tt1124035": 874, 470 | "tt1125849": 938, 471 | "tt1131729": 2540, 472 | "tt1133985": 2227, 473 | "tt1135952": 844, 474 | "tt1139797": 774, 475 | "tt1148204": 2591, 476 | "tt1156466": 1225, 477 | "tt1158278": 726, 478 | "tt1174732": 1101, 479 | "tt1179031": 1420, 480 | "tt1179904": 278, 481 | "tt1186367": 2135, 482 | "tt1188729": 2439, 483 | "tt1193138": 1608, 484 | "tt1194173": 2708, 485 | "tt1210166": 1533, 486 | "tt1217613": 2620, 487 | "tt1219289": 1486, 488 | "tt1220719": 1371, 489 | "tt1228705": 2348, 490 | "tt1229340": 2527, 491 | "tt1229822": 1129, 492 | "tt1233381": 251, 493 | "tt1244754": 1151, 494 | "tt1253863": 1462, 495 | "tt1255953": 478, 496 | "tt1274586": 1156, 497 | "tt1276104": 1762, 498 | "tt1282140": 1664, 499 | "tt1285016": 2204, 500 | "tt1291150": 1791, 501 | "tt1291584": 2190, 502 | "tt1298650": 2491, 503 | "tt1300851": 1650, 504 | "tt1305806": 943, 505 | "tt1306980": 1191, 506 | "tt1322269": 1304, 507 | "tt1324999": 2010, 508 | "tt1340800": 1013, 509 | "tt1343092": 2841, 510 | "tt1360860": 1142, 511 | "tt1371111": 3109, 512 | "tt1375670": 2011, 513 | "tt1396218": 1332, 514 | "tt1401152": 2157, 515 | "tt1403865": 1532, 516 | "tt1411238": 957, 517 | "tt1438176": 1458, 518 | "tt1439572": 737, 519 | "tt1446714": 2179, 520 | "tt1454029": 1585, 521 | "tt1454468": 370, 522 | "tt1458175": 1841, 523 | "tt1462758": 905, 524 | "tt1468846": 973, 525 | "tt1478338": 1969, 526 | "tt1486190": 1151, 527 | "tt1502712": 1395, 528 | "tt1533117": 3552, 529 | "tt1535970": 1078, 530 | "tt1560747": 565, 531 | "tt1563738": 925, 532 | "tt1564367": 1929, 533 | "tt1568346": 2134, 534 | "tt1602620": 233, 535 | "tt1606378": 1895, 536 | "tt1615147": 896, 537 | "tt1616195": 1568, 538 | "tt1628841": 1935, 539 | "tt1637725": 1939, 540 | "tt1646987": 1726, 541 | "tt1649443": 1853, 542 | "tt1655420": 1490, 543 | "tt1670345": 2992, 544 | "tt1675434": 1373, 545 | "tt1692486": 699, 546 | "tt1706593": 691, 547 | "tt1723811": 330, 548 | "tt1747958": 1272, 549 | "tt1757746": 1492, 550 | "tt1781769": 1408, 551 | "tt1800241": 1434, 552 | "tt1800246": 1170, 553 | "tt1809398": 1289, 554 | "tt1832382": 1526, 555 | "tt1855325": 1775, 556 | "tt1877832": 3052, 557 | "tt1907668": 1008, 558 | "tt1951266": 1022, 559 | "tt1971325": 1425, 560 | "tt1979320": 2363, 561 | "tt1981115": 2122, 562 | "tt2017561": 1125, 563 | "tt2053463": 934, 564 | "tt2056771": 1192, 565 | "tt2058107": 830, 566 | "tt2058673": 2369, 567 | "tt2059255": 792, 568 | "tt2070649": 1511, 569 | "tt2084970": 1553, 570 | "tt2103281": 906, 571 | "tt2109184": 349, 572 | "tt2118775": 2097, 573 | "tt2140373": 1758, 574 | "tt2167266": 1223, 575 | "tt2238032": 2680, 576 | "tt2258281": 105, 577 | "tt2267998": 2475, 578 | "tt2294449": 2566, 579 | "tt2310332": 2932, 580 | "tt2334873": 454, 581 | "tt2345567": 892, 582 | "tt2366450": 976, 583 | "tt2381249": 2808, 584 | "tt2382298": 1066, 585 | "tt2404435": 2825, 586 | "tt2463288": 1751, 587 | "tt2473794": 704, 588 | "tt2567026": 2231, 589 | "tt2582802": 1139, 590 | "tt2639344": 502, 591 | "tt2675914": 1108, 592 | "tt2713180": 1845, 593 | "tt2717822": 1516, 594 | "tt2800240": 1772, 595 | "tt2823054": 2321, 596 | "tt2884018": 897, 597 | "tt2908856": 756, 598 | "tt2911666": 1112, 599 | "tt2923316": 1239, 600 | "tt2980516": 1372, 601 | "tt3062096": 3007, 602 | "tt3064298": 1603, 603 | "tt3077214": 1080, 604 | "tt3289956": 1507, 605 | "tt3296658": 294, 606 | "tt3312830": 1229, 607 | "tt3319920": 835, 608 | "tt3395184": 394, 609 | "tt3410834": 1520, 610 | "tt3416744": 775, 611 | "tt3439114": 1383, 612 | "tt3465916": 39, 613 | "tt3474602": 906, 614 | "tt3478232": 1610, 615 | "tt3498820": 2957, 616 | "tt3501416": 2102, 617 | "tt3531578": 1260, 618 | "tt3630276": 275, 619 | "tt3659786": 964, 620 | "tt3671542": 1113, 621 | "tt3700392": 1600, 622 | "tt3700804": 616, 623 | "tt3707106": 779, 624 | "tt3714720": 751, 625 | "tt3766394": 1288, 626 | "tt3808342": 173, 627 | "tt3860916": 970, 628 | "tt3960412": 2015, 629 | "tt4046784": 1932, 630 | "tt4052882": 1498, 631 | "tt4136084": 1271, 632 | "tt4151192": 1918, 633 | "tt4176826": 1635, 634 | "tt4242158": 2090, 635 | "tt4273292": 381, 636 | "tt4501454": 1223, 637 | "tt4651520": 1808, 638 | "tt4698684": 1412, 639 | "tt4721400": 269, 640 | "tt4781612": 1165, 641 | "tt4786282": 1177, 642 | "tt4824302": 1592, 643 | "tt4939066": 2262, 644 | "tt5052448": 1235, 645 | "tt5065810": 984, 646 | "tt5294550": 1453, 647 | "tt5564148": 699, 648 | "tt5576318": 1567, 649 | "tt5580036": 1452, 650 | "tt5593416": 1180, 651 | "tt5649144": 690, 652 | "tt5688868": 2349, 653 | "tt5827496": 1599, 654 | "tt5866930": 1887, 655 | "tt6133130": 1586, 656 | "tt6298600": 659, 657 | "tt6466464": 1613, 658 | "tt6513406": 2072, 659 | "tt6788942": 1750, 660 | "tt7055592": 2457, 661 | "tt7131870": 3565, 662 | "tt7180392": 1872 663 | }, 664 | "val": { 665 | "tt0032138": 672, 666 | "tt0038650": 973, 667 | "tt0048545": 837, 668 | "tt0053221": 969, 669 | "tt0053579": 999, 670 | "tt0054167": 662, 671 | "tt0061722": 449, 672 | "tt0064115": 1011, 673 | "tt0066026": 995, 674 | "tt0067140": 1245, 675 | "tt0069762": 755, 676 | "tt0070245": 666, 677 | "tt0071562": 1445, 678 | "tt0072443": 144, 679 | "tt0072890": 1512, 680 | "tt0073486": 1575, 681 | "tt0074811": 667, 682 | "tt0076759": 2210, 683 | "tt0079182": 370, 684 | "tt0079470": 885, 685 | "tt0080610": 578, 686 | "tt0080684": 1770, 687 | "tt0083866": 1192, 688 | "tt0083929": 1181, 689 | "tt0084899": 833, 690 | "tt0085991": 817, 691 | "tt0087332": 1233, 692 | "tt0089853": 384, 693 | "tt0089907": 832, 694 | "tt0090756": 896, 695 | "tt0091355": 455, 696 | "tt0091369": 1415, 697 | "tt0092699": 1191, 698 | "tt0092991": 1215, 699 | "tt0094737": 1297, 700 | "tt0094761": 1201, 701 | "tt0095765": 1328, 702 | "tt0095953": 1222, 703 | "tt0096256": 1147, 704 | "tt0096446": 2386, 705 | "tt0097576": 1610, 706 | "tt0099685": 1248, 707 | "tt0100112": 471, 708 | "tt0100403": 1778, 709 | "tt0101410": 1041, 710 | "tt0102138": 2568, 711 | "tt0103074": 1159, 712 | "tt0103241": 1009, 713 | "tt0103292": 1094, 714 | "tt0104797": 2073, 715 | "tt0105236": 645, 716 | "tt0105652": 609, 717 | "tt0106226": 1074, 718 | "tt0106977": 1979, 719 | "tt0107614": 1495, 720 | "tt0108160": 707, 721 | "tt0108656": 1583, 722 | "tt0109020": 405, 723 | "tt0110475": 1326, 724 | "tt0110932": 1647, 725 | "tt0112462": 1928, 726 | "tt0112641": 1832, 727 | "tt0112740": 2114, 728 | "tt0113870": 862, 729 | "tt0114558": 2806, 730 | "tt0116367": 2039, 731 | "tt0116996": 1533, 732 | "tt0117381": 2034, 733 | "tt0118548": 1412, 734 | "tt0118571": 1746, 735 | "tt0118842": 931, 736 | "tt0119094": 3014, 737 | "tt0119116": 2363, 738 | "tt0119174": 1819, 739 | "tt0119822": 1738, 740 | "tt0120483": 1549, 741 | "tt0120601": 1346, 742 | "tt0120780": 1599, 743 | "tt0120863": 1210, 744 | "tt0121765": 2029, 745 | "tt0122933": 1195, 746 | "tt0129387": 1182, 747 | "tt0133093": 2160, 748 | "tt0138097": 1787, 749 | "tt0140352": 1877, 750 | "tt0145487": 2031, 751 | "tt0166896": 752, 752 | "tt0166924": 1295, 753 | "tt0167261": 3721, 754 | "tt0167404": 672, 755 | "tt0182789": 1354, 756 | "tt0186151": 1913, 757 | "tt0209958": 1595, 758 | "tt0217869": 389, 759 | "tt0240890": 1000, 760 | "tt0248667": 2045, 761 | "tt0258463": 1747, 762 | "tt0261392": 1738, 763 | "tt0265666": 833, 764 | "tt0268126": 1458, 765 | "tt0268978": 1577, 766 | "tt0277027": 1708, 767 | "tt0285742": 862, 768 | "tt0289879": 1654, 769 | "tt0290002": 2359, 770 | "tt0298228": 1005, 771 | "tt0311113": 1633, 772 | "tt0317740": 2162, 773 | "tt0319262": 1395, 774 | "tt0322259": 2470, 775 | "tt0324197": 384, 776 | "tt0337921": 1242, 777 | "tt0341495": 1208, 778 | "tt0343818": 2101, 779 | "tt0360486": 1527, 780 | "tt0370263": 2232, 781 | "tt0371724": 1701, 782 | "tt0375063": 819, 783 | "tt0395169": 1384, 784 | "tt0401383": 925, 785 | "tt0408236": 1678, 786 | "tt0416320": 515, 787 | "tt0432021": 1851, 788 | "tt0434409": 2288, 789 | "tt0454848": 1641, 790 | "tt0455760": 1580, 791 | "tt0457297": 1180, 792 | "tt0457430": 1109, 793 | "tt0457513": 361, 794 | "tt0467200": 1388, 795 | "tt0469494": 660, 796 | "tt0470752": 893, 797 | "tt0480025": 1410, 798 | "tt0758730": 1663, 799 | "tt0758758": 1965, 800 | "tt0780653": 1478, 801 | "tt0790628": 1650, 802 | "tt0808151": 2073, 803 | "tt0816692": 2232, 804 | "tt0824758": 1277, 805 | "tt0838232": 1855, 806 | "tt0898367": 1072, 807 | "tt0940709": 944, 808 | "tt0964517": 1437, 809 | "tt0993842": 1716, 810 | "tt1000774": 2538, 811 | "tt1019452": 1144, 812 | "tt1032755": 1382, 813 | "tt1041829": 1664, 814 | "tt1055292": 1895, 815 | "tt1065073": 2057, 816 | "tt1071875": 1855, 817 | "tt1073498": 1714, 818 | "tt1093906": 1271, 819 | "tt1100089": 1120, 820 | "tt1144884": 1433, 821 | "tt1172049": 1182, 822 | "tt1178663": 327, 823 | "tt1182345": 1082, 824 | "tt1190080": 2962, 825 | "tt1211837": 1381, 826 | "tt1216496": 670, 827 | "tt1232829": 2313, 828 | "tt1284575": 1881, 829 | "tt1341167": 1712, 830 | "tt1355683": 1263, 831 | "tt1385826": 1363, 832 | "tt1409024": 1791, 833 | "tt1441953": 1408, 834 | "tt1462900": 1854, 835 | "tt1504320": 1308, 836 | "tt1540133": 1137, 837 | "tt1582248": 844, 838 | "tt1586752": 2051, 839 | "tt1591095": 1339, 840 | "tt1596363": 1687, 841 | "tt1602613": 643, 842 | "tt1611840": 1159, 843 | "tt1619029": 507, 844 | "tt1645170": 1804, 845 | "tt1659337": 1447, 846 | "tt1703957": 758, 847 | "tt1722484": 1423, 848 | "tt1725986": 1352, 849 | "tt1731141": 2007, 850 | "tt1742683": 1089, 851 | "tt1840309": 2516, 852 | "tt1895587": 1068, 853 | "tt1974419": 819, 854 | "tt2032557": 1408, 855 | "tt2076220": 199, 856 | "tt2078768": 1659, 857 | "tt2109248": 3241, 858 | "tt2132285": 598, 859 | "tt2381991": 1286, 860 | "tt2645044": 2756, 861 | "tt2788732": 1441, 862 | "tt2832470": 2038, 863 | "tt2872732": 1795, 864 | "tt2978462": 2044, 865 | "tt3110958": 2497, 866 | "tt3316960": 812, 867 | "tt3421514": 1882, 868 | "tt3464902": 479, 869 | "tt3488710": 695, 870 | "tt3508840": 265, 871 | "tt3553442": 1461, 872 | "tt3672840": 2226, 873 | "tt3726704": 516, 874 | "tt3824458": 915, 875 | "tt3882082": 1080, 876 | "tt3922798": 1316, 877 | "tt4160708": 1155, 878 | "tt4647900": 1526, 879 | "tt4967094": 1797, 880 | "tt5726086": 1464, 881 | "tt6121428": 903, 882 | "tt6190198": 1130, 883 | "tt7160070": 2684, 884 | "tt7672188": 1415 885 | }, 886 | "test": { 887 | "tt0048028": 667, 888 | "tt0049470": 996, 889 | "tt0049833": 1089, 890 | "tt0050419": 423, 891 | "tt0052357": 1101, 892 | "tt0058331": 1521, 893 | "tt0061811": 820, 894 | "tt0063442": 1144, 895 | "tt0066206": 1030, 896 | "tt0068646": 1250, 897 | "tt0070291": 793, 898 | "tt0070511": 1354, 899 | "tt0073195": 1244, 900 | "tt0073582": 1192, 901 | "tt0073629": 1080, 902 | "tt0075314": 889, 903 | "tt0075686": 397, 904 | "tt0078788": 1589, 905 | "tt0079672": 798, 906 | "tt0080455": 2202, 907 | "tt0080761": 549, 908 | "tt0082089": 607, 909 | "tt0082198": 1416, 910 | "tt0083907": 1083, 911 | "tt0083946": 610, 912 | "tt0084390": 398, 913 | "tt0086190": 2178, 914 | "tt0086856": 1425, 915 | "tt0087921": 698, 916 | "tt0088247": 1851, 917 | "tt0088944": 1843, 918 | "tt0089218": 1555, 919 | "tt0089881": 973, 920 | "tt0090257": 549, 921 | "tt0091867": 981, 922 | "tt0092099": 1875, 923 | "tt0093773": 1582, 924 | "tt0094964": 903, 925 | "tt0095250": 1746, 926 | "tt0096320": 1138, 927 | "tt0099423": 2430, 928 | "tt0100405": 1132, 929 | "tt0103776": 1779, 930 | "tt0103855": 2151, 931 | "tt0104466": 285, 932 | "tt0104553": 1277, 933 | "tt0104691": 1458, 934 | "tt0107290": 1328, 935 | "tt0107617": 427, 936 | "tt0108399": 1941, 937 | "tt0110116": 1001, 938 | "tt0110167": 1127, 939 | "tt0110604": 1784, 940 | "tt0111280": 1507, 941 | "tt0111797": 855, 942 | "tt0112384": 1684, 943 | "tt0112573": 3096, 944 | "tt0112818": 1022, 945 | "tt0112883": 844, 946 | "tt0113277": 2000, 947 | "tt0114746": 1302, 948 | "tt0115734": 817, 949 | "tt0115759": 2034, 950 | "tt0115956": 1684, 951 | "tt0116213": 2346, 952 | "tt0116282": 760, 953 | "tt0116767": 388, 954 | "tt0116922": 1212, 955 | "tt0117060": 1537, 956 | "tt0117571": 1463, 957 | "tt0118583": 2032, 958 | "tt0118715": 1070, 959 | "tt0119303": 1981, 960 | "tt0119349": 968, 961 | "tt0119375": 881, 962 | "tt0119488": 1780, 963 | "tt0120255": 473, 964 | "tt0120382": 1397, 965 | "tt0120689": 2398, 966 | "tt0120731": 1699, 967 | "tt0120738": 2243, 968 | "tt0120812": 1661, 969 | "tt0120890": 1151, 970 | "tt0120903": 2147, 971 | "tt0123755": 1063, 972 | "tt0124315": 1177, 973 | "tt0127536": 1455, 974 | "tt0133152": 2466, 975 | "tt0137439": 448, 976 | "tt0137523": 2357, 977 | "tt0142342": 1566, 978 | "tt0163025": 1471, 979 | "tt0172495": 2609, 980 | "tt0178868": 665, 981 | "tt0190332": 1646, 982 | "tt0195714": 1728, 983 | "tt0212985": 1828, 984 | "tt0217505": 1805, 985 | "tt0219822": 1238, 986 | "tt0253474": 1212, 987 | "tt0257360": 758, 988 | "tt0280609": 1441, 989 | "tt0281358": 1189, 990 | "tt0281686": 1119, 991 | "tt0319061": 1838, 992 | "tt0330373": 1978, 993 | "tt0335119": 798, 994 | "tt0361748": 1448, 995 | "tt0368891": 2431, 996 | "tt0368933": 1345, 997 | "tt0369441": 1398, 998 | "tt0370032": 2017, 999 | "tt0373051": 1115, 1000 | "tt0373469": 1660, 1001 | "tt0379786": 2135, 1002 | "tt0381061": 2357, 1003 | "tt0386588": 1828, 1004 | "tt0387898": 258, 1005 | "tt0399201": 2711, 1006 | "tt0404978": 1087, 1007 | "tt0409459": 2766, 1008 | "tt0416508": 1449, 1009 | "tt0440963": 2761, 1010 | "tt0443272": 919, 1011 | "tt0443680": 1030, 1012 | "tt0452625": 1395, 1013 | "tt0455824": 2532, 1014 | "tt0458352": 1597, 1015 | "tt0458525": 2476, 1016 | "tt0460791": 1792, 1017 | "tt0460989": 530, 1018 | "tt0462499": 1880, 1019 | "tt0477347": 1853, 1020 | "tt0479884": 1828, 1021 | "tt0481369": 1117, 1022 | "tt0780571": 1365, 1023 | "tt0783233": 1027, 1024 | "tt0800080": 2257, 1025 | "tt0800369": 2098, 1026 | "tt0815245": 1318, 1027 | "tt0822832": 1550, 1028 | "tt0844347": 771, 1029 | "tt0878804": 2014, 1030 | "tt0903624": 2967, 1031 | "tt0905372": 1361, 1032 | "tt0944835": 1964, 1033 | "tt0945513": 1443, 1034 | "tt0970179": 1979, 1035 | "tt0976051": 1288, 1036 | "tt1001508": 1195, 1037 | "tt1007029": 1287, 1038 | "tt1017460": 1127, 1039 | "tt1033575": 789, 1040 | "tt1034314": 1443, 1041 | "tt1038919": 2013, 1042 | "tt1046173": 1955, 1043 | "tt1063669": 1626, 1044 | "tt1086772": 2092, 1045 | "tt1092026": 1791, 1046 | "tt1099212": 1764, 1047 | "tt1119646": 1806, 1048 | "tt1120985": 652, 1049 | "tt1124037": 1551, 1050 | "tt1170358": 2489, 1051 | "tt1181614": 1200, 1052 | "tt1189340": 1494, 1053 | "tt1201607": 1449, 1054 | "tt1205489": 1381, 1055 | "tt1220634": 1413, 1056 | "tt1229238": 2335, 1057 | "tt1287878": 397, 1058 | "tt1292566": 1851, 1059 | "tt1318514": 1816, 1060 | "tt1375666": 2610, 1061 | "tt1386932": 2454, 1062 | "tt1392190": 2585, 1063 | "tt1397514": 1256, 1064 | "tt1399103": 2873, 1065 | "tt1412386": 1488, 1066 | "tt1413492": 2359, 1067 | "tt1424381": 1586, 1068 | "tt1431045": 2040, 1069 | "tt1440728": 767, 1070 | "tt1446147": 1975, 1071 | "tt1483013": 1910, 1072 | "tt1510906": 1203, 1073 | "tt1524137": 2218, 1074 | "tt1570728": 1681, 1075 | "tt1623205": 1621, 1076 | "tt1663662": 2666, 1077 | "tt1707386": 2549, 1078 | "tt1748122": 885, 1079 | "tt1843287": 319, 1080 | "tt1853728": 1795, 1081 | "tt1872181": 2884, 1082 | "tt2011351": 405, 1083 | "tt2024544": 728, 1084 | "tt2099556": 542, 1085 | "tt2115388": 1047, 1086 | "tt2194499": 1605, 1087 | "tt2402927": 631, 1088 | "tt2409818": 635, 1089 | "tt2446980": 1096, 1090 | "tt2488496": 2396, 1091 | "tt2567712": 1173, 1092 | "tt2582846": 1874, 1093 | "tt2614684": 1098, 1094 | "tt2802144": 2455, 1095 | "tt3385516": 2569, 1096 | "tt3480796": 1628, 1097 | "tt3495026": 2425, 1098 | "tt4008652": 774, 1099 | "tt4034354": 1237, 1100 | "tt4520364": 1460, 1101 | "tt4915672": 1809, 1102 | "tt4972062": 792, 1103 | "tt5140878": 1400, 1104 | "tt6157626": 1083, 1105 | "tt6518634": 2030, 1106 | "tt6644200": 784 1107 | } 1108 | } -------------------------------------------------------------------------------- /data/data_preparation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import json 5 | 6 | # Concate 16 shot images into a single image, 7 | # the concated images are used for speeding up pre-training. 8 | # Matrix size of the concated image: [16x3] 9 | def concate_pic(shot_info, img_path, save_path, row=16): 10 | for imdb, shot_num in shot_info.items(): 11 | pic_num = shot_num // row 12 | for item in range(pic_num): 13 | img_list = [] 14 | for idx in range(row): 15 | shot_id = item * row + idx 16 | img_name_0 = f"{img_path}/{imdb}/shot_{str(shot_id).zfill(4)}_img_0.jpg" 17 | img_name_1 = f"{img_path}/{imdb}/shot_{str(shot_id).zfill(4)}_img_1.jpg" 18 | img_name_2 = f"{img_path}/{imdb}/shot_{str(shot_id).zfill(4)}_img_2.jpg" 19 | img_0 = cv2.imread(img_name_0) 20 | img_1 = cv2.imread(img_name_1) 21 | img_2 = cv2.imread(img_name_2) 22 | img = np.concatenate([img_0,img_1,img_2],axis=1) 23 | img_list.append(img) 24 | full_img = np.concatenate(img_list,axis=0) 25 | # print(img.shape) 26 | # print(full_img.shape) 27 | new_pic_dir = f"{save_path}/{imdb}/" 28 | if not os.path.isdir(new_pic_dir): 29 | os.makedirs(new_pic_dir) 30 | filename = new_pic_dir + str(item).zfill(4) + '.jpg' 31 | cv2.imwrite(filename, full_img) 32 | 33 | # Number of shot in each movie 34 | def _generate_shot_num(new_shot_info='./MovieNet_shot_num.json'): 35 | shot_info = './MovieNet_1.0_shotinfo.json' 36 | shot_split = './movie1K.split.v1.json' 37 | with open(shot_info, 'rb') as f: 38 | shot_info_data = json.load(f) 39 | with open(shot_split, 'rb') as f: 40 | shot_split_data = json.load(f) 41 | new_shot_info_data = {} 42 | _type = ['train','val','test'] 43 | for _t in _type: 44 | new_shot_info_data[_t] = {} 45 | _movie_list = shot_split_data[_t] 46 | for idx, imdb_id in enumerate(_movie_list): 47 | shot_num = shot_info_data[_t][str(idx)] 48 | new_shot_info_data[_t][imdb_id] = shot_num 49 | with open(new_shot_info, 'w') as f: 50 | json.dump(new_shot_info_data, f, indent=4) 51 | 52 | 53 | def process_raw_label(_T = 'train', raw_root_dir = './'): 54 | split = 'movie1K.split.v1.json' 55 | data_dict = json.load(open(os.path.join(raw_root_dir,split))) 56 | 57 | # print(data_dict.keys()) 58 | # dict_keys(['train', 'val', 'test', 'full']) 59 | # print(len(data_dict['train'])) # 660 60 | # print(len(data_dict['val'])) # 220 61 | # print(len(data_dict['test'])) # 220 62 | # print(len(data_dict['full'])) # 1100 63 | 64 | data_list = data_dict[_T] 65 | 66 | # annotation 67 | annotation_path = 'annotation' 68 | count = 0 69 | video_list = [] 70 | # all annotations 71 | for index,name in enumerate(data_list): 72 | # print(name) 73 | annotation_file = os.path.join(raw_root_dir, annotation_path, name+'.json') 74 | data = json.load(open(annotation_file)) 75 | # only need sence seg labels 76 | if data['scene'] is not None: 77 | video_list.append({'name':name,'index':index}) 78 | count += 1 79 | print(f'scene annotations num: {count}') 80 | return video_list 81 | 82 | 83 | 84 | # GT generation 85 | def process_scene_seg_lable(scene_seg_path = './CVPR20SceneSeg/data/scene318/label318', 86 | scene_seg_label_json_name = './movie1K.scene_seg_318_name_index_shotnum_label.v1.json', 87 | raw_root_dir = './MovieNet'): 88 | def _process(data): 89 | seg_label = [] 90 | for i in data: 91 | name = i['name'] 92 | index = i['index'] 93 | label = [] 94 | with open (os.path.join(scene_seg_path,name+'.txt'), 'r') as f: 95 | shotnum_label = f.readlines() 96 | for i in shotnum_label: 97 | if ' ' in i: 98 | shot_id = i.split(' ')[0].strip() 99 | l = i.split(' ')[1].strip() 100 | label.append((shot_id,l)) 101 | shot_count = len(label) + 1 102 | seg_label.append({"name":name, "index":index, "shot_count":shot_count, "label":label }) 103 | return seg_label 104 | 105 | train_list = process_raw_label('train',raw_root_dir) 106 | val_list = process_raw_label('val',raw_root_dir) 107 | test_list = process_raw_label('test',raw_root_dir) 108 | data = {'train':train_list, 'val':val_list, 'test':test_list} 109 | 110 | # CVPR20SceneSeg GT 111 | train = _process(data['train']) 112 | test = _process(data['test']) 113 | val = _process(data['val']) 114 | d_all = {'train':train, 'val':val, 'test':test} 115 | 116 | with open(scene_seg_label_json_name,'w') as f: 117 | f.write(json.dumps(d_all)) 118 | 119 | 120 | 121 | if __name__ == '__main__': 122 | # Path of movienet images 123 | img_path = '/MovieNet_unzip/240P' 124 | 125 | # Shot number 126 | shot_info = './MovieNet_shot_num.json' 127 | _generate_shot_num(shot_info) 128 | 129 | # GT label 130 | scene_seg_label_json_name = './movie1K.scene_seg_318_name_index_shotnum_label.v1.json' 131 | ## Download LGSS Annotation from: https://github.com/AnyiRao/SceneSeg/blob/master/docs/INSTALL.md 132 | ## 'scene_seg_path' is the path of the downloaded annotations 133 | scene_seg_path = './CVPR20SceneSeg/data/scene318/label318' 134 | ## Path of raw MovieNet 135 | raw_root_dir = './MovieNet/MovieNet_Ori' 136 | process_scene_seg_lable(scene_seg_path ,scene_seg_label_json_name, raw_root_dir) 137 | 138 | # Concate images 139 | save_path = './compressed_shot_images' 140 | with open(shot_info, 'rb') as f: 141 | shot_info_data = json.load(f) 142 | concate_pic(shot_info_data['train'], img_path, save_path) 143 | 144 | -------------------------------------------------------------------------------- /data/movie1K.split.v1.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": [ 3 | "tt0035423", 4 | "tt0045537", 5 | "tt0047396", 6 | "tt0048605", 7 | "tt0049730", 8 | "tt0050706", 9 | "tt0053125", 10 | "tt0056869", 11 | "tt0056923", 12 | "tt0057115", 13 | "tt0058461", 14 | "tt0059043", 15 | "tt0059592", 16 | "tt0060522", 17 | "tt0061138", 18 | "tt0061418", 19 | "tt0061781", 20 | "tt0062622", 21 | "tt0064040", 22 | "tt0064276", 23 | "tt0064665", 24 | "tt0065214", 25 | "tt0065724", 26 | "tt0065988", 27 | "tt0066249", 28 | "tt0066921", 29 | "tt0067116", 30 | "tt0067185", 31 | "tt0068935", 32 | "tt0069293", 33 | "tt0069467", 34 | "tt0069995", 35 | "tt0070047", 36 | "tt0070246", 37 | "tt0070379", 38 | "tt0070735", 39 | "tt0070849", 40 | "tt0071129", 41 | "tt0071315", 42 | "tt0071360", 43 | "tt0072684", 44 | "tt0073440", 45 | "tt0074119", 46 | "tt0074285", 47 | "tt0074686", 48 | "tt0074749", 49 | "tt0075148", 50 | "tt0076729", 51 | "tt0077402", 52 | "tt0077405", 53 | "tt0077416", 54 | "tt0077651", 55 | "tt0078841", 56 | "tt0078908", 57 | "tt0079095", 58 | "tt0079116", 59 | "tt0079417", 60 | "tt0079944", 61 | "tt0079945", 62 | "tt0080339", 63 | "tt0080453", 64 | "tt0080745", 65 | "tt0080958", 66 | "tt0080979", 67 | "tt0081505", 68 | "tt0082186", 69 | "tt0082846", 70 | "tt0082971", 71 | "tt0083658", 72 | "tt0083987", 73 | "tt0084549", 74 | "tt0084628", 75 | "tt0084726", 76 | "tt0084787", 77 | "tt0085794", 78 | "tt0086250", 79 | "tt0086837", 80 | "tt0086879", 81 | "tt0086969", 82 | "tt0087182", 83 | "tt0087469", 84 | "tt0088170", 85 | "tt0088222", 86 | "tt0088847", 87 | "tt0088939", 88 | "tt0088993", 89 | "tt0090022", 90 | "tt0090605", 91 | "tt0091042", 92 | "tt0091203", 93 | "tt0091251", 94 | "tt0091406", 95 | "tt0091738", 96 | "tt0091763", 97 | "tt0092603", 98 | "tt0093010", 99 | "tt0093209", 100 | "tt0093565", 101 | "tt0093748", 102 | "tt0093779", 103 | "tt0094226", 104 | "tt0094291", 105 | "tt0095016", 106 | "tt0095497", 107 | "tt0095956", 108 | "tt0096463", 109 | "tt0096754", 110 | "tt0096874", 111 | "tt0096895", 112 | "tt0097216", 113 | "tt0097372", 114 | "tt0097428", 115 | "tt0098258", 116 | "tt0098635", 117 | "tt0098724", 118 | "tt0099348", 119 | "tt0099487", 120 | "tt0099653", 121 | "tt0099674", 122 | "tt0099810", 123 | "tt0100150", 124 | "tt0100157", 125 | "tt0100234", 126 | "tt0100802", 127 | "tt0100935", 128 | "tt0100998", 129 | "tt0101272", 130 | "tt0101393", 131 | "tt0101700", 132 | "tt0101889", 133 | "tt0101921", 134 | "tt0102492", 135 | "tt0102926", 136 | "tt0103064", 137 | "tt0103772", 138 | "tt0103786", 139 | "tt0104036", 140 | "tt0104257", 141 | "tt0104348", 142 | "tt0105226", 143 | "tt0105665", 144 | "tt0105695", 145 | "tt0106332", 146 | "tt0106582", 147 | "tt0107507", 148 | "tt0107653", 149 | "tt0107736", 150 | "tt0107808", 151 | "tt0107822", 152 | "tt0108122", 153 | "tt0108289", 154 | "tt0108330", 155 | "tt0109686", 156 | "tt0109830", 157 | "tt0109831", 158 | "tt0110074", 159 | "tt0110148", 160 | "tt0110201", 161 | "tt0110322", 162 | "tt0110632", 163 | "tt0110912", 164 | "tt0111003", 165 | "tt0112769", 166 | "tt0113101", 167 | "tt0113243", 168 | "tt0113253", 169 | "tt0113497", 170 | "tt0114367", 171 | "tt0114369", 172 | "tt0114388", 173 | "tt0114814", 174 | "tt0115798", 175 | "tt0115964", 176 | "tt0116209", 177 | "tt0116477", 178 | "tt0116629", 179 | "tt0116695", 180 | "tt0117500", 181 | "tt0117509", 182 | "tt0117666", 183 | "tt0117731", 184 | "tt0117883", 185 | "tt0117951", 186 | "tt0118636", 187 | "tt0118655", 188 | "tt0118688", 189 | "tt0118689", 190 | "tt0118749", 191 | "tt0118799", 192 | "tt0118845", 193 | "tt0118883", 194 | "tt0118929", 195 | "tt0118971", 196 | "tt0119008", 197 | "tt0119081", 198 | "tt0119177", 199 | "tt0119250", 200 | "tt0119314", 201 | "tt0119396", 202 | "tt0119528", 203 | "tt0119567", 204 | "tt0119643", 205 | "tt0119654", 206 | "tt0119670", 207 | "tt0119738", 208 | "tt0120263", 209 | "tt0120338", 210 | "tt0120586", 211 | "tt0120591", 212 | "tt0120616", 213 | "tt0120655", 214 | "tt0120660", 215 | "tt0120667", 216 | "tt0120669", 217 | "tt0120696", 218 | "tt0120735", 219 | "tt0120737", 220 | "tt0120744", 221 | "tt0120755", 222 | "tt0120787", 223 | "tt0120804", 224 | "tt0120815", 225 | "tt0120885", 226 | "tt0120902", 227 | "tt0120912", 228 | "tt0120915", 229 | "tt0121766", 230 | "tt0122690", 231 | "tt0125439", 232 | "tt0125664", 233 | "tt0126886", 234 | "tt0128445", 235 | "tt0134119", 236 | "tt0134273", 237 | "tt0134847", 238 | "tt0137494", 239 | "tt0139134", 240 | "tt0139654", 241 | "tt0142688", 242 | "tt0143145", 243 | "tt0144084", 244 | "tt0144117", 245 | "tt0159365", 246 | "tt0159784", 247 | "tt0160127", 248 | "tt0162346", 249 | "tt0162661", 250 | "tt0164052", 251 | "tt0167190", 252 | "tt0167260", 253 | "tt0167331", 254 | "tt0169547", 255 | "tt0171363", 256 | "tt0175880", 257 | "tt0180073", 258 | "tt0180093", 259 | "tt0181689", 260 | "tt0181875", 261 | "tt0183523", 262 | "tt0183649", 263 | "tt0187078", 264 | "tt0187393", 265 | "tt0190590", 266 | "tt0195685", 267 | "tt0199354", 268 | "tt0199753", 269 | "tt0203009", 270 | "tt0206634", 271 | "tt0207201", 272 | "tt0208092", 273 | "tt0209144", 274 | "tt0209463", 275 | "tt0210727", 276 | "tt0212338", 277 | "tt0213149", 278 | "tt0227445", 279 | "tt0232500", 280 | "tt0234215", 281 | "tt0240772", 282 | "tt0242653", 283 | "tt0243876", 284 | "tt0244244", 285 | "tt0244353", 286 | "tt0245844", 287 | "tt0246578", 288 | "tt0250494", 289 | "tt0250797", 290 | "tt0251160", 291 | "tt0253754", 292 | "tt0258000", 293 | "tt0264395", 294 | "tt0264616", 295 | "tt0266697", 296 | "tt0266915", 297 | "tt0268695", 298 | "tt0272152", 299 | "tt0275719", 300 | "tt0278504", 301 | "tt0283509", 302 | "tt0286106", 303 | "tt0288477", 304 | "tt0290334", 305 | "tt0294870", 306 | "tt0299658", 307 | "tt0308476", 308 | "tt0309698", 309 | "tt0313542", 310 | "tt0315327", 311 | "tt0316654", 312 | "tt0317198", 313 | "tt0317919", 314 | "tt0318627", 315 | "tt0318974", 316 | "tt0325710", 317 | "tt0325980", 318 | "tt0328107", 319 | "tt0329101", 320 | "tt0331811", 321 | "tt0332452", 322 | "tt0335266", 323 | "tt0337978", 324 | "tt0338013", 325 | "tt0338751", 326 | "tt0343660", 327 | "tt0346094", 328 | "tt0349903", 329 | "tt0350258", 330 | "tt0351977", 331 | "tt0357413", 332 | "tt0359950", 333 | "tt0362227", 334 | "tt0363589", 335 | "tt0363771", 336 | "tt0365907", 337 | "tt0369339", 338 | "tt0369702", 339 | "tt0371257", 340 | "tt0372183", 341 | "tt0372784", 342 | "tt0372824", 343 | "tt0373074", 344 | "tt0374546", 345 | "tt0375679", 346 | "tt0376994", 347 | "tt0377713", 348 | "tt0378194", 349 | "tt0379306", 350 | "tt0382625", 351 | "tt0383028", 352 | "tt0383216", 353 | "tt0383574", 354 | "tt0385004", 355 | "tt0387564", 356 | "tt0387877", 357 | "tt0388795", 358 | "tt0390022", 359 | "tt0393109", 360 | "tt0395699", 361 | "tt0397078", 362 | "tt0398027", 363 | "tt0399295", 364 | "tt0405159", 365 | "tt0407887", 366 | "tt0408306", 367 | "tt0408790", 368 | "tt0413893", 369 | "tt0414055", 370 | "tt0414387", 371 | "tt0414982", 372 | "tt0415380", 373 | "tt0417741", 374 | "tt0418279", 375 | "tt0418819", 376 | "tt0419887", 377 | "tt0420223", 378 | "tt0421715", 379 | "tt0424345", 380 | "tt0425061", 381 | "tt0425210", 382 | "tt0427309", 383 | "tt0427954", 384 | "tt0430357", 385 | "tt0433035", 386 | "tt0435705", 387 | "tt0439815", 388 | "tt0443453", 389 | "tt0443706", 390 | "tt0448157", 391 | "tt0449088", 392 | "tt0450259", 393 | "tt0450385", 394 | "tt0454841", 395 | "tt0454876", 396 | "tt0454921", 397 | "tt0457939", 398 | "tt0458413", 399 | "tt0462200", 400 | "tt0467406", 401 | "tt0468565", 402 | "tt0468569", 403 | "tt0473705", 404 | "tt0475293", 405 | "tt0477348", 406 | "tt0479997", 407 | "tt0489018", 408 | "tt0493464", 409 | "tt0499448", 410 | "tt0499549", 411 | "tt0758774", 412 | "tt0765128", 413 | "tt0765429", 414 | "tt0765447", 415 | "tt0780504", 416 | "tt0790636", 417 | "tt0790686", 418 | "tt0796366", 419 | "tt0800320", 420 | "tt0810819", 421 | "tt0815236", 422 | "tt0824747", 423 | "tt0826711", 424 | "tt0829482", 425 | "tt0844286", 426 | "tt0846308", 427 | "tt0848228", 428 | "tt0862846", 429 | "tt0887883", 430 | "tt0913425", 431 | "tt0914798", 432 | "tt0942385", 433 | "tt0947798", 434 | "tt0958860", 435 | "tt0959337", 436 | "tt0963794", 437 | "tt0963966", 438 | "tt0970416", 439 | "tt0974661", 440 | "tt0975645", 441 | "tt0977855", 442 | "tt0985694", 443 | "tt0985699", 444 | "tt0986233", 445 | "tt0986263", 446 | "tt0988045", 447 | "tt0993846", 448 | "tt1010048", 449 | "tt1013753", 450 | "tt1016268", 451 | "tt1022603", 452 | "tt1024648", 453 | "tt1027718", 454 | "tt1029234", 455 | "tt1029360", 456 | "tt1037705", 457 | "tt1045658", 458 | "tt1045772", 459 | "tt1054606", 460 | "tt1055369", 461 | "tt1057500", 462 | "tt1059786", 463 | "tt1068649", 464 | "tt1068680", 465 | "tt1072748", 466 | "tt1074638", 467 | "tt1084950", 468 | "tt1104001", 469 | "tt1124035", 470 | "tt1125849", 471 | "tt1131729", 472 | "tt1133985", 473 | "tt1135952", 474 | "tt1139797", 475 | "tt1148204", 476 | "tt1156466", 477 | "tt1158278", 478 | "tt1174732", 479 | "tt1179031", 480 | "tt1179904", 481 | "tt1186367", 482 | "tt1188729", 483 | "tt1193138", 484 | "tt1194173", 485 | "tt1210166", 486 | "tt1217613", 487 | "tt1219289", 488 | "tt1220719", 489 | "tt1228705", 490 | "tt1229340", 491 | "tt1229822", 492 | "tt1233381", 493 | "tt1244754", 494 | "tt1253863", 495 | "tt1255953", 496 | "tt1274586", 497 | "tt1276104", 498 | "tt1282140", 499 | "tt1285016", 500 | "tt1291150", 501 | "tt1291584", 502 | "tt1298650", 503 | "tt1300851", 504 | "tt1305806", 505 | "tt1306980", 506 | "tt1322269", 507 | "tt1324999", 508 | "tt1340800", 509 | "tt1343092", 510 | "tt1360860", 511 | "tt1371111", 512 | "tt1375670", 513 | "tt1396218", 514 | "tt1401152", 515 | "tt1403865", 516 | "tt1411238", 517 | "tt1438176", 518 | "tt1439572", 519 | "tt1446714", 520 | "tt1454029", 521 | "tt1454468", 522 | "tt1458175", 523 | "tt1462758", 524 | "tt1468846", 525 | "tt1478338", 526 | "tt1486190", 527 | "tt1502712", 528 | "tt1533117", 529 | "tt1535970", 530 | "tt1560747", 531 | "tt1563738", 532 | "tt1564367", 533 | "tt1568346", 534 | "tt1602620", 535 | "tt1606378", 536 | "tt1615147", 537 | "tt1616195", 538 | "tt1628841", 539 | "tt1637725", 540 | "tt1646987", 541 | "tt1649443", 542 | "tt1655420", 543 | "tt1670345", 544 | "tt1675434", 545 | "tt1692486", 546 | "tt1706593", 547 | "tt1723811", 548 | "tt1747958", 549 | "tt1757746", 550 | "tt1781769", 551 | "tt1800241", 552 | "tt1800246", 553 | "tt1809398", 554 | "tt1832382", 555 | "tt1855325", 556 | "tt1877832", 557 | "tt1907668", 558 | "tt1951266", 559 | "tt1971325", 560 | "tt1979320", 561 | "tt1981115", 562 | "tt2017561", 563 | "tt2053463", 564 | "tt2056771", 565 | "tt2058107", 566 | "tt2058673", 567 | "tt2059255", 568 | "tt2070649", 569 | "tt2084970", 570 | "tt2103281", 571 | "tt2109184", 572 | "tt2118775", 573 | "tt2140373", 574 | "tt2167266", 575 | "tt2238032", 576 | "tt2258281", 577 | "tt2267998", 578 | "tt2294449", 579 | "tt2310332", 580 | "tt2334873", 581 | "tt2345567", 582 | "tt2366450", 583 | "tt2381249", 584 | "tt2382298", 585 | "tt2404435", 586 | "tt2463288", 587 | "tt2473794", 588 | "tt2567026", 589 | "tt2582802", 590 | "tt2639344", 591 | "tt2675914", 592 | "tt2713180", 593 | "tt2717822", 594 | "tt2800240", 595 | "tt2823054", 596 | "tt2884018", 597 | "tt2908856", 598 | "tt2911666", 599 | "tt2923316", 600 | "tt2980516", 601 | "tt3062096", 602 | "tt3064298", 603 | "tt3077214", 604 | "tt3289956", 605 | "tt3296658", 606 | "tt3312830", 607 | "tt3319920", 608 | "tt3395184", 609 | "tt3410834", 610 | "tt3416744", 611 | "tt3439114", 612 | "tt3465916", 613 | "tt3474602", 614 | "tt3478232", 615 | "tt3498820", 616 | "tt3501416", 617 | "tt3531578", 618 | "tt3630276", 619 | "tt3659786", 620 | "tt3671542", 621 | "tt3700392", 622 | "tt3700804", 623 | "tt3707106", 624 | "tt3714720", 625 | "tt3766394", 626 | "tt3808342", 627 | "tt3860916", 628 | "tt3960412", 629 | "tt4046784", 630 | "tt4052882", 631 | "tt4136084", 632 | "tt4151192", 633 | "tt4176826", 634 | "tt4242158", 635 | "tt4273292", 636 | "tt4501454", 637 | "tt4651520", 638 | "tt4698684", 639 | "tt4721400", 640 | "tt4781612", 641 | "tt4786282", 642 | "tt4824302", 643 | "tt4939066", 644 | "tt5052448", 645 | "tt5065810", 646 | "tt5294550", 647 | "tt5564148", 648 | "tt5576318", 649 | "tt5580036", 650 | "tt5593416", 651 | "tt5649144", 652 | "tt5688868", 653 | "tt5827496", 654 | "tt5866930", 655 | "tt6133130", 656 | "tt6298600", 657 | "tt6466464", 658 | "tt6513406", 659 | "tt6788942", 660 | "tt7055592", 661 | "tt7131870", 662 | "tt7180392" 663 | ], 664 | "val": [ 665 | "tt0032138", 666 | "tt0038650", 667 | "tt0048545", 668 | "tt0053221", 669 | "tt0053579", 670 | "tt0054167", 671 | "tt0061722", 672 | "tt0064115", 673 | "tt0066026", 674 | "tt0067140", 675 | "tt0069762", 676 | "tt0070245", 677 | "tt0071562", 678 | "tt0072443", 679 | "tt0072890", 680 | "tt0073486", 681 | "tt0074811", 682 | "tt0076759", 683 | "tt0079182", 684 | "tt0079470", 685 | "tt0080610", 686 | "tt0080684", 687 | "tt0083866", 688 | "tt0083929", 689 | "tt0084899", 690 | "tt0085991", 691 | "tt0087332", 692 | "tt0089853", 693 | "tt0089907", 694 | "tt0090756", 695 | "tt0091355", 696 | "tt0091369", 697 | "tt0092699", 698 | "tt0092991", 699 | "tt0094737", 700 | "tt0094761", 701 | "tt0095765", 702 | "tt0095953", 703 | "tt0096256", 704 | "tt0096446", 705 | "tt0097576", 706 | "tt0099685", 707 | "tt0100112", 708 | "tt0100403", 709 | "tt0101410", 710 | "tt0102138", 711 | "tt0103074", 712 | "tt0103241", 713 | "tt0103292", 714 | "tt0104797", 715 | "tt0105236", 716 | "tt0105652", 717 | "tt0106226", 718 | "tt0106977", 719 | "tt0107614", 720 | "tt0108160", 721 | "tt0108656", 722 | "tt0109020", 723 | "tt0110475", 724 | "tt0110932", 725 | "tt0112462", 726 | "tt0112641", 727 | "tt0112740", 728 | "tt0113870", 729 | "tt0114558", 730 | "tt0116367", 731 | "tt0116996", 732 | "tt0117381", 733 | "tt0118548", 734 | "tt0118571", 735 | "tt0118842", 736 | "tt0119094", 737 | "tt0119116", 738 | "tt0119174", 739 | "tt0119822", 740 | "tt0120483", 741 | "tt0120601", 742 | "tt0120780", 743 | "tt0120863", 744 | "tt0121765", 745 | "tt0122933", 746 | "tt0129387", 747 | "tt0133093", 748 | "tt0138097", 749 | "tt0140352", 750 | "tt0145487", 751 | "tt0166896", 752 | "tt0166924", 753 | "tt0167261", 754 | "tt0167404", 755 | "tt0182789", 756 | "tt0186151", 757 | "tt0209958", 758 | "tt0217869", 759 | "tt0240890", 760 | "tt0248667", 761 | "tt0258463", 762 | "tt0261392", 763 | "tt0265666", 764 | "tt0268126", 765 | "tt0268978", 766 | "tt0277027", 767 | "tt0285742", 768 | "tt0289879", 769 | "tt0290002", 770 | "tt0298228", 771 | "tt0311113", 772 | "tt0317740", 773 | "tt0319262", 774 | "tt0322259", 775 | "tt0324197", 776 | "tt0337921", 777 | "tt0341495", 778 | "tt0343818", 779 | "tt0360486", 780 | "tt0370263", 781 | "tt0371724", 782 | "tt0375063", 783 | "tt0395169", 784 | "tt0401383", 785 | "tt0408236", 786 | "tt0416320", 787 | "tt0432021", 788 | "tt0434409", 789 | "tt0454848", 790 | "tt0455760", 791 | "tt0457297", 792 | "tt0457430", 793 | "tt0457513", 794 | "tt0467200", 795 | "tt0469494", 796 | "tt0470752", 797 | "tt0480025", 798 | "tt0758730", 799 | "tt0758758", 800 | "tt0780653", 801 | "tt0790628", 802 | "tt0808151", 803 | "tt0816692", 804 | "tt0824758", 805 | "tt0838232", 806 | "tt0898367", 807 | "tt0940709", 808 | "tt0964517", 809 | "tt0993842", 810 | "tt1000774", 811 | "tt1019452", 812 | "tt1032755", 813 | "tt1041829", 814 | "tt1055292", 815 | "tt1065073", 816 | "tt1071875", 817 | "tt1073498", 818 | "tt1093906", 819 | "tt1100089", 820 | "tt1144884", 821 | "tt1172049", 822 | "tt1178663", 823 | "tt1182345", 824 | "tt1190080", 825 | "tt1211837", 826 | "tt1216496", 827 | "tt1232829", 828 | "tt1284575", 829 | "tt1341167", 830 | "tt1355683", 831 | "tt1385826", 832 | "tt1409024", 833 | "tt1441953", 834 | "tt1462900", 835 | "tt1504320", 836 | "tt1540133", 837 | "tt1582248", 838 | "tt1586752", 839 | "tt1591095", 840 | "tt1596363", 841 | "tt1602613", 842 | "tt1611840", 843 | "tt1619029", 844 | "tt1645170", 845 | "tt1659337", 846 | "tt1703957", 847 | "tt1722484", 848 | "tt1725986", 849 | "tt1731141", 850 | "tt1742683", 851 | "tt1840309", 852 | "tt1895587", 853 | "tt1974419", 854 | "tt2032557", 855 | "tt2076220", 856 | "tt2078768", 857 | "tt2109248", 858 | "tt2132285", 859 | "tt2381991", 860 | "tt2645044", 861 | "tt2788732", 862 | "tt2832470", 863 | "tt2872732", 864 | "tt2978462", 865 | "tt3110958", 866 | "tt3316960", 867 | "tt3421514", 868 | "tt3464902", 869 | "tt3488710", 870 | "tt3508840", 871 | "tt3553442", 872 | "tt3672840", 873 | "tt3726704", 874 | "tt3824458", 875 | "tt3882082", 876 | "tt3922798", 877 | "tt4160708", 878 | "tt4647900", 879 | "tt4967094", 880 | "tt5726086", 881 | "tt6121428", 882 | "tt6190198", 883 | "tt7160070", 884 | "tt7672188" 885 | ], 886 | "test": [ 887 | "tt0048028", 888 | "tt0049470", 889 | "tt0049833", 890 | "tt0050419", 891 | "tt0052357", 892 | "tt0058331", 893 | "tt0061811", 894 | "tt0063442", 895 | "tt0066206", 896 | "tt0068646", 897 | "tt0070291", 898 | "tt0070511", 899 | "tt0073195", 900 | "tt0073582", 901 | "tt0073629", 902 | "tt0075314", 903 | "tt0075686", 904 | "tt0078788", 905 | "tt0079672", 906 | "tt0080455", 907 | "tt0080761", 908 | "tt0082089", 909 | "tt0082198", 910 | "tt0083907", 911 | "tt0083946", 912 | "tt0084390", 913 | "tt0086190", 914 | "tt0086856", 915 | "tt0087921", 916 | "tt0088247", 917 | "tt0088944", 918 | "tt0089218", 919 | "tt0089881", 920 | "tt0090257", 921 | "tt0091867", 922 | "tt0092099", 923 | "tt0093773", 924 | "tt0094964", 925 | "tt0095250", 926 | "tt0096320", 927 | "tt0099423", 928 | "tt0100405", 929 | "tt0103776", 930 | "tt0103855", 931 | "tt0104466", 932 | "tt0104553", 933 | "tt0104691", 934 | "tt0107290", 935 | "tt0107617", 936 | "tt0108399", 937 | "tt0110116", 938 | "tt0110167", 939 | "tt0110604", 940 | "tt0111280", 941 | "tt0111797", 942 | "tt0112384", 943 | "tt0112573", 944 | "tt0112818", 945 | "tt0112883", 946 | "tt0113277", 947 | "tt0114746", 948 | "tt0115734", 949 | "tt0115759", 950 | "tt0115956", 951 | "tt0116213", 952 | "tt0116282", 953 | "tt0116767", 954 | "tt0116922", 955 | "tt0117060", 956 | "tt0117571", 957 | "tt0118583", 958 | "tt0118715", 959 | "tt0119303", 960 | "tt0119349", 961 | "tt0119375", 962 | "tt0119488", 963 | "tt0120255", 964 | "tt0120382", 965 | "tt0120689", 966 | "tt0120731", 967 | "tt0120738", 968 | "tt0120812", 969 | "tt0120890", 970 | "tt0120903", 971 | "tt0123755", 972 | "tt0124315", 973 | "tt0127536", 974 | "tt0133152", 975 | "tt0137439", 976 | "tt0137523", 977 | "tt0142342", 978 | "tt0163025", 979 | "tt0172495", 980 | "tt0178868", 981 | "tt0190332", 982 | "tt0195714", 983 | "tt0212985", 984 | "tt0217505", 985 | "tt0219822", 986 | "tt0253474", 987 | "tt0257360", 988 | "tt0280609", 989 | "tt0281358", 990 | "tt0281686", 991 | "tt0319061", 992 | "tt0330373", 993 | "tt0335119", 994 | "tt0361748", 995 | "tt0368891", 996 | "tt0368933", 997 | "tt0369441", 998 | "tt0370032", 999 | "tt0373051", 1000 | "tt0373469", 1001 | "tt0379786", 1002 | "tt0381061", 1003 | "tt0386588", 1004 | "tt0387898", 1005 | "tt0399201", 1006 | "tt0404978", 1007 | "tt0409459", 1008 | "tt0416508", 1009 | "tt0440963", 1010 | "tt0443272", 1011 | "tt0443680", 1012 | "tt0452625", 1013 | "tt0455824", 1014 | "tt0458352", 1015 | "tt0458525", 1016 | "tt0460791", 1017 | "tt0460989", 1018 | "tt0462499", 1019 | "tt0477347", 1020 | "tt0479884", 1021 | "tt0481369", 1022 | "tt0780571", 1023 | "tt0783233", 1024 | "tt0800080", 1025 | "tt0800369", 1026 | "tt0815245", 1027 | "tt0822832", 1028 | "tt0844347", 1029 | "tt0878804", 1030 | "tt0903624", 1031 | "tt0905372", 1032 | "tt0944835", 1033 | "tt0945513", 1034 | "tt0970179", 1035 | "tt0976051", 1036 | "tt1001508", 1037 | "tt1007029", 1038 | "tt1017460", 1039 | "tt1033575", 1040 | "tt1034314", 1041 | "tt1038919", 1042 | "tt1046173", 1043 | "tt1063669", 1044 | "tt1086772", 1045 | "tt1092026", 1046 | "tt1099212", 1047 | "tt1119646", 1048 | "tt1120985", 1049 | "tt1124037", 1050 | "tt1170358", 1051 | "tt1181614", 1052 | "tt1189340", 1053 | "tt1201607", 1054 | "tt1205489", 1055 | "tt1220634", 1056 | "tt1229238", 1057 | "tt1287878", 1058 | "tt1292566", 1059 | "tt1318514", 1060 | "tt1375666", 1061 | "tt1386932", 1062 | "tt1392190", 1063 | "tt1397514", 1064 | "tt1399103", 1065 | "tt1412386", 1066 | "tt1413492", 1067 | "tt1424381", 1068 | "tt1431045", 1069 | "tt1440728", 1070 | "tt1446147", 1071 | "tt1483013", 1072 | "tt1510906", 1073 | "tt1524137", 1074 | "tt1570728", 1075 | "tt1623205", 1076 | "tt1663662", 1077 | "tt1707386", 1078 | "tt1748122", 1079 | "tt1843287", 1080 | "tt1853728", 1081 | "tt1872181", 1082 | "tt2011351", 1083 | "tt2024544", 1084 | "tt2099556", 1085 | "tt2115388", 1086 | "tt2194499", 1087 | "tt2402927", 1088 | "tt2409818", 1089 | "tt2446980", 1090 | "tt2488496", 1091 | "tt2567712", 1092 | "tt2582846", 1093 | "tt2614684", 1094 | "tt2802144", 1095 | "tt3385516", 1096 | "tt3480796", 1097 | "tt3495026", 1098 | "tt4008652", 1099 | "tt4034354", 1100 | "tt4520364", 1101 | "tt4915672", 1102 | "tt4972062", 1103 | "tt5140878", 1104 | "tt6157626", 1105 | "tt6518634", 1106 | "tt6644200" 1107 | ], 1108 | "full": [ 1109 | "tt0032138", 1110 | "tt0035423", 1111 | "tt0038650", 1112 | "tt0045537", 1113 | "tt0047396", 1114 | "tt0048028", 1115 | "tt0048545", 1116 | "tt0048605", 1117 | "tt0049470", 1118 | "tt0049730", 1119 | "tt0049833", 1120 | "tt0050419", 1121 | "tt0050706", 1122 | "tt0052357", 1123 | "tt0053125", 1124 | "tt0053221", 1125 | "tt0053579", 1126 | "tt0054167", 1127 | "tt0056869", 1128 | "tt0056923", 1129 | "tt0057115", 1130 | "tt0058331", 1131 | "tt0058461", 1132 | "tt0059043", 1133 | "tt0059592", 1134 | "tt0060522", 1135 | "tt0061138", 1136 | "tt0061418", 1137 | "tt0061722", 1138 | "tt0061781", 1139 | "tt0061811", 1140 | "tt0062622", 1141 | "tt0063442", 1142 | "tt0064040", 1143 | "tt0064115", 1144 | "tt0064276", 1145 | "tt0064665", 1146 | "tt0065214", 1147 | "tt0065724", 1148 | "tt0065988", 1149 | "tt0066026", 1150 | "tt0066206", 1151 | "tt0066249", 1152 | "tt0066921", 1153 | "tt0067116", 1154 | "tt0067140", 1155 | "tt0067185", 1156 | "tt0068646", 1157 | "tt0068935", 1158 | "tt0069293", 1159 | "tt0069467", 1160 | "tt0069762", 1161 | "tt0069995", 1162 | "tt0070047", 1163 | "tt0070245", 1164 | "tt0070246", 1165 | "tt0070291", 1166 | "tt0070379", 1167 | "tt0070511", 1168 | "tt0070735", 1169 | "tt0070849", 1170 | "tt0071129", 1171 | "tt0071315", 1172 | "tt0071360", 1173 | "tt0071562", 1174 | "tt0072443", 1175 | "tt0072684", 1176 | "tt0072890", 1177 | "tt0073195", 1178 | "tt0073440", 1179 | "tt0073486", 1180 | "tt0073582", 1181 | "tt0073629", 1182 | "tt0074119", 1183 | "tt0074285", 1184 | "tt0074686", 1185 | "tt0074749", 1186 | "tt0074811", 1187 | "tt0075148", 1188 | "tt0075314", 1189 | "tt0075686", 1190 | "tt0076729", 1191 | "tt0076759", 1192 | "tt0077402", 1193 | "tt0077405", 1194 | "tt0077416", 1195 | "tt0077651", 1196 | "tt0078788", 1197 | "tt0078841", 1198 | "tt0078908", 1199 | "tt0079095", 1200 | "tt0079116", 1201 | "tt0079182", 1202 | "tt0079417", 1203 | "tt0079470", 1204 | "tt0079672", 1205 | "tt0079944", 1206 | "tt0079945", 1207 | "tt0080339", 1208 | "tt0080453", 1209 | "tt0080455", 1210 | "tt0080610", 1211 | "tt0080684", 1212 | "tt0080745", 1213 | "tt0080761", 1214 | "tt0080958", 1215 | "tt0080979", 1216 | "tt0081505", 1217 | "tt0082089", 1218 | "tt0082186", 1219 | "tt0082198", 1220 | "tt0082846", 1221 | "tt0082971", 1222 | "tt0083658", 1223 | "tt0083866", 1224 | "tt0083907", 1225 | "tt0083929", 1226 | "tt0083946", 1227 | "tt0083987", 1228 | "tt0084390", 1229 | "tt0084549", 1230 | "tt0084628", 1231 | "tt0084726", 1232 | "tt0084787", 1233 | "tt0084899", 1234 | "tt0085794", 1235 | "tt0085991", 1236 | "tt0086190", 1237 | "tt0086250", 1238 | "tt0086837", 1239 | "tt0086856", 1240 | "tt0086879", 1241 | "tt0086969", 1242 | "tt0087182", 1243 | "tt0087332", 1244 | "tt0087469", 1245 | "tt0087921", 1246 | "tt0088170", 1247 | "tt0088222", 1248 | "tt0088247", 1249 | "tt0088847", 1250 | "tt0088939", 1251 | "tt0088944", 1252 | "tt0088993", 1253 | "tt0089218", 1254 | "tt0089853", 1255 | "tt0089881", 1256 | "tt0089907", 1257 | "tt0090022", 1258 | "tt0090257", 1259 | "tt0090605", 1260 | "tt0090756", 1261 | "tt0091042", 1262 | "tt0091203", 1263 | "tt0091251", 1264 | "tt0091355", 1265 | "tt0091369", 1266 | "tt0091406", 1267 | "tt0091738", 1268 | "tt0091763", 1269 | "tt0091867", 1270 | "tt0092099", 1271 | "tt0092603", 1272 | "tt0092699", 1273 | "tt0092991", 1274 | "tt0093010", 1275 | "tt0093209", 1276 | "tt0093565", 1277 | "tt0093748", 1278 | "tt0093773", 1279 | "tt0093779", 1280 | "tt0094226", 1281 | "tt0094291", 1282 | "tt0094737", 1283 | "tt0094761", 1284 | "tt0094964", 1285 | "tt0095016", 1286 | "tt0095250", 1287 | "tt0095497", 1288 | "tt0095765", 1289 | "tt0095953", 1290 | "tt0095956", 1291 | "tt0096256", 1292 | "tt0096320", 1293 | "tt0096446", 1294 | "tt0096463", 1295 | "tt0096754", 1296 | "tt0096874", 1297 | "tt0096895", 1298 | "tt0097216", 1299 | "tt0097372", 1300 | "tt0097428", 1301 | "tt0097576", 1302 | "tt0098258", 1303 | "tt0098635", 1304 | "tt0098724", 1305 | "tt0099348", 1306 | "tt0099423", 1307 | "tt0099487", 1308 | "tt0099653", 1309 | "tt0099674", 1310 | "tt0099685", 1311 | "tt0099810", 1312 | "tt0100112", 1313 | "tt0100150", 1314 | "tt0100157", 1315 | "tt0100234", 1316 | "tt0100403", 1317 | "tt0100405", 1318 | "tt0100802", 1319 | "tt0100935", 1320 | "tt0100998", 1321 | "tt0101272", 1322 | "tt0101393", 1323 | "tt0101410", 1324 | "tt0101700", 1325 | "tt0101889", 1326 | "tt0101921", 1327 | "tt0102138", 1328 | "tt0102492", 1329 | "tt0102926", 1330 | "tt0103064", 1331 | "tt0103074", 1332 | "tt0103241", 1333 | "tt0103292", 1334 | "tt0103772", 1335 | "tt0103776", 1336 | "tt0103786", 1337 | "tt0103855", 1338 | "tt0104036", 1339 | "tt0104257", 1340 | "tt0104348", 1341 | "tt0104466", 1342 | "tt0104553", 1343 | "tt0104691", 1344 | "tt0104797", 1345 | "tt0105226", 1346 | "tt0105236", 1347 | "tt0105652", 1348 | "tt0105665", 1349 | "tt0105695", 1350 | "tt0106226", 1351 | "tt0106332", 1352 | "tt0106582", 1353 | "tt0106977", 1354 | "tt0107290", 1355 | "tt0107507", 1356 | "tt0107614", 1357 | "tt0107617", 1358 | "tt0107653", 1359 | "tt0107736", 1360 | "tt0107808", 1361 | "tt0107822", 1362 | "tt0108122", 1363 | "tt0108160", 1364 | "tt0108289", 1365 | "tt0108330", 1366 | "tt0108399", 1367 | "tt0108656", 1368 | "tt0109020", 1369 | "tt0109686", 1370 | "tt0109830", 1371 | "tt0109831", 1372 | "tt0110074", 1373 | "tt0110116", 1374 | "tt0110148", 1375 | "tt0110167", 1376 | "tt0110201", 1377 | "tt0110322", 1378 | "tt0110475", 1379 | "tt0110604", 1380 | "tt0110632", 1381 | "tt0110912", 1382 | "tt0110932", 1383 | "tt0111003", 1384 | "tt0111280", 1385 | "tt0111797", 1386 | "tt0112384", 1387 | "tt0112462", 1388 | "tt0112573", 1389 | "tt0112641", 1390 | "tt0112740", 1391 | "tt0112769", 1392 | "tt0112818", 1393 | "tt0112883", 1394 | "tt0113101", 1395 | "tt0113243", 1396 | "tt0113253", 1397 | "tt0113277", 1398 | "tt0113497", 1399 | "tt0113870", 1400 | "tt0114367", 1401 | "tt0114369", 1402 | "tt0114388", 1403 | "tt0114558", 1404 | "tt0114746", 1405 | "tt0114814", 1406 | "tt0115734", 1407 | "tt0115759", 1408 | "tt0115798", 1409 | "tt0115956", 1410 | "tt0115964", 1411 | "tt0116209", 1412 | "tt0116213", 1413 | "tt0116282", 1414 | "tt0116367", 1415 | "tt0116477", 1416 | "tt0116629", 1417 | "tt0116695", 1418 | "tt0116767", 1419 | "tt0116922", 1420 | "tt0116996", 1421 | "tt0117060", 1422 | "tt0117381", 1423 | "tt0117500", 1424 | "tt0117509", 1425 | "tt0117571", 1426 | "tt0117666", 1427 | "tt0117731", 1428 | "tt0117883", 1429 | "tt0117951", 1430 | "tt0118548", 1431 | "tt0118571", 1432 | "tt0118583", 1433 | "tt0118636", 1434 | "tt0118655", 1435 | "tt0118688", 1436 | "tt0118689", 1437 | "tt0118715", 1438 | "tt0118749", 1439 | "tt0118799", 1440 | "tt0118842", 1441 | "tt0118845", 1442 | "tt0118883", 1443 | "tt0118929", 1444 | "tt0118971", 1445 | "tt0119008", 1446 | "tt0119081", 1447 | "tt0119094", 1448 | "tt0119116", 1449 | "tt0119174", 1450 | "tt0119177", 1451 | "tt0119250", 1452 | "tt0119303", 1453 | "tt0119314", 1454 | "tt0119349", 1455 | "tt0119375", 1456 | "tt0119396", 1457 | "tt0119488", 1458 | "tt0119528", 1459 | "tt0119567", 1460 | "tt0119643", 1461 | "tt0119654", 1462 | "tt0119670", 1463 | "tt0119738", 1464 | "tt0119822", 1465 | "tt0120255", 1466 | "tt0120263", 1467 | "tt0120338", 1468 | "tt0120382", 1469 | "tt0120483", 1470 | "tt0120586", 1471 | "tt0120591", 1472 | "tt0120601", 1473 | "tt0120616", 1474 | "tt0120655", 1475 | "tt0120660", 1476 | "tt0120667", 1477 | "tt0120669", 1478 | "tt0120689", 1479 | "tt0120696", 1480 | "tt0120731", 1481 | "tt0120735", 1482 | "tt0120737", 1483 | "tt0120738", 1484 | "tt0120744", 1485 | "tt0120755", 1486 | "tt0120780", 1487 | "tt0120787", 1488 | "tt0120804", 1489 | "tt0120812", 1490 | "tt0120815", 1491 | "tt0120863", 1492 | "tt0120885", 1493 | "tt0120890", 1494 | "tt0120902", 1495 | "tt0120903", 1496 | "tt0120912", 1497 | "tt0120915", 1498 | "tt0121765", 1499 | "tt0121766", 1500 | "tt0122690", 1501 | "tt0122933", 1502 | "tt0123755", 1503 | "tt0124315", 1504 | "tt0125439", 1505 | "tt0125664", 1506 | "tt0126886", 1507 | "tt0127536", 1508 | "tt0128445", 1509 | "tt0129387", 1510 | "tt0133093", 1511 | "tt0133152", 1512 | "tt0134119", 1513 | "tt0134273", 1514 | "tt0134847", 1515 | "tt0137439", 1516 | "tt0137494", 1517 | "tt0137523", 1518 | "tt0138097", 1519 | "tt0139134", 1520 | "tt0139654", 1521 | "tt0140352", 1522 | "tt0142342", 1523 | "tt0142688", 1524 | "tt0143145", 1525 | "tt0144084", 1526 | "tt0144117", 1527 | "tt0145487", 1528 | "tt0159365", 1529 | "tt0159784", 1530 | "tt0160127", 1531 | "tt0162346", 1532 | "tt0162661", 1533 | "tt0163025", 1534 | "tt0164052", 1535 | "tt0166896", 1536 | "tt0166924", 1537 | "tt0167190", 1538 | "tt0167260", 1539 | "tt0167261", 1540 | "tt0167331", 1541 | "tt0167404", 1542 | "tt0169547", 1543 | "tt0171363", 1544 | "tt0172495", 1545 | "tt0175880", 1546 | "tt0178868", 1547 | "tt0180073", 1548 | "tt0180093", 1549 | "tt0181689", 1550 | "tt0181875", 1551 | "tt0182789", 1552 | "tt0183523", 1553 | "tt0183649", 1554 | "tt0186151", 1555 | "tt0187078", 1556 | "tt0187393", 1557 | "tt0190332", 1558 | "tt0190590", 1559 | "tt0195685", 1560 | "tt0195714", 1561 | "tt0199354", 1562 | "tt0199753", 1563 | "tt0203009", 1564 | "tt0206634", 1565 | "tt0207201", 1566 | "tt0208092", 1567 | "tt0209144", 1568 | "tt0209463", 1569 | "tt0209958", 1570 | "tt0210727", 1571 | "tt0212338", 1572 | "tt0212985", 1573 | "tt0213149", 1574 | "tt0217505", 1575 | "tt0217869", 1576 | "tt0219822", 1577 | "tt0227445", 1578 | "tt0232500", 1579 | "tt0234215", 1580 | "tt0240772", 1581 | "tt0240890", 1582 | "tt0242653", 1583 | "tt0243876", 1584 | "tt0244244", 1585 | "tt0244353", 1586 | "tt0245844", 1587 | "tt0246578", 1588 | "tt0248667", 1589 | "tt0250494", 1590 | "tt0250797", 1591 | "tt0251160", 1592 | "tt0253474", 1593 | "tt0253754", 1594 | "tt0257360", 1595 | "tt0258000", 1596 | "tt0258463", 1597 | "tt0261392", 1598 | "tt0264395", 1599 | "tt0264616", 1600 | "tt0265666", 1601 | "tt0266697", 1602 | "tt0266915", 1603 | "tt0268126", 1604 | "tt0268695", 1605 | "tt0268978", 1606 | "tt0272152", 1607 | "tt0275719", 1608 | "tt0277027", 1609 | "tt0278504", 1610 | "tt0280609", 1611 | "tt0281358", 1612 | "tt0281686", 1613 | "tt0283509", 1614 | "tt0285742", 1615 | "tt0286106", 1616 | "tt0288477", 1617 | "tt0289879", 1618 | "tt0290002", 1619 | "tt0290334", 1620 | "tt0294870", 1621 | "tt0298228", 1622 | "tt0299658", 1623 | "tt0308476", 1624 | "tt0309698", 1625 | "tt0311113", 1626 | "tt0313542", 1627 | "tt0315327", 1628 | "tt0316654", 1629 | "tt0317198", 1630 | "tt0317740", 1631 | "tt0317919", 1632 | "tt0318627", 1633 | "tt0318974", 1634 | "tt0319061", 1635 | "tt0319262", 1636 | "tt0322259", 1637 | "tt0324197", 1638 | "tt0325710", 1639 | "tt0325980", 1640 | "tt0328107", 1641 | "tt0329101", 1642 | "tt0330373", 1643 | "tt0331811", 1644 | "tt0332452", 1645 | "tt0335119", 1646 | "tt0335266", 1647 | "tt0337921", 1648 | "tt0337978", 1649 | "tt0338013", 1650 | "tt0338751", 1651 | "tt0341495", 1652 | "tt0343660", 1653 | "tt0343818", 1654 | "tt0346094", 1655 | "tt0349903", 1656 | "tt0350258", 1657 | "tt0351977", 1658 | "tt0357413", 1659 | "tt0359950", 1660 | "tt0360486", 1661 | "tt0361748", 1662 | "tt0362227", 1663 | "tt0363589", 1664 | "tt0363771", 1665 | "tt0365907", 1666 | "tt0368891", 1667 | "tt0368933", 1668 | "tt0369339", 1669 | "tt0369441", 1670 | "tt0369702", 1671 | "tt0370032", 1672 | "tt0370263", 1673 | "tt0371257", 1674 | "tt0371724", 1675 | "tt0372183", 1676 | "tt0372784", 1677 | "tt0372824", 1678 | "tt0373051", 1679 | "tt0373074", 1680 | "tt0373469", 1681 | "tt0374546", 1682 | "tt0375063", 1683 | "tt0375679", 1684 | "tt0376994", 1685 | "tt0377713", 1686 | "tt0378194", 1687 | "tt0379306", 1688 | "tt0379786", 1689 | "tt0381061", 1690 | "tt0382625", 1691 | "tt0383028", 1692 | "tt0383216", 1693 | "tt0383574", 1694 | "tt0385004", 1695 | "tt0386588", 1696 | "tt0387564", 1697 | "tt0387877", 1698 | "tt0387898", 1699 | "tt0388795", 1700 | "tt0390022", 1701 | "tt0393109", 1702 | "tt0395169", 1703 | "tt0395699", 1704 | "tt0397078", 1705 | "tt0398027", 1706 | "tt0399201", 1707 | "tt0399295", 1708 | "tt0401383", 1709 | "tt0404978", 1710 | "tt0405159", 1711 | "tt0407887", 1712 | "tt0408236", 1713 | "tt0408306", 1714 | "tt0408790", 1715 | "tt0409459", 1716 | "tt0413893", 1717 | "tt0414055", 1718 | "tt0414387", 1719 | "tt0414982", 1720 | "tt0415380", 1721 | "tt0416320", 1722 | "tt0416508", 1723 | "tt0417741", 1724 | "tt0418279", 1725 | "tt0418819", 1726 | "tt0419887", 1727 | "tt0420223", 1728 | "tt0421715", 1729 | "tt0424345", 1730 | "tt0425061", 1731 | "tt0425210", 1732 | "tt0427309", 1733 | "tt0427954", 1734 | "tt0430357", 1735 | "tt0432021", 1736 | "tt0433035", 1737 | "tt0434409", 1738 | "tt0435705", 1739 | "tt0439815", 1740 | "tt0440963", 1741 | "tt0443272", 1742 | "tt0443453", 1743 | "tt0443680", 1744 | "tt0443706", 1745 | "tt0448157", 1746 | "tt0449088", 1747 | "tt0450259", 1748 | "tt0450385", 1749 | "tt0452625", 1750 | "tt0454841", 1751 | "tt0454848", 1752 | "tt0454876", 1753 | "tt0454921", 1754 | "tt0455760", 1755 | "tt0455824", 1756 | "tt0457297", 1757 | "tt0457430", 1758 | "tt0457513", 1759 | "tt0457939", 1760 | "tt0458352", 1761 | "tt0458413", 1762 | "tt0458525", 1763 | "tt0460791", 1764 | "tt0460989", 1765 | "tt0462200", 1766 | "tt0462499", 1767 | "tt0467200", 1768 | "tt0467406", 1769 | "tt0468565", 1770 | "tt0468569", 1771 | "tt0469494", 1772 | "tt0470752", 1773 | "tt0473705", 1774 | "tt0475293", 1775 | "tt0477347", 1776 | "tt0477348", 1777 | "tt0479884", 1778 | "tt0479997", 1779 | "tt0480025", 1780 | "tt0481369", 1781 | "tt0489018", 1782 | "tt0493464", 1783 | "tt0499448", 1784 | "tt0499549", 1785 | "tt0758730", 1786 | "tt0758758", 1787 | "tt0758774", 1788 | "tt0765128", 1789 | "tt0765429", 1790 | "tt0765447", 1791 | "tt0780504", 1792 | "tt0780571", 1793 | "tt0780653", 1794 | "tt0783233", 1795 | "tt0790628", 1796 | "tt0790636", 1797 | "tt0790686", 1798 | "tt0796366", 1799 | "tt0800080", 1800 | "tt0800320", 1801 | "tt0800369", 1802 | "tt0808151", 1803 | "tt0810819", 1804 | "tt0815236", 1805 | "tt0815245", 1806 | "tt0816692", 1807 | "tt0822832", 1808 | "tt0824747", 1809 | "tt0824758", 1810 | "tt0826711", 1811 | "tt0829482", 1812 | "tt0838232", 1813 | "tt0844286", 1814 | "tt0844347", 1815 | "tt0846308", 1816 | "tt0848228", 1817 | "tt0862846", 1818 | "tt0878804", 1819 | "tt0887883", 1820 | "tt0898367", 1821 | "tt0903624", 1822 | "tt0905372", 1823 | "tt0913425", 1824 | "tt0914798", 1825 | "tt0940709", 1826 | "tt0942385", 1827 | "tt0944835", 1828 | "tt0945513", 1829 | "tt0947798", 1830 | "tt0958860", 1831 | "tt0959337", 1832 | "tt0963794", 1833 | "tt0963966", 1834 | "tt0964517", 1835 | "tt0970179", 1836 | "tt0970416", 1837 | "tt0974661", 1838 | "tt0975645", 1839 | "tt0976051", 1840 | "tt0977855", 1841 | "tt0985694", 1842 | "tt0985699", 1843 | "tt0986233", 1844 | "tt0986263", 1845 | "tt0988045", 1846 | "tt0993842", 1847 | "tt0993846", 1848 | "tt1000774", 1849 | "tt1001508", 1850 | "tt1007029", 1851 | "tt1010048", 1852 | "tt1013753", 1853 | "tt1016268", 1854 | "tt1017460", 1855 | "tt1019452", 1856 | "tt1022603", 1857 | "tt1024648", 1858 | "tt1027718", 1859 | "tt1029234", 1860 | "tt1029360", 1861 | "tt1032755", 1862 | "tt1033575", 1863 | "tt1034314", 1864 | "tt1037705", 1865 | "tt1038919", 1866 | "tt1041829", 1867 | "tt1045658", 1868 | "tt1045772", 1869 | "tt1046173", 1870 | "tt1054606", 1871 | "tt1055292", 1872 | "tt1055369", 1873 | "tt1057500", 1874 | "tt1059786", 1875 | "tt1063669", 1876 | "tt1065073", 1877 | "tt1068649", 1878 | "tt1068680", 1879 | "tt1071875", 1880 | "tt1072748", 1881 | "tt1073498", 1882 | "tt1074638", 1883 | "tt1084950", 1884 | "tt1086772", 1885 | "tt1092026", 1886 | "tt1093906", 1887 | "tt1099212", 1888 | "tt1100089", 1889 | "tt1104001", 1890 | "tt1119646", 1891 | "tt1120985", 1892 | "tt1124035", 1893 | "tt1124037", 1894 | "tt1125849", 1895 | "tt1131729", 1896 | "tt1133985", 1897 | "tt1135952", 1898 | "tt1139797", 1899 | "tt1144884", 1900 | "tt1148204", 1901 | "tt1156466", 1902 | "tt1158278", 1903 | "tt1170358", 1904 | "tt1172049", 1905 | "tt1174732", 1906 | "tt1178663", 1907 | "tt1179031", 1908 | "tt1179904", 1909 | "tt1181614", 1910 | "tt1182345", 1911 | "tt1186367", 1912 | "tt1188729", 1913 | "tt1189340", 1914 | "tt1190080", 1915 | "tt1193138", 1916 | "tt1194173", 1917 | "tt1201607", 1918 | "tt1205489", 1919 | "tt1210166", 1920 | "tt1211837", 1921 | "tt1216496", 1922 | "tt1217613", 1923 | "tt1219289", 1924 | "tt1220634", 1925 | "tt1220719", 1926 | "tt1228705", 1927 | "tt1229238", 1928 | "tt1229340", 1929 | "tt1229822", 1930 | "tt1232829", 1931 | "tt1233381", 1932 | "tt1244754", 1933 | "tt1253863", 1934 | "tt1255953", 1935 | "tt1274586", 1936 | "tt1276104", 1937 | "tt1282140", 1938 | "tt1284575", 1939 | "tt1285016", 1940 | "tt1287878", 1941 | "tt1291150", 1942 | "tt1291584", 1943 | "tt1292566", 1944 | "tt1298650", 1945 | "tt1300851", 1946 | "tt1305806", 1947 | "tt1306980", 1948 | "tt1318514", 1949 | "tt1322269", 1950 | "tt1324999", 1951 | "tt1340800", 1952 | "tt1341167", 1953 | "tt1343092", 1954 | "tt1355683", 1955 | "tt1360860", 1956 | "tt1371111", 1957 | "tt1375666", 1958 | "tt1375670", 1959 | "tt1385826", 1960 | "tt1386932", 1961 | "tt1392190", 1962 | "tt1396218", 1963 | "tt1397514", 1964 | "tt1399103", 1965 | "tt1401152", 1966 | "tt1403865", 1967 | "tt1409024", 1968 | "tt1411238", 1969 | "tt1412386", 1970 | "tt1413492", 1971 | "tt1424381", 1972 | "tt1431045", 1973 | "tt1438176", 1974 | "tt1439572", 1975 | "tt1440728", 1976 | "tt1441953", 1977 | "tt1446147", 1978 | "tt1446714", 1979 | "tt1454029", 1980 | "tt1454468", 1981 | "tt1458175", 1982 | "tt1462758", 1983 | "tt1462900", 1984 | "tt1468846", 1985 | "tt1478338", 1986 | "tt1483013", 1987 | "tt1486190", 1988 | "tt1502712", 1989 | "tt1504320", 1990 | "tt1510906", 1991 | "tt1524137", 1992 | "tt1533117", 1993 | "tt1535970", 1994 | "tt1540133", 1995 | "tt1560747", 1996 | "tt1563738", 1997 | "tt1564367", 1998 | "tt1568346", 1999 | "tt1570728", 2000 | "tt1582248", 2001 | "tt1586752", 2002 | "tt1591095", 2003 | "tt1596363", 2004 | "tt1602613", 2005 | "tt1602620", 2006 | "tt1606378", 2007 | "tt1611840", 2008 | "tt1615147", 2009 | "tt1616195", 2010 | "tt1619029", 2011 | "tt1623205", 2012 | "tt1628841", 2013 | "tt1637725", 2014 | "tt1645170", 2015 | "tt1646987", 2016 | "tt1649443", 2017 | "tt1655420", 2018 | "tt1659337", 2019 | "tt1663662", 2020 | "tt1670345", 2021 | "tt1675434", 2022 | "tt1692486", 2023 | "tt1703957", 2024 | "tt1706593", 2025 | "tt1707386", 2026 | "tt1722484", 2027 | "tt1723811", 2028 | "tt1725986", 2029 | "tt1731141", 2030 | "tt1742683", 2031 | "tt1747958", 2032 | "tt1748122", 2033 | "tt1757746", 2034 | "tt1781769", 2035 | "tt1800241", 2036 | "tt1800246", 2037 | "tt1809398", 2038 | "tt1832382", 2039 | "tt1840309", 2040 | "tt1843287", 2041 | "tt1853728", 2042 | "tt1855325", 2043 | "tt1872181", 2044 | "tt1877832", 2045 | "tt1895587", 2046 | "tt1907668", 2047 | "tt1951266", 2048 | "tt1971325", 2049 | "tt1974419", 2050 | "tt1979320", 2051 | "tt1981115", 2052 | "tt2011351", 2053 | "tt2017561", 2054 | "tt2024544", 2055 | "tt2032557", 2056 | "tt2053463", 2057 | "tt2056771", 2058 | "tt2058107", 2059 | "tt2058673", 2060 | "tt2059255", 2061 | "tt2070649", 2062 | "tt2076220", 2063 | "tt2078768", 2064 | "tt2084970", 2065 | "tt2099556", 2066 | "tt2103281", 2067 | "tt2109184", 2068 | "tt2109248", 2069 | "tt2115388", 2070 | "tt2118775", 2071 | "tt2132285", 2072 | "tt2140373", 2073 | "tt2167266", 2074 | "tt2194499", 2075 | "tt2238032", 2076 | "tt2258281", 2077 | "tt2267998", 2078 | "tt2294449", 2079 | "tt2310332", 2080 | "tt2334873", 2081 | "tt2345567", 2082 | "tt2366450", 2083 | "tt2381249", 2084 | "tt2381991", 2085 | "tt2382298", 2086 | "tt2402927", 2087 | "tt2404435", 2088 | "tt2409818", 2089 | "tt2446980", 2090 | "tt2463288", 2091 | "tt2473794", 2092 | "tt2488496", 2093 | "tt2567026", 2094 | "tt2567712", 2095 | "tt2582802", 2096 | "tt2582846", 2097 | "tt2614684", 2098 | "tt2639344", 2099 | "tt2645044", 2100 | "tt2675914", 2101 | "tt2713180", 2102 | "tt2717822", 2103 | "tt2788732", 2104 | "tt2800240", 2105 | "tt2802144", 2106 | "tt2823054", 2107 | "tt2832470", 2108 | "tt2872732", 2109 | "tt2884018", 2110 | "tt2908856", 2111 | "tt2911666", 2112 | "tt2923316", 2113 | "tt2978462", 2114 | "tt2980516", 2115 | "tt3062096", 2116 | "tt3064298", 2117 | "tt3077214", 2118 | "tt3110958", 2119 | "tt3289956", 2120 | "tt3296658", 2121 | "tt3312830", 2122 | "tt3316960", 2123 | "tt3319920", 2124 | "tt3385516", 2125 | "tt3395184", 2126 | "tt3410834", 2127 | "tt3416744", 2128 | "tt3421514", 2129 | "tt3439114", 2130 | "tt3464902", 2131 | "tt3465916", 2132 | "tt3474602", 2133 | "tt3478232", 2134 | "tt3480796", 2135 | "tt3488710", 2136 | "tt3495026", 2137 | "tt3498820", 2138 | "tt3501416", 2139 | "tt3508840", 2140 | "tt3531578", 2141 | "tt3553442", 2142 | "tt3630276", 2143 | "tt3659786", 2144 | "tt3671542", 2145 | "tt3672840", 2146 | "tt3700392", 2147 | "tt3700804", 2148 | "tt3707106", 2149 | "tt3714720", 2150 | "tt3726704", 2151 | "tt3766394", 2152 | "tt3808342", 2153 | "tt3824458", 2154 | "tt3860916", 2155 | "tt3882082", 2156 | "tt3922798", 2157 | "tt3960412", 2158 | "tt4008652", 2159 | "tt4034354", 2160 | "tt4046784", 2161 | "tt4052882", 2162 | "tt4136084", 2163 | "tt4151192", 2164 | "tt4160708", 2165 | "tt4176826", 2166 | "tt4242158", 2167 | "tt4273292", 2168 | "tt4501454", 2169 | "tt4520364", 2170 | "tt4647900", 2171 | "tt4651520", 2172 | "tt4698684", 2173 | "tt4721400", 2174 | "tt4781612", 2175 | "tt4786282", 2176 | "tt4824302", 2177 | "tt4915672", 2178 | "tt4939066", 2179 | "tt4967094", 2180 | "tt4972062", 2181 | "tt5052448", 2182 | "tt5065810", 2183 | "tt5140878", 2184 | "tt5294550", 2185 | "tt5564148", 2186 | "tt5576318", 2187 | "tt5580036", 2188 | "tt5593416", 2189 | "tt5649144", 2190 | "tt5688868", 2191 | "tt5726086", 2192 | "tt5827496", 2193 | "tt5866930", 2194 | "tt6121428", 2195 | "tt6133130", 2196 | "tt6157626", 2197 | "tt6190198", 2198 | "tt6298600", 2199 | "tt6466464", 2200 | "tt6513406", 2201 | "tt6518634", 2202 | "tt6644200", 2203 | "tt6788942", 2204 | "tt7055592", 2205 | "tt7131870", 2206 | "tt7160070", 2207 | "tt7180392", 2208 | "tt7672188" 2209 | ] 2210 | } -------------------------------------------------------------------------------- /data/movienet_data.py: -------------------------------------------------------------------------------- 1 | from PIL import ImageFilter 2 | import random 3 | import torch 4 | import torchvision.transforms as transforms 5 | import json 6 | import cv2 7 | import numpy as np 8 | from torchvision import utils as vutils 9 | 10 | class TwoWayTransform: 11 | def __init__(self, base_transform_a, 12 | base_transform_b, fixed_aug_shot=True): 13 | self.base_transform_a = base_transform_a 14 | self.base_transform_b = base_transform_b 15 | self.fixed = fixed_aug_shot 16 | 17 | def __call__(self, x): 18 | frame_num = len(x) 19 | if self.fixed: 20 | seed = np.random.randint(2147483647) 21 | q, k = [], [] 22 | for i in range(frame_num): 23 | random.seed(seed) 24 | q.append(self.base_transform_a(x[i])) 25 | seed = np.random.randint(2147483647) 26 | for i in range(frame_num): 27 | random.seed(seed) 28 | k.append(self.base_transform_b(x[i])) 29 | else: 30 | q = [self.base_transform_a(x[i]) for i in range(frame_num)] 31 | k = [self.base_transform_b(x[i]) for i in range(frame_num)] 32 | q = torch.cat(q, axis = 0) 33 | k = torch.cat(k, axis = 0) 34 | return [q, k] 35 | 36 | 37 | class MovieNet_Shot_Dataset(torch.utils.data.Dataset): 38 | def __init__(self, img_path, shot_info_path, transform, 39 | shot_len = 16, frame_per_shot = 3, _Type='train'): 40 | self.img_path = img_path 41 | with open(shot_info_path, 'rb') as f: 42 | self.shot_info = json.load(f) 43 | self.img_path = img_path 44 | self.shot_len = shot_len 45 | self.frame_per_shot = frame_per_shot 46 | self.transform = transform 47 | self._Type = _Type.lower() 48 | assert self._Type in ['train','val','test'] 49 | self.idx_imdb_map = {} 50 | data_length = 0 51 | for imdb, shot_num in self.shot_info[_Type].items(): 52 | for i in range(shot_num // shot_len): 53 | self.idx_imdb_map[data_length] = (imdb, i) 54 | data_length += 1 55 | 56 | 57 | def __len__(self): 58 | return len(self.idx_imdb_map.keys()) 59 | 60 | 61 | def _transform(self, img_list): 62 | q, k = [], [] 63 | for item in img_list: 64 | out = self.transform(item) 65 | q.append(out[0]) 66 | k.append(out[1]) 67 | out_q = torch.stack(q, axis=0) 68 | out_k = torch.stack(k, axis=0) 69 | return [out_q, out_k] 70 | 71 | 72 | def _process_puzzle(self, idx): 73 | imdb, puzzle_id = self.idx_imdb_map[idx] 74 | img_path = f'{self.img_path}/{imdb}/{str(puzzle_id).zfill(4)}.jpg' 75 | img = cv2.imread(img_path) 76 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 77 | img = np.vsplit(img, self.shot_len) 78 | img = [np.hsplit(i, self.frame_per_shot) for i in img] 79 | data = self._transform(img) 80 | return data 81 | 82 | 83 | def __getitem__(self, idx): 84 | return self._process_puzzle(idx) 85 | 86 | 87 | 88 | class GaussianBlur: 89 | def __init__(self, sigma=[.1, 2.]): 90 | self.sigma = sigma 91 | 92 | def __call__(self, x): 93 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 94 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 95 | return x 96 | 97 | def get_train_loader(cfg): 98 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 99 | std=[0.229, 0.224, 0.225]) 100 | augmentation_base = [ 101 | transforms.ToPILImage(), 102 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)), 103 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5), 104 | transforms.RandomHorizontalFlip(), 105 | transforms.ToTensor(), 106 | normalize 107 | ] 108 | augmentation_color = [ 109 | transforms.ToPILImage(), 110 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)), 111 | transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.5), 112 | transforms.RandomGrayscale(p=0.2), 113 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5), 114 | transforms.RandomHorizontalFlip(), 115 | transforms.ToTensor(), 116 | normalize 117 | ] 118 | augmentation_q = augmentation_color if cfg['data']['color_aug_for_q'] else augmentation_base 119 | augmentation_k = augmentation_color if cfg['data']['color_aug_for_k'] else augmentation_base 120 | 121 | train_transform = TwoWayTransform( 122 | transforms.Compose(augmentation_q), 123 | transforms.Compose(augmentation_k), 124 | fixed_aug_shot=cfg['data']['fixed_aug_shot']) 125 | 126 | img_path = cfg['data']['data_path'] 127 | shot_info_path = cfg['data']['shot_info'] 128 | train_dataset = MovieNet_Shot_Dataset(img_path, shot_info_path, train_transform) 129 | train_sampler = None 130 | if cfg['DDP']['multiprocessing_distributed']: 131 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True) 132 | train_loader = torch.utils.data.DataLoader(train_dataset, 133 | batch_size=cfg['optim']['bs'], num_workers=cfg['data']['workers'], 134 | sampler=train_sampler, shuffle=(train_sampler is None), pin_memory=True, drop_last=True) 135 | return train_loader, train_sampler 136 | 137 | if __name__ == '__main__': 138 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 139 | std=[0.229, 0.224, 0.225]) 140 | augmentation_base = [ 141 | transforms.ToPILImage(), 142 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)), 143 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5), 144 | transforms.RandomHorizontalFlip(), 145 | transforms.ToTensor(), 146 | # normalize 147 | ] 148 | augmentation_color = [ 149 | transforms.ToPILImage(), 150 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)), 151 | transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.5), 152 | transforms.RandomGrayscale(p=0.2), 153 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5), 154 | transforms.RandomHorizontalFlip(), 155 | transforms.ToTensor(), 156 | # normalize 157 | ] 158 | train_transform = TwoWayTransform( 159 | transforms.Compose(augmentation_base), 160 | transforms.Compose(augmentation_color), 161 | fixed_aug_shot=False) 162 | img_path = './compressed_shot_images' 163 | shot_info_path = './MovieNet_shot_num.json' 164 | train_dataset = MovieNet_Shot_Dataset(img_path, shot_info_path, train_transform) 165 | print(f'len: {len(train_dataset)}') 166 | i = train_dataset[0] 167 | print(i[0].size()) 168 | 169 | 170 | 171 | 172 | 173 | 174 | -------------------------------------------------------------------------------- /extract_embeddings.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | import torch 4 | import argparse 5 | import time 6 | from models.backbones.visual.resnet import encoder_resnet50 7 | import json 8 | import cv2 9 | from torchvision import transforms 10 | from torch.utils.data import DataLoader 11 | 12 | class MovieNet_SingleShot_Dataset(torch.utils.data.Dataset): 13 | def __init__(self, img_path, shot_info_path, transform, 14 | frame_per_shot = 3, _Type='train'): 15 | self.img_path = img_path 16 | with open(shot_info_path, 'rb') as f: 17 | self.shot_info = json.load(f) 18 | self.img_path = img_path 19 | self.frame_per_shot = frame_per_shot 20 | self.transform = transform 21 | self._Type = _Type.lower() 22 | assert self._Type in ['train','val','test'] 23 | self.idx_imdb_map = {} 24 | data_length = 0 25 | for info in self.shot_info[_Type]: 26 | imdb = info['name'] 27 | for shot in info['label']: 28 | self.idx_imdb_map[data_length] = (imdb, shot[0], shot[1]) 29 | data_length += 1 30 | 31 | def __len__(self): 32 | return len(self.idx_imdb_map.keys()) 33 | 34 | def _process(self, idx): 35 | imdb, _id, label = self.idx_imdb_map[idx] 36 | img_path_0 = f'{self.img_path}/{imdb}/shot_{_id}_img_0.jpg' 37 | img_path_1 = f'{self.img_path}/{imdb}/shot_{_id}_img_1.jpg' 38 | img_path_2 = f'{self.img_path}/{imdb}/shot_{_id}_img_2.jpg' 39 | img_0 = cv2.cvtColor(cv2.imread(img_path_0), cv2.COLOR_BGR2RGB) 40 | img_1 = cv2.cvtColor(cv2.imread(img_path_1), cv2.COLOR_BGR2RGB) 41 | img_2 = cv2.cvtColor(cv2.imread(img_path_2), cv2.COLOR_BGR2RGB) 42 | data_0 = self.transform(img_0) 43 | data_1 = self.transform(img_1) 44 | data_2 = self.transform(img_2) 45 | data = torch.cat([data_0, data_1, data_2], axis=0) 46 | label = int(label) 47 | # According to LGSS[1] 48 | # [1] https://arxiv.org/abs/2004.02678 49 | if label == -1: 50 | label = 1 51 | return data, label, (imdb, _id) 52 | 53 | 54 | def __getitem__(self, idx): 55 | return self._process(idx) 56 | 57 | def get_loader(cfg, _Type='train'): 58 | normalize = transforms.Normalize( 59 | mean=[0.485, 0.456, 0.406], 60 | std=[0.229, 0.224, 0.225] 61 | ) 62 | 63 | _transform = transforms.Compose([ 64 | transforms.ToPILImage(), 65 | transforms.Resize(224), 66 | transforms.CenterCrop(224), 67 | transforms.ToTensor(), 68 | normalize, 69 | ]) 70 | dataset = MovieNet_SingleShot_Dataset( 71 | img_path = cfg.shot_img_path, 72 | shot_info_path = cfg.shot_info_path, 73 | transform = _transform, 74 | frame_per_shot = cfg.frame_per_shot, 75 | _Type=_Type, 76 | ) 77 | loader = DataLoader( 78 | dataset, batch_size=cfg.bs, drop_last=False, 79 | shuffle=False, num_workers=cfg.worker_num, pin_memory=True 80 | ) 81 | return loader 82 | 83 | def get_encoder(model_name='resnet50', weight_path='', input_channel=9): 84 | encoder = None 85 | model_name = model_name.lower() 86 | if model_name == 'resnet50': 87 | encoder = encoder_resnet50(weight_path='',input_channel=input_channel) 88 | model_weight = torch.load(weight_path,map_location=torch.device('cpu'))['state_dict'] 89 | pretrained_dict = {} 90 | for k, v in model_weight.items(): 91 | # moco loading 92 | if k.startswith('module.encoder_k'): 93 | continue 94 | if k == 'module.queue' or k == 'module.queue_ptr': 95 | continue 96 | if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'): 97 | k = k[17:] 98 | 99 | pretrained_dict[k] = v 100 | encoder.load_state_dict(pretrained_dict, strict = False) 101 | print(f'loaded from {weight_path}') 102 | return encoder 103 | 104 | 105 | @torch.no_grad() 106 | def get_save_embeddings(model, loader, shot_num, filename, log_interval=100): 107 | # dict 108 | # key: index, value: [(embeddings, label), ...] 109 | embeddings = {} 110 | model.eval() 111 | 112 | print(f'total length of dataset: {len(loader.dataset)}') 113 | print(f'total length of loader: {len(loader)}') 114 | 115 | for batch_idx, (data, target, index) in enumerate(loader): 116 | if batch_idx % log_interval == 0: 117 | print(f'processed: {batch_idx}') 118 | 119 | data = data.cuda(non_blocking=True) # ([bs, shot_num, 9, 224, 224]) 120 | data = data.view(-1, 9, 224, 224) 121 | 122 | target = target.view(-1).cuda() 123 | output = model(data, False) # ([bs * shot_num, 2048]) 124 | for i, key in enumerate(index[0]): 125 | if key not in embeddings: 126 | embeddings[key] = [] 127 | t_emb = output[i*shot_num:(i+1)*shot_num].cpu().numpy() 128 | t_label = target[i].cpu().numpy() 129 | embeddings[key].append((t_emb.copy() ,t_label.copy())) 130 | pickle.dump(embeddings, open(filename, 'wb')) 131 | 132 | 133 | def extract_features(cfg): 134 | time_str = time.strftime("%Y-%m-%d_%H_%M_%S", time.localtime()) 135 | save_dir = os.path.join(cfg.save_dir, time_str) 136 | if not os.path.exists(save_dir): 137 | os.makedirs(save_dir) 138 | cfg.log_file = save_dir + '/extraction.log' 139 | encoder = get_encoder( 140 | model_name=cfg.model_name, 141 | weight_path=cfg.model_path, 142 | input_channel=cfg.frame_per_shot * 3 143 | ).cuda() 144 | dataType = [cfg.Type] 145 | if dataType[0] == 'all': 146 | dataType = ['train','test','val'] 147 | for _T in dataType: 148 | to_log(cfg, f'processing: {_T} \n') 149 | loader = get_loader(cfg, _Type = _T) 150 | filename = os.path.join(save_dir, _T+'.pkl') 151 | get_save_embeddings(encoder, 152 | loader, 153 | cfg.shot_num, 154 | filename, 155 | log_interval=100 156 | ) 157 | to_log(cfg, f'{_T} embeddings are saved in {filename}!\n') 158 | 159 | 160 | def to_log(cfg, content, echo=True): 161 | with open(cfg.log_file, 'a') as f: 162 | f.writelines(content+'\n') 163 | if echo: print(content) 164 | 165 | 166 | def get_config(): 167 | parser = argparse.ArgumentParser() 168 | parser.add_argument('model_path', type=str) 169 | parser.add_argument('--shot_info_path', type=str, 170 | default='./data/movie1K.scene_seg_318_name_index_shotnum_label.v1.json') 171 | parser.add_argument('--shot_img_path', type=str, default='./MovieNet_unzip/240P/') 172 | parser.add_argument('--Type', type=str, default='train', choices=['train','test','val','all']) 173 | parser.add_argument('--model_name', type=str, default='resnet50') 174 | parser.add_argument('--frame_per_shot', type=int, default=3) 175 | parser.add_argument('--shot_num', type=int, default=1) 176 | parser.add_argument('--worker_num', type=int, default=16) 177 | parser.add_argument('--bs', type=int, default=64) 178 | parser.add_argument('--save_dir', type=str, default='./embeddings/') 179 | parser.add_argument('--gpu-id', type=str, default='0') 180 | cfg = parser.parse_args() 181 | 182 | # select GPU 183 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 184 | os.environ["CUDA_VISIBLE_DEVICES"] = cfg.gpu_id 185 | 186 | return cfg 187 | 188 | 189 | if __name__ == '__main__': 190 | cfg = get_config() 191 | extract_features(cfg) -------------------------------------------------------------------------------- /figures/puzzle_example.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentYoutuResearch/SceneSegmentation-SCRL/7d2daed4c8f1922aa6c85abaf9db36abaf0ae67e/figures/puzzle_example.jpg -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentYoutuResearch/SceneSegmentation-SCRL/7d2daed4c8f1922aa6c85abaf9db36abaf0ae67e/models/__init__.py -------------------------------------------------------------------------------- /models/backbones/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentYoutuResearch/SceneSegmentation-SCRL/7d2daed4c8f1922aa6c85abaf9db36abaf0ae67e/models/backbones/__init__.py -------------------------------------------------------------------------------- /models/backbones/visual/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | import torch.nn as nn 4 | from typing import Type, Any, Callable, Union, List, Optional 5 | 6 | 7 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 8 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 9 | 'wide_resnet50_2', 'wide_resnet101_2'] 10 | 11 | 12 | model_urls = { 13 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 14 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 15 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 16 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 17 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 18 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 19 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 20 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 21 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 22 | } 23 | 24 | 25 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: 26 | """3x3 convolution with padding""" 27 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 28 | padding=dilation, groups=groups, bias=False, dilation=dilation) 29 | 30 | 31 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 32 | """1x1 convolution""" 33 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 34 | 35 | 36 | class BasicBlock(nn.Module): 37 | expansion: int = 1 38 | 39 | def __init__( 40 | self, 41 | inplanes: int, 42 | planes: int, 43 | stride: int = 1, 44 | downsample: Optional[nn.Module] = None, 45 | groups: int = 1, 46 | base_width: int = 64, 47 | dilation: int = 1, 48 | norm_layer: Optional[Callable[..., nn.Module]] = None 49 | ) -> None: 50 | super(BasicBlock, self).__init__() 51 | if norm_layer is None: 52 | norm_layer = nn.BatchNorm2d 53 | if groups != 1 or base_width != 64: 54 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 55 | if dilation > 1: 56 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 57 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 58 | self.conv1 = conv3x3(inplanes, planes, stride) 59 | self.bn1 = norm_layer(planes) 60 | self.relu = nn.ReLU(inplace=True) 61 | self.conv2 = conv3x3(planes, planes) 62 | self.bn2 = norm_layer(planes) 63 | self.downsample = downsample 64 | self.stride = stride 65 | 66 | def forward(self, x: Tensor) -> Tensor: 67 | identity = x 68 | 69 | out = self.conv1(x) 70 | out = self.bn1(out) 71 | out = self.relu(out) 72 | 73 | out = self.conv2(out) 74 | out = self.bn2(out) 75 | 76 | if self.downsample is not None: 77 | identity = self.downsample(x) 78 | 79 | out += identity 80 | out = self.relu(out) 81 | 82 | return out 83 | 84 | 85 | class Bottleneck(nn.Module): 86 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 87 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 88 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 89 | # This variant is also known as ResNet V1.5 and improves accuracy according to 90 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 91 | 92 | expansion: int = 4 93 | 94 | def __init__( 95 | self, 96 | inplanes: int, 97 | planes: int, 98 | stride: int = 1, 99 | downsample: Optional[nn.Module] = None, 100 | groups: int = 1, 101 | base_width: int = 64, 102 | dilation: int = 1, 103 | norm_layer: Optional[Callable[..., nn.Module]] = None 104 | ) -> None: 105 | super(Bottleneck, self).__init__() 106 | if norm_layer is None: 107 | norm_layer = nn.BatchNorm2d 108 | width = int(planes * (base_width / 64.)) * groups 109 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 110 | self.conv1 = conv1x1(inplanes, width) 111 | self.bn1 = norm_layer(width) 112 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 113 | self.bn2 = norm_layer(width) 114 | self.conv3 = conv1x1(width, planes * self.expansion) 115 | self.bn3 = norm_layer(planes * self.expansion) 116 | self.relu = nn.ReLU(inplace=True) 117 | self.downsample = downsample 118 | self.stride = stride 119 | 120 | def forward(self, x: Tensor) -> Tensor: 121 | identity = x 122 | 123 | out = self.conv1(x) 124 | out = self.bn1(out) 125 | out = self.relu(out) 126 | 127 | out = self.conv2(out) 128 | out = self.bn2(out) 129 | out = self.relu(out) 130 | 131 | out = self.conv3(out) 132 | out = self.bn3(out) 133 | 134 | if self.downsample is not None: 135 | identity = self.downsample(x) 136 | 137 | out += identity 138 | out = self.relu(out) 139 | 140 | return out 141 | 142 | 143 | class ResNet(nn.Module): 144 | 145 | def __init__( 146 | self, 147 | block: Type[Union[BasicBlock, Bottleneck]], 148 | layers: List[int], 149 | input_channel:int = 3, 150 | num_classes: int = 1000, 151 | zero_init_residual: bool = True, 152 | groups: int = 1, 153 | width_per_group: int = 64, 154 | replace_stride_with_dilation: Optional[List[bool]] = None, 155 | norm_layer: Optional[Callable[..., nn.Module]] = None 156 | ) -> None: 157 | super(ResNet, self).__init__() 158 | if norm_layer is None: 159 | norm_layer = nn.BatchNorm2d 160 | self._norm_layer = norm_layer 161 | 162 | self.inplanes = 64 163 | self.dilation = 1 164 | if replace_stride_with_dilation is None: 165 | # each element in the tuple indicates if we should replace 166 | # the 2x2 stride with a dilated convolution instead 167 | replace_stride_with_dilation = [False, False, False] 168 | if len(replace_stride_with_dilation) != 3: 169 | raise ValueError("replace_stride_with_dilation should be None " 170 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 171 | self.groups = groups 172 | self.base_width = width_per_group 173 | self.conv1 = nn.Conv2d(input_channel, self.inplanes, kernel_size=7, stride=2, padding=3, 174 | bias=False) 175 | self.bn1 = norm_layer(self.inplanes) 176 | self.relu = nn.ReLU(inplace=True) 177 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 178 | self.layer1 = self._make_layer(block, 64, layers[0]) 179 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 180 | dilate=replace_stride_with_dilation[0]) 181 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 182 | dilate=replace_stride_with_dilation[1]) 183 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 184 | dilate=replace_stride_with_dilation[2]) 185 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 186 | self.fc = nn.Linear(512 * block.expansion, num_classes) 187 | 188 | for m in self.modules(): 189 | if isinstance(m, nn.Conv2d): 190 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 191 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 192 | nn.init.constant_(m.weight, 1) 193 | nn.init.constant_(m.bias, 0) 194 | 195 | # Zero-initialize the last BN in each residual branch, 196 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 197 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 198 | if zero_init_residual: 199 | for m in self.modules(): 200 | if isinstance(m, Bottleneck): 201 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] 202 | elif isinstance(m, BasicBlock): 203 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] 204 | 205 | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, 206 | stride: int = 1, dilate: bool = False) -> nn.Sequential: 207 | norm_layer = self._norm_layer 208 | downsample = None 209 | previous_dilation = self.dilation 210 | if dilate: 211 | self.dilation *= stride 212 | stride = 1 213 | if stride != 1 or self.inplanes != planes * block.expansion: 214 | downsample = nn.Sequential( 215 | conv1x1(self.inplanes, planes * block.expansion, stride), 216 | norm_layer(planes * block.expansion), 217 | ) 218 | 219 | layers = [] 220 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 221 | self.base_width, previous_dilation, norm_layer)) 222 | self.inplanes = planes * block.expansion 223 | for _ in range(1, blocks): 224 | layers.append(block(self.inplanes, planes, groups=self.groups, 225 | base_width=self.base_width, dilation=self.dilation, 226 | norm_layer=norm_layer)) 227 | 228 | return nn.Sequential(*layers) 229 | 230 | def _forward_impl(self, x: Tensor, is_fc: bool) -> Tensor: 231 | # See note [TorchScript super()] 232 | x = self.conv1(x) 233 | x = self.bn1(x) 234 | x = self.relu(x) 235 | x = self.maxpool(x) 236 | 237 | x = self.layer1(x) 238 | x = self.layer2(x) 239 | x = self.layer3(x) 240 | x = self.layer4(x) 241 | 242 | x = self.avgpool(x) 243 | x = torch.flatten(x, 1) 244 | if is_fc: 245 | x = self.fc(x) 246 | return x 247 | 248 | def forward(self, x: Tensor, is_fc=True) -> Tensor: 249 | return self._forward_impl(x, is_fc) 250 | 251 | 252 | class Encoder(ResNet): 253 | def __init__(self, 254 | input_channel:int, 255 | block: Type[Union[BasicBlock, Bottleneck]], 256 | layers: List[int], 257 | weight_path: str, 258 | num_classes: int = 2048, 259 | **kwargs: Any 260 | ) -> None: 261 | super(Encoder, self).__init__(block, layers, input_channel, num_classes, **kwargs) 262 | self.input_channel = input_channel 263 | if weight_path is not None and len(weight_path) > 1: 264 | print(f'loading weight from {weight_path}') 265 | self._load_from_weight(weight_path) 266 | 267 | def _load_from_weight(self, weight_path: str): 268 | model_weight = torch.load(weight_path) 269 | pretrained_dict = {} 270 | for k, v in model_weight.items(): 271 | if k.startswith('conv1') or k.startswith('fc'): 272 | continue 273 | pretrained_dict[k] = v 274 | self.load_state_dict(pretrained_dict, strict = False) 275 | 276 | 277 | def _forward_fc(self, x: Tensor): 278 | x = self.fc(x) 279 | return x 280 | 281 | 282 | 283 | 284 | def _resnet( 285 | arch: str, 286 | block: Type[Union[BasicBlock, Bottleneck]], 287 | layers: List[int], 288 | pretrained: bool, 289 | progress: bool, 290 | **kwargs: Any 291 | ) -> ResNet: 292 | model = ResNet(block, layers, **kwargs) 293 | if pretrained: 294 | state_dict = load_state_dict_from_url(model_urls[arch], 295 | progress=progress) 296 | model.load_state_dict(state_dict) 297 | return model 298 | 299 | 300 | def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 301 | r"""ResNet-18 model from 302 | `"Deep Residual Learning for Image Recognition" `_. 303 | 304 | Args: 305 | pretrained (bool): If True, returns a model pre-trained on ImageNet 306 | progress (bool): If True, displays a progress bar of the download to stderr 307 | """ 308 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 309 | **kwargs) 310 | 311 | 312 | 313 | def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 314 | r"""ResNet-34 model from 315 | `"Deep Residual Learning for Image Recognition" `_. 316 | 317 | Args: 318 | pretrained (bool): If True, returns a model pre-trained on ImageNet 319 | progress (bool): If True, displays a progress bar of the download to stderr 320 | """ 321 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 322 | **kwargs) 323 | 324 | 325 | 326 | def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 327 | r"""ResNet-50 model from 328 | `"Deep Residual Learning for Image Recognition" `_. 329 | 330 | Args: 331 | pretrained (bool): If True, returns a model pre-trained on ImageNet 332 | progress (bool): If True, displays a progress bar of the download to stderr 333 | """ 334 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 335 | **kwargs) 336 | 337 | 338 | 339 | def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 340 | r"""ResNet-101 model from 341 | `"Deep Residual Learning for Image Recognition" `_. 342 | 343 | Args: 344 | pretrained (bool): If True, returns a model pre-trained on ImageNet 345 | progress (bool): If True, displays a progress bar of the download to stderr 346 | """ 347 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 348 | **kwargs) 349 | 350 | 351 | 352 | def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 353 | r"""ResNet-152 model from 354 | `"Deep Residual Learning for Image Recognition" `_. 355 | 356 | Args: 357 | pretrained (bool): If True, returns a model pre-trained on ImageNet 358 | progress (bool): If True, displays a progress bar of the download to stderr 359 | """ 360 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 361 | **kwargs) 362 | 363 | 364 | 365 | def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 366 | r"""ResNeXt-50 32x4d model from 367 | `"Aggregated Residual Transformation for Deep Neural Networks" `_. 368 | 369 | Args: 370 | pretrained (bool): If True, returns a model pre-trained on ImageNet 371 | progress (bool): If True, displays a progress bar of the download to stderr 372 | """ 373 | kwargs['groups'] = 32 374 | kwargs['width_per_group'] = 4 375 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 376 | pretrained, progress, **kwargs) 377 | 378 | 379 | 380 | def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 381 | r"""ResNeXt-101 32x8d model from 382 | `"Aggregated Residual Transformation for Deep Neural Networks" `_. 383 | 384 | Args: 385 | pretrained (bool): If True, returns a model pre-trained on ImageNet 386 | progress (bool): If True, displays a progress bar of the download to stderr 387 | """ 388 | kwargs['groups'] = 32 389 | kwargs['width_per_group'] = 8 390 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 391 | pretrained, progress, **kwargs) 392 | 393 | 394 | 395 | def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 396 | r"""Wide ResNet-50-2 model from 397 | `"Wide Residual Networks" `_. 398 | 399 | The model is the same as ResNet except for the bottleneck number of channels 400 | which is twice larger in every block. The number of channels in outer 1x1 401 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 402 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 403 | 404 | Args: 405 | pretrained (bool): If True, returns a model pre-trained on ImageNet 406 | progress (bool): If True, displays a progress bar of the download to stderr 407 | """ 408 | kwargs['width_per_group'] = 64 * 2 409 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 410 | pretrained, progress, **kwargs) 411 | 412 | 413 | 414 | def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 415 | r"""Wide ResNet-101-2 model from 416 | `"Wide Residual Networks" `_. 417 | 418 | The model is the same as ResNet except for the bottleneck number of channels 419 | which is twice larger in every block. The number of channels in outer 1x1 420 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 421 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 422 | 423 | Args: 424 | pretrained (bool): If True, returns a model pre-trained on ImageNet 425 | progress (bool): If True, displays a progress bar of the download to stderr 426 | """ 427 | kwargs['width_per_group'] = 64 * 2 428 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 429 | pretrained, progress, **kwargs) 430 | 431 | 432 | def encoder_resnet50(input_channel:int = 9, weight_path: str = '', progress: bool = True, num_classes=2048, 433 | **kwargs: Any) -> Encoder: 434 | 435 | return Encoder(input_channel, Bottleneck, [3, 4, 6, 3], weight_path, num_classes, **kwargs) 436 | 437 | 438 | def get_encoder(model_name='resnet50', weight_path='', modal='v', input_channel=9, ssl_type='moco'): 439 | encoder = None 440 | model_name = model_name.lower() 441 | modal = modal.lower() 442 | ssl_type = ssl_type.lower() 443 | 444 | if modal.startswith('v'): 445 | if model_name == 'resnet50': 446 | if 'resnet50-19c8e357' in weight_path or len(weight_path) < 1: 447 | encoder = encoder_resnet50(weight_path=weight_path,input_channel=input_channel) 448 | else: 449 | encoder = encoder_resnet50(weight_path='',input_channel=input_channel) 450 | # model_weight = torch.load(weight_path)['state_dict'] 451 | model_weight = torch.load(weight_path,map_location=torch.device('cpu'))['state_dict'] 452 | pretrained_dict = {} 453 | for k, v in model_weight.items(): 454 | # print(k) 455 | if ssl_type == 'moco': 456 | # moco loading 457 | if k.startswith('module.encoder_k'): 458 | continue 459 | if k == 'module.queue' or k == 'module.queue_ptr': 460 | continue 461 | if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'): 462 | k = k[17:] 463 | else: 464 | # simsiam loading 465 | if k.startswith('module.encoder') and not k.startswith('module.encoder.fc'): 466 | k = k[15:] 467 | 468 | pretrained_dict[k] = v 469 | encoder.load_state_dict(pretrained_dict, strict = False) 470 | 471 | return encoder 472 | 473 | 474 | 475 | if __name__ == '__main__': 476 | model = encoder_resnet50(input_channel=9, weight_path='./pretrained/resnet50-19c8e357.pth') 477 | x = torch.randn(2,9,224,224) 478 | out = model(x) 479 | print(out.size()) 480 | -------------------------------------------------------------------------------- /models/core/SCRL_MoCo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from cluster.Group import Cluster_GPU 4 | 5 | class SCRL(nn.Module): 6 | """ 7 | Referenced from MoCo[1] and SCRL[2]. 8 | [1] https://arxiv.org/abs/1911.05722 9 | [2] https://arxiv.org/abs/2205.05487 10 | """ 11 | def __init__(self, base_encoder, dim=2048, K=65536, 12 | m=0.999, T=0.07, mlp=False, 13 | encoder_pretrained_path: str ='', 14 | multi_positive = False, 15 | positive_selection = 'cluster', 16 | cluster_num = 10, 17 | soft_gamma=0.5): 18 | super(SCRL, self).__init__() 19 | 20 | self.K = K 21 | self.m = m 22 | self.T = T 23 | self.dim = dim 24 | self.multi_positive = multi_positive 25 | self.forward_fn = self.forward_SCRL 26 | self.cluster_num = cluster_num 27 | self.soft_gamma = soft_gamma 28 | assert self.cluster_num > 0 29 | 30 | # positive selection strategy 31 | if 'cluster' in positive_selection: 32 | self.selection_fn = self.get_q_and_k_index_cluster 33 | self.cluster_obj = Cluster_GPU(self.cluster_num) 34 | else: 35 | raise NotImplementedError 36 | 37 | self.encoder_q = base_encoder(weight_path = encoder_pretrained_path) 38 | self.encoder_k = base_encoder(weight_path = encoder_pretrained_path) 39 | self.mlp = mlp 40 | 41 | # hack: brute-force replacement 42 | if mlp: 43 | dim_mlp = self.encoder_q.fc.weight.shape[1] 44 | self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc) 45 | self.encoder_k.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc) 46 | 47 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 48 | param_k.data.copy_(param_q.data) 49 | param_k.requires_grad = False 50 | 51 | # create the queue 52 | self.register_buffer("queue", torch.randn(dim, K)) 53 | self.queue = nn.functional.normalize(self.queue, dim=0) 54 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 55 | 56 | @torch.no_grad() 57 | def _momentum_update_key_encoder(self): 58 | """ 59 | Momentum update of the key encoder 60 | """ 61 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 62 | param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) 63 | 64 | @torch.no_grad() 65 | def _dequeue_and_enqueue(self, keys): 66 | # gather keys before updating queue 67 | keys = concat_all_gather(keys) 68 | 69 | batch_size = keys.shape[0] 70 | 71 | ptr = int(self.queue_ptr) 72 | 73 | assert self.K % batch_size == 0 # for simplicity 74 | 75 | # replace the keys at ptr (dequeue and enqueue) 76 | self.queue[:, ptr:ptr + batch_size] = keys.T 77 | ptr = (ptr + batch_size) % self.K # move pointer 78 | 79 | self.queue_ptr[0] = ptr 80 | 81 | @torch.no_grad() 82 | def _batch_shuffle_ddp(self, x): 83 | """ 84 | Batch shuffle, for making use of BatchNorm. 85 | *** Only support DistributedDataParallel (DDP) model. *** 86 | """ 87 | # gather from all gpus 88 | batch_size_this = x.shape[0] 89 | x_gather = concat_all_gather(x) 90 | batch_size_all = x_gather.shape[0] 91 | 92 | num_gpus = batch_size_all // batch_size_this 93 | 94 | # random shuffle index 95 | idx_shuffle = torch.randperm(batch_size_all).cuda() 96 | 97 | # broadcast to all gpus 98 | torch.distributed.broadcast(idx_shuffle, src=0) 99 | 100 | # index for restoring 101 | idx_unshuffle = torch.argsort(idx_shuffle) 102 | 103 | # shuffled index for this gpu 104 | gpu_idx = torch.distributed.get_rank() 105 | idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] 106 | 107 | return x_gather[idx_this], idx_unshuffle 108 | 109 | 110 | @torch.no_grad() 111 | def _batch_unshuffle_ddp(self, x, idx_unshuffle): 112 | """ 113 | Undo batch shuffle. 114 | *** Only support DistributedDataParallel (DDP) model. *** 115 | """ 116 | # gather from all gpus 117 | batch_size_this = x.shape[0] 118 | x_gather = concat_all_gather(x) 119 | batch_size_all = x_gather.shape[0] 120 | 121 | num_gpus = batch_size_all // batch_size_this 122 | 123 | # restored index for this gpu 124 | gpu_idx = torch.distributed.get_rank() 125 | idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx] 126 | 127 | return x_gather[idx_this] 128 | 129 | @torch.no_grad() 130 | def get_q_and_k_index_cluster(self, embeddings, return_group=False) -> tuple: 131 | 132 | B = embeddings.size(0) 133 | target_index = list(range(0, B)) 134 | q_index = target_index 135 | 136 | choice_cluster, choice_points = self.cluster_obj(embeddings) 137 | k_index = [] 138 | for c in choice_cluster: 139 | k_index.append(int(choice_points[c])) 140 | if return_group: 141 | return (q_index, k_index, choice_cluster, choice_points) 142 | else: 143 | return (q_index, k_index) 144 | 145 | 146 | def forward(self, img_q, img_k): 147 | """ 148 | Input: 149 | query , key (images) 150 | Output: 151 | logits, targets 152 | """ 153 | return self.forward_fn(img_q, img_k) 154 | 155 | 156 | def forward_SCRL(self, img_q, img_k): 157 | # compute query features 158 | embeddings = self.encoder_q(img_q, self.mlp) 159 | embeddings = nn.functional.normalize(embeddings, dim=1) 160 | 161 | # get q and k index 162 | index_q, index_k = self.selection_fn(embeddings) 163 | 164 | # features of q 165 | q = embeddings[index_q] 166 | 167 | # compute key features 168 | with torch.no_grad(): 169 | # update the key encoder 170 | self._momentum_update_key_encoder() 171 | 172 | # shuffle for making use of BN 173 | img_k, idx_unshuffle = self._batch_shuffle_ddp(img_k) 174 | 175 | k = self.encoder_k(img_k, self.mlp) 176 | k = nn.functional.normalize(k, dim=1) 177 | 178 | # undo shuffle 179 | k = self._batch_unshuffle_ddp(k, idx_unshuffle) 180 | 181 | k_ori = k 182 | k = k[index_k] 183 | 184 | # compute logits 185 | # positive logits: Nx1 186 | if self.multi_positive: 187 | # SCRL Soft-SC 188 | k = (k + k_ori) * self.soft_gamma 189 | 190 | 191 | l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) 192 | 193 | # negative logits: NxK 194 | l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) 195 | 196 | # logits: Nx(1+K) 197 | logits = torch.cat([l_pos, l_neg], dim=1) 198 | 199 | # apply temperature 200 | logits /= self.T 201 | 202 | # labels: positive key indicators 203 | labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() 204 | 205 | # dequeue and enqueue 206 | self._dequeue_and_enqueue(k) 207 | 208 | return logits, labels 209 | 210 | # the old moco forward func 211 | def forward_moco_old(self, im_q, im_k): 212 | """ 213 | Input: 214 | im_q: a batch of query images 215 | im_k: a batch of key images 216 | Output: 217 | logits, targets 218 | """ 219 | 220 | # compute query features 221 | q = self.encoder_q(im_q) # queries: NxC 222 | q = nn.functional.normalize(q, dim=1) 223 | 224 | # compute key features 225 | with torch.no_grad(): # no gradient to keys 226 | self._momentum_update_key_encoder() # update the key encoder 227 | 228 | # shuffle for making use of BN 229 | im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k) 230 | 231 | k = self.encoder_k(im_k) # keys: NxC 232 | k = nn.functional.normalize(k, dim=1) 233 | 234 | # undo shuffle 235 | k = self._batch_unshuffle_ddp(k, idx_unshuffle) 236 | 237 | # compute logits 238 | # Einstein sum is more intuitive 239 | # positive logits: Nx1 240 | l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) 241 | # negative logits: NxK 242 | l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) 243 | 244 | # logits: Nx(1+K) 245 | logits = torch.cat([l_pos, l_neg], dim=1) 246 | 247 | # apply temperature 248 | logits /= self.T 249 | 250 | # labels: positive key indicators 251 | labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() 252 | 253 | # dequeue and enqueue 254 | self._dequeue_and_enqueue(k) 255 | 256 | return logits, labels 257 | 258 | 259 | # utils 260 | @torch.no_grad() 261 | def concat_all_gather(tensor): 262 | """ 263 | Performs all_gather operation on the provided tensors. 264 | *** Warning ***: torch.distributed.all_gather has no gradient. 265 | """ 266 | tensors_gather = [torch.ones_like(tensor) 267 | for _ in range(torch.distributed.get_world_size())] 268 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 269 | 270 | output = torch.cat(tensors_gather, dim=0) 271 | return output 272 | -------------------------------------------------------------------------------- /models/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentYoutuResearch/SceneSegmentation-SCRL/7d2daed4c8f1922aa6c85abaf9db36abaf0ae67e/models/core/__init__.py -------------------------------------------------------------------------------- /models/factory.py: -------------------------------------------------------------------------------- 1 | import models.backbones.visual.resnet as resnet 2 | from models.core.SCRL_MoCo import SCRL 3 | from data.movienet_data import get_train_loader 4 | import torch, os 5 | from utils import to_log 6 | 7 | def get_model(cfg): 8 | encoder = None 9 | model = None 10 | if 'multimodal' not in cfg or cfg['multimodal']['using_audio'] == False: 11 | encoder = resnet.encoder_resnet50 12 | else: 13 | raise NotImplementedError 14 | assert encoder is not None 15 | 16 | to_log(cfg, 'backbone init: ' + cfg['model']['backbone'], True) 17 | 18 | if cfg['model']['SSL'] == 'SCRL': 19 | model = SCRL( 20 | base_encoder = encoder, 21 | dim = cfg['MoCo']['dim'], 22 | K = cfg['MoCo']['k'], 23 | m = cfg['MoCo']['m'], 24 | T = cfg['MoCo']['t'], 25 | mlp = cfg['MoCo']['mlp'], 26 | encoder_pretrained_path = cfg['model']['backbone_pretrain'], 27 | multi_positive = cfg['MoCo']['multi_positive'], 28 | positive_selection = cfg['model']['Positive_Selection'], 29 | cluster_num = cfg['model']['cluster_num'], 30 | soft_gamma = cfg['model']['soft_gamma'], 31 | ) 32 | else: 33 | raise NotImplementedError 34 | to_log(cfg, 'model init: ' + cfg['model']['SSL'], True) 35 | 36 | if cfg['model']['SyncBatchNorm']: 37 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 38 | to_log(cfg, 'SyncBatchNorm: on' if cfg['model']['SyncBatchNorm'] else 'SyncBatchNorm: off', True) 39 | return model 40 | 41 | def get_loader(cfg): 42 | train_loader, train_sampler = get_train_loader(cfg) 43 | return train_loader, train_sampler 44 | 45 | 46 | def get_criterion(cfg): 47 | criterion = None 48 | if cfg['model']['SSL'] == 'simsiam': 49 | criterion = torch.nn.CosineSimilarity(dim=1) 50 | elif cfg['model']['SSL'] == 'SCRL': 51 | criterion = torch.nn.CrossEntropyLoss() 52 | else: 53 | raise NotImplementedError 54 | to_log(cfg, 'criterion init: ' + str(criterion), True) 55 | return criterion 56 | 57 | def get_optimizer(cfg, model): 58 | optimizer = None 59 | if cfg['optim']['optimizer'] == 'sgd': 60 | if cfg['model']['SSL'] == 'simsiam': 61 | if cfg['model']['fix_pred_lr']: 62 | optim_params = [{'params': model.module.encoder.parameters(), 'fix_lr': False}, 63 | {'params': model.module.predictor.parameters(), 'fix_lr': True}] 64 | else: 65 | optim_params = model.parameters() 66 | elif cfg['model']['SSL'] == 'SCRL': 67 | optim_params = model.parameters() 68 | else: 69 | raise NotImplementedError 70 | 71 | optimizer = torch.optim.SGD(optim_params, cfg['optim']['lr'], 72 | momentum=cfg['optim']['momentum'], 73 | weight_decay=cfg['optim']['wd']) 74 | else: 75 | raise NotImplementedError 76 | return optimizer 77 | 78 | def get_training_stuff(cfg, gpu, ngpus_per_node): 79 | cfg['optim']['bs'] = int(cfg['optim']['bs'] / ngpus_per_node) 80 | to_log(cfg, 'shot per GPU: ' + str(cfg['optim']['bs']), True) 81 | 82 | if cfg['data']['clipshuffle']: 83 | len_per_data = cfg['data']['clipshuffle_len'] 84 | else: 85 | len_per_data = 1 86 | assert cfg['optim']['bs'] % len_per_data == 0 87 | cfg['optim']['bs'] = int(cfg['optim']['bs'] / len_per_data ) 88 | cfg['data']['workers'] = int(( cfg['data']['workers'] + ngpus_per_node - 1) / ngpus_per_node) 89 | to_log(cfg, 'batch size per GPU: ' + str(cfg['optim']['bs']), True) 90 | to_log(cfg, 'worker per GPU: ' + str(cfg['data']['workers']) , True) 91 | 92 | train_loader, train_sampler = get_train_loader(cfg) 93 | model = get_model(cfg) 94 | model.cuda(gpu) 95 | model = torch.nn.parallel.DistributedDataParallel(model, 96 | device_ids=[gpu], 97 | output_device=gpu, 98 | find_unused_parameters=True) 99 | 100 | criterion = get_criterion(cfg).cuda(gpu) 101 | optimizer = get_optimizer(cfg, model) 102 | cfg['optim']['start_epoch'] = 0 103 | resume = cfg['model']['resume'] 104 | if resume is not None and len(resume) > 1: 105 | if os.path.isfile(resume): 106 | to_log(cfg, "=> loading checkpoint '{}'".format(resume), True) 107 | if gpu is None: 108 | checkpoint = torch.load(resume) 109 | else: 110 | loc = f'cuda:{gpu}' 111 | checkpoint = torch.load(resume, map_location=loc) 112 | start_epoch = checkpoint['epoch'] 113 | cfg['optim']['start_epoch'] = start_epoch 114 | model.load_state_dict(checkpoint['state_dict']) 115 | optimizer.load_state_dict(checkpoint['optimizer']) 116 | to_log(cfg, "=> loaded checkpoint '{}' (epoch {})" 117 | .format(resume, checkpoint['epoch']), True) 118 | else: 119 | to_log(cfg, "=> no checkpoint found at '{}'".format(resume), True) 120 | raise FileNotFoundError 121 | 122 | 123 | assert model is not None \ 124 | and train_loader is not None \ 125 | and criterion is not None \ 126 | and optimizer is not None 127 | 128 | return (model, train_loader, train_sampler, criterion, optimizer) 129 | -------------------------------------------------------------------------------- /pretrain_main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import yaml 3 | import os, shutil 4 | import builtins 5 | import math 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.parallel 10 | import torch.backends.cudnn as cudnn 11 | import torch.distributed as dist 12 | import torch.optim 13 | import torch.multiprocessing as mp 14 | import random 15 | import numpy as np 16 | 17 | from models import factory 18 | from utils import to_log, set_log 19 | from pretrain_trainer import train_SCRL 20 | 21 | 22 | def start_training(cfg): 23 | # only multiprocessing_distributed is supported 24 | if cfg['DDP']['multiprocessing_distributed']: 25 | 26 | ngpus_per_node = torch.cuda.device_count() 27 | 28 | if cfg['DDP']['dist_url'] == "env://": 29 | os.environ['MASTER_ADDR'] = cfg['DDP']['master_ip'] 30 | os.environ['MASTER_PORT'] = str(cfg['DDP']['master_port']) 31 | os.environ['WORLD_SIZE'] = str(ngpus_per_node * cfg['DDP']['machine_num']) 32 | os.environ['NODE_RANK'] = str(cfg['DDP']['node_num']) 33 | os.environ['NUM_NODES'] = str(cfg['DDP']['machine_num']) 34 | os.environ['NUM_GPUS_PER_NODE'] = str(ngpus_per_node) 35 | # os.environ['NCCL_IB_DISABLE'] = "1" 36 | 37 | cfg['DDP']['world_size'] = ngpus_per_node * cfg['DDP']['machine_num'] 38 | print(cfg['DDP']['world_size'], ngpus_per_node) 39 | 40 | mp.spawn(task_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, cfg)) 41 | 42 | 43 | def setup_worker(seed, gpu): 44 | torch.manual_seed(seed) 45 | torch.cuda.manual_seed_all(seed) 46 | np.random.seed(seed) 47 | random.seed(seed) 48 | cudnn.benchmark = True 49 | torch.cuda.set_device(gpu) 50 | 51 | def task_worker(gpu, ngpus_per_node, cfg): 52 | setup_worker(seed = 100, gpu = gpu) 53 | if gpu != 0: 54 | def print_pass(*args): 55 | pass 56 | builtins.print = print_pass 57 | 58 | cfg['DDP']['rank']= cfg['DDP']['node_num'] * ngpus_per_node + gpu 59 | cfg['DDP']['gpu'] = gpu 60 | if cfg['DDP']['dist_url'] == 'env://': 61 | os.environ['RANK'] = str(cfg['DDP']['rank']) 62 | 63 | print(cfg['DDP']['dist_backend'], cfg['DDP']['dist_url'], cfg['DDP']['world_size'],cfg['DDP']['rank'] ) 64 | dist.init_process_group(backend=cfg['DDP']['dist_backend'], init_method=cfg['DDP']['dist_url']) 65 | 66 | if gpu == 0: 67 | to_log(cfg, 'DDP init succeed!', True) 68 | 69 | model, train_loader, train_sampler, criterion, optimizer \ 70 | = factory.get_training_stuff(cfg, gpu, ngpus_per_node) 71 | 72 | # training function 73 | if cfg['model']['SSL'] == 'SCRL': 74 | train_fun = train_SCRL 75 | else: 76 | raise NotImplementedError 77 | 78 | start_epoch = cfg['optim']['start_epoch'] 79 | end_epoch = cfg['optim']['epochs'] 80 | 81 | assert train_fun is not None 82 | for epoch in range(start_epoch, end_epoch): 83 | train_sampler.set_epoch(epoch) 84 | adjust_learning_rate(optimizer, cfg['optim']['lr'], epoch, cfg) 85 | train_fun(gpu, train_loader, model, criterion, optimizer, epoch, cfg) 86 | if cfg['DDP']['rank'] == 0 and (epoch + 1) % 4 == 0: 87 | save_checkpoint(cfg,{ 88 | 'epoch': epoch + 1, 89 | 'arch': cfg['model']['backbone'], 90 | 'state_dict': model.state_dict(), 91 | 'optimizer' : optimizer.state_dict(), 92 | }, is_best=False, filename='checkpoint_{:04d}.pth.tar'.format(epoch)) 93 | 94 | 95 | def adjust_learning_rate(optimizer, init_lr, epoch, cfg): 96 | """Decay the learning rate based on schedule""" 97 | if cfg['optim']['lr_cos'] == True: 98 | cur_lr = init_lr * 0.5 * (1. + math.cos(0.5 * math.pi * epoch / cfg['optim']['epochs'])) 99 | else: 100 | cur_lr = init_lr 101 | for milestone in cfg['optim']['schedule']: 102 | cur_lr *= 0.1 if epoch >= milestone else 1. 103 | for param_group in optimizer.param_groups: 104 | if 'fix_lr' in param_group and param_group['fix_lr']: 105 | param_group['lr'] = init_lr 106 | else: 107 | param_group['lr'] = cur_lr 108 | 109 | 110 | 111 | def save_checkpoint(cfg, state, is_best, filename='checkpoint.pth.tar'): 112 | p = os.path.join(cfg['log']['dir'], 'checkpoints') 113 | if not os.path.exists(p): 114 | os.makedirs(p) 115 | 116 | torch.save(state, os.path.join(p, filename)) 117 | if is_best: 118 | shutil.copyfile(os.path.join(p, filename), os.path.join(p, 'model_best.pth.tar')) 119 | 120 | 121 | def get_config(): 122 | parser = argparse.ArgumentParser() 123 | parser.add_argument('--config', type=str, default='./config/SCRL_pretrain_default.yaml') 124 | args = parser.parse_args() 125 | cfg = yaml.safe_load(open(args.config, encoding='utf8')) 126 | cfg = set_log(cfg) 127 | shutil.copy(args.config, cfg['log']['dir']) 128 | return cfg 129 | 130 | 131 | def main(): 132 | cfg = get_config() 133 | start_training(cfg) 134 | 135 | if __name__ == '__main__': 136 | main() 137 | -------------------------------------------------------------------------------- /pretrain_trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import torch.nn.parallel 4 | import torch.optim 5 | from utils import AverageMeter, ProgressMeter, to_log, accuracy 6 | 7 | 8 | 9 | def train_SCRL(gpu, train_loader, model, criterion, optimizer, epoch, cfg): 10 | batch_time = AverageMeter('Time', ':6.3f') 11 | data_time = AverageMeter('Data', ':6.3f') 12 | losses = AverageMeter('Loss', ':.4e') 13 | top1 = AverageMeter('Acc@1', ':6.2f') 14 | top5 = AverageMeter('Acc@5', ':6.2f') 15 | 16 | progress = ProgressMeter( 17 | len(train_loader), 18 | [batch_time, data_time, losses, top1, top5], 19 | prefix="Epoch: [{}]".format(epoch)) 20 | 21 | gradient_clip_val = cfg['optim']['gradient_norm'] 22 | 23 | model.train() 24 | view_size = (-1, 3 * cfg['data']['frame_size'], 224, 224) 25 | pivot = time.time() 26 | for i, data in enumerate(train_loader): 27 | if gpu is not None: 28 | data_q = data[0].cuda(gpu, non_blocking=True) 29 | data_k = data[1].cuda(gpu, non_blocking=True) 30 | data_time.update(time.time() - pivot) 31 | data_q = data_q.view(view_size) 32 | data_k = data_k.view(view_size) 33 | 34 | output, target = model(data_q, data_k) 35 | 36 | loss = criterion(output, target) 37 | 38 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 39 | 40 | losses.update(loss.item(), target.size(0)) 41 | top1.update(acc1[0], target.size(0)) 42 | top5.update(acc5[0], target.size(0)) 43 | 44 | optimizer.zero_grad() 45 | loss.backward() 46 | 47 | # gradient clipping 48 | if gradient_clip_val > 0: 49 | torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip_val) 50 | 51 | optimizer.step() 52 | 53 | batch_time.update(time.time() - pivot) 54 | pivot = time.time() 55 | 56 | if gpu == 0 and i % cfg['log']['print_freq'] == 0: 57 | _out = progress.display(i) 58 | to_log(cfg, _out, True) 59 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import torch 4 | 5 | class AverageMeter(object): 6 | """Computes and stores the average and current value""" 7 | def __init__(self, name, fmt=':f'): 8 | self.name = name 9 | self.fmt = fmt 10 | self.reset() 11 | 12 | def reset(self): 13 | self.val = 0 14 | self.avg = 0 15 | self.sum = 0 16 | self.count = 0 17 | 18 | def update(self, val, n=1): 19 | self.val = val 20 | self.sum += val * n 21 | self.count += n 22 | self.avg = self.sum / self.count 23 | 24 | def __str__(self): 25 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 26 | return fmtstr.format(**self.__dict__) 27 | 28 | 29 | class ProgressMeter(object): 30 | def __init__(self, num_batches, meters, prefix=""): 31 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 32 | self.meters = meters 33 | self.prefix = prefix 34 | 35 | def display(self, batch): 36 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 37 | entries += [str(meter) for meter in self.meters] 38 | out = '\t'.join(entries) 39 | return out 40 | 41 | def _get_batch_fmtstr(self, num_batches): 42 | num_digits = len(str(num_batches // 1)) 43 | fmt = '{:' + str(num_digits) + 'd}' 44 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 45 | 46 | def set_log(cfg): 47 | time_str = time.strftime("%Y-%m-%d_%H_%M_%S", time.localtime()) 48 | cfg['log']['dir'] = cfg['log']['dir'] + time_str 49 | if not os.path.exists(cfg['log']['dir']): 50 | os.makedirs(cfg['log']['dir']) 51 | return cfg 52 | 53 | def to_log(cfg, content, echo=False, gpu_print_id=0): 54 | # gpu_print_id < 0 force to print 55 | if cfg['DDP']['gpu'] == gpu_print_id and gpu_print_id >= 0: 56 | log_path = os.path.join(cfg['log']['dir'], 'log.txt') 57 | with open(log_path, 'a') as f: 58 | f.writelines(content+'\n') 59 | if echo: 60 | print(content) 61 | 62 | def accuracy(output, target, topk=(1,)): 63 | """Computes the accuracy over the k top predictions for the specified values of k""" 64 | with torch.no_grad(): 65 | maxk = max(topk) 66 | batch_size = target.size(0) 67 | 68 | _, pred = output.topk(maxk, 1, True, True) 69 | pred = pred.t() 70 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 71 | 72 | res = [] 73 | for k in topk: 74 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) 75 | res.append(correct_k.mul_(100.0 / batch_size)) 76 | return res --------------------------------------------------------------------------------