├── ops
├── __init__.py
├── basic_ops.py
├── generate_scripts.py
├── dataset_config.py
├── non_local.py
├── rstg.py
├── models_config.py
├── dataset.py
├── temporal_shift.py
├── dataset_syncMNIST.py
├── nested_model.py
├── models_utils.py
├── resnet3d_xl.py
├── resnet2d.py
└── utils.py
├── .gitignore
├── DyReg_GNN_arch.png
├── train_model.sh
├── tools
├── gen_label_other.py
├── vid2img_sthv2.py
├── gen_label_sthv2.py
├── gen_label_sthv1.py
└── vid2img_other.py
├── evaluate_model.sh
├── evaluate_model_multi_clip.sh
├── create_model.py
├── README.md
├── opts.py
├── LICENSE
└── test_models.py
/ops/__init__.py:
--------------------------------------------------------------------------------
1 | from ops.basic_ops import *
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # ignore all logs
2 | **__pycache__*
3 |
--------------------------------------------------------------------------------
/DyReg_GNN_arch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bit-ml/DyReg-GNN/HEAD/DyReg_GNN_arch.png
--------------------------------------------------------------------------------
/ops/basic_ops.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class Identity(torch.nn.Module):
5 | def forward(self, input):
6 | return input
7 |
8 |
9 | class SegmentConsensus(torch.nn.Module):
10 |
11 | def __init__(self, consensus_type, dim=1):
12 | super(SegmentConsensus, self).__init__()
13 | self.consensus_type = consensus_type
14 | self.dim = dim
15 | self.shape = None
16 |
17 | def forward(self, input_tensor):
18 | self.shape = input_tensor.size()
19 | if self.consensus_type == 'avg':
20 | output = input_tensor.mean(dim=self.dim, keepdim=True)
21 | elif self.consensus_type == 'identity':
22 | output = input_tensor
23 | else:
24 | output = None
25 |
26 | return output
27 |
28 |
29 | class ConsensusModule(torch.nn.Module):
30 |
31 | def __init__(self, consensus_type, dim=1):
32 | super(ConsensusModule, self).__init__()
33 | self.consensus_type = consensus_type if consensus_type != 'rnn' else 'identity'
34 | self.dim = dim
35 |
36 | def forward(self, input):
37 | return SegmentConsensus(self.consensus_type, self.dim)(input)
38 |
--------------------------------------------------------------------------------
/train_model.sh:
--------------------------------------------------------------------------------
1 | RAND=$((RANDOM))
2 |
3 | MODEL_DIR='./models/'$1
4 | LOG_NAME=$MODEL_DIR'/log_'$RAND
5 | mkdir $MODEL_DIR
6 |
7 | args="--print-freq=1 --name=$1
8 | --npb --offset_generator=big
9 | --init_skip_zero=False --warmup_validate=False
10 | --dynamic_regions=dyreg --init_regions=center
11 | --place_graph=layer2.2_layer3.4_layer4.1
12 | --ch_div=4 --graph_residual_type=norm
13 | --rnn_type=GRU --aggregation_type=dot --send_layers=1
14 | --use_rstg=True --rstg_skip_connection=False --remap_first=True
15 | --update_layers=0 --combine_by_sum=True --project_at_input=True
16 | --tmp_norm_skip_conn=False --bottleneck_graph=True
17 | --lr 0.001 --batch-size 2 --dataset=somethingv2 --modality=RGB --arch=resnet50 --num_segments 16 --gd 20
18 | --lr_type=step --lr_steps 20 30 40 --epochs 60 -j 16 --dropout=0.5 --consensus_type=avg --eval-freq=1
19 | --shift_div=8 --shift --shift_place=blockres --model_dir=$MODEL_DIR"
20 |
21 |
22 | echo $args > $MODEL_DIR'/args_'$RAND
23 |
24 | # run on CPU:
25 | # CUDA_VISIBLE_DEVICES="" python -u main_standard.py $args & tee $LOG_NAME
26 |
27 | # run on GPU
28 | CUDA_VISIBLE_DEVICES=0 python -u -m torch.distributed.launch --nproc_per_node=1 --master_port=6004 main_standard.py $args |& tee $LOG_NAME
29 |
--------------------------------------------------------------------------------
/tools/gen_label_other.py:
--------------------------------------------------------------------------------
1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding"
2 | # arXiv:1811.08383
3 | # Ji Lin*, Chuang Gan, Song Han
4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu
5 | # ------------------------------------------------------
6 | # Code adapted from https://github.com/metalbubble/TRN-pytorch/blob/master/process_dataset.py
7 | # processing the raw data of the video Something-Something-V2
8 |
9 | import os
10 | import json
11 | import pdb
12 |
13 | if __name__ == '__main__':
14 | num_sec = 5
15 |
16 | dataset_folder = f'/data/datasets/in_the_wild/gifs-frames-{num_sec}s'
17 | filename_output = f'/data/datasets/in_the_wild/gifs-frames-{num_sec}s.txt'
18 |
19 | folders = os.listdir(dataset_folder)
20 | output = []
21 | for i in range(len(folders)):
22 | curFolder = folders[i]
23 | curIDX = 0
24 | # counting the number of frames in each video folders
25 | video_folder = os.path.join(dataset_folder, curFolder)
26 | if os.path.exists(video_folder):
27 | dir_files = os.listdir(video_folder)
28 | output.append('%s %d %d' % (curFolder, len(dir_files), curIDX))
29 | print('%d/%d' % (i, len(folders)))
30 | else:
31 | print(f'video {video_folder} does not exist: skipping')
32 |
33 | with open(filename_output, 'w') as f:
34 | f.write('\n'.join(output))
35 |
--------------------------------------------------------------------------------
/evaluate_model.sh:
--------------------------------------------------------------------------------
1 | RAND=$((RANDOM))
2 |
3 | MODEL_DIR='./models/'$1
4 | LOG_NAME=$MODEL_DIR'/log_'$RAND
5 |
6 | mkdir $MODEL_DIR
7 |
8 | RESUME_CKPT='./checkpoints/dyreg_gnn_model_l2l3l4.pth.tar'
9 |
10 | args="--name=$1 --visualisation=hsv --weights=$RESUME_CKPT
11 | --npb --warmup_validate=False
12 | --offset_generator=big --dynamic_regions=dyreg
13 | --use_rstg=True --bottleneck_graph=True
14 | --rstg_skip_connection=False --remap_first=True
15 | --place_graph=layer2.2_layer3.4_layer4.1
16 | --ch_div=4 --graph_residual_type=norm
17 | --rnn_type=GRU --aggregation_type=dot --send_layers=1
18 | --update_layers=0 --combine_by_sum=True --project_at_input=True
19 | --tmp_norm_skip_conn=False --init_regions=center
20 | --lr 0.001 --batch-size 2 --dataset=somethingv2 --modality=RGB --arch resnet50
21 | --test_segments 16 --gd 20 --lr_steps 20 40 --epochs 60 -j 16
22 | --dropout 0.000000000000000001 --consensus_type=avg --eval-freq=1
23 | --shift --shift_div=8 --shift_place=blockres --model_dir=$MODEL_DIR
24 | --full_size_224=False --full_res=False --save_kernels=True"
25 |
26 |
27 | echo $args > $MODEL_DIR'/args_'$RAND
28 |
29 | # run on CPU:
30 | # CUDA_VISIBLE_DEVICES="" python -u test_models.py $args & tee $LOG_NAME
31 |
32 | # run on GPU
33 | CUDA_VISIBLE_DEVICES=1 python -u -m torch.distributed.launch --nproc_per_node=1 --master_port=6004 test_models.py $args |& tee $LOG_NAME
34 |
35 |
36 |
37 |
--------------------------------------------------------------------------------
/evaluate_model_multi_clip.sh:
--------------------------------------------------------------------------------
1 | RAND=$((RANDOM))
2 |
3 | MODEL_DIR='./models/'$1
4 | LOG_NAME=$MODEL_DIR'/log_'$RAND
5 |
6 | mkdir $MODEL_DIR
7 |
8 | RESUME_CKPT='./checkpoints/dyreg_gnn_model_l2l3l4.pth.tar'
9 |
10 | args="--name=$1 --visualisation=hsv --weights=$RESUME_CKPT
11 | --npb --warmup_validate=False
12 | --offset_generator=big --dynamic_regions=dyreg
13 | --use_rstg=True --bottleneck_graph=True
14 | --rstg_skip_connection=False --remap_first=True
15 | --place_graph=layer2.2_layer3.4_layer4.1
16 | --ch_div=4 --graph_residual_type=norm
17 | --rnn_type=GRU --aggregation_type=dot --send_layers=1
18 | --update_layers=0 --combine_by_sum=True --project_at_input=True
19 | --tmp_norm_skip_conn=False --init_regions=center
20 | --lr 0.001 --batch-size 2 --dataset=somethingv2 --modality=RGB --arch resnet50
21 | --test_segments 16 --gd 20 --lr_steps 20 40 --epochs 60 -j 16
22 | --dropout 0.000000000000000001 --consensus_type=avg --eval-freq=1
23 | --shift --shift_div=8 --shift_place=blockres --model_dir=$MODEL_DIR
24 | --full_size_224=False --full_res=True --test_crops=3 --twice_sample"
25 |
26 |
27 | echo $args > $MODEL_DIR'/args_'$RAND
28 |
29 | # run on CPU:
30 | #CUDA_VISIBLE_DEVICES="" python -u test_models.py $args & tee $LOG_NAME
31 |
32 | # run on GPU
33 | CUDA_VISIBLE_DEVICES=0 python -u -m torch.distributed.launch --nproc_per_node=1 --master_port=6004 test_models.py $args |& tee $LOG_NAME
34 |
35 |
36 |
37 |
--------------------------------------------------------------------------------
/tools/vid2img_sthv2.py:
--------------------------------------------------------------------------------
1 | # Code adapted from "TSM: Temporal Shift Module for Efficient Video Understanding"
2 | # Ji Lin*, Chuang Gan, Song Han
3 |
4 | import os
5 | import threading
6 | import pdb
7 |
8 | NUM_THREADS = 12
9 | VIDEO_ROOT = './data/smt-smt-V2/smt-smt-V2-videos/20bn-something-something-v2/'
10 | FRAME_ROOT = './data/smt-smt-V2/smt-smt-V2-frames/'
11 |
12 |
13 | def split(l, n):
14 | """Yield successive n-sized chunks from l."""
15 | for i in range(0, len(l), n):
16 | yield l[i:i + n]
17 |
18 |
19 | def extract(video, tmpl='%06d.jpg'):
20 | cmd = 'ffmpeg -i \"{}/{}\" -threads 1 -vf scale=-1:256 -q:v 0 \"{}/{}/%06d.jpg\"'.format(VIDEO_ROOT, video,
21 | FRAME_ROOT, video[:-5])
22 | os.system(cmd)
23 |
24 |
25 | def target(video_list):
26 | for video in video_list:
27 | video_path = os.path.join(FRAME_ROOT, video[:-5])
28 | if not os.path.exists(video_path):
29 | os.makedirs(os.path.join(FRAME_ROOT, video[:-5]))
30 | extract(video)
31 | else:
32 | dir_files = os.listdir(os.path.join(FRAME_ROOT, video[:-5]))
33 | if len(dir_files) <= 10 and len(dir_files) != 0:
34 | print(f'folder {video} has only {len(dir_files)} frames')
35 | extract(video)
36 |
37 |
38 |
39 | if __name__ == '__main__':
40 | if not os.path.exists(VIDEO_ROOT):
41 | raise ValueError('Please download videos and set VIDEO_ROOT variable.')
42 | if not os.path.exists(FRAME_ROOT):
43 | os.makedirs(FRAME_ROOT)
44 |
45 | video_list = os.listdir(VIDEO_ROOT)
46 | splits = list(split(video_list, NUM_THREADS))
47 |
48 | threads = []
49 | for i, split in enumerate(splits):
50 | thread = threading.Thread(target=target, args=(split,))
51 | thread.start()
52 | threads.append(thread)
53 |
54 | for thread in threads:
55 | thread.join()
--------------------------------------------------------------------------------
/tools/gen_label_sthv2.py:
--------------------------------------------------------------------------------
1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding"
2 | # arXiv:1811.08383
3 | # Ji Lin*, Chuang Gan, Song Han
4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu
5 | # ------------------------------------------------------
6 | # Code adapted from https://github.com/metalbubble/TRN-pytorch/blob/master/process_dataset.py
7 | # processing the raw data of the video Something-Something-V2
8 |
9 | import os
10 | import json
11 | DATA_UTILS_ROOT='./data/smt-smt-V2/tsm_data/'
12 | FRAME_ROOT='./data/smt-smt-V2/smt-smt-V2-frames/'
13 |
14 | if __name__ == '__main__':
15 | dataset_name = DATA_UTILS_ROOT+'something-something-v2'
16 | with open('%s-labels.json' % dataset_name) as f:
17 | data = json.load(f)
18 | categories = []
19 | for i, (cat, idx) in enumerate(data.items()):
20 | assert i == int(idx) # make sure the rank is right
21 | categories.append(cat)
22 |
23 | with open(DATA_UTILS_ROOT+'category.txt', 'w') as f:
24 | f.write('\n'.join(categories))
25 |
26 | dict_categories = {}
27 | for i, category in enumerate(categories):
28 | dict_categories[category] = i
29 |
30 | files_input = ['%s-validation.json' % dataset_name, '%s-train.json' % dataset_name, '%s-test.json' % dataset_name]
31 | files_output = ['val_videofolder.txt', 'train_videofolder.txt', 'test_videofolder.txt']
32 | for (filename_input, filename_output) in zip(files_input, files_output):
33 | with open(filename_input) as f:
34 | data = json.load(f)
35 | folders = []
36 | idx_categories = []
37 | for item in data:
38 | folders.append(item['id'])
39 | if 'test' not in filename_input:
40 | idx_categories.append(dict_categories[item['template'].replace('[', '').replace(']', '')])
41 | else:
42 | idx_categories.append(0)
43 | output = []
44 | for i in range(len(folders)):
45 | curFolder = folders[i]
46 | curIDX = idx_categories[i]
47 | # counting the number of frames in each video folders
48 | video_folder = os.path.join(FRAME_ROOT, curFolder)
49 | if os.path.exists(video_folder):
50 | dir_files = os.listdir(video_folder)
51 | output.append('%s %d %d' % (curFolder, len(dir_files), curIDX))
52 | print('%d/%d' % (i, len(folders)))
53 | else:
54 | print(f'video {video_folder} does not exist: skipping')
55 |
56 | with open(DATA_UTILS_ROOT+filename_output, 'w') as f:
57 | f.write('\n'.join(output))
58 |
--------------------------------------------------------------------------------
/tools/gen_label_sthv1.py:
--------------------------------------------------------------------------------
1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding"
2 | # arXiv:1811.08383
3 | # Ji Lin*, Chuang Gan, Song Han
4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu
5 | # ------------------------------------------------------
6 | # Code adapted from https://github.com/metalbubble/TRN-pytorch/blob/master/process_dataset.py
7 | # processing the raw data of the video Something-Something-V1
8 |
9 | import os
10 |
11 | if __name__ == '__main__':
12 | dataset_name = '/data/datasets/something-something/something-something-v1' # 'jester-v1''something-something-v1' # 'jester-v1'
13 | with open('%s-labels.csv' % dataset_name) as f:
14 | lines = f.readlines()
15 | categories = []
16 | for line in lines:
17 | line = line.rstrip()
18 | categories.append(line)
19 | categories = sorted(categories)
20 | with open('smtv1_category.txt', 'w') as f:
21 | f.write('\n'.join(categories))
22 |
23 | dict_categories = {}
24 | for i, category in enumerate(categories):
25 | dict_categories[category] = i
26 |
27 | files_input = ['%s-validation.csv' % dataset_name, '%s-train.csv' % dataset_name]
28 | files_output = ['smtv1-val_videofolder.txt', 'smtv1-train_videofolder.txt']
29 | for (filename_input, filename_output) in zip(files_input, files_output):
30 | if 'val' in filename_output:
31 | split = 'valid'
32 | elif 'train' in filename_output:
33 | split = 'train'
34 | with open(filename_input) as f:
35 | lines = f.readlines()
36 | folders = []
37 | idx_categories = []
38 | for line in lines:
39 | line = line.rstrip()
40 | items = line.split(';')
41 | folders.append(items[0])
42 | idx_categories.append(dict_categories[items[1]])
43 | output = []
44 | for i in range(len(folders)):
45 | curFolder = folders[i]
46 | curIDX = idx_categories[i]
47 | # counting the number of frames in each video folders
48 | video_folder = os.path.join(f'/data/datasets/something-something/20bn-something-something-v1/{split}/', curFolder)
49 |
50 | # dir_files = os.listdir(os.path.join('../img', curFolder))
51 | # output.append('%s %d %d' % ('something/v1/img/' + curFolder, len(dir_files), curIDX))
52 | # print('%d/%d' % (i, len(folders)))
53 | if os.path.exists(video_folder):
54 | dir_files = os.listdir(video_folder)
55 | output.append('%s %d %d' % (curFolder, len(dir_files), curIDX))
56 | print('%d/%d' % (i, len(folders)))
57 | else:
58 | print(f'video {video_folder} does not exist: skipping')
59 |
60 | with open(filename_output, 'w') as f:
61 | f.write('\n'.join(output))
62 |
--------------------------------------------------------------------------------
/tools/vid2img_other.py:
--------------------------------------------------------------------------------
1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding"
2 | # arXiv:1811.08383
3 | # Ji Lin*, Chuang Gan, Song Han
4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu
5 |
6 | import os
7 | import threading
8 | import pdb
9 |
10 | NUM_THREADS = 12
11 | # VIDEO_ROOT = '/ssd/video/something/v2/20bn-something-something-v2' # Downloaded webm videos
12 | # FRAME_ROOT = '/ssd/video/something/v2/20bn-something-something-v2-frames' # Directory for extracted frames
13 | # VIDEO_ROOT = '/data/datasets/smt-smt-V2/20bn-something-something-v2'
14 | # FRAME_ROOT = '/data/datasets/smt-smt-V2/20bn-something-something-v2-frames'
15 |
16 | num_sec = 5
17 | VIDEO_ROOT = '/data/datasets/in_the_wild/gifs'
18 | FRAME_ROOT = f'/data/datasets/in_the_wild/gifs-frames-{num_sec}s'
19 |
20 | # VIDEO_ROOT = '/data/datasets/in_the_wild/dataset_imar'
21 | # FRAME_ROOT = f'/data/datasets/in_the_wild/dataset_imar-frames-{num_sec}s'
22 |
23 |
24 | def split(l, n):
25 | """Yield successive n-sized chunks from l."""
26 | for i in range(0, len(l), n):
27 | yield l[i:i + n]
28 |
29 |
30 | def extract(video, tmpl='%06d.jpg'):
31 | # os.system(f'ffmpeg -i {VIDEO_ROOT}/{video} -vf -threads 1 -vf scale=-1:256 -q:v 0 '
32 | # f'{FRAME_ROOT}/{video[:-5]}/{tmpl}')
33 | # cmd0 = 'ffmpeg -i \"{}/{}\" -threads 1 -vf scale=-1:256 -q:v 0 \"{}/{}/%06d.jpg\"'.format(VIDEO_ROOT, video,
34 | # FRAME_ROOT, video[:-5])
35 |
36 | cmd = f'ffmpeg -t 00:0{num_sec} '
37 | cmd = cmd + '-i \"{}/{}\" -threads 1 -vf scale=-1:256 -q:v 0 \"{}/{}/%06d.jpg\"'.format(VIDEO_ROOT, video,
38 | FRAME_ROOT, video[:-5])
39 |
40 |
41 | os.system(cmd)
42 |
43 |
44 | def target(video_list):
45 | for video in video_list:
46 | video_path = os.path.join(FRAME_ROOT, video[:-5])
47 | if not os.path.exists(video_path):
48 | #print(f'video {video_path} does not exists')
49 | os.makedirs(os.path.join(FRAME_ROOT, video[:-5]))
50 | extract(video)
51 | else:
52 | dir_files = os.listdir(os.path.join(FRAME_ROOT, video[:-5]))
53 | if len(dir_files) <= 10:
54 | print(f'folder {video} has only {len(dir_files)} frames')
55 | extract(video)
56 |
57 |
58 |
59 | if __name__ == '__main__':
60 | if not os.path.exists(VIDEO_ROOT):
61 | raise ValueError('Please download videos and set VIDEO_ROOT variable.')
62 | if not os.path.exists(FRAME_ROOT):
63 | os.makedirs(FRAME_ROOT)
64 |
65 | video_list = os.listdir(VIDEO_ROOT)
66 | splits = list(split(video_list, NUM_THREADS))
67 |
68 | threads = []
69 | for i, split in enumerate(splits):
70 | #target(split)
71 | thread = threading.Thread(target=target, args=(split,))
72 | thread.start()
73 | threads.append(thread)
74 |
75 | for thread in threads:
76 | thread.join()
--------------------------------------------------------------------------------
/create_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import functional as F
3 | from ops.dyreg import DynamicGraph, dyregParams
4 |
5 | class SpaceTimeModel(torch.nn.Module):
6 | def __init__(self):
7 | super(SpaceTimeModel, self).__init__()
8 | dyreg_params = dyregParams()
9 | dyregParams.offset_lstm_dim = 32
10 | self.dyreg = DynamicGraph(dyreg_params,
11 | backbone_dim=32, node_dim=32, out_num_ch=32,
12 | H=16, W=16,
13 | iH=16, iW=16,
14 | project_i3d=False,
15 | name='lalalal')
16 |
17 |
18 | self.fc = torch.nn.Linear(32, 10)
19 |
20 | def forward(self, x):
21 | dx = self.dyreg(x)
22 | # you can initialize the dyreg branch as identity function by normalisation,
23 | # as done in DynamicGraphWrapper found in ./ops/dyreg.py
24 | x = x + dx
25 | # average over time and space: T, H, W
26 | x = x.mean(-1).mean(-1).mean(-2)
27 | x = self.fc(x)
28 | return x
29 |
30 |
31 |
32 | class ConvSpaceTimeModel(torch.nn.Module):
33 | def __init__(self):
34 | super(ConvSpaceTimeModel, self).__init__()
35 | dyreg_params = dyregParams()
36 | dyregParams.offset_lstm_dim = 32
37 | self.conv1 = torch.nn.Conv3d(in_channels=3, out_channels=32, kernel_size=[1,3,3], stride=[1,2,2],padding=[0,1,1])
38 | self.conv2 = torch.nn.Conv3d(in_channels=32, out_channels=32, kernel_size=[1,3,3], stride=[1,2,2],padding=[0,1,1])
39 | self.dyreg = DynamicGraph(dyreg_params,
40 | backbone_dim=32, node_dim=32, out_num_ch=32,
41 | H=16, W=16,
42 | iH=16, iW=16,
43 | project_i3d=False,
44 | name='lalalal')
45 |
46 | self.conv3 = torch.nn.Conv3d(in_channels=32, out_channels=32, kernel_size=[1,3,3], stride=[1,2,2],padding=[0,1,1])
47 | self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
48 | self.fc = torch.nn.Linear(32, 10)
49 |
50 | def forward(self, x):
51 | x = F.relu(self.conv1(x))
52 | x = F.relu(self.conv2(x))
53 | input = x.permute(0,2,1,3,4).contiguous()
54 | dx = self.dyreg(input)
55 | dx = dx.permute(0,2,1,3,4).contiguous()
56 | # you can initialize the dyreg branch as identity function by normalisation,
57 | # as done in DynamicGraphWrapper found in ./ops/dyreg.py
58 | x = x + dx
59 | x = F.relu(self.conv3(x))
60 | # average over time and space: T, H, W
61 | x = x.mean(-1).mean(-1).mean(-1)
62 | x = self.fc(x)
63 | return x
64 |
65 | B = 8
66 | T = 10
67 | C = 32
68 | H = 16
69 | W = 16
70 | x = torch.ones(B,T,C,H,W)
71 | st_model = SpaceTimeModel()
72 | out1 = st_model(x)
73 |
74 |
75 | B = 8
76 | T = 10
77 | C = 3
78 | H = 64
79 | W = 64
80 | x = torch.ones(B,C,T,H,W)
81 | conv_st_model = ConvSpaceTimeModel()
82 | out2 = conv_st_model(x)
83 | [print(f'{k} {v.shape}') for k,v in conv_st_model.named_parameters()]
84 |
85 | out = out1 + out2
86 | print('done')
--------------------------------------------------------------------------------
/ops/generate_scripts.py:
--------------------------------------------------------------------------------
1 |
2 | import itertools
3 |
4 | var_params = {}
5 |
6 | # var_params['aux_loss'] = ['inter_videos_all', 'inter_videos_content', 'inter_nodes_all', 'inter_nodes_content']
7 | # var_params['aux_loss'] = ['global_foreground_pos_background_neg',
8 | # 'global_foreground_pos_foreground_neg',
9 | # 'global_foreground_pos_foreground_background_neg',
10 | # 'global_foreground_pos_foreground_background_neg_rand']
11 |
12 |
13 | var_params['contrastive_mlp'] = [False]
14 |
15 | var_params['aux_loss'] = ['nodes_temporal_matching']
16 | var_params['contrastive_alpha'] = [0.1, 1.0]
17 | var_params['contrastive_temp'] = [0.1, 0.5, 1.0, 2.0]
18 | # var_params['contrastive_temp'] = [0.05]
19 |
20 | # var_params['place_graph'] = ['layer2.1_layer3.1_layer4.1']
21 | # var_params['place_graph'] = ['layer3.1']
22 |
23 | somelists = [val for k,val in var_params.items()]
24 |
25 |
26 | for element in itertools.product(*somelists):
27 | print('PARAMS')
28 | args = ''
29 | keys = list(var_params.keys())
30 |
31 | name_model = 'model_contrastive'
32 | for i,k in enumerate(keys):
33 | name_model += f'_{k}_{element[i]}'
34 | args += f' --{k}={element[i]}'
35 | args += f' --name={name_model}'
36 | # print(args)
37 |
38 | args = args + f" --place_graph=layer2.1_layer3.1_layer4.1 --contrastive_type=nodes_temporal --rstg_combine=plus --node_confidence=none --use_detector=False --compute_iou=False --tmp_fix_kernel_grads=False --create_order=False --distill=False --distributed=True --bottleneck_graph=False --scaled_distill=False --alpha_distill=1.0 --glore_pos_emb_type=sinusoidal --npb --offset_generator=fishnet --predefined_augm=False --freeze_backbone=False --init_skip_zero=False --warmup_validate=False --smt_split=classic --use_rstg=True --tmp_init_different=0.0 --rstg_skip_connection=False --remap_first=True --isolate_graphs=False --dynamic_regions=constrained_fix_size --ch_div=4 --graph_residual_type=norm --rnn_type=GRU --aggregation_type=dot --send_layers=1 --update_layers=0 --combine_by_sum=True --project_at_input=True --tmp_norm_after_project_at_input=True --tmp_norm_skip_conn=False --tmp_norm_before_dynamic_graph=False --init_regions=center --tmp_increment_reg=True --lr 0.01 --batch-size 14 --dataset=cater --modality=RGB --arch=resnet18 --num_segments 13 --gd 20 --lr_type=step --lr_steps 75 125 200 --epochs 150 -j 16 --dropout=0.5 --consensus_type=avg --eval-freq=1 --shift_div=8 --shift --shift_place=blockres --model_dir=$MODEL_DIR"
39 |
40 |
41 | script_name = f'auto_scripts/start_' + name_model + '.sh'
42 |
43 | with open(script_name, 'w') as f:
44 | f.write('RAND=$((RANDOM))\n')
45 | f.write(f'MODEL_DIR=\'/models/graph_models/models_cater_dynamic_pytorch/exp_contrastive_global/{name_model}\'\n')
46 |
47 | f.write('LOG_NAME=$MODEL_DIR\'/log_\'$RAND\n')
48 | f.write('mkdir $MODEL_DIR\n')
49 | f.write('\n')
50 | f.write(f'args="{args}"\n')
51 | f.write('\n')
52 |
53 |
54 | f.write('cp ./train.py $MODEL_DIR\'/train.py_\'$RAND\n')
55 | f.write('mkdir $MODEL_DIR\'/code/\'\n')
56 | f.write('mkdir $MODEL_DIR\'/code/ops_\'$RAND\'/\'\n')
57 | f.write('cp -r ./ops/ $MODEL_DIR\'/code/ops_\'$RAND\'/\'\n')
58 | f.write('cp -r ./main.py $MODEL_DIR\'/code/\'\n')
59 | f.write('echo $args > $MODEL_DIR\'/args_\'$RAND\n')
60 |
61 | f.write('\n')
62 | f.write('\n')
63 |
64 | f.write(f'CUDA_VISIBLE_DEVICES=$1 python -u -m torch.distributed.launch --nproc_per_node=1 --master_port=$2 main_contrastive.py $args |& tee $LOG_NAME\n')
65 |
--------------------------------------------------------------------------------
/ops/dataset_config.py:
--------------------------------------------------------------------------------
1 | # Code adapted from "TSM: Temporal Shift Module for Efficient Video Understanding"
2 | # Ji Lin*, Chuang Gan, Song Han
3 |
4 |
5 | import os
6 |
7 | global args, best_prec1
8 | from opts import parse_args
9 | args = parse_args()
10 |
11 | if args.dataset == 'somethingv2':
12 | ROOT_DATASET = './data/smt-smt-V2/'
13 | elif args.dataset == 'something':
14 | ROOT_DATASET = './data/datasets/something-something/'
15 | elif args.dataset == 'others':
16 | ROOT_DATASET = './data/datasets/in_the_wild/'
17 | elif args.dataset == 'syncMNIST':
18 | ROOT_DATASET = './data/syncMNIST/'
19 | elif args.dataset == 'multiSyncMNIST':
20 | ROOT_DATASET = './data/multiSyncMNIST/'
21 |
22 |
23 | def return_something(modality):
24 | filename_categories = 'tsm_data/smtv1_category.txt'
25 | if modality == 'RGB':
26 | root_data = ROOT_DATASET + '/smt-smt-V2-frames/'
27 | filename_imglist_train = 'tsm_data/smtv1-train_videofolder.txt'
28 | filename_imglist_val = 'tsm_data/smtv1-val_videofolder.txt'
29 | prefix = '{:05d}.jpg'
30 | else:
31 | print('no such modality:'+modality)
32 | raise NotImplementedError
33 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
34 |
35 |
36 | def return_others(modality):
37 | filename_categories = 'categories.txt'
38 |
39 | root_data = ROOT_DATASET + '/gifs-frames-5s/'
40 | filename_imglist_train = 'gifs-frames-5s.txt'
41 | filename_imglist_val = 'gifs-frames-5s.txt'
42 | prefix = '{:06d}.jpg'
43 |
44 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
45 |
46 |
47 |
48 | def return_somethingv2(modality):
49 | if modality == 'RGB':
50 | filename_categories = 'tsm_data/category.txt'
51 | root_data = ROOT_DATASET + '/smt-smt-V2-frames/'
52 | filename_imglist_train = 'tsm_data/train_videofolder.txt'
53 | filename_imglist_val = 'tsm_data/val_videofolder.txt'
54 | prefix = '{:06d}.jpg'
55 | else:
56 | print('no such modality:'+modality)
57 | raise NotImplementedError
58 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
59 |
60 |
61 |
62 | def return_dataset(dataset, modality):
63 | if dataset == 'syncMNIST':
64 | n_class = 46
65 | test_dataset = '/data/datasets/video_mnist/sync_mnist_large_v2_split_test_max_sync_dist_160_num_classes_46_no_digits_5_no_noise_parts_0/'
66 | train_dataset = '/data/datasets/video_mnist/sync_mnist_large_v2_split_train_max_sync_dist_160_num_classes_46_no_digits_5_no_noise_parts_0/'
67 | return n_class, train_dataset, test_dataset, None, None
68 | elif dataset == 'multiSyncMNIST':
69 | n_class = 56
70 | test_dataset = args.test_dataset
71 | train_dataset = args.train_dataset
72 | return n_class, train_dataset, test_dataset, None, None
73 | dict_single = {'something': return_something,
74 | 'somethingv2': return_somethingv2,
75 | 'others' : return_others}
76 | if dataset in dict_single:
77 | file_categories, file_imglist_train, file_imglist_val, root_data, prefix = dict_single[dataset](modality)
78 | else:
79 | raise ValueError('Unknown dataset '+dataset)
80 |
81 | file_imglist_train = os.path.join(ROOT_DATASET, file_imglist_train)
82 | file_imglist_val = os.path.join(ROOT_DATASET, file_imglist_val)
83 |
84 | if isinstance(file_categories, str):
85 | file_categories = os.path.join(ROOT_DATASET, file_categories)
86 | with open(file_categories) as f:
87 | lines = f.readlines()
88 | categories = [item.rstrip() for item in lines]
89 | else: # number of categories
90 | categories = [None] * file_categories
91 | n_class = len(categories)
92 | print('{}: {} classes'.format(dataset, n_class))
93 | return n_class, file_imglist_train, file_imglist_val, root_data, prefix
94 |
--------------------------------------------------------------------------------
/ops/non_local.py:
--------------------------------------------------------------------------------
1 | # Non-local block using embedded gaussian
2 | # Code from
3 | # https://github.com/AlexHex7/Non-local_pytorch/blob/master/Non-Local_pytorch_0.3.1/lib/non_local_embedded_gaussian.py
4 | import torch
5 | from torch import nn
6 | from torch.nn import functional as F
7 |
8 |
9 | class _NonLocalBlockND(nn.Module):
10 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):
11 | super(_NonLocalBlockND, self).__init__()
12 |
13 | assert dimension in [1, 2, 3]
14 |
15 | self.dimension = dimension
16 | self.sub_sample = sub_sample
17 |
18 | self.in_channels = in_channels
19 | self.inter_channels = inter_channels
20 |
21 | if self.inter_channels is None:
22 | self.inter_channels = in_channels // 2
23 | if self.inter_channels == 0:
24 | self.inter_channels = 1
25 |
26 | if dimension == 3:
27 | conv_nd = nn.Conv3d
28 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
29 | bn = nn.BatchNorm3d
30 | elif dimension == 2:
31 | conv_nd = nn.Conv2d
32 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
33 | bn = nn.BatchNorm2d
34 | else:
35 | conv_nd = nn.Conv1d
36 | max_pool_layer = nn.MaxPool1d(kernel_size=(2))
37 | bn = nn.BatchNorm1d
38 |
39 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
40 | kernel_size=1, stride=1, padding=0)
41 |
42 | if bn_layer:
43 | self.W = nn.Sequential(
44 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
45 | kernel_size=1, stride=1, padding=0),
46 | bn(self.in_channels)
47 | )
48 | nn.init.constant_(self.W[1].weight, 0)
49 | nn.init.constant_(self.W[1].bias, 0)
50 | else:
51 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
52 | kernel_size=1, stride=1, padding=0)
53 | nn.init.constant_(self.W.weight, 0)
54 | nn.init.constant_(self.W.bias, 0)
55 |
56 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
57 | kernel_size=1, stride=1, padding=0)
58 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
59 | kernel_size=1, stride=1, padding=0)
60 |
61 | if sub_sample:
62 | self.g = nn.Sequential(self.g, max_pool_layer)
63 | self.phi = nn.Sequential(self.phi, max_pool_layer)
64 |
65 | def forward(self, x):
66 | '''
67 | :param x: (b, c, t, h, w)
68 | :return:
69 | '''
70 |
71 | batch_size = x.size(0)
72 |
73 | g_x = self.g(x).view(batch_size, self.inter_channels, -1)
74 | g_x = g_x.permute(0, 2, 1)
75 |
76 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
77 | theta_x = theta_x.permute(0, 2, 1)
78 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
79 | f = torch.matmul(theta_x, phi_x)
80 | f_div_C = F.softmax(f, dim=-1)
81 |
82 | y = torch.matmul(f_div_C, g_x)
83 | y = y.permute(0, 2, 1).contiguous()
84 | y = y.view(batch_size, self.inter_channels, *x.size()[2:])
85 | W_y = self.W(y)
86 | z = W_y + x
87 |
88 | return z
89 |
90 |
91 | class NONLocalBlock1D(_NonLocalBlockND):
92 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
93 | super(NONLocalBlock1D, self).__init__(in_channels,
94 | inter_channels=inter_channels,
95 | dimension=1, sub_sample=sub_sample,
96 | bn_layer=bn_layer)
97 |
98 |
99 | class NONLocalBlock2D(_NonLocalBlockND):
100 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
101 | super(NONLocalBlock2D, self).__init__(in_channels,
102 | inter_channels=inter_channels,
103 | dimension=2, sub_sample=sub_sample,
104 | bn_layer=bn_layer)
105 |
106 |
107 | class NONLocalBlock3D(_NonLocalBlockND):
108 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
109 | super(NONLocalBlock3D, self).__init__(in_channels,
110 | inter_channels=inter_channels,
111 | dimension=3, sub_sample=sub_sample,
112 | bn_layer=bn_layer)
113 |
114 |
115 | class NL3DWrapper(nn.Module):
116 | def __init__(self, block, n_segment):
117 | super(NL3DWrapper, self).__init__()
118 | self.block = block
119 | self.nl = NONLocalBlock3D(block.bn3.num_features)
120 | self.n_segment = n_segment
121 |
122 | def forward(self, x):
123 | x = self.block(x)
124 |
125 | nt, c, h, w = x.size()
126 | x = x.view(nt // self.n_segment, self.n_segment, c, h, w).transpose(1, 2) # n, c, t, h, w
127 | x = self.nl(x)
128 | x = x.transpose(1, 2).contiguous().view(nt, c, h, w)
129 | return x
130 |
131 |
132 | def make_non_local(net, n_segment):
133 | import torchvision
134 | import archs
135 | if isinstance(net, torchvision.models.ResNet):
136 | net.layer2 = nn.Sequential(
137 | NL3DWrapper(net.layer2[0], n_segment),
138 | net.layer2[1],
139 | NL3DWrapper(net.layer2[2], n_segment),
140 | net.layer2[3],
141 | )
142 | net.layer3 = nn.Sequential(
143 | NL3DWrapper(net.layer3[0], n_segment),
144 | net.layer3[1],
145 | NL3DWrapper(net.layer3[2], n_segment),
146 | net.layer3[3],
147 | NL3DWrapper(net.layer3[4], n_segment),
148 | net.layer3[5],
149 | )
150 | else:
151 | raise NotImplementedError
152 |
153 |
154 | if __name__ == '__main__':
155 | from torch.autograd import Variable
156 | import torch
157 |
158 | sub_sample = True
159 | bn_layer = True
160 |
161 | img = Variable(torch.zeros(2, 3, 20))
162 | net = NONLocalBlock1D(3, sub_sample=sub_sample, bn_layer=bn_layer)
163 | out = net(img)
164 | print(out.size())
165 |
166 | img = Variable(torch.zeros(2, 3, 20, 20))
167 | net = NONLocalBlock2D(3, sub_sample=sub_sample, bn_layer=bn_layer)
168 | out = net(img)
169 | print(out.size())
170 |
171 | img = Variable(torch.randn(2, 3, 10, 20, 20))
172 | net = NONLocalBlock3D(3, sub_sample=sub_sample, bn_layer=bn_layer)
173 | out = net(img)
174 | print(out.size())
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Discovering Dynamic Salient Regions with Spatio-Temporal Graph Neural Networks
2 | This is the official code for DyReg model inroduced in [Discovering Dynamic Salient Regions with Spatio-Temporal Graph Neural Networks](https://arxiv.org/abs/2009.08427)
3 |
4 |
5 |

6 |
7 |
8 | ## Citation
9 | Please use the following BibTeX to cite our work.
10 | ```
11 | @incollection{duta2021dynamic_dyreg_gnn_neurips2021,
12 | title = {Discovering Dynamic Salient Regions with Spatio-Temporal Graph
13 | Neural Networks},
14 | author = {Duta, Iulia and Nicolicioiu, Andrei and Leordeanu, Marius},
15 | booktitle = {Advances in Neural Information Processing Systems 34},
16 | year = {2021}
17 | }
18 |
19 | @article{duta2020dynamic_dyreg,
20 | title = {Dynamic Regions Graph Neural Networks for Spatio-Temporal Reasoning},
21 | author = {Duta, Iulia and Nicolicioiu, Andrei and Leordeanu, Marius},
22 | journal = {NeurIPS 2020 Workshop on Object Representations for Learning and Reasoning},
23 | year = {2020},
24 | }
25 | ```
26 |
27 | ## Requirements
28 |
29 | The code was developed using:
30 |
31 | - python 3.7
32 | - matplotlib
33 | - torch 1.7.1
34 | - script
35 | - pandas
36 | - torchvision
37 | - moviepy
38 | - ffmpeg
39 |
40 |
41 | ## Overview:
42 | The repository contains the Pytorch implementation of the DyReg-GNN model.
43 | The model is defined and trained in the following files:
44 | - [ops/dyreg.py](https://github.com/andreinicolicioiu/dyreg_clean/blob/main/ops/dyreg.py) - code for our DyReg module
45 | - [ops/rstg.py](https://github.com/andreinicolicioiu/dyreg_clean/blob/main/ops/rstg.py) - code for the Spatio-temporal GNN (RSTG) used to process the graph extracted using DyReg
46 |
47 | - [create_model.py](https://github.com/andreinicolicioiu/dyreg_clean/blob/main/create_model.py) - two examples how to integrate the DyReg-GNN module inside an existing backbone
48 | - [main_standard.py](https://github.com/andreinicolicioiu/dyreg_clean/blob/main/main_standard.py) - code to train a model on Smt-Smt dataset
49 | - [test_models.py](https://github.com/andreinicolicioiu/dyreg_clean/blob/main/test_models.py) - code for multi-clip evaluation
50 |
51 | Scripts for preparing the data, training and testing the model:
52 | - [train_model.sh](https://github.com/andreinicolicioiu/dyreg_clean/blob/main/train_model.sh) - example of script to train DyReg-GNN
53 | - [evaluate_model.sh](https://github.com/andreinicolicioiu/dyreg_clean/blob/main/evaluate_model.sh) - example of script to evaluate on a single clip DyReg-GNN
54 | - [evaluate_model_multi_clip.sh](https://github.com/andreinicolicioiu/dyreg_clean/blob/main/evaluate_model_multi_clip.sh) - example of script to evaluate on multiple clips DyReg-GNN
55 |
56 | - [tools/](https://github.com/andreinicolicioiu/dyreg_clean/blob/main/tools) contains all the script used to prepare Smt-Smt dataset (similar to the setup used in TSM)
57 |
58 |
59 | ## Prepare dataset
60 |
61 | For [Something Something dataset](https://arxiv.org/pdf/1706.04261v2.pdf):
62 | * the json files containing meta-data should be stored in `./data/smt-smt-V2/tsm_data`
63 | * the zip files containing the videos should be stored in `./data/smt-smt-V2/`
64 |
65 | - - - -
66 |
67 | 1. **To extract the videos from the zip files run:**
68 |
69 | `cat 20bn-something-something-v2-?? | tar zx`
70 |
71 | 2. **To extract the frames from videos run:**
72 |
73 | `python tools/vid2img_sthv2.py`
74 |
75 | → The videos will be stored in *$FRAME_ROOT* (default `'./data/smt-smt-V2/tmp_smt-smt-V2-frames'`)
76 |
77 | 💡 *If you already have the dataset as frames, place them under `./data/smt-smt-V2/smt-smt-V2-frames/`, one folder for each video* \
78 | 💡💡 *If you need to change the path for datasets modify *$ROOT_DATASET* in [dataset_config.py](https://github.com/andreinicolicioiu/dyreg_clean/blob/main/ops/dataset_config.py)*
79 |
80 |
81 | 3. **To generate the labels file in the required format please run:**
82 |
83 | `python tools/gen_label_sthv2.py `
84 |
85 | → The resulting txt files, for each split, will be stored in *$DATA_UTILS_ROOT* (default `'./data/smt-smt-V2/tsm_data/'`)
86 |
87 |
88 | ## How to run the model
89 |
90 | DyReg-GNN module can be simply inserted into any space-time model.
91 | ```
92 | import torch
93 | from torch.nn import functional as F
94 | from ops.dyreg import DynamicGraph, dyregParams
95 |
96 | class SpaceTimeModel(torch.nn.Module):
97 | def __init__(self):
98 | super(SpaceTimeModel, self).__init__()
99 | dyreg_params = dyregParams()
100 | dyregParams.offset_lstm_dim = 32
101 | self.dyreg = DynamicGraph(dyreg_params,
102 | backbone_dim=32, node_dim=32, out_num_ch=32,
103 | H=16, W=16,
104 | iH=16, iW=16,
105 | project_i3d=False,
106 | name='lalalal')
107 |
108 |
109 | self.fc = torch.nn.Linear(32, 10)
110 |
111 | def forward(self, x):
112 | dx = self.dyreg(x)
113 | # you can initialize the dyreg branch as identity function by normalisation,
114 | # as done in DynamicGraphWrapper found in ./ops/dyreg.py
115 | x = x + dx
116 | # average over time and space: T, H, W
117 | x = x.mean(-1).mean(-1).mean(-2)
118 | x = self.fc(x)
119 | return x
120 |
121 |
122 | B = 8
123 | T = 10
124 | C = 32
125 | H = 16
126 | W = 16
127 | x = torch.ones(B,T,C,H,W)
128 | st_model = SpaceTimeModel()
129 | out = st_model(x)
130 | ```
131 |
132 |
133 | For another example of how to integrate DyReg (DynamicGraph module) inside your model please look at [create_model.py](https://github.com/andreinicolicioiu/dyreg_clean/blob/main/create_model.py) or run:
134 |
135 | `python create_model.py`
136 |
137 |
138 | ### Something-Something experiments
139 |
140 | #### Training a model
141 |
142 | To train a model on smt-smt v2 dataset please run
143 |
144 | `./start_main_standard.sh model_name`
145 |
146 | For default hyperparameters check opts.py. For example, `place_graph` flag controls how many DyReg-GNN modules to use and where to place them inside the backbone:
147 | ```
148 | # for a model with 3 DyReg-GNN modules placed after layer 2-block 2, layer 3-block 4 and layer 4-block 1 of the backbone
149 | --place_graph=layer2.2_layer3.4_layer4.1
150 | # for a model with 1 dyreg module placed after layer 3 block 4 of the backbone
151 | --place_graph=layer3.4
152 | ```
153 |
154 | #### Single clip evaluation
155 |
156 | Train a model with the above script or download a pre-trained DyReg-GNN model from [here](https://drive.google.com/file/d/1MA_zfHSoutTVL6X9JxyIBLkX0rNJmC_6/view?usp=sharing) and put the checkpoint in ./ckeckpoints/
157 |
158 | To evaluate a model on smt-smt v2 dataset on a single 224 x 224 central crop, run:
159 |
160 | `./start_main_standard_test.sh model_name`
161 |
162 | The flag `$RESUME_CKPT` indicate the the checkpoint used for evaluation.
163 |
164 |
165 | #### Multi clips evaluation
166 |
167 | To evaluate a model in the multi-clips setup (3 spatials clips x 2 temporal samplings) on Smt-Smt v2 dataset please run
168 |
169 | `./evaluate_model.sh model_name`
170 |
171 | The flag `$RESUME_CKPT` indicate the the checkpoint used for evaluation.
172 |
173 | ### TSM Baseline
174 | This repository adds DyReg-GNN modules to a TSM backbone based on code from [here](https://github.com/mit-han-lab/temporal-shift-module).
175 |
--------------------------------------------------------------------------------
/ops/rstg.py:
--------------------------------------------------------------------------------
1 |
2 | # Code for RSTG model
3 | # Recurrent Space-time Graph Neural Networks - RSTG
4 | # adapted from https://github.com/IuliaDuta/RSTG
5 |
6 | import torch
7 | from torch import nn
8 | from torch.nn import functional as F
9 | import sys
10 | import os
11 | sys.path.append(os.path.abspath(os.path.join(
12 | os.path.dirname(os.path.realpath(__file__)), '..')))
13 | from opts import parser
14 | from ops.models_utils import *
15 |
16 | global args, best_prec1
17 | from opts import parse_args
18 | args = parser.parse_args()
19 | args = parse_args()
20 | import pdb
21 |
22 |
23 | class RSTG(nn.Module):
24 | def __init__(self,params, backbone_dim=1024, node_dim=512, project_i3d=True):
25 |
26 | super(RSTG, self).__init__()
27 | self.params = params
28 | self.backbone_dim = backbone_dim
29 | self.node_dim = node_dim
30 | self.number_iterations = 3
31 | self.num_nodes = 9
32 | self.project_i3d = project_i3d
33 | self.norm_dict = nn.ModuleDict({})
34 |
35 | # intern LSTM
36 | self.internal_lstm = recurrent_net(batch_first=True,
37 | input_size=self.node_dim, hidden_size=self.node_dim)
38 | # extern LSTM
39 | self.external_lstm = recurrent_net(batch_first=True,
40 | input_size=self.node_dim, hidden_size=self.node_dim)
41 |
42 | # send function
43 | if self.params.send_layers == 1:
44 | self.send_mlp = nn.Sequential(
45 | torch.nn.Linear(self.node_dim, self.node_dim),
46 | torch.nn.ReLU()
47 | )
48 | elif self.params.send_layers == 2:
49 | if self.params.combine_by_sum:
50 | comb_mult = 1
51 | else:
52 | comb_mult = 2
53 | self.send_mlp = nn.Sequential(
54 | torch.nn.Linear(comb_mult * self.node_dim, self.node_dim),
55 | torch.nn.ReLU(),
56 | torch.nn.Linear(self.node_dim, self.node_dim),
57 | torch.nn.ReLU()
58 | )
59 |
60 | # norm send
61 | self.norm_dict['send_message_norm'] = LayerNormAffineXC(
62 | self.node_dim, (self.num_nodes,self.num_nodes,self.node_dim)
63 | )
64 | self.norm_dict['update_norm'] = LayerNormAffineXC(
65 | self.node_dim, (self.num_nodes, self.node_dim)
66 | )
67 | self.norm_dict['before_graph_norm'] = LayerNormAffineXC(
68 | self.node_dim, (self.num_nodes, self.node_dim)
69 | )
70 |
71 |
72 |
73 | # attention function
74 | if 'dot' in self.params.aggregation_type:
75 | self.att_q = torch.nn.Linear(self.node_dim, self.node_dim)
76 | self.att_k = torch.nn.Linear(self.node_dim, self.node_dim)
77 | # attention bias
78 | self.att_bias = torch.nn.Parameter(
79 | torch.zeros(size=[1,1,1,self.node_dim], requires_grad=True)
80 | )
81 |
82 | # update function
83 | if self.params.update_layers == 0:
84 | self.update_mlp = nn.Identity()
85 | elif self.params.update_layers == 1:
86 | self.update_mlp = nn.Sequential(
87 | torch.nn.Linear(self.node_dim, self.node_dim),
88 | torch.nn.ReLU()
89 | )
90 | elif self.params.update_layers == 2:
91 | if self.params.combine_by_sum:
92 | comb_mult = 1
93 | else:
94 | comb_mult = 2
95 | self.update_mlp = nn.Sequential(
96 | torch.nn.Linear(comb_mult * self.node_dim, self.node_dim),
97 | torch.nn.ReLU(),
98 | torch.nn.Linear(self.node_dim, self.node_dim),
99 | torch.nn.ReLU()
100 | )
101 |
102 | def get_norm(self, input, name, zero_init=False):
103 | # input: B x N x T x C
104 | # or B x N x N x T x C
105 | if len(input.size()) == 5:
106 | input = input.permute(0,3,1,2,4).contiguous()
107 | elif len(input.size()) == 4:
108 | input = input.permute(0,2,1,3).contiguous()
109 |
110 | norm = self.norm_dict[name]
111 |
112 | input = norm(input)
113 | if len(input.size()) == 5:
114 | input = input.permute(0,2,3,1,4).contiguous()
115 | elif len(input.size()) == 4:
116 | input = input.permute(0,2,1,3).contiguous()
117 | return input
118 |
119 | def send_messages(self, nodes):
120 | # nodes: B x N x T x C
121 | # nodes1: B x 1 x N x T x C
122 | # nodes2: B x N x 1 x T x C
123 | nodes1 = nodes.unsqueeze(1)
124 | nodes2 = nodes.unsqueeze(2)
125 |
126 | nodes1 = nodes1.repeat(1,self.num_nodes,1,1,1)
127 | nodes2 = nodes2.repeat(1,1,self.num_nodes,1,1)
128 | # B x N x N x T x 2C
129 | if self.params.combine_by_sum:
130 | messages = nodes1 + nodes2
131 | else:
132 | messages = torch.cat((nodes1, nodes2),dim=-1)
133 | messages = self.send_mlp(messages)
134 | messages = self.get_norm(messages, 'send_message_norm')
135 | return messages
136 |
137 | def aggregate(self, nodes, messages):
138 | if 'sum' in self.params.aggregation_type:
139 | return self.sum_aggregate(messages)
140 | elif 'dot' in self.params.aggregation_type:
141 | return self.dot_attention(nodes, messages)
142 |
143 | def sum_aggregate(self, messages):
144 | # nodes: B x N x T x C
145 | # messages: B x NxN x T x C
146 | return messages.mean(dim=2)
147 |
148 | def dot_attention(self, nodes, messages):
149 | # nodes: B x N x T x C
150 | # messages: B x NxN x T x C
151 |
152 | # nodes1: B x N x 1 x T x C
153 | # nodes2: B x 1 x N x T x C
154 | # corr B x N x N x T
155 |
156 | nodes_q = self.att_q(nodes)
157 | nodes_k = self.att_k(nodes)
158 |
159 | nodes_q = nodes_q.permute(0,2,1,3)
160 | nodes_k = nodes_k.permute(0,2,3,1)
161 |
162 | corr = torch.matmul(nodes_q, nodes_k).unsqueeze(-1)
163 | corr = corr.permute(0,2,3,1,4)
164 |
165 | nodes = F.softmax(corr, dim=2) * messages
166 | nodes = nodes.sum(dim=2)
167 |
168 | nodes = nodes + self.att_bias
169 | nodes = F.relu(nodes)
170 | return nodes
171 |
172 | def update_nodes(self, nodes, aggregated):
173 | if self.params.combine_by_sum:
174 | upd_input = nodes + aggregated
175 | else:
176 | upd_input = torch.cat((nodes, aggregated), dim=-1)
177 | nodes = self.update_mlp(upd_input)
178 | nodes = self.get_norm(nodes, 'update_norm')
179 | return nodes
180 |
181 | def forward(self, input):
182 | self.B = input.shape[0]
183 | self.T = input.shape[1]
184 |
185 | # input RSTG: B x T x C x H x W
186 | # set input: ... -> B x T x C x N
187 |
188 | # for LSTM we need (B * N) x T x C
189 | # propagation: B x 1 x N x T x C
190 | # + B x N x 1 x T x C
191 | # (B x N*N x T) x C => liniar
192 | # nodes: B x N x T x C
193 | nodes = input.permute(0,3,1,2)
194 |
195 | nodes = self.get_norm(nodes, 'before_graph_norm')
196 | time_iter_mom = [0, 1, 2]
197 |
198 | for space_iter in range(self.number_iterations):
199 | # internal time processing
200 | if space_iter in time_iter_mom:
201 | nodes = nodes.view(self.B * self.num_nodes, self.T, self.node_dim)
202 | self.internal_lstm.flatten_parameters()
203 | nodes, _ = self.internal_lstm(nodes)
204 | nodes = nodes.view(self.B, self.num_nodes, self.T, self.node_dim)
205 |
206 | # space_processing: send, aggregate, update
207 | messages = self.send_messages(nodes)
208 | aggregated = self.aggregate(nodes, messages)
209 | nodes = self.update_nodes(nodes, aggregated)
210 |
211 | # external time processing
212 | nodes = nodes.view(self.B * self.num_nodes, self.T, self.node_dim)
213 | self.external_lstm.flatten_parameters()
214 | nodes, _ = self.external_lstm(nodes)
215 | nodes = nodes.view(self.B, self.num_nodes, self.T, self.node_dim)
216 |
217 |
218 | # B x N x T x C -> B x T x C x N
219 | nodes = nodes.permute(0,2,3,1).contiguous()
220 |
221 | return nodes
222 |
--------------------------------------------------------------------------------
/ops/models_config.py:
--------------------------------------------------------------------------------
1 | from opts import parser
2 | global args
3 | args = parser.parse_args()
4 |
5 | def return_config_resnet13():
6 | graph_params = {}
7 | graph_params[1] = {'in_channels' : 64, 'H' : 32, 'node_dim' : 2 * 64 // args.ch_div, 'project_i3d' : True, 'name': 'layer1'}
8 | graph_params[2] = {'in_channels' : 128, 'H' : 16, 'node_dim' : 128 // args.ch_div, 'project_i3d' : True, 'name': 'layer2'}
9 | graph_params[3] = {'in_channels' : 256, 'H' : 8, 'node_dim' : 256 // args.ch_div, 'project_i3d' : True, 'name': 'layer3'}
10 | out_pool_size = 8
11 | # 256
12 | return graph_params, out_pool_size
13 |
14 | def return_config_resnet18():
15 | graph_params = {}
16 | if args.bottleneck_graph:
17 | if args.full_res:
18 | graph_params[1] = {'in_channels' : 64 // args.ch_div, 'iH' : 56, 'H' : 64, 'node_dim' : 64 // args.ch_div, 'name': 'layer1'}
19 | graph_params[2] = {'in_channels' : 128// args.ch_div, 'iH' : 28, 'H' : 32, 'node_dim' : 128 // args.ch_div, 'name': 'layer2'}
20 | graph_params[3] = {'in_channels' : 256// args.ch_div, 'iH' : 14, 'H' : 16, 'node_dim' : 256 // args.ch_div, 'name': 'layer3'}
21 | graph_params[4] = {'in_channels' : 512// args.ch_div, 'iH' : 7, 'H' : 8, 'node_dim' : min(256,512 // args.ch_div), 'project_i3d' : True, 'name': 'layer4'}
22 | out_pool_size = 8
23 | else:
24 | graph_params[1] = {'in_channels' : 64 // args.ch_div, 'H' : 56, 'iH' : 56, 'node_dim' : 64 // args.ch_div, 'name': 'layer1'}
25 | graph_params[2] = {'in_channels' : 128// args.ch_div, 'H' : 28, 'iH' : 28, 'node_dim' : 128 // args.ch_div, 'name': 'layer2'}
26 | graph_params[3] = {'in_channels' : 256// args.ch_div, 'H' : 14, 'iH' : 14, 'node_dim' : 256 // args.ch_div, 'name': 'layer3'}
27 | graph_params[4] = {'in_channels' : 512// args.ch_div, 'H' : 7, 'iH' : 7, 'node_dim' : min(256,512 // args.ch_div), 'project_i3d' : True, 'name': 'layer4'}
28 | out_pool_size = 7
29 |
30 | for i in [1,2,3]:
31 | graph_params[i]['project_i3d'] = (args.offset_generator == 'glore') #True for glore, False for fishnet
32 | else:
33 | graph_params[1] = {'in_channels' : 64, 'H' : 56, 'node_dim' : 64 // args.ch_div, 'project_i3d' : True, 'name': 'layer1'}
34 | graph_params[2] = {'in_channels' : 128, 'H' : 28, 'node_dim' : 128 // args.ch_div, 'project_i3d' : True, 'name': 'layer2'}
35 | graph_params[3] = {'in_channels' : 256, 'H' : 14, 'node_dim' : 256 // args.ch_div, 'project_i3d' : True, 'name': 'layer3'}
36 | graph_params[4] = {'in_channels' : 512, 'H' : 7, 'node_dim' : min(512,512 // args.ch_div), 'project_i3d' : True, 'name': 'layer4'}
37 | out_pool_size = 7
38 |
39 | return graph_params, out_pool_size
40 |
41 | def return_config_resnet34():
42 | graph_params = {}
43 |
44 | graph_params[1] = {'in_channels' : 64, 'H' : 56, 'node_dim' : 64 // args.ch_div, 'project_i3d' : True, 'name': 'layer1'}
45 | graph_params[2] = {'in_channels' : 128, 'H' : 28, 'node_dim' : 128 // args.ch_div, 'project_i3d' : True, 'name': 'layer2'}
46 | graph_params[3] = {'in_channels' : 256, 'H' : 14, 'node_dim' : 256 // args.ch_div, 'project_i3d' : True, 'name': 'layer3'}
47 | graph_params[4] = {'in_channels' : 512, 'H' : 7, 'node_dim' : min(512,512 // args.ch_div), 'project_i3d' : True, 'name': 'layer4'}
48 | out_pool_size = 7
49 |
50 | return graph_params, out_pool_size
51 |
52 |
53 |
54 | def return_config_wide_resnet50_2():
55 | graph_params = {}
56 |
57 | graph_params[1] = {'in_channels' : 512 // args.ch_div, 'H' : 56, 'iH' : 56, 'node_dim' : 512 // args.ch_div, 'name': 'layer1'}
58 | graph_params[2] = {'in_channels' : 1024// args.ch_div, 'H' : 28, 'iH' : 28, 'node_dim' : 1024 // args.ch_div, 'name': 'layer2'}
59 | graph_params[3] = {'in_channels' : 2048// args.ch_div, 'H' : 14, 'iH' : 14, 'node_dim' : 2048 // args.ch_div, 'name': 'layer3'}
60 | graph_params[4] = {'in_channels' : 4096// args.ch_div, 'H' : 7, 'iH' : 7, 'node_dim' : min(512,4096 // args.ch_div), 'project_i3d' : True, 'name': 'layer4'}
61 |
62 |
63 | out_pool_size = 7
64 |
65 | for i in [1,2,3]:
66 | graph_params[i]['project_i3d'] = (args.offset_generator == 'glore') #True for glore, False for fishnet
67 |
68 |
69 | return graph_params, out_pool_size
70 | def return_config_resnet50():
71 | graph_params = {}
72 | out_pool_size = 0
73 | if args.bottleneck_graph:
74 |
75 | if args.full_res:
76 | graph_params[1] = {'in_channels' : 256 // args.ch_div, 'iH' : 56, 'H' : 64, 'node_dim' : 256 // args.ch_div, 'name': 'layer1'}
77 | graph_params[2] = {'in_channels' : 512// args.ch_div, 'iH' : 28, 'H' : 32, 'node_dim' : 512 // args.ch_div, 'name': 'layer2'}
78 | graph_params[3] = {'in_channels' : 1024// args.ch_div, 'iH' : 14, 'H' : 16, 'node_dim' : 1024 // args.ch_div, 'name': 'layer3'}
79 | graph_params[4] = {'in_channels' : 2048// args.ch_div, 'iH' : 7, 'H' : 8, 'node_dim' : min(256,2048 // args.ch_div), 'project_i3d' : True, 'name': 'layer4'}
80 | out_pool_size = 8
81 |
82 | else:
83 | graph_params[1] = {'in_channels' : 256 // args.ch_div, 'H' : 56, 'iH' : 56, 'node_dim' : 256 // args.ch_div, 'name': 'layer1'}
84 | graph_params[2] = {'in_channels' : 512// args.ch_div, 'H' : 28, 'iH' : 28, 'node_dim' : 512 // args.ch_div, 'name': 'layer2'}
85 | graph_params[3] = {'in_channels' : 1024// args.ch_div, 'H' : 14, 'iH' : 14, 'node_dim' : 1024 // args.ch_div, 'name': 'layer3'}
86 | graph_params[4] = {'in_channels' : 2048// args.ch_div, 'H' : 7, 'iH' : 7, 'node_dim' : min(256,2048 // args.ch_div), 'project_i3d' : True, 'name': 'layer4'}
87 | out_pool_size = 7
88 |
89 | for i in [1,2,3]:
90 | graph_params[i]['project_i3d'] = (args.offset_generator == 'glore') #True for glore, False for fishnet
91 |
92 |
93 | else:
94 | # TO BE REEMOVED
95 | graph_params[1] = {'in_channels' : 256, 'H' : 56, 'node_dim' : 256 // args.ch_div, 'project_i3d' : True, 'name': 'layer1'}
96 | graph_params[2] = {'in_channels' : 512, 'H' : 28, 'node_dim' : 512 // args.ch_div, 'project_i3d' : True, 'name': 'layer2'}
97 | graph_params[3] = {'in_channels' : 1024, 'H' : 14, 'node_dim' : 1024 // args.ch_div, 'project_i3d' : True, 'name': 'layer3'}
98 | graph_params[4] = {'in_channels' : 2048, 'H' : 7, 'node_dim' : min(512,2048 // args.ch_div), 'project_i3d' : True, 'name': 'layer4'}
99 |
100 | return graph_params, out_pool_size
101 |
102 | def return_config_resnet101():
103 | graph_params = {}
104 | if args.bottleneck_graph:
105 |
106 | graph_params[1] = {'in_channels' : 256 // args.ch_div, 'H' : 56, 'iH' : 56, 'node_dim' : 256 // args.ch_div, 'name': 'layer1'}
107 | graph_params[2] = {'in_channels' : 512// args.ch_div, 'H' : 28, 'iH' : 28, 'node_dim' : 512 // args.ch_div, 'name': 'layer2'}
108 | graph_params[3] = {'in_channels' : 1024// args.ch_div, 'H' : 14, 'iH' : 14, 'node_dim' : 1024 // args.ch_div, 'name': 'layer3'}
109 | graph_params[4] = {'in_channels' : 2048// args.ch_div, 'H' : 7, 'iH' : 7, 'node_dim' : min(256,2048 // args.ch_div), 'project_i3d' : True, 'name': 'layer4'}
110 | out_pool_size = 7
111 |
112 | for i in [1,2,3]:
113 | graph_params[i]['project_i3d'] = (args.offset_generator == 'glore') #True for glore, False for fishnet
114 | else:
115 | graph_params[1] = {'in_channels' : 256, 'H' : 56, 'node_dim' : 256 // args.ch_div, 'project_i3d' : True, 'name': 'layer1'}
116 | graph_params[2] = {'in_channels' : 512, 'H' : 28, 'node_dim' : 512 // args.ch_div, 'project_i3d' : True, 'name': 'layer2'}
117 | graph_params[3] = {'in_channels' : 1024, 'H' : 14, 'node_dim' : 1024 // args.ch_div, 'project_i3d' : True, 'name': 'layer3'}
118 | graph_params[4] = {'in_channels' : 2048, 'H' : 7, 'node_dim' : min(512,2048 // args.ch_div), 'project_i3d' : True, 'name': 'layer4'}
119 | out_pool_size = 7
120 | if args.full_res:
121 | out_pool_size = 8
122 | return graph_params, out_pool_size
123 |
124 |
125 |
126 |
127 |
128 | def get_models_config():
129 | graph_params = {}
130 | if args.arch == 'resnet13':
131 | graph_params, out_pool_size = return_config_resnet13()
132 | if args.arch == 'resnet18':
133 | graph_params, out_pool_size = return_config_resnet18()
134 | elif args.arch == 'resnet34':
135 | graph_params, out_pool_size = return_config_resnet34()
136 | elif args.arch == 'resnet50':
137 | graph_params, out_pool_size = return_config_resnet50()
138 | elif args.arch == 'wide_resnet50_2':
139 | graph_params, out_pool_size = return_config_wide_resnet50_2()
140 | elif args.arch == 'resnet101':
141 | graph_params, out_pool_size = return_config_resnet101()
142 | out_num_ch = 2048
143 | return graph_params, out_pool_size, out_num_ch, None
--------------------------------------------------------------------------------
/ops/dataset.py:
--------------------------------------------------------------------------------
1 | # Code adapted from "TSM: Temporal Shift Module for Efficient Video Understanding"
2 | # Ji Lin*, Chuang Gan, Song Han
3 |
4 |
5 | import torch.utils.data as data
6 |
7 | from PIL import Image
8 | import os
9 | import numpy as np
10 | from numpy.random import randint
11 | global args, best_prec1
12 | from opts import parse_args
13 | args = parse_args()
14 |
15 | class VideoRecord(object):
16 | def __init__(self, row):
17 | self._data = row
18 |
19 | @property
20 | def path(self):
21 | return self._data[0]
22 |
23 | @property
24 | def num_frames(self):
25 | return int(self._data[1])
26 |
27 | @property
28 | def label(self):
29 | return int(self._data[2])
30 |
31 |
32 | class TSNDataSet(data.Dataset):
33 | def __init__(self, root_path, list_file,
34 | num_segments=3, new_length=1, modality='RGB',
35 | image_tmpl='img_{:05d}.jpg', transform=None,
36 | random_shift=True, test_mode=False,
37 | remove_missing=False, dense_sample=False, twice_sample=False, split=''):
38 |
39 | self.root_path = root_path
40 | self.list_file = list_file
41 | self.num_segments = num_segments
42 | self.new_length = new_length
43 | self.modality = modality
44 | self.image_tmpl = image_tmpl
45 | self.transform = transform
46 | self.random_shift = random_shift
47 | self.test_mode = test_mode
48 | self.remove_missing = remove_missing
49 | self.dense_sample = dense_sample # using dense sample as I3D
50 | self.twice_sample = twice_sample # twice sample for more validation
51 | if self.dense_sample:
52 | print('=> Using dense sample for the dataset...')
53 | if self.twice_sample:
54 | print('=> Using twice sample for the dataset...')
55 |
56 | self.split = split
57 | self.epoch = -1
58 | self._parse_list()
59 |
60 |
61 | def set_epoch_attributes(self, epoch, augment_dict, detected_boxes_dict=None):
62 | self.epoch = epoch
63 | self.augment_dict = augment_dict
64 |
65 | print(self.epoch)
66 |
67 | def _load_image(self, directory, idx):
68 | try:
69 | return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(idx))).convert('RGB')]
70 | except Exception:
71 | print('error loading image:', os.path.join(self.root_path, directory, self.image_tmpl.format(idx)))
72 | return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(1))).convert('RGB')]
73 |
74 |
75 | def _parse_list(self):
76 | # check the frame number is large >3:
77 | tmp = [x.strip().split(' ') for x in open(self.list_file)]
78 | if not self.test_mode or self.remove_missing:
79 | print(f'before removing missing: {len(tmp)}')
80 | tmp = [item for item in tmp if int(item[1]) >= 3]
81 | print(f'after removing missing: {len(tmp)}')
82 |
83 | self.video_list = [VideoRecord(item) for item in tmp]
84 |
85 | if self.image_tmpl == '{:06d}-{}_{:05d}.jpg':
86 | for v in self.video_list:
87 | v._data[1] = int(v._data[1]) / 2
88 | print('video number:%d' % (len(self.video_list)))
89 |
90 | def get_video_list(self):
91 | return self.video_list
92 |
93 | def _sample_indices(self, record):
94 | """
95 | :param record: VideoRecord
96 | :return: list
97 | """
98 | if self.dense_sample: # i3d dense sample
99 | sample_pos = max(1, 1 + record.num_frames - 64)
100 | t_stride = 64 // self.num_segments
101 | start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1)
102 | offsets = [(idx * t_stride + start_idx) % record.num_frames for idx in range(self.num_segments)]
103 | return np.array(offsets) + 1
104 | else: # normal sample
105 | average_duration = (record.num_frames - self.new_length + 1) // self.num_segments
106 | if average_duration > 0:
107 | offsets = np.multiply(list(range(self.num_segments)), average_duration) + randint(average_duration,
108 | size=self.num_segments)
109 | elif record.num_frames > self.num_segments:
110 | offsets = np.sort(randint(record.num_frames - self.new_length + 1, size=self.num_segments))
111 | else:
112 | # shouldn here be a range?
113 | offsets = np.zeros((self.num_segments,))
114 | return offsets + 1
115 |
116 | def _get_val_indices(self, record):
117 | if self.dense_sample: # i3d dense sample
118 | sample_pos = max(1, 1 + record.num_frames - 64)
119 | t_stride = 64 // self.num_segments
120 | start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1)
121 | offsets = [(idx * t_stride + start_idx) % record.num_frames for idx in range(self.num_segments)]
122 | return np.array(offsets) + 1
123 | else:
124 | if record.num_frames > self.num_segments + self.new_length - 1:
125 | tick = (record.num_frames - self.new_length + 1) / float(self.num_segments)
126 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)])
127 | else:
128 | # shouldn here be a range?
129 | offsets = np.zeros((self.num_segments,))
130 | return offsets + 1
131 |
132 | def _get_test_indices(self, record):
133 | if self.dense_sample:
134 | sample_pos = max(1, 1 + record.num_frames - 64)
135 | t_stride = 64 // self.num_segments
136 | start_list = np.linspace(0, sample_pos - 1, num=10, dtype=int)
137 | offsets = []
138 | for start_idx in start_list.tolist():
139 | offsets += [(idx * t_stride + start_idx) % record.num_frames for idx in range(self.num_segments)]
140 | return np.array(offsets) + 1
141 | elif self.twice_sample:
142 |
143 |
144 | tick = (record.num_frames - self.new_length + 1) / float(self.num_segments)
145 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)] +
146 | [int(tick * x) for x in range(self.num_segments)])
147 |
148 | return offsets + 1
149 | else:
150 | tick = (record.num_frames - self.new_length + 1) / float(self.num_segments)
151 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)])
152 | return offsets + 1
153 |
154 |
155 |
156 | def __getitem__(self, index):
157 | record = self.video_list[index]
158 | # check this is a legit video folder
159 |
160 | if self.image_tmpl == 'flow_{}_{:05d}.jpg':
161 | file_name = self.image_tmpl.format('x', 1)
162 | full_path = os.path.join(self.root_path, record.path, file_name)
163 | elif self.image_tmpl == '{:06d}-{}_{:05d}.jpg':
164 | file_name = self.image_tmpl.format(int(record.path), 'x', 1)
165 | full_path = os.path.join(self.root_path, '{:06d}'.format(int(record.path)), file_name)
166 | else:
167 | file_name = self.image_tmpl.format(1)
168 | full_path = os.path.join(self.root_path, record.path, file_name)
169 |
170 | while not os.path.exists(full_path):
171 | print('################## Not Found:', os.path.join(self.root_path, record.path, file_name))
172 | index = np.random.randint(len(self.video_list))
173 | record = self.video_list[index]
174 | if self.image_tmpl == 'flow_{}_{:05d}.jpg':
175 | file_name = self.image_tmpl.format('x', 1)
176 | full_path = os.path.join(self.root_path, record.path, file_name)
177 | elif self.image_tmpl == '{:06d}-{}_{:05d}.jpg':
178 | file_name = self.image_tmpl.format(int(record.path), 'x', 1)
179 | full_path = os.path.join(self.root_path, '{:06d}'.format(int(record.path)), file_name)
180 | else:
181 | file_name = self.image_tmpl.format(1)
182 | full_path = os.path.join(self.root_path, record.path, file_name)
183 |
184 | if not self.test_mode:
185 | segment_indices = self._sample_indices(record) if self.random_shift else self._get_val_indices(record)
186 | else:
187 | segment_indices = self._get_test_indices(record)
188 | crt_detected_boxes = 10
189 | return self.get(record, segment_indices, crt_detected_boxes)
190 |
191 | def get(self, record, indices, detected_boxes):
192 | images = list()
193 | temp_indices = list()
194 | for seg_ind in indices:
195 | p = int(seg_ind)
196 | for i in range(self.new_length):
197 | seg_imgs = self._load_image(record.path, p)
198 |
199 | temp_indices.extend([p])
200 | images.extend(seg_imgs)
201 | if p < record.num_frames:
202 | p += 1
203 |
204 | process_data = self.transform(images)
205 | gt_box = np.zeros((16, 10,4))
206 | return gt_box, process_data, record.label, 0
207 |
208 |
209 | def __len__(self):
210 | return len(self.video_list)
--------------------------------------------------------------------------------
/ops/temporal_shift.py:
--------------------------------------------------------------------------------
1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding"
2 | # arXiv:1811.08383
3 | # Ji Lin*, Chuang Gan, Song Han
4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | import pdb
10 |
11 | class TemporalShift(nn.Module):
12 | def __init__(self, net, n_segment=3, n_div=8, inplace=False):
13 | super(TemporalShift, self).__init__()
14 | self.net = net
15 | self.n_segment = n_segment
16 | self.fold_div = n_div
17 | self.inplace = inplace
18 | if inplace:
19 | print('=> Using in-place shift...')
20 | print('=> Using fold div: {}'.format(self.fold_div))
21 |
22 | def forward(self, x):
23 | x = self.shift(x, self.n_segment, fold_div=self.fold_div, inplace=self.inplace)
24 | return self.net(x)
25 |
26 | @staticmethod
27 | def shift(x, n_segment, fold_div=3, inplace=False):
28 | nt, c, h, w = x.size()
29 | n_batch = nt // n_segment
30 | x = x.view(n_batch, n_segment, c, h, w)
31 |
32 | fold = c // fold_div
33 | if inplace:
34 | # Due to some out of order error when performing parallel computing.
35 | # May need to write a CUDA kernel.
36 | raise NotImplementedError
37 | # out = InplaceShift.apply(x, fold)
38 | else:
39 | # pdb.set_trace()
40 | # print(f'shift input mean: {x.mean()}')
41 | out = torch.zeros_like(x)
42 | out[:, :-1, :fold] = x[:, 1:, :fold] # shift left
43 | out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold] # shift right
44 | out[:, :, 2 * fold:] = x[:, :, 2 * fold:] # not shift
45 | # print(f'shift out mean: {out.mean()}')
46 | # print(f'out shape {out.shape}')
47 | return out.view(nt, c, h, w)
48 |
49 | class MyIdentity(nn.Module):
50 | def __init__(self, net):
51 | super(MyIdentity, self).__init__()
52 |
53 |
54 | def forward(self, x):
55 | return x
56 |
57 |
58 |
59 | class InplaceShift(torch.autograd.Function):
60 | # Special thanks to @raoyongming for the help to this function
61 | @staticmethod
62 | def forward(ctx, input, fold):
63 | # not support higher order gradient
64 | # input = input.detach_()
65 | ctx.fold_ = fold
66 | n, t, c, h, w = input.size()
67 | buffer = input.data.new(n, t, fold, h, w).zero_()
68 | buffer[:, :-1] = input.data[:, 1:, :fold]
69 | input.data[:, :, :fold] = buffer
70 | buffer.zero_()
71 | buffer[:, 1:] = input.data[:, :-1, fold: 2 * fold]
72 | input.data[:, :, fold: 2 * fold] = buffer
73 | return input
74 |
75 | @staticmethod
76 | def backward(ctx, grad_output):
77 | # grad_output = grad_output.detach_()
78 | fold = ctx.fold_
79 | n, t, c, h, w = grad_output.size()
80 | buffer = grad_output.data.new(n, t, fold, h, w).zero_()
81 | buffer[:, 1:] = grad_output.data[:, :-1, :fold]
82 | grad_output.data[:, :, :fold] = buffer
83 | buffer.zero_()
84 | buffer[:, :-1] = grad_output.data[:, 1:, fold: 2 * fold]
85 | grad_output.data[:, :, fold: 2 * fold] = buffer
86 | return grad_output, None
87 |
88 |
89 | class TemporalPool(nn.Module):
90 | def __init__(self, net, n_segment):
91 | super(TemporalPool, self).__init__()
92 | self.net = net
93 | self.n_segment = n_segment
94 |
95 | def forward(self, x):
96 | x = self.temporal_pool(x, n_segment=self.n_segment)
97 | return self.net(x)
98 |
99 | @staticmethod
100 | def temporal_pool(x, n_segment):
101 | nt, c, h, w = x.size()
102 | n_batch = nt // n_segment
103 | x = x.view(n_batch, n_segment, c, h, w).transpose(1, 2) # n, c, t, h, w
104 | x = F.max_pool3d(x, kernel_size=(3, 1, 1), stride=(2, 1, 1), padding=(1, 0, 0))
105 | x = x.transpose(1, 2).contiguous().view(nt // 2, c, h, w)
106 | return x
107 |
108 |
109 | def make_temporal_shift(net, n_segment, n_div=8, place='blockres', temporal_pool=False):
110 | if temporal_pool:
111 | n_segment_list = [n_segment, n_segment // 2, n_segment // 2, n_segment // 2]
112 | else:
113 | n_segment_list = [n_segment] * 4
114 | assert n_segment_list[-1] > 0
115 | print('=> n_segment per stage: {}'.format(n_segment_list))
116 |
117 | import torchvision
118 | # if isinstance(net, torchvision.models.ResNet):
119 | if True:
120 | if place == 'block':
121 | # print('place == block')
122 | def make_block_temporal(stage, this_segment):
123 | blocks = list(stage.children())
124 | # print('=> Processing stage with {} blocks'.format(len(blocks)))
125 | for i, b in enumerate(blocks):
126 | # print(f'tenporal shift {b}')
127 | # pdb.set_trace()
128 | blocks[i] = TemporalShift(b, n_segment=this_segment, n_div=n_div)
129 | return nn.Sequential(*(blocks))
130 | # pdb.set_trace()
131 | net.layer1 = make_block_temporal(net.layer1, n_segment_list[0])
132 | net.layer2 = make_block_temporal(net.layer2, n_segment_list[1])
133 | net.layer3 = make_block_temporal(net.layer3, n_segment_list[2])
134 | net.layer4 = make_block_temporal(net.layer4, n_segment_list[3])
135 |
136 | elif 'blockres' in place:
137 | print(f'place={place}')
138 |
139 | n_round = 1
140 | if len(list(net.layer3.children())) >= 23:
141 | n_round = 2
142 | print('=> Using n_round {} to insert temporal shift'.format(n_round))
143 |
144 | def make_block_temporal(stage, this_segment):
145 | blocks = list(stage.children())
146 | print('=> Processing stage with {} blocks residual'.format(len(blocks)))
147 | nr_shift_layers = 0
148 | # pdb.set_trace()
149 | for i, b in enumerate(blocks):
150 | if i % n_round == 0:
151 | # print(f'[{i}] tenporal shift. this_segment = {this_segment}. n_div = {n_div}. n_round = {n_round}')
152 | #pdb.set_trace()
153 | nr_shift_layers += 1
154 | blocks[i].conv1 = TemporalShift(b.conv1, n_segment=this_segment, n_div=n_div)
155 | # blocks[i].conv1 = b.conv1
156 | # print(f'nr_shift_layers: {nr_shift_layers}')
157 | return nn.Sequential(*blocks)
158 | # pdb.set_trace()
159 | # net.bn1 = MyIdentity(net.bn1)
160 |
161 | # net.conv1 = TemporalShift(net.conv1, n_segment=n_segment_list[0], n_div=3)
162 | net.layer1 = make_block_temporal(net.layer1, n_segment_list[0])
163 | net.layer2 = make_block_temporal(net.layer2, n_segment_list[1])
164 | net.layer3 = make_block_temporal(net.layer3, n_segment_list[2])
165 | net.layer4 = make_block_temporal(net.layer4, n_segment_list[3])
166 | else:
167 | raise NotImplementedError(place)
168 |
169 |
170 | def make_temporal_pool(net, n_segment):
171 | import torchvision
172 | if isinstance(net, torchvision.models.ResNet):
173 | print('=> Injecting nonlocal pooling')
174 | net.layer2 = TemporalPool(net.layer2, n_segment)
175 | else:
176 | raise NotImplementedError
177 |
178 |
179 | if __name__ == '__main__':
180 | # test inplace shift v.s. vanilla shift
181 | tsm1 = TemporalShift(nn.Sequential(), n_segment=8, n_div=8, inplace=False)
182 | tsm2 = TemporalShift(nn.Sequential(), n_segment=8, n_div=8, inplace=True)
183 |
184 | print('=> Testing CPU...')
185 | # test forward
186 | with torch.no_grad():
187 | for i in range(10):
188 | x = torch.rand(2 * 8, 3, 224, 224)
189 | y1 = tsm1(x)
190 | y2 = tsm2(x)
191 | assert torch.norm(y1 - y2).item() < 1e-5
192 |
193 | # test backward
194 | with torch.enable_grad():
195 | for i in range(10):
196 | x1 = torch.rand(2 * 8, 3, 224, 224)
197 | x1.requires_grad_()
198 | x2 = x1.clone()
199 | y1 = tsm1(x1)
200 | y2 = tsm2(x2)
201 | grad1 = torch.autograd.grad((y1 ** 2).mean(), [x1])[0]
202 | grad2 = torch.autograd.grad((y2 ** 2).mean(), [x2])[0]
203 | assert torch.norm(grad1 - grad2).item() < 1e-5
204 |
205 | print('=> Testing GPU...')
206 | tsm1.cuda()
207 | tsm2.cuda()
208 | # test forward
209 | with torch.no_grad():
210 | for i in range(10):
211 | x = torch.rand(2 * 8, 3, 224, 224).cuda()
212 | y1 = tsm1(x)
213 | y2 = tsm2(x)
214 | assert torch.norm(y1 - y2).item() < 1e-5
215 |
216 | # test backward
217 | with torch.enable_grad():
218 | for i in range(10):
219 | x1 = torch.rand(2 * 8, 3, 224, 224).cuda()
220 | x1.requires_grad_()
221 | x2 = x1.clone()
222 | y1 = tsm1(x1)
223 | y2 = tsm2(x2)
224 | grad1 = torch.autograd.grad((y1 ** 2).mean(), [x1])[0]
225 | grad2 = torch.autograd.grad((y2 ** 2).mean(), [x2])[0]
226 | assert torch.norm(grad1 - grad2).item() < 1e-5
227 | print('Test passed.')
228 |
229 |
230 |
231 |
232 |
--------------------------------------------------------------------------------
/ops/dataset_syncMNIST.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from glob import glob
3 | import os
4 | import pickle
5 | import torch
6 | import time
7 | import pdb
8 | import math
9 |
10 |
11 | def get_labels_dict_multi_sincron():
12 | # create labels
13 | label = 0
14 | labels_dict = {}
15 | for i in range(10):
16 | for j in range(i):
17 | labels_dict[(i, j)] = label
18 | labels_dict[(j, i)] = label
19 | label += 1
20 | # no sync digits
21 | labels_dict[(-1, -1)] = label
22 | label += 1
23 | #
24 | for i in range(10):
25 | labels_dict[(i, i)] = label
26 | label += 1
27 |
28 | return labels_dict, label
29 |
30 | class SyncedMNISTDataSet(torch.utils.data.IterableDataset):
31 | def __init__(self, split='train', dataset_path=None, dataset_fraction_used=1.0):
32 | super(SyncedMNISTDataSet).__init__()
33 | #assert end > start, "this example code only works with end >= start"
34 | self.split = split
35 | #self.dataset_path = f'/data/datasets/video_mnist/sync_mnist_large_v2_split_{split}_max_sync_dist_160_num_classes_46_no_digits_5_no_noise_parts_0_uint8/'
36 | #self.dataset_path = f'/data/datasets/video_mnist/sync_mnist_large_v2_split_{split}_max_sync_dist_160_num_classes_46_no_digits_5_no_noise_parts_0_uint8/'
37 | #self.dataset_path = f'/data/datasets/video_mnist/sync_mnist_large_v2_split_{split}_max_sync_dist_160_num_classes_46_no_digits_5_no_noise_parts_0_uint8/'
38 | self.dataset_path = dataset_path
39 | self.worker_id = 0
40 | self.max_workers = 0
41 | self.dataset_len = len(glob(self.dataset_path + '/data*pickle')) * 1000
42 | self.dataset_fraction_used = dataset_fraction_used
43 | # print(f'in iterable dataset: {self.worker_id } / {self.max_workers}')
44 | self.gen = None
45 |
46 | def __iter__(self):
47 | # print(f'worker: [{self.worker_id } / {self.max_workers}]')
48 | return self.gen.next_item(self.worker_id)
49 | def __len__(self):
50 | return self.dataset_len
51 | def syncMNIST_worker_init_fn(worker_id):
52 | worker_info = torch.utils.data.get_worker_info()
53 | dataset = worker_info.dataset # the dataset copy in this worker process
54 | dataset.worker_id = worker_info.id
55 | dataset.max_workers = worker_info.num_workers
56 | # print(f'in woker init fct: {dataset.worker_id}/ {dataset.max_workers}')
57 | dataset.gen = SyncMNISTGenerator(dataset_path=dataset.dataset_path,
58 | worker_id=dataset.worker_id, max_workers=dataset.max_workers,
59 | split=dataset.split, dataset_fraction_used=dataset.dataset_fraction_used)
60 | # print(f'in woker init fct: len gen {len(dataset.gen)}')
61 |
62 | def read_data_mnist(file, num_classes=65):
63 | with open(file, 'rb') as fo:
64 | # print(f'loading: {file}')
65 | videos_dict = pickle.load(fo)
66 | # x = np.expand_dims(videos_dict['videos'], 4)
67 | # x = videos_dict['videos']#.astype(np.float32) / 255.0 * 2 - 1.0
68 | x = videos_dict['videos']#.astype(np.uint8)
69 |
70 | y = videos_dict['labels'].astype(int).squeeze()
71 | y = np.clip(y, 0, num_classes-1)
72 | #y = np.expand_dims(np.eye(num_classes)[y], axis=1)
73 | # y = y.astype(np.float32)
74 |
75 | coords = videos_dict['videos_digits_coords']
76 | top_left = coords
77 | bot_right = coords + 28
78 | digits_boxes = np.concatenate([top_left,bot_right], axis=-1) // 2
79 | is_ann_boxes = np.ones((digits_boxes.shape[0], 1))
80 |
81 | video = x
82 | label = y.astype(np.int64)
83 |
84 | if False:
85 | digits = videos_dict['videos_digits']
86 | min_digits = digits.min(1)
87 | max_digits = digits.max(1)
88 | labels_dict, max_labels = get_labels_dict_multi_sincron()
89 |
90 | min_max = np.zeros_like(label)
91 | for i in range(min_max.shape[0]):
92 | min_max[i] = labels_dict[(min_digits[i], max_digits[i])]
93 |
94 |
95 | # print(f'video.dtype: {video.dtype}')
96 | # video = torch.from_numpy(x).float()
97 | # label = torch.from_numpy(y).long()
98 | # print(f'video: {video.shape}')
99 | return video, label, digits_boxes, is_ann_boxes
100 | # return video, label, digits_boxes, is_ann_boxes, min_max
101 |
102 | class SyncMNISTGenerator():
103 | def __init__(self, dataset_path,split, max_epochs = 100, num_classes=46, num_digits=0,
104 | worker_id=0, max_workers=1, dataset_fraction_used=1.0):
105 | no_videos_per_pickle = 1000
106 | self.num_classes = num_classes
107 | self.max_epochs = max_epochs
108 | self.train_files = glob(dataset_path + '/data*pickle')
109 |
110 | # if num_digits == 5:
111 | # train_dataset_files = f'/data/datasets/video_mnist/files_order_pytorch/{split}_files_order3.pickle'
112 | # #elif num_digits == 3:
113 | # else:
114 | # train_dataset_files = f'/data/datasets/video_mnist/files_order_pytorch/{split}_files_{num_digits}digits_order3.pickle'
115 | # if os.path.exists(train_dataset_files):
116 | # with open(train_dataset_files, 'rb') as f:
117 | # self.mnist_random_order = pickle.load(f)
118 | # print(f'Reading from: {self.mnist_random_order[0]}')
119 | # else:
120 | # print('Shuffling and saving train files')
121 | # self.mnist_random_order = []
122 | # for ep in range(100):
123 | # np.random.shuffle(self.train_files)
124 | # self.mnist_random_order.append(self.train_files)
125 | # with open(train_dataset_files, 'wb') as f:
126 | # pickle.dump(self.mnist_random_order, f)
127 | # print(f'Reading: {len(self.mnist_random_order[0])} pickles')
128 |
129 | overall_start = 0
130 | # print(f'dataset_path: {dataset_path}')
131 | # print(f'len(self.train_files): {len(self.train_files)}')
132 | # print(f'dataset_fraction_used: {dataset_fraction_used}')
133 | overall_end = int( dataset_fraction_used * len(self.train_files))
134 |
135 | per_worker = int(math.ceil((overall_end - overall_start) / float(max_workers)))
136 | self.start = overall_start + worker_id * per_worker
137 | self.end = min(self.start + per_worker, overall_end)
138 | # print(f'generator has overall_end:{overall_end} start-end {self.start}-{self.end}')
139 |
140 | def next_item(self, idx):
141 | #for epoch in range(self.max_epochs):
142 | # print(f"Generator epoch: {epoch}")
143 | #return 1
144 | train_files = self.train_files #self.mnist_random_order[0]
145 | train_files = train_files[self.start:self.end]
146 | # print(f'Read:{train_files}')
147 | for file in train_files:
148 | train_videos, train_labels, target_boxes_np, is_ann_boxes = read_data_mnist(file, num_classes=self.num_classes)
149 | # train_videos, train_labels, target_boxes_np, is_ann_boxes, min_max_label = read_data_mnist(file, num_classes=self.num_classes)
150 |
151 | for pick_i in range(train_videos.shape[0]):
152 | # print(f'idx [{idx}] element {pick_i} from {file}')
153 | # video_ids = np.zeros_like(train_labels[pick_i], dtype=np.float32)
154 | #yield (train_videos[pick_i], train_labels[pick_i],video_ids, target_boxes_np[pick_i], is_ann_boxes[pick_i], pick_i, file)
155 | video = train_videos[pick_i] #.astype(np.float32) / 255.0 * 2 - 1.0
156 | yield video, train_labels[pick_i], target_boxes_np[pick_i]#, min_max_label[pick_i]
157 |
158 |
159 | if __name__ == "__main__":
160 | if False:
161 | dataset_path = '/data/datasets/video_mnist/sync_mnist_large_v2_split_test_max_sync_dist_160_num_classes_46_no_digits_5_no_noise_parts_0/'
162 | gen = SyncMNISTGenerator(dataset_path=dataset_path, split='test')
163 |
164 | get_next_item = gen.next_item()
165 | nr_videos = 15000
166 |
167 | time1 = time.time()
168 | # for b in range(nr_videos):
169 | for b, (train_videos, train_labels, target_boxes_np, _ , _) in enumerate(get_next_item):
170 | # train_videos, train_labels, target_boxes_np, _ , _ = next(get_next_item)
171 | print(f'[{b}] {train_videos.mean()}')
172 | time2 = time.time()
173 | print(f'time read {nr_videos} videos: {time2 - time1}')
174 |
175 | else:
176 | batch_size = 8
177 | ds = SyncedMNISTDataSet(split='train')
178 | loader = torch.utils.data.DataLoader(ds,batch_size=batch_size, num_workers=2, worker_init_fn=syncMNIST_worker_init_fn)
179 | time1 = time.time()
180 |
181 |
182 | #for b, (train_videos, train_labels, target_boxes_np, _ , _, _, _) in enumerate(loader):
183 | for b, (train_videos, train_labels) in enumerate(loader):
184 | if b % 100 == 0:
185 | print(b * batch_size)
186 | pass
187 | # pdb.set_trace()
188 | # print(f'[{b}] {train_videos.shape}, {train_labels.shape}')
189 |
190 | # print('Epoch 2')
191 | # for b, (train_videos, train_labels, target_boxes_np, _ , _) in enumerate(loader):
192 | # # pdb.set_trace()
193 | # print(f'[{b}] {train_videos[0].numpy().mean()}')
194 |
195 |
196 |
197 | time2 = time.time()
198 | print(f'time read : {time2 - time1}')
199 |
200 |
--------------------------------------------------------------------------------
/opts.py:
--------------------------------------------------------------------------------
1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding"
2 | # arXiv:1811.08383
3 | # Ji Lin*, Chuang Gan, Song Han
4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu
5 | import argparse
6 | def str2bool(v):
7 | if isinstance(v, bool):
8 | return v
9 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
10 | return True
11 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
12 | return False
13 | else:
14 | raise argparse.ArgumentTypeError('Boolean value expected.')
15 |
16 |
17 | parser = argparse.ArgumentParser(description="PyTorch implementation of Temporal Segment Networks")
18 | parser.add_argument('--dataset', type=str, default='somethingv2')
19 | parser.add_argument('--dataset_fraction_used', default=1.0, type=float,
20 | help='use just x% of the total dataset')
21 |
22 | parser.add_argument('--test_dataset', type=str, default='../test')
23 | parser.add_argument('--train_dataset', type=str, default='../train')
24 | parser.add_argument('--modality', type=str, choices=['RGB', 'Flow', 'gray'], default='RGB')
25 | parser.add_argument('--train_list', type=str, default="")
26 | parser.add_argument('--val_list', type=str, default="")
27 | parser.add_argument('--root_path', type=str, default="")
28 | parser.add_argument('--store_name', type=str, default="")
29 | # ========================= Model Configs ==========================
30 | parser.add_argument('--arch', type=str, default="resnet50")
31 | parser.add_argument('--num_segments', type=int, default=16)
32 |
33 | parser.add_argument('--consensus_type', type=str, default='avg')
34 | parser.add_argument('--k', type=int, default=3)
35 |
36 | parser.add_argument('--dropout', '--do', default=0.5, type=float,
37 | metavar='DO', help='dropout ratio (default: 0.5)')
38 | parser.add_argument('--loss_type', type=str, default="nll",
39 | choices=['nll'])
40 | parser.add_argument('--img_feature_dim', default=256, type=int, help="the feature dimension for each frame")
41 | parser.add_argument('--suffix', type=str, default=None)
42 | parser.add_argument('--pretrain', type=str, default='imagenet')
43 | parser.add_argument('--tune_from', type=str, default=None, help='fine-tune from checkpoint')
44 |
45 | # ========================= Learning Configs ==========================
46 | parser.add_argument('--epochs', default=120, type=int, metavar='N',
47 | help='number of total epochs to run')
48 | parser.add_argument('-b', '--batch-size', default=10, type=int,
49 | metavar='N', help='mini-batch size (default: 256)')
50 | parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,
51 | metavar='LR', help='initial learning rate')
52 | parser.add_argument('--lr_type', default='step', type=str,
53 | metavar='LRtype', help='learning rate type')
54 | parser.add_argument('--lr_steps', default=[50, 100], type=float, nargs="+",
55 | metavar='LRSteps', help='epochs to decay learning rate by 10')
56 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
57 | help='momentum')
58 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float,
59 | metavar='W', help='weight decay (default: 5e-4)')
60 | parser.add_argument('--clip-gradient', '--gd', default=None, type=float,
61 | metavar='W', help='gradient norm clipping (default: disabled)')
62 | parser.add_argument('--no_partialbn', '--npb', default=False, action="store_true")
63 |
64 | # no_partialbn == NOT FREEZE
65 | # partialbn == enable_partialbn == FREZE BN (except the first one)
66 |
67 | # ========================= Monitor Configs ==========================
68 | parser.add_argument('--print-freq', '-p', default=20, type=int,
69 | metavar='N', help='print frequency (default: 10)')
70 | parser.add_argument('--eval-freq', '-ef', default=5, type=int,
71 | metavar='N', help='evaluation frequency (default: 5)')
72 |
73 |
74 | # ========================= Runtime Configs ==========================
75 | parser.add_argument('-j', '--workers', default=5, type=int, metavar='N',
76 | help='number of data loading workers (default: 8)')
77 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
78 | help='path to latest checkpoint (default: none)')
79 | parser.add_argument('--replace_ignore', default=True, type=str2bool,
80 | help='change name in resnet backbone')
81 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
82 | help='evaluate model on validation set')
83 | parser.add_argument('--snapshot_pref', type=str, default="")
84 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
85 | help='manual epoch number (useful on restarts)')
86 | parser.add_argument('--gpus', nargs='+', type=int, default=None)
87 | parser.add_argument('--flow_prefix', default="", type=str)
88 | parser.add_argument('--root_log',type=str, default='log')
89 | parser.add_argument('--root_model', type=str, default='checkpoint')
90 | parser.add_argument('--model_dir', type=str, default='./models/')
91 | parser.add_argument('--coeff', type=str, default=None) #I don't know what is this for
92 |
93 | parser.add_argument('--shift', default=False, action="store_true", help='use shift for models')
94 | parser.add_argument('--shift_div', default=8, type=int, help='number of div for shift (default: 8)')
95 | parser.add_argument('--shift_place', default='blockres', type=str, help='place for shift (default: stageres)')
96 | parser.add_argument('--temporal_pool', default=False, action="store_true", help='add temporal pooling')
97 | parser.add_argument('--non_local', default=False, action="store_true", help='add non local block')
98 | parser.add_argument('--dense_sample', default=False, action="store_true", help='use dense sample for video dataset')
99 | parser.add_argument('--name', default='graph', type=str,
100 | help='name of the model')
101 |
102 | # # may contain splits
103 | parser.add_argument('--weights', type=str, default=None)
104 | parser.add_argument('--test_segments', type=str, default=25)
105 |
106 | parser.add_argument('--twice_sample', default=False, action="store_true", help='use twice sample for ensemble')
107 | parser.add_argument('--full_res', default=False, type=str2bool, help='Evaluate at full resolution')
108 |
109 | parser.add_argument('--full_size_224', default=False, type=str2bool,
110 | help='reschale to 224 crop the center 224x225')
111 | parser.add_argument('--test_crops', type=int, default=1)
112 | parser.add_argument('--test_batch_size', type=int, default=1)
113 |
114 | # for true test
115 | parser.add_argument('--test_list', type=str, default=None)
116 | parser.add_argument('--csv_file', type=str, default=None)
117 | parser.add_argument('--softmax', default=False, action="store_true", help='use softmax')
118 | parser.add_argument('--max_num', type=int, default=-1)
119 | parser.add_argument('--input_size', type=int, default=224)
120 | parser.add_argument('--crop_fusion_type', type=str, default='avg')
121 | parser.add_argument('--num_set_segments',type=int, default=1,help='TODO: select multiply set of n-frames from a video')
122 |
123 | # ========================= Dynamic Regions Graph ==========================
124 |
125 | parser.add_argument('--dynamic_regions', default='dyreg', type=str, choices=['none', 'pos_only', 'dyreg',
126 | 'semantic'],
127 | help='type of regions used')
128 | parser.add_argument('--init_regions', default='grid', type=str, choices=['center', 'grid'],
129 | help='anchor position (default: center)')
130 | parser.add_argument('--offset_lstm_dim', default=128, type=int, help='number channels for the offset lstm (default: 128)')
131 | parser.add_argument('--use_rstg', default=False, type=str2bool,
132 | help='divide regions size by a scaling factor')
133 | parser.add_argument('--combine_by_sum', default=False, type=str2bool,
134 | help='combine two vectors by concatenation or by summing')
135 | parser.add_argument('--project_at_input', default=False, type=str2bool,
136 | help='combine two vectors by concatenation or by summing')
137 |
138 |
139 | parser.add_argument('--update_layers', type=int, default=2)
140 | parser.add_argument('--send_layers', type=int, default=2)
141 | parser.add_argument('--rnn_type', type=str, choices=['LSTM', 'GRU'], default='LSTM')
142 | parser.add_argument('--aggregation_type', type=str, choices=['dot', 'sum'], default='dot')
143 | parser.add_argument('--offset_generator', type=str, choices=['none', 'big', 'small'], default='big')
144 |
145 | parser.add_argument('--place_graph', default='layer1.1_layer2.2_layer3.4_layer4.1', type=str,
146 | help='where to place the graph: layeri.j_')
147 | parser.add_argument('--rstg_combine', type=str, choices=['serial', 'plus'], default='plus')
148 | parser.add_argument('--ch_div', type=int, default=2)
149 | parser.add_argument('--graph_residual_type', default='norm', type=str,
150 | help='norm / out_gate/ 1chan_out_gate/ gru_gate')
151 | parser.add_argument('--remap_first', default=False, type=str2bool,
152 | help='project and remap or remap and project')
153 | parser.add_argument('--rstg_skip_connection', default=False, type=str2bool,
154 | help='use skip connection from graphs')
155 | parser.add_argument('--warmup_validate', default=False, type=str2bool,
156 | help='warmup 100 steps before validate')
157 | parser.add_argument('--tmp_norm_skip_conn', default=False, type=str2bool,
158 | help='norm for skip connection')
159 | parser.add_argument('--init_skip_zero', default=False, type=str2bool,
160 | help='norm for skip connection')
161 | parser.add_argument('--bottleneck_graph', default=False, type=str2bool,
162 | help='smaller graph in bottleneck layer')
163 |
164 |
165 | parser.add_argument('--eval_mode', default='test', type=str,
166 | choices=['train', 'test'])
167 | parser.add_argument('--visualisation', type=str, choices=['rgb', 'hsv'], default='hsv')
168 | parser.add_argument('--save_kernels', default=False, type=str2bool,
169 | help='save the kernels')
170 |
171 | #params for running distributed
172 | parser.add_argument('--local_rank', type=int, default=0)
173 | parser.add_argument('--ngpu', type=int, default=0)
174 | parser.add_argument('--world_size', type=int, default=0)
175 | parser.add_argument('--check_learned_params', default=False, type=str2bool)
176 |
177 | def parse_args():
178 | from ops import models_config
179 | args = parser.parse_args()
180 | args.graph_params, args.out_pool_size, args.out_num_ch, args.distill_path = models_config.get_models_config()
181 | return args
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/ops/nested_model.py:
--------------------------------------------------------------------------------
1 | """
2 | Optional: Data Parallelism
3 | ==========================
4 | **Authors**: `Sung Kim `_ and `Jenny Kang `_
5 |
6 | In this tutorial, we will learn how to use multiple GPUs using ``DataParallel``.
7 |
8 | It's very easy to use GPUs with PyTorch. You can put the model on a GPU:
9 |
10 | .. code:: python
11 |
12 | device = torch.device("cuda:0")
13 | model.to(device)
14 |
15 | Then, you can copy all your tensors to the GPU:
16 |
17 | .. code:: python
18 |
19 | mytensor = my_tensor.to(device)
20 |
21 | Please note that just calling ``my_tensor.to(device)`` returns a new copy of
22 | ``my_tensor`` on GPU instead of rewriting ``my_tensor``. You need to assign it to
23 | a new tensor and use that tensor on the GPU.
24 |
25 | It's natural to execute your forward, backward propagations on multiple GPUs.
26 | However, Pytorch will only use one GPU by default. You can easily run your
27 | operations on multiple GPUs by making your model run parallelly using
28 | ``DataParallel``:
29 |
30 | .. code:: python
31 |
32 | model = nn.DataParallel(model)
33 |
34 | That's the core behind this tutorial. We will explore it in more detail below.
35 | """
36 |
37 |
38 | ######################################################################
39 | # Imports and parameters
40 | # ----------------------
41 | #
42 | # Import PyTorch modules and define parameters.
43 | #
44 |
45 | import torch
46 | import torch.nn as nn
47 | from torch.utils.data import Dataset, DataLoader
48 | import pdb
49 | # Parameters and DataLoaders
50 | ch_dim = 128
51 |
52 | input_size = (16,14,14,ch_dim)
53 | output_size = 2
54 |
55 | batch_size = 30
56 | data_size = 100
57 |
58 |
59 | ######################################################################
60 | # Device
61 | #
62 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
63 |
64 | ######################################################################
65 | # Dummy DataSet
66 | # -------------
67 | #
68 | # Make a dummy (random) dataset. You just need to implement the
69 | # getitem
70 | #
71 |
72 | class RandomDataset(Dataset):
73 |
74 | def __init__(self, size, length):
75 | self.len = length
76 | self.data = torch.randn(length, size[0], size[1], size[2],size[3])
77 |
78 | def __getitem__(self, index):
79 | return self.data[index]
80 |
81 | def __len__(self):
82 | return self.len
83 |
84 | rand_loader = DataLoader(dataset=RandomDataset(input_size, data_size),
85 | batch_size=batch_size, shuffle=True)
86 |
87 |
88 | ######################################################################
89 | # Simple Model
90 | # ------------
91 | #
92 | # For the demo, our model just gets an input, performs a linear operation, and
93 | # gives an output. However, you can use ``DataParallel`` on any model (CNN, RNN,
94 | # Capsule Net etc.)
95 | #
96 | # We've placed a print statement inside the model to monitor the size of input
97 | # and output tensors.
98 | # Please pay attention to what is printed at batch rank 0.
99 | #
100 |
101 |
102 | class LayerNormAffine(nn.Module):
103 | def __init__(self,input, size=128):
104 | super(LayerNormAffine, self).__init__()
105 | self.norm = nn.LayerNorm(128,elementwise_affine=False).to(input.device)
106 | # nn.init.constant_(self.norm.weight, 0)
107 | # nn.init.constant_(self.norm.bias, 0)
108 |
109 | self.scale = torch.nn.Parameter(torch.Tensor(size=[1,1,1,size])).to(input.device)
110 | self.bias = torch.nn.Parameter(torch.Tensor(size=[1,1,1,size])).to(input.device)
111 |
112 | # pdb.set_trace()
113 | # self.norm = self.norm.to(input.device)
114 | # self.scale = scale.to(input.device)
115 | # self.bias = bias.to(input.device)
116 |
117 | nn.init.constant_(self.bias, 0)
118 | nn.init.constant_(self.scale, 1)
119 |
120 | # self.register_parameter('bias', self.bias)
121 | # self.register_parameter('scale', self.scale)
122 |
123 | def forward(self, input):
124 | # input = self.norm(input) * self.params['scale'] + self.params['bias']
125 | input = self.norm(input) * self.scale + self.bias
126 |
127 |
128 | return input
129 |
130 |
131 | class Model2(nn.Module):
132 | # Our model
133 |
134 | def __init__(self, input_size, output_size):
135 | super(Model2, self).__init__()
136 | self.fc = nn.Linear(ch_dim, ch_dim)
137 | self.norm_dict = nn.ModuleDict({})
138 | self.module_norm_dict = nn.ModuleList()
139 |
140 | self.model2_bias = torch.nn.Parameter(torch.Tensor(size=[1,1,1,128]))
141 | nn.init.constant_(self.model2_bias, 1)
142 |
143 | # self.init_norm = LayerNormAffine()
144 |
145 |
146 | def get_norm(self, input, name, zero_init=False):
147 | # input: B * T x C x H x W
148 |
149 | if name not in self.norm_dict:
150 | self.model2_asdasdiasdadksa = torch.nn.Parameter(torch.Tensor(size=[1,1,1,128]))
151 | nn.init.constant_(self.model2_asdasdiasdadksa, 1)
152 |
153 |
154 | norm = LayerNormAffine(input)
155 | self.norm_dict[name] = norm
156 | # print(self)
157 | else:
158 | norm = self.norm_dict[name]
159 |
160 | input = self.norm_dict[name](input)
161 | return input
162 |
163 | def forward(self, input):
164 | pdb.set_trace()
165 | input = self.get_norm(input, 'norm1')
166 | input = input + self.model2_bias
167 |
168 | output = self.fc(input)
169 | print("\tIn Model2: input size", input.size(),
170 | "output size", output.size())
171 |
172 | return output
173 |
174 |
175 | class Model(nn.Module):
176 | # Our model
177 |
178 | def __init__(self, input_size, output_size):
179 | super(Model, self).__init__()
180 | self.model2 = Model2(input_size, ch_dim)
181 | self.fc = nn.Linear(ch_dim, output_size)
182 |
183 | def forward(self, input):
184 | input = self.model2(input)
185 | output = self.fc(input)
186 | print(f"Forward model: {self}")
187 | # print(f'Model parameters: {self.parameters()}')
188 | #for p in self.parameters():
189 | for name, param in model.named_parameters():
190 | print(f'param: {name}')
191 | print(f'param: {param.device}')
192 |
193 | print("\tIn Model: input size", input.size(),
194 | "output size", output.size())
195 |
196 | return output
197 |
198 |
199 | ######################################################################
200 | # Create Model and DataParallel
201 | # -----------------------------
202 | #
203 | # This is the core part of the tutorial. First, we need to make a model instance
204 | # and check if we have multiple GPUs. If we have multiple GPUs, we can wrap
205 | # our model using ``nn.DataParallel``. Then we can put our model on GPUs by
206 | # ``model.to(device)``
207 | #
208 |
209 | model = Model(input_size, output_size)
210 | if torch.cuda.device_count() > 1:
211 | print("Let's use", torch.cuda.device_count(), "GPUs!")
212 | print(f"Model: {model}")
213 | # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
214 | model = nn.DataParallel(model)
215 |
216 | model.to(device)
217 |
218 |
219 | ######################################################################
220 | # Run the Model
221 | # -------------
222 | #
223 | # Now we can see the sizes of input and output tensors.
224 | #
225 |
226 | for data in rand_loader:
227 | input = data.to(device)
228 | output = model(input)
229 | print("Outside: input size", input.size(),
230 | "output_size", output.size())
231 |
232 |
233 | ######################################################################
234 | # Results
235 | # -------
236 | #
237 | # If you have no GPU or one GPU, when we batch 30 inputs and 30 outputs, the model gets 30 and outputs 30 as
238 | # expected. But if you have multiple GPUs, then you can get results like this.
239 | #
240 | # 2 GPUs
241 | # ~~~~~~
242 | #
243 | # If you have 2, you will see:
244 | #
245 | # .. code:: bash
246 | #
247 | # # on 2 GPUs
248 | # Let's use 2 GPUs!
249 | # In Model: input size torch.Size([15, 5]) output size torch.Size([15, 2])
250 | # In Model: input size torch.Size([15, 5]) output size torch.Size([15, 2])
251 | # Outside: input size torch.Size([30, 5]) output_size torch.Size([30, 2])
252 | # In Model: input size torch.Size([15, 5]) output size torch.Size([15, 2])
253 | # In Model: input size torch.Size([15, 5]) output size torch.Size([15, 2])
254 | # Outside: input size torch.Size([30, 5]) output_size torch.Size([30, 2])
255 | # In Model: input size torch.Size([15, 5]) output size torch.Size([15, 2])
256 | # In Model: input size torch.Size([15, 5]) output size torch.Size([15, 2])
257 | # Outside: input size torch.Size([30, 5]) output_size torch.Size([30, 2])
258 | # In Model: input size torch.Size([5, 5]) output size torch.Size([5, 2])
259 | # In Model: input size torch.Size([5, 5]) output size torch.Size([5, 2])
260 | # Outside: input size torch.Size([10, 5]) output_size torch.Size([10, 2])
261 | #
262 | # 3 GPUs
263 | # ~~~~~~
264 | #
265 | # If you have 3 GPUs, you will see:
266 | #
267 | # .. code:: bash
268 | #
269 | # Let's use 3 GPUs!
270 | # In Model: input size torch.Size([10, 5]) output size torch.Size([10, 2])
271 | # In Model: input size torch.Size([10, 5]) output size torch.Size([10, 2])
272 | # In Model: input size torch.Size([10, 5]) output size torch.Size([10, 2])
273 | # Outside: input size torch.Size([30, 5]) output_size torch.Size([30, 2])
274 | # In Model: input size torch.Size([10, 5]) output size torch.Size([10, 2])
275 | # In Model: input size torch.Size([10, 5]) output size torch.Size([10, 2])
276 | # In Model: input size torch.Size([10, 5]) output size torch.Size([10, 2])
277 | # Outside: input size torch.Size([30, 5]) output_size torch.Size([30, 2])
278 | # In Model: input size torch.Size([10, 5]) output size torch.Size([10, 2])
279 | # In Model: input size torch.Size([10, 5]) output size torch.Size([10, 2])
280 | # In Model: input size torch.Size([10, 5]) output size torch.Size([10, 2])
281 | # Outside: input size torch.Size([30, 5]) output_size torch.Size([30, 2])
282 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
283 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
284 | # In Model: input size torch.Size([2, 5]) output size torch.Size([2, 2])
285 | # Outside: input size torch.Size([10, 5]) output_size torch.Size([10, 2])
286 | #
287 | # 8 GPUs
288 | # ~~~~~~~~~~~~~~
289 | #
290 | # If you have 8, you will see:
291 | #
292 | # .. code:: bash
293 | #
294 | # Let's use 8 GPUs!
295 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
296 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
297 | # In Model: input size torch.Size([2, 5]) output size torch.Size([2, 2])
298 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
299 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
300 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
301 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
302 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
303 | # Outside: input size torch.Size([30, 5]) output_size torch.Size([30, 2])
304 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
305 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
306 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
307 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
308 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
309 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
310 | # In Model: input size torch.Size([2, 5]) output size torch.Size([2, 2])
311 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
312 | # Outside: input size torch.Size([30, 5]) output_size torch.Size([30, 2])
313 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
314 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
315 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
316 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
317 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
318 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
319 | # In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2])
320 | # In Model: input size torch.Size([2, 5]) output size torch.Size([2, 2])
321 | # Outside: input size torch.Size([30, 5]) output_size torch.Size([30, 2])
322 | # In Model: input size torch.Size([2, 5]) output size torch.Size([2, 2])
323 | # In Model: input size torch.Size([2, 5]) output size torch.Size([2, 2])
324 | # In Model: input size torch.Size([2, 5]) output size torch.Size([2, 2])
325 | # In Model: input size torch.Size([2, 5]) output size torch.Size([2, 2])
326 | # In Model: input size torch.Size([2, 5]) output size torch.Size([2, 2])
327 | # Outside: input size torch.Size([10, 5]) output_size torch.Size([10, 2])
328 | #
329 |
330 |
331 | ######################################################################
332 | # Summary
333 | # -------
334 | #
335 | # DataParallel splits your data automatically and sends job orders to multiple
336 | # models on several GPUs. After each model finishes their job, DataParallel
337 | # collects and merges the results before returning it to you.
338 | #
339 | # For more information, please check out
340 | # https://pytorch.org/tutorials/beginner/former\_torchies/parallelism\_tutorial.html.
341 | #
--------------------------------------------------------------------------------
/ops/models_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 | import matplotlib.pyplot as plt
6 | import matplotlib.colors as colors
7 |
8 | import sys
9 | import os
10 | sys.path.append(os.path.abspath(os.path.join(
11 | os.path.dirname(os.path.realpath(__file__)), '..')))
12 | import time
13 | from opts import parser
14 |
15 | global args, best_prec1
16 | from opts import parse_args
17 | args = parser.parse_args()
18 | args = parse_args()
19 |
20 | import pickle
21 | import pdb
22 |
23 | if args.rnn_type == 'LSTM':
24 | recurrent_net = torch.nn.LSTM
25 | elif args.rnn_type == 'GRU':
26 | recurrent_net = torch.nn.GRU
27 |
28 | def differentiable_resize_area(x, kernel):
29 | # x: B x C x H x W
30 | # kernel: B x h x w x H x W
31 | B = x.size()[0]
32 | C = x.size()[1]
33 | H = x.size()[2]
34 | W = x.size()[3]
35 | h = kernel.size()[1]
36 | w = kernel.size()[2]
37 |
38 | kernel_res = kernel.view(B, h * w, H * W)
39 | kernel_res = kernel_res.permute(0,2,1)
40 | x_res = x.view(B, C, H * W)
41 |
42 | x_resize = torch.matmul(x_res, kernel_res)
43 | x_resize = x_resize.view(B, C, h , w)
44 | # x_resize: B x C x h x w
45 | return x_resize
46 |
47 | def save_kernel(mean_kernels, folder='.'):
48 | kernel = mean_kernels[0].detach().numpy()
49 | max_t = kernel.shape[0]
50 | num_rows = kernel.shape[1]
51 |
52 | #for tt in range(max_t-5,max_t):
53 | all_frames = []
54 | for tt in range(max_t):
55 | f, axarr = plt.subplots(num_rows,num_rows)
56 | N = kernel.shape[1] * kernel.shape[2]
57 | for ii in range(kernel.shape[1]):
58 | for jj in range(kernel.shape[2]):
59 | axarr[ii][jj].imshow(kernel[tt][ii][jj])
60 |
61 | plt.savefig(f'{folder}/kernel_mean_{tt}.png')
62 |
63 |
64 | def atanh(x: torch.Tensor):
65 | x = x.clamp(-1 + 1e-7, 1 - 1e-7)
66 | return (torch.log(1 + x).sub(torch.log(1 - x))).mul(0.5)
67 |
68 |
69 | class LayerNormAffine2D(nn.Module):
70 | # B * T x num_chan x H x W
71 | def __init__(self, num_ch, norm_shape, zero_init=False):
72 | super(LayerNormAffine2D, self).__init__()
73 | self.scale = torch.nn.Parameter(torch.Tensor(size=[1,num_ch,1,1]))#.to(input.device)
74 | self.bias = torch.nn.Parameter(torch.Tensor(size=[1,num_ch,1,1]))#.to(input.device)
75 | self.norm = nn.LayerNorm(norm_shape, elementwise_affine=False)#.to(input.device)
76 |
77 | bias_init = 0
78 | scale_init = 1
79 | if zero_init:
80 | scale_init = 0
81 | nn.init.constant_(self.bias, bias_init)
82 | nn.init.constant_(self.scale, scale_init)
83 |
84 | def forward(self, input):
85 | input = self.norm(input) * self.scale + self.bias
86 | return input
87 |
88 | class LayerNormAffine1D(nn.Module):
89 | # B * T x num_chan x N
90 | def __init__(self, num_ch, norm_shape, zero_init=False):
91 | super(LayerNormAffine1D, self).__init__()
92 | self.scale = torch.nn.Parameter(torch.Tensor(size=[1,num_ch,1]))#.to(input.device)
93 | self.bias = torch.nn.Parameter(torch.Tensor(size=[1,num_ch,1]))#.to(input.device)
94 | self.norm = nn.LayerNorm(norm_shape, elementwise_affine=False)#.to(input.device)
95 |
96 | bias_init = 0
97 | scale_init = 1
98 | if zero_init:
99 | scale_init = 0
100 | nn.init.constant_(self.bias, bias_init)
101 | nn.init.constant_(self.scale, scale_init)
102 |
103 | def forward(self, input):
104 | input = self.norm(input) * self.scale + self.bias
105 | return input
106 |
107 | # input: D1 x D2 x ... x C
108 | # ex B x N x T x C
109 | # or B x N x N x T x C
110 | class LayerNormAffineXC(nn.Module):
111 | def __init__(self, num_ch, norm_shape):
112 | super(LayerNormAffineXC, self).__init__()
113 | self.scale = torch.nn.Parameter(torch.Tensor(size=[num_ch]))
114 | self.bias = torch.nn.Parameter(torch.Tensor(size=[num_ch]))
115 | self.norm = nn.LayerNorm(norm_shape, elementwise_affine=False)
116 |
117 | nn.init.constant_(self.bias, 0)
118 | nn.init.constant_(self.scale, 1)
119 |
120 | def forward(self, input):
121 | input = self.norm(input) * self.scale + self.bias
122 | return input
123 |
124 |
125 |
126 |
127 | class GloRe(nn.Module):
128 | def __init__(self, input_height=14, input_channels=16, h=3,w=3):
129 | super(GloRe, self).__init__()
130 |
131 | self.h = h
132 | self.w = w
133 | self.H = input_height
134 | self.W = input_height
135 | self.input_channels = input_channels
136 | self.num_nodes = h * w
137 |
138 | self.norm_dict = nn.ModuleDict({})
139 | self.norm_dict[f'offset_location_emb'] = LayerNormAffine2D(
140 | input_channels, (input_channels, self.h, self.w)
141 | )
142 | # reduce channels
143 | self.mask_conv = nn.Conv2d(self.input_channels, self.num_nodes, [1, 1])
144 | self.update = nn.Conv1d(self.input_channels, args.offset_lstm_dim, [1])
145 |
146 | # positional embeding
147 | pos_emb = self.positional_emb(self.input_channels)
148 | self.register_buffer('pos_emb_buf', pos_emb)
149 | self.sin_pos_emb_proj = nn.Conv2d(self.input_channels, self.input_channels, [1,1])
150 |
151 | def get_norm(self, input, name):
152 | # input: B * T x C x H x W
153 | norm = self.norm_dict[name]
154 | input = norm(input)
155 | return input
156 | def apply_sin_positional_emb(self, x):
157 | # pos_emb_buf: C x H x W
158 | # x:B * T x C x H x W
159 | emb = self.pos_emb_buf.unsqueeze(0)
160 |
161 | emb = emb.repeat(x.shape[0],1,1,1)
162 | emb = self.sin_pos_emb_proj(emb)
163 | out = x + emb
164 | return out #emb- just pos
165 | def positional_emb(self, num_channels=2*32):
166 | # pos_h: H x 1 x 1
167 | # T: 1 x C x 1
168 | # T_even: 1 x C/2 x 1
169 | # emb_h_even: H x C/2 x 1
170 | channels = num_channels // 2
171 | pos_h = torch.linspace(0,1,self.H).unsqueeze(1).unsqueeze(2)
172 | T = torch.pow(10000, torch.arange(0,channels).float() / channels)
173 | T_even = T.view(-1,2)[:,:1].unsqueeze(0)
174 | emb_h_even = torch.sin(pos_h / T_even)
175 |
176 | T_odd = T.view(-1,2)[:,1:].unsqueeze(0)
177 | emb_h_odd = torch.cos(pos_h / T_odd)
178 |
179 | emb_h = torch.cat((emb_h_even, emb_h_odd), dim=2).view(self.H,1, channels)
180 | emb_h = emb_h.repeat(1,self.W,1)
181 |
182 | # pos_w: W x 1 x 1
183 | # T: 1 x C x 1
184 | # T_even: 1 x C/2 x 1
185 | # emb_h_even: W x C/2 x 1
186 | pos_w = torch.linspace(0,1,self.W).unsqueeze(1).unsqueeze(2)
187 | emb_w_even = torch.sin(pos_w / T_even)
188 | emb_w_odd = torch.cos(pos_w / T_odd)
189 | emb_w = torch.cat((emb_w_even, emb_w_odd), dim=2).view(1, self.W, channels)
190 | emb_w = emb_w.repeat(self.H,1,1)
191 |
192 | emb = torch.cat((emb_h, emb_w), dim=2).permute(2,0,1)
193 | return emb
194 | def forward(self, x):
195 | # x: B * T x C x H x W
196 | # mask: B * T x num_nodes x H x W
197 | mask = self.mask_conv(x)
198 | x = self.apply_sin_positional_emb(x)
199 |
200 | # softmax over the number of receiving nodes: each of the
201 | # H * W location send predominantly to one node
202 | mask = F.softmax(mask, dim=1)
203 | mask = mask.view(mask.shape[0], 3,3, mask.shape[2], mask.shape[3])
204 |
205 | # nodes: B*T x C x num_nodes
206 | offset_nodes = differentiable_resize_area(x, mask)
207 |
208 | # B*T x C x h x w
209 | offset_nodes = offset_nodes.view(offset_nodes.shape[0], offset_nodes.shape[1], offset_nodes.shape[2] * offset_nodes.shape[3])
210 | offset_nodes = self.update(offset_nodes)
211 | # nodes: B*T x offset_lstm_dim x num_nodes
212 | return offset_nodes
213 |
214 |
215 | def get_fishnet_params(input_height,keep_spatial_size = False):
216 | first_height = input_height
217 | if input_height == 32:
218 | first_height = 32
219 | strides = [(2, 2), (2,2)]
220 | padding = [(0,0), (0,0)]
221 | norm_size = [15, 7, 7]
222 | elif input_height == 16:
223 | first_height = 16
224 | strides = [(2, 2), (1,1)]
225 | padding = [(0,0), (1,1)]
226 | norm_size = [7, 7, 7]
227 | elif input_height == 8:
228 | strides = [(1, 1), (1,1)]
229 | padding = [(0,0), (1,1)]
230 | norm_size = [6, 6, 6]
231 | elif input_height > 50:
232 | strides = [(2, 2), (2,2)]
233 | padding = [(0,0), (0,0)]
234 | norm_size = [13, 6, 6]
235 | self.max_pool = torch.nn.MaxPool2d(2, stride=2)
236 | first_height = 28
237 | elif input_height > 20:
238 | strides = [(2, 2), (2,2)]
239 | padding = [(0,0), (0,0)]
240 | norm_size = [13, 6, 6]
241 | elif input_height > 10:
242 | strides = [(2, 2), (1,1)]
243 | padding = [(0,0), (1,1)]
244 | norm_size = [6, 6, 6]
245 | else:
246 | strides = [(1, 1), (1,1)]
247 | padding = [(0,0), (1,1)]
248 | norm_size = [5, 5, 5]
249 |
250 | output_padding = (0,0)
251 | if keep_spatial_size:
252 | if input_height == 16:
253 | padding = [(1,1), (1,1)]
254 | output_padding = (1,1)
255 | norm_size = [8, 8, 16]
256 | return first_height, strides, padding, norm_size, output_padding
257 |
258 | class Fishnet(nn.Module):
259 | def __init__(self, input_height=14, input_channels=16, keep_spatial_size=False):
260 | super(Fishnet, self).__init__()
261 |
262 | self.offset_layers = 2
263 | self.offset_channels = [32, 16, 16]
264 | self.offset_channels_tr = [32, 32, 16]
265 | self.input_height = input_height
266 | self.keep_spatial_size = keep_spatial_size
267 |
268 | first_height, strides, padding, norm_size, output_padding = get_fishnet_params(
269 | input_height, keep_spatial_size)
270 | self.norm_size = norm_size
271 |
272 | self.norm_dict = nn.ModuleDict({})
273 |
274 | # reduce channels
275 | self.conv1 = nn.Conv2d(input_channels, self.offset_channels[0], [1, 1]) # de pus relu
276 | self.norm_dict[f'norm1'] = LayerNormAffine2D(
277 | self.offset_channels[0],
278 | (self.offset_channels[0], first_height, first_height)
279 | )
280 |
281 | self.tail_conv = nn.ModuleList()
282 | self.body_trans_conv = nn.ModuleList()
283 | self.head_conv = nn.ModuleList()
284 | for i in range(self.offset_layers):
285 | self.tail_conv.append(nn.Conv2d(
286 | self.offset_channels[max(0,i-1)], self.offset_channels[i],
287 | [3, 3], padding=padding[i], stride=strides[i]
288 | )
289 | )
290 | self.norm_dict[f'tail_norm_{i}'] = LayerNormAffine2D(
291 | self.offset_channels[i],
292 | (self.offset_channels[i], norm_size[i], norm_size[i])
293 | )
294 | stop = 0
295 | if self.keep_spatial_size:
296 | stop = -1
297 | for ind, i in enumerate(range(self.offset_layers - 1, stop,-1)):
298 | if keep_spatial_size and i == 0:
299 | self.body_trans_conv.append(nn.ConvTranspose2d(
300 | self.offset_channels_tr[i+1], self.offset_channels_tr[i],
301 | [3, 3], padding=padding[i], output_padding=output_padding,
302 | stride=strides[i]
303 | )
304 | )
305 | else:
306 | self.body_trans_conv.append(nn.ConvTranspose2d(
307 | self.offset_channels_tr[i+1], self.offset_channels_tr[i],
308 | [3, 3], padding=padding[i], stride=strides[i]
309 | )
310 | )
311 |
312 | self.norm_dict[f'body_trans_norm_{ind}'] = LayerNormAffine2D(
313 | self.offset_channels_tr[i],
314 | (self.offset_channels_tr[i], norm_size[i-1], norm_size[i-1])
315 | )
316 |
317 | if not self.keep_spatial_size:
318 | for i in range(1, self.offset_layers):
319 | self.head_conv.append(nn.Conv2d(
320 | self.offset_channels[i-1], self.offset_channels[i],
321 | [3, 3], padding=padding[i], stride=strides[i]
322 | )
323 | )
324 | self.norm_dict[f'head_norm_{i-1}'] = LayerNormAffine2D(
325 | self.offset_channels[i],
326 | (self.offset_channels[i], norm_size[i], norm_size[i])
327 | )
328 |
329 | def get_norm(self, input, name, zero_init=False):
330 | # input: B * T x C x H x W
331 | norm = self.norm_dict[name]
332 | input = norm(input)
333 | return input
334 |
335 | def forward(self, x):
336 | # x: B * T x C x H x W
337 | x = F.relu(self.conv1(x))
338 | if self.input_height > 50:
339 | x = self.max_pool(x)
340 |
341 | x = self.get_norm(x, 'norm1')
342 | all_x = []
343 | all_x = [x]
344 |
345 | for i in range(self.offset_layers):
346 | x = F.relu(self.tail_conv[i](x))
347 | x = self.get_norm(x, f'tail_norm_{i}')
348 | all_x.append(x)
349 |
350 | stop = 0
351 | if self.keep_spatial_size:
352 | stop = -1
353 |
354 | for ind, i in enumerate(range(self.offset_layers - 1, stop,-1)):
355 | x = F.relu(self.body_trans_conv[ind](x))
356 | x = self.get_norm(x, f'body_trans_norm_{ind}')
357 | x = x + all_x[i]
358 |
359 | if not self.keep_spatial_size:
360 | for i in range(1, self.offset_layers):
361 | x = F.relu(self.head_conv[i-1](x))
362 | x = self.get_norm(x, f'head_norm_{i-1}')
363 |
364 | return x
--------------------------------------------------------------------------------
/test_models.py:
--------------------------------------------------------------------------------
1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding"
2 | # arXiv:1811.08383
3 | # Ji Lin*, Chuang Gan, Song Han
4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu
5 |
6 | # Notice that this file has been modified to support ensemble testing
7 |
8 | import argparse
9 | import time
10 |
11 | import torch.nn.parallel
12 | import torch.optim
13 | from sklearn.metrics import confusion_matrix
14 | from ops.dataset import TSNDataSet
15 | from ops.models import TSN
16 | from ops.transforms import *
17 | from ops import dataset_config
18 | from torch.nn import functional as F
19 | import pdb
20 | # options
21 | import os
22 | from opts import parser
23 | from ops.utils import save_kernels
24 |
25 | global args, best_prec1
26 | args = parser.parse_args()
27 | print(args)
28 |
29 | class AverageMeter(object):
30 | """Computes and stores the average and current value"""
31 | def __init__(self):
32 | self.reset()
33 |
34 | def reset(self):
35 | self.val = 0
36 | self.avg = 0
37 | self.sum = 0
38 | self.count = 0
39 |
40 | def update(self, val, n=1):
41 | self.val = val
42 | self.sum += val * n
43 | self.count += n
44 | self.avg = self.sum / self.count
45 |
46 |
47 | def accuracy(output, target, topk=(1,)):
48 | """Computes the precision@k for the specified values of k"""
49 | maxk = max(topk)
50 | batch_size = target.size(0)
51 | _, pred = output.topk(maxk, 1, True, True)
52 | pred = pred.t()
53 | correct = pred.eq(target.view(1, -1).expand_as(pred))
54 | res = []
55 | for k in topk:
56 | # correct_k = correct[:k].view(-1).float().sum(0)
57 | correct_k = correct[:k].float().sum()
58 | res.append(correct_k.mul_(100.0 / batch_size))
59 | return res
60 |
61 |
62 | def parse_shift_option_from_log_name(log_name):
63 | if 'shift' in log_name:
64 | strings = log_name.split('_')
65 | for i, s in enumerate(strings):
66 | if 'shift' in s:
67 | break
68 | return True, int(strings[i].replace('shift', '')), strings[i + 1]
69 | else:
70 | return False, None, None
71 |
72 |
73 | weights_list = args.weights.split(',')
74 | test_segments_list = [int(s) for s in args.test_segments.split(',')]
75 | assert len(weights_list) == len(test_segments_list)
76 | if args.coeff is None:
77 | coeff_list = [1] * len(weights_list)
78 | else:
79 | coeff_list = [float(c) for c in args.coeff.split(',')]
80 |
81 | if args.test_list is not None:
82 | test_file_list = args.test_list.split(',')
83 | else:
84 | test_file_list = [None] * len(weights_list)
85 |
86 |
87 | data_iter_list = []
88 | net_list = []
89 | modality_list = []
90 |
91 | total_num = None
92 | for this_weights, this_test_segments, test_file in zip(weights_list, test_segments_list, test_file_list):
93 | is_shift, shift_div, shift_place = args.shift, args.shift_div, args.shift_place
94 | modality = args.modality
95 | #this_arch = this_weights.split('TSM_')[1].split('_')[2]
96 | this_arch = args.arch
97 | modality_list.append(modality)
98 | num_class, args.train_list, val_list, root_path, prefix = dataset_config.return_dataset(args.dataset,
99 | modality)
100 | print('=> shift: {}, shift_div: {}, shift_place: {}'.format(is_shift, shift_div, shift_place))
101 | net = TSN(num_class, this_test_segments if is_shift else 1, modality,
102 | base_model=this_arch,
103 | consensus_type=args.crop_fusion_type,
104 | dropout=args.dropout,
105 | img_feature_dim=args.img_feature_dim,
106 | partial_bn=not args.no_partialbn,
107 | fc_lr5=not (args.tune_from and args.dataset in args.tune_from),
108 | pretrain=args.pretrain,
109 | is_shift=is_shift, shift_div=shift_div, shift_place=shift_place,
110 | temporal_pool=args.temporal_pool,
111 | non_local='_nl' in this_weights,
112 | )
113 |
114 | if 'tpool' in this_weights:
115 | from ops.temporal_shift import make_temporal_pool
116 | make_temporal_pool(net.base_model, this_test_segments) # since DataParallel
117 |
118 | print(f'Loading weights from: {this_weights}')
119 | checkpoint = torch.load(this_weights, map_location=torch.device('cpu'))
120 | epoch = checkpoint['epoch']
121 | checkpoint = checkpoint['state_dict']
122 | print(f'Evaluating epoch: {epoch}')
123 | ckpt_dict = {}
124 | for k, v in list(checkpoint.items()):
125 | if ('dynamic_graph.ph' not in k and 'dynamic_graph.pw' not in k and 'dynamic_graph.arange_h' not in k and 'dynamic_graph.arange_w' not in k
126 | and 'const_dh_ones' not in k and 'const_dw_ones' not in k and 'fix_offsets' not in k):
127 | # remove first tag ('module')
128 | key_name = '.'.join(k.split('.')[1:])
129 | # remove tags from checkpoints vars
130 | key_name = key_name.replace('base_model.map_final_project','map_final_project')
131 | key_name = key_name.replace('.block','.0.block')
132 | key_name = key_name.replace('.dynamic_graph.','.1.dynamic_graph.')
133 | key_name = key_name.replace('.norm_dict.residual_norm','.1.norm_dict.residual_norm')
134 | ckpt_dict[key_name] = v
135 |
136 | model_dict = net.state_dict()
137 |
138 | print('Model parameters')
139 | [print(k) for k in model_dict.keys()]
140 |
141 | print('checkpoint parameters')
142 | [print(k) for k in checkpoint.keys()]
143 |
144 | #
145 | for k, v in model_dict.items():
146 | if 'ignore' in k:
147 | old_name = k.replace('ignore.', '')
148 | ckpt_val = checkpoint['module.' +old_name]
149 | del ckpt_dict[old_name]
150 | ckpt_dict[k] = ckpt_val
151 |
152 | model_dict.update(ckpt_dict)
153 | net.load_state_dict(model_dict)
154 |
155 | input_size = net.scale_size if args.full_res else net.input_size
156 |
157 |
158 | scale1 = net.scale_size
159 | crop1 = input_size
160 | if args.test_crops == 1:
161 | cropping = torchvision.transforms.Compose([
162 | GroupScale(scale1),
163 | GroupCenterCrop(crop1),
164 | ])
165 | elif args.test_crops == 3: # do not flip, so only 5 crops
166 | cropping = torchvision.transforms.Compose([
167 | GroupFullResSample(input_size, net.scale_size, flip=False)
168 | ])
169 | elif args.test_crops == 5: # do not flip, so only 5 crops
170 | cropping = torchvision.transforms.Compose([
171 | GroupOverSample(input_size, net.scale_size, flip=False)
172 | ])
173 | elif args.test_crops == 10:
174 | cropping = torchvision.transforms.Compose([
175 | GroupOverSample(input_size, net.scale_size)
176 | ])
177 | else:
178 | raise ValueError("Only 1, 5, 10 crops are supported while we got {}".format(args.test_crops))
179 | data_loader = torch.utils.data.DataLoader(
180 | TSNDataSet(root_path, test_file if test_file is not None else val_list, num_segments=this_test_segments,
181 | new_length=1,
182 | modality=modality,
183 | image_tmpl=prefix,
184 | test_mode=True,
185 | remove_missing=len(weights_list) == 1,
186 | transform=torchvision.transforms.Compose([
187 | cropping,
188 | Stack(roll=(False)),
189 | ToTorchFormatTensor(div=(True)),
190 | GroupNormalize(net.input_mean, net.input_std),
191 | ]), dense_sample=args.dense_sample, twice_sample=args.twice_sample),
192 | batch_size=args.batch_size, shuffle=False,
193 | num_workers=args.workers if torch.cuda.is_available() else 0, pin_memory=True
194 | )
195 |
196 | if args.gpus is not None:
197 | devices = [args.gpus[i] for i in range(args.workers)]
198 | else:
199 | devices = list(range(args.workers))
200 |
201 | if torch.cuda.is_available():
202 | net = torch.nn.DataParallel(net.cuda())
203 | net.eval()
204 | data_gen = enumerate(data_loader)
205 | if total_num is None:
206 | total_num = len(data_loader.dataset)
207 | else:
208 | assert total_num == len(data_loader.dataset)
209 | data_iter_list.append(data_gen)
210 | net_list.append(net)
211 | output = []
212 |
213 | def eval_video(video_data, net, this_test_segments, modality, mode='eval'):
214 | if mode == 'eval':
215 | net.eval()
216 | else:
217 | print("something wrong. double check the flags")
218 | net.train()
219 | with torch.no_grad():
220 | i, data, label, _ = video_data
221 | batch_size = label.numel()
222 | num_crop = args.test_crops
223 | if args.dense_sample:
224 | num_crop *= 10 # 10 clips for testing when using dense sample
225 |
226 | if args.twice_sample:
227 | num_crop *= 2
228 |
229 | length = 3
230 | data_in = data
231 |
232 | if is_shift:
233 | data_in = data_in.view(batch_size * num_crop, this_test_segments, length, data_in.size(2), data_in.size(3))
234 |
235 |
236 | rst, model_aux_feats = net(data_in)
237 | aux = model_aux_feats['interm_feats']
238 | offsets = model_aux_feats['offsets']
239 |
240 | list_save_iter = [0,100,1000]
241 | if args.save_kernels and i in list_save_iter:
242 | folder = os.path.join(args.model_dir, args.store_name, args.root_log,'kernels')
243 | save_kernels(data_in, aux, folder=folder, name=f'validation_iter_{i}', predicted_offsets=offsets)
244 | rst = rst.reshape(batch_size, num_crop, -1).mean(1)
245 |
246 | if args.softmax:
247 | # take the softmax to normalize the output to probability
248 | rst = F.softmax(rst, dim=1)
249 | rst = rst.data.cpu().numpy().copy()
250 |
251 | if torch.cuda.is_available():
252 | if net.module.is_shift:
253 | rst = rst.reshape(batch_size, num_class)
254 | else:
255 | rst = rst.reshape((batch_size, -1, num_class)).mean(axis=1).reshape((batch_size, num_class))
256 | else:
257 | if net.is_shift:
258 | rst = rst.reshape(batch_size, num_class)
259 | else:
260 | rst = rst.reshape((batch_size, -1, num_class)).mean(axis=1).reshape((batch_size, num_class))
261 | return i, rst, label
262 |
263 |
264 | proc_start_time = time.time()
265 | max_num = args.max_num if args.max_num > 0 else total_num
266 |
267 | top1 = AverageMeter()
268 | top5 = AverageMeter()
269 |
270 | all_batch_time = []
271 | for i, data_label_pairs in enumerate(zip(*data_iter_list)):
272 | with torch.no_grad():
273 | if i >= max_num:
274 | break
275 | this_rst_list = []
276 | this_label = None
277 | begin_proc = time.time()
278 | for n_seg, (_, (_, data, label,detector_out)), net, modality in zip(test_segments_list, data_label_pairs, net_list, modality_list):
279 | rst = eval_video((i, data, label,detector_out), net, n_seg, modality, mode='eval')
280 | this_rst_list.append(rst[1])
281 | this_label = label
282 | assert len(this_rst_list) == len(coeff_list)
283 | for i_coeff in range(len(this_rst_list)):
284 | this_rst_list[i_coeff] *= coeff_list[i_coeff]
285 | ensembled_predict = sum(this_rst_list) / len(this_rst_list)
286 |
287 | for p, g in zip(ensembled_predict, this_label.cpu().numpy()):
288 | output.append([p[None, ...], g])
289 | cnt_time = time.time() - proc_start_time
290 | batch_time = time.time() - begin_proc
291 | if i > 0:
292 | all_batch_time.append(batch_time)
293 | prec1, prec5 = accuracy(torch.from_numpy(ensembled_predict), this_label, topk=(1, 5))
294 | top1.update(prec1.item(), this_label.numel())
295 | top5.update(prec5.item(), this_label.numel())
296 | if i % 5 == 0:
297 | print('video {} done, total {}/{}, average {:.3f} sec/video, '
298 | 'moving Prec@1 {:.3f} Prec@5 {:.3f}'.format(i * args.batch_size, i * args.batch_size, total_num,
299 | float(cnt_time) / (i+1) / args.batch_size, top1.avg, top5.avg))
300 | print(f"average sec/video: {np.array(all_batch_time).mean() / args.batch_size} ")
301 | video_pred = [np.argmax(x[0]) for x in output]
302 | video_pred_top5 = [np.argsort(np.mean(x[0], axis=0).reshape(-1))[::-1][:5] for x in output]
303 |
304 | video_labels = [x[1] for x in output]
305 |
306 |
307 | if args.csv_file is not None:
308 | print('=> Writing result to csv file: {}'.format(args.csv_file))
309 | with open(test_file_list[0].replace('test_videofolder.txt', 'category.txt')) as f:
310 | categories = f.readlines()
311 | categories = [f.strip() for f in categories]
312 | with open(test_file_list[0]) as f:
313 | vid_names = f.readlines()
314 | vid_names = [n.split(' ')[0] for n in vid_names]
315 | assert len(vid_names) == len(video_pred)
316 | if args.dataset != 'somethingv2': # only output top1
317 | with open(args.csv_file, 'w') as f:
318 | for n, pred in zip(vid_names, video_pred):
319 | f.write('{};{}\n'.format(n, categories[pred]))
320 | else:
321 | with open(args.csv_file, 'w') as f:
322 | for n, pred5 in zip(vid_names, video_pred_top5):
323 | fill = [n]
324 | for p in list(pred5):
325 | fill.append(p)
326 | f.write('{};{};{};{};{};{}\n'.format(*fill))
327 |
328 |
329 | cf = confusion_matrix(video_labels, video_pred).astype(float)
330 |
331 | np.save('cm.npy', cf)
332 | cls_cnt = cf.sum(axis=1)
333 | cls_hit = np.diag(cf)
334 | eps = 0.0000001
335 | cls_acc = cls_hit / (cls_cnt + eps)
336 | print(cls_acc)
337 | upper = np.mean(np.max(cf, axis=1) / (cls_cnt + eps))
338 | print('upper bound: {}'.format(upper))
339 |
340 | print('-----Evaluation is finished------')
341 | print('Class Accuracy {:.02f}%'.format(np.mean(cls_acc) * 100))
342 | print('Overall Prec@1 {:.02f}% Prec@5 {:.02f}%'.format(top1.avg, top5.avg))
343 |
344 |
345 |
--------------------------------------------------------------------------------
/ops/resnet3d_xl.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 | import math
6 | import numpy as np
7 |
8 | import pdb
9 | from functools import partial
10 |
11 | from opts import parser
12 | args = parser.parse_args()
13 | from ops.rstg import *
14 |
15 | __all__ = [
16 | 'ResNet', 'resnet10', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
17 | 'resnet152', 'resnet200',
18 | ]
19 |
20 |
21 | def conv3x3x3(in_planes, out_planes, stride=1):
22 | # 3x3x3 convolution with padding
23 | return nn.Conv3d(
24 | in_planes,
25 | out_planes,
26 | kernel_size=3,
27 | stride=stride,
28 | padding=1,
29 | bias=False)
30 |
31 |
32 | def downsample_basic_block(x, planes, stride):
33 | out = F.avg_pool3d(x, kernel_size=1, stride=stride)
34 | zero_pads = torch.Tensor(
35 | out.size(0), planes - out.size(1), out.size(2), out.size(3),
36 | out.size(4)).zero_()
37 | if isinstance(out.data, torch.cuda.FloatTensor):
38 | zero_pads = zero_pads.cuda()
39 |
40 | out = Variable(torch.cat([out.data, zero_pads], dim=1))
41 |
42 | return out
43 |
44 |
45 | class BasicBlock(nn.Module):
46 | expansion = 1
47 |
48 | def __init__(self, inplanes, planes, stride=1, downsample=None):
49 | super(BasicBlock, self).__init__()
50 | self.conv1 = conv3x3x3(inplanes, planes, stride)
51 | self.bn1 = nn.BatchNorm3d(planes)
52 | self.relu = nn.ReLU(inplace=True)
53 | self.conv2 = conv3x3x3(planes, planes)
54 | self.bn2 = nn.BatchNorm3d(planes)
55 | self.downsample = downsample
56 | self.stride = stride
57 |
58 | def forward(self, x):
59 | residual = x
60 |
61 | out = self.conv1(x)
62 | out = self.bn1(out)
63 | out = self.relu(out)
64 |
65 | out = self.conv2(out)
66 | out = self.bn2(out)
67 |
68 | if self.downsample is not None:
69 | residual = self.downsample(x)
70 |
71 | out += residual
72 | out = self.relu(out)
73 |
74 | return out
75 |
76 |
77 | class Bottleneck(nn.Module):
78 | conv_op = None
79 | offset_groups = 1
80 |
81 | def __init__(self, dim_in, dim_out, stride, dim_inner, group=1, use_temp_conv=1, temp_stride=1, dcn=False,
82 | shortcut_type='B'):
83 | super(Bottleneck, self).__init__()
84 | # 1 x 1 layer
85 | self.with_dcn = dcn
86 | self.conv1 = self.Conv3dBN(dim_in, dim_inner, (1 + use_temp_conv * 2, 1, 1), (temp_stride, 1, 1),
87 | (use_temp_conv, 0, 0))
88 | self.relu = nn.ReLU(inplace=True)
89 | # 3 x 3 layer
90 | self.conv2 = self.Conv3dBN(dim_inner, dim_inner, (1, 3, 3), (1, stride, stride), (0, 1, 1))
91 | # 1 x 1 layer
92 | self.conv3 = self.Conv3dBN(dim_inner, dim_out, (1, 1, 1), (1, 1, 1), (0, 0, 0))
93 |
94 | self.shortcut_type = shortcut_type
95 | self.dim_in = dim_in
96 | self.dim_out = dim_out
97 | self.temp_stride = temp_stride
98 | self.stride = stride
99 | # nn.Conv3d(dim_in, dim_out, (1,1,1),(temp_stride,stride,stride),(0,0,0))
100 | if self.shortcut_type == 'B':
101 | if self.dim_in == self.dim_out and self.temp_stride == 1 and self.stride == 1: # or (self.dim_in == self.dim_out and self.dim_in == 64 and self.stride ==1):
102 |
103 | pass
104 | else:
105 | # pass
106 | self.shortcut = self.Conv3dBN(dim_in, dim_out, (1, 1, 1), (temp_stride, stride, stride), (0, 0, 0))
107 |
108 | # nn.Conv3d(dim_in,dim_inner,kernel_size=(1+use_temp_conv*2,1,1),stride = (temp_stride,1,1),padding = )
109 |
110 | def forward(self, x):
111 | residual = x
112 | out = self.conv1(x)
113 | out = self.relu(out)
114 | out = self.conv2(out)
115 | out = self.relu(out)
116 | out = self.conv3(out)
117 | if self.dim_in == self.dim_out and self.temp_stride == 1 and self.stride == 1:
118 | pass
119 | else:
120 | residual = self.shortcut(residual)
121 | out += residual
122 | out = self.relu(out)
123 | return out
124 |
125 | def Conv3dBN(self, dim_in, dim_out, kernels, strides, pads, group=1):
126 | if self.with_dcn and kernels[0] > 1:
127 | # use deformable conv
128 | return nn.Sequential(
129 | self.conv_op(dim_in, dim_out, kernel_size=kernels, stride=strides, padding=pads, bias=False,
130 | offset_groups=self.offset_groups),
131 | nn.BatchNorm3d(dim_out)
132 | )
133 | else:
134 | return nn.Sequential(
135 | nn.Conv3d(dim_in, dim_out, kernel_size=kernels, stride=strides, padding=pads, bias=False),
136 | nn.BatchNorm3d(dim_out)
137 | )
138 |
139 |
140 | class ResNet(nn.Module):
141 |
142 | def __init__(self,
143 | block,
144 | layers,
145 | use_temp_convs_set,
146 | temp_strides_set,
147 | sample_size,
148 | sample_duration,
149 | shortcut_type='B',
150 | num_classes=400,
151 | stage_with_dcn=(False, False, False, False),
152 | extract_features=False,
153 | loss_type='softmax'):
154 | super(ResNet, self).__init__()
155 | self.extract_features = extract_features
156 | self.stage_with_dcn = stage_with_dcn
157 | self.group = 1
158 | self.width_per_group = 64
159 | self.dim_inner = self.group * self.width_per_group
160 | # self.shortcut_type = shortcut_type
161 | self.conv1 = nn.Conv3d(
162 | 3,
163 | 64,
164 | kernel_size=(1 + use_temp_convs_set[0][0] * 2, 7, 7),
165 | stride=(temp_strides_set[0][0], 2, 2),
166 | padding=(use_temp_convs_set[0][0], 3, 3),
167 | bias=False)
168 | self.bn1 = nn.BatchNorm3d(64)
169 | self.relu = nn.ReLU(inplace=True)
170 | self.maxpool1 = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 0, 0))
171 | with_dcn = True if self.stage_with_dcn[0] else False
172 | self.layer1 = self._make_layer(block, 64, 256, shortcut_type, stride=1, num_blocks=layers[0],
173 | dim_inner=self.dim_inner, group=self.group, use_temp_convs=use_temp_convs_set[1],
174 | temp_strides=temp_strides_set[1], dcn=with_dcn)
175 | self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
176 | with_dcn = True if self.stage_with_dcn[1] else False
177 | self.layer2 = self._make_layer(block, 256, 512, shortcut_type, stride=2, num_blocks=layers[1],
178 | dim_inner=self.dim_inner * 2, group=self.group,
179 | use_temp_convs=use_temp_convs_set[2], temp_strides=temp_strides_set[2],
180 | dcn=with_dcn)
181 | with_dcn = True if self.stage_with_dcn[2] else False
182 | self.layer3 = self._make_layer(block, 512, 1024, shortcut_type, stride=2, num_blocks=layers[2],
183 | dim_inner=self.dim_inner * 4, group=self.group,
184 | use_temp_convs=use_temp_convs_set[3], temp_strides=temp_strides_set[3],
185 | dcn=with_dcn)
186 | with_dcn = True if self.stage_with_dcn[3] else False
187 | self.layer4 = self._make_layer(block, 1024, 2048, shortcut_type, stride=1, num_blocks=layers[3],
188 | dim_inner=self.dim_inner * 8, group=self.group,
189 | use_temp_convs=use_temp_convs_set[4], temp_strides=temp_strides_set[4],
190 | dcn=with_dcn)
191 | last_duration = int(math.ceil(sample_duration / 2)) # int(math.ceil(sample_duration / 8))
192 | last_size = int(math.ceil(sample_size / 16))
193 | # self.avgpool = nn.AvgPool3d((last_duration, last_size, last_size), stride=1) #nn.AdaptiveAvgPool3d((1, 1, 1)) #
194 | self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
195 | self.dropout = torch.nn.Dropout(p=0.5)
196 | self.classifier = nn.Linear(2048, num_classes)
197 |
198 | for m in self.modules():
199 | # if isinstance(m, nn.Conv3d):
200 | # m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out')
201 | # elif isinstance(m,nn.Linear):
202 | # m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out')
203 | # elif
204 | if isinstance(m, nn.BatchNorm3d):
205 | m.weight.data.fill_(1)
206 | m.bias.data.zero_()
207 |
208 | def _make_layer(self, block, dim_in, dim_out, shortcut_type, stride, num_blocks, dim_inner=None, group=None,
209 | use_temp_convs=None, temp_strides=None, dcn=False):
210 | if use_temp_convs is None:
211 | use_temp_convs = np.zeros(num_blocks).astype(int)
212 | if temp_strides is None:
213 | temp_strides = np.ones(num_blocks).astype(int)
214 | if len(use_temp_convs) < num_blocks:
215 | for _ in range(num_blocks - len(use_temp_convs)):
216 | use_temp_convs.append(0)
217 | temp_strides.append(1)
218 | layers = []
219 | for idx in range(num_blocks):
220 | block_stride = 2 if (idx == 0 and stride == 2) else 1
221 |
222 | layers.append(
223 | block(dim_in, dim_out, block_stride, dim_inner, group, use_temp_convs[idx], temp_strides[idx], dcn))
224 | dim_in = dim_out
225 | return nn.Sequential(*layers)
226 |
227 | def forward_single(self, x):
228 | x = self.conv1(x)
229 |
230 | x = self.bn1(x)
231 | x = self.relu(x)
232 | x = self.maxpool1(x)
233 |
234 | x = self.layer1(x)
235 | x = self.maxpool2(x)
236 | x = self.layer2(x)
237 |
238 | x = self.layer3(x)
239 | features = self.layer4(x)
240 |
241 | x = self.avgpool(features)
242 |
243 | y = x
244 | # x = x.view(x.size(0), -1)
245 | # x = self.dropout(x)
246 |
247 | # y = self.classifier(x)
248 | if self.extract_features:
249 | return y, features
250 | else:
251 | return y
252 |
253 | def forward_multi(self, x):
254 | clip_preds = []
255 | # import ipdb;ipdb.set_trace()
256 | for clip_idx in range(x.shape[1]): # B, 10, 3, 3, 32, 224, 224
257 | spatial_crops = []
258 | for crop_idx in range(x.shape[2]):
259 | clip = x[:, clip_idx, crop_idx]
260 | clip = self.forward_single(clip)
261 | spatial_crops.append(clip)
262 | spatial_crops = torch.stack(spatial_crops, 1).mean(1) # (B, 400)
263 | clip_preds.append(spatial_crops)
264 | clip_preds = torch.stack(clip_preds, 1).mean(1) # (B, 400)
265 | return clip_preds
266 |
267 | def forward(self, x):
268 | # pdb.set_trace()
269 | # x: BT x 3 x H x W -> B x T x 3 x H x W
270 | # pdb.set_trace()
271 | x = x.view([args.batch_size, args.num_segments, x.shape[-3], x.shape[-2], x.shape[-1]])
272 | x = x.permute([0,2,1,3,4])
273 | # 5D tensor == single clip
274 | if x.dim() == 5:
275 | pred = self.forward_single(x)
276 |
277 | # 7D tensor == 3 crops/10 clips
278 | elif x.dim() == 7:
279 | pred = self.forward_multi(x)
280 |
281 | # loss_dict = {}
282 | # if 'label' in batch:
283 | # loss = F.cross_entropy(pred, batch['label'], reduction='none')
284 | # loss_dict = {'clf': loss}
285 |
286 | return pred
287 |
288 |
289 | def get_fine_tuning_parameters(model, ft_begin_index):
290 | if ft_begin_index == 0:
291 | return model.parameters()
292 |
293 | ft_module_names = []
294 | for i in range(ft_begin_index, 5):
295 | ft_module_names.append('layer{}'.format(i))
296 | ft_module_names.append('fc')
297 | # import ipdb;ipdb.set_trace()
298 | parameters = []
299 | for k, v in model.named_parameters():
300 | for ft_module in ft_module_names:
301 | if ft_module in k:
302 | parameters.append({'params': v})
303 | break
304 | else:
305 | parameters.append({'params': v, 'lr': 0.0})
306 |
307 | return parameters
308 |
309 |
310 | def obtain_arc(arc_type):
311 | # c2d, ResNet50
312 | if arc_type == 1:
313 | use_temp_convs_1 = [0]
314 | temp_strides_1 = [2]
315 | use_temp_convs_2 = [0, 0, 0]
316 | temp_strides_2 = [1, 1, 1]
317 | use_temp_convs_3 = [0, 0, 0, 0]
318 | temp_strides_3 = [1, 1, 1, 1]
319 | use_temp_convs_4 = [0, ] * 6
320 | temp_strides_4 = [1, ] * 6
321 | use_temp_convs_5 = [0, 0, 0]
322 | temp_strides_5 = [1, 1, 1]
323 |
324 | # i3d, ResNet50
325 | if arc_type == 2:
326 | use_temp_convs_1 = [2]
327 | temp_strides_1 = [1]
328 | use_temp_convs_2 = [1, 1, 1]
329 | temp_strides_2 = [1, 1, 1]
330 | use_temp_convs_3 = [1, 0, 1, 0]
331 | temp_strides_3 = [1, 1, 1, 1]
332 | use_temp_convs_4 = [1, 0, 1, 0, 1, 0]
333 | temp_strides_4 = [1, 1, 1, 1, 1, 1]
334 | use_temp_convs_5 = [0, 1, 0]
335 | temp_strides_5 = [1, 1, 1]
336 |
337 | # c2d, ResNet101
338 | if arc_type == 3:
339 | use_temp_convs_1 = [0]
340 | temp_strides_1 = [2]
341 | use_temp_convs_2 = [0, 0, 0]
342 | temp_strides_2 = [1, 1, 1]
343 | use_temp_convs_3 = [0, 0, 0, 0]
344 | temp_strides_3 = [1, 1, 1, 1]
345 | use_temp_convs_4 = [0, ] * 23
346 | temp_strides_4 = [1, ] * 23
347 | use_temp_convs_5 = [0, 0, 0]
348 | temp_strides_5 = [1, 1, 1]
349 |
350 | # i3d, ResNet101
351 | if arc_type == 4:
352 | use_temp_convs_1 = [2]
353 | temp_strides_1 = [2]
354 | use_temp_convs_2 = [1, 1, 1]
355 | temp_strides_2 = [1, 1, 1]
356 | use_temp_convs_3 = [1, 0, 1, 0]
357 | temp_strides_3 = [1, 1, 1, 1]
358 | use_temp_convs_4 = []
359 | for i in range(23):
360 | if i % 2 == 0:
361 | use_temp_convs_4.append(1)
362 | else:
363 | use_temp_convs_4.append(0)
364 |
365 | temp_strides_4 = [1, ] * 23
366 | use_temp_convs_5 = [0, 1, 0]
367 | temp_strides_5 = [1, 1, 1]
368 |
369 | use_temp_convs_set = [use_temp_convs_1, use_temp_convs_2, use_temp_convs_3, use_temp_convs_4, use_temp_convs_5]
370 | temp_strides_set = [temp_strides_1, temp_strides_2, temp_strides_3, temp_strides_4, temp_strides_5]
371 |
372 | return use_temp_convs_set, temp_strides_set
373 |
374 |
375 | def resnet10(**kwargs):
376 | """Constructs a ResNet-18 model.
377 | """
378 | use_temp_convs_set = []
379 | temp_strides_set = []
380 | model = ResNet(BasicBlock, [1, 1, 1, 1], use_temp_convs_set, temp_strides_set, **kwargs)
381 | return model
382 |
383 |
384 | def resnet18(**kwargs):
385 | """Constructs a ResNet-18 model.
386 | """
387 | use_temp_convs_set = []
388 | temp_strides_set = []
389 | model = ResNet(BasicBlock, [2, 2, 2, 2], use_temp_convs_set, temp_strides_set, **kwargs)
390 | return model
391 |
392 |
393 | def resnet34(**kwargs):
394 | """Constructs a ResNet-34 model.
395 | """
396 | use_temp_convs_set = []
397 | temp_strides_set = []
398 | model = ResNet(BasicBlock, [3, 4, 6, 3], use_temp_convs_set, temp_strides_set, **kwargs)
399 | return model
400 |
401 |
402 | def resnet50(extract_features, **kwargs):
403 | """Constructs a ResNet-50 model.
404 | """
405 | use_temp_convs_set, temp_strides_set = obtain_arc(2)
406 | model = ResNet(Bottleneck, [3, 4, 6, 3], use_temp_convs_set, temp_strides_set,
407 | extract_features=extract_features, **kwargs)
408 | return model
409 |
410 |
411 | def resnet101(**kwargs):
412 | """Constructs a ResNet-101 model.
413 | """
414 | use_temp_convs_set, temp_strides_set = obtain_arc(4)
415 | model = ResNet(Bottleneck, [3, 4, 23, 3], use_temp_convs_set, temp_strides_set, **kwargs)
416 | return model
417 |
418 |
419 | def resnet152(**kwargs):
420 | """Constructs a ResNet-101 model.
421 | """
422 | use_temp_convs_set = []
423 | temp_strides_set = []
424 | model = ResNet(Bottleneck, [3, 8, 36, 3], use_temp_convs_set, temp_strides_set, **kwargs)
425 | return model
426 |
427 |
428 | def resnet200(**kwargs):
429 | """Constructs a ResNet-101 model.
430 | """
431 | use_temp_convs_set = []
432 | temp_strides_set = []
433 | model = ResNet(Bottleneck, [3, 24, 36, 3], use_temp_convs_set, temp_strides_set, **kwargs)
434 | return model
435 |
436 |
437 | def Net(num_classes, extract_features=False, loss_type='softmax',
438 | weights=None, freeze_all_but_cls=False):
439 | net = globals()['resnet' + str(50)](
440 | num_classes=num_classes,
441 | sample_size=50,
442 | sample_duration=32,
443 | extract_features=extract_features,
444 | loss_type=loss_type,
445 | )
446 |
447 | if weights is not None:
448 | kinetics_weights = torch.load(weights)['state_dict']
449 | print("Found weights in {}.".format(weights))
450 | cls_name = 'fc'
451 | else:
452 | kinetics_weights = torch.load('kinetics-res50.pth')
453 | cls_name = 'fc'
454 | print('\n Restoring Kintetics \n')
455 |
456 | new_weights = {}
457 | for k, v in kinetics_weights.items():
458 | if not k.startswith('module.' + cls_name):
459 | new_weights[k.replace('module.', '')] = v
460 | else:
461 | print(f"!!! Smt wrong with restore {k}")
462 | net.load_state_dict(new_weights, strict=False)
463 |
464 | if freeze_all_but_cls:
465 | for name, par in net.named_parameters():
466 | if not name.startswith('classifier'):
467 | par.requires_grad = False
468 | return net
469 |
--------------------------------------------------------------------------------
/ops/resnet2d.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | import torch.nn as nn
4 | # from .utils import load_state_dict_from_url
5 | from typing import Type, Any, Callable, Union, List, Optional
6 | import pdb
7 | try:
8 | from torch.hub import load_state_dict_from_url
9 | except ImportError:
10 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
11 |
12 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
13 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
14 | 'wide_resnet50_2', 'wide_resnet101_2']
15 |
16 |
17 | model_urls = {
18 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
19 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
20 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
21 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
22 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
23 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
24 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
25 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
26 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
27 | }
28 |
29 |
30 |
31 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
32 | """3x3 convolution with padding"""
33 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
34 | padding=dilation, groups=groups, bias=False, dilation=dilation)
35 |
36 |
37 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
38 | """1x1 convolution"""
39 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
40 |
41 |
42 | class BasicBlock(nn.Module):
43 | expansion: int = 1
44 |
45 | def __init__(
46 | self,
47 | inplanes: int,
48 | planes: int,
49 | stride: int = 1,
50 | downsample: Optional[nn.Module] = None,
51 | groups: int = 1,
52 | base_width: int = 64,
53 | dilation: int = 1,
54 | norm_layer: Optional[Callable[..., nn.Module]] = None
55 | ) -> None:
56 | super(BasicBlock, self).__init__()
57 | if norm_layer is None:
58 | norm_layer = nn.BatchNorm2d
59 | if groups != 1 or base_width != 64:
60 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
61 | if dilation > 1:
62 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
63 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
64 | self.conv1 = conv3x3(inplanes, planes, stride)
65 | self.bn1 = norm_layer(planes)
66 | self.relu = nn.ReLU(inplace=True)
67 | self.conv2 = conv3x3(planes, planes)
68 | self.bn2 = norm_layer(planes)
69 | self.downsample = downsample
70 | self.stride = stride
71 |
72 | def forward(self, x: Tensor) -> Tensor:
73 | identity = x
74 |
75 | out = self.conv1(x)
76 | out = self.bn1(out)
77 | out = self.relu(out)
78 |
79 | out = self.conv2(out)
80 | out = self.bn2(out)
81 |
82 | if self.downsample is not None:
83 | identity = self.downsample(x)
84 |
85 | out += identity
86 | out = self.relu(out)
87 |
88 | return out
89 |
90 |
91 | class Bottleneck(nn.Module):
92 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
93 | # while original implementation places the stride at the first 1x1 convolution(self.conv1)
94 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
95 | # This variant is also known as ResNet V1.5 and improves accuracy according to
96 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
97 |
98 | expansion: int = 4
99 |
100 | def __init__(
101 | self,
102 | inplanes: int,
103 | planes: int,
104 | stride: int = 1,
105 | downsample: Optional[nn.Module] = None,
106 | groups: int = 1,
107 | base_width: int = 64,
108 | dilation: int = 1,
109 | norm_layer: Optional[Callable[..., nn.Module]] = None
110 | ) -> None:
111 | super(Bottleneck, self).__init__()
112 | if norm_layer is None:
113 | norm_layer = nn.BatchNorm2d
114 | width = int(planes * (base_width / 64.)) * groups
115 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
116 | self.conv1 = conv1x1(inplanes, width)
117 | self.bn1 = norm_layer(width)
118 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
119 | self.bn2 = norm_layer(width)
120 | self.conv3 = conv1x1(width, planes * self.expansion)
121 | self.bn3 = norm_layer(planes * self.expansion)
122 | self.relu = nn.ReLU(inplace=True)
123 | self.downsample = downsample
124 | self.stride = stride
125 |
126 | def forward(self, x: Tensor) -> Tensor:
127 | identity = x
128 |
129 | out = self.conv1(x)
130 | out = self.bn1(out)
131 | out = self.relu(out)
132 |
133 | out = self.conv2(out)
134 | out = self.bn2(out)
135 | out = self.relu(out)
136 |
137 | out = self.conv3(out)
138 | out = self.bn3(out)
139 |
140 | if self.downsample is not None:
141 | identity = self.downsample(x)
142 |
143 | out += identity
144 | out = self.relu(out)
145 |
146 | return out
147 |
148 |
149 | class ResNet(nn.Module):
150 |
151 | def __init__(
152 | self,
153 | block: Type[Union[BasicBlock, Bottleneck]],
154 | layers: List[int],
155 | num_classes: int = 1000,
156 | zero_init_residual: bool = False,
157 | groups: int = 1,
158 | width_per_group: int = 64,
159 | replace_stride_with_dilation: Optional[List[bool]] = None,
160 | norm_layer: Optional[Callable[..., nn.Module]] = None
161 | ) -> None:
162 | super(ResNet, self).__init__()
163 | if norm_layer is None:
164 | norm_layer = nn.BatchNorm2d
165 | self._norm_layer = norm_layer
166 |
167 | self.inplanes = 64
168 | self.dilation = 1
169 | if replace_stride_with_dilation is None:
170 | # each element in the tuple indicates if we should replace
171 | # the 2x2 stride with a dilated convolution instead
172 | replace_stride_with_dilation = [False, False, False]
173 | if len(replace_stride_with_dilation) != 3:
174 | raise ValueError("replace_stride_with_dilation should be None "
175 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
176 | self.groups = groups
177 | self.base_width = width_per_group
178 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
179 | bias=False)
180 | self.bn1 = norm_layer(self.inplanes)
181 | self.relu = nn.ReLU(inplace=True)
182 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
183 | self.layer1 = self._make_layer(block, 64, layers[0])
184 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
185 | dilate=replace_stride_with_dilation[0])
186 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
187 | dilate=replace_stride_with_dilation[1])
188 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
189 | dilate=replace_stride_with_dilation[2])
190 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
191 | self.fc = nn.Linear(512 * block.expansion, num_classes)
192 |
193 | for m in self.modules():
194 | if isinstance(m, nn.Conv2d):
195 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
196 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
197 | nn.init.constant_(m.weight, 1)
198 | nn.init.constant_(m.bias, 0)
199 |
200 | # Zero-initialize the last BN in each residual branch,
201 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
202 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
203 | if zero_init_residual:
204 | for m in self.modules():
205 | if isinstance(m, Bottleneck):
206 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
207 | elif isinstance(m, BasicBlock):
208 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
209 |
210 | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
211 | stride: int = 1, dilate: bool = False) -> nn.Sequential:
212 | norm_layer = self._norm_layer
213 | downsample = None
214 | previous_dilation = self.dilation
215 | if dilate:
216 | self.dilation *= stride
217 | stride = 1
218 | if stride != 1 or self.inplanes != planes * block.expansion:
219 | downsample = nn.Sequential(
220 | conv1x1(self.inplanes, planes * block.expansion, stride),
221 | norm_layer(planes * block.expansion),
222 | )
223 |
224 | layers = []
225 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
226 | self.base_width, previous_dilation, norm_layer))
227 | self.inplanes = planes * block.expansion
228 | for _ in range(1, blocks):
229 | layers.append(block(self.inplanes, planes, groups=self.groups,
230 | base_width=self.base_width, dilation=self.dilation,
231 | norm_layer=norm_layer))
232 |
233 | return nn.Sequential(*layers)
234 |
235 | def _forward_impl(self, x) -> Tensor:
236 | # See note [TorchScript super()]
237 | x = self.conv1(x)
238 | x = self.bn1(x)
239 | x = self.relu(x)
240 | x = self.maxpool(x)
241 | x = self.layer1(x)
242 | x = self.layer2(x)
243 | x = self.layer3(x)
244 | x = self.layer4(x)
245 | x = self.avgpool(x)
246 | x = torch.flatten(x, 1)
247 | x = self.fc(x)
248 |
249 | return x
250 |
251 | def forward(self, x: Tensor) -> Tensor:
252 | return self._forward_impl(x)
253 |
254 |
255 | def _resnet(
256 | arch: str,
257 | block: Type[Union[BasicBlock, Bottleneck]],
258 | layers: List[int],
259 | pretrained: bool,
260 | progress: bool,
261 | **kwargs: Any
262 | ) -> ResNet:
263 | model = ResNet(block, layers, **kwargs)
264 | if pretrained:
265 | ckpt_dict = load_state_dict_from_url(model_urls[arch],
266 | progress=progress)
267 |
268 | model_dict = model.state_dict()
269 | restore_dict = {}
270 | for k, v in ckpt_dict.items():
271 | restore_dict[k] = v
272 |
273 | for k, v in model_dict.items():
274 | if 'ignore' in k:
275 | old_name = k.replace('ignore.', '')
276 | ckpt_val = ckpt_dict[old_name]
277 | del restore_dict[old_name]
278 | restore_dict[k] = ckpt_val
279 |
280 | model.load_state_dict(restore_dict)
281 | return model
282 |
283 |
284 | def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
285 | r"""ResNet-18 model from
286 | `"Deep Residual Learning for Image Recognition" `_
287 |
288 | Args:
289 | pretrained (bool): If True, returns a model pre-trained on ImageNet
290 | progress (bool): If True, displays a progress bar of the download to stderr
291 | """
292 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
293 | **kwargs)
294 |
295 |
296 | def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
297 | r"""ResNet-34 model from
298 | `"Deep Residual Learning for Image Recognition" `_
299 |
300 | Args:
301 | pretrained (bool): If True, returns a model pre-trained on ImageNet
302 | progress (bool): If True, displays a progress bar of the download to stderr
303 | """
304 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
305 | **kwargs)
306 |
307 |
308 | def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
309 | r"""ResNet-50 model from
310 | `"Deep Residual Learning for Image Recognition" `_
311 |
312 | Args:
313 | pretrained (bool): If True, returns a model pre-trained on ImageNet
314 | progress (bool): If True, displays a progress bar of the download to stderr
315 | """
316 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
317 | **kwargs)
318 |
319 |
320 | def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
321 | r"""ResNet-101 model from
322 | `"Deep Residual Learning for Image Recognition" `_
323 |
324 | Args:
325 | pretrained (bool): If True, returns a model pre-trained on ImageNet
326 | progress (bool): If True, displays a progress bar of the download to stderr
327 | """
328 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
329 | **kwargs)
330 |
331 |
332 | def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
333 | r"""ResNet-152 model from
334 | `"Deep Residual Learning for Image Recognition" `_
335 |
336 | Args:
337 | pretrained (bool): If True, returns a model pre-trained on ImageNet
338 | progress (bool): If True, displays a progress bar of the download to stderr
339 | """
340 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
341 | **kwargs)
342 |
343 |
344 | def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
345 | r"""ResNeXt-50 32x4d model from
346 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
347 |
348 | Args:
349 | pretrained (bool): If True, returns a model pre-trained on ImageNet
350 | progress (bool): If True, displays a progress bar of the download to stderr
351 | """
352 | kwargs['groups'] = 32
353 | kwargs['width_per_group'] = 4
354 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
355 | pretrained, progress, **kwargs)
356 |
357 |
358 | def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
359 | r"""ResNeXt-101 32x8d model from
360 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
361 |
362 | Args:
363 | pretrained (bool): If True, returns a model pre-trained on ImageNet
364 | progress (bool): If True, displays a progress bar of the download to stderr
365 | """
366 | kwargs['groups'] = 32
367 | kwargs['width_per_group'] = 8
368 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
369 | pretrained, progress, **kwargs)
370 |
371 |
372 | def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
373 | r"""Wide ResNet-50-2 model from
374 | `"Wide Residual Networks" `_
375 |
376 | The model is the same as ResNet except for the bottleneck number of channels
377 | which is twice larger in every block. The number of channels in outer 1x1
378 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
379 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
380 |
381 | Args:
382 | pretrained (bool): If True, returns a model pre-trained on ImageNet
383 | progress (bool): If True, displays a progress bar of the download to stderr
384 | """
385 | kwargs['width_per_group'] = 64 * 2
386 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
387 | pretrained, progress, **kwargs)
388 |
389 |
390 | def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
391 | r"""Wide ResNet-101-2 model from
392 | `"Wide Residual Networks" `_
393 |
394 | The model is the same as ResNet except for the bottleneck number of channels
395 | which is twice larger in every block. The number of channels in outer 1x1
396 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
397 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
398 |
399 | Args:
400 | pretrained (bool): If True, returns a model pre-trained on ImageNet
401 | progress (bool): If True, displays a progress bar of the download to stderr
402 | """
403 | kwargs['width_per_group'] = 64 * 2
404 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
405 | pretrained, progress, **kwargs)
406 |
407 |
408 |
409 | # def ResNet2D(num_classes, extract_features=False, loss_type='softmax',
410 | # weights=None, freeze_all_but_cls=False):
411 | # net = globals()['resnet' + str(50)](
412 | # num_classes=num_classes,
413 | # sample_size=50,
414 | # sample_duration=32,
415 | # extract_features=extract_features,
416 | # loss_type=loss_type,
417 | # )
418 |
419 | # if weights is not None:
420 | # kinetics_weights = torch.load(weights)['state_dict']
421 | # print("Found weights in {}.".format(weights))
422 | # cls_name = 'fc'
423 | # else:
424 | # kinetics_weights = torch.load(model_urls['resnet50'])
425 | # cls_name = 'fc'
426 | # print('\n Restoring Kintetics \n')
427 |
428 | # new_weights = {}
429 | # for k, v in kinetics_weights.items():
430 | # if not k.startswith('module.' + cls_name):
431 | # new_weights[k.replace('module.', '')] = v
432 | # else:
433 | # print(f"!!! Smt wrong with restore {k}")
434 | # net.load_state_dict(new_weights, strict=False)
435 |
436 | # if freeze_all_but_cls:
437 | # for name, par in net.named_parameters():
438 | # if not name.startswith('classifier'):
439 | # par.requires_grad = False
440 | # return net
441 |
--------------------------------------------------------------------------------
/ops/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pdb
3 | import os
4 | import matplotlib.pyplot as plt
5 | import matplotlib.colors as colors
6 |
7 | from PIL import Image
8 | import PIL
9 | from moviepy.editor import ImageSequenceClip
10 | import matplotlib
11 | matplotlib.use("Agg")
12 | global args
13 | from opts import parse_args
14 | args = parse_args()
15 |
16 |
17 | def offsets_to_boxes(offsets, frame_size=224):
18 | pred_boxes_dict = {}
19 | places = args.place_graph.replace('layer','').split('_')
20 | all_positions = []
21 | for place in places:
22 | layer = int(place.split('.')[0])
23 | block = int(place.split('.')[1])
24 | pred_position = f'layer{layer}_{block}'
25 | lb = (layer,block)
26 | all_positions.append(lb)
27 | for (layer,block) in all_positions:
28 | pred_position = f'layer{layer}_block{block}'
29 | num_layer = int(layer)
30 | pred_offset = offsets[pred_position].squeeze().detach().cpu().numpy()
31 | pred_offsets = pred_offset / args.graph_params[num_layer]['H'] * frame_size #224.0
32 | regions_h = pred_offsets[:,:,0]
33 | regions_w = pred_offsets[:,:,1]
34 | kernel_center_h = pred_offsets[:,:,2]
35 | kernel_center_w = pred_offsets[:,:,3]
36 |
37 | y1 = np.minimum(frame_size-1,np.maximum(0,kernel_center_h - regions_h))
38 | x1 = np.minimum(frame_size-1,np.maximum(0,kernel_center_w - regions_w))
39 | y2 = np.minimum(frame_size-1, kernel_center_h + regions_h)
40 | x2 = np.minimum(frame_size-1, kernel_center_w + regions_w)
41 |
42 | pred_boxes = np.stack([y1, x1, y2, x2], axis=-1)
43 | pred_boxes_dict[f'layer{layer}_{block}'] = pred_boxes
44 | return pred_boxes_dict
45 |
46 |
47 |
48 | def draw_box(frame, box, color=[1,1,1]):
49 | H = args.input_size
50 | (h,w,dh,dw) = box
51 |
52 | left = max(0,int(dw-w))
53 | right = min(int(dw+w),H-1)
54 | up = max(int(dh-h),0)
55 | down = min(int(dh+h), H-1)
56 |
57 | # print(left,right,up,down)
58 | frame[left,up:down] = color#(0,0,1)
59 |
60 | frame[right,up:down] = color#(0,0,1)
61 | frame[left:right, up] = color#(0,0,1)
62 | frame[left:right, down] = color#(0,0,1)
63 | return frame
64 |
65 | def draw_box_gt(frame, target_boxes, colors=[1,1,1], partial=False, line_width=1):
66 | if len(colors) < target_boxes.shape[0]:
67 | colors = colors * target_boxes.shape[0]
68 | for b in range(target_boxes.shape[0]):
69 | if partial and b % 3 != 0:
70 | continue
71 | box = target_boxes[b]
72 |
73 | top_h = int(box[0])
74 | left_w = int(box[1])
75 | bot_h = int(box[2])
76 | right_w = int(box[3])
77 | frame[top_h:bot_h, left_w:left_w + line_width] = colors[b]
78 | frame[top_h:bot_h, right_w:right_w + line_width] = colors[b]
79 |
80 | frame[top_h: top_h + line_width, left_w : right_w] = colors[b]
81 | frame[bot_h : bot_h + line_width, left_w : right_w] = colors[b]
82 |
83 | return frame
84 |
85 | def save_grad_cam(input, grad_cams, folder, name='initial'):
86 |
87 | import cv2
88 | folder_dir = folder + f'/viz_grad_cam_{name}/'
89 | if not os.path.exists(folder_dir):
90 | os.makedirs(folder_dir)
91 |
92 | video_input_val = input.detach().cpu().numpy()
93 | video_input_val = np.reshape(video_input_val, [args.batch_size, -1, 3 , video_input_val.shape[3],video_input_val.shape[4]])
94 | video_input_val = np.transpose(video_input_val, [0,1,3,4,2])
95 | for ind_grad, gr in enumerate(grad_cams):
96 | name = gr[0]
97 | grad_cam = gr[1]
98 | grad_cam = grad_cam.view((-1,16) + grad_cam.shape[1:])
99 | num_saved_videos = 2
100 | for video_idx in range(min(num_saved_videos, args.batch_size)):
101 | max_t = 16
102 | all_frames = []
103 | for tt in range(max_t):
104 | if 'resnet50_smt_else' in args.arch:
105 | real_tt = 2*tt
106 | else:
107 | real_tt = tt
108 |
109 | frame = video_input_val[video_idx,real_tt,:,:]
110 | if args.modality == 'RGB':
111 | frame = frame * 57.375 + 114.75
112 |
113 | grad_c = grad_cam[video_idx, tt].cpu().detach().numpy()
114 |
115 | grad_min = grad_c.min()
116 | grad_max = (grad_c - grad_min).max()
117 | grad_c = (grad_c - grad_min) / grad_max
118 | grad_c = (grad_c * 255).astype(np.uint8)
119 | grad_c = np.array(Image.fromarray(grad_c).resize((frame.shape[0],frame.shape[1]), resample=PIL.Image.BILINEAR))
120 |
121 | heatmap = cv2.applyColorMap(grad_c, cv2.COLORMAP_JET)
122 | cam = np.float32(heatmap) + np.float32(frame)
123 | cam = cam / cam.max()
124 |
125 | combined_img =np.concatenate((np.uint8(frame), np.uint8(255 * cam)), axis=1)
126 | cv2.imwrite(os.path.join(folder_dir, f'video_{video_idx}_frame_{tt}_{ind_grad}_grad_{name}.jpg'), combined_img)
127 |
128 |
129 | def save_mean_kernels(all_kernels,epoch=0,folder=''):
130 | places = args.place_graph.replace('layer','').split('_')
131 | placement_all_models = []
132 | for place in places:
133 | layer = int(place.split('.')[0])
134 | block = int(place.split('.')[1])
135 | placement_all_models.append(f'layer{layer}_block{block}')
136 |
137 | for placement in placement_all_models:
138 | kernel_val = all_kernels[placement]
139 |
140 | kernel_val = np.reshape(kernel_val, [args.num_segments , 3,3,kernel_val.shape[-2],kernel_val.shape[-1]])
141 | for tt in range(args.num_segments):
142 | f, axarr = plt.subplots(3,3)
143 | for ii in range(3):
144 | for jj in range(3):
145 | curent_kernel = kernel_val[tt][ii][jj]
146 | curent_kernel_max = curent_kernel.max()
147 | curent_kernel = curent_kernel / curent_kernel_max
148 | curent_kernel = (curent_kernel * 255).astype(np.uint8)
149 | curent_kernel = np.array(Image.fromarray(curent_kernel).resize((224,224), resample=PIL.Image.BILINEAR))
150 | curent_kernel = curent_kernel.astype(np.float32) / 255.0
151 |
152 | axarr[ii][jj].imshow(curent_kernel)
153 |
154 | placement = placement.replace('block','')
155 | folder_dir = folder + f'/viz_kernel_{placement}/'
156 | if not os.path.exists(folder_dir):
157 | os.makedirs(folder_dir)
158 | plt.savefig(f'{folder_dir}/mean_kernels_time_epoch_{epoch}_time_{tt}.png')
159 |
160 | # offsets: predicted offsets
161 | # target_offset: distilled offsets
162 | # input: B x TC x 224 x 224
163 | def save_kernels(input, interm_feats, folder, name='initial', target_offset=None, target_boxes_val=None, predicted_offsets=None):
164 | # predicted offsets is a dict. detach when used
165 | predicted_boxes_dict = offsets_to_boxes(predicted_offsets,frame_size=input.shape[-1])
166 |
167 | places = args.place_graph.replace('layer','').split('_')
168 | placement_all_models = []# ['layer4', 'layer3','layer2','layer1']
169 | for place in places:
170 | layer = int(place.split('.')[0])
171 | block = int(place.split('.')[1])
172 | placement_all_models.append(f'layer{layer}_{block}')
173 | input_ch = 3
174 | if args.modality == 'gray':
175 | input_ch = 1
176 | for placement in placement_all_models:
177 | predicted_boxes = predicted_boxes_dict[placement]
178 | predicted_boxes = predicted_boxes.reshape(input.shape[0], args.num_segments, predicted_boxes.shape[1], predicted_boxes.shape[2])
179 |
180 | kernel_val = interm_feats[placement+'_kernels'].detach().cpu().numpy()
181 | video_input_val = input.detach().cpu().numpy()
182 |
183 | video_input_val = np.reshape(video_input_val, [args.batch_size * args.test_crops, -1, input_ch , video_input_val.shape[-2],video_input_val.shape[-1]])
184 | video_input_val = np.transpose(video_input_val, [0,1,3,4,2])
185 | if args.modality == 'gray':
186 | tmp_zeros = -1 * np.ones((video_input_val.shape[0], video_input_val.shape[1], video_input_val.shape[2], video_input_val.shape[3], 3))
187 | tmp_zeros[:,:,:,:,0] = video_input_val[:,:,:,:,0]
188 | video_input_val = tmp_zeros
189 |
190 |
191 | kernel_val = np.reshape(kernel_val, [args.batch_size * args.test_crops, -1 , kernel_val.shape[1],kernel_val.shape[2],kernel_val.shape[3],kernel_val.shape[4]])
192 |
193 | folder_dir = folder + f'/viz_kernel_{placement}/'
194 | if not os.path.exists(folder_dir):
195 | os.makedirs(folder_dir)
196 |
197 | folder_dir = folder_dir + f'/{name}/'
198 | if not os.path.exists(folder_dir):
199 | os.makedirs(folder_dir)
200 | tt = 2
201 |
202 | num_rows = kernel_val.shape[2]
203 | save_individual_frames = True
204 |
205 | num_saved_videos = 2
206 |
207 | for video_idx in range(min(num_saved_videos, args.test_crops * args.batch_size)):
208 | max_t = kernel_val.shape[1]
209 | all_frames = []
210 | for tt in range(max_t):
211 | if 'resnet50_smt_else' in args.arch:
212 | real_tt = 2*tt
213 | else:
214 | real_tt = tt
215 | frame = video_input_val[video_idx,real_tt,:,:]
216 | if args.modality == 'RGB':
217 | frame = frame * 57.375 + 114.75
218 | frame = frame / 255.0
219 | else:
220 | frame = (frame + 1.0) / 2.0
221 |
222 | frame_hsv = colors.rgb_to_hsv(frame)
223 |
224 | rgb = np.array([ [1, 0, 0 ], [0, 1, 0 ], [0, 0, 1 ] ])
225 | hsv = colors.rgb_to_hsv(rgb)
226 |
227 | f, axarr = plt.subplots(num_rows,num_rows)
228 | N = kernel_val.shape[2] * kernel_val.shape[3]
229 |
230 | max_s = frame_hsv[:,:,1].max()
231 | max_v = frame_hsv[:,:,2].mean() + 0.85 * (frame_hsv[:,:,2].max() - frame_hsv[:,:,2].mean()) # poate de pus mean
232 | HSV_tuples = [(x*1.0/N, max_s, max_v) for x in range(N)]
233 |
234 | color_kernels = np.zeros( frame.shape[:2] + (3,) )
235 |
236 | cc = 0
237 | for ii in range(kernel_val.shape[2]):
238 | for jj in range(kernel_val.shape[3]):
239 |
240 | if save_individual_frames:
241 | axarr[ii][jj].imshow(kernel_val[video_idx][tt][ii][jj])
242 | curent_kernel = kernel_val[video_idx][tt][ii][jj]
243 | curent_kernel_max = curent_kernel.max()
244 | curent_kernel = curent_kernel / curent_kernel_max
245 | curent_kernel = (curent_kernel * 255).astype(np.uint8)
246 | curent_kernel = np.array(Image.fromarray(curent_kernel).resize((frame.shape[0],frame.shape[1]), resample=PIL.Image.BILINEAR))
247 | curent_kernel = curent_kernel.astype(np.float32) / 255.0
248 |
249 | mask = ((curent_kernel / curent_kernel.max()) > 0.3)
250 |
251 | curent_kernel = curent_kernel
252 |
253 | color_kernels[:,:,0] = (1.0 - mask) * color_kernels[:,:,0] + mask * HSV_tuples[cc][0]
254 | color_kernels[:,:,1] = (1.0 - mask) * color_kernels[:,:,1] + mask * HSV_tuples[cc][1]
255 | color_kernels[:,:,2] = (1.0 - mask) * color_kernels[:,:,2] + mask * curent_kernel * HSV_tuples[cc][2]
256 |
257 | cc += 1
258 |
259 | if frame.shape[2] == 1:
260 | # gray to rgb
261 | frame = np.tile(frame, [1,1,3])
262 |
263 |
264 | if target_offset and 'layer1' not in placement:
265 | frame_distill = frame.copy()
266 | for ii in range(kernel_val.shape[2] * kernel_val.shape[3]):
267 | place = placement.replace('layer', '')
268 | layer = int(place.split('_')[0])
269 | block = int(place.split('_')[1])
270 | # TODO: un-hardcodat 224
271 | h,w,dh,dw = target_offset[f'layer{layer}_block{block}'][video_idx, tt,ii] / kernel_val.shape[-1] * 224
272 | frame_distill = draw_box(frame_distill, (h,w,dh,dw))
273 | frame_distill = np.clip(frame_distill, 0.0, 1.0)
274 | rgb_color_kernels = colors.hsv_to_rgb(color_kernels)
275 | frame[:,:] += 1 * rgb_color_kernels#10 * kernel_sum
276 | frame = np.clip(frame, 0.0, 1.0)
277 |
278 | if save_individual_frames:
279 | plt.savefig(f'{folder_dir}/kernel_{video_idx}_{tt}.png')
280 | f, axarr = plt.subplots(1,1)
281 | axarr.imshow(frame)
282 | plt.savefig(f'{folder_dir}/frame_{video_idx}_{tt}.png')
283 | if target_offset and 'layer1' not in placement:
284 | f, axarr = plt.subplots(1,1)
285 | axarr.imshow(frame_distill)
286 | plt.savefig(f'{folder_dir}/z_distill_{video_idx}_{tt}.png')
287 | plt.close(f)
288 |
289 |
290 | frame = (frame / frame.max() * 255.0).astype(np.uint8)
291 | all_frames.append(frame)
292 |
293 | clip = ImageSequenceClip(list(all_frames), fps=3)
294 | clip.write_gif(f'{folder_dir}/video_{video_idx}.gif', fps=3)
295 |
296 | plt.close('all')
297 |
298 |
299 | def draw_frames(input, folder, ref_point, ref_dim, video_id, name=''):
300 | print(input.shape)
301 | video_input_val = input
302 | if name == 'crop':
303 | video_input_val = np.reshape(video_input_val, [-1, 3 , video_input_val.shape[1],video_input_val.shape[2]])
304 | video_input_val = np.transpose(video_input_val, [0,2,3,1])
305 |
306 | folder_dir = folder + f'/viz_kernel_tmp/'
307 | if not os.path.exists(folder_dir):
308 | os.makedirs(folder_dir)
309 |
310 | folder_dir = folder_dir + f'/' #'./viz_kernel_center/'
311 | if not os.path.exists(folder_dir):
312 | os.makedirs(folder_dir)
313 | tt = 2
314 |
315 | video_idx = int(video_id)
316 | for tt in range(16):
317 | frame = video_input_val[tt,:,:]
318 | if args.modality == 'RGB' and name == 'crop':
319 | frame = frame * 57.375 + 114.75
320 | frame = frame / 255.0
321 | elif name == 'crop':
322 | frame = (frame + 1.0) / 2.0
323 | else:
324 | frame = frame / 255.0
325 |
326 | frame_hsv = colors.rgb_to_hsv(frame)
327 |
328 | rgb = np.array([ [1, 0, 0 ], [0, 1, 0 ], [0, 0, 1 ] ])
329 |
330 | frame = np.clip(frame, 0.0, 1.0)
331 | # draw center
332 | frame[ref_point[0]-3:ref_point[0]+3, ref_point[1]-3:ref_point[1]+3] = (1.0, 0.0, 0.0)
333 | # draw box
334 | left = ref_point[1] - ref_dim[1]
335 | right = ref_point[1] + ref_dim[1]
336 | up = ref_point[0] - ref_dim[0]
337 | down = ref_point[0] + ref_dim[0]
338 | frame[up-1:up+1,left:right] = (0,1,0)
339 | frame[down-1:down+1,left:right] = (0,1,0)
340 | frame[up:down,left-1:left+1] = (0,1,0)
341 | frame[up:down,right-1:right+1] = (0,1,0)
342 |
343 | f, axarr = plt.subplots(1,1)
344 | axarr.imshow(frame)
345 | plt.savefig(f'{folder_dir}/frame_{video_idx}_{tt}_{name}.png')
346 | plt.close(f)
347 | plt.close('all')
348 |
349 |
350 |
351 | def count_params(params, contains=[''], ignores=['nonenone']):
352 | total_params = 0
353 | for name, shape in params:
354 | # print(f'{name} shape: {shape}')
355 | ok = True
356 | for ignore_name in ignores:
357 | if ignore_name in name:
358 | ok = False
359 | if not ok:
360 | continue
361 | prod = 1
362 | for d in shape:
363 | prod *= d
364 | selected_param = False
365 | for c in contains:
366 | if c in name:
367 | selected_param = True
368 | if selected_param:
369 | total_params += prod
370 | return total_params
371 | def softmax(scores):
372 | es = np.exp(scores - scores.max(axis=-1)[..., None])
373 | return es / es.sum(axis=-1)[..., None]
374 |
375 |
376 | class LearnedParamChecker():
377 | def __init__(self,model):
378 | self.model = model
379 | self.initial_params = self.save_initial_params()
380 |
381 | def save_initial_params(self):
382 | initial_params = {}
383 | for name, p in self.model.named_parameters():
384 | initial_params[name] = p.detach().cpu().numpy()
385 | return initial_params
386 |
387 | def compare_current_initial_params(self):
388 | for name, current_p in self.model.named_parameters():
389 | initial_p = self.initial_params[name]
390 | current_p = current_p.detach().cpu().numpy()
391 | diff = np.mean(np.abs(initial_p - current_p))
392 | print(f'params: {name} mean change : {diff}')
393 |
394 |
395 |
396 | class AverageMeter(object):
397 | """Computes and stores the average and current value"""
398 |
399 | def __init__(self):
400 | self.reset()
401 |
402 | def reset(self):
403 | self.val = 0
404 | self.avg = 0
405 | self.sum = 0
406 | self.count = 0
407 |
408 | def update(self, val, n=1):
409 | self.val = val
410 | self.sum += val * n
411 | self.count += n
412 | self.avg = self.sum / self.count
413 |
414 |
415 | def accuracy(output, target, topk=(1,)):
416 | """Computes the precision@k for the specified values of k"""
417 | maxk = max(topk)
418 | batch_size = target.size(0)
419 |
420 | _, pred = output.topk(maxk, 1, True, True)
421 | pred = pred.t()
422 | correct = pred.eq(target.view(1, -1).expand_as(pred))
423 |
424 | res = []
425 | for k in topk:
426 | # correct_k = correct[:k].view(-1).float().sum(0)
427 | correct_k = correct[:k].reshape(-1).float().sum(0)
428 |
429 | res.append(correct_k.mul_(100.0 / batch_size))
430 | return res
431 |
432 |
433 |
434 |
435 | def adjust_kernel(pred_kernel_val, pred_boxes):
436 | # pred_kernel_val: (BT, 3, 3, 14, 14)
437 | # pred_boxes: (BT, 9, 4) ([left_h, left_w, righ_h, right_w]))
438 |
439 | pred_kernel_val = np.reshape(pred_kernel_val,(pred_kernel_val.shape[0], pred_kernel_val.shape[1]*pred_kernel_val.shape[2], pred_kernel_val.shape[3], pred_kernel_val.shape[4]))
440 | all_adjust_kernel_boxes = np.zeros_like(pred_boxes)
441 |
442 | time1 = time.time()
443 | for b in range(pred_kernel_val.shape[0]):
444 | for i in range(9):
445 | # kernel from 14x14 ->224x224
446 | curent_kernel = pred_kernel_val[b,i] / pred_kernel_val[b,i].max()
447 | curent_kernel = (curent_kernel * 255).astype(np.uint8)
448 |
449 | time5 = time.time()
450 | curent_kernel = np.array(Image.fromarray(curent_kernel).resize((224,224), resample=PIL.Image.BILINEAR))
451 | time6 = time.time()
452 | # print(f'Resize time: {time6-time5}')
453 | curent_kernel = curent_kernel.astype(np.float32) / 255.0
454 | curent_kernel = curent_kernel / curent_kernel.sum()
455 |
456 | (left_h, left_w, right_h, right_w) = pred_boxes[b,i]
457 | h = right_h - left_h
458 | w = right_w - left_w
459 | # print(f'New frame h={h}, w={w}')
460 | prev_sums = curent_kernel.sum()
461 | prev_adjust_kernel_boxes = np.array([left_h, left_w, right_h, right_w])
462 | prev_dx = 0.0
463 | prev_dy = 0.0
464 | time7 = time.time()
465 | for j in np.arange(0.01,1, 0.01):
466 | dx = j * (h/2)
467 | dy = j * (w/2)
468 | new_left_h = int(left_h+dx)
469 | new_left_w = int(left_w+dy)
470 | new_right_h = int(right_h-dx)
471 | new_right_w = int(right_w-dy)
472 | adjust_kernel_boxes = np.array([new_left_h, new_left_w, new_right_h, new_right_w])
473 | crt_sums = curent_kernel[new_left_h:new_right_h+1, new_left_w:new_right_w+1].sum()
474 | if crt_sums < 0.90 and prev_sums >= 0.90:
475 | # print(prev_dx, prev_dy, prev_sums)
476 | # print(dx, dy, crt_sums)
477 |
478 | # print(f'[ {left_h}, {left_w}, {right_h}, {right_w}] -> [{new_left_h}, {new_left_w}, {new_right_h}, {new_right_w}]')
479 | all_adjust_kernel_boxes[b,i] = prev_adjust_kernel_boxes
480 | break
481 |
482 | prev_sums = crt_sums
483 | prev_adjust_kernel_boxes = np.array([new_left_h, new_left_w, new_right_h, new_right_w])
484 | prev_dx = dx
485 | prev_dy = dy
486 | time8 = time.time()
487 | # print(f'find adjustment: {time8-time7}')
488 | time2 = time.time()
489 | print(f'Adjust kernel: {time2-time1}')
490 | return all_adjust_kernel_boxes
491 |
--------------------------------------------------------------------------------