├── ops ├── __init__.py ├── basic_ops.py ├── generate_scripts.py ├── dataset_config.py ├── non_local.py ├── rstg.py ├── models_config.py ├── dataset.py ├── temporal_shift.py ├── dataset_syncMNIST.py ├── nested_model.py ├── models_utils.py ├── resnet3d_xl.py ├── resnet2d.py └── utils.py ├── .gitignore ├── DyReg_GNN_arch.png ├── train_model.sh ├── tools ├── gen_label_other.py ├── vid2img_sthv2.py ├── gen_label_sthv2.py ├── gen_label_sthv1.py └── vid2img_other.py ├── evaluate_model.sh ├── evaluate_model_multi_clip.sh ├── create_model.py ├── README.md ├── opts.py ├── LICENSE └── test_models.py /ops/__init__.py: -------------------------------------------------------------------------------- 1 | from ops.basic_ops import * -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # ignore all logs 2 | **__pycache__* 3 | -------------------------------------------------------------------------------- /DyReg_GNN_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bit-ml/DyReg-GNN/HEAD/DyReg_GNN_arch.png -------------------------------------------------------------------------------- /ops/basic_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Identity(torch.nn.Module): 5 | def forward(self, input): 6 | return input 7 | 8 | 9 | class SegmentConsensus(torch.nn.Module): 10 | 11 | def __init__(self, consensus_type, dim=1): 12 | super(SegmentConsensus, self).__init__() 13 | self.consensus_type = consensus_type 14 | self.dim = dim 15 | self.shape = None 16 | 17 | def forward(self, input_tensor): 18 | self.shape = input_tensor.size() 19 | if self.consensus_type == 'avg': 20 | output = input_tensor.mean(dim=self.dim, keepdim=True) 21 | elif self.consensus_type == 'identity': 22 | output = input_tensor 23 | else: 24 | output = None 25 | 26 | return output 27 | 28 | 29 | class ConsensusModule(torch.nn.Module): 30 | 31 | def __init__(self, consensus_type, dim=1): 32 | super(ConsensusModule, self).__init__() 33 | self.consensus_type = consensus_type if consensus_type != 'rnn' else 'identity' 34 | self.dim = dim 35 | 36 | def forward(self, input): 37 | return SegmentConsensus(self.consensus_type, self.dim)(input) 38 | -------------------------------------------------------------------------------- /train_model.sh: -------------------------------------------------------------------------------- 1 | RAND=$((RANDOM)) 2 | 3 | MODEL_DIR='./models/'$1 4 | LOG_NAME=$MODEL_DIR'/log_'$RAND 5 | mkdir $MODEL_DIR 6 | 7 | args="--print-freq=1 --name=$1 8 | --npb --offset_generator=big 9 | --init_skip_zero=False --warmup_validate=False 10 | --dynamic_regions=dyreg --init_regions=center 11 | --place_graph=layer2.2_layer3.4_layer4.1 12 | --ch_div=4 --graph_residual_type=norm 13 | --rnn_type=GRU --aggregation_type=dot --send_layers=1 14 | --use_rstg=True --rstg_skip_connection=False --remap_first=True 15 | --update_layers=0 --combine_by_sum=True --project_at_input=True 16 | --tmp_norm_skip_conn=False --bottleneck_graph=True 17 | --lr 0.001 --batch-size 2 --dataset=somethingv2 --modality=RGB --arch=resnet50 --num_segments 16 --gd 20 18 | --lr_type=step --lr_steps 20 30 40 --epochs 60 -j 16 --dropout=0.5 --consensus_type=avg --eval-freq=1 19 | --shift_div=8 --shift --shift_place=blockres --model_dir=$MODEL_DIR" 20 | 21 | 22 | echo $args > $MODEL_DIR'/args_'$RAND 23 | 24 | # run on CPU: 25 | # CUDA_VISIBLE_DEVICES="" python -u main_standard.py $args & tee $LOG_NAME 26 | 27 | # run on GPU 28 | CUDA_VISIBLE_DEVICES=0 python -u -m torch.distributed.launch --nproc_per_node=1 --master_port=6004 main_standard.py $args |& tee $LOG_NAME 29 | -------------------------------------------------------------------------------- /tools/gen_label_other.py: -------------------------------------------------------------------------------- 1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding" 2 | # arXiv:1811.08383 3 | # Ji Lin*, Chuang Gan, Song Han 4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu 5 | # ------------------------------------------------------ 6 | # Code adapted from https://github.com/metalbubble/TRN-pytorch/blob/master/process_dataset.py 7 | # processing the raw data of the video Something-Something-V2 8 | 9 | import os 10 | import json 11 | import pdb 12 | 13 | if __name__ == '__main__': 14 | num_sec = 5 15 | 16 | dataset_folder = f'/data/datasets/in_the_wild/gifs-frames-{num_sec}s' 17 | filename_output = f'/data/datasets/in_the_wild/gifs-frames-{num_sec}s.txt' 18 | 19 | folders = os.listdir(dataset_folder) 20 | output = [] 21 | for i in range(len(folders)): 22 | curFolder = folders[i] 23 | curIDX = 0 24 | # counting the number of frames in each video folders 25 | video_folder = os.path.join(dataset_folder, curFolder) 26 | if os.path.exists(video_folder): 27 | dir_files = os.listdir(video_folder) 28 | output.append('%s %d %d' % (curFolder, len(dir_files), curIDX)) 29 | print('%d/%d' % (i, len(folders))) 30 | else: 31 | print(f'video {video_folder} does not exist: skipping') 32 | 33 | with open(filename_output, 'w') as f: 34 | f.write('\n'.join(output)) 35 | -------------------------------------------------------------------------------- /evaluate_model.sh: -------------------------------------------------------------------------------- 1 | RAND=$((RANDOM)) 2 | 3 | MODEL_DIR='./models/'$1 4 | LOG_NAME=$MODEL_DIR'/log_'$RAND 5 | 6 | mkdir $MODEL_DIR 7 | 8 | RESUME_CKPT='./checkpoints/dyreg_gnn_model_l2l3l4.pth.tar' 9 | 10 | args="--name=$1 --visualisation=hsv --weights=$RESUME_CKPT 11 | --npb --warmup_validate=False 12 | --offset_generator=big --dynamic_regions=dyreg 13 | --use_rstg=True --bottleneck_graph=True 14 | --rstg_skip_connection=False --remap_first=True 15 | --place_graph=layer2.2_layer3.4_layer4.1 16 | --ch_div=4 --graph_residual_type=norm 17 | --rnn_type=GRU --aggregation_type=dot --send_layers=1 18 | --update_layers=0 --combine_by_sum=True --project_at_input=True 19 | --tmp_norm_skip_conn=False --init_regions=center 20 | --lr 0.001 --batch-size 2 --dataset=somethingv2 --modality=RGB --arch resnet50 21 | --test_segments 16 --gd 20 --lr_steps 20 40 --epochs 60 -j 16 22 | --dropout 0.000000000000000001 --consensus_type=avg --eval-freq=1 23 | --shift --shift_div=8 --shift_place=blockres --model_dir=$MODEL_DIR 24 | --full_size_224=False --full_res=False --save_kernels=True" 25 | 26 | 27 | echo $args > $MODEL_DIR'/args_'$RAND 28 | 29 | # run on CPU: 30 | # CUDA_VISIBLE_DEVICES="" python -u test_models.py $args & tee $LOG_NAME 31 | 32 | # run on GPU 33 | CUDA_VISIBLE_DEVICES=1 python -u -m torch.distributed.launch --nproc_per_node=1 --master_port=6004 test_models.py $args |& tee $LOG_NAME 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /evaluate_model_multi_clip.sh: -------------------------------------------------------------------------------- 1 | RAND=$((RANDOM)) 2 | 3 | MODEL_DIR='./models/'$1 4 | LOG_NAME=$MODEL_DIR'/log_'$RAND 5 | 6 | mkdir $MODEL_DIR 7 | 8 | RESUME_CKPT='./checkpoints/dyreg_gnn_model_l2l3l4.pth.tar' 9 | 10 | args="--name=$1 --visualisation=hsv --weights=$RESUME_CKPT 11 | --npb --warmup_validate=False 12 | --offset_generator=big --dynamic_regions=dyreg 13 | --use_rstg=True --bottleneck_graph=True 14 | --rstg_skip_connection=False --remap_first=True 15 | --place_graph=layer2.2_layer3.4_layer4.1 16 | --ch_div=4 --graph_residual_type=norm 17 | --rnn_type=GRU --aggregation_type=dot --send_layers=1 18 | --update_layers=0 --combine_by_sum=True --project_at_input=True 19 | --tmp_norm_skip_conn=False --init_regions=center 20 | --lr 0.001 --batch-size 2 --dataset=somethingv2 --modality=RGB --arch resnet50 21 | --test_segments 16 --gd 20 --lr_steps 20 40 --epochs 60 -j 16 22 | --dropout 0.000000000000000001 --consensus_type=avg --eval-freq=1 23 | --shift --shift_div=8 --shift_place=blockres --model_dir=$MODEL_DIR 24 | --full_size_224=False --full_res=True --test_crops=3 --twice_sample" 25 | 26 | 27 | echo $args > $MODEL_DIR'/args_'$RAND 28 | 29 | # run on CPU: 30 | #CUDA_VISIBLE_DEVICES="" python -u test_models.py $args & tee $LOG_NAME 31 | 32 | # run on GPU 33 | CUDA_VISIBLE_DEVICES=0 python -u -m torch.distributed.launch --nproc_per_node=1 --master_port=6004 test_models.py $args |& tee $LOG_NAME 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /tools/vid2img_sthv2.py: -------------------------------------------------------------------------------- 1 | # Code adapted from "TSM: Temporal Shift Module for Efficient Video Understanding" 2 | # Ji Lin*, Chuang Gan, Song Han 3 | 4 | import os 5 | import threading 6 | import pdb 7 | 8 | NUM_THREADS = 12 9 | VIDEO_ROOT = './data/smt-smt-V2/smt-smt-V2-videos/20bn-something-something-v2/' 10 | FRAME_ROOT = './data/smt-smt-V2/smt-smt-V2-frames/' 11 | 12 | 13 | def split(l, n): 14 | """Yield successive n-sized chunks from l.""" 15 | for i in range(0, len(l), n): 16 | yield l[i:i + n] 17 | 18 | 19 | def extract(video, tmpl='%06d.jpg'): 20 | cmd = 'ffmpeg -i \"{}/{}\" -threads 1 -vf scale=-1:256 -q:v 0 \"{}/{}/%06d.jpg\"'.format(VIDEO_ROOT, video, 21 | FRAME_ROOT, video[:-5]) 22 | os.system(cmd) 23 | 24 | 25 | def target(video_list): 26 | for video in video_list: 27 | video_path = os.path.join(FRAME_ROOT, video[:-5]) 28 | if not os.path.exists(video_path): 29 | os.makedirs(os.path.join(FRAME_ROOT, video[:-5])) 30 | extract(video) 31 | else: 32 | dir_files = os.listdir(os.path.join(FRAME_ROOT, video[:-5])) 33 | if len(dir_files) <= 10 and len(dir_files) != 0: 34 | print(f'folder {video} has only {len(dir_files)} frames') 35 | extract(video) 36 | 37 | 38 | 39 | if __name__ == '__main__': 40 | if not os.path.exists(VIDEO_ROOT): 41 | raise ValueError('Please download videos and set VIDEO_ROOT variable.') 42 | if not os.path.exists(FRAME_ROOT): 43 | os.makedirs(FRAME_ROOT) 44 | 45 | video_list = os.listdir(VIDEO_ROOT) 46 | splits = list(split(video_list, NUM_THREADS)) 47 | 48 | threads = [] 49 | for i, split in enumerate(splits): 50 | thread = threading.Thread(target=target, args=(split,)) 51 | thread.start() 52 | threads.append(thread) 53 | 54 | for thread in threads: 55 | thread.join() -------------------------------------------------------------------------------- /tools/gen_label_sthv2.py: -------------------------------------------------------------------------------- 1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding" 2 | # arXiv:1811.08383 3 | # Ji Lin*, Chuang Gan, Song Han 4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu 5 | # ------------------------------------------------------ 6 | # Code adapted from https://github.com/metalbubble/TRN-pytorch/blob/master/process_dataset.py 7 | # processing the raw data of the video Something-Something-V2 8 | 9 | import os 10 | import json 11 | DATA_UTILS_ROOT='./data/smt-smt-V2/tsm_data/' 12 | FRAME_ROOT='./data/smt-smt-V2/smt-smt-V2-frames/' 13 | 14 | if __name__ == '__main__': 15 | dataset_name = DATA_UTILS_ROOT+'something-something-v2' 16 | with open('%s-labels.json' % dataset_name) as f: 17 | data = json.load(f) 18 | categories = [] 19 | for i, (cat, idx) in enumerate(data.items()): 20 | assert i == int(idx) # make sure the rank is right 21 | categories.append(cat) 22 | 23 | with open(DATA_UTILS_ROOT+'category.txt', 'w') as f: 24 | f.write('\n'.join(categories)) 25 | 26 | dict_categories = {} 27 | for i, category in enumerate(categories): 28 | dict_categories[category] = i 29 | 30 | files_input = ['%s-validation.json' % dataset_name, '%s-train.json' % dataset_name, '%s-test.json' % dataset_name] 31 | files_output = ['val_videofolder.txt', 'train_videofolder.txt', 'test_videofolder.txt'] 32 | for (filename_input, filename_output) in zip(files_input, files_output): 33 | with open(filename_input) as f: 34 | data = json.load(f) 35 | folders = [] 36 | idx_categories = [] 37 | for item in data: 38 | folders.append(item['id']) 39 | if 'test' not in filename_input: 40 | idx_categories.append(dict_categories[item['template'].replace('[', '').replace(']', '')]) 41 | else: 42 | idx_categories.append(0) 43 | output = [] 44 | for i in range(len(folders)): 45 | curFolder = folders[i] 46 | curIDX = idx_categories[i] 47 | # counting the number of frames in each video folders 48 | video_folder = os.path.join(FRAME_ROOT, curFolder) 49 | if os.path.exists(video_folder): 50 | dir_files = os.listdir(video_folder) 51 | output.append('%s %d %d' % (curFolder, len(dir_files), curIDX)) 52 | print('%d/%d' % (i, len(folders))) 53 | else: 54 | print(f'video {video_folder} does not exist: skipping') 55 | 56 | with open(DATA_UTILS_ROOT+filename_output, 'w') as f: 57 | f.write('\n'.join(output)) 58 | -------------------------------------------------------------------------------- /tools/gen_label_sthv1.py: -------------------------------------------------------------------------------- 1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding" 2 | # arXiv:1811.08383 3 | # Ji Lin*, Chuang Gan, Song Han 4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu 5 | # ------------------------------------------------------ 6 | # Code adapted from https://github.com/metalbubble/TRN-pytorch/blob/master/process_dataset.py 7 | # processing the raw data of the video Something-Something-V1 8 | 9 | import os 10 | 11 | if __name__ == '__main__': 12 | dataset_name = '/data/datasets/something-something/something-something-v1' # 'jester-v1''something-something-v1' # 'jester-v1' 13 | with open('%s-labels.csv' % dataset_name) as f: 14 | lines = f.readlines() 15 | categories = [] 16 | for line in lines: 17 | line = line.rstrip() 18 | categories.append(line) 19 | categories = sorted(categories) 20 | with open('smtv1_category.txt', 'w') as f: 21 | f.write('\n'.join(categories)) 22 | 23 | dict_categories = {} 24 | for i, category in enumerate(categories): 25 | dict_categories[category] = i 26 | 27 | files_input = ['%s-validation.csv' % dataset_name, '%s-train.csv' % dataset_name] 28 | files_output = ['smtv1-val_videofolder.txt', 'smtv1-train_videofolder.txt'] 29 | for (filename_input, filename_output) in zip(files_input, files_output): 30 | if 'val' in filename_output: 31 | split = 'valid' 32 | elif 'train' in filename_output: 33 | split = 'train' 34 | with open(filename_input) as f: 35 | lines = f.readlines() 36 | folders = [] 37 | idx_categories = [] 38 | for line in lines: 39 | line = line.rstrip() 40 | items = line.split(';') 41 | folders.append(items[0]) 42 | idx_categories.append(dict_categories[items[1]]) 43 | output = [] 44 | for i in range(len(folders)): 45 | curFolder = folders[i] 46 | curIDX = idx_categories[i] 47 | # counting the number of frames in each video folders 48 | video_folder = os.path.join(f'/data/datasets/something-something/20bn-something-something-v1/{split}/', curFolder) 49 | 50 | # dir_files = os.listdir(os.path.join('../img', curFolder)) 51 | # output.append('%s %d %d' % ('something/v1/img/' + curFolder, len(dir_files), curIDX)) 52 | # print('%d/%d' % (i, len(folders))) 53 | if os.path.exists(video_folder): 54 | dir_files = os.listdir(video_folder) 55 | output.append('%s %d %d' % (curFolder, len(dir_files), curIDX)) 56 | print('%d/%d' % (i, len(folders))) 57 | else: 58 | print(f'video {video_folder} does not exist: skipping') 59 | 60 | with open(filename_output, 'w') as f: 61 | f.write('\n'.join(output)) 62 | -------------------------------------------------------------------------------- /tools/vid2img_other.py: -------------------------------------------------------------------------------- 1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding" 2 | # arXiv:1811.08383 3 | # Ji Lin*, Chuang Gan, Song Han 4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu 5 | 6 | import os 7 | import threading 8 | import pdb 9 | 10 | NUM_THREADS = 12 11 | # VIDEO_ROOT = '/ssd/video/something/v2/20bn-something-something-v2' # Downloaded webm videos 12 | # FRAME_ROOT = '/ssd/video/something/v2/20bn-something-something-v2-frames' # Directory for extracted frames 13 | # VIDEO_ROOT = '/data/datasets/smt-smt-V2/20bn-something-something-v2' 14 | # FRAME_ROOT = '/data/datasets/smt-smt-V2/20bn-something-something-v2-frames' 15 | 16 | num_sec = 5 17 | VIDEO_ROOT = '/data/datasets/in_the_wild/gifs' 18 | FRAME_ROOT = f'/data/datasets/in_the_wild/gifs-frames-{num_sec}s' 19 | 20 | # VIDEO_ROOT = '/data/datasets/in_the_wild/dataset_imar' 21 | # FRAME_ROOT = f'/data/datasets/in_the_wild/dataset_imar-frames-{num_sec}s' 22 | 23 | 24 | def split(l, n): 25 | """Yield successive n-sized chunks from l.""" 26 | for i in range(0, len(l), n): 27 | yield l[i:i + n] 28 | 29 | 30 | def extract(video, tmpl='%06d.jpg'): 31 | # os.system(f'ffmpeg -i {VIDEO_ROOT}/{video} -vf -threads 1 -vf scale=-1:256 -q:v 0 ' 32 | # f'{FRAME_ROOT}/{video[:-5]}/{tmpl}') 33 | # cmd0 = 'ffmpeg -i \"{}/{}\" -threads 1 -vf scale=-1:256 -q:v 0 \"{}/{}/%06d.jpg\"'.format(VIDEO_ROOT, video, 34 | # FRAME_ROOT, video[:-5]) 35 | 36 | cmd = f'ffmpeg -t 00:0{num_sec} ' 37 | cmd = cmd + '-i \"{}/{}\" -threads 1 -vf scale=-1:256 -q:v 0 \"{}/{}/%06d.jpg\"'.format(VIDEO_ROOT, video, 38 | FRAME_ROOT, video[:-5]) 39 | 40 | 41 | os.system(cmd) 42 | 43 | 44 | def target(video_list): 45 | for video in video_list: 46 | video_path = os.path.join(FRAME_ROOT, video[:-5]) 47 | if not os.path.exists(video_path): 48 | #print(f'video {video_path} does not exists') 49 | os.makedirs(os.path.join(FRAME_ROOT, video[:-5])) 50 | extract(video) 51 | else: 52 | dir_files = os.listdir(os.path.join(FRAME_ROOT, video[:-5])) 53 | if len(dir_files) <= 10: 54 | print(f'folder {video} has only {len(dir_files)} frames') 55 | extract(video) 56 | 57 | 58 | 59 | if __name__ == '__main__': 60 | if not os.path.exists(VIDEO_ROOT): 61 | raise ValueError('Please download videos and set VIDEO_ROOT variable.') 62 | if not os.path.exists(FRAME_ROOT): 63 | os.makedirs(FRAME_ROOT) 64 | 65 | video_list = os.listdir(VIDEO_ROOT) 66 | splits = list(split(video_list, NUM_THREADS)) 67 | 68 | threads = [] 69 | for i, split in enumerate(splits): 70 | #target(split) 71 | thread = threading.Thread(target=target, args=(split,)) 72 | thread.start() 73 | threads.append(thread) 74 | 75 | for thread in threads: 76 | thread.join() -------------------------------------------------------------------------------- /create_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | from ops.dyreg import DynamicGraph, dyregParams 4 | 5 | class SpaceTimeModel(torch.nn.Module): 6 | def __init__(self): 7 | super(SpaceTimeModel, self).__init__() 8 | dyreg_params = dyregParams() 9 | dyregParams.offset_lstm_dim = 32 10 | self.dyreg = DynamicGraph(dyreg_params, 11 | backbone_dim=32, node_dim=32, out_num_ch=32, 12 | H=16, W=16, 13 | iH=16, iW=16, 14 | project_i3d=False, 15 | name='lalalal') 16 | 17 | 18 | self.fc = torch.nn.Linear(32, 10) 19 | 20 | def forward(self, x): 21 | dx = self.dyreg(x) 22 | # you can initialize the dyreg branch as identity function by normalisation, 23 | # as done in DynamicGraphWrapper found in ./ops/dyreg.py 24 | x = x + dx 25 | # average over time and space: T, H, W 26 | x = x.mean(-1).mean(-1).mean(-2) 27 | x = self.fc(x) 28 | return x 29 | 30 | 31 | 32 | class ConvSpaceTimeModel(torch.nn.Module): 33 | def __init__(self): 34 | super(ConvSpaceTimeModel, self).__init__() 35 | dyreg_params = dyregParams() 36 | dyregParams.offset_lstm_dim = 32 37 | self.conv1 = torch.nn.Conv3d(in_channels=3, out_channels=32, kernel_size=[1,3,3], stride=[1,2,2],padding=[0,1,1]) 38 | self.conv2 = torch.nn.Conv3d(in_channels=32, out_channels=32, kernel_size=[1,3,3], stride=[1,2,2],padding=[0,1,1]) 39 | self.dyreg = DynamicGraph(dyreg_params, 40 | backbone_dim=32, node_dim=32, out_num_ch=32, 41 | H=16, W=16, 42 | iH=16, iW=16, 43 | project_i3d=False, 44 | name='lalalal') 45 | 46 | self.conv3 = torch.nn.Conv3d(in_channels=32, out_channels=32, kernel_size=[1,3,3], stride=[1,2,2],padding=[0,1,1]) 47 | self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1)) 48 | self.fc = torch.nn.Linear(32, 10) 49 | 50 | def forward(self, x): 51 | x = F.relu(self.conv1(x)) 52 | x = F.relu(self.conv2(x)) 53 | input = x.permute(0,2,1,3,4).contiguous() 54 | dx = self.dyreg(input) 55 | dx = dx.permute(0,2,1,3,4).contiguous() 56 | # you can initialize the dyreg branch as identity function by normalisation, 57 | # as done in DynamicGraphWrapper found in ./ops/dyreg.py 58 | x = x + dx 59 | x = F.relu(self.conv3(x)) 60 | # average over time and space: T, H, W 61 | x = x.mean(-1).mean(-1).mean(-1) 62 | x = self.fc(x) 63 | return x 64 | 65 | B = 8 66 | T = 10 67 | C = 32 68 | H = 16 69 | W = 16 70 | x = torch.ones(B,T,C,H,W) 71 | st_model = SpaceTimeModel() 72 | out1 = st_model(x) 73 | 74 | 75 | B = 8 76 | T = 10 77 | C = 3 78 | H = 64 79 | W = 64 80 | x = torch.ones(B,C,T,H,W) 81 | conv_st_model = ConvSpaceTimeModel() 82 | out2 = conv_st_model(x) 83 | [print(f'{k} {v.shape}') for k,v in conv_st_model.named_parameters()] 84 | 85 | out = out1 + out2 86 | print('done') -------------------------------------------------------------------------------- /ops/generate_scripts.py: -------------------------------------------------------------------------------- 1 | 2 | import itertools 3 | 4 | var_params = {} 5 | 6 | # var_params['aux_loss'] = ['inter_videos_all', 'inter_videos_content', 'inter_nodes_all', 'inter_nodes_content'] 7 | # var_params['aux_loss'] = ['global_foreground_pos_background_neg', 8 | # 'global_foreground_pos_foreground_neg', 9 | # 'global_foreground_pos_foreground_background_neg', 10 | # 'global_foreground_pos_foreground_background_neg_rand'] 11 | 12 | 13 | var_params['contrastive_mlp'] = [False] 14 | 15 | var_params['aux_loss'] = ['nodes_temporal_matching'] 16 | var_params['contrastive_alpha'] = [0.1, 1.0] 17 | var_params['contrastive_temp'] = [0.1, 0.5, 1.0, 2.0] 18 | # var_params['contrastive_temp'] = [0.05] 19 | 20 | # var_params['place_graph'] = ['layer2.1_layer3.1_layer4.1'] 21 | # var_params['place_graph'] = ['layer3.1'] 22 | 23 | somelists = [val for k,val in var_params.items()] 24 | 25 | 26 | for element in itertools.product(*somelists): 27 | print('PARAMS') 28 | args = '' 29 | keys = list(var_params.keys()) 30 | 31 | name_model = 'model_contrastive' 32 | for i,k in enumerate(keys): 33 | name_model += f'_{k}_{element[i]}' 34 | args += f' --{k}={element[i]}' 35 | args += f' --name={name_model}' 36 | # print(args) 37 | 38 | args = args + f" --place_graph=layer2.1_layer3.1_layer4.1 --contrastive_type=nodes_temporal --rstg_combine=plus --node_confidence=none --use_detector=False --compute_iou=False --tmp_fix_kernel_grads=False --create_order=False --distill=False --distributed=True --bottleneck_graph=False --scaled_distill=False --alpha_distill=1.0 --glore_pos_emb_type=sinusoidal --npb --offset_generator=fishnet --predefined_augm=False --freeze_backbone=False --init_skip_zero=False --warmup_validate=False --smt_split=classic --use_rstg=True --tmp_init_different=0.0 --rstg_skip_connection=False --remap_first=True --isolate_graphs=False --dynamic_regions=constrained_fix_size --ch_div=4 --graph_residual_type=norm --rnn_type=GRU --aggregation_type=dot --send_layers=1 --update_layers=0 --combine_by_sum=True --project_at_input=True --tmp_norm_after_project_at_input=True --tmp_norm_skip_conn=False --tmp_norm_before_dynamic_graph=False --init_regions=center --tmp_increment_reg=True --lr 0.01 --batch-size 14 --dataset=cater --modality=RGB --arch=resnet18 --num_segments 13 --gd 20 --lr_type=step --lr_steps 75 125 200 --epochs 150 -j 16 --dropout=0.5 --consensus_type=avg --eval-freq=1 --shift_div=8 --shift --shift_place=blockres --model_dir=$MODEL_DIR" 39 | 40 | 41 | script_name = f'auto_scripts/start_' + name_model + '.sh' 42 | 43 | with open(script_name, 'w') as f: 44 | f.write('RAND=$((RANDOM))\n') 45 | f.write(f'MODEL_DIR=\'/models/graph_models/models_cater_dynamic_pytorch/exp_contrastive_global/{name_model}\'\n') 46 | 47 | f.write('LOG_NAME=$MODEL_DIR\'/log_\'$RAND\n') 48 | f.write('mkdir $MODEL_DIR\n') 49 | f.write('\n') 50 | f.write(f'args="{args}"\n') 51 | f.write('\n') 52 | 53 | 54 | f.write('cp ./train.py $MODEL_DIR\'/train.py_\'$RAND\n') 55 | f.write('mkdir $MODEL_DIR\'/code/\'\n') 56 | f.write('mkdir $MODEL_DIR\'/code/ops_\'$RAND\'/\'\n') 57 | f.write('cp -r ./ops/ $MODEL_DIR\'/code/ops_\'$RAND\'/\'\n') 58 | f.write('cp -r ./main.py $MODEL_DIR\'/code/\'\n') 59 | f.write('echo $args > $MODEL_DIR\'/args_\'$RAND\n') 60 | 61 | f.write('\n') 62 | f.write('\n') 63 | 64 | f.write(f'CUDA_VISIBLE_DEVICES=$1 python -u -m torch.distributed.launch --nproc_per_node=1 --master_port=$2 main_contrastive.py $args |& tee $LOG_NAME\n') 65 | -------------------------------------------------------------------------------- /ops/dataset_config.py: -------------------------------------------------------------------------------- 1 | # Code adapted from "TSM: Temporal Shift Module for Efficient Video Understanding" 2 | # Ji Lin*, Chuang Gan, Song Han 3 | 4 | 5 | import os 6 | 7 | global args, best_prec1 8 | from opts import parse_args 9 | args = parse_args() 10 | 11 | if args.dataset == 'somethingv2': 12 | ROOT_DATASET = './data/smt-smt-V2/' 13 | elif args.dataset == 'something': 14 | ROOT_DATASET = './data/datasets/something-something/' 15 | elif args.dataset == 'others': 16 | ROOT_DATASET = './data/datasets/in_the_wild/' 17 | elif args.dataset == 'syncMNIST': 18 | ROOT_DATASET = './data/syncMNIST/' 19 | elif args.dataset == 'multiSyncMNIST': 20 | ROOT_DATASET = './data/multiSyncMNIST/' 21 | 22 | 23 | def return_something(modality): 24 | filename_categories = 'tsm_data/smtv1_category.txt' 25 | if modality == 'RGB': 26 | root_data = ROOT_DATASET + '/smt-smt-V2-frames/' 27 | filename_imglist_train = 'tsm_data/smtv1-train_videofolder.txt' 28 | filename_imglist_val = 'tsm_data/smtv1-val_videofolder.txt' 29 | prefix = '{:05d}.jpg' 30 | else: 31 | print('no such modality:'+modality) 32 | raise NotImplementedError 33 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix 34 | 35 | 36 | def return_others(modality): 37 | filename_categories = 'categories.txt' 38 | 39 | root_data = ROOT_DATASET + '/gifs-frames-5s/' 40 | filename_imglist_train = 'gifs-frames-5s.txt' 41 | filename_imglist_val = 'gifs-frames-5s.txt' 42 | prefix = '{:06d}.jpg' 43 | 44 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix 45 | 46 | 47 | 48 | def return_somethingv2(modality): 49 | if modality == 'RGB': 50 | filename_categories = 'tsm_data/category.txt' 51 | root_data = ROOT_DATASET + '/smt-smt-V2-frames/' 52 | filename_imglist_train = 'tsm_data/train_videofolder.txt' 53 | filename_imglist_val = 'tsm_data/val_videofolder.txt' 54 | prefix = '{:06d}.jpg' 55 | else: 56 | print('no such modality:'+modality) 57 | raise NotImplementedError 58 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix 59 | 60 | 61 | 62 | def return_dataset(dataset, modality): 63 | if dataset == 'syncMNIST': 64 | n_class = 46 65 | test_dataset = '/data/datasets/video_mnist/sync_mnist_large_v2_split_test_max_sync_dist_160_num_classes_46_no_digits_5_no_noise_parts_0/' 66 | train_dataset = '/data/datasets/video_mnist/sync_mnist_large_v2_split_train_max_sync_dist_160_num_classes_46_no_digits_5_no_noise_parts_0/' 67 | return n_class, train_dataset, test_dataset, None, None 68 | elif dataset == 'multiSyncMNIST': 69 | n_class = 56 70 | test_dataset = args.test_dataset 71 | train_dataset = args.train_dataset 72 | return n_class, train_dataset, test_dataset, None, None 73 | dict_single = {'something': return_something, 74 | 'somethingv2': return_somethingv2, 75 | 'others' : return_others} 76 | if dataset in dict_single: 77 | file_categories, file_imglist_train, file_imglist_val, root_data, prefix = dict_single[dataset](modality) 78 | else: 79 | raise ValueError('Unknown dataset '+dataset) 80 | 81 | file_imglist_train = os.path.join(ROOT_DATASET, file_imglist_train) 82 | file_imglist_val = os.path.join(ROOT_DATASET, file_imglist_val) 83 | 84 | if isinstance(file_categories, str): 85 | file_categories = os.path.join(ROOT_DATASET, file_categories) 86 | with open(file_categories) as f: 87 | lines = f.readlines() 88 | categories = [item.rstrip() for item in lines] 89 | else: # number of categories 90 | categories = [None] * file_categories 91 | n_class = len(categories) 92 | print('{}: {} classes'.format(dataset, n_class)) 93 | return n_class, file_imglist_train, file_imglist_val, root_data, prefix 94 | -------------------------------------------------------------------------------- /ops/non_local.py: -------------------------------------------------------------------------------- 1 | # Non-local block using embedded gaussian 2 | # Code from 3 | # https://github.com/AlexHex7/Non-local_pytorch/blob/master/Non-Local_pytorch_0.3.1/lib/non_local_embedded_gaussian.py 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | 8 | 9 | class _NonLocalBlockND(nn.Module): 10 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): 11 | super(_NonLocalBlockND, self).__init__() 12 | 13 | assert dimension in [1, 2, 3] 14 | 15 | self.dimension = dimension 16 | self.sub_sample = sub_sample 17 | 18 | self.in_channels = in_channels 19 | self.inter_channels = inter_channels 20 | 21 | if self.inter_channels is None: 22 | self.inter_channels = in_channels // 2 23 | if self.inter_channels == 0: 24 | self.inter_channels = 1 25 | 26 | if dimension == 3: 27 | conv_nd = nn.Conv3d 28 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) 29 | bn = nn.BatchNorm3d 30 | elif dimension == 2: 31 | conv_nd = nn.Conv2d 32 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) 33 | bn = nn.BatchNorm2d 34 | else: 35 | conv_nd = nn.Conv1d 36 | max_pool_layer = nn.MaxPool1d(kernel_size=(2)) 37 | bn = nn.BatchNorm1d 38 | 39 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 40 | kernel_size=1, stride=1, padding=0) 41 | 42 | if bn_layer: 43 | self.W = nn.Sequential( 44 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 45 | kernel_size=1, stride=1, padding=0), 46 | bn(self.in_channels) 47 | ) 48 | nn.init.constant_(self.W[1].weight, 0) 49 | nn.init.constant_(self.W[1].bias, 0) 50 | else: 51 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 52 | kernel_size=1, stride=1, padding=0) 53 | nn.init.constant_(self.W.weight, 0) 54 | nn.init.constant_(self.W.bias, 0) 55 | 56 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 57 | kernel_size=1, stride=1, padding=0) 58 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 59 | kernel_size=1, stride=1, padding=0) 60 | 61 | if sub_sample: 62 | self.g = nn.Sequential(self.g, max_pool_layer) 63 | self.phi = nn.Sequential(self.phi, max_pool_layer) 64 | 65 | def forward(self, x): 66 | ''' 67 | :param x: (b, c, t, h, w) 68 | :return: 69 | ''' 70 | 71 | batch_size = x.size(0) 72 | 73 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 74 | g_x = g_x.permute(0, 2, 1) 75 | 76 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 77 | theta_x = theta_x.permute(0, 2, 1) 78 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 79 | f = torch.matmul(theta_x, phi_x) 80 | f_div_C = F.softmax(f, dim=-1) 81 | 82 | y = torch.matmul(f_div_C, g_x) 83 | y = y.permute(0, 2, 1).contiguous() 84 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 85 | W_y = self.W(y) 86 | z = W_y + x 87 | 88 | return z 89 | 90 | 91 | class NONLocalBlock1D(_NonLocalBlockND): 92 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 93 | super(NONLocalBlock1D, self).__init__(in_channels, 94 | inter_channels=inter_channels, 95 | dimension=1, sub_sample=sub_sample, 96 | bn_layer=bn_layer) 97 | 98 | 99 | class NONLocalBlock2D(_NonLocalBlockND): 100 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 101 | super(NONLocalBlock2D, self).__init__(in_channels, 102 | inter_channels=inter_channels, 103 | dimension=2, sub_sample=sub_sample, 104 | bn_layer=bn_layer) 105 | 106 | 107 | class NONLocalBlock3D(_NonLocalBlockND): 108 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 109 | super(NONLocalBlock3D, self).__init__(in_channels, 110 | inter_channels=inter_channels, 111 | dimension=3, sub_sample=sub_sample, 112 | bn_layer=bn_layer) 113 | 114 | 115 | class NL3DWrapper(nn.Module): 116 | def __init__(self, block, n_segment): 117 | super(NL3DWrapper, self).__init__() 118 | self.block = block 119 | self.nl = NONLocalBlock3D(block.bn3.num_features) 120 | self.n_segment = n_segment 121 | 122 | def forward(self, x): 123 | x = self.block(x) 124 | 125 | nt, c, h, w = x.size() 126 | x = x.view(nt // self.n_segment, self.n_segment, c, h, w).transpose(1, 2) # n, c, t, h, w 127 | x = self.nl(x) 128 | x = x.transpose(1, 2).contiguous().view(nt, c, h, w) 129 | return x 130 | 131 | 132 | def make_non_local(net, n_segment): 133 | import torchvision 134 | import archs 135 | if isinstance(net, torchvision.models.ResNet): 136 | net.layer2 = nn.Sequential( 137 | NL3DWrapper(net.layer2[0], n_segment), 138 | net.layer2[1], 139 | NL3DWrapper(net.layer2[2], n_segment), 140 | net.layer2[3], 141 | ) 142 | net.layer3 = nn.Sequential( 143 | NL3DWrapper(net.layer3[0], n_segment), 144 | net.layer3[1], 145 | NL3DWrapper(net.layer3[2], n_segment), 146 | net.layer3[3], 147 | NL3DWrapper(net.layer3[4], n_segment), 148 | net.layer3[5], 149 | ) 150 | else: 151 | raise NotImplementedError 152 | 153 | 154 | if __name__ == '__main__': 155 | from torch.autograd import Variable 156 | import torch 157 | 158 | sub_sample = True 159 | bn_layer = True 160 | 161 | img = Variable(torch.zeros(2, 3, 20)) 162 | net = NONLocalBlock1D(3, sub_sample=sub_sample, bn_layer=bn_layer) 163 | out = net(img) 164 | print(out.size()) 165 | 166 | img = Variable(torch.zeros(2, 3, 20, 20)) 167 | net = NONLocalBlock2D(3, sub_sample=sub_sample, bn_layer=bn_layer) 168 | out = net(img) 169 | print(out.size()) 170 | 171 | img = Variable(torch.randn(2, 3, 10, 20, 20)) 172 | net = NONLocalBlock3D(3, sub_sample=sub_sample, bn_layer=bn_layer) 173 | out = net(img) 174 | print(out.size()) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Discovering Dynamic Salient Regions with Spatio-Temporal Graph Neural Networks 2 | This is the official code for DyReg model inroduced in [Discovering Dynamic Salient Regions with Spatio-Temporal Graph Neural Networks](https://arxiv.org/abs/2009.08427) 3 | 4 |
5 | 6 |
7 | 8 | ## Citation 9 | Please use the following BibTeX to cite our work. 10 | ``` 11 | @incollection{duta2021dynamic_dyreg_gnn_neurips2021, 12 | title = {Discovering Dynamic Salient Regions with Spatio-Temporal Graph 13 | Neural Networks}, 14 | author = {Duta, Iulia and Nicolicioiu, Andrei and Leordeanu, Marius}, 15 | booktitle = {Advances in Neural Information Processing Systems 34}, 16 | year = {2021} 17 | } 18 | 19 | @article{duta2020dynamic_dyreg, 20 | title = {Dynamic Regions Graph Neural Networks for Spatio-Temporal Reasoning}, 21 | author = {Duta, Iulia and Nicolicioiu, Andrei and Leordeanu, Marius}, 22 | journal = {NeurIPS 2020 Workshop on Object Representations for Learning and Reasoning}, 23 | year = {2020}, 24 | } 25 | ``` 26 | 27 | ## Requirements 28 | 29 | The code was developed using: 30 | 31 | - python 3.7 32 | - matplotlib 33 | - torch 1.7.1 34 | - script 35 | - pandas 36 | - torchvision 37 | - moviepy 38 | - ffmpeg 39 | 40 | 41 | ## Overview: 42 | The repository contains the Pytorch implementation of the DyReg-GNN model. 43 | The model is defined and trained in the following files: 44 | - [ops/dyreg.py](https://github.com/andreinicolicioiu/dyreg_clean/blob/main/ops/dyreg.py) - code for our DyReg module 45 | - [ops/rstg.py](https://github.com/andreinicolicioiu/dyreg_clean/blob/main/ops/rstg.py) - code for the Spatio-temporal GNN (RSTG) used to process the graph extracted using DyReg 46 | 47 | - [create_model.py](https://github.com/andreinicolicioiu/dyreg_clean/blob/main/create_model.py) - two examples how to integrate the DyReg-GNN module inside an existing backbone 48 | - [main_standard.py](https://github.com/andreinicolicioiu/dyreg_clean/blob/main/main_standard.py) - code to train a model on Smt-Smt dataset 49 | - [test_models.py](https://github.com/andreinicolicioiu/dyreg_clean/blob/main/test_models.py) - code for multi-clip evaluation 50 | 51 | Scripts for preparing the data, training and testing the model: 52 | - [train_model.sh](https://github.com/andreinicolicioiu/dyreg_clean/blob/main/train_model.sh) - example of script to train DyReg-GNN 53 | - [evaluate_model.sh](https://github.com/andreinicolicioiu/dyreg_clean/blob/main/evaluate_model.sh) - example of script to evaluate on a single clip DyReg-GNN 54 | - [evaluate_model_multi_clip.sh](https://github.com/andreinicolicioiu/dyreg_clean/blob/main/evaluate_model_multi_clip.sh) - example of script to evaluate on multiple clips DyReg-GNN 55 | 56 | - [tools/](https://github.com/andreinicolicioiu/dyreg_clean/blob/main/tools) contains all the script used to prepare Smt-Smt dataset (similar to the setup used in TSM) 57 | 58 | 59 | ## Prepare dataset 60 | 61 | For [Something Something dataset](https://arxiv.org/pdf/1706.04261v2.pdf): 62 | * the json files containing meta-data should be stored in `./data/smt-smt-V2/tsm_data` 63 | * the zip files containing the videos should be stored in `./data/smt-smt-V2/` 64 | 65 | - - - - 66 | 67 | 1. **To extract the videos from the zip files run:** 68 | 69 | `cat 20bn-something-something-v2-?? | tar zx` 70 | 71 | 2. **To extract the frames from videos run:** 72 | 73 | `python tools/vid2img_sthv2.py` 74 | 75 | → The videos will be stored in *$FRAME_ROOT* (default `'./data/smt-smt-V2/tmp_smt-smt-V2-frames'`) 76 | 77 | 💡 *If you already have the dataset as frames, place them under `./data/smt-smt-V2/smt-smt-V2-frames/`, one folder for each video* \ 78 | 💡💡 *If you need to change the path for datasets modify *$ROOT_DATASET* in [dataset_config.py](https://github.com/andreinicolicioiu/dyreg_clean/blob/main/ops/dataset_config.py)* 79 | 80 | 81 | 3. **To generate the labels file in the required format please run:** 82 | 83 | `python tools/gen_label_sthv2.py ` 84 | 85 | → The resulting txt files, for each split, will be stored in *$DATA_UTILS_ROOT* (default `'./data/smt-smt-V2/tsm_data/'`) 86 | 87 | 88 | ## How to run the model 89 | 90 | DyReg-GNN module can be simply inserted into any space-time model. 91 | ``` 92 | import torch 93 | from torch.nn import functional as F 94 | from ops.dyreg import DynamicGraph, dyregParams 95 | 96 | class SpaceTimeModel(torch.nn.Module): 97 | def __init__(self): 98 | super(SpaceTimeModel, self).__init__() 99 | dyreg_params = dyregParams() 100 | dyregParams.offset_lstm_dim = 32 101 | self.dyreg = DynamicGraph(dyreg_params, 102 | backbone_dim=32, node_dim=32, out_num_ch=32, 103 | H=16, W=16, 104 | iH=16, iW=16, 105 | project_i3d=False, 106 | name='lalalal') 107 | 108 | 109 | self.fc = torch.nn.Linear(32, 10) 110 | 111 | def forward(self, x): 112 | dx = self.dyreg(x) 113 | # you can initialize the dyreg branch as identity function by normalisation, 114 | # as done in DynamicGraphWrapper found in ./ops/dyreg.py 115 | x = x + dx 116 | # average over time and space: T, H, W 117 | x = x.mean(-1).mean(-1).mean(-2) 118 | x = self.fc(x) 119 | return x 120 | 121 | 122 | B = 8 123 | T = 10 124 | C = 32 125 | H = 16 126 | W = 16 127 | x = torch.ones(B,T,C,H,W) 128 | st_model = SpaceTimeModel() 129 | out = st_model(x) 130 | ``` 131 | 132 | 133 | For another example of how to integrate DyReg (DynamicGraph module) inside your model please look at [create_model.py](https://github.com/andreinicolicioiu/dyreg_clean/blob/main/create_model.py) or run: 134 | 135 | `python create_model.py` 136 | 137 | 138 | ### Something-Something experiments 139 | 140 | #### Training a model 141 | 142 | To train a model on smt-smt v2 dataset please run 143 | 144 | `./start_main_standard.sh model_name` 145 | 146 | For default hyperparameters check opts.py. For example, `place_graph` flag controls how many DyReg-GNN modules to use and where to place them inside the backbone: 147 | ``` 148 | # for a model with 3 DyReg-GNN modules placed after layer 2-block 2, layer 3-block 4 and layer 4-block 1 of the backbone 149 | --place_graph=layer2.2_layer3.4_layer4.1 150 | # for a model with 1 dyreg module placed after layer 3 block 4 of the backbone 151 | --place_graph=layer3.4 152 | ``` 153 | 154 | #### Single clip evaluation 155 | 156 | Train a model with the above script or download a pre-trained DyReg-GNN model from [here](https://drive.google.com/file/d/1MA_zfHSoutTVL6X9JxyIBLkX0rNJmC_6/view?usp=sharing) and put the checkpoint in ./ckeckpoints/ 157 | 158 | To evaluate a model on smt-smt v2 dataset on a single 224 x 224 central crop, run: 159 | 160 | `./start_main_standard_test.sh model_name` 161 | 162 | The flag `$RESUME_CKPT` indicate the the checkpoint used for evaluation. 163 | 164 | 165 | #### Multi clips evaluation 166 | 167 | To evaluate a model in the multi-clips setup (3 spatials clips x 2 temporal samplings) on Smt-Smt v2 dataset please run 168 | 169 | `./evaluate_model.sh model_name` 170 | 171 | The flag `$RESUME_CKPT` indicate the the checkpoint used for evaluation. 172 | 173 | ### TSM Baseline 174 | This repository adds DyReg-GNN modules to a TSM backbone based on code from [here](https://github.com/mit-han-lab/temporal-shift-module). 175 | -------------------------------------------------------------------------------- /ops/rstg.py: -------------------------------------------------------------------------------- 1 | 2 | # Code for RSTG model 3 | # Recurrent Space-time Graph Neural Networks - RSTG 4 | # adapted from https://github.com/IuliaDuta/RSTG 5 | 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | import sys 10 | import os 11 | sys.path.append(os.path.abspath(os.path.join( 12 | os.path.dirname(os.path.realpath(__file__)), '..'))) 13 | from opts import parser 14 | from ops.models_utils import * 15 | 16 | global args, best_prec1 17 | from opts import parse_args 18 | args = parser.parse_args() 19 | args = parse_args() 20 | import pdb 21 | 22 | 23 | class RSTG(nn.Module): 24 | def __init__(self,params, backbone_dim=1024, node_dim=512, project_i3d=True): 25 | 26 | super(RSTG, self).__init__() 27 | self.params = params 28 | self.backbone_dim = backbone_dim 29 | self.node_dim = node_dim 30 | self.number_iterations = 3 31 | self.num_nodes = 9 32 | self.project_i3d = project_i3d 33 | self.norm_dict = nn.ModuleDict({}) 34 | 35 | # intern LSTM 36 | self.internal_lstm = recurrent_net(batch_first=True, 37 | input_size=self.node_dim, hidden_size=self.node_dim) 38 | # extern LSTM 39 | self.external_lstm = recurrent_net(batch_first=True, 40 | input_size=self.node_dim, hidden_size=self.node_dim) 41 | 42 | # send function 43 | if self.params.send_layers == 1: 44 | self.send_mlp = nn.Sequential( 45 | torch.nn.Linear(self.node_dim, self.node_dim), 46 | torch.nn.ReLU() 47 | ) 48 | elif self.params.send_layers == 2: 49 | if self.params.combine_by_sum: 50 | comb_mult = 1 51 | else: 52 | comb_mult = 2 53 | self.send_mlp = nn.Sequential( 54 | torch.nn.Linear(comb_mult * self.node_dim, self.node_dim), 55 | torch.nn.ReLU(), 56 | torch.nn.Linear(self.node_dim, self.node_dim), 57 | torch.nn.ReLU() 58 | ) 59 | 60 | # norm send 61 | self.norm_dict['send_message_norm'] = LayerNormAffineXC( 62 | self.node_dim, (self.num_nodes,self.num_nodes,self.node_dim) 63 | ) 64 | self.norm_dict['update_norm'] = LayerNormAffineXC( 65 | self.node_dim, (self.num_nodes, self.node_dim) 66 | ) 67 | self.norm_dict['before_graph_norm'] = LayerNormAffineXC( 68 | self.node_dim, (self.num_nodes, self.node_dim) 69 | ) 70 | 71 | 72 | 73 | # attention function 74 | if 'dot' in self.params.aggregation_type: 75 | self.att_q = torch.nn.Linear(self.node_dim, self.node_dim) 76 | self.att_k = torch.nn.Linear(self.node_dim, self.node_dim) 77 | # attention bias 78 | self.att_bias = torch.nn.Parameter( 79 | torch.zeros(size=[1,1,1,self.node_dim], requires_grad=True) 80 | ) 81 | 82 | # update function 83 | if self.params.update_layers == 0: 84 | self.update_mlp = nn.Identity() 85 | elif self.params.update_layers == 1: 86 | self.update_mlp = nn.Sequential( 87 | torch.nn.Linear(self.node_dim, self.node_dim), 88 | torch.nn.ReLU() 89 | ) 90 | elif self.params.update_layers == 2: 91 | if self.params.combine_by_sum: 92 | comb_mult = 1 93 | else: 94 | comb_mult = 2 95 | self.update_mlp = nn.Sequential( 96 | torch.nn.Linear(comb_mult * self.node_dim, self.node_dim), 97 | torch.nn.ReLU(), 98 | torch.nn.Linear(self.node_dim, self.node_dim), 99 | torch.nn.ReLU() 100 | ) 101 | 102 | def get_norm(self, input, name, zero_init=False): 103 | # input: B x N x T x C 104 | # or B x N x N x T x C 105 | if len(input.size()) == 5: 106 | input = input.permute(0,3,1,2,4).contiguous() 107 | elif len(input.size()) == 4: 108 | input = input.permute(0,2,1,3).contiguous() 109 | 110 | norm = self.norm_dict[name] 111 | 112 | input = norm(input) 113 | if len(input.size()) == 5: 114 | input = input.permute(0,2,3,1,4).contiguous() 115 | elif len(input.size()) == 4: 116 | input = input.permute(0,2,1,3).contiguous() 117 | return input 118 | 119 | def send_messages(self, nodes): 120 | # nodes: B x N x T x C 121 | # nodes1: B x 1 x N x T x C 122 | # nodes2: B x N x 1 x T x C 123 | nodes1 = nodes.unsqueeze(1) 124 | nodes2 = nodes.unsqueeze(2) 125 | 126 | nodes1 = nodes1.repeat(1,self.num_nodes,1,1,1) 127 | nodes2 = nodes2.repeat(1,1,self.num_nodes,1,1) 128 | # B x N x N x T x 2C 129 | if self.params.combine_by_sum: 130 | messages = nodes1 + nodes2 131 | else: 132 | messages = torch.cat((nodes1, nodes2),dim=-1) 133 | messages = self.send_mlp(messages) 134 | messages = self.get_norm(messages, 'send_message_norm') 135 | return messages 136 | 137 | def aggregate(self, nodes, messages): 138 | if 'sum' in self.params.aggregation_type: 139 | return self.sum_aggregate(messages) 140 | elif 'dot' in self.params.aggregation_type: 141 | return self.dot_attention(nodes, messages) 142 | 143 | def sum_aggregate(self, messages): 144 | # nodes: B x N x T x C 145 | # messages: B x NxN x T x C 146 | return messages.mean(dim=2) 147 | 148 | def dot_attention(self, nodes, messages): 149 | # nodes: B x N x T x C 150 | # messages: B x NxN x T x C 151 | 152 | # nodes1: B x N x 1 x T x C 153 | # nodes2: B x 1 x N x T x C 154 | # corr B x N x N x T 155 | 156 | nodes_q = self.att_q(nodes) 157 | nodes_k = self.att_k(nodes) 158 | 159 | nodes_q = nodes_q.permute(0,2,1,3) 160 | nodes_k = nodes_k.permute(0,2,3,1) 161 | 162 | corr = torch.matmul(nodes_q, nodes_k).unsqueeze(-1) 163 | corr = corr.permute(0,2,3,1,4) 164 | 165 | nodes = F.softmax(corr, dim=2) * messages 166 | nodes = nodes.sum(dim=2) 167 | 168 | nodes = nodes + self.att_bias 169 | nodes = F.relu(nodes) 170 | return nodes 171 | 172 | def update_nodes(self, nodes, aggregated): 173 | if self.params.combine_by_sum: 174 | upd_input = nodes + aggregated 175 | else: 176 | upd_input = torch.cat((nodes, aggregated), dim=-1) 177 | nodes = self.update_mlp(upd_input) 178 | nodes = self.get_norm(nodes, 'update_norm') 179 | return nodes 180 | 181 | def forward(self, input): 182 | self.B = input.shape[0] 183 | self.T = input.shape[1] 184 | 185 | # input RSTG: B x T x C x H x W 186 | # set input: ... -> B x T x C x N 187 | 188 | # for LSTM we need (B * N) x T x C 189 | # propagation: B x 1 x N x T x C 190 | # + B x N x 1 x T x C 191 | # (B x N*N x T) x C => liniar 192 | # nodes: B x N x T x C 193 | nodes = input.permute(0,3,1,2) 194 | 195 | nodes = self.get_norm(nodes, 'before_graph_norm') 196 | time_iter_mom = [0, 1, 2] 197 | 198 | for space_iter in range(self.number_iterations): 199 | # internal time processing 200 | if space_iter in time_iter_mom: 201 | nodes = nodes.view(self.B * self.num_nodes, self.T, self.node_dim) 202 | self.internal_lstm.flatten_parameters() 203 | nodes, _ = self.internal_lstm(nodes) 204 | nodes = nodes.view(self.B, self.num_nodes, self.T, self.node_dim) 205 | 206 | # space_processing: send, aggregate, update 207 | messages = self.send_messages(nodes) 208 | aggregated = self.aggregate(nodes, messages) 209 | nodes = self.update_nodes(nodes, aggregated) 210 | 211 | # external time processing 212 | nodes = nodes.view(self.B * self.num_nodes, self.T, self.node_dim) 213 | self.external_lstm.flatten_parameters() 214 | nodes, _ = self.external_lstm(nodes) 215 | nodes = nodes.view(self.B, self.num_nodes, self.T, self.node_dim) 216 | 217 | 218 | # B x N x T x C -> B x T x C x N 219 | nodes = nodes.permute(0,2,3,1).contiguous() 220 | 221 | return nodes 222 | -------------------------------------------------------------------------------- /ops/models_config.py: -------------------------------------------------------------------------------- 1 | from opts import parser 2 | global args 3 | args = parser.parse_args() 4 | 5 | def return_config_resnet13(): 6 | graph_params = {} 7 | graph_params[1] = {'in_channels' : 64, 'H' : 32, 'node_dim' : 2 * 64 // args.ch_div, 'project_i3d' : True, 'name': 'layer1'} 8 | graph_params[2] = {'in_channels' : 128, 'H' : 16, 'node_dim' : 128 // args.ch_div, 'project_i3d' : True, 'name': 'layer2'} 9 | graph_params[3] = {'in_channels' : 256, 'H' : 8, 'node_dim' : 256 // args.ch_div, 'project_i3d' : True, 'name': 'layer3'} 10 | out_pool_size = 8 11 | # 256 12 | return graph_params, out_pool_size 13 | 14 | def return_config_resnet18(): 15 | graph_params = {} 16 | if args.bottleneck_graph: 17 | if args.full_res: 18 | graph_params[1] = {'in_channels' : 64 // args.ch_div, 'iH' : 56, 'H' : 64, 'node_dim' : 64 // args.ch_div, 'name': 'layer1'} 19 | graph_params[2] = {'in_channels' : 128// args.ch_div, 'iH' : 28, 'H' : 32, 'node_dim' : 128 // args.ch_div, 'name': 'layer2'} 20 | graph_params[3] = {'in_channels' : 256// args.ch_div, 'iH' : 14, 'H' : 16, 'node_dim' : 256 // args.ch_div, 'name': 'layer3'} 21 | graph_params[4] = {'in_channels' : 512// args.ch_div, 'iH' : 7, 'H' : 8, 'node_dim' : min(256,512 // args.ch_div), 'project_i3d' : True, 'name': 'layer4'} 22 | out_pool_size = 8 23 | else: 24 | graph_params[1] = {'in_channels' : 64 // args.ch_div, 'H' : 56, 'iH' : 56, 'node_dim' : 64 // args.ch_div, 'name': 'layer1'} 25 | graph_params[2] = {'in_channels' : 128// args.ch_div, 'H' : 28, 'iH' : 28, 'node_dim' : 128 // args.ch_div, 'name': 'layer2'} 26 | graph_params[3] = {'in_channels' : 256// args.ch_div, 'H' : 14, 'iH' : 14, 'node_dim' : 256 // args.ch_div, 'name': 'layer3'} 27 | graph_params[4] = {'in_channels' : 512// args.ch_div, 'H' : 7, 'iH' : 7, 'node_dim' : min(256,512 // args.ch_div), 'project_i3d' : True, 'name': 'layer4'} 28 | out_pool_size = 7 29 | 30 | for i in [1,2,3]: 31 | graph_params[i]['project_i3d'] = (args.offset_generator == 'glore') #True for glore, False for fishnet 32 | else: 33 | graph_params[1] = {'in_channels' : 64, 'H' : 56, 'node_dim' : 64 // args.ch_div, 'project_i3d' : True, 'name': 'layer1'} 34 | graph_params[2] = {'in_channels' : 128, 'H' : 28, 'node_dim' : 128 // args.ch_div, 'project_i3d' : True, 'name': 'layer2'} 35 | graph_params[3] = {'in_channels' : 256, 'H' : 14, 'node_dim' : 256 // args.ch_div, 'project_i3d' : True, 'name': 'layer3'} 36 | graph_params[4] = {'in_channels' : 512, 'H' : 7, 'node_dim' : min(512,512 // args.ch_div), 'project_i3d' : True, 'name': 'layer4'} 37 | out_pool_size = 7 38 | 39 | return graph_params, out_pool_size 40 | 41 | def return_config_resnet34(): 42 | graph_params = {} 43 | 44 | graph_params[1] = {'in_channels' : 64, 'H' : 56, 'node_dim' : 64 // args.ch_div, 'project_i3d' : True, 'name': 'layer1'} 45 | graph_params[2] = {'in_channels' : 128, 'H' : 28, 'node_dim' : 128 // args.ch_div, 'project_i3d' : True, 'name': 'layer2'} 46 | graph_params[3] = {'in_channels' : 256, 'H' : 14, 'node_dim' : 256 // args.ch_div, 'project_i3d' : True, 'name': 'layer3'} 47 | graph_params[4] = {'in_channels' : 512, 'H' : 7, 'node_dim' : min(512,512 // args.ch_div), 'project_i3d' : True, 'name': 'layer4'} 48 | out_pool_size = 7 49 | 50 | return graph_params, out_pool_size 51 | 52 | 53 | 54 | def return_config_wide_resnet50_2(): 55 | graph_params = {} 56 | 57 | graph_params[1] = {'in_channels' : 512 // args.ch_div, 'H' : 56, 'iH' : 56, 'node_dim' : 512 // args.ch_div, 'name': 'layer1'} 58 | graph_params[2] = {'in_channels' : 1024// args.ch_div, 'H' : 28, 'iH' : 28, 'node_dim' : 1024 // args.ch_div, 'name': 'layer2'} 59 | graph_params[3] = {'in_channels' : 2048// args.ch_div, 'H' : 14, 'iH' : 14, 'node_dim' : 2048 // args.ch_div, 'name': 'layer3'} 60 | graph_params[4] = {'in_channels' : 4096// args.ch_div, 'H' : 7, 'iH' : 7, 'node_dim' : min(512,4096 // args.ch_div), 'project_i3d' : True, 'name': 'layer4'} 61 | 62 | 63 | out_pool_size = 7 64 | 65 | for i in [1,2,3]: 66 | graph_params[i]['project_i3d'] = (args.offset_generator == 'glore') #True for glore, False for fishnet 67 | 68 | 69 | return graph_params, out_pool_size 70 | def return_config_resnet50(): 71 | graph_params = {} 72 | out_pool_size = 0 73 | if args.bottleneck_graph: 74 | 75 | if args.full_res: 76 | graph_params[1] = {'in_channels' : 256 // args.ch_div, 'iH' : 56, 'H' : 64, 'node_dim' : 256 // args.ch_div, 'name': 'layer1'} 77 | graph_params[2] = {'in_channels' : 512// args.ch_div, 'iH' : 28, 'H' : 32, 'node_dim' : 512 // args.ch_div, 'name': 'layer2'} 78 | graph_params[3] = {'in_channels' : 1024// args.ch_div, 'iH' : 14, 'H' : 16, 'node_dim' : 1024 // args.ch_div, 'name': 'layer3'} 79 | graph_params[4] = {'in_channels' : 2048// args.ch_div, 'iH' : 7, 'H' : 8, 'node_dim' : min(256,2048 // args.ch_div), 'project_i3d' : True, 'name': 'layer4'} 80 | out_pool_size = 8 81 | 82 | else: 83 | graph_params[1] = {'in_channels' : 256 // args.ch_div, 'H' : 56, 'iH' : 56, 'node_dim' : 256 // args.ch_div, 'name': 'layer1'} 84 | graph_params[2] = {'in_channels' : 512// args.ch_div, 'H' : 28, 'iH' : 28, 'node_dim' : 512 // args.ch_div, 'name': 'layer2'} 85 | graph_params[3] = {'in_channels' : 1024// args.ch_div, 'H' : 14, 'iH' : 14, 'node_dim' : 1024 // args.ch_div, 'name': 'layer3'} 86 | graph_params[4] = {'in_channels' : 2048// args.ch_div, 'H' : 7, 'iH' : 7, 'node_dim' : min(256,2048 // args.ch_div), 'project_i3d' : True, 'name': 'layer4'} 87 | out_pool_size = 7 88 | 89 | for i in [1,2,3]: 90 | graph_params[i]['project_i3d'] = (args.offset_generator == 'glore') #True for glore, False for fishnet 91 | 92 | 93 | else: 94 | # TO BE REEMOVED 95 | graph_params[1] = {'in_channels' : 256, 'H' : 56, 'node_dim' : 256 // args.ch_div, 'project_i3d' : True, 'name': 'layer1'} 96 | graph_params[2] = {'in_channels' : 512, 'H' : 28, 'node_dim' : 512 // args.ch_div, 'project_i3d' : True, 'name': 'layer2'} 97 | graph_params[3] = {'in_channels' : 1024, 'H' : 14, 'node_dim' : 1024 // args.ch_div, 'project_i3d' : True, 'name': 'layer3'} 98 | graph_params[4] = {'in_channels' : 2048, 'H' : 7, 'node_dim' : min(512,2048 // args.ch_div), 'project_i3d' : True, 'name': 'layer4'} 99 | 100 | return graph_params, out_pool_size 101 | 102 | def return_config_resnet101(): 103 | graph_params = {} 104 | if args.bottleneck_graph: 105 | 106 | graph_params[1] = {'in_channels' : 256 // args.ch_div, 'H' : 56, 'iH' : 56, 'node_dim' : 256 // args.ch_div, 'name': 'layer1'} 107 | graph_params[2] = {'in_channels' : 512// args.ch_div, 'H' : 28, 'iH' : 28, 'node_dim' : 512 // args.ch_div, 'name': 'layer2'} 108 | graph_params[3] = {'in_channels' : 1024// args.ch_div, 'H' : 14, 'iH' : 14, 'node_dim' : 1024 // args.ch_div, 'name': 'layer3'} 109 | graph_params[4] = {'in_channels' : 2048// args.ch_div, 'H' : 7, 'iH' : 7, 'node_dim' : min(256,2048 // args.ch_div), 'project_i3d' : True, 'name': 'layer4'} 110 | out_pool_size = 7 111 | 112 | for i in [1,2,3]: 113 | graph_params[i]['project_i3d'] = (args.offset_generator == 'glore') #True for glore, False for fishnet 114 | else: 115 | graph_params[1] = {'in_channels' : 256, 'H' : 56, 'node_dim' : 256 // args.ch_div, 'project_i3d' : True, 'name': 'layer1'} 116 | graph_params[2] = {'in_channels' : 512, 'H' : 28, 'node_dim' : 512 // args.ch_div, 'project_i3d' : True, 'name': 'layer2'} 117 | graph_params[3] = {'in_channels' : 1024, 'H' : 14, 'node_dim' : 1024 // args.ch_div, 'project_i3d' : True, 'name': 'layer3'} 118 | graph_params[4] = {'in_channels' : 2048, 'H' : 7, 'node_dim' : min(512,2048 // args.ch_div), 'project_i3d' : True, 'name': 'layer4'} 119 | out_pool_size = 7 120 | if args.full_res: 121 | out_pool_size = 8 122 | return graph_params, out_pool_size 123 | 124 | 125 | 126 | 127 | 128 | def get_models_config(): 129 | graph_params = {} 130 | if args.arch == 'resnet13': 131 | graph_params, out_pool_size = return_config_resnet13() 132 | if args.arch == 'resnet18': 133 | graph_params, out_pool_size = return_config_resnet18() 134 | elif args.arch == 'resnet34': 135 | graph_params, out_pool_size = return_config_resnet34() 136 | elif args.arch == 'resnet50': 137 | graph_params, out_pool_size = return_config_resnet50() 138 | elif args.arch == 'wide_resnet50_2': 139 | graph_params, out_pool_size = return_config_wide_resnet50_2() 140 | elif args.arch == 'resnet101': 141 | graph_params, out_pool_size = return_config_resnet101() 142 | out_num_ch = 2048 143 | return graph_params, out_pool_size, out_num_ch, None -------------------------------------------------------------------------------- /ops/dataset.py: -------------------------------------------------------------------------------- 1 | # Code adapted from "TSM: Temporal Shift Module for Efficient Video Understanding" 2 | # Ji Lin*, Chuang Gan, Song Han 3 | 4 | 5 | import torch.utils.data as data 6 | 7 | from PIL import Image 8 | import os 9 | import numpy as np 10 | from numpy.random import randint 11 | global args, best_prec1 12 | from opts import parse_args 13 | args = parse_args() 14 | 15 | class VideoRecord(object): 16 | def __init__(self, row): 17 | self._data = row 18 | 19 | @property 20 | def path(self): 21 | return self._data[0] 22 | 23 | @property 24 | def num_frames(self): 25 | return int(self._data[1]) 26 | 27 | @property 28 | def label(self): 29 | return int(self._data[2]) 30 | 31 | 32 | class TSNDataSet(data.Dataset): 33 | def __init__(self, root_path, list_file, 34 | num_segments=3, new_length=1, modality='RGB', 35 | image_tmpl='img_{:05d}.jpg', transform=None, 36 | random_shift=True, test_mode=False, 37 | remove_missing=False, dense_sample=False, twice_sample=False, split=''): 38 | 39 | self.root_path = root_path 40 | self.list_file = list_file 41 | self.num_segments = num_segments 42 | self.new_length = new_length 43 | self.modality = modality 44 | self.image_tmpl = image_tmpl 45 | self.transform = transform 46 | self.random_shift = random_shift 47 | self.test_mode = test_mode 48 | self.remove_missing = remove_missing 49 | self.dense_sample = dense_sample # using dense sample as I3D 50 | self.twice_sample = twice_sample # twice sample for more validation 51 | if self.dense_sample: 52 | print('=> Using dense sample for the dataset...') 53 | if self.twice_sample: 54 | print('=> Using twice sample for the dataset...') 55 | 56 | self.split = split 57 | self.epoch = -1 58 | self._parse_list() 59 | 60 | 61 | def set_epoch_attributes(self, epoch, augment_dict, detected_boxes_dict=None): 62 | self.epoch = epoch 63 | self.augment_dict = augment_dict 64 | 65 | print(self.epoch) 66 | 67 | def _load_image(self, directory, idx): 68 | try: 69 | return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(idx))).convert('RGB')] 70 | except Exception: 71 | print('error loading image:', os.path.join(self.root_path, directory, self.image_tmpl.format(idx))) 72 | return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(1))).convert('RGB')] 73 | 74 | 75 | def _parse_list(self): 76 | # check the frame number is large >3: 77 | tmp = [x.strip().split(' ') for x in open(self.list_file)] 78 | if not self.test_mode or self.remove_missing: 79 | print(f'before removing missing: {len(tmp)}') 80 | tmp = [item for item in tmp if int(item[1]) >= 3] 81 | print(f'after removing missing: {len(tmp)}') 82 | 83 | self.video_list = [VideoRecord(item) for item in tmp] 84 | 85 | if self.image_tmpl == '{:06d}-{}_{:05d}.jpg': 86 | for v in self.video_list: 87 | v._data[1] = int(v._data[1]) / 2 88 | print('video number:%d' % (len(self.video_list))) 89 | 90 | def get_video_list(self): 91 | return self.video_list 92 | 93 | def _sample_indices(self, record): 94 | """ 95 | :param record: VideoRecord 96 | :return: list 97 | """ 98 | if self.dense_sample: # i3d dense sample 99 | sample_pos = max(1, 1 + record.num_frames - 64) 100 | t_stride = 64 // self.num_segments 101 | start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1) 102 | offsets = [(idx * t_stride + start_idx) % record.num_frames for idx in range(self.num_segments)] 103 | return np.array(offsets) + 1 104 | else: # normal sample 105 | average_duration = (record.num_frames - self.new_length + 1) // self.num_segments 106 | if average_duration > 0: 107 | offsets = np.multiply(list(range(self.num_segments)), average_duration) + randint(average_duration, 108 | size=self.num_segments) 109 | elif record.num_frames > self.num_segments: 110 | offsets = np.sort(randint(record.num_frames - self.new_length + 1, size=self.num_segments)) 111 | else: 112 | # shouldn here be a range? 113 | offsets = np.zeros((self.num_segments,)) 114 | return offsets + 1 115 | 116 | def _get_val_indices(self, record): 117 | if self.dense_sample: # i3d dense sample 118 | sample_pos = max(1, 1 + record.num_frames - 64) 119 | t_stride = 64 // self.num_segments 120 | start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1) 121 | offsets = [(idx * t_stride + start_idx) % record.num_frames for idx in range(self.num_segments)] 122 | return np.array(offsets) + 1 123 | else: 124 | if record.num_frames > self.num_segments + self.new_length - 1: 125 | tick = (record.num_frames - self.new_length + 1) / float(self.num_segments) 126 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)]) 127 | else: 128 | # shouldn here be a range? 129 | offsets = np.zeros((self.num_segments,)) 130 | return offsets + 1 131 | 132 | def _get_test_indices(self, record): 133 | if self.dense_sample: 134 | sample_pos = max(1, 1 + record.num_frames - 64) 135 | t_stride = 64 // self.num_segments 136 | start_list = np.linspace(0, sample_pos - 1, num=10, dtype=int) 137 | offsets = [] 138 | for start_idx in start_list.tolist(): 139 | offsets += [(idx * t_stride + start_idx) % record.num_frames for idx in range(self.num_segments)] 140 | return np.array(offsets) + 1 141 | elif self.twice_sample: 142 | 143 | 144 | tick = (record.num_frames - self.new_length + 1) / float(self.num_segments) 145 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)] + 146 | [int(tick * x) for x in range(self.num_segments)]) 147 | 148 | return offsets + 1 149 | else: 150 | tick = (record.num_frames - self.new_length + 1) / float(self.num_segments) 151 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)]) 152 | return offsets + 1 153 | 154 | 155 | 156 | def __getitem__(self, index): 157 | record = self.video_list[index] 158 | # check this is a legit video folder 159 | 160 | if self.image_tmpl == 'flow_{}_{:05d}.jpg': 161 | file_name = self.image_tmpl.format('x', 1) 162 | full_path = os.path.join(self.root_path, record.path, file_name) 163 | elif self.image_tmpl == '{:06d}-{}_{:05d}.jpg': 164 | file_name = self.image_tmpl.format(int(record.path), 'x', 1) 165 | full_path = os.path.join(self.root_path, '{:06d}'.format(int(record.path)), file_name) 166 | else: 167 | file_name = self.image_tmpl.format(1) 168 | full_path = os.path.join(self.root_path, record.path, file_name) 169 | 170 | while not os.path.exists(full_path): 171 | print('################## Not Found:', os.path.join(self.root_path, record.path, file_name)) 172 | index = np.random.randint(len(self.video_list)) 173 | record = self.video_list[index] 174 | if self.image_tmpl == 'flow_{}_{:05d}.jpg': 175 | file_name = self.image_tmpl.format('x', 1) 176 | full_path = os.path.join(self.root_path, record.path, file_name) 177 | elif self.image_tmpl == '{:06d}-{}_{:05d}.jpg': 178 | file_name = self.image_tmpl.format(int(record.path), 'x', 1) 179 | full_path = os.path.join(self.root_path, '{:06d}'.format(int(record.path)), file_name) 180 | else: 181 | file_name = self.image_tmpl.format(1) 182 | full_path = os.path.join(self.root_path, record.path, file_name) 183 | 184 | if not self.test_mode: 185 | segment_indices = self._sample_indices(record) if self.random_shift else self._get_val_indices(record) 186 | else: 187 | segment_indices = self._get_test_indices(record) 188 | crt_detected_boxes = 10 189 | return self.get(record, segment_indices, crt_detected_boxes) 190 | 191 | def get(self, record, indices, detected_boxes): 192 | images = list() 193 | temp_indices = list() 194 | for seg_ind in indices: 195 | p = int(seg_ind) 196 | for i in range(self.new_length): 197 | seg_imgs = self._load_image(record.path, p) 198 | 199 | temp_indices.extend([p]) 200 | images.extend(seg_imgs) 201 | if p < record.num_frames: 202 | p += 1 203 | 204 | process_data = self.transform(images) 205 | gt_box = np.zeros((16, 10,4)) 206 | return gt_box, process_data, record.label, 0 207 | 208 | 209 | def __len__(self): 210 | return len(self.video_list) -------------------------------------------------------------------------------- /ops/temporal_shift.py: -------------------------------------------------------------------------------- 1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding" 2 | # arXiv:1811.08383 3 | # Ji Lin*, Chuang Gan, Song Han 4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import pdb 10 | 11 | class TemporalShift(nn.Module): 12 | def __init__(self, net, n_segment=3, n_div=8, inplace=False): 13 | super(TemporalShift, self).__init__() 14 | self.net = net 15 | self.n_segment = n_segment 16 | self.fold_div = n_div 17 | self.inplace = inplace 18 | if inplace: 19 | print('=> Using in-place shift...') 20 | print('=> Using fold div: {}'.format(self.fold_div)) 21 | 22 | def forward(self, x): 23 | x = self.shift(x, self.n_segment, fold_div=self.fold_div, inplace=self.inplace) 24 | return self.net(x) 25 | 26 | @staticmethod 27 | def shift(x, n_segment, fold_div=3, inplace=False): 28 | nt, c, h, w = x.size() 29 | n_batch = nt // n_segment 30 | x = x.view(n_batch, n_segment, c, h, w) 31 | 32 | fold = c // fold_div 33 | if inplace: 34 | # Due to some out of order error when performing parallel computing. 35 | # May need to write a CUDA kernel. 36 | raise NotImplementedError 37 | # out = InplaceShift.apply(x, fold) 38 | else: 39 | # pdb.set_trace() 40 | # print(f'shift input mean: {x.mean()}') 41 | out = torch.zeros_like(x) 42 | out[:, :-1, :fold] = x[:, 1:, :fold] # shift left 43 | out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold] # shift right 44 | out[:, :, 2 * fold:] = x[:, :, 2 * fold:] # not shift 45 | # print(f'shift out mean: {out.mean()}') 46 | # print(f'out shape {out.shape}') 47 | return out.view(nt, c, h, w) 48 | 49 | class MyIdentity(nn.Module): 50 | def __init__(self, net): 51 | super(MyIdentity, self).__init__() 52 | 53 | 54 | def forward(self, x): 55 | return x 56 | 57 | 58 | 59 | class InplaceShift(torch.autograd.Function): 60 | # Special thanks to @raoyongming for the help to this function 61 | @staticmethod 62 | def forward(ctx, input, fold): 63 | # not support higher order gradient 64 | # input = input.detach_() 65 | ctx.fold_ = fold 66 | n, t, c, h, w = input.size() 67 | buffer = input.data.new(n, t, fold, h, w).zero_() 68 | buffer[:, :-1] = input.data[:, 1:, :fold] 69 | input.data[:, :, :fold] = buffer 70 | buffer.zero_() 71 | buffer[:, 1:] = input.data[:, :-1, fold: 2 * fold] 72 | input.data[:, :, fold: 2 * fold] = buffer 73 | return input 74 | 75 | @staticmethod 76 | def backward(ctx, grad_output): 77 | # grad_output = grad_output.detach_() 78 | fold = ctx.fold_ 79 | n, t, c, h, w = grad_output.size() 80 | buffer = grad_output.data.new(n, t, fold, h, w).zero_() 81 | buffer[:, 1:] = grad_output.data[:, :-1, :fold] 82 | grad_output.data[:, :, :fold] = buffer 83 | buffer.zero_() 84 | buffer[:, :-1] = grad_output.data[:, 1:, fold: 2 * fold] 85 | grad_output.data[:, :, fold: 2 * fold] = buffer 86 | return grad_output, None 87 | 88 | 89 | class TemporalPool(nn.Module): 90 | def __init__(self, net, n_segment): 91 | super(TemporalPool, self).__init__() 92 | self.net = net 93 | self.n_segment = n_segment 94 | 95 | def forward(self, x): 96 | x = self.temporal_pool(x, n_segment=self.n_segment) 97 | return self.net(x) 98 | 99 | @staticmethod 100 | def temporal_pool(x, n_segment): 101 | nt, c, h, w = x.size() 102 | n_batch = nt // n_segment 103 | x = x.view(n_batch, n_segment, c, h, w).transpose(1, 2) # n, c, t, h, w 104 | x = F.max_pool3d(x, kernel_size=(3, 1, 1), stride=(2, 1, 1), padding=(1, 0, 0)) 105 | x = x.transpose(1, 2).contiguous().view(nt // 2, c, h, w) 106 | return x 107 | 108 | 109 | def make_temporal_shift(net, n_segment, n_div=8, place='blockres', temporal_pool=False): 110 | if temporal_pool: 111 | n_segment_list = [n_segment, n_segment // 2, n_segment // 2, n_segment // 2] 112 | else: 113 | n_segment_list = [n_segment] * 4 114 | assert n_segment_list[-1] > 0 115 | print('=> n_segment per stage: {}'.format(n_segment_list)) 116 | 117 | import torchvision 118 | # if isinstance(net, torchvision.models.ResNet): 119 | if True: 120 | if place == 'block': 121 | # print('place == block') 122 | def make_block_temporal(stage, this_segment): 123 | blocks = list(stage.children()) 124 | # print('=> Processing stage with {} blocks'.format(len(blocks))) 125 | for i, b in enumerate(blocks): 126 | # print(f'tenporal shift {b}') 127 | # pdb.set_trace() 128 | blocks[i] = TemporalShift(b, n_segment=this_segment, n_div=n_div) 129 | return nn.Sequential(*(blocks)) 130 | # pdb.set_trace() 131 | net.layer1 = make_block_temporal(net.layer1, n_segment_list[0]) 132 | net.layer2 = make_block_temporal(net.layer2, n_segment_list[1]) 133 | net.layer3 = make_block_temporal(net.layer3, n_segment_list[2]) 134 | net.layer4 = make_block_temporal(net.layer4, n_segment_list[3]) 135 | 136 | elif 'blockres' in place: 137 | print(f'place={place}') 138 | 139 | n_round = 1 140 | if len(list(net.layer3.children())) >= 23: 141 | n_round = 2 142 | print('=> Using n_round {} to insert temporal shift'.format(n_round)) 143 | 144 | def make_block_temporal(stage, this_segment): 145 | blocks = list(stage.children()) 146 | print('=> Processing stage with {} blocks residual'.format(len(blocks))) 147 | nr_shift_layers = 0 148 | # pdb.set_trace() 149 | for i, b in enumerate(blocks): 150 | if i % n_round == 0: 151 | # print(f'[{i}] tenporal shift. this_segment = {this_segment}. n_div = {n_div}. n_round = {n_round}') 152 | #pdb.set_trace() 153 | nr_shift_layers += 1 154 | blocks[i].conv1 = TemporalShift(b.conv1, n_segment=this_segment, n_div=n_div) 155 | # blocks[i].conv1 = b.conv1 156 | # print(f'nr_shift_layers: {nr_shift_layers}') 157 | return nn.Sequential(*blocks) 158 | # pdb.set_trace() 159 | # net.bn1 = MyIdentity(net.bn1) 160 | 161 | # net.conv1 = TemporalShift(net.conv1, n_segment=n_segment_list[0], n_div=3) 162 | net.layer1 = make_block_temporal(net.layer1, n_segment_list[0]) 163 | net.layer2 = make_block_temporal(net.layer2, n_segment_list[1]) 164 | net.layer3 = make_block_temporal(net.layer3, n_segment_list[2]) 165 | net.layer4 = make_block_temporal(net.layer4, n_segment_list[3]) 166 | else: 167 | raise NotImplementedError(place) 168 | 169 | 170 | def make_temporal_pool(net, n_segment): 171 | import torchvision 172 | if isinstance(net, torchvision.models.ResNet): 173 | print('=> Injecting nonlocal pooling') 174 | net.layer2 = TemporalPool(net.layer2, n_segment) 175 | else: 176 | raise NotImplementedError 177 | 178 | 179 | if __name__ == '__main__': 180 | # test inplace shift v.s. vanilla shift 181 | tsm1 = TemporalShift(nn.Sequential(), n_segment=8, n_div=8, inplace=False) 182 | tsm2 = TemporalShift(nn.Sequential(), n_segment=8, n_div=8, inplace=True) 183 | 184 | print('=> Testing CPU...') 185 | # test forward 186 | with torch.no_grad(): 187 | for i in range(10): 188 | x = torch.rand(2 * 8, 3, 224, 224) 189 | y1 = tsm1(x) 190 | y2 = tsm2(x) 191 | assert torch.norm(y1 - y2).item() < 1e-5 192 | 193 | # test backward 194 | with torch.enable_grad(): 195 | for i in range(10): 196 | x1 = torch.rand(2 * 8, 3, 224, 224) 197 | x1.requires_grad_() 198 | x2 = x1.clone() 199 | y1 = tsm1(x1) 200 | y2 = tsm2(x2) 201 | grad1 = torch.autograd.grad((y1 ** 2).mean(), [x1])[0] 202 | grad2 = torch.autograd.grad((y2 ** 2).mean(), [x2])[0] 203 | assert torch.norm(grad1 - grad2).item() < 1e-5 204 | 205 | print('=> Testing GPU...') 206 | tsm1.cuda() 207 | tsm2.cuda() 208 | # test forward 209 | with torch.no_grad(): 210 | for i in range(10): 211 | x = torch.rand(2 * 8, 3, 224, 224).cuda() 212 | y1 = tsm1(x) 213 | y2 = tsm2(x) 214 | assert torch.norm(y1 - y2).item() < 1e-5 215 | 216 | # test backward 217 | with torch.enable_grad(): 218 | for i in range(10): 219 | x1 = torch.rand(2 * 8, 3, 224, 224).cuda() 220 | x1.requires_grad_() 221 | x2 = x1.clone() 222 | y1 = tsm1(x1) 223 | y2 = tsm2(x2) 224 | grad1 = torch.autograd.grad((y1 ** 2).mean(), [x1])[0] 225 | grad2 = torch.autograd.grad((y2 ** 2).mean(), [x2])[0] 226 | assert torch.norm(grad1 - grad2).item() < 1e-5 227 | print('Test passed.') 228 | 229 | 230 | 231 | 232 | -------------------------------------------------------------------------------- /ops/dataset_syncMNIST.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from glob import glob 3 | import os 4 | import pickle 5 | import torch 6 | import time 7 | import pdb 8 | import math 9 | 10 | 11 | def get_labels_dict_multi_sincron(): 12 | # create labels 13 | label = 0 14 | labels_dict = {} 15 | for i in range(10): 16 | for j in range(i): 17 | labels_dict[(i, j)] = label 18 | labels_dict[(j, i)] = label 19 | label += 1 20 | # no sync digits 21 | labels_dict[(-1, -1)] = label 22 | label += 1 23 | # 24 | for i in range(10): 25 | labels_dict[(i, i)] = label 26 | label += 1 27 | 28 | return labels_dict, label 29 | 30 | class SyncedMNISTDataSet(torch.utils.data.IterableDataset): 31 | def __init__(self, split='train', dataset_path=None, dataset_fraction_used=1.0): 32 | super(SyncedMNISTDataSet).__init__() 33 | #assert end > start, "this example code only works with end >= start" 34 | self.split = split 35 | #self.dataset_path = f'/data/datasets/video_mnist/sync_mnist_large_v2_split_{split}_max_sync_dist_160_num_classes_46_no_digits_5_no_noise_parts_0_uint8/' 36 | #self.dataset_path = f'/data/datasets/video_mnist/sync_mnist_large_v2_split_{split}_max_sync_dist_160_num_classes_46_no_digits_5_no_noise_parts_0_uint8/' 37 | #self.dataset_path = f'/data/datasets/video_mnist/sync_mnist_large_v2_split_{split}_max_sync_dist_160_num_classes_46_no_digits_5_no_noise_parts_0_uint8/' 38 | self.dataset_path = dataset_path 39 | self.worker_id = 0 40 | self.max_workers = 0 41 | self.dataset_len = len(glob(self.dataset_path + '/data*pickle')) * 1000 42 | self.dataset_fraction_used = dataset_fraction_used 43 | # print(f'in iterable dataset: {self.worker_id } / {self.max_workers}') 44 | self.gen = None 45 | 46 | def __iter__(self): 47 | # print(f'worker: [{self.worker_id } / {self.max_workers}]') 48 | return self.gen.next_item(self.worker_id) 49 | def __len__(self): 50 | return self.dataset_len 51 | def syncMNIST_worker_init_fn(worker_id): 52 | worker_info = torch.utils.data.get_worker_info() 53 | dataset = worker_info.dataset # the dataset copy in this worker process 54 | dataset.worker_id = worker_info.id 55 | dataset.max_workers = worker_info.num_workers 56 | # print(f'in woker init fct: {dataset.worker_id}/ {dataset.max_workers}') 57 | dataset.gen = SyncMNISTGenerator(dataset_path=dataset.dataset_path, 58 | worker_id=dataset.worker_id, max_workers=dataset.max_workers, 59 | split=dataset.split, dataset_fraction_used=dataset.dataset_fraction_used) 60 | # print(f'in woker init fct: len gen {len(dataset.gen)}') 61 | 62 | def read_data_mnist(file, num_classes=65): 63 | with open(file, 'rb') as fo: 64 | # print(f'loading: {file}') 65 | videos_dict = pickle.load(fo) 66 | # x = np.expand_dims(videos_dict['videos'], 4) 67 | # x = videos_dict['videos']#.astype(np.float32) / 255.0 * 2 - 1.0 68 | x = videos_dict['videos']#.astype(np.uint8) 69 | 70 | y = videos_dict['labels'].astype(int).squeeze() 71 | y = np.clip(y, 0, num_classes-1) 72 | #y = np.expand_dims(np.eye(num_classes)[y], axis=1) 73 | # y = y.astype(np.float32) 74 | 75 | coords = videos_dict['videos_digits_coords'] 76 | top_left = coords 77 | bot_right = coords + 28 78 | digits_boxes = np.concatenate([top_left,bot_right], axis=-1) // 2 79 | is_ann_boxes = np.ones((digits_boxes.shape[0], 1)) 80 | 81 | video = x 82 | label = y.astype(np.int64) 83 | 84 | if False: 85 | digits = videos_dict['videos_digits'] 86 | min_digits = digits.min(1) 87 | max_digits = digits.max(1) 88 | labels_dict, max_labels = get_labels_dict_multi_sincron() 89 | 90 | min_max = np.zeros_like(label) 91 | for i in range(min_max.shape[0]): 92 | min_max[i] = labels_dict[(min_digits[i], max_digits[i])] 93 | 94 | 95 | # print(f'video.dtype: {video.dtype}') 96 | # video = torch.from_numpy(x).float() 97 | # label = torch.from_numpy(y).long() 98 | # print(f'video: {video.shape}') 99 | return video, label, digits_boxes, is_ann_boxes 100 | # return video, label, digits_boxes, is_ann_boxes, min_max 101 | 102 | class SyncMNISTGenerator(): 103 | def __init__(self, dataset_path,split, max_epochs = 100, num_classes=46, num_digits=0, 104 | worker_id=0, max_workers=1, dataset_fraction_used=1.0): 105 | no_videos_per_pickle = 1000 106 | self.num_classes = num_classes 107 | self.max_epochs = max_epochs 108 | self.train_files = glob(dataset_path + '/data*pickle') 109 | 110 | # if num_digits == 5: 111 | # train_dataset_files = f'/data/datasets/video_mnist/files_order_pytorch/{split}_files_order3.pickle' 112 | # #elif num_digits == 3: 113 | # else: 114 | # train_dataset_files = f'/data/datasets/video_mnist/files_order_pytorch/{split}_files_{num_digits}digits_order3.pickle' 115 | # if os.path.exists(train_dataset_files): 116 | # with open(train_dataset_files, 'rb') as f: 117 | # self.mnist_random_order = pickle.load(f) 118 | # print(f'Reading from: {self.mnist_random_order[0]}') 119 | # else: 120 | # print('Shuffling and saving train files') 121 | # self.mnist_random_order = [] 122 | # for ep in range(100): 123 | # np.random.shuffle(self.train_files) 124 | # self.mnist_random_order.append(self.train_files) 125 | # with open(train_dataset_files, 'wb') as f: 126 | # pickle.dump(self.mnist_random_order, f) 127 | # print(f'Reading: {len(self.mnist_random_order[0])} pickles') 128 | 129 | overall_start = 0 130 | # print(f'dataset_path: {dataset_path}') 131 | # print(f'len(self.train_files): {len(self.train_files)}') 132 | # print(f'dataset_fraction_used: {dataset_fraction_used}') 133 | overall_end = int( dataset_fraction_used * len(self.train_files)) 134 | 135 | per_worker = int(math.ceil((overall_end - overall_start) / float(max_workers))) 136 | self.start = overall_start + worker_id * per_worker 137 | self.end = min(self.start + per_worker, overall_end) 138 | # print(f'generator has overall_end:{overall_end} start-end {self.start}-{self.end}') 139 | 140 | def next_item(self, idx): 141 | #for epoch in range(self.max_epochs): 142 | # print(f"Generator epoch: {epoch}") 143 | #return 1 144 | train_files = self.train_files #self.mnist_random_order[0] 145 | train_files = train_files[self.start:self.end] 146 | # print(f'Read:{train_files}') 147 | for file in train_files: 148 | train_videos, train_labels, target_boxes_np, is_ann_boxes = read_data_mnist(file, num_classes=self.num_classes) 149 | # train_videos, train_labels, target_boxes_np, is_ann_boxes, min_max_label = read_data_mnist(file, num_classes=self.num_classes) 150 | 151 | for pick_i in range(train_videos.shape[0]): 152 | # print(f'idx [{idx}] element {pick_i} from {file}') 153 | # video_ids = np.zeros_like(train_labels[pick_i], dtype=np.float32) 154 | #yield (train_videos[pick_i], train_labels[pick_i],video_ids, target_boxes_np[pick_i], is_ann_boxes[pick_i], pick_i, file) 155 | video = train_videos[pick_i] #.astype(np.float32) / 255.0 * 2 - 1.0 156 | yield video, train_labels[pick_i], target_boxes_np[pick_i]#, min_max_label[pick_i] 157 | 158 | 159 | if __name__ == "__main__": 160 | if False: 161 | dataset_path = '/data/datasets/video_mnist/sync_mnist_large_v2_split_test_max_sync_dist_160_num_classes_46_no_digits_5_no_noise_parts_0/' 162 | gen = SyncMNISTGenerator(dataset_path=dataset_path, split='test') 163 | 164 | get_next_item = gen.next_item() 165 | nr_videos = 15000 166 | 167 | time1 = time.time() 168 | # for b in range(nr_videos): 169 | for b, (train_videos, train_labels, target_boxes_np, _ , _) in enumerate(get_next_item): 170 | # train_videos, train_labels, target_boxes_np, _ , _ = next(get_next_item) 171 | print(f'[{b}] {train_videos.mean()}') 172 | time2 = time.time() 173 | print(f'time read {nr_videos} videos: {time2 - time1}') 174 | 175 | else: 176 | batch_size = 8 177 | ds = SyncedMNISTDataSet(split='train') 178 | loader = torch.utils.data.DataLoader(ds,batch_size=batch_size, num_workers=2, worker_init_fn=syncMNIST_worker_init_fn) 179 | time1 = time.time() 180 | 181 | 182 | #for b, (train_videos, train_labels, target_boxes_np, _ , _, _, _) in enumerate(loader): 183 | for b, (train_videos, train_labels) in enumerate(loader): 184 | if b % 100 == 0: 185 | print(b * batch_size) 186 | pass 187 | # pdb.set_trace() 188 | # print(f'[{b}] {train_videos.shape}, {train_labels.shape}') 189 | 190 | # print('Epoch 2') 191 | # for b, (train_videos, train_labels, target_boxes_np, _ , _) in enumerate(loader): 192 | # # pdb.set_trace() 193 | # print(f'[{b}] {train_videos[0].numpy().mean()}') 194 | 195 | 196 | 197 | time2 = time.time() 198 | print(f'time read : {time2 - time1}') 199 | 200 | -------------------------------------------------------------------------------- /opts.py: -------------------------------------------------------------------------------- 1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding" 2 | # arXiv:1811.08383 3 | # Ji Lin*, Chuang Gan, Song Han 4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu 5 | import argparse 6 | def str2bool(v): 7 | if isinstance(v, bool): 8 | return v 9 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 10 | return True 11 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 12 | return False 13 | else: 14 | raise argparse.ArgumentTypeError('Boolean value expected.') 15 | 16 | 17 | parser = argparse.ArgumentParser(description="PyTorch implementation of Temporal Segment Networks") 18 | parser.add_argument('--dataset', type=str, default='somethingv2') 19 | parser.add_argument('--dataset_fraction_used', default=1.0, type=float, 20 | help='use just x% of the total dataset') 21 | 22 | parser.add_argument('--test_dataset', type=str, default='../test') 23 | parser.add_argument('--train_dataset', type=str, default='../train') 24 | parser.add_argument('--modality', type=str, choices=['RGB', 'Flow', 'gray'], default='RGB') 25 | parser.add_argument('--train_list', type=str, default="") 26 | parser.add_argument('--val_list', type=str, default="") 27 | parser.add_argument('--root_path', type=str, default="") 28 | parser.add_argument('--store_name', type=str, default="") 29 | # ========================= Model Configs ========================== 30 | parser.add_argument('--arch', type=str, default="resnet50") 31 | parser.add_argument('--num_segments', type=int, default=16) 32 | 33 | parser.add_argument('--consensus_type', type=str, default='avg') 34 | parser.add_argument('--k', type=int, default=3) 35 | 36 | parser.add_argument('--dropout', '--do', default=0.5, type=float, 37 | metavar='DO', help='dropout ratio (default: 0.5)') 38 | parser.add_argument('--loss_type', type=str, default="nll", 39 | choices=['nll']) 40 | parser.add_argument('--img_feature_dim', default=256, type=int, help="the feature dimension for each frame") 41 | parser.add_argument('--suffix', type=str, default=None) 42 | parser.add_argument('--pretrain', type=str, default='imagenet') 43 | parser.add_argument('--tune_from', type=str, default=None, help='fine-tune from checkpoint') 44 | 45 | # ========================= Learning Configs ========================== 46 | parser.add_argument('--epochs', default=120, type=int, metavar='N', 47 | help='number of total epochs to run') 48 | parser.add_argument('-b', '--batch-size', default=10, type=int, 49 | metavar='N', help='mini-batch size (default: 256)') 50 | parser.add_argument('--lr', '--learning-rate', default=0.001, type=float, 51 | metavar='LR', help='initial learning rate') 52 | parser.add_argument('--lr_type', default='step', type=str, 53 | metavar='LRtype', help='learning rate type') 54 | parser.add_argument('--lr_steps', default=[50, 100], type=float, nargs="+", 55 | metavar='LRSteps', help='epochs to decay learning rate by 10') 56 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 57 | help='momentum') 58 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, 59 | metavar='W', help='weight decay (default: 5e-4)') 60 | parser.add_argument('--clip-gradient', '--gd', default=None, type=float, 61 | metavar='W', help='gradient norm clipping (default: disabled)') 62 | parser.add_argument('--no_partialbn', '--npb', default=False, action="store_true") 63 | 64 | # no_partialbn == NOT FREEZE 65 | # partialbn == enable_partialbn == FREZE BN (except the first one) 66 | 67 | # ========================= Monitor Configs ========================== 68 | parser.add_argument('--print-freq', '-p', default=20, type=int, 69 | metavar='N', help='print frequency (default: 10)') 70 | parser.add_argument('--eval-freq', '-ef', default=5, type=int, 71 | metavar='N', help='evaluation frequency (default: 5)') 72 | 73 | 74 | # ========================= Runtime Configs ========================== 75 | parser.add_argument('-j', '--workers', default=5, type=int, metavar='N', 76 | help='number of data loading workers (default: 8)') 77 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 78 | help='path to latest checkpoint (default: none)') 79 | parser.add_argument('--replace_ignore', default=True, type=str2bool, 80 | help='change name in resnet backbone') 81 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 82 | help='evaluate model on validation set') 83 | parser.add_argument('--snapshot_pref', type=str, default="") 84 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 85 | help='manual epoch number (useful on restarts)') 86 | parser.add_argument('--gpus', nargs='+', type=int, default=None) 87 | parser.add_argument('--flow_prefix', default="", type=str) 88 | parser.add_argument('--root_log',type=str, default='log') 89 | parser.add_argument('--root_model', type=str, default='checkpoint') 90 | parser.add_argument('--model_dir', type=str, default='./models/') 91 | parser.add_argument('--coeff', type=str, default=None) #I don't know what is this for 92 | 93 | parser.add_argument('--shift', default=False, action="store_true", help='use shift for models') 94 | parser.add_argument('--shift_div', default=8, type=int, help='number of div for shift (default: 8)') 95 | parser.add_argument('--shift_place', default='blockres', type=str, help='place for shift (default: stageres)') 96 | parser.add_argument('--temporal_pool', default=False, action="store_true", help='add temporal pooling') 97 | parser.add_argument('--non_local', default=False, action="store_true", help='add non local block') 98 | parser.add_argument('--dense_sample', default=False, action="store_true", help='use dense sample for video dataset') 99 | parser.add_argument('--name', default='graph', type=str, 100 | help='name of the model') 101 | 102 | # # may contain splits 103 | parser.add_argument('--weights', type=str, default=None) 104 | parser.add_argument('--test_segments', type=str, default=25) 105 | 106 | parser.add_argument('--twice_sample', default=False, action="store_true", help='use twice sample for ensemble') 107 | parser.add_argument('--full_res', default=False, type=str2bool, help='Evaluate at full resolution') 108 | 109 | parser.add_argument('--full_size_224', default=False, type=str2bool, 110 | help='reschale to 224 crop the center 224x225') 111 | parser.add_argument('--test_crops', type=int, default=1) 112 | parser.add_argument('--test_batch_size', type=int, default=1) 113 | 114 | # for true test 115 | parser.add_argument('--test_list', type=str, default=None) 116 | parser.add_argument('--csv_file', type=str, default=None) 117 | parser.add_argument('--softmax', default=False, action="store_true", help='use softmax') 118 | parser.add_argument('--max_num', type=int, default=-1) 119 | parser.add_argument('--input_size', type=int, default=224) 120 | parser.add_argument('--crop_fusion_type', type=str, default='avg') 121 | parser.add_argument('--num_set_segments',type=int, default=1,help='TODO: select multiply set of n-frames from a video') 122 | 123 | # ========================= Dynamic Regions Graph ========================== 124 | 125 | parser.add_argument('--dynamic_regions', default='dyreg', type=str, choices=['none', 'pos_only', 'dyreg', 126 | 'semantic'], 127 | help='type of regions used') 128 | parser.add_argument('--init_regions', default='grid', type=str, choices=['center', 'grid'], 129 | help='anchor position (default: center)') 130 | parser.add_argument('--offset_lstm_dim', default=128, type=int, help='number channels for the offset lstm (default: 128)') 131 | parser.add_argument('--use_rstg', default=False, type=str2bool, 132 | help='divide regions size by a scaling factor') 133 | parser.add_argument('--combine_by_sum', default=False, type=str2bool, 134 | help='combine two vectors by concatenation or by summing') 135 | parser.add_argument('--project_at_input', default=False, type=str2bool, 136 | help='combine two vectors by concatenation or by summing') 137 | 138 | 139 | parser.add_argument('--update_layers', type=int, default=2) 140 | parser.add_argument('--send_layers', type=int, default=2) 141 | parser.add_argument('--rnn_type', type=str, choices=['LSTM', 'GRU'], default='LSTM') 142 | parser.add_argument('--aggregation_type', type=str, choices=['dot', 'sum'], default='dot') 143 | parser.add_argument('--offset_generator', type=str, choices=['none', 'big', 'small'], default='big') 144 | 145 | parser.add_argument('--place_graph', default='layer1.1_layer2.2_layer3.4_layer4.1', type=str, 146 | help='where to place the graph: layeri.j_') 147 | parser.add_argument('--rstg_combine', type=str, choices=['serial', 'plus'], default='plus') 148 | parser.add_argument('--ch_div', type=int, default=2) 149 | parser.add_argument('--graph_residual_type', default='norm', type=str, 150 | help='norm / out_gate/ 1chan_out_gate/ gru_gate') 151 | parser.add_argument('--remap_first', default=False, type=str2bool, 152 | help='project and remap or remap and project') 153 | parser.add_argument('--rstg_skip_connection', default=False, type=str2bool, 154 | help='use skip connection from graphs') 155 | parser.add_argument('--warmup_validate', default=False, type=str2bool, 156 | help='warmup 100 steps before validate') 157 | parser.add_argument('--tmp_norm_skip_conn', default=False, type=str2bool, 158 | help='norm for skip connection') 159 | parser.add_argument('--init_skip_zero', default=False, type=str2bool, 160 | help='norm for skip connection') 161 | parser.add_argument('--bottleneck_graph', default=False, type=str2bool, 162 | help='smaller graph in bottleneck layer') 163 | 164 | 165 | parser.add_argument('--eval_mode', default='test', type=str, 166 | choices=['train', 'test']) 167 | parser.add_argument('--visualisation', type=str, choices=['rgb', 'hsv'], default='hsv') 168 | parser.add_argument('--save_kernels', default=False, type=str2bool, 169 | help='save the kernels') 170 | 171 | #params for running distributed 172 | parser.add_argument('--local_rank', type=int, default=0) 173 | parser.add_argument('--ngpu', type=int, default=0) 174 | parser.add_argument('--world_size', type=int, default=0) 175 | parser.add_argument('--check_learned_params', default=False, type=str2bool) 176 | 177 | def parse_args(): 178 | from ops import models_config 179 | args = parser.parse_args() 180 | args.graph_params, args.out_pool_size, args.out_num_ch, args.distill_path = models_config.get_models_config() 181 | return args -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /ops/nested_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Optional: Data Parallelism 3 | ========================== 4 | **Authors**: `Sung Kim `_ and `Jenny Kang `_ 5 | 6 | In this tutorial, we will learn how to use multiple GPUs using ``DataParallel``. 7 | 8 | It's very easy to use GPUs with PyTorch. You can put the model on a GPU: 9 | 10 | .. code:: python 11 | 12 | device = torch.device("cuda:0") 13 | model.to(device) 14 | 15 | Then, you can copy all your tensors to the GPU: 16 | 17 | .. code:: python 18 | 19 | mytensor = my_tensor.to(device) 20 | 21 | Please note that just calling ``my_tensor.to(device)`` returns a new copy of 22 | ``my_tensor`` on GPU instead of rewriting ``my_tensor``. You need to assign it to 23 | a new tensor and use that tensor on the GPU. 24 | 25 | It's natural to execute your forward, backward propagations on multiple GPUs. 26 | However, Pytorch will only use one GPU by default. You can easily run your 27 | operations on multiple GPUs by making your model run parallelly using 28 | ``DataParallel``: 29 | 30 | .. code:: python 31 | 32 | model = nn.DataParallel(model) 33 | 34 | That's the core behind this tutorial. We will explore it in more detail below. 35 | """ 36 | 37 | 38 | ###################################################################### 39 | # Imports and parameters 40 | # ---------------------- 41 | # 42 | # Import PyTorch modules and define parameters. 43 | # 44 | 45 | import torch 46 | import torch.nn as nn 47 | from torch.utils.data import Dataset, DataLoader 48 | import pdb 49 | # Parameters and DataLoaders 50 | ch_dim = 128 51 | 52 | input_size = (16,14,14,ch_dim) 53 | output_size = 2 54 | 55 | batch_size = 30 56 | data_size = 100 57 | 58 | 59 | ###################################################################### 60 | # Device 61 | # 62 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 63 | 64 | ###################################################################### 65 | # Dummy DataSet 66 | # ------------- 67 | # 68 | # Make a dummy (random) dataset. You just need to implement the 69 | # getitem 70 | # 71 | 72 | class RandomDataset(Dataset): 73 | 74 | def __init__(self, size, length): 75 | self.len = length 76 | self.data = torch.randn(length, size[0], size[1], size[2],size[3]) 77 | 78 | def __getitem__(self, index): 79 | return self.data[index] 80 | 81 | def __len__(self): 82 | return self.len 83 | 84 | rand_loader = DataLoader(dataset=RandomDataset(input_size, data_size), 85 | batch_size=batch_size, shuffle=True) 86 | 87 | 88 | ###################################################################### 89 | # Simple Model 90 | # ------------ 91 | # 92 | # For the demo, our model just gets an input, performs a linear operation, and 93 | # gives an output. However, you can use ``DataParallel`` on any model (CNN, RNN, 94 | # Capsule Net etc.) 95 | # 96 | # We've placed a print statement inside the model to monitor the size of input 97 | # and output tensors. 98 | # Please pay attention to what is printed at batch rank 0. 99 | # 100 | 101 | 102 | class LayerNormAffine(nn.Module): 103 | def __init__(self,input, size=128): 104 | super(LayerNormAffine, self).__init__() 105 | self.norm = nn.LayerNorm(128,elementwise_affine=False).to(input.device) 106 | # nn.init.constant_(self.norm.weight, 0) 107 | # nn.init.constant_(self.norm.bias, 0) 108 | 109 | self.scale = torch.nn.Parameter(torch.Tensor(size=[1,1,1,size])).to(input.device) 110 | self.bias = torch.nn.Parameter(torch.Tensor(size=[1,1,1,size])).to(input.device) 111 | 112 | # pdb.set_trace() 113 | # self.norm = self.norm.to(input.device) 114 | # self.scale = scale.to(input.device) 115 | # self.bias = bias.to(input.device) 116 | 117 | nn.init.constant_(self.bias, 0) 118 | nn.init.constant_(self.scale, 1) 119 | 120 | # self.register_parameter('bias', self.bias) 121 | # self.register_parameter('scale', self.scale) 122 | 123 | def forward(self, input): 124 | # input = self.norm(input) * self.params['scale'] + self.params['bias'] 125 | input = self.norm(input) * self.scale + self.bias 126 | 127 | 128 | return input 129 | 130 | 131 | class Model2(nn.Module): 132 | # Our model 133 | 134 | def __init__(self, input_size, output_size): 135 | super(Model2, self).__init__() 136 | self.fc = nn.Linear(ch_dim, ch_dim) 137 | self.norm_dict = nn.ModuleDict({}) 138 | self.module_norm_dict = nn.ModuleList() 139 | 140 | self.model2_bias = torch.nn.Parameter(torch.Tensor(size=[1,1,1,128])) 141 | nn.init.constant_(self.model2_bias, 1) 142 | 143 | # self.init_norm = LayerNormAffine() 144 | 145 | 146 | def get_norm(self, input, name, zero_init=False): 147 | # input: B * T x C x H x W 148 | 149 | if name not in self.norm_dict: 150 | self.model2_asdasdiasdadksa = torch.nn.Parameter(torch.Tensor(size=[1,1,1,128])) 151 | nn.init.constant_(self.model2_asdasdiasdadksa, 1) 152 | 153 | 154 | norm = LayerNormAffine(input) 155 | self.norm_dict[name] = norm 156 | # print(self) 157 | else: 158 | norm = self.norm_dict[name] 159 | 160 | input = self.norm_dict[name](input) 161 | return input 162 | 163 | def forward(self, input): 164 | pdb.set_trace() 165 | input = self.get_norm(input, 'norm1') 166 | input = input + self.model2_bias 167 | 168 | output = self.fc(input) 169 | print("\tIn Model2: input size", input.size(), 170 | "output size", output.size()) 171 | 172 | return output 173 | 174 | 175 | class Model(nn.Module): 176 | # Our model 177 | 178 | def __init__(self, input_size, output_size): 179 | super(Model, self).__init__() 180 | self.model2 = Model2(input_size, ch_dim) 181 | self.fc = nn.Linear(ch_dim, output_size) 182 | 183 | def forward(self, input): 184 | input = self.model2(input) 185 | output = self.fc(input) 186 | print(f"Forward model: {self}") 187 | # print(f'Model parameters: {self.parameters()}') 188 | #for p in self.parameters(): 189 | for name, param in model.named_parameters(): 190 | print(f'param: {name}') 191 | print(f'param: {param.device}') 192 | 193 | print("\tIn Model: input size", input.size(), 194 | "output size", output.size()) 195 | 196 | return output 197 | 198 | 199 | ###################################################################### 200 | # Create Model and DataParallel 201 | # ----------------------------- 202 | # 203 | # This is the core part of the tutorial. First, we need to make a model instance 204 | # and check if we have multiple GPUs. If we have multiple GPUs, we can wrap 205 | # our model using ``nn.DataParallel``. Then we can put our model on GPUs by 206 | # ``model.to(device)`` 207 | # 208 | 209 | model = Model(input_size, output_size) 210 | if torch.cuda.device_count() > 1: 211 | print("Let's use", torch.cuda.device_count(), "GPUs!") 212 | print(f"Model: {model}") 213 | # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs 214 | model = nn.DataParallel(model) 215 | 216 | model.to(device) 217 | 218 | 219 | ###################################################################### 220 | # Run the Model 221 | # ------------- 222 | # 223 | # Now we can see the sizes of input and output tensors. 224 | # 225 | 226 | for data in rand_loader: 227 | input = data.to(device) 228 | output = model(input) 229 | print("Outside: input size", input.size(), 230 | "output_size", output.size()) 231 | 232 | 233 | ###################################################################### 234 | # Results 235 | # ------- 236 | # 237 | # If you have no GPU or one GPU, when we batch 30 inputs and 30 outputs, the model gets 30 and outputs 30 as 238 | # expected. But if you have multiple GPUs, then you can get results like this. 239 | # 240 | # 2 GPUs 241 | # ~~~~~~ 242 | # 243 | # If you have 2, you will see: 244 | # 245 | # .. code:: bash 246 | # 247 | # # on 2 GPUs 248 | # Let's use 2 GPUs! 249 | # In Model: input size torch.Size([15, 5]) output size torch.Size([15, 2]) 250 | # In Model: input size torch.Size([15, 5]) output size torch.Size([15, 2]) 251 | # Outside: input size torch.Size([30, 5]) output_size torch.Size([30, 2]) 252 | # In Model: input size torch.Size([15, 5]) output size torch.Size([15, 2]) 253 | # In Model: input size torch.Size([15, 5]) output size torch.Size([15, 2]) 254 | # Outside: input size torch.Size([30, 5]) output_size torch.Size([30, 2]) 255 | # In Model: input size torch.Size([15, 5]) output size torch.Size([15, 2]) 256 | # In Model: input size torch.Size([15, 5]) output size torch.Size([15, 2]) 257 | # Outside: input size torch.Size([30, 5]) output_size torch.Size([30, 2]) 258 | # In Model: input size torch.Size([5, 5]) output size torch.Size([5, 2]) 259 | # In Model: input size torch.Size([5, 5]) output size torch.Size([5, 2]) 260 | # Outside: input size torch.Size([10, 5]) output_size torch.Size([10, 2]) 261 | # 262 | # 3 GPUs 263 | # ~~~~~~ 264 | # 265 | # If you have 3 GPUs, you will see: 266 | # 267 | # .. code:: bash 268 | # 269 | # Let's use 3 GPUs! 270 | # In Model: input size torch.Size([10, 5]) output size torch.Size([10, 2]) 271 | # In Model: input size torch.Size([10, 5]) output size torch.Size([10, 2]) 272 | # In Model: input size torch.Size([10, 5]) output size torch.Size([10, 2]) 273 | # Outside: input size torch.Size([30, 5]) output_size torch.Size([30, 2]) 274 | # In Model: input size torch.Size([10, 5]) output size torch.Size([10, 2]) 275 | # In Model: input size torch.Size([10, 5]) output size torch.Size([10, 2]) 276 | # In Model: input size torch.Size([10, 5]) output size torch.Size([10, 2]) 277 | # Outside: input size torch.Size([30, 5]) output_size torch.Size([30, 2]) 278 | # In Model: input size torch.Size([10, 5]) output size torch.Size([10, 2]) 279 | # In Model: input size torch.Size([10, 5]) output size torch.Size([10, 2]) 280 | # In Model: input size torch.Size([10, 5]) output size torch.Size([10, 2]) 281 | # Outside: input size torch.Size([30, 5]) output_size torch.Size([30, 2]) 282 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) 283 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) 284 | # In Model: input size torch.Size([2, 5]) output size torch.Size([2, 2]) 285 | # Outside: input size torch.Size([10, 5]) output_size torch.Size([10, 2]) 286 | # 287 | # 8 GPUs 288 | # ~~~~~~~~~~~~~~ 289 | # 290 | # If you have 8, you will see: 291 | # 292 | # .. code:: bash 293 | # 294 | # Let's use 8 GPUs! 295 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) 296 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) 297 | # In Model: input size torch.Size([2, 5]) output size torch.Size([2, 2]) 298 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) 299 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) 300 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) 301 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) 302 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) 303 | # Outside: input size torch.Size([30, 5]) output_size torch.Size([30, 2]) 304 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) 305 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) 306 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) 307 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) 308 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) 309 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) 310 | # In Model: input size torch.Size([2, 5]) output size torch.Size([2, 2]) 311 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) 312 | # Outside: input size torch.Size([30, 5]) output_size torch.Size([30, 2]) 313 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) 314 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) 315 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) 316 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) 317 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) 318 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) 319 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) 320 | # In Model: input size torch.Size([2, 5]) output size torch.Size([2, 2]) 321 | # Outside: input size torch.Size([30, 5]) output_size torch.Size([30, 2]) 322 | # In Model: input size torch.Size([2, 5]) output size torch.Size([2, 2]) 323 | # In Model: input size torch.Size([2, 5]) output size torch.Size([2, 2]) 324 | # In Model: input size torch.Size([2, 5]) output size torch.Size([2, 2]) 325 | # In Model: input size torch.Size([2, 5]) output size torch.Size([2, 2]) 326 | # In Model: input size torch.Size([2, 5]) output size torch.Size([2, 2]) 327 | # Outside: input size torch.Size([10, 5]) output_size torch.Size([10, 2]) 328 | # 329 | 330 | 331 | ###################################################################### 332 | # Summary 333 | # ------- 334 | # 335 | # DataParallel splits your data automatically and sends job orders to multiple 336 | # models on several GPUs. After each model finishes their job, DataParallel 337 | # collects and merges the results before returning it to you. 338 | # 339 | # For more information, please check out 340 | # https://pytorch.org/tutorials/beginner/former\_torchies/parallelism\_tutorial.html. 341 | # -------------------------------------------------------------------------------- /ops/models_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | import matplotlib.pyplot as plt 6 | import matplotlib.colors as colors 7 | 8 | import sys 9 | import os 10 | sys.path.append(os.path.abspath(os.path.join( 11 | os.path.dirname(os.path.realpath(__file__)), '..'))) 12 | import time 13 | from opts import parser 14 | 15 | global args, best_prec1 16 | from opts import parse_args 17 | args = parser.parse_args() 18 | args = parse_args() 19 | 20 | import pickle 21 | import pdb 22 | 23 | if args.rnn_type == 'LSTM': 24 | recurrent_net = torch.nn.LSTM 25 | elif args.rnn_type == 'GRU': 26 | recurrent_net = torch.nn.GRU 27 | 28 | def differentiable_resize_area(x, kernel): 29 | # x: B x C x H x W 30 | # kernel: B x h x w x H x W 31 | B = x.size()[0] 32 | C = x.size()[1] 33 | H = x.size()[2] 34 | W = x.size()[3] 35 | h = kernel.size()[1] 36 | w = kernel.size()[2] 37 | 38 | kernel_res = kernel.view(B, h * w, H * W) 39 | kernel_res = kernel_res.permute(0,2,1) 40 | x_res = x.view(B, C, H * W) 41 | 42 | x_resize = torch.matmul(x_res, kernel_res) 43 | x_resize = x_resize.view(B, C, h , w) 44 | # x_resize: B x C x h x w 45 | return x_resize 46 | 47 | def save_kernel(mean_kernels, folder='.'): 48 | kernel = mean_kernels[0].detach().numpy() 49 | max_t = kernel.shape[0] 50 | num_rows = kernel.shape[1] 51 | 52 | #for tt in range(max_t-5,max_t): 53 | all_frames = [] 54 | for tt in range(max_t): 55 | f, axarr = plt.subplots(num_rows,num_rows) 56 | N = kernel.shape[1] * kernel.shape[2] 57 | for ii in range(kernel.shape[1]): 58 | for jj in range(kernel.shape[2]): 59 | axarr[ii][jj].imshow(kernel[tt][ii][jj]) 60 | 61 | plt.savefig(f'{folder}/kernel_mean_{tt}.png') 62 | 63 | 64 | def atanh(x: torch.Tensor): 65 | x = x.clamp(-1 + 1e-7, 1 - 1e-7) 66 | return (torch.log(1 + x).sub(torch.log(1 - x))).mul(0.5) 67 | 68 | 69 | class LayerNormAffine2D(nn.Module): 70 | # B * T x num_chan x H x W 71 | def __init__(self, num_ch, norm_shape, zero_init=False): 72 | super(LayerNormAffine2D, self).__init__() 73 | self.scale = torch.nn.Parameter(torch.Tensor(size=[1,num_ch,1,1]))#.to(input.device) 74 | self.bias = torch.nn.Parameter(torch.Tensor(size=[1,num_ch,1,1]))#.to(input.device) 75 | self.norm = nn.LayerNorm(norm_shape, elementwise_affine=False)#.to(input.device) 76 | 77 | bias_init = 0 78 | scale_init = 1 79 | if zero_init: 80 | scale_init = 0 81 | nn.init.constant_(self.bias, bias_init) 82 | nn.init.constant_(self.scale, scale_init) 83 | 84 | def forward(self, input): 85 | input = self.norm(input) * self.scale + self.bias 86 | return input 87 | 88 | class LayerNormAffine1D(nn.Module): 89 | # B * T x num_chan x N 90 | def __init__(self, num_ch, norm_shape, zero_init=False): 91 | super(LayerNormAffine1D, self).__init__() 92 | self.scale = torch.nn.Parameter(torch.Tensor(size=[1,num_ch,1]))#.to(input.device) 93 | self.bias = torch.nn.Parameter(torch.Tensor(size=[1,num_ch,1]))#.to(input.device) 94 | self.norm = nn.LayerNorm(norm_shape, elementwise_affine=False)#.to(input.device) 95 | 96 | bias_init = 0 97 | scale_init = 1 98 | if zero_init: 99 | scale_init = 0 100 | nn.init.constant_(self.bias, bias_init) 101 | nn.init.constant_(self.scale, scale_init) 102 | 103 | def forward(self, input): 104 | input = self.norm(input) * self.scale + self.bias 105 | return input 106 | 107 | # input: D1 x D2 x ... x C 108 | # ex B x N x T x C 109 | # or B x N x N x T x C 110 | class LayerNormAffineXC(nn.Module): 111 | def __init__(self, num_ch, norm_shape): 112 | super(LayerNormAffineXC, self).__init__() 113 | self.scale = torch.nn.Parameter(torch.Tensor(size=[num_ch])) 114 | self.bias = torch.nn.Parameter(torch.Tensor(size=[num_ch])) 115 | self.norm = nn.LayerNorm(norm_shape, elementwise_affine=False) 116 | 117 | nn.init.constant_(self.bias, 0) 118 | nn.init.constant_(self.scale, 1) 119 | 120 | def forward(self, input): 121 | input = self.norm(input) * self.scale + self.bias 122 | return input 123 | 124 | 125 | 126 | 127 | class GloRe(nn.Module): 128 | def __init__(self, input_height=14, input_channels=16, h=3,w=3): 129 | super(GloRe, self).__init__() 130 | 131 | self.h = h 132 | self.w = w 133 | self.H = input_height 134 | self.W = input_height 135 | self.input_channels = input_channels 136 | self.num_nodes = h * w 137 | 138 | self.norm_dict = nn.ModuleDict({}) 139 | self.norm_dict[f'offset_location_emb'] = LayerNormAffine2D( 140 | input_channels, (input_channels, self.h, self.w) 141 | ) 142 | # reduce channels 143 | self.mask_conv = nn.Conv2d(self.input_channels, self.num_nodes, [1, 1]) 144 | self.update = nn.Conv1d(self.input_channels, args.offset_lstm_dim, [1]) 145 | 146 | # positional embeding 147 | pos_emb = self.positional_emb(self.input_channels) 148 | self.register_buffer('pos_emb_buf', pos_emb) 149 | self.sin_pos_emb_proj = nn.Conv2d(self.input_channels, self.input_channels, [1,1]) 150 | 151 | def get_norm(self, input, name): 152 | # input: B * T x C x H x W 153 | norm = self.norm_dict[name] 154 | input = norm(input) 155 | return input 156 | def apply_sin_positional_emb(self, x): 157 | # pos_emb_buf: C x H x W 158 | # x:B * T x C x H x W 159 | emb = self.pos_emb_buf.unsqueeze(0) 160 | 161 | emb = emb.repeat(x.shape[0],1,1,1) 162 | emb = self.sin_pos_emb_proj(emb) 163 | out = x + emb 164 | return out #emb- just pos 165 | def positional_emb(self, num_channels=2*32): 166 | # pos_h: H x 1 x 1 167 | # T: 1 x C x 1 168 | # T_even: 1 x C/2 x 1 169 | # emb_h_even: H x C/2 x 1 170 | channels = num_channels // 2 171 | pos_h = torch.linspace(0,1,self.H).unsqueeze(1).unsqueeze(2) 172 | T = torch.pow(10000, torch.arange(0,channels).float() / channels) 173 | T_even = T.view(-1,2)[:,:1].unsqueeze(0) 174 | emb_h_even = torch.sin(pos_h / T_even) 175 | 176 | T_odd = T.view(-1,2)[:,1:].unsqueeze(0) 177 | emb_h_odd = torch.cos(pos_h / T_odd) 178 | 179 | emb_h = torch.cat((emb_h_even, emb_h_odd), dim=2).view(self.H,1, channels) 180 | emb_h = emb_h.repeat(1,self.W,1) 181 | 182 | # pos_w: W x 1 x 1 183 | # T: 1 x C x 1 184 | # T_even: 1 x C/2 x 1 185 | # emb_h_even: W x C/2 x 1 186 | pos_w = torch.linspace(0,1,self.W).unsqueeze(1).unsqueeze(2) 187 | emb_w_even = torch.sin(pos_w / T_even) 188 | emb_w_odd = torch.cos(pos_w / T_odd) 189 | emb_w = torch.cat((emb_w_even, emb_w_odd), dim=2).view(1, self.W, channels) 190 | emb_w = emb_w.repeat(self.H,1,1) 191 | 192 | emb = torch.cat((emb_h, emb_w), dim=2).permute(2,0,1) 193 | return emb 194 | def forward(self, x): 195 | # x: B * T x C x H x W 196 | # mask: B * T x num_nodes x H x W 197 | mask = self.mask_conv(x) 198 | x = self.apply_sin_positional_emb(x) 199 | 200 | # softmax over the number of receiving nodes: each of the 201 | # H * W location send predominantly to one node 202 | mask = F.softmax(mask, dim=1) 203 | mask = mask.view(mask.shape[0], 3,3, mask.shape[2], mask.shape[3]) 204 | 205 | # nodes: B*T x C x num_nodes 206 | offset_nodes = differentiable_resize_area(x, mask) 207 | 208 | # B*T x C x h x w 209 | offset_nodes = offset_nodes.view(offset_nodes.shape[0], offset_nodes.shape[1], offset_nodes.shape[2] * offset_nodes.shape[3]) 210 | offset_nodes = self.update(offset_nodes) 211 | # nodes: B*T x offset_lstm_dim x num_nodes 212 | return offset_nodes 213 | 214 | 215 | def get_fishnet_params(input_height,keep_spatial_size = False): 216 | first_height = input_height 217 | if input_height == 32: 218 | first_height = 32 219 | strides = [(2, 2), (2,2)] 220 | padding = [(0,0), (0,0)] 221 | norm_size = [15, 7, 7] 222 | elif input_height == 16: 223 | first_height = 16 224 | strides = [(2, 2), (1,1)] 225 | padding = [(0,0), (1,1)] 226 | norm_size = [7, 7, 7] 227 | elif input_height == 8: 228 | strides = [(1, 1), (1,1)] 229 | padding = [(0,0), (1,1)] 230 | norm_size = [6, 6, 6] 231 | elif input_height > 50: 232 | strides = [(2, 2), (2,2)] 233 | padding = [(0,0), (0,0)] 234 | norm_size = [13, 6, 6] 235 | self.max_pool = torch.nn.MaxPool2d(2, stride=2) 236 | first_height = 28 237 | elif input_height > 20: 238 | strides = [(2, 2), (2,2)] 239 | padding = [(0,0), (0,0)] 240 | norm_size = [13, 6, 6] 241 | elif input_height > 10: 242 | strides = [(2, 2), (1,1)] 243 | padding = [(0,0), (1,1)] 244 | norm_size = [6, 6, 6] 245 | else: 246 | strides = [(1, 1), (1,1)] 247 | padding = [(0,0), (1,1)] 248 | norm_size = [5, 5, 5] 249 | 250 | output_padding = (0,0) 251 | if keep_spatial_size: 252 | if input_height == 16: 253 | padding = [(1,1), (1,1)] 254 | output_padding = (1,1) 255 | norm_size = [8, 8, 16] 256 | return first_height, strides, padding, norm_size, output_padding 257 | 258 | class Fishnet(nn.Module): 259 | def __init__(self, input_height=14, input_channels=16, keep_spatial_size=False): 260 | super(Fishnet, self).__init__() 261 | 262 | self.offset_layers = 2 263 | self.offset_channels = [32, 16, 16] 264 | self.offset_channels_tr = [32, 32, 16] 265 | self.input_height = input_height 266 | self.keep_spatial_size = keep_spatial_size 267 | 268 | first_height, strides, padding, norm_size, output_padding = get_fishnet_params( 269 | input_height, keep_spatial_size) 270 | self.norm_size = norm_size 271 | 272 | self.norm_dict = nn.ModuleDict({}) 273 | 274 | # reduce channels 275 | self.conv1 = nn.Conv2d(input_channels, self.offset_channels[0], [1, 1]) # de pus relu 276 | self.norm_dict[f'norm1'] = LayerNormAffine2D( 277 | self.offset_channels[0], 278 | (self.offset_channels[0], first_height, first_height) 279 | ) 280 | 281 | self.tail_conv = nn.ModuleList() 282 | self.body_trans_conv = nn.ModuleList() 283 | self.head_conv = nn.ModuleList() 284 | for i in range(self.offset_layers): 285 | self.tail_conv.append(nn.Conv2d( 286 | self.offset_channels[max(0,i-1)], self.offset_channels[i], 287 | [3, 3], padding=padding[i], stride=strides[i] 288 | ) 289 | ) 290 | self.norm_dict[f'tail_norm_{i}'] = LayerNormAffine2D( 291 | self.offset_channels[i], 292 | (self.offset_channels[i], norm_size[i], norm_size[i]) 293 | ) 294 | stop = 0 295 | if self.keep_spatial_size: 296 | stop = -1 297 | for ind, i in enumerate(range(self.offset_layers - 1, stop,-1)): 298 | if keep_spatial_size and i == 0: 299 | self.body_trans_conv.append(nn.ConvTranspose2d( 300 | self.offset_channels_tr[i+1], self.offset_channels_tr[i], 301 | [3, 3], padding=padding[i], output_padding=output_padding, 302 | stride=strides[i] 303 | ) 304 | ) 305 | else: 306 | self.body_trans_conv.append(nn.ConvTranspose2d( 307 | self.offset_channels_tr[i+1], self.offset_channels_tr[i], 308 | [3, 3], padding=padding[i], stride=strides[i] 309 | ) 310 | ) 311 | 312 | self.norm_dict[f'body_trans_norm_{ind}'] = LayerNormAffine2D( 313 | self.offset_channels_tr[i], 314 | (self.offset_channels_tr[i], norm_size[i-1], norm_size[i-1]) 315 | ) 316 | 317 | if not self.keep_spatial_size: 318 | for i in range(1, self.offset_layers): 319 | self.head_conv.append(nn.Conv2d( 320 | self.offset_channels[i-1], self.offset_channels[i], 321 | [3, 3], padding=padding[i], stride=strides[i] 322 | ) 323 | ) 324 | self.norm_dict[f'head_norm_{i-1}'] = LayerNormAffine2D( 325 | self.offset_channels[i], 326 | (self.offset_channels[i], norm_size[i], norm_size[i]) 327 | ) 328 | 329 | def get_norm(self, input, name, zero_init=False): 330 | # input: B * T x C x H x W 331 | norm = self.norm_dict[name] 332 | input = norm(input) 333 | return input 334 | 335 | def forward(self, x): 336 | # x: B * T x C x H x W 337 | x = F.relu(self.conv1(x)) 338 | if self.input_height > 50: 339 | x = self.max_pool(x) 340 | 341 | x = self.get_norm(x, 'norm1') 342 | all_x = [] 343 | all_x = [x] 344 | 345 | for i in range(self.offset_layers): 346 | x = F.relu(self.tail_conv[i](x)) 347 | x = self.get_norm(x, f'tail_norm_{i}') 348 | all_x.append(x) 349 | 350 | stop = 0 351 | if self.keep_spatial_size: 352 | stop = -1 353 | 354 | for ind, i in enumerate(range(self.offset_layers - 1, stop,-1)): 355 | x = F.relu(self.body_trans_conv[ind](x)) 356 | x = self.get_norm(x, f'body_trans_norm_{ind}') 357 | x = x + all_x[i] 358 | 359 | if not self.keep_spatial_size: 360 | for i in range(1, self.offset_layers): 361 | x = F.relu(self.head_conv[i-1](x)) 362 | x = self.get_norm(x, f'head_norm_{i-1}') 363 | 364 | return x -------------------------------------------------------------------------------- /test_models.py: -------------------------------------------------------------------------------- 1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding" 2 | # arXiv:1811.08383 3 | # Ji Lin*, Chuang Gan, Song Han 4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu 5 | 6 | # Notice that this file has been modified to support ensemble testing 7 | 8 | import argparse 9 | import time 10 | 11 | import torch.nn.parallel 12 | import torch.optim 13 | from sklearn.metrics import confusion_matrix 14 | from ops.dataset import TSNDataSet 15 | from ops.models import TSN 16 | from ops.transforms import * 17 | from ops import dataset_config 18 | from torch.nn import functional as F 19 | import pdb 20 | # options 21 | import os 22 | from opts import parser 23 | from ops.utils import save_kernels 24 | 25 | global args, best_prec1 26 | args = parser.parse_args() 27 | print(args) 28 | 29 | class AverageMeter(object): 30 | """Computes and stores the average and current value""" 31 | def __init__(self): 32 | self.reset() 33 | 34 | def reset(self): 35 | self.val = 0 36 | self.avg = 0 37 | self.sum = 0 38 | self.count = 0 39 | 40 | def update(self, val, n=1): 41 | self.val = val 42 | self.sum += val * n 43 | self.count += n 44 | self.avg = self.sum / self.count 45 | 46 | 47 | def accuracy(output, target, topk=(1,)): 48 | """Computes the precision@k for the specified values of k""" 49 | maxk = max(topk) 50 | batch_size = target.size(0) 51 | _, pred = output.topk(maxk, 1, True, True) 52 | pred = pred.t() 53 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 54 | res = [] 55 | for k in topk: 56 | # correct_k = correct[:k].view(-1).float().sum(0) 57 | correct_k = correct[:k].float().sum() 58 | res.append(correct_k.mul_(100.0 / batch_size)) 59 | return res 60 | 61 | 62 | def parse_shift_option_from_log_name(log_name): 63 | if 'shift' in log_name: 64 | strings = log_name.split('_') 65 | for i, s in enumerate(strings): 66 | if 'shift' in s: 67 | break 68 | return True, int(strings[i].replace('shift', '')), strings[i + 1] 69 | else: 70 | return False, None, None 71 | 72 | 73 | weights_list = args.weights.split(',') 74 | test_segments_list = [int(s) for s in args.test_segments.split(',')] 75 | assert len(weights_list) == len(test_segments_list) 76 | if args.coeff is None: 77 | coeff_list = [1] * len(weights_list) 78 | else: 79 | coeff_list = [float(c) for c in args.coeff.split(',')] 80 | 81 | if args.test_list is not None: 82 | test_file_list = args.test_list.split(',') 83 | else: 84 | test_file_list = [None] * len(weights_list) 85 | 86 | 87 | data_iter_list = [] 88 | net_list = [] 89 | modality_list = [] 90 | 91 | total_num = None 92 | for this_weights, this_test_segments, test_file in zip(weights_list, test_segments_list, test_file_list): 93 | is_shift, shift_div, shift_place = args.shift, args.shift_div, args.shift_place 94 | modality = args.modality 95 | #this_arch = this_weights.split('TSM_')[1].split('_')[2] 96 | this_arch = args.arch 97 | modality_list.append(modality) 98 | num_class, args.train_list, val_list, root_path, prefix = dataset_config.return_dataset(args.dataset, 99 | modality) 100 | print('=> shift: {}, shift_div: {}, shift_place: {}'.format(is_shift, shift_div, shift_place)) 101 | net = TSN(num_class, this_test_segments if is_shift else 1, modality, 102 | base_model=this_arch, 103 | consensus_type=args.crop_fusion_type, 104 | dropout=args.dropout, 105 | img_feature_dim=args.img_feature_dim, 106 | partial_bn=not args.no_partialbn, 107 | fc_lr5=not (args.tune_from and args.dataset in args.tune_from), 108 | pretrain=args.pretrain, 109 | is_shift=is_shift, shift_div=shift_div, shift_place=shift_place, 110 | temporal_pool=args.temporal_pool, 111 | non_local='_nl' in this_weights, 112 | ) 113 | 114 | if 'tpool' in this_weights: 115 | from ops.temporal_shift import make_temporal_pool 116 | make_temporal_pool(net.base_model, this_test_segments) # since DataParallel 117 | 118 | print(f'Loading weights from: {this_weights}') 119 | checkpoint = torch.load(this_weights, map_location=torch.device('cpu')) 120 | epoch = checkpoint['epoch'] 121 | checkpoint = checkpoint['state_dict'] 122 | print(f'Evaluating epoch: {epoch}') 123 | ckpt_dict = {} 124 | for k, v in list(checkpoint.items()): 125 | if ('dynamic_graph.ph' not in k and 'dynamic_graph.pw' not in k and 'dynamic_graph.arange_h' not in k and 'dynamic_graph.arange_w' not in k 126 | and 'const_dh_ones' not in k and 'const_dw_ones' not in k and 'fix_offsets' not in k): 127 | # remove first tag ('module') 128 | key_name = '.'.join(k.split('.')[1:]) 129 | # remove tags from checkpoints vars 130 | key_name = key_name.replace('base_model.map_final_project','map_final_project') 131 | key_name = key_name.replace('.block','.0.block') 132 | key_name = key_name.replace('.dynamic_graph.','.1.dynamic_graph.') 133 | key_name = key_name.replace('.norm_dict.residual_norm','.1.norm_dict.residual_norm') 134 | ckpt_dict[key_name] = v 135 | 136 | model_dict = net.state_dict() 137 | 138 | print('Model parameters') 139 | [print(k) for k in model_dict.keys()] 140 | 141 | print('checkpoint parameters') 142 | [print(k) for k in checkpoint.keys()] 143 | 144 | # 145 | for k, v in model_dict.items(): 146 | if 'ignore' in k: 147 | old_name = k.replace('ignore.', '') 148 | ckpt_val = checkpoint['module.' +old_name] 149 | del ckpt_dict[old_name] 150 | ckpt_dict[k] = ckpt_val 151 | 152 | model_dict.update(ckpt_dict) 153 | net.load_state_dict(model_dict) 154 | 155 | input_size = net.scale_size if args.full_res else net.input_size 156 | 157 | 158 | scale1 = net.scale_size 159 | crop1 = input_size 160 | if args.test_crops == 1: 161 | cropping = torchvision.transforms.Compose([ 162 | GroupScale(scale1), 163 | GroupCenterCrop(crop1), 164 | ]) 165 | elif args.test_crops == 3: # do not flip, so only 5 crops 166 | cropping = torchvision.transforms.Compose([ 167 | GroupFullResSample(input_size, net.scale_size, flip=False) 168 | ]) 169 | elif args.test_crops == 5: # do not flip, so only 5 crops 170 | cropping = torchvision.transforms.Compose([ 171 | GroupOverSample(input_size, net.scale_size, flip=False) 172 | ]) 173 | elif args.test_crops == 10: 174 | cropping = torchvision.transforms.Compose([ 175 | GroupOverSample(input_size, net.scale_size) 176 | ]) 177 | else: 178 | raise ValueError("Only 1, 5, 10 crops are supported while we got {}".format(args.test_crops)) 179 | data_loader = torch.utils.data.DataLoader( 180 | TSNDataSet(root_path, test_file if test_file is not None else val_list, num_segments=this_test_segments, 181 | new_length=1, 182 | modality=modality, 183 | image_tmpl=prefix, 184 | test_mode=True, 185 | remove_missing=len(weights_list) == 1, 186 | transform=torchvision.transforms.Compose([ 187 | cropping, 188 | Stack(roll=(False)), 189 | ToTorchFormatTensor(div=(True)), 190 | GroupNormalize(net.input_mean, net.input_std), 191 | ]), dense_sample=args.dense_sample, twice_sample=args.twice_sample), 192 | batch_size=args.batch_size, shuffle=False, 193 | num_workers=args.workers if torch.cuda.is_available() else 0, pin_memory=True 194 | ) 195 | 196 | if args.gpus is not None: 197 | devices = [args.gpus[i] for i in range(args.workers)] 198 | else: 199 | devices = list(range(args.workers)) 200 | 201 | if torch.cuda.is_available(): 202 | net = torch.nn.DataParallel(net.cuda()) 203 | net.eval() 204 | data_gen = enumerate(data_loader) 205 | if total_num is None: 206 | total_num = len(data_loader.dataset) 207 | else: 208 | assert total_num == len(data_loader.dataset) 209 | data_iter_list.append(data_gen) 210 | net_list.append(net) 211 | output = [] 212 | 213 | def eval_video(video_data, net, this_test_segments, modality, mode='eval'): 214 | if mode == 'eval': 215 | net.eval() 216 | else: 217 | print("something wrong. double check the flags") 218 | net.train() 219 | with torch.no_grad(): 220 | i, data, label, _ = video_data 221 | batch_size = label.numel() 222 | num_crop = args.test_crops 223 | if args.dense_sample: 224 | num_crop *= 10 # 10 clips for testing when using dense sample 225 | 226 | if args.twice_sample: 227 | num_crop *= 2 228 | 229 | length = 3 230 | data_in = data 231 | 232 | if is_shift: 233 | data_in = data_in.view(batch_size * num_crop, this_test_segments, length, data_in.size(2), data_in.size(3)) 234 | 235 | 236 | rst, model_aux_feats = net(data_in) 237 | aux = model_aux_feats['interm_feats'] 238 | offsets = model_aux_feats['offsets'] 239 | 240 | list_save_iter = [0,100,1000] 241 | if args.save_kernels and i in list_save_iter: 242 | folder = os.path.join(args.model_dir, args.store_name, args.root_log,'kernels') 243 | save_kernels(data_in, aux, folder=folder, name=f'validation_iter_{i}', predicted_offsets=offsets) 244 | rst = rst.reshape(batch_size, num_crop, -1).mean(1) 245 | 246 | if args.softmax: 247 | # take the softmax to normalize the output to probability 248 | rst = F.softmax(rst, dim=1) 249 | rst = rst.data.cpu().numpy().copy() 250 | 251 | if torch.cuda.is_available(): 252 | if net.module.is_shift: 253 | rst = rst.reshape(batch_size, num_class) 254 | else: 255 | rst = rst.reshape((batch_size, -1, num_class)).mean(axis=1).reshape((batch_size, num_class)) 256 | else: 257 | if net.is_shift: 258 | rst = rst.reshape(batch_size, num_class) 259 | else: 260 | rst = rst.reshape((batch_size, -1, num_class)).mean(axis=1).reshape((batch_size, num_class)) 261 | return i, rst, label 262 | 263 | 264 | proc_start_time = time.time() 265 | max_num = args.max_num if args.max_num > 0 else total_num 266 | 267 | top1 = AverageMeter() 268 | top5 = AverageMeter() 269 | 270 | all_batch_time = [] 271 | for i, data_label_pairs in enumerate(zip(*data_iter_list)): 272 | with torch.no_grad(): 273 | if i >= max_num: 274 | break 275 | this_rst_list = [] 276 | this_label = None 277 | begin_proc = time.time() 278 | for n_seg, (_, (_, data, label,detector_out)), net, modality in zip(test_segments_list, data_label_pairs, net_list, modality_list): 279 | rst = eval_video((i, data, label,detector_out), net, n_seg, modality, mode='eval') 280 | this_rst_list.append(rst[1]) 281 | this_label = label 282 | assert len(this_rst_list) == len(coeff_list) 283 | for i_coeff in range(len(this_rst_list)): 284 | this_rst_list[i_coeff] *= coeff_list[i_coeff] 285 | ensembled_predict = sum(this_rst_list) / len(this_rst_list) 286 | 287 | for p, g in zip(ensembled_predict, this_label.cpu().numpy()): 288 | output.append([p[None, ...], g]) 289 | cnt_time = time.time() - proc_start_time 290 | batch_time = time.time() - begin_proc 291 | if i > 0: 292 | all_batch_time.append(batch_time) 293 | prec1, prec5 = accuracy(torch.from_numpy(ensembled_predict), this_label, topk=(1, 5)) 294 | top1.update(prec1.item(), this_label.numel()) 295 | top5.update(prec5.item(), this_label.numel()) 296 | if i % 5 == 0: 297 | print('video {} done, total {}/{}, average {:.3f} sec/video, ' 298 | 'moving Prec@1 {:.3f} Prec@5 {:.3f}'.format(i * args.batch_size, i * args.batch_size, total_num, 299 | float(cnt_time) / (i+1) / args.batch_size, top1.avg, top5.avg)) 300 | print(f"average sec/video: {np.array(all_batch_time).mean() / args.batch_size} ") 301 | video_pred = [np.argmax(x[0]) for x in output] 302 | video_pred_top5 = [np.argsort(np.mean(x[0], axis=0).reshape(-1))[::-1][:5] for x in output] 303 | 304 | video_labels = [x[1] for x in output] 305 | 306 | 307 | if args.csv_file is not None: 308 | print('=> Writing result to csv file: {}'.format(args.csv_file)) 309 | with open(test_file_list[0].replace('test_videofolder.txt', 'category.txt')) as f: 310 | categories = f.readlines() 311 | categories = [f.strip() for f in categories] 312 | with open(test_file_list[0]) as f: 313 | vid_names = f.readlines() 314 | vid_names = [n.split(' ')[0] for n in vid_names] 315 | assert len(vid_names) == len(video_pred) 316 | if args.dataset != 'somethingv2': # only output top1 317 | with open(args.csv_file, 'w') as f: 318 | for n, pred in zip(vid_names, video_pred): 319 | f.write('{};{}\n'.format(n, categories[pred])) 320 | else: 321 | with open(args.csv_file, 'w') as f: 322 | for n, pred5 in zip(vid_names, video_pred_top5): 323 | fill = [n] 324 | for p in list(pred5): 325 | fill.append(p) 326 | f.write('{};{};{};{};{};{}\n'.format(*fill)) 327 | 328 | 329 | cf = confusion_matrix(video_labels, video_pred).astype(float) 330 | 331 | np.save('cm.npy', cf) 332 | cls_cnt = cf.sum(axis=1) 333 | cls_hit = np.diag(cf) 334 | eps = 0.0000001 335 | cls_acc = cls_hit / (cls_cnt + eps) 336 | print(cls_acc) 337 | upper = np.mean(np.max(cf, axis=1) / (cls_cnt + eps)) 338 | print('upper bound: {}'.format(upper)) 339 | 340 | print('-----Evaluation is finished------') 341 | print('Class Accuracy {:.02f}%'.format(np.mean(cls_acc) * 100)) 342 | print('Overall Prec@1 {:.02f}% Prec@5 {:.02f}%'.format(top1.avg, top5.avg)) 343 | 344 | 345 | -------------------------------------------------------------------------------- /ops/resnet3d_xl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import math 6 | import numpy as np 7 | 8 | import pdb 9 | from functools import partial 10 | 11 | from opts import parser 12 | args = parser.parse_args() 13 | from ops.rstg import * 14 | 15 | __all__ = [ 16 | 'ResNet', 'resnet10', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 17 | 'resnet152', 'resnet200', 18 | ] 19 | 20 | 21 | def conv3x3x3(in_planes, out_planes, stride=1): 22 | # 3x3x3 convolution with padding 23 | return nn.Conv3d( 24 | in_planes, 25 | out_planes, 26 | kernel_size=3, 27 | stride=stride, 28 | padding=1, 29 | bias=False) 30 | 31 | 32 | def downsample_basic_block(x, planes, stride): 33 | out = F.avg_pool3d(x, kernel_size=1, stride=stride) 34 | zero_pads = torch.Tensor( 35 | out.size(0), planes - out.size(1), out.size(2), out.size(3), 36 | out.size(4)).zero_() 37 | if isinstance(out.data, torch.cuda.FloatTensor): 38 | zero_pads = zero_pads.cuda() 39 | 40 | out = Variable(torch.cat([out.data, zero_pads], dim=1)) 41 | 42 | return out 43 | 44 | 45 | class BasicBlock(nn.Module): 46 | expansion = 1 47 | 48 | def __init__(self, inplanes, planes, stride=1, downsample=None): 49 | super(BasicBlock, self).__init__() 50 | self.conv1 = conv3x3x3(inplanes, planes, stride) 51 | self.bn1 = nn.BatchNorm3d(planes) 52 | self.relu = nn.ReLU(inplace=True) 53 | self.conv2 = conv3x3x3(planes, planes) 54 | self.bn2 = nn.BatchNorm3d(planes) 55 | self.downsample = downsample 56 | self.stride = stride 57 | 58 | def forward(self, x): 59 | residual = x 60 | 61 | out = self.conv1(x) 62 | out = self.bn1(out) 63 | out = self.relu(out) 64 | 65 | out = self.conv2(out) 66 | out = self.bn2(out) 67 | 68 | if self.downsample is not None: 69 | residual = self.downsample(x) 70 | 71 | out += residual 72 | out = self.relu(out) 73 | 74 | return out 75 | 76 | 77 | class Bottleneck(nn.Module): 78 | conv_op = None 79 | offset_groups = 1 80 | 81 | def __init__(self, dim_in, dim_out, stride, dim_inner, group=1, use_temp_conv=1, temp_stride=1, dcn=False, 82 | shortcut_type='B'): 83 | super(Bottleneck, self).__init__() 84 | # 1 x 1 layer 85 | self.with_dcn = dcn 86 | self.conv1 = self.Conv3dBN(dim_in, dim_inner, (1 + use_temp_conv * 2, 1, 1), (temp_stride, 1, 1), 87 | (use_temp_conv, 0, 0)) 88 | self.relu = nn.ReLU(inplace=True) 89 | # 3 x 3 layer 90 | self.conv2 = self.Conv3dBN(dim_inner, dim_inner, (1, 3, 3), (1, stride, stride), (0, 1, 1)) 91 | # 1 x 1 layer 92 | self.conv3 = self.Conv3dBN(dim_inner, dim_out, (1, 1, 1), (1, 1, 1), (0, 0, 0)) 93 | 94 | self.shortcut_type = shortcut_type 95 | self.dim_in = dim_in 96 | self.dim_out = dim_out 97 | self.temp_stride = temp_stride 98 | self.stride = stride 99 | # nn.Conv3d(dim_in, dim_out, (1,1,1),(temp_stride,stride,stride),(0,0,0)) 100 | if self.shortcut_type == 'B': 101 | if self.dim_in == self.dim_out and self.temp_stride == 1 and self.stride == 1: # or (self.dim_in == self.dim_out and self.dim_in == 64 and self.stride ==1): 102 | 103 | pass 104 | else: 105 | # pass 106 | self.shortcut = self.Conv3dBN(dim_in, dim_out, (1, 1, 1), (temp_stride, stride, stride), (0, 0, 0)) 107 | 108 | # nn.Conv3d(dim_in,dim_inner,kernel_size=(1+use_temp_conv*2,1,1),stride = (temp_stride,1,1),padding = ) 109 | 110 | def forward(self, x): 111 | residual = x 112 | out = self.conv1(x) 113 | out = self.relu(out) 114 | out = self.conv2(out) 115 | out = self.relu(out) 116 | out = self.conv3(out) 117 | if self.dim_in == self.dim_out and self.temp_stride == 1 and self.stride == 1: 118 | pass 119 | else: 120 | residual = self.shortcut(residual) 121 | out += residual 122 | out = self.relu(out) 123 | return out 124 | 125 | def Conv3dBN(self, dim_in, dim_out, kernels, strides, pads, group=1): 126 | if self.with_dcn and kernels[0] > 1: 127 | # use deformable conv 128 | return nn.Sequential( 129 | self.conv_op(dim_in, dim_out, kernel_size=kernels, stride=strides, padding=pads, bias=False, 130 | offset_groups=self.offset_groups), 131 | nn.BatchNorm3d(dim_out) 132 | ) 133 | else: 134 | return nn.Sequential( 135 | nn.Conv3d(dim_in, dim_out, kernel_size=kernels, stride=strides, padding=pads, bias=False), 136 | nn.BatchNorm3d(dim_out) 137 | ) 138 | 139 | 140 | class ResNet(nn.Module): 141 | 142 | def __init__(self, 143 | block, 144 | layers, 145 | use_temp_convs_set, 146 | temp_strides_set, 147 | sample_size, 148 | sample_duration, 149 | shortcut_type='B', 150 | num_classes=400, 151 | stage_with_dcn=(False, False, False, False), 152 | extract_features=False, 153 | loss_type='softmax'): 154 | super(ResNet, self).__init__() 155 | self.extract_features = extract_features 156 | self.stage_with_dcn = stage_with_dcn 157 | self.group = 1 158 | self.width_per_group = 64 159 | self.dim_inner = self.group * self.width_per_group 160 | # self.shortcut_type = shortcut_type 161 | self.conv1 = nn.Conv3d( 162 | 3, 163 | 64, 164 | kernel_size=(1 + use_temp_convs_set[0][0] * 2, 7, 7), 165 | stride=(temp_strides_set[0][0], 2, 2), 166 | padding=(use_temp_convs_set[0][0], 3, 3), 167 | bias=False) 168 | self.bn1 = nn.BatchNorm3d(64) 169 | self.relu = nn.ReLU(inplace=True) 170 | self.maxpool1 = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 0, 0)) 171 | with_dcn = True if self.stage_with_dcn[0] else False 172 | self.layer1 = self._make_layer(block, 64, 256, shortcut_type, stride=1, num_blocks=layers[0], 173 | dim_inner=self.dim_inner, group=self.group, use_temp_convs=use_temp_convs_set[1], 174 | temp_strides=temp_strides_set[1], dcn=with_dcn) 175 | self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) 176 | with_dcn = True if self.stage_with_dcn[1] else False 177 | self.layer2 = self._make_layer(block, 256, 512, shortcut_type, stride=2, num_blocks=layers[1], 178 | dim_inner=self.dim_inner * 2, group=self.group, 179 | use_temp_convs=use_temp_convs_set[2], temp_strides=temp_strides_set[2], 180 | dcn=with_dcn) 181 | with_dcn = True if self.stage_with_dcn[2] else False 182 | self.layer3 = self._make_layer(block, 512, 1024, shortcut_type, stride=2, num_blocks=layers[2], 183 | dim_inner=self.dim_inner * 4, group=self.group, 184 | use_temp_convs=use_temp_convs_set[3], temp_strides=temp_strides_set[3], 185 | dcn=with_dcn) 186 | with_dcn = True if self.stage_with_dcn[3] else False 187 | self.layer4 = self._make_layer(block, 1024, 2048, shortcut_type, stride=1, num_blocks=layers[3], 188 | dim_inner=self.dim_inner * 8, group=self.group, 189 | use_temp_convs=use_temp_convs_set[4], temp_strides=temp_strides_set[4], 190 | dcn=with_dcn) 191 | last_duration = int(math.ceil(sample_duration / 2)) # int(math.ceil(sample_duration / 8)) 192 | last_size = int(math.ceil(sample_size / 16)) 193 | # self.avgpool = nn.AvgPool3d((last_duration, last_size, last_size), stride=1) #nn.AdaptiveAvgPool3d((1, 1, 1)) # 194 | self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1)) 195 | self.dropout = torch.nn.Dropout(p=0.5) 196 | self.classifier = nn.Linear(2048, num_classes) 197 | 198 | for m in self.modules(): 199 | # if isinstance(m, nn.Conv3d): 200 | # m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out') 201 | # elif isinstance(m,nn.Linear): 202 | # m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out') 203 | # elif 204 | if isinstance(m, nn.BatchNorm3d): 205 | m.weight.data.fill_(1) 206 | m.bias.data.zero_() 207 | 208 | def _make_layer(self, block, dim_in, dim_out, shortcut_type, stride, num_blocks, dim_inner=None, group=None, 209 | use_temp_convs=None, temp_strides=None, dcn=False): 210 | if use_temp_convs is None: 211 | use_temp_convs = np.zeros(num_blocks).astype(int) 212 | if temp_strides is None: 213 | temp_strides = np.ones(num_blocks).astype(int) 214 | if len(use_temp_convs) < num_blocks: 215 | for _ in range(num_blocks - len(use_temp_convs)): 216 | use_temp_convs.append(0) 217 | temp_strides.append(1) 218 | layers = [] 219 | for idx in range(num_blocks): 220 | block_stride = 2 if (idx == 0 and stride == 2) else 1 221 | 222 | layers.append( 223 | block(dim_in, dim_out, block_stride, dim_inner, group, use_temp_convs[idx], temp_strides[idx], dcn)) 224 | dim_in = dim_out 225 | return nn.Sequential(*layers) 226 | 227 | def forward_single(self, x): 228 | x = self.conv1(x) 229 | 230 | x = self.bn1(x) 231 | x = self.relu(x) 232 | x = self.maxpool1(x) 233 | 234 | x = self.layer1(x) 235 | x = self.maxpool2(x) 236 | x = self.layer2(x) 237 | 238 | x = self.layer3(x) 239 | features = self.layer4(x) 240 | 241 | x = self.avgpool(features) 242 | 243 | y = x 244 | # x = x.view(x.size(0), -1) 245 | # x = self.dropout(x) 246 | 247 | # y = self.classifier(x) 248 | if self.extract_features: 249 | return y, features 250 | else: 251 | return y 252 | 253 | def forward_multi(self, x): 254 | clip_preds = [] 255 | # import ipdb;ipdb.set_trace() 256 | for clip_idx in range(x.shape[1]): # B, 10, 3, 3, 32, 224, 224 257 | spatial_crops = [] 258 | for crop_idx in range(x.shape[2]): 259 | clip = x[:, clip_idx, crop_idx] 260 | clip = self.forward_single(clip) 261 | spatial_crops.append(clip) 262 | spatial_crops = torch.stack(spatial_crops, 1).mean(1) # (B, 400) 263 | clip_preds.append(spatial_crops) 264 | clip_preds = torch.stack(clip_preds, 1).mean(1) # (B, 400) 265 | return clip_preds 266 | 267 | def forward(self, x): 268 | # pdb.set_trace() 269 | # x: BT x 3 x H x W -> B x T x 3 x H x W 270 | # pdb.set_trace() 271 | x = x.view([args.batch_size, args.num_segments, x.shape[-3], x.shape[-2], x.shape[-1]]) 272 | x = x.permute([0,2,1,3,4]) 273 | # 5D tensor == single clip 274 | if x.dim() == 5: 275 | pred = self.forward_single(x) 276 | 277 | # 7D tensor == 3 crops/10 clips 278 | elif x.dim() == 7: 279 | pred = self.forward_multi(x) 280 | 281 | # loss_dict = {} 282 | # if 'label' in batch: 283 | # loss = F.cross_entropy(pred, batch['label'], reduction='none') 284 | # loss_dict = {'clf': loss} 285 | 286 | return pred 287 | 288 | 289 | def get_fine_tuning_parameters(model, ft_begin_index): 290 | if ft_begin_index == 0: 291 | return model.parameters() 292 | 293 | ft_module_names = [] 294 | for i in range(ft_begin_index, 5): 295 | ft_module_names.append('layer{}'.format(i)) 296 | ft_module_names.append('fc') 297 | # import ipdb;ipdb.set_trace() 298 | parameters = [] 299 | for k, v in model.named_parameters(): 300 | for ft_module in ft_module_names: 301 | if ft_module in k: 302 | parameters.append({'params': v}) 303 | break 304 | else: 305 | parameters.append({'params': v, 'lr': 0.0}) 306 | 307 | return parameters 308 | 309 | 310 | def obtain_arc(arc_type): 311 | # c2d, ResNet50 312 | if arc_type == 1: 313 | use_temp_convs_1 = [0] 314 | temp_strides_1 = [2] 315 | use_temp_convs_2 = [0, 0, 0] 316 | temp_strides_2 = [1, 1, 1] 317 | use_temp_convs_3 = [0, 0, 0, 0] 318 | temp_strides_3 = [1, 1, 1, 1] 319 | use_temp_convs_4 = [0, ] * 6 320 | temp_strides_4 = [1, ] * 6 321 | use_temp_convs_5 = [0, 0, 0] 322 | temp_strides_5 = [1, 1, 1] 323 | 324 | # i3d, ResNet50 325 | if arc_type == 2: 326 | use_temp_convs_1 = [2] 327 | temp_strides_1 = [1] 328 | use_temp_convs_2 = [1, 1, 1] 329 | temp_strides_2 = [1, 1, 1] 330 | use_temp_convs_3 = [1, 0, 1, 0] 331 | temp_strides_3 = [1, 1, 1, 1] 332 | use_temp_convs_4 = [1, 0, 1, 0, 1, 0] 333 | temp_strides_4 = [1, 1, 1, 1, 1, 1] 334 | use_temp_convs_5 = [0, 1, 0] 335 | temp_strides_5 = [1, 1, 1] 336 | 337 | # c2d, ResNet101 338 | if arc_type == 3: 339 | use_temp_convs_1 = [0] 340 | temp_strides_1 = [2] 341 | use_temp_convs_2 = [0, 0, 0] 342 | temp_strides_2 = [1, 1, 1] 343 | use_temp_convs_3 = [0, 0, 0, 0] 344 | temp_strides_3 = [1, 1, 1, 1] 345 | use_temp_convs_4 = [0, ] * 23 346 | temp_strides_4 = [1, ] * 23 347 | use_temp_convs_5 = [0, 0, 0] 348 | temp_strides_5 = [1, 1, 1] 349 | 350 | # i3d, ResNet101 351 | if arc_type == 4: 352 | use_temp_convs_1 = [2] 353 | temp_strides_1 = [2] 354 | use_temp_convs_2 = [1, 1, 1] 355 | temp_strides_2 = [1, 1, 1] 356 | use_temp_convs_3 = [1, 0, 1, 0] 357 | temp_strides_3 = [1, 1, 1, 1] 358 | use_temp_convs_4 = [] 359 | for i in range(23): 360 | if i % 2 == 0: 361 | use_temp_convs_4.append(1) 362 | else: 363 | use_temp_convs_4.append(0) 364 | 365 | temp_strides_4 = [1, ] * 23 366 | use_temp_convs_5 = [0, 1, 0] 367 | temp_strides_5 = [1, 1, 1] 368 | 369 | use_temp_convs_set = [use_temp_convs_1, use_temp_convs_2, use_temp_convs_3, use_temp_convs_4, use_temp_convs_5] 370 | temp_strides_set = [temp_strides_1, temp_strides_2, temp_strides_3, temp_strides_4, temp_strides_5] 371 | 372 | return use_temp_convs_set, temp_strides_set 373 | 374 | 375 | def resnet10(**kwargs): 376 | """Constructs a ResNet-18 model. 377 | """ 378 | use_temp_convs_set = [] 379 | temp_strides_set = [] 380 | model = ResNet(BasicBlock, [1, 1, 1, 1], use_temp_convs_set, temp_strides_set, **kwargs) 381 | return model 382 | 383 | 384 | def resnet18(**kwargs): 385 | """Constructs a ResNet-18 model. 386 | """ 387 | use_temp_convs_set = [] 388 | temp_strides_set = [] 389 | model = ResNet(BasicBlock, [2, 2, 2, 2], use_temp_convs_set, temp_strides_set, **kwargs) 390 | return model 391 | 392 | 393 | def resnet34(**kwargs): 394 | """Constructs a ResNet-34 model. 395 | """ 396 | use_temp_convs_set = [] 397 | temp_strides_set = [] 398 | model = ResNet(BasicBlock, [3, 4, 6, 3], use_temp_convs_set, temp_strides_set, **kwargs) 399 | return model 400 | 401 | 402 | def resnet50(extract_features, **kwargs): 403 | """Constructs a ResNet-50 model. 404 | """ 405 | use_temp_convs_set, temp_strides_set = obtain_arc(2) 406 | model = ResNet(Bottleneck, [3, 4, 6, 3], use_temp_convs_set, temp_strides_set, 407 | extract_features=extract_features, **kwargs) 408 | return model 409 | 410 | 411 | def resnet101(**kwargs): 412 | """Constructs a ResNet-101 model. 413 | """ 414 | use_temp_convs_set, temp_strides_set = obtain_arc(4) 415 | model = ResNet(Bottleneck, [3, 4, 23, 3], use_temp_convs_set, temp_strides_set, **kwargs) 416 | return model 417 | 418 | 419 | def resnet152(**kwargs): 420 | """Constructs a ResNet-101 model. 421 | """ 422 | use_temp_convs_set = [] 423 | temp_strides_set = [] 424 | model = ResNet(Bottleneck, [3, 8, 36, 3], use_temp_convs_set, temp_strides_set, **kwargs) 425 | return model 426 | 427 | 428 | def resnet200(**kwargs): 429 | """Constructs a ResNet-101 model. 430 | """ 431 | use_temp_convs_set = [] 432 | temp_strides_set = [] 433 | model = ResNet(Bottleneck, [3, 24, 36, 3], use_temp_convs_set, temp_strides_set, **kwargs) 434 | return model 435 | 436 | 437 | def Net(num_classes, extract_features=False, loss_type='softmax', 438 | weights=None, freeze_all_but_cls=False): 439 | net = globals()['resnet' + str(50)]( 440 | num_classes=num_classes, 441 | sample_size=50, 442 | sample_duration=32, 443 | extract_features=extract_features, 444 | loss_type=loss_type, 445 | ) 446 | 447 | if weights is not None: 448 | kinetics_weights = torch.load(weights)['state_dict'] 449 | print("Found weights in {}.".format(weights)) 450 | cls_name = 'fc' 451 | else: 452 | kinetics_weights = torch.load('kinetics-res50.pth') 453 | cls_name = 'fc' 454 | print('\n Restoring Kintetics \n') 455 | 456 | new_weights = {} 457 | for k, v in kinetics_weights.items(): 458 | if not k.startswith('module.' + cls_name): 459 | new_weights[k.replace('module.', '')] = v 460 | else: 461 | print(f"!!! Smt wrong with restore {k}") 462 | net.load_state_dict(new_weights, strict=False) 463 | 464 | if freeze_all_but_cls: 465 | for name, par in net.named_parameters(): 466 | if not name.startswith('classifier'): 467 | par.requires_grad = False 468 | return net 469 | -------------------------------------------------------------------------------- /ops/resnet2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | import torch.nn as nn 4 | # from .utils import load_state_dict_from_url 5 | from typing import Type, Any, Callable, Union, List, Optional 6 | import pdb 7 | try: 8 | from torch.hub import load_state_dict_from_url 9 | except ImportError: 10 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 11 | 12 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 13 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 14 | 'wide_resnet50_2', 'wide_resnet101_2'] 15 | 16 | 17 | model_urls = { 18 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 19 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 20 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 21 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 22 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 23 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 24 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 25 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 26 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 27 | } 28 | 29 | 30 | 31 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: 32 | """3x3 convolution with padding""" 33 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 34 | padding=dilation, groups=groups, bias=False, dilation=dilation) 35 | 36 | 37 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 38 | """1x1 convolution""" 39 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 40 | 41 | 42 | class BasicBlock(nn.Module): 43 | expansion: int = 1 44 | 45 | def __init__( 46 | self, 47 | inplanes: int, 48 | planes: int, 49 | stride: int = 1, 50 | downsample: Optional[nn.Module] = None, 51 | groups: int = 1, 52 | base_width: int = 64, 53 | dilation: int = 1, 54 | norm_layer: Optional[Callable[..., nn.Module]] = None 55 | ) -> None: 56 | super(BasicBlock, self).__init__() 57 | if norm_layer is None: 58 | norm_layer = nn.BatchNorm2d 59 | if groups != 1 or base_width != 64: 60 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 61 | if dilation > 1: 62 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 63 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 64 | self.conv1 = conv3x3(inplanes, planes, stride) 65 | self.bn1 = norm_layer(planes) 66 | self.relu = nn.ReLU(inplace=True) 67 | self.conv2 = conv3x3(planes, planes) 68 | self.bn2 = norm_layer(planes) 69 | self.downsample = downsample 70 | self.stride = stride 71 | 72 | def forward(self, x: Tensor) -> Tensor: 73 | identity = x 74 | 75 | out = self.conv1(x) 76 | out = self.bn1(out) 77 | out = self.relu(out) 78 | 79 | out = self.conv2(out) 80 | out = self.bn2(out) 81 | 82 | if self.downsample is not None: 83 | identity = self.downsample(x) 84 | 85 | out += identity 86 | out = self.relu(out) 87 | 88 | return out 89 | 90 | 91 | class Bottleneck(nn.Module): 92 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 93 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 94 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 95 | # This variant is also known as ResNet V1.5 and improves accuracy according to 96 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 97 | 98 | expansion: int = 4 99 | 100 | def __init__( 101 | self, 102 | inplanes: int, 103 | planes: int, 104 | stride: int = 1, 105 | downsample: Optional[nn.Module] = None, 106 | groups: int = 1, 107 | base_width: int = 64, 108 | dilation: int = 1, 109 | norm_layer: Optional[Callable[..., nn.Module]] = None 110 | ) -> None: 111 | super(Bottleneck, self).__init__() 112 | if norm_layer is None: 113 | norm_layer = nn.BatchNorm2d 114 | width = int(planes * (base_width / 64.)) * groups 115 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 116 | self.conv1 = conv1x1(inplanes, width) 117 | self.bn1 = norm_layer(width) 118 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 119 | self.bn2 = norm_layer(width) 120 | self.conv3 = conv1x1(width, planes * self.expansion) 121 | self.bn3 = norm_layer(planes * self.expansion) 122 | self.relu = nn.ReLU(inplace=True) 123 | self.downsample = downsample 124 | self.stride = stride 125 | 126 | def forward(self, x: Tensor) -> Tensor: 127 | identity = x 128 | 129 | out = self.conv1(x) 130 | out = self.bn1(out) 131 | out = self.relu(out) 132 | 133 | out = self.conv2(out) 134 | out = self.bn2(out) 135 | out = self.relu(out) 136 | 137 | out = self.conv3(out) 138 | out = self.bn3(out) 139 | 140 | if self.downsample is not None: 141 | identity = self.downsample(x) 142 | 143 | out += identity 144 | out = self.relu(out) 145 | 146 | return out 147 | 148 | 149 | class ResNet(nn.Module): 150 | 151 | def __init__( 152 | self, 153 | block: Type[Union[BasicBlock, Bottleneck]], 154 | layers: List[int], 155 | num_classes: int = 1000, 156 | zero_init_residual: bool = False, 157 | groups: int = 1, 158 | width_per_group: int = 64, 159 | replace_stride_with_dilation: Optional[List[bool]] = None, 160 | norm_layer: Optional[Callable[..., nn.Module]] = None 161 | ) -> None: 162 | super(ResNet, self).__init__() 163 | if norm_layer is None: 164 | norm_layer = nn.BatchNorm2d 165 | self._norm_layer = norm_layer 166 | 167 | self.inplanes = 64 168 | self.dilation = 1 169 | if replace_stride_with_dilation is None: 170 | # each element in the tuple indicates if we should replace 171 | # the 2x2 stride with a dilated convolution instead 172 | replace_stride_with_dilation = [False, False, False] 173 | if len(replace_stride_with_dilation) != 3: 174 | raise ValueError("replace_stride_with_dilation should be None " 175 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 176 | self.groups = groups 177 | self.base_width = width_per_group 178 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 179 | bias=False) 180 | self.bn1 = norm_layer(self.inplanes) 181 | self.relu = nn.ReLU(inplace=True) 182 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 183 | self.layer1 = self._make_layer(block, 64, layers[0]) 184 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 185 | dilate=replace_stride_with_dilation[0]) 186 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 187 | dilate=replace_stride_with_dilation[1]) 188 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 189 | dilate=replace_stride_with_dilation[2]) 190 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 191 | self.fc = nn.Linear(512 * block.expansion, num_classes) 192 | 193 | for m in self.modules(): 194 | if isinstance(m, nn.Conv2d): 195 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 196 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 197 | nn.init.constant_(m.weight, 1) 198 | nn.init.constant_(m.bias, 0) 199 | 200 | # Zero-initialize the last BN in each residual branch, 201 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 202 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 203 | if zero_init_residual: 204 | for m in self.modules(): 205 | if isinstance(m, Bottleneck): 206 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] 207 | elif isinstance(m, BasicBlock): 208 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] 209 | 210 | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, 211 | stride: int = 1, dilate: bool = False) -> nn.Sequential: 212 | norm_layer = self._norm_layer 213 | downsample = None 214 | previous_dilation = self.dilation 215 | if dilate: 216 | self.dilation *= stride 217 | stride = 1 218 | if stride != 1 or self.inplanes != planes * block.expansion: 219 | downsample = nn.Sequential( 220 | conv1x1(self.inplanes, planes * block.expansion, stride), 221 | norm_layer(planes * block.expansion), 222 | ) 223 | 224 | layers = [] 225 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 226 | self.base_width, previous_dilation, norm_layer)) 227 | self.inplanes = planes * block.expansion 228 | for _ in range(1, blocks): 229 | layers.append(block(self.inplanes, planes, groups=self.groups, 230 | base_width=self.base_width, dilation=self.dilation, 231 | norm_layer=norm_layer)) 232 | 233 | return nn.Sequential(*layers) 234 | 235 | def _forward_impl(self, x) -> Tensor: 236 | # See note [TorchScript super()] 237 | x = self.conv1(x) 238 | x = self.bn1(x) 239 | x = self.relu(x) 240 | x = self.maxpool(x) 241 | x = self.layer1(x) 242 | x = self.layer2(x) 243 | x = self.layer3(x) 244 | x = self.layer4(x) 245 | x = self.avgpool(x) 246 | x = torch.flatten(x, 1) 247 | x = self.fc(x) 248 | 249 | return x 250 | 251 | def forward(self, x: Tensor) -> Tensor: 252 | return self._forward_impl(x) 253 | 254 | 255 | def _resnet( 256 | arch: str, 257 | block: Type[Union[BasicBlock, Bottleneck]], 258 | layers: List[int], 259 | pretrained: bool, 260 | progress: bool, 261 | **kwargs: Any 262 | ) -> ResNet: 263 | model = ResNet(block, layers, **kwargs) 264 | if pretrained: 265 | ckpt_dict = load_state_dict_from_url(model_urls[arch], 266 | progress=progress) 267 | 268 | model_dict = model.state_dict() 269 | restore_dict = {} 270 | for k, v in ckpt_dict.items(): 271 | restore_dict[k] = v 272 | 273 | for k, v in model_dict.items(): 274 | if 'ignore' in k: 275 | old_name = k.replace('ignore.', '') 276 | ckpt_val = ckpt_dict[old_name] 277 | del restore_dict[old_name] 278 | restore_dict[k] = ckpt_val 279 | 280 | model.load_state_dict(restore_dict) 281 | return model 282 | 283 | 284 | def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 285 | r"""ResNet-18 model from 286 | `"Deep Residual Learning for Image Recognition" `_ 287 | 288 | Args: 289 | pretrained (bool): If True, returns a model pre-trained on ImageNet 290 | progress (bool): If True, displays a progress bar of the download to stderr 291 | """ 292 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 293 | **kwargs) 294 | 295 | 296 | def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 297 | r"""ResNet-34 model from 298 | `"Deep Residual Learning for Image Recognition" `_ 299 | 300 | Args: 301 | pretrained (bool): If True, returns a model pre-trained on ImageNet 302 | progress (bool): If True, displays a progress bar of the download to stderr 303 | """ 304 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 305 | **kwargs) 306 | 307 | 308 | def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 309 | r"""ResNet-50 model from 310 | `"Deep Residual Learning for Image Recognition" `_ 311 | 312 | Args: 313 | pretrained (bool): If True, returns a model pre-trained on ImageNet 314 | progress (bool): If True, displays a progress bar of the download to stderr 315 | """ 316 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 317 | **kwargs) 318 | 319 | 320 | def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 321 | r"""ResNet-101 model from 322 | `"Deep Residual Learning for Image Recognition" `_ 323 | 324 | Args: 325 | pretrained (bool): If True, returns a model pre-trained on ImageNet 326 | progress (bool): If True, displays a progress bar of the download to stderr 327 | """ 328 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 329 | **kwargs) 330 | 331 | 332 | def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 333 | r"""ResNet-152 model from 334 | `"Deep Residual Learning for Image Recognition" `_ 335 | 336 | Args: 337 | pretrained (bool): If True, returns a model pre-trained on ImageNet 338 | progress (bool): If True, displays a progress bar of the download to stderr 339 | """ 340 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 341 | **kwargs) 342 | 343 | 344 | def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 345 | r"""ResNeXt-50 32x4d model from 346 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 347 | 348 | Args: 349 | pretrained (bool): If True, returns a model pre-trained on ImageNet 350 | progress (bool): If True, displays a progress bar of the download to stderr 351 | """ 352 | kwargs['groups'] = 32 353 | kwargs['width_per_group'] = 4 354 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 355 | pretrained, progress, **kwargs) 356 | 357 | 358 | def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 359 | r"""ResNeXt-101 32x8d model from 360 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 361 | 362 | Args: 363 | pretrained (bool): If True, returns a model pre-trained on ImageNet 364 | progress (bool): If True, displays a progress bar of the download to stderr 365 | """ 366 | kwargs['groups'] = 32 367 | kwargs['width_per_group'] = 8 368 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 369 | pretrained, progress, **kwargs) 370 | 371 | 372 | def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 373 | r"""Wide ResNet-50-2 model from 374 | `"Wide Residual Networks" `_ 375 | 376 | The model is the same as ResNet except for the bottleneck number of channels 377 | which is twice larger in every block. The number of channels in outer 1x1 378 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 379 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 380 | 381 | Args: 382 | pretrained (bool): If True, returns a model pre-trained on ImageNet 383 | progress (bool): If True, displays a progress bar of the download to stderr 384 | """ 385 | kwargs['width_per_group'] = 64 * 2 386 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 387 | pretrained, progress, **kwargs) 388 | 389 | 390 | def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 391 | r"""Wide ResNet-101-2 model from 392 | `"Wide Residual Networks" `_ 393 | 394 | The model is the same as ResNet except for the bottleneck number of channels 395 | which is twice larger in every block. The number of channels in outer 1x1 396 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 397 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 398 | 399 | Args: 400 | pretrained (bool): If True, returns a model pre-trained on ImageNet 401 | progress (bool): If True, displays a progress bar of the download to stderr 402 | """ 403 | kwargs['width_per_group'] = 64 * 2 404 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 405 | pretrained, progress, **kwargs) 406 | 407 | 408 | 409 | # def ResNet2D(num_classes, extract_features=False, loss_type='softmax', 410 | # weights=None, freeze_all_but_cls=False): 411 | # net = globals()['resnet' + str(50)]( 412 | # num_classes=num_classes, 413 | # sample_size=50, 414 | # sample_duration=32, 415 | # extract_features=extract_features, 416 | # loss_type=loss_type, 417 | # ) 418 | 419 | # if weights is not None: 420 | # kinetics_weights = torch.load(weights)['state_dict'] 421 | # print("Found weights in {}.".format(weights)) 422 | # cls_name = 'fc' 423 | # else: 424 | # kinetics_weights = torch.load(model_urls['resnet50']) 425 | # cls_name = 'fc' 426 | # print('\n Restoring Kintetics \n') 427 | 428 | # new_weights = {} 429 | # for k, v in kinetics_weights.items(): 430 | # if not k.startswith('module.' + cls_name): 431 | # new_weights[k.replace('module.', '')] = v 432 | # else: 433 | # print(f"!!! Smt wrong with restore {k}") 434 | # net.load_state_dict(new_weights, strict=False) 435 | 436 | # if freeze_all_but_cls: 437 | # for name, par in net.named_parameters(): 438 | # if not name.startswith('classifier'): 439 | # par.requires_grad = False 440 | # return net 441 | -------------------------------------------------------------------------------- /ops/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pdb 3 | import os 4 | import matplotlib.pyplot as plt 5 | import matplotlib.colors as colors 6 | 7 | from PIL import Image 8 | import PIL 9 | from moviepy.editor import ImageSequenceClip 10 | import matplotlib 11 | matplotlib.use("Agg") 12 | global args 13 | from opts import parse_args 14 | args = parse_args() 15 | 16 | 17 | def offsets_to_boxes(offsets, frame_size=224): 18 | pred_boxes_dict = {} 19 | places = args.place_graph.replace('layer','').split('_') 20 | all_positions = [] 21 | for place in places: 22 | layer = int(place.split('.')[0]) 23 | block = int(place.split('.')[1]) 24 | pred_position = f'layer{layer}_{block}' 25 | lb = (layer,block) 26 | all_positions.append(lb) 27 | for (layer,block) in all_positions: 28 | pred_position = f'layer{layer}_block{block}' 29 | num_layer = int(layer) 30 | pred_offset = offsets[pred_position].squeeze().detach().cpu().numpy() 31 | pred_offsets = pred_offset / args.graph_params[num_layer]['H'] * frame_size #224.0 32 | regions_h = pred_offsets[:,:,0] 33 | regions_w = pred_offsets[:,:,1] 34 | kernel_center_h = pred_offsets[:,:,2] 35 | kernel_center_w = pred_offsets[:,:,3] 36 | 37 | y1 = np.minimum(frame_size-1,np.maximum(0,kernel_center_h - regions_h)) 38 | x1 = np.minimum(frame_size-1,np.maximum(0,kernel_center_w - regions_w)) 39 | y2 = np.minimum(frame_size-1, kernel_center_h + regions_h) 40 | x2 = np.minimum(frame_size-1, kernel_center_w + regions_w) 41 | 42 | pred_boxes = np.stack([y1, x1, y2, x2], axis=-1) 43 | pred_boxes_dict[f'layer{layer}_{block}'] = pred_boxes 44 | return pred_boxes_dict 45 | 46 | 47 | 48 | def draw_box(frame, box, color=[1,1,1]): 49 | H = args.input_size 50 | (h,w,dh,dw) = box 51 | 52 | left = max(0,int(dw-w)) 53 | right = min(int(dw+w),H-1) 54 | up = max(int(dh-h),0) 55 | down = min(int(dh+h), H-1) 56 | 57 | # print(left,right,up,down) 58 | frame[left,up:down] = color#(0,0,1) 59 | 60 | frame[right,up:down] = color#(0,0,1) 61 | frame[left:right, up] = color#(0,0,1) 62 | frame[left:right, down] = color#(0,0,1) 63 | return frame 64 | 65 | def draw_box_gt(frame, target_boxes, colors=[1,1,1], partial=False, line_width=1): 66 | if len(colors) < target_boxes.shape[0]: 67 | colors = colors * target_boxes.shape[0] 68 | for b in range(target_boxes.shape[0]): 69 | if partial and b % 3 != 0: 70 | continue 71 | box = target_boxes[b] 72 | 73 | top_h = int(box[0]) 74 | left_w = int(box[1]) 75 | bot_h = int(box[2]) 76 | right_w = int(box[3]) 77 | frame[top_h:bot_h, left_w:left_w + line_width] = colors[b] 78 | frame[top_h:bot_h, right_w:right_w + line_width] = colors[b] 79 | 80 | frame[top_h: top_h + line_width, left_w : right_w] = colors[b] 81 | frame[bot_h : bot_h + line_width, left_w : right_w] = colors[b] 82 | 83 | return frame 84 | 85 | def save_grad_cam(input, grad_cams, folder, name='initial'): 86 | 87 | import cv2 88 | folder_dir = folder + f'/viz_grad_cam_{name}/' 89 | if not os.path.exists(folder_dir): 90 | os.makedirs(folder_dir) 91 | 92 | video_input_val = input.detach().cpu().numpy() 93 | video_input_val = np.reshape(video_input_val, [args.batch_size, -1, 3 , video_input_val.shape[3],video_input_val.shape[4]]) 94 | video_input_val = np.transpose(video_input_val, [0,1,3,4,2]) 95 | for ind_grad, gr in enumerate(grad_cams): 96 | name = gr[0] 97 | grad_cam = gr[1] 98 | grad_cam = grad_cam.view((-1,16) + grad_cam.shape[1:]) 99 | num_saved_videos = 2 100 | for video_idx in range(min(num_saved_videos, args.batch_size)): 101 | max_t = 16 102 | all_frames = [] 103 | for tt in range(max_t): 104 | if 'resnet50_smt_else' in args.arch: 105 | real_tt = 2*tt 106 | else: 107 | real_tt = tt 108 | 109 | frame = video_input_val[video_idx,real_tt,:,:] 110 | if args.modality == 'RGB': 111 | frame = frame * 57.375 + 114.75 112 | 113 | grad_c = grad_cam[video_idx, tt].cpu().detach().numpy() 114 | 115 | grad_min = grad_c.min() 116 | grad_max = (grad_c - grad_min).max() 117 | grad_c = (grad_c - grad_min) / grad_max 118 | grad_c = (grad_c * 255).astype(np.uint8) 119 | grad_c = np.array(Image.fromarray(grad_c).resize((frame.shape[0],frame.shape[1]), resample=PIL.Image.BILINEAR)) 120 | 121 | heatmap = cv2.applyColorMap(grad_c, cv2.COLORMAP_JET) 122 | cam = np.float32(heatmap) + np.float32(frame) 123 | cam = cam / cam.max() 124 | 125 | combined_img =np.concatenate((np.uint8(frame), np.uint8(255 * cam)), axis=1) 126 | cv2.imwrite(os.path.join(folder_dir, f'video_{video_idx}_frame_{tt}_{ind_grad}_grad_{name}.jpg'), combined_img) 127 | 128 | 129 | def save_mean_kernels(all_kernels,epoch=0,folder=''): 130 | places = args.place_graph.replace('layer','').split('_') 131 | placement_all_models = [] 132 | for place in places: 133 | layer = int(place.split('.')[0]) 134 | block = int(place.split('.')[1]) 135 | placement_all_models.append(f'layer{layer}_block{block}') 136 | 137 | for placement in placement_all_models: 138 | kernel_val = all_kernels[placement] 139 | 140 | kernel_val = np.reshape(kernel_val, [args.num_segments , 3,3,kernel_val.shape[-2],kernel_val.shape[-1]]) 141 | for tt in range(args.num_segments): 142 | f, axarr = plt.subplots(3,3) 143 | for ii in range(3): 144 | for jj in range(3): 145 | curent_kernel = kernel_val[tt][ii][jj] 146 | curent_kernel_max = curent_kernel.max() 147 | curent_kernel = curent_kernel / curent_kernel_max 148 | curent_kernel = (curent_kernel * 255).astype(np.uint8) 149 | curent_kernel = np.array(Image.fromarray(curent_kernel).resize((224,224), resample=PIL.Image.BILINEAR)) 150 | curent_kernel = curent_kernel.astype(np.float32) / 255.0 151 | 152 | axarr[ii][jj].imshow(curent_kernel) 153 | 154 | placement = placement.replace('block','') 155 | folder_dir = folder + f'/viz_kernel_{placement}/' 156 | if not os.path.exists(folder_dir): 157 | os.makedirs(folder_dir) 158 | plt.savefig(f'{folder_dir}/mean_kernels_time_epoch_{epoch}_time_{tt}.png') 159 | 160 | # offsets: predicted offsets 161 | # target_offset: distilled offsets 162 | # input: B x TC x 224 x 224 163 | def save_kernels(input, interm_feats, folder, name='initial', target_offset=None, target_boxes_val=None, predicted_offsets=None): 164 | # predicted offsets is a dict. detach when used 165 | predicted_boxes_dict = offsets_to_boxes(predicted_offsets,frame_size=input.shape[-1]) 166 | 167 | places = args.place_graph.replace('layer','').split('_') 168 | placement_all_models = []# ['layer4', 'layer3','layer2','layer1'] 169 | for place in places: 170 | layer = int(place.split('.')[0]) 171 | block = int(place.split('.')[1]) 172 | placement_all_models.append(f'layer{layer}_{block}') 173 | input_ch = 3 174 | if args.modality == 'gray': 175 | input_ch = 1 176 | for placement in placement_all_models: 177 | predicted_boxes = predicted_boxes_dict[placement] 178 | predicted_boxes = predicted_boxes.reshape(input.shape[0], args.num_segments, predicted_boxes.shape[1], predicted_boxes.shape[2]) 179 | 180 | kernel_val = interm_feats[placement+'_kernels'].detach().cpu().numpy() 181 | video_input_val = input.detach().cpu().numpy() 182 | 183 | video_input_val = np.reshape(video_input_val, [args.batch_size * args.test_crops, -1, input_ch , video_input_val.shape[-2],video_input_val.shape[-1]]) 184 | video_input_val = np.transpose(video_input_val, [0,1,3,4,2]) 185 | if args.modality == 'gray': 186 | tmp_zeros = -1 * np.ones((video_input_val.shape[0], video_input_val.shape[1], video_input_val.shape[2], video_input_val.shape[3], 3)) 187 | tmp_zeros[:,:,:,:,0] = video_input_val[:,:,:,:,0] 188 | video_input_val = tmp_zeros 189 | 190 | 191 | kernel_val = np.reshape(kernel_val, [args.batch_size * args.test_crops, -1 , kernel_val.shape[1],kernel_val.shape[2],kernel_val.shape[3],kernel_val.shape[4]]) 192 | 193 | folder_dir = folder + f'/viz_kernel_{placement}/' 194 | if not os.path.exists(folder_dir): 195 | os.makedirs(folder_dir) 196 | 197 | folder_dir = folder_dir + f'/{name}/' 198 | if not os.path.exists(folder_dir): 199 | os.makedirs(folder_dir) 200 | tt = 2 201 | 202 | num_rows = kernel_val.shape[2] 203 | save_individual_frames = True 204 | 205 | num_saved_videos = 2 206 | 207 | for video_idx in range(min(num_saved_videos, args.test_crops * args.batch_size)): 208 | max_t = kernel_val.shape[1] 209 | all_frames = [] 210 | for tt in range(max_t): 211 | if 'resnet50_smt_else' in args.arch: 212 | real_tt = 2*tt 213 | else: 214 | real_tt = tt 215 | frame = video_input_val[video_idx,real_tt,:,:] 216 | if args.modality == 'RGB': 217 | frame = frame * 57.375 + 114.75 218 | frame = frame / 255.0 219 | else: 220 | frame = (frame + 1.0) / 2.0 221 | 222 | frame_hsv = colors.rgb_to_hsv(frame) 223 | 224 | rgb = np.array([ [1, 0, 0 ], [0, 1, 0 ], [0, 0, 1 ] ]) 225 | hsv = colors.rgb_to_hsv(rgb) 226 | 227 | f, axarr = plt.subplots(num_rows,num_rows) 228 | N = kernel_val.shape[2] * kernel_val.shape[3] 229 | 230 | max_s = frame_hsv[:,:,1].max() 231 | max_v = frame_hsv[:,:,2].mean() + 0.85 * (frame_hsv[:,:,2].max() - frame_hsv[:,:,2].mean()) # poate de pus mean 232 | HSV_tuples = [(x*1.0/N, max_s, max_v) for x in range(N)] 233 | 234 | color_kernels = np.zeros( frame.shape[:2] + (3,) ) 235 | 236 | cc = 0 237 | for ii in range(kernel_val.shape[2]): 238 | for jj in range(kernel_val.shape[3]): 239 | 240 | if save_individual_frames: 241 | axarr[ii][jj].imshow(kernel_val[video_idx][tt][ii][jj]) 242 | curent_kernel = kernel_val[video_idx][tt][ii][jj] 243 | curent_kernel_max = curent_kernel.max() 244 | curent_kernel = curent_kernel / curent_kernel_max 245 | curent_kernel = (curent_kernel * 255).astype(np.uint8) 246 | curent_kernel = np.array(Image.fromarray(curent_kernel).resize((frame.shape[0],frame.shape[1]), resample=PIL.Image.BILINEAR)) 247 | curent_kernel = curent_kernel.astype(np.float32) / 255.0 248 | 249 | mask = ((curent_kernel / curent_kernel.max()) > 0.3) 250 | 251 | curent_kernel = curent_kernel 252 | 253 | color_kernels[:,:,0] = (1.0 - mask) * color_kernels[:,:,0] + mask * HSV_tuples[cc][0] 254 | color_kernels[:,:,1] = (1.0 - mask) * color_kernels[:,:,1] + mask * HSV_tuples[cc][1] 255 | color_kernels[:,:,2] = (1.0 - mask) * color_kernels[:,:,2] + mask * curent_kernel * HSV_tuples[cc][2] 256 | 257 | cc += 1 258 | 259 | if frame.shape[2] == 1: 260 | # gray to rgb 261 | frame = np.tile(frame, [1,1,3]) 262 | 263 | 264 | if target_offset and 'layer1' not in placement: 265 | frame_distill = frame.copy() 266 | for ii in range(kernel_val.shape[2] * kernel_val.shape[3]): 267 | place = placement.replace('layer', '') 268 | layer = int(place.split('_')[0]) 269 | block = int(place.split('_')[1]) 270 | # TODO: un-hardcodat 224 271 | h,w,dh,dw = target_offset[f'layer{layer}_block{block}'][video_idx, tt,ii] / kernel_val.shape[-1] * 224 272 | frame_distill = draw_box(frame_distill, (h,w,dh,dw)) 273 | frame_distill = np.clip(frame_distill, 0.0, 1.0) 274 | rgb_color_kernels = colors.hsv_to_rgb(color_kernels) 275 | frame[:,:] += 1 * rgb_color_kernels#10 * kernel_sum 276 | frame = np.clip(frame, 0.0, 1.0) 277 | 278 | if save_individual_frames: 279 | plt.savefig(f'{folder_dir}/kernel_{video_idx}_{tt}.png') 280 | f, axarr = plt.subplots(1,1) 281 | axarr.imshow(frame) 282 | plt.savefig(f'{folder_dir}/frame_{video_idx}_{tt}.png') 283 | if target_offset and 'layer1' not in placement: 284 | f, axarr = plt.subplots(1,1) 285 | axarr.imshow(frame_distill) 286 | plt.savefig(f'{folder_dir}/z_distill_{video_idx}_{tt}.png') 287 | plt.close(f) 288 | 289 | 290 | frame = (frame / frame.max() * 255.0).astype(np.uint8) 291 | all_frames.append(frame) 292 | 293 | clip = ImageSequenceClip(list(all_frames), fps=3) 294 | clip.write_gif(f'{folder_dir}/video_{video_idx}.gif', fps=3) 295 | 296 | plt.close('all') 297 | 298 | 299 | def draw_frames(input, folder, ref_point, ref_dim, video_id, name=''): 300 | print(input.shape) 301 | video_input_val = input 302 | if name == 'crop': 303 | video_input_val = np.reshape(video_input_val, [-1, 3 , video_input_val.shape[1],video_input_val.shape[2]]) 304 | video_input_val = np.transpose(video_input_val, [0,2,3,1]) 305 | 306 | folder_dir = folder + f'/viz_kernel_tmp/' 307 | if not os.path.exists(folder_dir): 308 | os.makedirs(folder_dir) 309 | 310 | folder_dir = folder_dir + f'/' #'./viz_kernel_center/' 311 | if not os.path.exists(folder_dir): 312 | os.makedirs(folder_dir) 313 | tt = 2 314 | 315 | video_idx = int(video_id) 316 | for tt in range(16): 317 | frame = video_input_val[tt,:,:] 318 | if args.modality == 'RGB' and name == 'crop': 319 | frame = frame * 57.375 + 114.75 320 | frame = frame / 255.0 321 | elif name == 'crop': 322 | frame = (frame + 1.0) / 2.0 323 | else: 324 | frame = frame / 255.0 325 | 326 | frame_hsv = colors.rgb_to_hsv(frame) 327 | 328 | rgb = np.array([ [1, 0, 0 ], [0, 1, 0 ], [0, 0, 1 ] ]) 329 | 330 | frame = np.clip(frame, 0.0, 1.0) 331 | # draw center 332 | frame[ref_point[0]-3:ref_point[0]+3, ref_point[1]-3:ref_point[1]+3] = (1.0, 0.0, 0.0) 333 | # draw box 334 | left = ref_point[1] - ref_dim[1] 335 | right = ref_point[1] + ref_dim[1] 336 | up = ref_point[0] - ref_dim[0] 337 | down = ref_point[0] + ref_dim[0] 338 | frame[up-1:up+1,left:right] = (0,1,0) 339 | frame[down-1:down+1,left:right] = (0,1,0) 340 | frame[up:down,left-1:left+1] = (0,1,0) 341 | frame[up:down,right-1:right+1] = (0,1,0) 342 | 343 | f, axarr = plt.subplots(1,1) 344 | axarr.imshow(frame) 345 | plt.savefig(f'{folder_dir}/frame_{video_idx}_{tt}_{name}.png') 346 | plt.close(f) 347 | plt.close('all') 348 | 349 | 350 | 351 | def count_params(params, contains=[''], ignores=['nonenone']): 352 | total_params = 0 353 | for name, shape in params: 354 | # print(f'{name} shape: {shape}') 355 | ok = True 356 | for ignore_name in ignores: 357 | if ignore_name in name: 358 | ok = False 359 | if not ok: 360 | continue 361 | prod = 1 362 | for d in shape: 363 | prod *= d 364 | selected_param = False 365 | for c in contains: 366 | if c in name: 367 | selected_param = True 368 | if selected_param: 369 | total_params += prod 370 | return total_params 371 | def softmax(scores): 372 | es = np.exp(scores - scores.max(axis=-1)[..., None]) 373 | return es / es.sum(axis=-1)[..., None] 374 | 375 | 376 | class LearnedParamChecker(): 377 | def __init__(self,model): 378 | self.model = model 379 | self.initial_params = self.save_initial_params() 380 | 381 | def save_initial_params(self): 382 | initial_params = {} 383 | for name, p in self.model.named_parameters(): 384 | initial_params[name] = p.detach().cpu().numpy() 385 | return initial_params 386 | 387 | def compare_current_initial_params(self): 388 | for name, current_p in self.model.named_parameters(): 389 | initial_p = self.initial_params[name] 390 | current_p = current_p.detach().cpu().numpy() 391 | diff = np.mean(np.abs(initial_p - current_p)) 392 | print(f'params: {name} mean change : {diff}') 393 | 394 | 395 | 396 | class AverageMeter(object): 397 | """Computes and stores the average and current value""" 398 | 399 | def __init__(self): 400 | self.reset() 401 | 402 | def reset(self): 403 | self.val = 0 404 | self.avg = 0 405 | self.sum = 0 406 | self.count = 0 407 | 408 | def update(self, val, n=1): 409 | self.val = val 410 | self.sum += val * n 411 | self.count += n 412 | self.avg = self.sum / self.count 413 | 414 | 415 | def accuracy(output, target, topk=(1,)): 416 | """Computes the precision@k for the specified values of k""" 417 | maxk = max(topk) 418 | batch_size = target.size(0) 419 | 420 | _, pred = output.topk(maxk, 1, True, True) 421 | pred = pred.t() 422 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 423 | 424 | res = [] 425 | for k in topk: 426 | # correct_k = correct[:k].view(-1).float().sum(0) 427 | correct_k = correct[:k].reshape(-1).float().sum(0) 428 | 429 | res.append(correct_k.mul_(100.0 / batch_size)) 430 | return res 431 | 432 | 433 | 434 | 435 | def adjust_kernel(pred_kernel_val, pred_boxes): 436 | # pred_kernel_val: (BT, 3, 3, 14, 14) 437 | # pred_boxes: (BT, 9, 4) ([left_h, left_w, righ_h, right_w])) 438 | 439 | pred_kernel_val = np.reshape(pred_kernel_val,(pred_kernel_val.shape[0], pred_kernel_val.shape[1]*pred_kernel_val.shape[2], pred_kernel_val.shape[3], pred_kernel_val.shape[4])) 440 | all_adjust_kernel_boxes = np.zeros_like(pred_boxes) 441 | 442 | time1 = time.time() 443 | for b in range(pred_kernel_val.shape[0]): 444 | for i in range(9): 445 | # kernel from 14x14 ->224x224 446 | curent_kernel = pred_kernel_val[b,i] / pred_kernel_val[b,i].max() 447 | curent_kernel = (curent_kernel * 255).astype(np.uint8) 448 | 449 | time5 = time.time() 450 | curent_kernel = np.array(Image.fromarray(curent_kernel).resize((224,224), resample=PIL.Image.BILINEAR)) 451 | time6 = time.time() 452 | # print(f'Resize time: {time6-time5}') 453 | curent_kernel = curent_kernel.astype(np.float32) / 255.0 454 | curent_kernel = curent_kernel / curent_kernel.sum() 455 | 456 | (left_h, left_w, right_h, right_w) = pred_boxes[b,i] 457 | h = right_h - left_h 458 | w = right_w - left_w 459 | # print(f'New frame h={h}, w={w}') 460 | prev_sums = curent_kernel.sum() 461 | prev_adjust_kernel_boxes = np.array([left_h, left_w, right_h, right_w]) 462 | prev_dx = 0.0 463 | prev_dy = 0.0 464 | time7 = time.time() 465 | for j in np.arange(0.01,1, 0.01): 466 | dx = j * (h/2) 467 | dy = j * (w/2) 468 | new_left_h = int(left_h+dx) 469 | new_left_w = int(left_w+dy) 470 | new_right_h = int(right_h-dx) 471 | new_right_w = int(right_w-dy) 472 | adjust_kernel_boxes = np.array([new_left_h, new_left_w, new_right_h, new_right_w]) 473 | crt_sums = curent_kernel[new_left_h:new_right_h+1, new_left_w:new_right_w+1].sum() 474 | if crt_sums < 0.90 and prev_sums >= 0.90: 475 | # print(prev_dx, prev_dy, prev_sums) 476 | # print(dx, dy, crt_sums) 477 | 478 | # print(f'[ {left_h}, {left_w}, {right_h}, {right_w}] -> [{new_left_h}, {new_left_w}, {new_right_h}, {new_right_w}]') 479 | all_adjust_kernel_boxes[b,i] = prev_adjust_kernel_boxes 480 | break 481 | 482 | prev_sums = crt_sums 483 | prev_adjust_kernel_boxes = np.array([new_left_h, new_left_w, new_right_h, new_right_w]) 484 | prev_dx = dx 485 | prev_dy = dy 486 | time8 = time.time() 487 | # print(f'find adjustment: {time8-time7}') 488 | time2 = time.time() 489 | print(f'Adjust kernel: {time2-time1}') 490 | return all_adjust_kernel_boxes 491 | --------------------------------------------------------------------------------