├── AOC-Net
├── __init__.py
├── adaptive_embedding_for_matching.py
├── complete_project
│ ├── AOCNet
│ │ ├── configs
│ │ │ ├── resnet101_aocnet.py
│ │ │ └── resnet101_aocnet_2.py
│ │ ├── dataloaders
│ │ │ ├── __init__.py
│ │ │ ├── custom_transforms.py
│ │ │ └── datasets_m.py
│ │ ├── networks
│ │ │ ├── __init__.py
│ │ │ ├── aoc
│ │ │ │ ├── __init__.py
│ │ │ │ ├── aocnet.py
│ │ │ │ ├── conditioning_layer.py
│ │ │ │ └── decoding_module.py
│ │ │ ├── deeplab
│ │ │ │ ├── __init__.py
│ │ │ │ ├── aspp.py
│ │ │ │ ├── backbone
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── __pycache__
│ │ │ │ │ │ ├── __init__.cpython-36.pyc
│ │ │ │ │ │ ├── mobilenet.cpython-36.pyc
│ │ │ │ │ │ └── resnet.cpython-36.pyc
│ │ │ │ │ ├── mobilenet.py
│ │ │ │ │ └── resnet.py
│ │ │ │ ├── decoder.py
│ │ │ │ └── deeplab.py
│ │ │ ├── engine
│ │ │ │ ├── __init__.py
│ │ │ │ ├── eval_manager_mm.py
│ │ │ │ └── train_manager_mm.py
│ │ │ └── layers
│ │ │ │ ├── __init__.py
│ │ │ │ ├── aspp.py
│ │ │ │ ├── attention.py
│ │ │ │ ├── conv_gru.py
│ │ │ │ ├── gct.py
│ │ │ │ ├── gru_conv.py
│ │ │ │ ├── loss.py
│ │ │ │ ├── matching.py
│ │ │ │ ├── normalization.py
│ │ │ │ ├── refiner.py
│ │ │ │ ├── shannon_entropy.py
│ │ │ │ └── shanoon_entropy.py
│ │ ├── scripts
│ │ │ ├── eval.sh
│ │ │ └── train.sh
│ │ ├── tools
│ │ │ ├── eval_net_mm_rpa.py
│ │ │ └── train_net_mm.py
│ │ └── utils
│ │ │ ├── __init__.py
│ │ │ ├── checkpoint.py
│ │ │ ├── eval.py
│ │ │ ├── image.py
│ │ │ ├── learning.py
│ │ │ ├── meters.py
│ │ │ └── metric.py
│ └── README.md
└── conditioning_layer.py
├── Dockerfile
├── README.md
├── Robust-VOS-Benchmark
├── AOT
│ └── eval_datasets.py
├── CFBI&AOC(ours)
│ └── datasets_robustness.py
└── __init__.py
└── figs
├── FIGS.MD
├── mm22_345_poster_a0.pdf
└── mm22_345_poster_a0.pptx
/AOC-Net/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/AOC-Net/complete_project/AOCNet/configs/resnet101_aocnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch
3 | import argparse
4 | import os
5 | import sys
6 | import cv2
7 | import time
8 | import socket
9 | import random
10 |
11 | class Configuration():
12 | def __init__(self):
13 | self.EXP_NAME = 'aoc_stage_1'
14 |
15 | self.EVAL_AUTO_RESUME = False
16 | self.UNC_RATIO=1.0
17 | self.MEM_EVERY=5
18 | self.USE_SF=False
19 | self.PAST_FRAME_NUM=4
20 | self.COMPRESSION_RATE=4
21 | self.BLOCK_NUM=2
22 |
23 | self.DIR_ROOT = '/datasets/'
24 | self.DIR_DATA = '/datasets/datasets/'
25 | self.DIR_TEMP_RESULT = '/datasets/result_evaluation/'
26 | self.DIR_DAVIS = os.path.join(self.DIR_DATA, 'DAVIS')
27 | self.DIR_PORTRAIT_DATA ='/datasets/Portrait_dataset/dataset_post/'
28 | self.DIR_YTB = os.path.join(self.DIR_DATA, 'train_all_frames/train_all_frames/')
29 | self.DIR_YTB_EVAL = os.path.join(self.DIR_DATA, 'valid_all_frames/valid_all_frames/')
30 | self.DIR_YTB_EVAL18 = os.path.join(self.DIR_ROOT, 'MODELS/STCN-ALL/YouTube2018/valid/')
31 | self.DIR_YTB_EVAL19 = os.path.join(self.DIR_ROOT,'MODELS/STCN-ALL/YouTube/valid/')
32 | self.DIR_RESULT = os.path.join(self.DIR_ROOT, 'result', self.EXP_NAME)
33 |
34 |
35 |
36 | self.DIR_CKPT = os.path.join(self.DIR_RESULT, 'ckpt')
37 | self.DIR_LOG = os.path.join(self.DIR_RESULT, 'log')
38 | self.DIR_IMG_LOG = os.path.join(self.DIR_RESULT, 'log', 'img')
39 | self.DIR_TB_LOG = os.path.join(self.DIR_RESULT, 'log', 'tensorboard')
40 | self.DIR_EVALUATION = os.path.join(self.DIR_RESULT, 'eval')
41 |
42 | self.DATASETS = ['youtubevos']
43 | self.DATA_WORKERS = 4
44 | self.DATA_RANDOMCROP = (465, 465)
45 | self.DATA_RANDOMFLIP = 0.5
46 | self.DATA_MAX_CROP_STEPS = 5
47 | self.DATA_MIN_SCALE_FACTOR = 1.
48 | self.DATA_MAX_SCALE_FACTOR = 1.3
49 | self.DATA_SHORT_EDGE_LEN = 480
50 | self.DATA_RANDOM_REVERSE_SEQ = True
51 | self.DATA_DAVIS_REPEAT = 30
52 | self.DATA_CURR_SEQ_LEN = 5
53 | self.DATA_RANDOM_GAP_DAVIS = 3
54 | self.DATA_RANDOM_GAP_YTB = 3
55 |
56 |
57 | self.PRETRAIN = True
58 | self.PRETRAIN_FULL = True
59 | self.PRETRAIN_MODEL = '/datasets/result/resnet101_cfbi_p2t_lr02_8C_pt_pxmc_pxm/ckpt/save_step_400000.pth'
60 | self.MODEL_BACKBONE = 'resnet'
61 | self.MODEL_MODULE = 'networks.aoc.aocnet'
62 | self.MODEL_OUTPUT_STRIDE = 16
63 | self.MODEL_ASPP_OUTDIM = 256
64 | self.MODEL_SHORTCUT_DIM = 48
65 | self.MODEL_SEMANTIC_EMBEDDING_DIM = 100
66 | self.MODEL_HEAD_EMBEDDING_DIM = 256
67 | self.MODEL_PRE_HEAD_EMBEDDING_DIM = 64
68 | self.MODEL_GN_GROUPS = 32
69 | self.MODEL_GN_EMB_GROUPS = 25
70 | self.MODEL_MULTI_LOCAL_DISTANCE = [2, 4, 6, 8, 10, 12]
71 | self.MODEL_LOCAL_DOWNSAMPLE = True
72 | self.MODEL_REFINE_CHANNELS = 64 # n * 32
73 | self.MODEL_LOW_LEVEL_INPLANES = 256 if self.MODEL_BACKBONE == 'resnet' else 24
74 | self.MODEL_RELATED_CHANNELS = 64
75 | self.MODEL_EPSILON = 1e-5
76 | self.MODEL_MATCHING_BACKGROUND = True
77 | self.MODEL_GCT_BETA_WD = True
78 | self.MODEL_FLOAT16_MATCHING = False
79 | self.MODEL_FREEZE_BN = True
80 | self.MODEL_FREEZE_BACKBONE = False
81 |
82 | self.TRAIN_TOTAL_STEPS = 50000
83 | self.TRAIN_START_STEP = 0
84 | self.TRAIN_LR = 0.01
85 | self.TRAIN_MOMENTUM = 0.9
86 | self.TRAIN_COSINE_DECAY = False
87 | self.TRAIN_WARM_UP_STEPS = 1000
88 | self.TRAIN_WEIGHT_DECAY = 15e-5
89 | self.TRAIN_POWER = 0.9
90 | self.TRAIN_GPUS = 8
91 | self.TRAIN_BATCH_SIZE = 8
92 | self.TRAIN_START_SEQ_TRAINING_STEPS = self.TRAIN_TOTAL_STEPS / 2
93 | self.TRAIN_TBLOG = False
94 | self.TRAIN_TBLOG_STEP = 60
95 | self.TRAIN_LOG_STEP = 20
96 | self.TRAIN_IMG_LOG = False
97 | self.TRAIN_TOP_K_PERCENT_PIXELS = 0.15
98 | self.TRAIN_HARD_MINING_STEP = self.TRAIN_TOTAL_STEPS / 2
99 | self.TRAIN_CLIP_GRAD_NORM = 5.
100 | self.TRAIN_SAVE_STEP = 2000
101 | self.TRAIN_MAX_KEEP_CKPT = 8000
102 | self.TRAIN_RESUME = False
103 | self.TRAIN_RESUME_CKPT = None
104 | self.TRAIN_RESUME_STEP = 0
105 | self.TRAIN_AUTO_RESUME = True
106 | self.TRAIN_GLOBAL_ATROUS_RATE = 1
107 | self.TRAIN_LOCAL_ATROUS_RATE = 1
108 | self.TRAIN_LOCAL_PARALLEL = True
109 | self.TRAIN_GLOBAL_CHUNKS = 1
110 | self.TRAIN_DATASET_FULL_RESOLUTION = True
111 |
112 |
113 | self.TEST_GPU_ID = 0
114 | self.TEST_DATASET = 'youtubevos'
115 | self.TEST_DATASET_FULL_RESOLUTION = True
116 | self.TEST_DATASET_SPLIT = ['val']
117 | self.TEST_CKPT_PATH = None
118 | self.TEST_CKPT_STEP = None # if "None", evaluate the latest checkpoint.
119 | self.TEST_FLIP = False
120 | self.TEST_MULTISCALE = [1]
121 | self.TEST_MIN_SIZE = None
122 | self.TEST_MAX_SIZE = 800 * 1.3 if self.TEST_MULTISCALE == [1.] else 800
123 | self.TEST_WORKERS = 4
124 | self.TEST_GLOBAL_CHUNKS = 4
125 | self.TEST_GLOBAL_ATROUS_RATE = 1
126 | self.TEST_LOCAL_ATROUS_RATE = 1
127 | self.TEST_LOCAL_PARALLEL = True
128 |
129 | # dist
130 | self.DIST_ENABLE = True
131 | self.DIST_BACKEND = "nccl"
132 |
133 | myname = socket.getfqdn(socket.gethostname( ))
134 | myaddr = socket.gethostbyname(myname)
135 |
136 | self.DIST_URL = "tcp://"+myaddr+":"+str(random.randint(30000,50000))
137 | self.DIST_START_GPU = 0
138 |
139 | self.__check()
140 |
141 | def __check(self):
142 | if not torch.cuda.is_available():
143 | raise ValueError('config.py: cuda is not avalable')
144 | if self.TRAIN_GPUS == 0:
145 | raise ValueError('config.py: the number of GPU is 0')
146 | for path in [self.DIR_RESULT, self.DIR_CKPT, self.DIR_LOG, self.DIR_EVALUATION, self.DIR_IMG_LOG, self.DIR_TB_LOG]:
147 | if not os.path.isdir(path):
148 | os.makedirs(path)
149 |
150 |
151 |
152 | cfg = Configuration()
153 |
--------------------------------------------------------------------------------
/AOC-Net/complete_project/AOCNet/configs/resnet101_aocnet_2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 | import os
4 | import sys
5 | import cv2
6 | import time
7 | import socket
8 | import random
9 |
10 | class Configuration():
11 | def __init__(self):
12 | self.EXP_NAME = 'aocnet_stage_2'
13 |
14 | self.EVAL_AUTO_RESUME = False
15 | self.UNC_RATIO=1.0
16 | self.MEM_EVERY=5
17 | self.USE_SF=False
18 | self.PAST_FRAME_NUM=4
19 | self.COMPRESSION_RATE=4
20 | self.BLOCK_NUM=2
21 |
22 | self.DIR_ROOT = '/datasets/'
23 | self.DIR_DATA = '/datasets/datasets/'
24 | self.DIR_TEMP_RESULT = '/datasets/result_evaluation/'
25 | self.DIR_DAVIS = os.path.join(self.DIR_DATA, 'DAVIS')
26 | self.DIR_PORTRAIT_DATA ='/datasets/Portrait_dataset/dataset_post/'
27 | self.DIR_YTB = os.path.join(self.DIR_DATA, 'train_all_frames/train_all_frames/')
28 | self.DIR_YTB_EVAL = os.path.join(self.DIR_DATA, 'valid_all_frames/valid_all_frames/')
29 | self.DIR_YTB_EVAL18 = os.path.join(self.DIR_ROOT, 'MODELS/STCN-ALL/YouTube2018/valid/')
30 | self.DIR_YTB_EVAL19 = os.path.join(self.DIR_ROOT,'MODELS/STCN-ALL/YouTube/valid/')
31 | self.DIR_RESULT = os.path.join(self.DIR_ROOT, 'result', self.EXP_NAME)
32 |
33 |
34 |
35 | self.DIR_CKPT = os.path.join(self.DIR_RESULT, 'ckpt')
36 | self.DIR_LOG = os.path.join(self.DIR_RESULT, 'log')
37 | self.DIR_IMG_LOG = os.path.join(self.DIR_RESULT, 'log', 'img')
38 | self.DIR_TB_LOG = os.path.join(self.DIR_RESULT, 'log', 'tensorboard')
39 | self.DIR_EVALUATION = os.path.join(self.DIR_RESULT, 'eval')
40 |
41 | self.DATASETS = ['youtubevos']
42 | self.DATA_WORKERS = 4
43 | self.DATA_RANDOMCROP = (465, 465)
44 | self.DATA_RANDOMFLIP = 0.5
45 | self.DATA_MAX_CROP_STEPS = 5
46 | self.DATA_MIN_SCALE_FACTOR = 1.
47 | self.DATA_MAX_SCALE_FACTOR = 1.3
48 | self.DATA_SHORT_EDGE_LEN = 480
49 | self.DATA_RANDOM_REVERSE_SEQ = True
50 | self.DATA_DAVIS_REPEAT = 30
51 | self.DATA_CURR_SEQ_LEN = 5
52 | self.DATA_RANDOM_GAP_DAVIS = 3
53 | self.DATA_RANDOM_GAP_YTB = 3
54 |
55 |
56 | self.PRETRAIN = True
57 | self.PRETRAIN_FULL = False
58 | self.PRETRAIN_MODEL = '/datasets/MODELS/CFBI/resnet101-deeplabv3p.pth.tar'
59 |
60 | self.MODEL_BACKBONE = 'resnet'
61 | self.MODEL_MODULE = 'networks.aoc.aocnet'
62 | self.MODEL_OUTPUT_STRIDE = 16
63 | self.MODEL_ASPP_OUTDIM = 256
64 | self.MODEL_SHORTCUT_DIM = 48
65 | self.MODEL_SEMANTIC_EMBEDDING_DIM = 100
66 | self.MODEL_HEAD_EMBEDDING_DIM = 256
67 | self.MODEL_PRE_HEAD_EMBEDDING_DIM = 64
68 | self.MODEL_GN_GROUPS = 32
69 | self.MODEL_GN_EMB_GROUPS = 25
70 | self.MODEL_MULTI_LOCAL_DISTANCE = [2, 4, 6, 8, 10, 12]
71 | self.MODEL_LOCAL_DOWNSAMPLE = True
72 | self.MODEL_REFINE_CHANNELS = 64 # n * 32
73 | self.MODEL_LOW_LEVEL_INPLANES = 256 if self.MODEL_BACKBONE == 'resnet' else 24
74 | self.MODEL_RELATED_CHANNELS = 64
75 | self.MODEL_EPSILON = 1e-5
76 | self.MODEL_MATCHING_BACKGROUND = True
77 | self.MODEL_GCT_BETA_WD = True
78 | self.MODEL_FLOAT16_MATCHING = False
79 | self.MODEL_FREEZE_BN = True
80 | self.MODEL_FREEZE_BACKBONE = False
81 |
82 | self.TRAIN_TOTAL_STEPS = 400000
83 | self.TRAIN_START_STEP = 0
84 | self.TRAIN_LR = 0.01
85 | self.TRAIN_MOMENTUM = 0.9
86 | self.TRAIN_COSINE_DECAY = False
87 | self.TRAIN_WARM_UP_STEPS = 1000
88 | self.TRAIN_WEIGHT_DECAY = 15e-5
89 | self.TRAIN_POWER = 0.9
90 | self.TRAIN_GPUS = 8
91 | self.TRAIN_BATCH_SIZE = 8
92 | self.TRAIN_START_SEQ_TRAINING_STEPS = self.TRAIN_TOTAL_STEPS / 2
93 | self.TRAIN_TBLOG = False
94 | self.TRAIN_TBLOG_STEP = 60
95 | self.TRAIN_LOG_STEP = 20
96 | self.TRAIN_IMG_LOG = False
97 | self.TRAIN_TOP_K_PERCENT_PIXELS = 0.15
98 | self.TRAIN_HARD_MINING_STEP = self.TRAIN_TOTAL_STEPS / 2
99 | self.TRAIN_CLIP_GRAD_NORM = 5.
100 | self.TRAIN_SAVE_STEP = 2000
101 | self.TRAIN_MAX_KEEP_CKPT = 8000
102 | self.TRAIN_RESUME = False
103 | self.TRAIN_RESUME_CKPT = None
104 | self.TRAIN_RESUME_STEP = 0
105 | self.TRAIN_AUTO_RESUME = True
106 | self.TRAIN_GLOBAL_ATROUS_RATE = 1
107 | self.TRAIN_LOCAL_ATROUS_RATE = 1
108 | self.TRAIN_LOCAL_PARALLEL = True
109 | self.TRAIN_GLOBAL_CHUNKS = 1
110 | self.TRAIN_DATASET_FULL_RESOLUTION = True
111 |
112 |
113 | self.TEST_GPU_ID = 0
114 | self.TEST_DATASET = 'youtubevos'
115 | self.TEST_DATASET_FULL_RESOLUTION = False
116 | self.TEST_DATASET_SPLIT = ['val']
117 | self.TEST_CKPT_PATH = None
118 | self.TEST_CKPT_STEP = None # if "None", evaluate the latest checkpoint.
119 | self.TEST_FLIP = False
120 | self.TEST_MULTISCALE = [1]
121 | self.TEST_MIN_SIZE = None
122 | self.TEST_MAX_SIZE = 800 * 1.3 if self.TEST_MULTISCALE == [1.] else 800
123 | self.TEST_WORKERS = 4
124 | self.TEST_GLOBAL_CHUNKS = 4
125 | self.TEST_GLOBAL_ATROUS_RATE = 1
126 | self.TEST_LOCAL_ATROUS_RATE = 1
127 | self.TEST_LOCAL_PARALLEL = True
128 |
129 | # dist
130 | self.DIST_ENABLE = True
131 | self.DIST_BACKEND = "nccl"
132 |
133 | myname = socket.getfqdn(socket.gethostname( ))
134 | myaddr = socket.gethostbyname(myname)
135 |
136 | self.DIST_URL = "tcp://"+myaddr+":"+str(random.randint(30000,50000))
137 | self.DIST_START_GPU = 0
138 |
139 | self.__check()
140 |
141 | def __check(self):
142 | if not torch.cuda.is_available():
143 | raise ValueError('config.py: cuda is not avalable')
144 | if self.TRAIN_GPUS == 0:
145 | raise ValueError('config.py: the number of GPU is 0')
146 | for path in [self.DIR_RESULT, self.DIR_CKPT, self.DIR_LOG, self.DIR_EVALUATION, self.DIR_IMG_LOG, self.DIR_TB_LOG]:
147 | if not os.path.isdir(path):
148 | os.makedirs(path)
149 |
150 |
151 |
152 | cfg = Configuration()
153 |
--------------------------------------------------------------------------------
/AOC-Net/complete_project/AOCNet/dataloaders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JerryX1110/Robust-Video-Object-Segmentation/f562afe70ac0f960d27980d3e82b237f8dea45ec/AOC-Net/complete_project/AOCNet/dataloaders/__init__.py
--------------------------------------------------------------------------------
/AOC-Net/complete_project/AOCNet/dataloaders/custom_transforms.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import cv2
4 | import numpy as np
5 | import torch
6 |
7 | cv2.setNumThreads(0)
8 |
9 | class Resize(object):
10 | """Rescale the image in a sample to a given size.
11 |
12 | Args:
13 | output_size (tuple or int): Desired output size. If tuple, output is
14 | matched to output_size. If int, smaller of image edges is matched
15 | to output_size keeping aspect ratio the same.
16 | """
17 |
18 | def __init__(self, output_size):
19 | assert isinstance(output_size, (int, tuple))
20 | if isinstance(output_size, int):
21 | self.output_size = (output_size, output_size)
22 | else:
23 | self.output_size = output_size
24 |
25 | def __call__(self, sample):
26 | prev_img = sample['prev_img']
27 | h, w = prev_img.shape[:2]
28 | if self.output_size == (h, w):
29 | return sample
30 | else:
31 | new_h, new_w = self.output_size
32 |
33 | for elem in sample.keys():
34 | if 'meta' in elem:
35 | continue
36 | tmp = sample[elem]
37 |
38 | if elem == 'prev_img' or elem == 'curr_img' or elem == 'ref_img':
39 | flagval = cv2.INTER_CUBIC
40 | else:
41 | flagval = cv2.INTER_NEAREST
42 |
43 | if elem == 'curr_img' or elem == 'curr_label':
44 | new_tmp = []
45 | all_tmp = tmp
46 | for tmp in all_tmp:
47 | tmp = cv2.resize(tmp, dsize=(new_w, new_h),
48 | interpolation=flagval)
49 | new_tmp.append(tmp)
50 | tmp = new_tmp
51 | else:
52 | tmp = cv2.resize(tmp, dsize=(new_w, new_h),
53 | interpolation=flagval)
54 |
55 | sample[elem] = tmp
56 |
57 | return sample
58 |
59 | class BalancedRandomCrop(object):
60 | """Crop randomly the image in a sample.
61 |
62 | Args:
63 | output_size (tuple or int): Desired output size. If int, square crop
64 | is made.
65 | """
66 |
67 | def __init__(self, output_size, max_step=5, max_obj_num=5, min_obj_pixel_num=100):
68 | assert isinstance(output_size, (int, tuple))
69 | if isinstance(output_size, int):
70 | self.output_size = (output_size, output_size)
71 | else:
72 | assert len(output_size) == 2
73 | self.output_size = output_size
74 | self.max_step = max_step
75 | self.max_obj_num = max_obj_num
76 | self.min_obj_pixel_num = min_obj_pixel_num
77 |
78 | def __call__(self, sample):
79 |
80 | image = sample['prev_img']
81 | h, w = image.shape[:2]
82 | new_h, new_w = self.output_size
83 | new_h = h if new_h >= h else new_h
84 | new_w = w if new_w >= w else new_w
85 | ref_label = sample["ref_label"]
86 | prev_label = sample["prev_label"]
87 | curr_label = sample["curr_label"]
88 |
89 | is_contain_obj = False
90 | step = 0
91 | while (not is_contain_obj) and (step < self.max_step):
92 | step += 1
93 | top = np.random.randint(0, h - new_h + 1)
94 | left = np.random.randint(0, w - new_w + 1)
95 | after_crop = []
96 | contains = []
97 | for elem in ([ref_label, prev_label] + curr_label):
98 | tmp = elem[top: top + new_h, left:left + new_w]
99 | contains.append(np.unique(tmp))
100 | after_crop.append(tmp)
101 |
102 |
103 | all_obj = list(np.sort(contains[0]))
104 |
105 | if all_obj[-1] == 0:
106 | continue
107 |
108 | # remove background
109 | if all_obj[0] == 0:
110 | all_obj = all_obj[1:]
111 | # remove small obj
112 | new_all_obj = []
113 | for obj_id in all_obj:
114 | after_crop_pixels = np.sum(after_crop[0] == obj_id)
115 | if after_crop_pixels > self.min_obj_pixel_num:
116 | new_all_obj.append(obj_id)
117 |
118 | if len(new_all_obj) == 0:
119 | is_contain_obj = False
120 | else:
121 | is_contain_obj = True
122 |
123 | if len(new_all_obj) > self.max_obj_num:
124 | random.shuffle(new_all_obj)
125 | new_all_obj = new_all_obj[:self.max_obj_num]
126 |
127 | all_obj = [0] + new_all_obj
128 |
129 |
130 | post_process = []
131 | for elem in after_crop:
132 | new_elem = elem * 0
133 | for idx in range(len(all_obj)):
134 | obj_id = all_obj[idx]
135 | if obj_id == 0:
136 | continue
137 | mask = elem == obj_id
138 |
139 | new_elem += (mask * idx).astype(np.uint8)
140 | post_process.append(new_elem.astype(np.uint8))
141 |
142 | sample["ref_label"] = post_process[0]
143 | sample["prev_label"] = post_process[1]
144 | curr_len = len(sample["curr_img"])
145 | sample["curr_label"] = []
146 | for idx in range(curr_len):
147 | sample["curr_label"].append(post_process[idx + 2])
148 |
149 | for elem in sample.keys():
150 | if 'meta' in elem or 'label' in elem:
151 | continue
152 | if elem == 'curr_img':
153 | new_tmp = []
154 | for tmp_ in sample[elem]:
155 | tmp_ = tmp_[top: top + new_h, left:left + new_w]
156 | new_tmp.append(tmp_)
157 | sample[elem] = new_tmp
158 | else:
159 | tmp = sample[elem]
160 | tmp = tmp[top: top + new_h, left:left + new_w]
161 | sample[elem] = tmp
162 |
163 | obj_num = len(all_obj) - 1
164 |
165 | sample['meta']['obj_num'] = obj_num
166 |
167 | return sample
168 |
169 |
170 | class RandomScale(object):
171 | """Randomly resize the image and the ground truth to specified scales.
172 | Args:
173 | scales (list): the list of scales
174 | """
175 |
176 | def __init__(self, min_scale=1., max_scale=1.3, short_edge=None):
177 | self.min_scale = min_scale
178 | self.max_scale = max_scale
179 | self.short_edge = short_edge
180 |
181 | def __call__(self, sample):
182 | # Fixed range of scales
183 | sc = np.random.uniform(self.min_scale, self.max_scale)
184 | # Align short edge
185 | if not (self.short_edge is None):
186 | image = sample['prev_img']
187 | h, w = image.shape[:2]
188 | if h > w:
189 | sc *= float(self.short_edge) / w
190 | else:
191 | sc *= float(self.short_edge) / h
192 |
193 |
194 | for elem in sample.keys():
195 | if 'meta' in elem:
196 | continue
197 | tmp = sample[elem]
198 |
199 | if elem == 'prev_img' or elem == 'curr_img' or elem == 'ref_img':
200 | flagval = cv2.INTER_CUBIC
201 | else:
202 | flagval = cv2.INTER_NEAREST
203 |
204 | if elem == 'curr_img' or elem == 'curr_label':
205 | new_tmp = []
206 | for tmp_ in tmp:
207 | tmp_ = cv2.resize(tmp_, None, fx=sc, fy=sc, interpolation=flagval)
208 | new_tmp.append(tmp_)
209 | tmp = new_tmp
210 | else:
211 | tmp = cv2.resize(tmp, None, fx=sc, fy=sc, interpolation=flagval)
212 |
213 | sample[elem] = tmp
214 |
215 | return sample
216 |
217 | class RestrictSize(object):
218 | """Randomly resize the image and the ground truth to specified scales.
219 | Args:
220 | scales (list): the list of scales
221 | """
222 |
223 | def __init__(self, min_size=None, max_size=800*1.3):
224 | self.min_size = min_size
225 | self.max_size = max_size
226 | assert ((min_size is None)) or ((max_size is None))
227 |
228 | def __call__(self, sample):
229 |
230 | # Fixed range of scales
231 | sc = None
232 | image = sample['ref_img']
233 | h, w = image.shape[:2]
234 | # Align short edge
235 | if not (self.min_size is None):
236 | if h > w:
237 | short_edge = w
238 | else:
239 | short_edge = h
240 | if short_edge < self.min_size:
241 | sc = float(self.min_size) / short_edge
242 | else:
243 | if h > w:
244 | long_edge = h
245 | else:
246 | long_edge = w
247 | if long_edge > self.max_size:
248 | sc = float(self.max_size) / long_edge
249 |
250 | if sc is None:
251 | new_h = h
252 | new_w = w
253 | else:
254 | new_h = int(sc * h)
255 | new_w = int(sc * w)
256 | new_h = new_h - (new_h - 1) % 4
257 | new_w = new_w - (new_w - 1) % 4
258 | if new_h == h and new_w == w:
259 | return sample
260 |
261 |
262 | for elem in sample.keys():
263 | if 'meta' in elem:
264 | continue
265 | tmp = sample[elem]
266 |
267 | if 'label' in elem:
268 | flagval = cv2.INTER_NEAREST
269 | else:
270 | flagval = cv2.INTER_CUBIC
271 |
272 | tmp = cv2.resize(tmp, dsize=(new_w, new_h), interpolation=flagval)
273 |
274 | sample[elem] = tmp
275 |
276 | return sample
277 |
278 |
279 | class RandomHorizontalFlip(object):
280 | """Horizontally flip the given image and ground truth randomly with a probability of 0.5."""
281 |
282 | def __init__(self, prob):
283 | self.p = prob
284 |
285 | def __call__(self, sample):
286 |
287 | if random.random() < self.p:
288 | for elem in sample.keys():
289 | if 'meta' in elem:
290 | continue
291 | if elem == 'curr_img' or elem == 'curr_label':
292 | new_tmp = []
293 | for tmp_ in sample[elem]:
294 | tmp_ = cv2.flip(tmp_, flipCode=1)
295 | new_tmp.append(tmp_)
296 | sample[elem] = new_tmp
297 | else:
298 | tmp = sample[elem]
299 | tmp = cv2.flip(tmp, flipCode=1)
300 | sample[elem] = tmp
301 |
302 | return sample
303 |
304 |
305 | class RandomGaussianBlur(object):
306 |
307 | def __init__(self, prob=0.2):
308 | self.p = prob
309 |
310 | def __call__(self, sample):
311 |
312 |
313 | for elem in sample.keys():
314 | if 'meta' in elem or 'label' in elem:
315 | continue
316 |
317 | if elem == 'curr_img':
318 | new_tmp = []
319 | for tmp_ in sample[elem]:
320 | if random.random() < self.p:
321 | std = random.random() * 1.9 + 0.1 # [0.1, 2]
322 | tmp_ = cv2.GaussianBlur(tmp_, (9, 9), sigmaX=std, sigmaY=std)
323 | new_tmp.append(tmp_)
324 | sample[elem] = new_tmp
325 | else:
326 | tmp = sample[elem]
327 | if random.random() < self.p:
328 | std = random.random() * 1.9 + 0.1 # [0.1, 2]
329 | tmp = cv2.GaussianBlur(tmp, (9, 9), sigmaX=std, sigmaY=std)
330 | sample[elem] = tmp
331 |
332 | return sample
333 |
334 | class SubtractMeanImage(object):
335 | def __init__(self, mean, change_channels=False):
336 | self.mean = mean
337 | self.change_channels = change_channels
338 |
339 | def __call__(self, sample):
340 | for elem in sample.keys():
341 | if 'image' in elem:
342 | if self.change_channels:
343 | sample[elem] = sample[elem][:, :, [2, 1, 0]]
344 | sample[elem] = np.subtract(
345 | sample[elem], np.array(self.mean, dtype=np.float32))
346 | return sample
347 |
348 | def __str__(self):
349 | return 'SubtractMeanImage' + str(self.mean)
350 |
351 |
352 | class ToTensor(object):
353 | """Convert ndarrays in sample to Tensors."""
354 |
355 | def __call__(self, sample):
356 |
357 | for elem in sample.keys():
358 | if 'meta' in elem:
359 | continue
360 | tmp = sample[elem]
361 |
362 | if elem == 'curr_img' or elem == 'curr_label':
363 | new_tmp = []
364 | for tmp_ in tmp:
365 | if tmp_.ndim == 2:
366 | tmp_ = tmp_[:, :, np.newaxis]
367 | else:
368 | tmp_ = tmp_ / 255.
369 | tmp_ -= (0.485, 0.456, 0.406)
370 | tmp_ /= (0.229, 0.224, 0.225)
371 | tmp_ = tmp_.transpose((2, 0, 1))
372 | new_tmp.append(torch.from_numpy(tmp_))
373 | tmp = new_tmp
374 | else:
375 | if tmp.ndim == 2:
376 | tmp = tmp[:, :, np.newaxis]
377 | else:
378 | tmp = tmp / 255.
379 | tmp -= (0.485, 0.456, 0.406)
380 | tmp /= (0.229, 0.224, 0.225)
381 | tmp = tmp.transpose((2, 0, 1))
382 | tmp = torch.from_numpy(tmp)
383 | sample[elem] = tmp
384 |
385 | return sample
386 |
387 | class MultiRestrictSize(object):
388 | def __init__(self, min_size=None, max_size=800, flip=False, multi_scale=[1.3]):
389 | self.min_size = min_size
390 | self.max_size = max_size
391 | self.multi_scale = multi_scale
392 | self.flip = flip
393 | assert ((min_size is None)) or ((max_size is None))
394 |
395 | def __call__(self, sample):
396 | samples = []
397 | image = sample['current_img']
398 | h, w = image.shape[:2]
399 | for scale in self.multi_scale:
400 | # Fixed range of scales
401 | sc = None
402 | # Align short edge
403 | if not (self.min_size is None):
404 | if h > w:
405 | short_edge = w
406 | else:
407 | short_edge = h
408 | if short_edge > self.min_size:
409 | sc = float(self.min_size) / short_edge
410 | else:
411 | if h > w:
412 | long_edge = h
413 | else:
414 | long_edge = w
415 | if long_edge > self.max_size:
416 | sc = float(self.max_size) / long_edge
417 |
418 | if sc is None:
419 | new_h = h
420 | new_w = w
421 | else:
422 | new_h = sc * h
423 | new_w = sc * w
424 | new_h = int(new_h * scale)
425 | new_w = int(new_w * scale)
426 |
427 | if (new_h - 1) % 16 != 0:
428 | new_h = int(np.around((new_h - 1) / 16.) * 16 + 1)
429 | if (new_w - 1) % 16 != 0:
430 | new_w = int(np.around((new_w - 1) / 16.) * 16 + 1)
431 |
432 | if new_h == h and new_w == w:
433 | samples.append(sample)
434 | else:
435 | new_sample = {}
436 | for elem in sample.keys():
437 | if 'meta' in elem:
438 | new_sample[elem] = sample[elem]
439 | continue
440 | tmp = sample[elem]
441 | if 'label' in elem:
442 | new_sample[elem] = sample[elem]
443 | continue
444 | else:
445 | flagval = cv2.INTER_CUBIC
446 | tmp = cv2.resize(tmp, dsize=(new_w, new_h), interpolation=flagval)
447 | new_sample[elem] = tmp
448 | samples.append(new_sample)
449 |
450 | if self.flip:
451 | now_sample = samples[-1]
452 | new_sample = {}
453 | for elem in now_sample.keys():
454 | if 'meta' in elem:
455 | new_sample[elem] = now_sample[elem].copy()
456 | new_sample[elem]['flip'] = True
457 | continue
458 | tmp = now_sample[elem]
459 | tmp = tmp[:, ::-1].copy()
460 | new_sample[elem] = tmp
461 | samples.append(new_sample)
462 |
463 | return samples
464 |
465 | class MultiToTensor(object):
466 | def __call__(self, samples):
467 | for idx in range(len(samples)):
468 | sample = samples[idx]
469 | for elem in sample.keys():
470 | if 'meta' in elem:
471 | continue
472 | tmp = sample[elem]
473 | if tmp is None:
474 | continue
475 |
476 | if tmp.ndim == 2:
477 | tmp = tmp[:, :, np.newaxis]
478 | else:
479 | tmp = tmp / 255.
480 | tmp -= (0.485, 0.456, 0.406)
481 | tmp /= (0.229, 0.224, 0.225)
482 |
483 | tmp = tmp.transpose((2, 0, 1))
484 | samples[idx][elem] = torch.from_numpy(tmp)
485 |
486 | return samples
487 |
488 |
--------------------------------------------------------------------------------
/AOC-Net/complete_project/AOCNet/networks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JerryX1110/Robust-Video-Object-Segmentation/f562afe70ac0f960d27980d3e82b237f8dea45ec/AOC-Net/complete_project/AOCNet/networks/__init__.py
--------------------------------------------------------------------------------
/AOC-Net/complete_project/AOCNet/networks/aoc/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JerryX1110/Robust-Video-Object-Segmentation/f562afe70ac0f960d27980d3e82b237f8dea45ec/AOC-Net/complete_project/AOCNet/networks/aoc/__init__.py
--------------------------------------------------------------------------------
/AOC-Net/complete_project/AOCNet/networks/aoc/aocnet.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from networks.layers.loss import Concat_CrossEntropyLoss
6 | from networks.layers.matching import global_matching, global_matching_for_eval, local_matching, foreground2background, global_matching_proxy,local_matching_proxy,global_matching_cluster2, global_matching_cluster,global_matching_for_eval_cluster, global_matching_for_eval_proxy
7 | from networks.layers.attention import calculate_attention_head, calculate_attention_head_for_eval, calculate_attention_head_p_m, calculate_attention_head_for_eval_p_m
8 | from networks.p2t.decoding_module import CalibrationDecoding,DynamicPreHead
9 |
10 |
11 | class AOCNet(nn.Module):
12 | def __init__(self, cfg, feature_extracter):
13 | super(AOCNet, self).__init__()
14 | self.cfg = cfg
15 | self.epsilon = cfg.MODEL_EPSILON
16 |
17 | self.feature_extracter=feature_extracter
18 |
19 | self.seperate_conv = nn.Conv2d(cfg.MODEL_ASPP_OUTDIM, cfg.MODEL_ASPP_OUTDIM, kernel_size=3, stride=1, padding=1, groups=cfg.MODEL_ASPP_OUTDIM)
20 | self.bn1 = nn.GroupNorm(cfg.MODEL_GN_GROUPS, cfg.MODEL_ASPP_OUTDIM)
21 | self.relu1 = nn.ReLU(True)
22 | self.embedding_conv = nn.Conv2d(cfg.MODEL_ASPP_OUTDIM, cfg.MODEL_SEMANTIC_EMBEDDING_DIM, 1, 1)
23 | self.bn2 = nn.GroupNorm(cfg.MODEL_GN_EMB_GROUPS, cfg.MODEL_SEMANTIC_EMBEDDING_DIM)
24 | self.relu2 = nn.ReLU(True)
25 | self.semantic_embedding=nn.Sequential(*[self.seperate_conv, self.bn1, self.relu1, self.embedding_conv, self.bn2, self.relu2])
26 |
27 | self.bg_bias = nn.Parameter(torch.zeros(1, 1, 1, 1))
28 | self.fg_bias = nn.Parameter(torch.zeros(1, 1, 1, 1))
29 |
30 | self.criterion = Concat_CrossEntropyLoss(cfg.TRAIN_TOP_K_PERCENT_PIXELS, cfg.TRAIN_HARD_MINING_STEP)
31 |
32 | for m in self.semantic_embedding:
33 | if isinstance(m, nn.Conv2d):
34 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
35 |
36 | self.dynamic_seghead = CalibrationDecoding(
37 | in_dim=cfg.MODEL_SEMANTIC_EMBEDDING_DIM + cfg.MODEL_PRE_HEAD_EMBEDDING_DIM,
38 | attention_dim=cfg.MODEL_SEMANTIC_EMBEDDING_DIM * 4,
39 | embed_dim=cfg.MODEL_HEAD_EMBEDDING_DIM,
40 | refine_dim=cfg.MODEL_REFINE_CHANNELS,
41 | low_level_dim=cfg.MODEL_LOW_LEVEL_INPLANES)
42 |
43 | in_dim = 2*(2 + len(cfg.MODEL_MULTI_LOCAL_DISTANCE))-1+2
44 |
45 | if cfg.MODEL_MATCHING_BACKGROUND:
46 | in_dim += 1 +len(cfg.MODEL_MULTI_LOCAL_DISTANCE)
47 |
48 |
49 | self.dynamic_prehead = DynamicPreHead(
50 | in_dim=in_dim,
51 | embed_dim=cfg.MODEL_PRE_HEAD_EMBEDDING_DIM)
52 |
53 |
54 | def forward(self, input,memory_prev_list, ref_frame_label, previous_frame_mask, current_frame_mask,
55 | gt_ids, step=0, tf_board=False):
56 | x, low_level = self.extract_feature(input)
57 | ref_frame_embedding, previous_frame_embedding, current_frame_embedding = torch.split(x, split_size_or_sections=int(x.size(0)/3), dim=0)
58 | _, _, current_low_level = torch.split(low_level, split_size_or_sections=int(x.size(0)/3), dim=0)
59 | bs, c, h, w = current_frame_embedding.size()
60 | tmp_dic, boards, memory_cur_list= self.before_seghead_process(
61 | memory_prev_list,
62 | ref_frame_embedding,
63 | previous_frame_embedding,
64 | current_frame_embedding,
65 | ref_frame_label,
66 | previous_frame_mask,
67 | gt_ids,
68 | current_low_level=current_low_level,tf_board=tf_board)
69 | label_dic=[]
70 | all_pred = []
71 | for i in range(bs):
72 | tmp_pred_logits = tmp_dic[i]
73 | tmp_pred_logits = nn.functional.interpolate(tmp_pred_logits, size=(input.shape[2],input.shape[3]), mode='bilinear', align_corners=True)
74 | tmp_dic[i] = tmp_pred_logits
75 | label_tmp, obj_num = current_frame_mask[i], gt_ids[i]
76 | label_dic.append(label_tmp.long())
77 | pred = tmp_pred_logits
78 | preds_s = torch.argmax(pred,dim=1)
79 | all_pred.append(preds_s)
80 | all_pred = torch.cat(all_pred, dim=0)
81 |
82 | return self.criterion(tmp_dic, label_dic, step), all_pred, boards, memory_cur_list
83 |
84 | def forward_for_eval(self,memory_prev_list, ref_embeddings, ref_masks, prev_embedding, prev_mask, current_frame, pred_size, gt_ids):
85 | current_frame_embedding, current_low_level = self.extract_feature(current_frame)
86 | if prev_embedding is None:
87 | return None, current_frame_embedding,memory_prev_list
88 | else:
89 | bs,c,h,w = current_frame_embedding.size()
90 | tmp_dic, _ ,memory_cur_list= self.before_seghead_process(
91 | memory_prev_list,
92 | ref_embeddings,
93 | prev_embedding,
94 | current_frame_embedding,
95 | ref_masks,
96 | prev_mask,
97 | gt_ids,
98 | current_low_level=current_low_level,
99 | tf_board=False)
100 | all_pred = []
101 | for i in range(bs):
102 | pred = tmp_dic[i]
103 | pred = nn.functional.interpolate(pred, size=(pred_size[0],pred_size[1]), mode='bilinear',align_corners=True)
104 | all_pred.append(pred)
105 | all_pred = torch.cat(all_pred, dim=0)
106 | all_pred = torch.softmax(all_pred, dim=1)
107 | return all_pred, current_frame_embedding, memory_cur_list
108 |
109 | def extract_feature(self, x):
110 | x, low_level=self.feature_extracter(x)
111 | x = self.semantic_embedding(x)
112 | return x, low_level
113 |
114 | def before_seghead_process(self, memory_prev_list=None,
115 | ref_frame_embedding=None, previous_frame_embedding=None, current_frame_embedding=None,
116 | ref_frame_label=None, previous_frame_mask=None,
117 | gt_ids=None, current_low_level=None, tf_board=False):
118 |
119 | cfg = self.cfg
120 |
121 | dic_tmp=[]
122 | bs,c,h,w = current_frame_embedding.size()
123 |
124 | if self.training:
125 | scale_ref_frame_label = torch.nn.functional.interpolate(ref_frame_label.float(),size=(h,w),mode='nearest')
126 | scale_ref_frame_label = scale_ref_frame_label.int()
127 | else:
128 | scale_ref_frame_labels = []
129 | for each_ref_frame_label in ref_frame_label:
130 | each_scale_ref_frame_label = torch.nn.functional.interpolate(each_ref_frame_label.float(),size=(h,w),mode='nearest')
131 | each_scale_ref_frame_label = each_scale_ref_frame_label.int()
132 | scale_ref_frame_labels.append(each_scale_ref_frame_label)
133 |
134 | scale_previous_frame_label=torch.nn.functional.interpolate(previous_frame_mask.float(),size=(h,w),mode='nearest')
135 | scale_previous_frame_label=scale_previous_frame_label.int()
136 |
137 | boards = {'image': {}, 'scalar': {}}
138 | memory_cur_list = []
139 |
140 | for n in range(bs):
141 | ref_obj_ids = torch.arange(0, gt_ids[n] + 1, device=current_frame_embedding.device).int().view(-1, 1, 1, 1)
142 | obj_num = ref_obj_ids.size(0)
143 | if gt_ids[n] > 0:
144 | dis_bias = torch.cat([self.bg_bias, self.fg_bias.expand(gt_ids[n], -1, -1, -1)], dim=0)
145 | else:
146 | dis_bias = self.bg_bias
147 |
148 | seq_current_frame_embedding = current_frame_embedding[n]
149 | seq_current_frame_embedding = seq_current_frame_embedding.permute(1,2,0)
150 |
151 |
152 | seq_prev_frame_embedding = previous_frame_embedding[n]
153 | seq_prev_frame_embedding = seq_prev_frame_embedding.permute(1,2,0)
154 | seq_previous_frame_label = (scale_previous_frame_label[n].int() == ref_obj_ids).float()
155 | to_cat_previous_frame = seq_previous_frame_label
156 | seq_previous_frame_label = seq_previous_frame_label.squeeze(1).permute(1,2,0)
157 | to_cat_current_frame_embedding = current_frame_embedding[n].unsqueeze(0).expand((obj_num,-1,-1,-1))
158 | to_cat_prev_frame_embedding = previous_frame_embedding[n].unsqueeze(0).expand((obj_num,-1,-1,-1))
159 |
160 |
161 | ## start global matching
162 | if self.training:
163 | seq_ref_frame_embedding = ref_frame_embedding[n]
164 | seq_ref_frame_embedding = seq_ref_frame_embedding.permute(1,2,0)
165 |
166 | seq_ref_frame_label = (scale_ref_frame_label[n].int() == ref_obj_ids).float()
167 | to_cat_ref_frame = seq_ref_frame_label
168 | seq_ref_frame_label = seq_ref_frame_label.squeeze(1).permute(1,2,0)
169 |
170 | global_matching_fg = global_matching(
171 | reference_embeddings=seq_ref_frame_embedding,
172 | query_embeddings=seq_current_frame_embedding,
173 | reference_labels=seq_ref_frame_label,
174 | n_chunks=cfg.TRAIN_GLOBAL_CHUNKS,
175 | dis_bias=dis_bias,
176 | atrous_rate=cfg.TRAIN_GLOBAL_ATROUS_RATE,
177 | use_float16=cfg.MODEL_FLOAT16_MATCHING)
178 |
179 | else:
180 | all_reference_embeddings = []
181 | all_reference_labels = []
182 | seq_ref_frame_labels = []
183 | for idx in range(len(scale_ref_frame_labels)):
184 | each_ref_frame_embedding = ref_frame_embedding[idx]
185 | scale_ref_frame_label = scale_ref_frame_labels[idx]
186 |
187 | seq_ref_frame_embedding = each_ref_frame_embedding[n]
188 | seq_ref_frame_embedding = seq_ref_frame_embedding.permute(1,2,0)
189 | all_reference_embeddings.append(seq_ref_frame_embedding)
190 |
191 | seq_ref_frame_label = (scale_ref_frame_label[n].int() == ref_obj_ids).float()
192 | seq_ref_frame_labels.append(seq_ref_frame_label)
193 | seq_ref_frame_label = seq_ref_frame_label.squeeze(1).permute(1,2,0)
194 | all_reference_labels.append(seq_ref_frame_label)
195 |
196 | global_matching_fg = global_matching_for_eval(
197 | all_reference_embeddings=all_reference_embeddings,
198 | query_embeddings=seq_current_frame_embedding,
199 | all_reference_labels=all_reference_labels,
200 | n_chunks=cfg.TEST_GLOBAL_CHUNKS,
201 | dis_bias=dis_bias,
202 | atrous_rate=cfg.TEST_GLOBAL_ATROUS_RATE,
203 | use_float16=cfg.MODEL_FLOAT16_MATCHING)
204 |
205 | ## end global matching
206 |
207 | ## start global matching cluster (defualt num: 16)
208 | if self.training:
209 | seq_ref_frame_embedding = ref_frame_embedding[n]
210 | seq_ref_frame_embedding = seq_ref_frame_embedding.permute(1,2,0)
211 |
212 | seq_ref_frame_label = (scale_ref_frame_label[n].int() == ref_obj_ids).float()
213 | to_cat_ref_frame = seq_ref_frame_label
214 | seq_ref_frame_label = seq_ref_frame_label.squeeze(1).permute(1,2,0)
215 |
216 | global_matching_fg_cluster = global_matching_cluster2(
217 | reference_embeddings=seq_ref_frame_embedding,
218 | query_embeddings=seq_current_frame_embedding,
219 | reference_labels=seq_ref_frame_label,
220 | n_chunks=cfg.TRAIN_GLOBAL_CHUNKS,
221 | dis_bias=dis_bias,
222 | atrous_rate=cfg.TRAIN_GLOBAL_ATROUS_RATE,
223 | use_float16=cfg.MODEL_FLOAT16_MATCHING)
224 |
225 | else:
226 | all_reference_embeddings = []
227 | all_reference_labels = []
228 | seq_ref_frame_labels = []
229 | for idx in range(len(scale_ref_frame_labels)):
230 | each_ref_frame_embedding = ref_frame_embedding[idx]
231 | scale_ref_frame_label = scale_ref_frame_labels[idx]
232 |
233 | seq_ref_frame_embedding = each_ref_frame_embedding[n]
234 | seq_ref_frame_embedding = seq_ref_frame_embedding.permute(1,2,0)
235 | all_reference_embeddings.append(seq_ref_frame_embedding)
236 |
237 | seq_ref_frame_label = (scale_ref_frame_label[n].int() == ref_obj_ids).float()
238 | seq_ref_frame_labels.append(seq_ref_frame_label)
239 | seq_ref_frame_label = seq_ref_frame_label.squeeze(1).permute(1,2,0)
240 | all_reference_labels.append(seq_ref_frame_label)
241 |
242 | global_matching_fg_cluster = global_matching_for_eval_cluster(
243 | all_reference_embeddings=all_reference_embeddings,
244 | query_embeddings=seq_current_frame_embedding,
245 | all_reference_labels=all_reference_labels,
246 | n_chunks=cfg.TEST_GLOBAL_CHUNKS,
247 | dis_bias=dis_bias,
248 | atrous_rate=cfg.TEST_GLOBAL_ATROUS_RATE,
249 | use_float16=cfg.MODEL_FLOAT16_MATCHING)
250 |
251 | ## end global matching cluster (defualt num: 16)
252 |
253 | ## start local matching
254 |
255 | local_matching_fg = local_matching(
256 | prev_frame_embedding=seq_prev_frame_embedding,
257 | query_embedding=seq_current_frame_embedding,
258 | prev_frame_labels=seq_previous_frame_label,
259 | multi_local_distance=cfg.MODEL_MULTI_LOCAL_DISTANCE,
260 | dis_bias=dis_bias,
261 | use_float16=cfg.MODEL_FLOAT16_MATCHING,
262 | atrous_rate=cfg.TRAIN_LOCAL_ATROUS_RATE if self.training else cfg.TEST_LOCAL_ATROUS_RATE,
263 | allow_downsample=cfg.MODEL_LOCAL_DOWNSAMPLE,
264 | allow_parallel=cfg.TRAIN_LOCAL_PARALLEL if self.training else cfg.TEST_LOCAL_PARALLEL)
265 |
266 | ## start end matching
267 |
268 |
269 |
270 | ## start refereced instance-level object representation calculation (proxy indicates cluster_num = 1)
271 | if self.training:
272 | attention_head = calculate_attention_head(
273 | ref_frame_embedding[n].unsqueeze(0).expand((obj_num,-1,-1,-1)),
274 | to_cat_ref_frame,
275 | previous_frame_embedding[n].unsqueeze(0).expand((obj_num,-1,-1,-1)),
276 | to_cat_previous_frame,
277 | epsilon=self.epsilon)
278 | else:
279 | attention_head = calculate_attention_head_for_eval(
280 | ref_frame_embedding,
281 | seq_ref_frame_labels,
282 | previous_frame_embedding[n].unsqueeze(0).expand((obj_num,-1,-1,-1)),
283 | to_cat_previous_frame,
284 | epsilon=self.epsilon)
285 | ## end refereced instance-level object representation calculation (proxy indicates cluster_num = 1)
286 |
287 |
288 | ## start proxy-based matching
289 | if self.training:
290 | attention_head,ref_head_pos, ref_head_neg, prev_head_pos, prev_head_neg= calculate_attention_head_p_m(
291 | ref_frame_embedding[n].unsqueeze(0).expand((obj_num,-1,-1,-1)),
292 | to_cat_ref_frame,
293 | previous_frame_embedding[n].unsqueeze(0).expand((obj_num,-1,-1,-1)),
294 | to_cat_previous_frame,
295 | epsilon=self.epsilon)
296 | else:
297 | attention_head,ref_head_pos, ref_head_neg, prev_head_pos, prev_head_neg = calculate_attention_head_for_eval_p_m(
298 | ref_frame_embedding,
299 | seq_ref_frame_labels,
300 | previous_frame_embedding[n].unsqueeze(0).expand((obj_num,-1,-1,-1)),
301 | to_cat_previous_frame,
302 | epsilon=self.epsilon)
303 |
304 | if self.training:
305 | global_matching_fg_proxy = global_matching_proxy(
306 | reference_embeddings=ref_head_pos,
307 | query_embeddings=seq_current_frame_embedding,
308 | reference_labels=seq_ref_frame_label,
309 | n_chunks=cfg.TRAIN_GLOBAL_CHUNKS,
310 | dis_bias=dis_bias,
311 | atrous_rate=cfg.TRAIN_GLOBAL_ATROUS_RATE,
312 | use_float16=cfg.MODEL_FLOAT16_MATCHING)
313 | else:
314 | global_matching_fg_proxy = global_matching_for_eval_proxy(
315 | all_reference_embeddings=ref_head_pos,
316 | query_embeddings=seq_current_frame_embedding,
317 | all_reference_labels=all_reference_labels,
318 | n_chunks=cfg.TEST_GLOBAL_CHUNKS,
319 | dis_bias=dis_bias,
320 | atrous_rate=cfg.TEST_GLOBAL_ATROUS_RATE,
321 | use_float16=cfg.MODEL_FLOAT16_MATCHING)
322 |
323 |
324 |
325 | seq_prev_frame_embedding_inst = torch.matmul(seq_previous_frame_label, prev_head_pos)
326 |
327 |
328 | local_matching_fg_proxy = local_matching_proxy(
329 | prev_frame_embedding=seq_prev_frame_embedding_inst,
330 | query_embedding=seq_current_frame_embedding,
331 | prev_frame_labels=seq_previous_frame_label,
332 | multi_local_distance=cfg.MODEL_MULTI_LOCAL_DISTANCE,
333 | dis_bias=dis_bias,
334 | use_float16=cfg.MODEL_FLOAT16_MATCHING,
335 | atrous_rate=cfg.TRAIN_LOCAL_ATROUS_RATE if self.training else cfg.TEST_LOCAL_ATROUS_RATE,
336 | allow_downsample=cfg.MODEL_LOCAL_DOWNSAMPLE,
337 | allow_parallel=cfg.TRAIN_LOCAL_PARALLEL if self.training else cfg.TEST_LOCAL_PARALLEL)
338 |
339 |
340 |
341 | to_cat_global_matching_fg_proxy = global_matching_fg_proxy.squeeze(0).permute(2,3,0,1)
342 | to_cat_global_matching_fg_cluster = global_matching_fg_cluster.squeeze(0).permute(2,3,0,1)
343 | to_cat_global_matching_fg = global_matching_fg.squeeze(0).permute(2,3,0,1)
344 | to_cat_local_matching_fg_proxy = local_matching_fg_proxy.squeeze(0).permute(2,3,0,1)
345 | to_cat_local_matching_fg = local_matching_fg.squeeze(0).permute(2,3,0,1)
346 |
347 |
348 |
349 | if cfg.MODEL_MATCHING_BACKGROUND:
350 | to_cat_global_matching_bg = foreground2background(to_cat_global_matching_fg, gt_ids[n] + 1)
351 | reshaped_prev_nn_feature_n = to_cat_local_matching_fg.permute(0, 2, 3, 1).unsqueeze(1)
352 | to_cat_local_matching_bg = foreground2background(reshaped_prev_nn_feature_n, gt_ids[n] + 1)
353 | to_cat_local_matching_bg = to_cat_local_matching_bg.permute(0, 4, 2, 3, 1).squeeze(-1)
354 |
355 | pre_to_cat = torch.cat((to_cat_global_matching_fg, to_cat_global_matching_fg_cluster, to_cat_global_matching_fg_proxy,to_cat_local_matching_fg, to_cat_local_matching_fg_proxy, to_cat_previous_frame), 1)
356 |
357 | if cfg.MODEL_MATCHING_BACKGROUND:
358 | pre_to_cat = torch.cat([pre_to_cat, to_cat_local_matching_bg, to_cat_global_matching_bg], 1)
359 |
360 | pre_to_cat = self.dynamic_prehead(pre_to_cat)
361 |
362 | to_cat = torch.cat((to_cat_current_frame_embedding, pre_to_cat),1)
363 |
364 |
365 | low_level_feat = current_low_level[n].unsqueeze(0)
366 |
367 | pred,memory_tmp_list = self.dynamic_seghead(to_cat, attention_head, memory_prev_list[n], low_level_feat,to_cat_previous_frame)
368 | memory_cur_list.append(memory_tmp_list)
369 | dic_tmp.append(pred)
370 |
371 |
372 | return dic_tmp, boards, memory_cur_list
373 |
374 | def get_module():
375 | return AOCNet
376 |
--------------------------------------------------------------------------------
/AOC-Net/complete_project/AOCNet/networks/aoc/conditioning_layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class conditioning_layer(nn.Module):
7 |
8 | # Equation (7) of the main paper
9 |
10 | def __init__(self,
11 | in_dim=256,
12 | beta_percentage=0.3):
13 | super(conditioning_layer,self).__init__()
14 |
15 | self.beta_percentage = beta_percentage
16 |
17 | kernel_size = 1
18 | self.phi_layer = nn.Conv2d(in_dim,1,kernel_size=kernel_size,stride=1,padding=int((kernel_size-1)/2))
19 | self.mlp_layer = nn.Linear(in_dim, in_dim)
20 |
21 | nn.init.kaiming_normal_(self.phi_layer.weight,mode='fan_out',nonlinearity='relu')
22 |
23 |
24 | def forward(self, z_in):
25 |
26 | # Step 1: phi(z_in)
27 | x = self.phi_layer(z_in)
28 |
29 | # Step 2: beta
30 | x = x.reshape(x.size()[0],x.size()[1],-1)
31 | z_in_reshape = z_in.reshape(z_in.size()[0],z_in.size()[1],-1)
32 | beta_rank = int(self.beta_percentage*z_in.size()[-1]*z_in.size()[-2])
33 | beta_val, _ = torch.topk(x, k=beta_rank, dim=-1, sorted=True)
34 |
35 | # Step 3: pi_beta(phi(z_in))
36 | x = x > beta_val[...,-1,None]
37 |
38 | # Step 4: z_in \odot pi_beta(phi(z_in))
39 | z_in_masked = z_in_reshape * x
40 |
41 | # Step 5: GAP (z_in \odot pi_beta(phi(z_in)))
42 | z_in_masked_gap = torch.nn.functional.avg_pool1d(z_in_masked,
43 | kernel_size=z_in_masked.size()[-1]).squeeze(-1)
44 |
45 | # Step 6: MLP ( GAP (z_in \odot pi_beta(phi(z_in))) )
46 | out = mlp_layer(z_in_masked_gap)
47 |
48 | return out
49 |
50 | class conditioning_block(nn.Module):
51 |
52 | # Equation (5) of the main paper
53 |
54 | def __init__(self,
55 | in_dim=256,
56 | proxy_dim = 400,
57 | beta_percentage=0.3):
58 | super(conditioning_block,self).__init__()
59 |
60 | self.CL_1 = conditioning_layer(in_dim, beta_percentage)
61 | self.CL_2 = conditioning_layer(in_dim, beta_percentage)
62 | self.CL_3 = conditioning_layer(proxy_dim, 1)
63 |
64 | self.mlp_layer = nn.Linear(in_dim * 2 + proxy_dim, in_dim)
65 |
66 | def forward(self, x, proxy_IA_head):
67 |
68 | px1 = torch.nn.functional.avg_pool2d(x,kernel_size=(x.size()[-2],x.size()[-1]),padding = 0)
69 | x_delta = (torch.sum(px1,dim=0,keepdim=True)-px1).squeeze(-1).squeeze(-1)
70 |
71 | # Step 1: cal intra-object conditioning code
72 | cl_out_1 = CL_1(x)
73 |
74 | # Step 2: cal inter-object conditioning code
75 | cl_out_2 = CL_2(x_delta)
76 |
77 | # Step 3: cal conditioning code with poxies
78 | cl_out_3 = CL_3(proxy_IA_head)
79 |
80 | # Step 4: conduct calibration
81 | a = self.mlp_layer(torch.cat([cl_out_1,cl_out_2,cl_out_3],dim=1))
82 | a = 1. + torch.tanh(a)
83 | a = a.unsqueeze(-1).unsqueeze(-1)
84 | x = a * x
85 |
86 | return x
87 |
--------------------------------------------------------------------------------
/AOC-Net/complete_project/AOCNet/networks/aoc/decoding_module.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from networks.layers.attention import IA_gate
5 | from networks.layers.gct import Bottleneck, GCT
6 | from networks.layers.aspp import ASPP
7 | from networks.p2t.conditioning_layer import conditioning_layer, conditioning_block
8 |
9 |
10 | class CalibrationDecoding(nn.Module):
11 | def __init__(self,
12 | in_dim=256,
13 | attention_dim=400,
14 | embed_dim=100,
15 | refine_dim=48,
16 | low_level_dim=256,
17 | beta_percentage=0.3):
18 | super(CalibrationDecoding,self).__init__()
19 | self.embed_dim = embed_dim
20 | IA_in_dim = attention_dim
21 | self.unc_topk_ratio = unc_topk_ratio
22 | self.IA1 = IA_gate(IA_in_dim, in_dim)
23 | self.layer1 = Bottleneck(in_dim, embed_dim)
24 |
25 |
26 | self.layer2 = Bottleneck(embed_dim, embed_dim, 1, 2)
27 | self.CLB2 = conditioning_block(
28 | in_dim=embed_dim,
29 | attention_dim=IA_in_dim,
30 | beta_percentage=self.beta_percentage)
31 |
32 |
33 | self.layer3 = Bottleneck(embed_dim, embed_dim * 2, 2)
34 | self.CLB3 = conditioning_block(
35 | in_dim=embed_dim,
36 | attention_dim=IA_in_dim,
37 | beta_percentage=self.beta_percentage)
38 |
39 |
40 | self.layer4 = Bottleneck(embed_dim * 2, embed_dim * 2, 1, 2)
41 | self.CLB4 = conditioning_block(
42 | in_dim=embed_dim*2,
43 | attention_dim=IA_in_dim,
44 | beta_percentage=self.beta_percentage)
45 |
46 | self.layer5 = Bottleneck(embed_dim * 2, embed_dim * 2, 1, 4)
47 | self.CLB5 = conditioning_block(
48 | in_dim=embed_dim*2,
49 | attention_dim=IA_in_dim,
50 | beta_percentage=self.beta_percentage)
51 |
52 | self.IA9 = IA_gate(IA_in_dim+embed_dim * 2, embed_dim * 2)
53 | self.ASPP = ASPP()
54 |
55 | self.M1_Reweight_Layer_1 = IA_gate(IA_in_dim, embed_dim * 2)
56 | self.M1_Bottleneck_1 = Bottleneck(embed_dim*2, embed_dim * 2, 1)
57 |
58 | self.M1_Reweight_Layer_2 = IA_gate(IA_in_dim, embed_dim * 2)
59 | self.M1_Bottleneck_2 = Bottleneck(embed_dim*2, embed_dim * 1, 1)
60 |
61 | self.M1_Reweight_Layer_3 = IA_gate(IA_in_dim, embed_dim * 1)
62 | self.M1_Bottleneck_3 = Bottleneck(embed_dim*1, embed_dim * 1, 1)
63 |
64 | self.M2_Reweight_Layer_1 = IA_gate(IA_in_dim, embed_dim * 2)
65 | self.M2_Bottleneck_1 = Bottleneck(embed_dim*2, embed_dim * 2, 1)
66 |
67 | self.M2_Reweight_Layer_2 = IA_gate(IA_in_dim, embed_dim * 2)
68 | self.M2_Bottleneck_2 = Bottleneck(embed_dim*2, embed_dim * 1, 1)
69 |
70 | self.M2_Reweight_Layer_3 = IA_gate(IA_in_dim, embed_dim *1)
71 | self.M2_Bottleneck_3 = Bottleneck(embed_dim*1, embed_dim * 1, 1)
72 |
73 |
74 | self.GCT_sc = GCT(low_level_dim + embed_dim)
75 | self.conv_sc = nn.Conv2d(low_level_dim + embed_dim, refine_dim, 1, bias=False)
76 | self.bn_sc = nn.GroupNorm(int(refine_dim / 4), refine_dim)
77 | self.relu = nn.ReLU(inplace=True)
78 |
79 | self.IA10 = IA_gate(IA_in_dim+embed_dim + refine_dim, embed_dim + refine_dim)
80 | self.conv1 = nn.Conv2d(embed_dim + refine_dim, int(embed_dim / 2), kernel_size=3, padding=1, bias=False)
81 | self.bn1 = nn.GroupNorm(32, int(embed_dim / 2))
82 |
83 |
84 | self.IA11 = IA_gate(IA_in_dim+int(embed_dim / 2), int(embed_dim / 2))
85 | self.conv2 = nn.Conv2d(int(embed_dim / 2), int(embed_dim / 2), kernel_size=3, padding=1, bias=False)
86 | self.bn2 = nn.GroupNorm(32, int(embed_dim / 2))
87 |
88 | self.IA_final_fg = nn.Linear(IA_in_dim, int(embed_dim / 2) + 1)
89 | self.IA_final_bg = nn.Linear(IA_in_dim, int(embed_dim / 2) + 1)
90 |
91 | nn.init.kaiming_normal_(self.conv_sc.weight,mode='fan_out', nonlinearity='relu')
92 | nn.init.kaiming_normal_(self.conv1.weight,mode='fan_out', nonlinearity='relu')
93 | nn.init.kaiming_normal_(self.conv2.weight,mode='fan_out', nonlinearity='relu')
94 |
95 |
96 | def forward(self, x, IA_head=None,memory_list=None,low_level_feat=None,to_cat_previous_frame=None):
97 |
98 |
99 | x = self.IA1(x, IA_head)
100 | x = self.layer1(x)
101 |
102 | # start: Calibration #1
103 | x = self.CLB2(x,IA_head)
104 | # end: Calibration #1
105 |
106 | x = self.layer2(x)
107 |
108 | # start: Calibration #2
109 | x = self.CLB3(x,IA_head)
110 | # end: Calibration #2
111 |
112 | x = self.layer3(x)
113 |
114 | # start: Calibration #3
115 | x = self.CLB4(x,IA_head)
116 | # end: Calibration #3
117 |
118 | x = self.layer4(x)
119 |
120 | # start: Calibration #4)
121 | x = self.CLB5(x,IA_head)
122 | # end: Calibration #4
123 |
124 | x = self.layer5(x)
125 |
126 | px1 = torch.nn.functional.avg_pool2d(x,kernel_size=(x.size()[-2],x.size()[-1]),padding = 0)
127 | px1_sum = torch.sum(px1,dim=0,keepdim=True)
128 | px1_delta = (px1_sum-px1).squeeze(-1).squeeze(-1)
129 |
130 | x = self.IA9(x, torch.cat([IA_head,px1_delta],dim=1))
131 | x = self.ASPP(x)
132 |
133 | x_emb_cur_1 = x.detach()
134 | if memory_list[0]==None or x_emb_cur_1.size()!=memory_list[0].size():
135 | memory_list[0] = x_emb_cur_1
136 | x = self.Modulator_1(x, memory_list[0].cuda(x.device),IA_head)
137 | x_emb_cur_2 = x.detach()
138 | if memory_list[1]==None or x_emb_cur_2.size()!=memory_list[1].size():
139 | memory_list[1] = x_emb_cur_2
140 | x = self.Modulator_2(x, memory_list[1].cuda(x.device),IA_head)
141 |
142 | x = self.decoder_final(x, low_level_feat, IA_head)
143 |
144 | fg_logit = self.IA_logit(x, IA_head, self.IA_final_fg)
145 | bg_logit = self.IA_logit(x, IA_head, self.IA_final_bg)
146 |
147 | pred = self.augment_background_logit(fg_logit, bg_logit)
148 | memory_list =[x_emb_cur_1.cpu(),memory_list[1].cpu()]
149 | return pred,memory_list
150 |
151 | def IA_logit(self, x, IA_head, IA_final):
152 | n, c, h, w = x.size()
153 | x = x.view(1, n * c, h, w)
154 | IA_output = IA_final(IA_head)
155 | IA_weight = IA_output[:, :c]
156 | IA_bias = IA_output[:, -1]
157 | IA_weight = IA_weight.view(n, c, 1, 1)
158 | IA_bias = IA_bias.view(-1)
159 | logit = F.conv2d(x, weight=IA_weight, bias=IA_bias, groups=n).view(n, 1, h, w)
160 | return logit
161 |
162 | def decoder_final(self, x, low_level_feat, IA_head):
163 | x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bicubic', align_corners=True)
164 |
165 | low_level_feat = self.GCT_sc(low_level_feat)
166 | low_level_feat = self.conv_sc(low_level_feat)
167 | low_level_feat = self.bn_sc(low_level_feat)
168 | low_level_feat = self.relu(low_level_feat)
169 |
170 | x = torch.cat([x, low_level_feat], dim=1)
171 |
172 | px1 = torch.nn.functional.avg_pool2d(x,kernel_size=(x.size()[-2],x.size()[-1]),padding = 0)
173 | px1_sum = torch.sum(px1,dim=0,keepdim=True)
174 | px1_delta = (px1_sum-px1).squeeze(-1).squeeze(-1)
175 |
176 | x = self.IA10(x, torch.cat([IA_head,px1_delta],dim=1))
177 | x = self.conv1(x)
178 | x = self.bn1(x)
179 | x = self.relu(x)
180 |
181 | px1 = torch.nn.functional.avg_pool2d(x,kernel_size=(x.size()[-2],x.size()[-1]),padding = 0)
182 | px1_sum = torch.sum(px1,dim=0,keepdim=True)
183 | px1_delta = (px1_sum-px1).squeeze(-1).squeeze(-1)
184 |
185 | x = self.IA11(x, torch.cat([IA_head,px1_delta],dim=1))
186 | x = self.conv2(x)
187 | x = self.bn2(x)
188 | x = self.relu(x)
189 |
190 | return x
191 |
192 | def Modulator_1(self, x, x_memory,IA_head):
193 | x = torch.cat([x, x_memory], dim=1)
194 | x = self.M1_Reweight_Layer_1(x, IA_head)
195 | x = self.M1_Bottleneck_1(x)
196 | x = self.M1_Reweight_Layer_2(x, IA_head)
197 | x = self.M1_Bottleneck_2(x)
198 | x = self.M1_Reweight_Layer_3(x, IA_head)
199 | x = self.M1_Bottleneck_3(x)
200 | return x
201 |
202 | def Modulator_2(self, x, x_memory,IA_head):
203 | x = torch.cat([x, x_memory], dim=1)
204 | x = self.M2_Reweight_Layer_1(x, IA_head)
205 | x = self.M2_Bottleneck_1(x)
206 | x = self.M2_Reweight_Layer_2(x, IA_head)
207 | x = self.M2_Bottleneck_2(x)
208 | x = self.M2_Reweight_Layer_3(x, IA_head)
209 | x = self.M2_Bottleneck_3(x)
210 | return x
211 |
212 |
213 | def augment_background_logit(self, fg_logit, bg_logit):
214 | # Augment the logit of absolute background by using the relative background logit of all the
215 | # foreground objects.
216 | obj_num = fg_logit.size(0)
217 | pred = fg_logit
218 | if obj_num > 1:
219 | bg_logit = bg_logit[1:obj_num, :, :, :]
220 | aug_bg_logit, _ = torch.min(bg_logit, dim=0, keepdim=True)
221 | pad = torch.zeros(aug_bg_logit.size(), device=aug_bg_logit.device).expand(obj_num - 1, -1, -1, -1)
222 | aug_bg_logit = torch.cat([aug_bg_logit, pad], dim=0)
223 | pred = pred + aug_bg_logit
224 | pred = pred.permute(1,0,2,3)
225 | return pred
226 |
227 |
228 | class DynamicPreHead(nn.Module):
229 | def __init__(self, in_dim=3, embed_dim=100, kernel_size=1):
230 | super(DynamicPreHead,self).__init__()
231 | self.conv=nn.Conv2d(in_dim,embed_dim,kernel_size=kernel_size,stride=1,padding=int((kernel_size-1)/2))
232 | self.bn = nn.GroupNorm(int(embed_dim / 4), embed_dim)
233 | self.relu = nn.ReLU(True)
234 | nn.init.kaiming_normal_(self.conv.weight,mode='fan_out',nonlinearity='relu')
235 |
236 | def forward(self, x):
237 | x = self.conv(x)
238 | x = self.bn(x)
239 | x = self.relu(x)
240 | return x
241 |
--------------------------------------------------------------------------------
/AOC-Net/complete_project/AOCNet/networks/deeplab/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JerryX1110/Robust-Video-Object-Segmentation/f562afe70ac0f960d27980d3e82b237f8dea45ec/AOC-Net/complete_project/AOCNet/networks/deeplab/__init__.py
--------------------------------------------------------------------------------
/AOC-Net/complete_project/AOCNet/networks/deeplab/aspp.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | class _ASPPModule(nn.Module):
7 | def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm):
8 | super(_ASPPModule, self).__init__()
9 | self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
10 | stride=1, padding=padding, dilation=dilation, bias=False)
11 | self.bn = BatchNorm(planes)
12 | self.relu = nn.ReLU(inplace=True)
13 |
14 | self._init_weight()
15 |
16 | def forward(self, x):
17 | x = self.atrous_conv(x)
18 | x = self.bn(x)
19 |
20 | return self.relu(x)
21 |
22 | def _init_weight(self):
23 | for m in self.modules():
24 | if isinstance(m, nn.Conv2d):
25 | torch.nn.init.kaiming_normal_(m.weight)
26 | elif isinstance(m, nn.BatchNorm2d):
27 | m.weight.data.fill_(1)
28 | m.bias.data.zero_()
29 |
30 | class ASPP(nn.Module):
31 | def __init__(self, backbone, output_stride, BatchNorm):
32 | super(ASPP, self).__init__()
33 | if backbone == 'drn':
34 | inplanes = 512
35 | elif backbone == 'mobilenet':
36 | inplanes = 320
37 | else:
38 | inplanes = 2048
39 | if output_stride == 16:
40 | dilations = [1, 6, 12, 18]
41 | elif output_stride == 8:
42 | dilations = [1, 12, 24, 36]
43 | else:
44 | raise NotImplementedError
45 |
46 | self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm)
47 | self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm)
48 | self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm)
49 | self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm)
50 |
51 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
52 | nn.Conv2d(inplanes, 256, 1, stride=1, bias=False),
53 | BatchNorm(256),
54 | nn.ReLU(inplace=True))
55 | self.conv1 = nn.Conv2d(1280, 256, 1, bias=False)
56 | self.bn1 = BatchNorm(256)
57 | self.relu = nn.ReLU(inplace=True)
58 | self.dropout = nn.Dropout(0.1)
59 | self._init_weight()
60 |
61 | def forward(self, x):
62 | x1 = self.aspp1(x)
63 | x2 = self.aspp2(x)
64 | x3 = self.aspp3(x)
65 | x4 = self.aspp4(x)
66 | x5 = self.global_avg_pool(x)
67 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
68 | x = torch.cat((x1, x2, x3, x4, x5), dim=1)
69 |
70 | x = self.conv1(x)
71 | x = self.bn1(x)
72 | x = self.relu(x)
73 |
74 | return self.dropout(x)
75 |
76 | def _init_weight(self):
77 | for m in self.modules():
78 | if isinstance(m, nn.Conv2d):
79 | torch.nn.init.kaiming_normal_(m.weight)
80 | elif isinstance(m, nn.BatchNorm2d):
81 | m.weight.data.fill_(1)
82 | m.bias.data.zero_()
83 |
84 |
85 | def build_aspp(backbone, output_stride, BatchNorm):
86 | return ASPP(backbone, output_stride, BatchNorm)
87 |
--------------------------------------------------------------------------------
/AOC-Net/complete_project/AOCNet/networks/deeplab/backbone/__init__.py:
--------------------------------------------------------------------------------
1 | from networks.deeplab.backbone import resnet, mobilenet
2 |
3 | def build_backbone(backbone, output_stride, BatchNorm):
4 | if backbone == 'resnet':
5 | return resnet.ResNet101(output_stride, BatchNorm)
6 | elif backbone == 'mobilenet':
7 | return mobilenet.MobileNetV2(output_stride, BatchNorm)
8 | else:
9 | raise NotImplementedError
10 |
--------------------------------------------------------------------------------
/AOC-Net/complete_project/AOCNet/networks/deeplab/backbone/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JerryX1110/Robust-Video-Object-Segmentation/f562afe70ac0f960d27980d3e82b237f8dea45ec/AOC-Net/complete_project/AOCNet/networks/deeplab/backbone/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/AOC-Net/complete_project/AOCNet/networks/deeplab/backbone/__pycache__/mobilenet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JerryX1110/Robust-Video-Object-Segmentation/f562afe70ac0f960d27980d3e82b237f8dea45ec/AOC-Net/complete_project/AOCNet/networks/deeplab/backbone/__pycache__/mobilenet.cpython-36.pyc
--------------------------------------------------------------------------------
/AOC-Net/complete_project/AOCNet/networks/deeplab/backbone/__pycache__/resnet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JerryX1110/Robust-Video-Object-Segmentation/f562afe70ac0f960d27980d3e82b237f8dea45ec/AOC-Net/complete_project/AOCNet/networks/deeplab/backbone/__pycache__/resnet.cpython-36.pyc
--------------------------------------------------------------------------------
/AOC-Net/complete_project/AOCNet/networks/deeplab/backbone/mobilenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 | import math
5 | import torch.utils.model_zoo as model_zoo
6 |
7 | def conv_bn(inp, oup, stride, BatchNorm):
8 | return nn.Sequential(
9 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
10 | BatchNorm(oup),
11 | nn.ReLU6(inplace=True)
12 | )
13 |
14 |
15 | def fixed_padding(inputs, kernel_size, dilation):
16 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1)
17 | pad_total = kernel_size_effective - 1
18 | pad_beg = pad_total // 2
19 | pad_end = pad_total - pad_beg
20 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end))
21 | return padded_inputs
22 |
23 |
24 | class InvertedResidual(nn.Module):
25 | def __init__(self, inp, oup, stride, dilation, expand_ratio, BatchNorm):
26 | super(InvertedResidual, self).__init__()
27 | self.stride = stride
28 | assert stride in [1, 2]
29 |
30 | hidden_dim = round(inp * expand_ratio)
31 | self.use_res_connect = self.stride == 1 and inp == oup
32 | self.kernel_size = 3
33 | self.dilation = dilation
34 |
35 | if expand_ratio == 1:
36 | self.conv = nn.Sequential(
37 | # dw
38 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False),
39 | BatchNorm(hidden_dim),
40 | nn.ReLU6(inplace=True),
41 | # pw-linear
42 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, 1, bias=False),
43 | BatchNorm(oup),
44 | )
45 | else:
46 | self.conv = nn.Sequential(
47 | # pw
48 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, 1, bias=False),
49 | BatchNorm(hidden_dim),
50 | nn.ReLU6(inplace=True),
51 | # dw
52 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False),
53 | BatchNorm(hidden_dim),
54 | nn.ReLU6(inplace=True),
55 | # pw-linear
56 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, bias=False),
57 | BatchNorm(oup),
58 | )
59 |
60 | def forward(self, x):
61 | x_pad = fixed_padding(x, self.kernel_size, dilation=self.dilation)
62 | if self.use_res_connect:
63 | x = x + self.conv(x_pad)
64 | else:
65 | x = self.conv(x_pad)
66 | return x
67 |
68 |
69 | class MobileNetV2(nn.Module):
70 | def __init__(self, output_stride=8, BatchNorm=None, width_mult=1., pretrained=False):
71 | super(MobileNetV2, self).__init__()
72 | block = InvertedResidual
73 | input_channel = 32
74 | current_stride = 1
75 | rate = 1
76 | interverted_residual_setting = [
77 | # t, c, n, s
78 | [1, 16, 1, 1],
79 | [6, 24, 2, 2],
80 | [6, 32, 3, 2],
81 | [6, 64, 4, 2],
82 | [6, 96, 3, 1],
83 | [6, 160, 3, 2],
84 | [6, 320, 1, 1],
85 | ]
86 |
87 | # building first layer
88 | input_channel = int(input_channel * width_mult)
89 | self.features = [conv_bn(3, input_channel, 2, BatchNorm)]
90 | current_stride *= 2
91 | # building inverted residual blocks
92 | for t, c, n, s in interverted_residual_setting:
93 | if current_stride == output_stride:
94 | stride = 1
95 | dilation = rate
96 | rate *= s
97 | else:
98 | stride = s
99 | dilation = 1
100 | current_stride *= s
101 | output_channel = int(c * width_mult)
102 | for i in range(n):
103 | if i == 0:
104 | self.features.append(block(input_channel, output_channel, stride, dilation, t, BatchNorm))
105 | else:
106 | self.features.append(block(input_channel, output_channel, 1, rate, t, BatchNorm))
107 | input_channel = output_channel
108 | self.features = nn.Sequential(*self.features)
109 | self._initialize_weights()
110 |
111 | if pretrained:
112 | self._load_pretrained_model()
113 |
114 | self.low_level_features = self.features[0:4]
115 | self.high_level_features = self.features[4:]
116 |
117 | self.feautre_8x = self.features[4:7]
118 | self.feature_16x = self.features[7:14]
119 | self.feature_32x = self.features[14:]
120 |
121 | def forward(self, x, return_mid_level=False):
122 | if return_mid_level:
123 | low_level_feat = self.low_level_features(x)
124 | mid_level_feat = self.feautre_8x(low_level_feat)
125 | x = self.feature_16x(mid_level_feat)
126 | x = self.feature_32x(x)
127 | return x, low_level_feat, mid_level_feat
128 | else:
129 | low_level_feat = self.low_level_features(x)
130 | x = self.high_level_features(low_level_feat)
131 | return x, low_level_feat
132 |
133 | def _load_pretrained_model(self):
134 | pretrain_dict = model_zoo.load_url('http://jeff95.me/models/mobilenet_v2-6a65762b.pth')
135 | model_dict = {}
136 | state_dict = self.state_dict()
137 | for k, v in pretrain_dict.items():
138 | if k in state_dict:
139 | model_dict[k] = v
140 | state_dict.update(model_dict)
141 | self.load_state_dict(state_dict)
142 |
143 | def _initialize_weights(self):
144 | for m in self.modules():
145 | if isinstance(m, nn.Conv2d):
146 | torch.nn.init.kaiming_normal_(m.weight)
147 | elif isinstance(m, nn.BatchNorm2d):
148 | m.weight.data.fill_(1)
149 | m.bias.data.zero_()
150 |
151 | if __name__ == "__main__":
152 | input = torch.rand(1, 3, 512, 512)
153 | model = MobileNetV2(output_stride=16, BatchNorm=nn.BatchNorm2d)
154 | output, low_level_feat = model(input)
155 | print(output.size())
156 | print(low_level_feat.size())
157 |
--------------------------------------------------------------------------------
/AOC-Net/complete_project/AOCNet/networks/deeplab/backbone/resnet.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch.nn as nn
3 | import torch.utils.model_zoo as model_zoo
4 |
5 | class Bottleneck(nn.Module):
6 | expansion = 4
7 |
8 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None):
9 | super(Bottleneck, self).__init__()
10 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
11 | self.bn1 = BatchNorm(planes)
12 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
13 | dilation=dilation, padding=dilation, bias=False)
14 | self.bn2 = BatchNorm(planes)
15 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
16 | self.bn3 = BatchNorm(planes * 4)
17 | self.relu = nn.ReLU(inplace=True)
18 | self.downsample = downsample
19 | self.stride = stride
20 | self.dilation = dilation
21 |
22 | def forward(self, x):
23 | residual = x
24 |
25 | out = self.conv1(x)
26 | out = self.bn1(out)
27 | out = self.relu(out)
28 |
29 | out = self.conv2(out)
30 | out = self.bn2(out)
31 | out = self.relu(out)
32 |
33 | out = self.conv3(out)
34 | out = self.bn3(out)
35 |
36 | if self.downsample is not None:
37 | residual = self.downsample(x)
38 |
39 | out += residual
40 | out = self.relu(out)
41 |
42 | return out
43 |
44 | class ResNet(nn.Module):
45 |
46 | def __init__(self, block, layers, output_stride, BatchNorm, pretrained=False):
47 | self.inplanes = 64
48 | super(ResNet, self).__init__()
49 | blocks = [1, 2, 4]
50 | if output_stride == 16:
51 | strides = [1, 2, 2, 1]
52 | dilations = [1, 1, 1, 2]
53 | elif output_stride == 8:
54 | strides = [1, 2, 1, 1]
55 | dilations = [1, 1, 2, 4]
56 | else:
57 | raise NotImplementedError
58 |
59 | # Modules
60 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
61 | bias=False)
62 | self.bn1 = BatchNorm(64)
63 | self.relu = nn.ReLU(inplace=True)
64 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
65 |
66 | self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm)
67 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm)
68 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm)
69 | self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm)
70 | self._init_weight()
71 |
72 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
73 | downsample = None
74 | if stride != 1 or self.inplanes != planes * block.expansion:
75 | downsample = nn.Sequential(
76 | nn.Conv2d(self.inplanes, planes * block.expansion,
77 | kernel_size=1, stride=stride, bias=False),
78 | BatchNorm(planes * block.expansion),
79 | )
80 |
81 | layers = []
82 | layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm))
83 | self.inplanes = planes * block.expansion
84 | for i in range(1, blocks):
85 | layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm))
86 |
87 | return nn.Sequential(*layers)
88 |
89 | def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
90 | downsample = None
91 | if stride != 1 or self.inplanes != planes * block.expansion:
92 | downsample = nn.Sequential(
93 | nn.Conv2d(self.inplanes, planes * block.expansion,
94 | kernel_size=1, stride=stride, bias=False),
95 | BatchNorm(planes * block.expansion),
96 | )
97 |
98 | layers = []
99 | layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation,
100 | downsample=downsample, BatchNorm=BatchNorm))
101 | self.inplanes = planes * block.expansion
102 | for i in range(1, len(blocks)):
103 | layers.append(block(self.inplanes, planes, stride=1,
104 | dilation=blocks[i]*dilation, BatchNorm=BatchNorm))
105 |
106 | return nn.Sequential(*layers)
107 |
108 | def forward(self, input, return_mid_level=False):
109 | x = self.conv1(input)
110 | x = self.bn1(x)
111 | x = self.relu(x)
112 | x = self.maxpool(x)
113 |
114 | x = self.layer1(x)
115 | low_level_feat = x
116 | x = self.layer2(x)
117 | mid_level_feat = x
118 | x = self.layer3(x)
119 | x = self.layer4(x)
120 | if return_mid_level:
121 | return x, low_level_feat, mid_level_feat
122 | else:
123 | return x, low_level_feat
124 | def _init_weight(self):
125 | for m in self.modules():
126 | if isinstance(m, nn.Conv2d):
127 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
128 | m.weight.data.normal_(0, math.sqrt(2. / n))
129 | elif isinstance(m, nn.BatchNorm2d):
130 | m.weight.data.fill_(1)
131 | m.bias.data.zero_()
132 |
133 | def _load_pretrained_model(self):
134 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth')
135 | model_dict = {}
136 | state_dict = self.state_dict()
137 | for k, v in pretrain_dict.items():
138 | if k in state_dict:
139 | model_dict[k] = v
140 | state_dict.update(model_dict)
141 | self.load_state_dict(state_dict)
142 |
143 | def ResNet101(output_stride, BatchNorm, pretrained=True):
144 | """Constructs a ResNet-101 model.
145 | Args:
146 | pretrained (bool): If True, returns a model pre-trained on ImageNet
147 | """
148 | model = ResNet(Bottleneck, [3, 4, 23, 3], output_stride, BatchNorm, pretrained=pretrained)
149 | return model
150 |
151 | if __name__ == "__main__":
152 | import torch
153 | model = ResNet101(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=8)
154 | input = torch.rand(1, 3, 512, 512)
155 | output, low_level_feat = model(input)
156 | print(output.size())
157 | print(low_level_feat.size())
158 |
--------------------------------------------------------------------------------
/AOC-Net/complete_project/AOCNet/networks/deeplab/decoder.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | class Decoder(nn.Module):
7 | def __init__(self, backbone, BatchNorm):
8 | super(Decoder, self).__init__()
9 | if backbone == 'resnet':
10 | low_level_inplanes = 256
11 | elif backbone == 'mobilenet':
12 | low_level_inplanes = 24
13 | else:
14 | raise NotImplementedError
15 |
16 | self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False)
17 | self.bn1 = BatchNorm(48)
18 | self.relu = nn.ReLU(inplace=True)
19 |
20 | self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False),
21 | BatchNorm(256),
22 | nn.ReLU(inplace=True),
23 | nn.Sequential(),
24 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
25 | BatchNorm(256),
26 | nn.ReLU(inplace=True),
27 | nn.Sequential())
28 |
29 | self._init_weight()
30 |
31 |
32 | def forward(self, x, low_level_feat):
33 | low_level_feat = self.conv1(low_level_feat)
34 | low_level_feat = self.bn1(low_level_feat)
35 | low_level_feat = self.relu(low_level_feat)
36 |
37 | x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True)
38 | x = torch.cat((x, low_level_feat), dim=1)
39 | x = self.last_conv(x)
40 |
41 | return x
42 |
43 | def _init_weight(self):
44 | for m in self.modules():
45 | if isinstance(m, nn.Conv2d):
46 | torch.nn.init.kaiming_normal_(m.weight)
47 | elif isinstance(m, nn.BatchNorm2d):
48 | m.weight.data.fill_(1)
49 | m.bias.data.zero_()
50 |
51 | def build_decoder(backbone, BatchNorm):
52 | return Decoder(backbone, BatchNorm)
53 |
--------------------------------------------------------------------------------
/AOC-Net/complete_project/AOCNet/networks/deeplab/deeplab.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from networks.deeplab.aspp import build_aspp
5 | from networks.deeplab.decoder import build_decoder
6 | from networks.deeplab.backbone import build_backbone
7 | from networks.layers.normalization import FrozenBatchNorm2d
8 |
9 | class DeepLab(nn.Module):
10 | def __init__(self,
11 | backbone='resnet',
12 | output_stride=16,
13 | freeze_bn=True):
14 | super(DeepLab, self).__init__()
15 |
16 | if freeze_bn == True:
17 | print("Use frozen BN in DeepLab!")
18 | BatchNorm = FrozenBatchNorm2d
19 | else:
20 | BatchNorm = nn.BatchNorm2d
21 |
22 | self.backbone = build_backbone(backbone, output_stride, BatchNorm)
23 | self.aspp = build_aspp(backbone, output_stride, BatchNorm)
24 | self.decoder = build_decoder(backbone, BatchNorm)
25 |
26 |
27 | def forward(self, input, return_aspp=False):
28 | if return_aspp:
29 | x, low_level_feat, mid_level_feat = self.backbone(input, True)
30 | else:
31 | x, low_level_feat = self.backbone(input)
32 | aspp_x = self.aspp(x)
33 | x = self.decoder(aspp_x, low_level_feat)
34 |
35 | if return_aspp:
36 | return x, aspp_x, low_level_feat, mid_level_feat
37 | else:
38 | return x, low_level_feat
39 |
40 |
41 | def get_1x_lr_params(self):
42 | modules = [self.backbone]
43 | for i in range(len(modules)):
44 | for m in modules[i].named_modules():
45 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], nn.BatchNorm2d):
46 | for p in m[1].parameters():
47 | if p.requires_grad:
48 | yield p
49 |
50 | def get_10x_lr_params(self):
51 | modules = [self.aspp, self.decoder]
52 | for i in range(len(modules)):
53 | for m in modules[i].named_modules():
54 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], nn.BatchNorm2d):
55 | for p in m[1].parameters():
56 | if p.requires_grad:
57 | yield p
58 |
59 |
60 | if __name__ == "__main__":
61 | model = DeepLab(backbone='resnet', output_stride=16)
62 | model.eval()
63 | input = torch.rand(2, 3, 513, 513)
64 | output = model(input)
65 | print(output.size())
66 |
67 |
68 |
--------------------------------------------------------------------------------
/AOC-Net/complete_project/AOCNet/networks/engine/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JerryX1110/Robust-Video-Object-Segmentation/f562afe70ac0f960d27980d3e82b237f8dea45ec/AOC-Net/complete_project/AOCNet/networks/engine/__init__.py
--------------------------------------------------------------------------------
/AOC-Net/complete_project/AOCNet/networks/engine/eval_manager_mm.py:
--------------------------------------------------------------------------------
1 | import os
2 | import importlib
3 | import time
4 | import datetime as datetime
5 | import torch
6 | import torch.nn as nn
7 | import torch.optim as optim
8 | from torch.utils.data import DataLoader
9 | from torchvision import transforms
10 | import numpy as np
11 | from dataloaders.datasets_m import YOUTUBE_VOS_Test,DAVIS_Test
12 | import dataloaders.custom_transforms as tr
13 | from networks.deeplab.deeplab import DeepLab
14 | from utils.meters import AverageMeter
15 | from utils.image import flip_tensor, save_mask, save_matching_result
16 | from utils.checkpoint import load_network
17 | from utils.eval import zip_folder
18 | from networks.layers.shannon_entropy import cal_shannon_entropy
19 | import math
20 |
21 | class Evaluator(object):
22 | def __init__(self, cfg):
23 |
24 | self.mem_every = cfg.MEM_EVERY
25 | self.unc_ratio = cfg.UNC_RATIO
26 |
27 | self.gpu = cfg.TEST_GPU_ID
28 | self.cfg = cfg
29 | self.print_log(cfg.__dict__)
30 | print("Use GPU {} for evaluating".format(self.gpu))
31 | torch.cuda.set_device(self.gpu)
32 |
33 | self.print_log('Build backbone.')
34 | self.feature_extracter = DeepLab(
35 | backbone=cfg.MODEL_BACKBONE,
36 | freeze_bn=cfg.MODEL_FREEZE_BN).cuda(self.gpu)
37 |
38 | self.print_log('Build VOS model.')
39 | CFBI = importlib.import_module(cfg.MODEL_MODULE)
40 | self.model = CFBI.get_module()(
41 | cfg,
42 | self.feature_extracter).cuda(self.gpu)
43 |
44 | self.process_pretrained_model()
45 |
46 | self.prepare_dataset()
47 |
48 | def process_pretrained_model(self):
49 | cfg = self.cfg
50 | if cfg.TEST_CKPT_PATH == 'test':
51 | self.ckpt = 'test'
52 | self.print_log('Test evaluation.')
53 | return
54 | if cfg.TEST_CKPT_PATH is None:
55 | if cfg.TEST_CKPT_STEP is not None:
56 | ckpt = str(cfg.TEST_CKPT_STEP)
57 | else:
58 | ckpts = os.listdir(cfg.DIR_CKPT)
59 | if len(ckpts) > 0:
60 | ckpts = list(map(lambda x: int(x.split('_')[-1].split('.')[0]), ckpts))
61 | ckpt = np.sort(ckpts)[-1]
62 | else:
63 | self.print_log('No checkpoint in {}.'.format(cfg.DIR_CKPT))
64 | exit()
65 | self.ckpt = ckpt
66 | cfg.TEST_CKPT_PATH = os.path.join(cfg.DIR_CKPT, 'save_step_%s.pth' % ckpt)
67 | self.model, removed_dict = load_network(self.model, cfg.TEST_CKPT_PATH, self.gpu)
68 | if len(removed_dict) > 0:
69 | self.print_log('Remove {} from pretrained model.'.format(removed_dict))
70 | self.print_log('Load latest checkpoint from {}'.format(cfg.TEST_CKPT_PATH))
71 | else:
72 | self.ckpt = 'unknown'
73 | self.model, removed_dict = load_network(self.model, cfg.TEST_CKPT_PATH, self.gpu)
74 | if len(removed_dict) > 0:
75 | self.print_log('Remove {} from pretrained model.'.format(removed_dict))
76 | self.print_log('Load checkpoint from {}'.format(cfg.TEST_CKPT_PATH))
77 |
78 | def prepare_dataset(self):
79 | cfg = self.cfg
80 | self.print_log('Process dataset...')
81 | eval_transforms = transforms.Compose([
82 | tr.MultiRestrictSize(cfg.TEST_MIN_SIZE, cfg.TEST_MAX_SIZE, cfg.TEST_FLIP, cfg.TEST_MULTISCALE),
83 | tr.MultiToTensor()])
84 |
85 | eval_name = '{}_{}_ckpt_{}'.format(cfg.TEST_DATASET, cfg.EXP_NAME, self.ckpt)
86 | if cfg.TEST_FLIP:
87 | eval_name += '_flip'
88 | if len(cfg.TEST_MULTISCALE) > 1:
89 | eval_name += '_ms'
90 | for ss in cfg.TEST_MULTISCALE:
91 | eval_name +="_"
92 | eval_name +=str(ss)
93 |
94 | eval_name+="_m_"+str(self.mem_every)+"_u_"+str(self.unc_ratio)+"_r_"+str(cfg.TEST_MAX_SIZE)+"_RPA"
95 |
96 | if cfg.TEST_DATASET == 'youtubevos19':
97 | self.result_root = os.path.join(cfg.DIR_EVALUATION, cfg.TEST_DATASET, eval_name, 'Annotations')
98 | self.dataset = YOUTUBE_VOS_Test(
99 | root=cfg.DIR_YTB_EVAL,
100 | transform=eval_transforms,
101 | result_root=self.result_root)
102 |
103 | elif cfg.TEST_DATASET == 'youtubevos18':
104 | self.result_root = os.path.join(cfg.DIR_EVALUATION, cfg.TEST_DATASET, eval_name, 'Annotations')
105 | self.dataset = YOUTUBE_VOS_Test(
106 | root=cfg.DIR_YTB_EVAL18,
107 | transform=eval_transforms,
108 | result_root=self.result_root)
109 | elif cfg.TEST_DATASET == 'youtubevos': ## predict all frame for youtubevos 2019
110 | self.result_root = os.path.join(cfg.DIR_EVALUATION, cfg.TEST_DATASET, eval_name, 'Annotations')
111 | self.dataset = YOUTUBE_VOS_Test(
112 | root=cfg.DIR_YTB_EVAL,
113 | transform=eval_transforms,
114 | result_root=self.result_root,
115 | use_all=True)
116 | elif cfg.TEST_DATASET == 'davis2017':
117 | resolution = 'Full-Resolution' if cfg.TEST_DATASET_FULL_RESOLUTION else '480p'
118 | self.result_root = os.path.join(cfg.DIR_EVALUATION, cfg.TEST_DATASET, eval_name, 'Annotations', resolution)
119 | self.dataset = DAVIS_Test(
120 | split=cfg.TEST_DATASET_SPLIT,
121 | root=cfg.DIR_DAVIS,
122 | year=2017,
123 | transform=eval_transforms,
124 | full_resolution=cfg.TEST_DATASET_FULL_RESOLUTION,
125 | result_root=self.result_root)
126 | elif cfg.TEST_DATASET == 'davis2017-label':
127 | resolution = 'Full-Resolution' if cfg.TEST_DATASET_FULL_RESOLUTION else '480p'
128 | self.result_root = os.path.join(cfg.DIR_EVALUATION, cfg.TEST_DATASET, eval_name, 'Annotations', resolution)
129 | self.dataset = DAVIS_Test_w_label(
130 | split=cfg.TEST_DATASET_SPLIT,
131 | root=cfg.DIR_DAVIS,
132 | year=2017,
133 | transform=eval_transforms,
134 | full_resolution=cfg.TEST_DATASET_FULL_RESOLUTION,
135 | result_root=self.result_root)
136 | elif cfg.TEST_DATASET == 'davis2016':
137 | resolution = 'Full-Resolution' if cfg.TEST_DATASET_FULL_RESOLUTION else '480p'
138 | self.result_root = os.path.join(cfg.DIR_EVALUATION, cfg.TEST_DATASET, eval_name, 'Annotations', resolution)
139 | self.dataset = DAVIS_Test(
140 | split=cfg.TEST_DATASET_SPLIT,
141 | root=cfg.DIR_DAVIS,
142 | year=2016,
143 | transform=eval_transforms,
144 | full_resolution=cfg.TEST_DATASET_FULL_RESOLUTION,
145 | result_root=self.result_root)
146 | elif cfg.TEST_DATASET == 'test':
147 | self.result_root = os.path.join(cfg.DIR_EVALUATION, cfg.TEST_DATASET, eval_name, 'Annotations')
148 | self.dataset = EVAL_TEST(eval_transforms, self.result_root)
149 | else:
150 | print('Unknown dataset!')
151 | exit()
152 |
153 | print('Eval {} on {}:'.format(cfg.EXP_NAME, cfg.TEST_DATASET))
154 | self.source_folder = os.path.join(cfg.DIR_EVALUATION, cfg.TEST_DATASET, eval_name, 'Annotations')
155 | self.zip_dir = os.path.join(cfg.DIR_EVALUATION, cfg.TEST_DATASET, '{}.zip'.format(eval_name))
156 | if not os.path.exists(self.result_root):
157 | os.makedirs(self.result_root)
158 | self.print_log('Done!')
159 |
160 | def evaluating(self):
161 | cfg = self.cfg
162 | self.model.eval()
163 | video_num = 0
164 | total_time = 0
165 | total_frame = 0
166 | total_sfps = 0
167 | total_video_num = len(self.dataset)
168 | PlaceHolder=[]
169 | for i in range(cfg.BLOCK_NUM):
170 | PlaceHolder.append(None)
171 |
172 | for seq_idx, seq_dataset in enumerate(self.dataset):
173 | video_num += 1
174 | seq_name = seq_dataset.seq_name
175 |
176 | print('Prcessing Seq {} [{}/{}]:'.format(seq_name, video_num, total_video_num))
177 |
178 | torch.cuda.empty_cache()
179 |
180 | seq_dataloader=DataLoader(seq_dataset, batch_size=1, shuffle=False, num_workers=cfg.TEST_WORKERS, pin_memory=True)
181 |
182 | seq_total_time = 0
183 | seq_total_frame = 0
184 | ref_embeddings = []
185 | ref_masks = []
186 | prev_embedding = []
187 | prev_mask = []
188 | ref_mask_confident = []
189 | memory_prev_all_list=[]
190 | memory_cur_all_list=[]
191 | memory_prev_list=[]
192 | memory_cur_list=[]
193 | label_all_list=[]
194 |
195 | with torch.no_grad():
196 | for frame_idx, samples in enumerate(seq_dataloader):
197 |
198 | time_start = time.time()
199 | all_preds = []
200 |
201 | join_label = None
202 | UPDATE=False
203 |
204 |
205 | if frame_idx==0:
206 | for aug_idx in range(len(samples)):
207 | memory_prev_all_list.append([PlaceHolder])
208 | else:
209 | memory_prev_all_list=memory_cur_all_list
210 |
211 | memory_cur_all_list=[]
212 | for aug_idx in range(len(samples)):
213 | if len(ref_embeddings) <= aug_idx:
214 | ref_embeddings.append([])
215 | ref_masks.append([])
216 | prev_embedding.append(None)
217 | prev_mask.append(None)
218 | ref_mask_confident.append([])
219 |
220 | sample = samples[aug_idx]
221 | ref_emb = ref_embeddings[aug_idx]
222 | #ref_m = ref_masks[aug_idx]
223 |
224 | ## use confident mask for correlation
225 | ref_m = ref_mask_confident[aug_idx]
226 |
227 | prev_emb = prev_embedding[aug_idx]
228 | prev_m = prev_mask[aug_idx]
229 |
230 |
231 | current_img = sample['current_img']
232 | if 'current_label' in sample.keys():
233 | current_label = sample['current_label'].cuda(self.gpu)
234 | else:
235 | current_label = None
236 |
237 | obj_list = sample['meta']['obj_list']
238 | obj_num = sample['meta']['obj_num']
239 | imgname = sample['meta']['current_name']
240 | ori_height = sample['meta']['height']
241 | ori_width = sample['meta']['width']
242 | current_img = current_img.cuda(self.gpu)
243 | obj_num = obj_num.cuda(self.gpu)
244 | bs, _, h, w = current_img.size()
245 |
246 | all_pred, current_embedding,memory_cur_list = self.model.forward_for_eval(memory_prev_all_list[aug_idx], ref_emb,
247 | ref_m, prev_emb, prev_m,
248 | current_img, gt_ids=obj_num,
249 | pred_size=[ori_height,ori_width])
250 | memory_cur_all_list.append(memory_cur_list)
251 |
252 | # delete the label that hasn't existed in the GT label
253 | all_pred_remake = []
254 | all_pred_exist = []
255 | if all_pred!=None:
256 | all_pred_split = all_pred.split(all_pred.size()[1],dim=1)[0]
257 |
258 | for i in range(all_pred.size()[1]):
259 | if i not in label_all_list:
260 | all_pred_remake.append(torch.zeros_like(all_pred_split[0][i]).unsqueeze(0))
261 | else:
262 | all_pred_remake.append(all_pred_split[0][i].unsqueeze(0))
263 | all_pred_exist.append(all_pred_split[0][i].unsqueeze(0))
264 | all_pred = torch.cat(all_pred_remake,dim=0).unsqueeze(0)
265 | all_pred_exist = torch.cat(all_pred_exist,dim=0).unsqueeze(0)
266 |
267 |
268 | if 'current_label' in sample.keys():
269 | label_cur_list = np.unique(sample['current_label'].cpu().detach().numpy()).tolist()
270 | for i in label_cur_list:
271 | if i not in label_all_list:
272 | label_all_list.append(i)
273 |
274 | if frame_idx == 0:
275 | if current_label is None:
276 | print("No first frame label in Seq {}.".format(seq_name))
277 | ref_embeddings[aug_idx].append(current_embedding)
278 | ref_masks[aug_idx].append(current_label)
279 | ref_mask_confident[aug_idx].append(current_label)
280 |
281 | prev_embedding[aug_idx] = current_embedding
282 | prev_mask[aug_idx] = current_label
283 |
284 | else:
285 | if sample['meta']['flip']:
286 | all_pred = flip_tensor(all_pred, 3)
287 | # In YouTube-VOS, not all the objects appear in the first frame for the first time. Thus, we
288 | # have to introduce new labels for new objects, if necessary.
289 | if not sample['meta']['flip'] and not(current_label is None) and join_label is None: # gt exists here
290 | join_label = current_label
291 | all_preds.append(all_pred)
292 |
293 | all_pred_org = all_pred
294 | current_label_0 = None
295 |
296 | if current_label is not None:
297 | ref_embeddings[aug_idx].append(current_embedding)
298 |
299 | else:
300 | all_preds_0 = torch.cat(all_preds, dim=0)
301 | all_preds_0 = torch.mean(all_preds_0, dim=0)
302 | pred_label_0 = torch.argmax(all_preds_0, dim=0)
303 | current_label_0 = pred_label_0.view(1, 1, ori_height, ori_width)
304 |
305 | # uncertainty region filter
306 | uncertainty_org,uncertainty_norm = cal_shannon_entropy(all_pred_exist)
307 |
308 | # we set mem_every==-1 to indicate we don't use extra confident candidate pool
309 | if self.mem_every>-1 and frame_idx%self.mem_every==0 and frame_idx!=0 and current_embedding!=None and current_label_0!=None:
310 | ref_embeddings[aug_idx].append(current_embedding)
311 | ref_masks[aug_idx].append(current_label_0)
312 | UPDATE=True
313 |
314 |
315 | prev_embedding[aug_idx] = current_embedding
316 |
317 | if frame_idx > 0:
318 | all_preds = torch.cat(all_preds, dim=0)
319 | all_preds = torch.mean(all_preds, dim=0)
320 | pred_label = torch.argmax(all_preds, dim=0)
321 | if join_label is not None:
322 | join_label = join_label.squeeze(0).squeeze(0)
323 | keep = (join_label == 0).long()
324 | pred_label = pred_label * keep + join_label * (1 - keep)
325 | pred_label = pred_label
326 | current_label = pred_label.view(1, 1, ori_height, ori_width)
327 | if samples[aug_idx]['meta']['flip']:
328 | flip_pred_label = flip_tensor(pred_label, 1)
329 | flip_current_label = flip_pred_label.view(1, 1, ori_height, ori_width)
330 |
331 | for aug_idx in range(len(samples)):
332 | if join_label is not None:
333 | if samples[aug_idx]['meta']['flip']:
334 | ref_masks[aug_idx].append(flip_current_label)
335 | ref_mask_confident[aug_idx].append(flip_current_label)
336 | else:
337 | ref_masks[aug_idx].append(current_label)
338 |
339 | uncertainty_org,uncertainty_norm = cal_shannon_entropy(all_pred_exist)
340 | join_label = join_label.squeeze(0).squeeze(0)
341 | keep = (join_label == 0).long()
342 | join_uncertainty_map = (join_label <0).long()
343 | uncertainty_org = uncertainty_org * keep + join_uncertainty_map * (1 - keep)
344 |
345 | uncertainty_region = (uncertainty_org>self.unc_ratio ).long()
346 | pred_label_c = pred_label* (1 - uncertainty_region) + (125)* uncertainty_region
347 | pred_label_c = pred_label_c.view(1, 1, ori_height, ori_width)
348 |
349 | ref_mask_confident[aug_idx].append(pred_label_c)
350 |
351 | if samples[aug_idx]['meta']['flip']:
352 | prev_mask[aug_idx] = flip_current_label
353 | else:
354 | prev_mask[aug_idx] = current_label
355 |
356 | if UPDATE:
357 | if self.mem_every>-1 and frame_idx%self.mem_every==0 and frame_idx!=0 and current_embedding!=None and current_label_0!=None :
358 | uncertainty_region = (uncertainty_org>self.unc_ratio ).long()
359 | pred_label_c = pred_label* (1 - uncertainty_region) + (125)* uncertainty_region
360 | pred_label_c = pred_label_c.view(1, 1, ori_height, ori_width)
361 | ref_mask_confident[aug_idx].append(pred_label_c)
362 |
363 | one_frametime = time.time() - time_start
364 | seq_total_time += one_frametime
365 | seq_total_frame += 1
366 | obj_num = obj_num[0].item()
367 | print('Frame: {}, Obj Num: {}, Time: {}'.format(imgname[0], obj_num, one_frametime))
368 | # Save result
369 | save_mask(pred_label, os.path.join(self.result_root, seq_name, imgname[0].split('.')[0]+'.png'))
370 |
371 | else:
372 | one_frametime = time.time() - time_start
373 | seq_total_time += one_frametime
374 | print('Ref Frame: {}, Time: {}'.format(imgname[0], one_frametime))
375 |
376 | del(ref_embeddings)
377 | del(ref_masks)
378 | del(prev_embedding)
379 | del(prev_mask)
380 | del(seq_dataset)
381 | del(seq_dataloader)
382 | del(memory_cur_all_list)
383 |
384 |
385 | seq_avg_time_per_frame = seq_total_time / seq_total_frame
386 | total_time += seq_total_time
387 | total_frame += seq_total_frame
388 | total_avg_time_per_frame = total_time / total_frame
389 | total_sfps += seq_avg_time_per_frame
390 | avg_sfps = total_sfps / (seq_idx + 1)
391 | print("Seq {} FPS: {}, Total FPS: {}, FPS per Seq: {}".format(seq_name, 1./seq_avg_time_per_frame, 1./total_avg_time_per_frame, 1./avg_sfps))
392 |
393 | zip_folder(self.source_folder, self.zip_dir)
394 | self.print_log('Save result to {}.'.format(self.zip_dir))
395 |
396 |
397 | def print_log(self, string):
398 | print(string)
399 |
400 |
401 |
402 |
403 |
404 |
--------------------------------------------------------------------------------
/AOC-Net/complete_project/AOCNet/networks/engine/train_manager_mm.py:
--------------------------------------------------------------------------------
1 | import os
2 | import importlib
3 | import time
4 | import datetime as datetime
5 | import torch
6 | import torch.nn as nn
7 | import torch.optim as optim
8 | import torch.distributed as dist
9 | from torch.utils.data import DataLoader
10 | from torchvision import transforms
11 | import numpy as np
12 | from dataloaders.datasets_m import DAVIS2017_Train, YOUTUBE_VOS_Train, TEST
13 | import dataloaders.custom_transforms as tr
14 | from networks.deeplab.deeplab import DeepLab
15 | from utils.meters import AverageMeter
16 | from utils.image import label2colormap, masked_image, save_image
17 | from utils.checkpoint import load_network_and_optimizer, load_network, save_network
18 | from utils.learning import adjust_learning_rate, get_trainable_params
19 | from utils.metric import pytorch_iou
20 | #torch.backends.cudnn.enabled = True
21 | #torch.backends.cudnn.benchmark = True
22 | class Trainer(object):
23 | def __init__(self , rank, cfg):
24 | self.gpu = rank + cfg.DIST_START_GPU
25 | self.rank = rank
26 | self.cfg = cfg
27 | self.print_log(cfg.__dict__)
28 | print("Use GPU {} for training".format(self.gpu))
29 | torch.cuda.set_device(self.gpu)
30 |
31 | self.print_log('Build backbone.')
32 | self.feature_extracter = DeepLab(
33 | backbone=cfg.MODEL_BACKBONE,
34 | freeze_bn=cfg.MODEL_FREEZE_BN).cuda(self.gpu)
35 |
36 | if cfg.MODEL_FREEZE_BACKBONE:
37 | for param in self.feature_extracter.parameters():
38 | param.requires_grad = False
39 |
40 | self.print_log('Build VOS model.')
41 | CFBI = importlib.import_module(cfg.MODEL_MODULE)
42 |
43 | self.model = CFBI.get_module()(
44 | cfg,
45 | self.feature_extracter).cuda(self.gpu)
46 |
47 | if cfg.DIST_ENABLE:
48 | dist.init_process_group(
49 | backend=cfg.DIST_BACKEND,
50 | init_method=cfg.DIST_URL,
51 | world_size=cfg.TRAIN_GPUS,
52 | rank=rank,
53 | timeout=datetime.timedelta(seconds=300))
54 | self.dist_model = torch.nn.parallel.DistributedDataParallel(
55 | self.model,
56 | device_ids=[self.gpu],
57 | find_unused_parameters=True)
58 | else:
59 | self.dist_model = self.model
60 |
61 | self.print_log('Build optimizer.')
62 | trainable_params = get_trainable_params(
63 | model=self.dist_model,
64 | base_lr=cfg.TRAIN_LR,
65 | weight_decay=cfg.TRAIN_WEIGHT_DECAY,
66 | beta_wd=cfg.MODEL_GCT_BETA_WD)
67 |
68 | self.optimizer = optim.SGD(
69 | trainable_params,
70 | lr=cfg.TRAIN_LR,
71 | momentum=cfg.TRAIN_MOMENTUM,
72 | nesterov=True)
73 |
74 | self.prepare_dataset()
75 | self.process_pretrained_model()
76 |
77 | if cfg.TRAIN_TBLOG and self.rank == 0:
78 | from tensorboardX import SummaryWriter
79 | self.tblogger = SummaryWriter(cfg.DIR_TB_LOG)
80 |
81 | def process_pretrained_model(self):
82 | cfg = self.cfg
83 |
84 | self.step = cfg.TRAIN_START_STEP
85 | self.epoch = 0
86 |
87 | if cfg.TRAIN_AUTO_RESUME:
88 | ckpts = os.listdir(cfg.DIR_CKPT)
89 | if len(ckpts) > 0:
90 | ckpts = list(map(lambda x: int(x.split('_')[-1].split('.')[0]), ckpts))
91 | ckpt = np.sort(ckpts)[-1]
92 | cfg.TRAIN_RESUME = True
93 | cfg.TRAIN_RESUME_CKPT = ckpt
94 | cfg.TRAIN_RESUME_STEP = ckpt + 1
95 | else:
96 | cfg.TRAIN_RESUME = False
97 |
98 | if cfg.TRAIN_RESUME:
99 | resume_ckpt = os.path.join(cfg.DIR_CKPT, 'save_step_%s.pth' % (cfg.TRAIN_RESUME_CKPT))
100 |
101 | self.model, self.optimizer, removed_dict = load_network_and_optimizer(self.model, self.optimizer, resume_ckpt, self.gpu)
102 |
103 | if len(removed_dict) > 0:
104 | self.print_log('Remove {} from checkpoint.'.format(removed_dict))
105 |
106 | self.step = cfg.TRAIN_RESUME_STEP
107 | if cfg.TRAIN_TOTAL_STEPS <= self.step:
108 | self.print_log("Your training has finished!")
109 | exit()
110 | self.epoch = int(np.ceil(self.step / len(self.trainloader)))
111 |
112 | self.print_log('Resume from step {}'.format(self.step))
113 |
114 | elif cfg.PRETRAIN:
115 | if cfg.PRETRAIN_FULL:
116 | self.model, removed_dict = load_network(self.model, cfg.PRETRAIN_MODEL, self.gpu)
117 | if len(removed_dict) > 0:
118 | self.print_log('Remove {} from pretrained model.'.format(removed_dict))
119 | self.print_log('Load pretrained VOS model from {}.'.format(cfg.PRETRAIN_MODEL))
120 | else:
121 | feature_extracter, removed_dict = load_network(self.feature_extracter, cfg.PRETRAIN_MODEL, self.gpu)
122 | if len(removed_dict) > 0:
123 | self.print_log('Remove {} from pretrained model.'.format(removed_dict))
124 | self.print_log('Load pretrained backbone model from {}.'.format(cfg.PRETRAIN_MODEL))
125 |
126 | def prepare_dataset(self):
127 | cfg = self.cfg
128 | self.print_log('Process dataset...')
129 | composed_transforms = transforms.Compose([
130 | tr.RandomScale(cfg.DATA_MIN_SCALE_FACTOR, cfg.DATA_MAX_SCALE_FACTOR, cfg.DATA_SHORT_EDGE_LEN),
131 | tr.BalancedRandomCrop(cfg.DATA_RANDOMCROP),
132 | tr.RandomHorizontalFlip(cfg.DATA_RANDOMFLIP),
133 | tr.Resize(cfg.DATA_RANDOMCROP),
134 | tr.ToTensor()])
135 |
136 | train_datasets = []
137 | if 'davis2017' in cfg.DATASETS:
138 | train_davis_dataset = DAVIS2017_Train(
139 | root=cfg.DIR_DAVIS,
140 | full_resolution=cfg.TRAIN_DATASET_FULL_RESOLUTION,
141 | transform=composed_transforms,
142 | repeat_time=cfg.DATA_DAVIS_REPEAT,
143 | curr_len=cfg.DATA_CURR_SEQ_LEN,
144 | rand_gap=cfg.DATA_RANDOM_GAP_DAVIS,
145 | rand_reverse=cfg.DATA_RANDOM_REVERSE_SEQ)
146 | train_datasets.append(train_davis_dataset)
147 |
148 | if 'youtubevos' in cfg.DATASETS:
149 | train_ytb_dataset = YOUTUBE_VOS_Train(
150 | root=cfg.DIR_YTB,
151 | transform=composed_transforms,
152 | curr_len=cfg.DATA_CURR_SEQ_LEN,
153 | rand_gap=cfg.DATA_RANDOM_GAP_YTB,
154 | rand_reverse=cfg.DATA_RANDOM_REVERSE_SEQ)
155 | train_datasets.append(train_ytb_dataset)
156 |
157 | if 'test' in cfg.DATASETS:
158 | test_dataset = TEST(
159 | transform=composed_transforms,
160 | curr_len=cfg.DATA_CURR_SEQ_LEN)
161 | train_datasets.append(test_dataset)
162 |
163 | if len(train_datasets) > 1:
164 | train_dataset = torch.utils.data.ConcatDataset(train_datasets)
165 | elif len(train_datasets) == 1:
166 | train_dataset = train_datasets[0]
167 | else:
168 | self.print_log('No dataset!')
169 | exit(0)
170 |
171 | self.train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
172 | self.trainloader = DataLoader(
173 | train_dataset,
174 | batch_size=int(cfg.TRAIN_BATCH_SIZE / cfg.TRAIN_GPUS),
175 | shuffle=False,
176 | num_workers=cfg.DATA_WORKERS,
177 | pin_memory=True,
178 | sampler=self.train_sampler)
179 |
180 | self.print_log('Done!')
181 |
182 | def sequential_training(self):
183 |
184 | cfg = self.cfg
185 |
186 | running_losses = []
187 | running_ious = []
188 | for _ in range(cfg.DATA_CURR_SEQ_LEN):
189 | running_losses.append(AverageMeter())
190 | running_ious.append(AverageMeter())
191 | batch_time = AverageMeter()
192 | avg_obj = AverageMeter()
193 |
194 | optimizer = self.optimizer
195 | model = self.dist_model
196 | train_sampler = self.train_sampler
197 | trainloader = self.trainloader
198 | step = self.step
199 | epoch = self.epoch
200 | max_itr = cfg.TRAIN_TOTAL_STEPS
201 |
202 | PlaceHolder=[]
203 | for i in range(cfg.BLOCK_NUM):
204 | PlaceHolder.append(None)
205 |
206 | self.print_log('Start training.')
207 | model.train()
208 | while step < cfg.TRAIN_TOTAL_STEPS:
209 | train_sampler.set_epoch(epoch)
210 | epoch += 1
211 | last_time = time.time()
212 | for frame_idx, sample in enumerate(trainloader):
213 | now_lr = adjust_learning_rate(
214 | optimizer=optimizer,
215 | base_lr=cfg.TRAIN_LR,
216 | p=cfg.TRAIN_POWER,
217 | itr=step,
218 | max_itr=max_itr,
219 | warm_up_steps=cfg.TRAIN_WARM_UP_STEPS,
220 | is_cosine_decay=cfg.TRAIN_COSINE_DECAY)
221 |
222 | ref_imgs = sample['ref_img'] # batch_size * 3 * h * w
223 | prev_imgs = sample['prev_img']
224 | curr_imgs = sample['curr_img'][0]
225 | ref_labels = sample['ref_label'] # batch_size * 1 * h * w
226 | prev_labels = sample['prev_label']
227 | curr_labels = sample['curr_label'][0]
228 | obj_nums = sample['meta']['obj_num']
229 | bs, _, h, w = curr_imgs.size()
230 |
231 | ref_labels = ref_labels.cuda(self.gpu)
232 | prev_labels = prev_labels.cuda(self.gpu)
233 | curr_labels = curr_labels.cuda(self.gpu)
234 | obj_nums = obj_nums.cuda(self.gpu)
235 |
236 | if step % cfg.TRAIN_TBLOG_STEP == 0 and self.rank == 0 and cfg.TRAIN_TBLOG:
237 | tf_board = True
238 | else:
239 | tf_board = False
240 |
241 | # Sequential training
242 | all_boards = []
243 | curr_imgs = prev_imgs
244 | curr_labels = prev_labels
245 | all_pred = prev_labels.squeeze(1)
246 | optimizer.zero_grad()
247 | memory_cur_list=[]
248 | memory_prev_list=[]
249 | for iii in range(int(cfg.TRAIN_BATCH_SIZE//cfg.TRAIN_GPUS)):
250 | memory_cur_list.append(PlaceHolder)
251 | memory_prev_list.append(PlaceHolder)
252 |
253 | for idx in range(cfg.DATA_CURR_SEQ_LEN):
254 | prev_imgs = curr_imgs
255 | curr_imgs = sample['curr_img'][idx]
256 | inputs = torch.cat((ref_imgs, prev_imgs, curr_imgs), 0).cuda(self.gpu)
257 | if step > cfg.TRAIN_START_SEQ_TRAINING_STEPS:
258 | # Use previous prediction instead of ground-truth mask
259 | prev_labels = all_pred.unsqueeze(1)
260 | else:
261 | # Use previous ground-truth mask
262 | prev_labels = curr_labels
263 | curr_labels = sample['curr_label'][idx].cuda(self.gpu)
264 |
265 | loss, all_pred, boards,memory_cur_list = model(
266 | inputs,
267 | memory_prev_list,
268 | ref_labels,
269 | prev_labels,
270 | curr_labels,
271 | gt_ids=obj_nums,
272 | step=step,
273 | tf_board=tf_board)
274 |
275 | memory_prev_list = memory_cur_list
276 |
277 | iou = pytorch_iou(all_pred.unsqueeze(1), curr_labels, obj_nums)
278 | loss = torch.mean(loss) / cfg.DATA_CURR_SEQ_LEN
279 | loss.backward()
280 | all_boards.append(boards)
281 | running_losses[idx].update(loss.item() * cfg.DATA_CURR_SEQ_LEN)
282 | running_ious[idx].update(iou.item())
283 | torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.TRAIN_CLIP_GRAD_NORM)
284 | optimizer.step()
285 | batch_time.update(time.time() - last_time)
286 | avg_obj.update(obj_nums.float().mean().item())
287 | last_time = time.time()
288 |
289 | if step % cfg.TRAIN_TBLOG_STEP == 0 and self.rank == 0:
290 | self.process_log(
291 | ref_imgs, prev_imgs, curr_imgs,
292 | ref_labels, prev_labels, curr_labels,
293 | all_pred, all_boards, running_losses, running_ious, now_lr, step)
294 |
295 | if step % cfg.TRAIN_LOG_STEP == 0 and self.rank == 0:
296 | strs = 'Itr:{}, LR:{:.7f}, Time:{:.3f}, Obj:{:.1f}'.format(step, now_lr, batch_time.avg, avg_obj.avg)
297 | batch_time.reset()
298 | avg_obj.reset()
299 | for idx in range(cfg.DATA_CURR_SEQ_LEN):
300 | strs += ', S{}: L {:.3f}({:.3f}) IoU {:.3f}({:.3f})'.format(idx, running_losses[idx].val, running_losses[idx].avg,
301 | running_ious[idx].val, running_ious[idx].avg)
302 | running_losses[idx].reset()
303 | running_ious[idx].reset()
304 |
305 | self.print_log(strs)
306 |
307 | if step % cfg.TRAIN_SAVE_STEP == 0 and step != 0 and self.rank == 0:
308 | self.print_log('Save CKPT (Step {}).'.format(step))
309 | save_network(self.model, optimizer, step, cfg.DIR_CKPT, cfg.TRAIN_MAX_KEEP_CKPT)
310 |
311 | step += 1
312 | if step > cfg.TRAIN_TOTAL_STEPS:
313 | break
314 |
315 | if self.rank == 0:
316 | self.print_log('Save final CKPT (Step {}).'.format(step - 1))
317 | save_network(self.model, optimizer, step - 1, cfg.DIR_CKPT, cfg.TRAIN_MAX_KEEP_CKPT)
318 |
319 | def print_log(self, string):
320 | if self.rank == 0:
321 | print(string)
322 |
323 |
324 | def process_log(self,
325 | ref_imgs, prev_imgs, curr_imgs,
326 | ref_labels, prev_labels, curr_labels,
327 | curr_pred, all_boards, running_losses, running_ious, now_lr, step):
328 | cfg = self.cfg
329 |
330 | mean = np.array([[[0.485]], [[0.456]], [[0.406]]])
331 | sigma = np.array([[[0.229]], [[0.224]], [[0.225]]])
332 |
333 | show_ref_img, show_prev_img, show_curr_img = [img.cpu().numpy()[0] * sigma + mean for img in [ref_imgs, prev_imgs, curr_imgs]]
334 |
335 | show_gt, show_prev_gt, show_ref_gt, show_preds_s = [label.cpu()[0].squeeze(0).numpy() for label in [curr_labels, prev_labels, ref_labels, curr_pred]]
336 |
337 | show_gtf, show_prev_gtf, show_ref_gtf, show_preds_sf = [label2colormap(label).transpose((2,0,1)) for label in [show_gt, show_prev_gt, show_ref_gt, show_preds_s]]
338 |
339 | if cfg.TRAIN_IMG_LOG or cfg.TRAIN_TBLOG:
340 |
341 | show_ref_img = masked_image(show_ref_img, show_ref_gtf, show_ref_gt)
342 | if cfg.TRAIN_IMG_LOG:
343 | save_image(show_ref_img, os.path.join(cfg.DIR_IMG_LOG, '%06d_ref_img.jpeg' % (step)))
344 |
345 | show_prev_img = masked_image(show_prev_img, show_prev_gtf, show_prev_gt)
346 | if cfg.TRAIN_IMG_LOG:
347 | save_image(show_prev_img, os.path.join(cfg.DIR_IMG_LOG, '%06d_prev_img.jpeg' % (step)))
348 |
349 | show_img_pred = masked_image(show_curr_img, show_preds_sf, show_preds_s)
350 | if cfg.TRAIN_IMG_LOG:
351 | save_image(show_img_pred, os.path.join(cfg.DIR_IMG_LOG, '%06d_prediction.jpeg' % (step)))
352 |
353 | show_curr_img = masked_image(show_curr_img, show_gtf, show_gt)
354 | if cfg.TRAIN_IMG_LOG:
355 | save_image(show_curr_img, os.path.join(cfg.DIR_IMG_LOG, '%06d_groundtruth.jpeg' % (step)))
356 |
357 | if cfg.TRAIN_TBLOG:
358 | for seq_step, running_loss, running_iou in zip(range(len(running_losses)), running_losses, running_ious):
359 | self.tblogger.add_scalar('S{}/Loss'.format(seq_step), running_loss.avg, step)
360 | self.tblogger.add_scalar('S{}/IoU'.format(seq_step), running_iou.avg, step)
361 |
362 | self.tblogger.add_scalar('LR', now_lr, step)
363 | self.tblogger.add_image('Ref/Image', show_ref_img, step)
364 | self.tblogger.add_image('Ref/GT', show_ref_gtf, step)
365 |
366 | self.tblogger.add_image('Prev/Image', show_prev_img, step)
367 | self.tblogger.add_image('Prev/GT', show_prev_gtf, step)
368 |
369 | self.tblogger.add_image('Curr/Image_GT', show_curr_img, step)
370 | self.tblogger.add_image('Curr/Image_Pred', show_img_pred, step)
371 |
372 | self.tblogger.add_image('Curr/Mask_GT', show_gtf, step)
373 | self.tblogger.add_image('Curr/Mask_Pred', show_preds_sf, step)
374 |
375 | for seq_step, boards in enumerate(all_boards):
376 | for key in boards['image'].keys():
377 | tmp = boards['image'][key].cpu().numpy()
378 | self.tblogger.add_image('S{}/' + key, tmp, step)
379 | for key in boards['scalar'].keys():
380 | tmp = boards['scalar'][key].cpu().numpy()
381 | self.tblogger.add_scalar('S{}/' + key, tmp, step)
382 |
383 | self.tblogger.flush()
384 |
385 | del(all_boards)
386 |
387 |
388 |
--------------------------------------------------------------------------------
/AOC-Net/complete_project/AOCNet/networks/layers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JerryX1110/Robust-Video-Object-Segmentation/f562afe70ac0f960d27980d3e82b237f8dea45ec/AOC-Net/complete_project/AOCNet/networks/layers/__init__.py
--------------------------------------------------------------------------------
/AOC-Net/complete_project/AOCNet/networks/layers/aspp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import math
4 | from torch import nn
5 | from networks.layers.gct import GCT
6 |
7 | class _ASPPModule(nn.Module):
8 | def __init__(self, inplanes, planes, kernel_size, padding, dilation):
9 | super(_ASPPModule, self).__init__()
10 | self.GCT = GCT(inplanes)
11 | self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
12 | stride=1, padding=padding, dilation=dilation, bias=False)
13 | self.bn = nn.GroupNorm(int(planes / 4), planes)
14 | self.relu = nn.ReLU(inplace=True)
15 |
16 | self._init_weight()
17 |
18 | def forward(self, x):
19 | x = self.GCT(x)
20 | x = self.atrous_conv(x)
21 | x = self.bn(x)
22 |
23 | return self.relu(x)
24 |
25 | def _init_weight(self):
26 | for m in self.modules():
27 | if isinstance(m, nn.Conv2d):
28 | torch.nn.init.kaiming_normal_(m.weight)
29 | elif isinstance(m, nn.BatchNorm2d):
30 | m.weight.data.fill_(1)
31 | m.bias.data.zero_()
32 |
33 | class ASPP(nn.Module):
34 | def __init__(self):
35 | super(ASPP, self).__init__()
36 |
37 | inplanes = 512
38 | dilations = [1, 6, 12, 18]
39 |
40 |
41 | self.aspp1 = _ASPPModule(inplanes, 128, 1, padding=0, dilation=dilations[0])
42 | self.aspp2 = _ASPPModule(inplanes, 128, 3, padding=dilations[1], dilation=dilations[1])
43 | self.aspp3 = _ASPPModule(inplanes, 128, 3, padding=dilations[2], dilation=dilations[2])
44 | self.aspp4 = _ASPPModule(inplanes, 128, 3, padding=dilations[3], dilation=dilations[3])
45 |
46 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
47 | nn.Conv2d(inplanes, 128, 1, stride=1, bias=False),
48 | nn.ReLU(inplace=True))
49 |
50 | self.GCT = GCT(640)
51 | self.conv1 = nn.Conv2d(640, 256, 1, bias=False)
52 | self.bn1 = nn.GroupNorm(32, 256)
53 | self.relu = nn.ReLU(inplace=True)
54 | self._init_weight()
55 |
56 | def forward(self, x):
57 | x1 = self.aspp1(x)
58 | x2 = self.aspp2(x)
59 | x3 = self.aspp3(x)
60 | x4 = self.aspp4(x)
61 | x5 = self.global_avg_pool(x)
62 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
63 | x = torch.cat((x1, x2, x3, x4, x5), dim=1)
64 |
65 | x = self.GCT(x)
66 | x = self.conv1(x)
67 | x = self.bn1(x)
68 | x = self.relu(x)
69 |
70 | return x
71 |
72 | def _init_weight(self):
73 | for m in self.modules():
74 | if isinstance(m, nn.Conv2d):
75 | torch.nn.init.kaiming_normal_(m.weight)
76 | elif isinstance(m, nn.BatchNorm2d):
77 | m.weight.data.fill_(1)
78 | m.bias.data.zero_()
--------------------------------------------------------------------------------
/AOC-Net/complete_project/AOCNet/networks/layers/conv_gru.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import torch.nn.functional as F
5 | from torch.nn import init
6 |
7 |
8 | class ConvGRUCell(nn.Module):
9 | """
10 | Generate a convolutional GRU cell
11 | """
12 |
13 | def __init__(self, input_size, hidden_size, kernel_size):
14 | super().__init__()
15 | padding = kernel_size // 2
16 | self.input_size = input_size
17 | self.hidden_size = hidden_size
18 | self.reset_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding)
19 | self.update_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding)
20 | self.out_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding)
21 |
22 | init.orthogonal(self.reset_gate.weight)
23 | init.orthogonal(self.update_gate.weight)
24 | init.orthogonal(self.out_gate.weight)
25 | init.constant(self.reset_gate.bias, 0.)
26 | init.constant(self.update_gate.bias, 0.)
27 | init.constant(self.out_gate.bias, 0.)
28 |
29 |
30 | def forward(self, input_, prev_state):
31 |
32 | # get batch and spatial sizes
33 | batch_size = input_.data.size()[0]
34 | spatial_size = input_.data.size()[2:]
35 |
36 | # generate empty prev_state, if None is provided
37 | if prev_state is None:
38 | state_size = [batch_size, self.hidden_size] + list(spatial_size)
39 | if torch.cuda.is_available():
40 | prev_state = torch.zeros(state_size).cuda()
41 | else:
42 | prev_state = torch.zeros(state_size)
43 | print("input_.size()",input_.size())
44 | print("prev_state.size()",prev_state.size())
45 | if input_.size()[0]!=prev_state.size()[0]:
46 |
47 | state_size = [input_.size()[0]-prev_state.size()[0], self.hidden_size] + list(spatial_size)
48 | if torch.cuda.is_available():
49 | prev_state_tmp = torch.zeros(state_size).cuda()
50 | else:
51 | prev_state_tmp = torch.zeros(state_size)
52 |
53 | prev_state = torch.cat([prev_state, prev_state_tmp],dim=0)
54 | print("add_on prev_state.size()", prev_state.size())
55 | # data size is [batch, channel, height, width]
56 | stacked_inputs = torch.cat([input_, prev_state], dim=1)
57 | update = F.sigmoid(self.update_gate(stacked_inputs))
58 | reset = F.sigmoid(self.reset_gate(stacked_inputs))
59 | out_inputs = F.tanh(self.out_gate(torch.cat([input_, prev_state * reset], dim=1)))
60 | new_state = prev_state * (1 - update) + out_inputs * update
61 |
62 | return new_state
63 |
64 |
65 | class ConvGRU(nn.Module):
66 |
67 | def __init__(self, input_size, hidden_sizes, kernel_sizes, n_layers):
68 | '''
69 | Generates a multi-layer convolutional GRU.
70 | Preserves spatial dimensions across cells, only altering depth.
71 |
72 | Parameters
73 | ----------
74 | input_size : integer. depth dimension of input tensors.
75 | hidden_sizes : integer or list. depth dimensions of hidden state.
76 | if integer, the same hidden size is used for all cells.
77 | kernel_sizes : integer or list. sizes of Conv2d gate kernels.
78 | if integer, the same kernel size is used for all cells.
79 | n_layers : integer. number of chained `ConvGRUCell`.
80 | '''
81 |
82 | super(ConvGRU, self).__init__()
83 |
84 | self.input_size = input_size
85 |
86 | if type(hidden_sizes) != list:
87 | self.hidden_sizes = [hidden_sizes]*n_layers
88 | else:
89 | assert len(hidden_sizes) == n_layers, '`hidden_sizes` must have the same length as n_layers'
90 | self.hidden_sizes = hidden_sizes
91 | if type(kernel_sizes) != list:
92 | self.kernel_sizes = [kernel_sizes]*n_layers
93 | else:
94 | assert len(kernel_sizes) == n_layers, '`kernel_sizes` must have the same length as n_layers'
95 | self.kernel_sizes = kernel_sizes
96 |
97 | self.n_layers = n_layers
98 |
99 | cells = []
100 | for i in range(self.n_layers):
101 | if i == 0:
102 | input_dim = self.input_size
103 | else:
104 | input_dim = self.hidden_sizes[i-1]
105 |
106 | cell = ConvGRUCell(input_dim, self.hidden_sizes[i], self.kernel_sizes[i])
107 | name = 'ConvGRUCell_' + str(i).zfill(2)
108 |
109 | setattr(self, name, cell)
110 | cells.append(getattr(self, name))
111 |
112 | self.cells = cells
113 |
114 |
115 | def forward(self, x, hidden=None):
116 | '''
117 | Parameters
118 | ----------
119 | x : 4D input tensor. (batch, channels, height, width).
120 | hidden : list of 4D hidden state representations. (batch, channels, height, width).
121 |
122 | Returns
123 | -------
124 | upd_hidden : 5D hidden representation. (layer, batch, channels, height, width).
125 | '''
126 | if not hidden:
127 | hidden = [None]*self.n_layers
128 |
129 | input_ = x
130 |
131 | upd_hidden = []
132 |
133 | for layer_idx in range(self.n_layers):
134 | cell = self.cells[layer_idx]
135 | cell_hidden = hidden[layer_idx]
136 |
137 | # pass through layer
138 | upd_cell_hidden = cell(input_, cell_hidden)
139 | upd_hidden.append(upd_cell_hidden.detach())
140 | # update input_ to the last updated hidden layer for next pass
141 | input_ = upd_cell_hidden
142 |
143 | # retain tensors in list to allow different hidden sizes
144 | return input_, upd_hidden
--------------------------------------------------------------------------------
/AOC-Net/complete_project/AOCNet/networks/layers/gct.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import math
4 | from torch import nn
5 | from networks.p2t.center_module import SpatialProp
6 |
7 | class GCT(nn.Module):
8 | def __init__(self, num_channels, epsilon=1e-5, mode='l2', after_relu=False):
9 | super(GCT, self).__init__()
10 | self.alpha = nn.Parameter(torch.ones(1, num_channels, 1, 1))
11 | self.gamma = nn.Parameter(torch.zeros(1, num_channels, 1, 1))
12 | self.beta = nn.Parameter(torch.zeros(1, num_channels, 1, 1))
13 | self.epsilon = epsilon
14 | self.mode = mode
15 | self.after_relu = after_relu
16 |
17 | def forward(self, x):
18 |
19 | if self.mode == 'l2':
20 | embedding = (x.pow(2).sum((2,3), keepdim=True) + self.epsilon).pow(0.5) * self.alpha
21 | norm = self.gamma / (embedding.pow(2).mean(dim=1, keepdim=True) + self.epsilon).pow(0.5)
22 |
23 | elif self.mode == 'l1':
24 | if not self.after_relu:
25 | _x = torch.abs(x)
26 | else:
27 | _x = x
28 | embedding = _x.sum((2,3), keepdim=True) * self.alpha
29 | norm = self.gamma / (torch.abs(embedding).mean(dim=1, keepdim=True) + self.epsilon)
30 | else:
31 | print('Unknown mode!')
32 | exit()
33 |
34 | gate = 1. + torch.tanh(embedding * norm + self.beta)
35 |
36 | return x * gate
37 |
38 | class Bottleneck(nn.Module):
39 | def __init__(self, inplanes, outplanes, stride=1, dilation=1):
40 | super(Bottleneck, self).__init__()
41 | expansion = 4
42 | planes = int(outplanes / expansion)
43 | self.GCT1 = GCT(inplanes)
44 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
45 | self.bn1 = nn.GroupNorm(32, planes)
46 |
47 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
48 | dilation=dilation, padding=dilation, bias=False)
49 | self.bn2 = nn.GroupNorm(32, planes)
50 |
51 | self.conv3 = nn.Conv2d(planes, planes * expansion, kernel_size=1, bias=False)
52 | self.bn3 = nn.GroupNorm(32, planes * expansion)
53 | self.relu = nn.ReLU(inplace=True)
54 | if stride != 1 or inplanes != planes * expansion:
55 | downsample = nn.Sequential(
56 | nn.Conv2d(inplanes, planes * expansion,
57 | kernel_size=1, stride=stride, bias=False),
58 | nn.GroupNorm(32, planes * expansion),
59 | )
60 | else:
61 | downsample = None
62 | self.downsample = downsample
63 | self.stride = stride
64 | self.dilation = dilation
65 |
66 | for m in self.modules():
67 | if isinstance(m, nn.Conv2d):
68 | nn.init.kaiming_normal_(m.weight,mode='fan_out', nonlinearity='relu')
69 |
70 | def forward(self, x):
71 | residual = x
72 |
73 | out = self.GCT1(x)
74 | out = self.conv1(out)
75 | out = self.bn1(out)
76 | out = self.relu(out)
77 |
78 | out = self.conv2(out)
79 | out = self.bn2(out)
80 | out = self.relu(out)
81 |
82 | out = self.conv3(out)
83 | out = self.bn3(out)
84 |
85 | if self.downsample is not None:
86 | residual = self.downsample(x)
87 |
88 | out += residual
89 | out = self.relu(out)
90 |
91 | return out
92 |
93 | class GCT_Spatial(nn.Module):
94 | def __init__(self, num_channels, epsilon=1e-5, mode='l2', after_relu=False):
95 | super(GCT_Spatial, self).__init__()
96 |
97 | self.alpha = nn.Parameter(torch.ones(1, num_channels, 1, 1))
98 | self.gamma = nn.Parameter(torch.zeros(1, num_channels, 1, 1))
99 | self.beta = nn.Parameter(torch.zeros(1, num_channels, 1, 1))
100 | self.epsilon = epsilon
101 | self.mode = mode
102 | self.after_relu = after_relu
103 |
104 | def forward(self, x,spatial_embedding):
105 |
106 | if self.mode == 'l2':
107 | embedding = (x.pow(2).sum((2,3), keepdim=True) + self.epsilon).pow(0.5) * self.alpha
108 | norm = self.gamma / (embedding.pow(2).mean(dim=1, keepdim=True) + self.epsilon).pow(0.5)
109 |
110 | elif self.mode == 'l1':
111 | if not self.after_relu:
112 | _x = torch.abs(x)
113 | else:
114 | _x = x
115 | embedding = _x.sum((2,3), keepdim=True) * self.alpha
116 | norm = self.gamma / (torch.abs(embedding).mean(dim=1, keepdim=True) + self.epsilon)
117 | else:
118 | print('Unknown mode!')
119 | exit()
120 |
121 | gate = 1. + torch.tanh(embedding * norm + self.beta)
122 |
123 | return x * gate
--------------------------------------------------------------------------------
/AOC-Net/complete_project/AOCNet/networks/layers/gru_conv.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch import nn
4 | from torch.autograd import Variable
5 |
6 |
7 | class ConvGRUCell(nn.Module):
8 | def __init__(self, input_size, input_dim, hidden_dim, kernel_size, bias, dtype):
9 | """
10 | Initialize the ConvLSTM cell
11 | :param input_size: (int, int)
12 | Height and width of input tensor as (height, width).
13 | :param input_dim: int
14 | Number of channels of input tensor.
15 | :param hidden_dim: int
16 | Number of channels of hidden state.
17 | :param kernel_size: (int, int)
18 | Size of the convolutional kernel.
19 | :param bias: bool
20 | Whether or not to add the bias.
21 | :param dtype: torch.cuda.FloatTensor or torch.FloatTensor
22 | Whether or not to use cuda.
23 | """
24 | super(ConvGRUCell, self).__init__()
25 | self.height, self.width = input_size
26 | self.padding = kernel_size[0] // 2, kernel_size[1] // 2
27 | self.hidden_dim = hidden_dim
28 | self.bias = bias
29 | self.dtype = dtype
30 |
31 | self.conv_gates = nn.Conv2d(in_channels=input_dim + hidden_dim,
32 | out_channels=2*self.hidden_dim, # for update_gate,reset_gate respectively
33 | kernel_size=kernel_size,
34 | padding=self.padding,
35 | bias=self.bias)
36 |
37 | self.conv_can = nn.Conv2d(in_channels=input_dim+hidden_dim,
38 | out_channels=self.hidden_dim, # for candidate neural memory
39 | kernel_size=kernel_size,
40 | padding=self.padding,
41 | bias=self.bias)
42 |
43 | def init_hidden(self, batch_size):
44 | return (Variable(torch.zeros(batch_size, self.hidden_dim, self.height, self.width)).type(self.dtype))
45 |
46 | def forward(self, input_tensor, h_cur):
47 | """
48 |
49 | :param self:
50 | :param input_tensor: (b, c, h, w)
51 | input is actually the target_model
52 | :param h_cur: (b, c_hidden, h, w)
53 | current hidden and cell states respectively
54 | :return: h_next,
55 | next hidden state
56 | """
57 | combined = torch.cat([input_tensor, h_cur], dim=1)
58 | combined_conv = self.conv_gates(combined)
59 |
60 | gamma, beta = torch.split(combined_conv, self.hidden_dim, dim=1)
61 | reset_gate = torch.sigmoid(gamma)
62 | update_gate = torch.sigmoid(beta)
63 |
64 | combined = torch.cat([input_tensor, reset_gate*h_cur], dim=1)
65 | cc_cnm = self.conv_can(combined)
66 | cnm = torch.tanh(cc_cnm)
67 |
68 | h_next = (1 - update_gate) * h_cur + update_gate * cnm
69 | return h_next
70 |
71 |
72 | class ConvGRU(nn.Module):
73 | def __init__(self, input_size, input_dim, hidden_dim, kernel_size, num_layers,
74 | dtype, batch_first=False, bias=True, return_all_layers=False):
75 | """
76 |
77 | :param input_size: (int, int)
78 | Height and width of input tensor as (height, width).
79 | :param input_dim: int e.g. 256
80 | Number of channels of input tensor.
81 | :param hidden_dim: int e.g. 1024
82 | Number of channels of hidden state.
83 | :param kernel_size: (int, int)
84 | Size of the convolutional kernel.
85 | :param num_layers: int
86 | Number of ConvLSTM layers
87 | :param dtype: torch.cuda.FloatTensor or torch.FloatTensor
88 | Whether or not to use cuda.
89 | :param alexnet_path: str
90 | pretrained alexnet parameters
91 | :param batch_first: bool
92 | if the first position of array is batch or not
93 | :param bias: bool
94 | Whether or not to add the bias.
95 | :param return_all_layers: bool
96 | if return hidden and cell states for all layers
97 | """
98 | super(ConvGRU, self).__init__()
99 |
100 | # Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers
101 | kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
102 | hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
103 | if not len(kernel_size) == len(hidden_dim) == num_layers:
104 | raise ValueError('Inconsistent list length.')
105 |
106 | self.height, self.width = input_size
107 | self.input_dim = input_dim
108 | self.hidden_dim = hidden_dim
109 | self.kernel_size = kernel_size
110 | self.dtype = dtype
111 | self.num_layers = num_layers
112 | self.batch_first = batch_first
113 | self.bias = bias
114 | self.return_all_layers = return_all_layers
115 |
116 | cell_list = []
117 | for i in range(0, self.num_layers):
118 | cur_input_dim = input_dim if i == 0 else hidden_dim[i - 1]
119 | cell_list.append(ConvGRUCell(input_size=(self.height, self.width),
120 | input_dim=cur_input_dim,
121 | hidden_dim=self.hidden_dim[i],
122 | kernel_size=self.kernel_size[i],
123 | bias=self.bias,
124 | dtype=self.dtype))
125 |
126 | # convert python list to pytorch module
127 | self.cell_list = nn.ModuleList(cell_list)
128 |
129 | def forward(self, input_tensor, hidden_state=None):
130 | """
131 |
132 | :param input_tensor: (b, t, c, h, w) or (t,b,c,h,w) depends on if batch first or not
133 | extracted features from alexnet
134 | :param hidden_state:
135 | :return: layer_output_list, last_state_list
136 | """
137 | print("input_tensor.device",input_tensor.device)
138 | if not self.batch_first:
139 | # (t, b, c, h, w) -> (b, t, c, h, w)
140 | input_tensor = input_tensor.permute(1, 0, 2, 3, 4)
141 |
142 | # Implement stateful ConvLSTM
143 | if hidden_state is not None:
144 | raise NotImplementedError()
145 | else:
146 | hidden_state = self._init_hidden(batch_size=input_tensor.size(0))
147 |
148 | layer_output_list = []
149 | last_state_list = []
150 |
151 | seq_len = input_tensor.size(1)
152 | cur_layer_input = input_tensor
153 |
154 | for layer_idx in range(self.num_layers):
155 | h = hidden_state[layer_idx]
156 | output_inner = []
157 | for t in range(seq_len):
158 | # input current hidden and cell state then compute the next hidden and cell state through ConvLSTMCell forward function
159 | h = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :], # (b,t,c,h,w)
160 | h_cur=h)
161 | output_inner.append(h)
162 |
163 | layer_output = torch.stack(output_inner, dim=1)
164 | cur_layer_input = layer_output
165 |
166 | layer_output_list.append(layer_output)
167 | last_state_list.append([h])
168 |
169 | if not self.return_all_layers:
170 | layer_output_list = layer_output_list[-1:]
171 | last_state_list = last_state_list[-1:]
172 |
173 | return layer_output_list, last_state_list
174 |
175 | def _init_hidden(self, batch_size):
176 | init_states = []
177 | for i in range(self.num_layers):
178 | init_states.append(self.cell_list[i].init_hidden(batch_size))
179 | return init_states
180 |
181 | @staticmethod
182 | def _check_kernel_size_consistency(kernel_size):
183 | if not (isinstance(kernel_size, tuple) or
184 | (isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))):
185 | raise ValueError('`kernel_size` must be tuple or list of tuples')
186 |
187 | @staticmethod
188 | def _extend_for_multilayer(param, num_layers):
189 | if not isinstance(param, list):
190 | param = [param] * num_layers
191 | return param
192 |
193 |
194 | if __name__ == '__main__':
195 | # set CUDA device
196 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
197 |
198 | # detect if CUDA is available or not
199 | use_gpu = torch.cuda.is_available()
200 | if use_gpu:
201 | dtype = torch.cuda.FloatTensor # computation in GPU
202 | else:
203 | dtype = torch.FloatTensor
204 | dtype = torch.FloatTensor
205 | height = width = 6
206 | channels = 256
207 | hidden_dim = [32, 64]
208 | kernel_size = (3,3) # kernel size for two stacked hidden layer
209 | num_layers = 2 # number of stacked hidden layer
210 | model = ConvGRU(input_size=(height, width),
211 | input_dim=channels,
212 | hidden_dim=hidden_dim,
213 | kernel_size=kernel_size,
214 | num_layers=num_layers,
215 | dtype=dtype,
216 | batch_first=True,
217 | bias = True,
218 | return_all_layers = False)
219 |
220 | batch_size = 1
221 | time_steps = 1
222 | input_tensor = torch.rand(batch_size, time_steps, channels, height, width) # (b,t,c,h,w)
223 | layer_output_list, last_state_list = model(input_tensor)
224 | for i in layer_output_list:
225 | print("i.size()",i.size())
226 | for i in last_state_list:
227 | for j in i:
228 | print("j.size()",j.size())
229 | layer_output_list, last_state_list2 = model(input_tensor,last_state_list)
--------------------------------------------------------------------------------
/AOC-Net/complete_project/AOCNet/networks/layers/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import os
4 |
5 | class Concat_BCEWithLogitsLoss(nn.Module):
6 | def __init__(self, top_k_percent_pixels=None,
7 | hard_example_mining_step=100000):
8 | super(Concat_BCEWithLogitsLoss, self).__init__()
9 | self.top_k_percent_pixels = top_k_percent_pixels
10 | if top_k_percent_pixels is not None:
11 | assert(top_k_percent_pixels > 0 and top_k_percent_pixels < 1)
12 | self.hard_example_mining_step = hard_example_mining_step
13 | if self.top_k_percent_pixels == None:
14 | self.bceloss = nn.BCEWithLogitsLoss(reduction='mean')
15 | else:
16 | self.bceloss = nn.BCEWithLogitsLoss(reduction='none')
17 |
18 | def forward(self, dic_tmp, y, step):
19 | total_loss = []
20 | for i in range(len(dic_tmp)):
21 | pred_logits = dic_tmp[i]
22 | gts = y[i]
23 | if self.top_k_percent_pixels == None:
24 | final_loss = self.bceloss(pred_logits, gts)
25 | else:
26 | # Only compute the loss for top k percent pixels.
27 | # First, compute the loss for all pixels. Note we do not put the loss
28 | # to loss_collection and set reduction = None to keep the shape.
29 | num_pixels = float(pred_logits.size(2) * pred_logits.size(3))
30 | pred_logits = pred_logits.view(-1, pred_logits.size(
31 | 1), pred_logits.size(2) * pred_logits.size(3))
32 | gts = gts.view(-1, gts.size(1), gts.size(2) * gts.size(3))
33 | pixel_losses = self.bceloss(pred_logits, gts)
34 | if self.hard_example_mining_step == 0:
35 | top_k_pixels = int(self.top_k_percent_pixels * num_pixels)
36 | else:
37 | ratio = min(
38 | 1.0, step / float(self.hard_example_mining_step))
39 | top_k_pixels = int(
40 | (ratio * self.top_k_percent_pixels + (1.0 - ratio)) * num_pixels)
41 | _, top_k_indices = torch.topk(
42 | pixel_losses, k=top_k_pixels, dim=2)
43 |
44 | final_loss = nn.BCEWithLogitsLoss(
45 | weight=top_k_indices, reduction='mean')(pred_logits, gts)
46 | final_loss = final_loss.unsqueeze(0)
47 | total_loss.append(final_loss)
48 | total_loss = torch.cat(total_loss, dim=0)
49 | return total_loss
50 |
51 |
52 | class Concat_CrossEntropyLoss(nn.Module):
53 | def __init__(self, top_k_percent_pixels=None,
54 | hard_example_mining_step=100000):
55 | super(Concat_CrossEntropyLoss, self).__init__()
56 | self.top_k_percent_pixels = top_k_percent_pixels
57 | if top_k_percent_pixels is not None:
58 | assert(top_k_percent_pixels > 0 and top_k_percent_pixels < 1)
59 | self.hard_example_mining_step = hard_example_mining_step
60 | if self.top_k_percent_pixels == None:
61 | self.celoss = nn.CrossEntropyLoss(
62 | ignore_index=255, reduction='mean')
63 | else:
64 | self.celoss = nn.CrossEntropyLoss(
65 | ignore_index=255, reduction='none')
66 |
67 | def forward(self, dic_tmp, y, step):
68 | total_loss = []
69 | for i in range(len(dic_tmp)):
70 | pred_logits = dic_tmp[i]
71 | gts = y[i]
72 | if self.top_k_percent_pixels == None:
73 | final_loss = self.celoss(pred_logits, gts)
74 | else:
75 | # Only compute the loss for top k percent pixels.
76 | # First, compute the loss for all pixels. Note we do not put the loss
77 | # to loss_collection and set reduction = None to keep the shape.
78 | num_pixels = float(pred_logits.size(2) * pred_logits.size(3))
79 | pred_logits = pred_logits.view(-1, pred_logits.size(
80 | 1), pred_logits.size(2) * pred_logits.size(3))
81 | gts = gts.view(-1, gts.size(1) * gts.size(2))
82 | pixel_losses = self.celoss(pred_logits, gts)
83 | if self.hard_example_mining_step == 0:
84 | top_k_pixels = int(self.top_k_percent_pixels * num_pixels)
85 | else:
86 | ratio = min(
87 | 1.0, step / float(self.hard_example_mining_step))
88 | top_k_pixels = int(
89 | (ratio * self.top_k_percent_pixels + (1.0 - ratio)) * num_pixels)
90 | top_k_loss, top_k_indices = torch.topk(
91 | pixel_losses, k=top_k_pixels, dim=1)
92 |
93 | final_loss = torch.mean(top_k_loss)
94 | final_loss = final_loss.unsqueeze(0)
95 | total_loss.append(final_loss)
96 | total_loss = torch.cat(total_loss, dim=0)
97 | return total_loss
98 |
--------------------------------------------------------------------------------
/AOC-Net/complete_project/AOCNet/networks/layers/normalization.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class FrozenBatchNorm2d(nn.Module):
7 | """
8 | BatchNorm2d where the batch statistics and the affine parameters
9 | are fixed
10 | """
11 | def __init__(self, n, epsilon=1e-5):
12 | super(FrozenBatchNorm2d, self).__init__()
13 | self.register_buffer("weight", torch.ones(n))
14 | self.register_buffer("bias", torch.zeros(n))
15 | self.register_buffer("running_mean", torch.zeros(n))
16 | self.register_buffer("running_var", torch.ones(n))
17 | self.epsilon = epsilon
18 |
19 | def forward(self, x):
20 | scale = self.weight * (self.running_var + self.epsilon).rsqrt()
21 | bias = self.bias - self.running_mean * scale
22 | scale = scale.reshape(1, -1, 1, 1)
23 | bias = bias.reshape(1, -1, 1, 1)
24 | return x * scale + bias
--------------------------------------------------------------------------------
/AOC-Net/complete_project/AOCNet/networks/layers/refiner.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | from torch import nn
4 | from torch.nn import functional as F
5 | from typing import Tuple
6 |
7 |
8 | class Refiner(nn.Module):
9 | """
10 | Refiner refines the coarse output to full resolution.
11 |
12 | Args:
13 | mode: area selection mode. Options:
14 | "full" - No area selection, refine everywhere using regular Conv2d.
15 | "sampling" - Refine fixed amount of pixels ranked by the top most errors.
16 | "thresholding" - Refine varying amount of pixels that have greater error than the threshold.
17 | sample_pixels: number of pixels to refine. Only used when mode == "sampling".
18 | threshold: error threshold ranged from 0 ~ 1. Refine where err > threshold. Only used when mode == "thresholding".
19 | kernel_size: The convolution kernel_size. Options: [1, 3]
20 | prevent_oversampling: True for regular cases, False for speedtest.
21 |
22 | Compatibility Args:
23 | patch_crop_method: the method for cropping patches. Options:
24 | "unfold" - Best performance for PyTorch and TorchScript.
25 | "roi_align" - Another way for croping patches.
26 | "gather" - Another way for croping patches.
27 | patch_replace_method: the method for replacing patches. Options:
28 | "scatter_nd" - Best performance for PyTorch and TorchScript.
29 | "scatter_element" - Another way for replacing patches.
30 |
31 | Input:
32 | src: (B, 3, H, W) full resolution source image.
33 | bgr: (B, 3, H, W) full resolution background image.
34 | pha: (B, 1, Hc, Wc) coarse alpha prediction.
35 | fgr: (B, 3, Hc, Wc) coarse foreground residual prediction.
36 | err: (B, 1, Hc, Hc) coarse error prediction.
37 | hid: (B, 32, Hc, Hc) coarse hidden encoding.
38 |
39 | Output:
40 | pha: (B, 1, H, W) full resolution alpha prediction.
41 | fgr: (B, 3, H, W) full resolution foreground residual prediction.
42 | ref: (B, 1, H/4, W/4) quarter resolution refinement selection map. 1 indicates refined 4x4 patch locations.
43 | """
44 |
45 | # For TorchScript export optimization.
46 | __constants__ = ['kernel_size', 'patch_crop_method', 'patch_replace_method']
47 |
48 | def __init__(self,
49 | mode: str,
50 | sample_pixels: int,
51 | threshold: float,
52 | kernel_size: int = 3,
53 | prevent_oversampling: bool = True,
54 | patch_crop_method: str = 'unfold',
55 | patch_replace_method: str = 'scatter_nd'):
56 | super().__init__()
57 | assert mode in ['full', 'sampling', 'thresholding']
58 | assert kernel_size in [1, 3]
59 | assert patch_crop_method in ['unfold', 'roi_align', 'gather']
60 | assert patch_replace_method in ['scatter_nd', 'scatter_element']
61 |
62 | self.mode = mode
63 | self.sample_pixels = sample_pixels
64 | self.threshold = threshold
65 | self.kernel_size = kernel_size
66 | self.prevent_oversampling = prevent_oversampling
67 | self.patch_crop_method = patch_crop_method
68 | self.patch_replace_method = patch_replace_method
69 |
70 | channels = [128, 64, 32, 16, 1] #4->1
71 | self.conv1 = nn.Conv2d(channels[0] + 4, channels[1], kernel_size, bias=False) # 6->3
72 | self.bn1 = nn.BatchNorm2d(channels[1])
73 | self.conv2 = nn.Conv2d(channels[1], channels[2], kernel_size, bias=False)
74 | self.bn2 = nn.BatchNorm2d(channels[2])
75 | self.conv3 = nn.Conv2d(channels[2] + 6, channels[3], kernel_size, bias=False)
76 | self.bn3 = nn.BatchNorm2d(channels[3])
77 | self.conv4 = nn.Conv2d(channels[3], channels[4], kernel_size, bias=True)
78 | self.relu = nn.ReLU(True)
79 |
80 | def forward(self,
81 | src: torch.Tensor, #image
82 | #bgr: torch.Tensor,
83 | pha: torch.Tensor, #probability_map (B, 1, H, W)
84 | #fgr: torch.Tensor, #warp_result
85 | err: torch.Tensor, #uncertainty map (B, 1, H, W)
86 | hid: torch.Tensor): #feature
87 |
88 | # 1.get size
89 | H_full, W_full = src.shape[2:]
90 | H_half, W_half = H_full // 2, W_full // 2
91 | H_quat, W_quat = H_full // 4, W_full // 4
92 |
93 | src_bgr = src #torch.cat([src, bgr], dim=1)
94 |
95 | if self.mode != 'full':
96 | err = F.interpolate(err, (H_quat, W_quat), mode='bilinear', align_corners=False) # downsample error map to 1/4
97 | ref = self.select_refinement_regions(err) #select top error regions
98 | idx = torch.nonzero(ref.squeeze(1))
99 | idx = idx[:, 0], idx[:, 1], idx[:, 2]
100 |
101 | if idx[0].size(0) > 0:
102 | x = torch.cat([hid, pha], dim=1) #32 + 1
103 | x = F.interpolate(x, (H_half, W_half), mode='bilinear', align_corners=False)
104 | x = self.crop_patch(x, idx, 2, 3 if self.kernel_size == 3 else 0)
105 |
106 | y = F.interpolate(src_bgr, (H_half, W_half), mode='bilinear', align_corners=False) # 6->3
107 | y = self.crop_patch(y, idx, 2, 3 if self.kernel_size == 3 else 0)
108 |
109 | x = self.conv1(torch.cat([x, y], dim=1))
110 | x = self.bn1(x)
111 | x = self.relu(x)
112 | x = self.conv2(x)
113 | x = self.bn2(x)
114 | x = self.relu(x)
115 |
116 | x = F.interpolate(x, 8 if self.kernel_size == 3 else 4, mode='nearest')
117 | y = self.crop_patch(src_bgr, idx, 4, 2 if self.kernel_size == 3 else 0)
118 |
119 | x = self.conv3(torch.cat([x, y], dim=1))
120 | x = self.bn3(x)
121 | x = self.relu(x)
122 | x = self.conv4(x)
123 |
124 | out = pha # torch.cat([pha, fgr], dim=1)
125 | out = F.interpolate(out, (H_full, W_full), mode='bilinear', align_corners=False)
126 | out = self.replace_patch(out, x, idx)
127 | pha = out #[:, :1]
128 | #fgr = out[:, 1:]
129 | else:
130 | pha = F.interpolate(pha, (H_full, W_full), mode='bilinear', align_corners=False)
131 | #fgr = F.interpolate(fgr, (H_full, W_full), mode='bilinear', align_corners=False)
132 | else:
133 | x = torch.cat([hid, pha], dim=1)
134 | x = F.interpolate(x, (H_half, W_half), mode='bilinear', align_corners=False)
135 | y = F.interpolate(src_bgr, (H_half, W_half), mode='bilinear', align_corners=False)
136 | if self.kernel_size == 3:
137 | x = F.pad(x, (3, 3, 3, 3))
138 | y = F.pad(y, (3, 3, 3, 3))
139 |
140 | x = self.conv1(torch.cat([x, y], dim=1))
141 | x = self.bn1(x)
142 | x = self.relu(x)
143 | x = self.conv2(x)
144 | x = self.bn2(x)
145 | x = self.relu(x)
146 |
147 | if self.kernel_size == 3:
148 | x = F.interpolate(x, (H_full + 4, W_full + 4))
149 | y = F.pad(src_bgr, (2, 2, 2, 2))
150 | else:
151 | x = F.interpolate(x, (H_full, W_full), mode='nearest')
152 | y = src_bgr
153 |
154 | x = self.conv3(torch.cat([x, y], dim=1))
155 | x = self.bn3(x)
156 | x = self.relu(x)
157 | x = self.conv4(x)
158 |
159 | pha = x[:, :1]
160 | #fgr = x[:, 1:]
161 | ref = torch.ones((src.size(0), 1, H_quat, W_quat), device=src.device, dtype=src.dtype)
162 |
163 | return pha, ref #fgr, ref
164 |
165 | def select_refinement_regions(self, err: torch.Tensor):
166 | """
167 | Select refinement regions.
168 | Input:
169 | err: error map (B, 1, H, W)
170 | Output:
171 | ref: refinement regions (B, 1, H, W). FloatTensor. 1 is selected, 0 is not.
172 | """
173 | if self.mode == 'sampling':
174 | # Sampling mode.
175 | b, _, h, w = err.shape
176 | err = err.view(b, -1)
177 | idx = err.topk(self.sample_pixels // 16, dim=1, sorted=False).indices
178 | ref = torch.zeros_like(err)
179 | ref.scatter_(1, idx, 1.)
180 | if self.prevent_oversampling:
181 | ref.mul_(err.gt(0).float())
182 | ref = ref.view(b, 1, h, w)
183 | else:
184 | # Thresholding mode.
185 | ref = err.gt(self.threshold).float()
186 | return ref
187 |
188 | def crop_patch(self,
189 | x: torch.Tensor,
190 | idx: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
191 | size: int,
192 | padding: int):
193 | """
194 | Crops selected patches from image given indices.
195 |
196 | Inputs:
197 | x: image (B, C, H, W).
198 | idx: selection indices Tuple[(P,), (P,), (P,),], where the 3 values are (B, H, W) index.
199 | size: center size of the patch, also stride of the crop.
200 | padding: expansion size of the patch.
201 | Output:
202 | patch: (P, C, h, w), where h = w = size + 2 * padding.
203 | """
204 | if padding != 0:
205 | x = F.pad(x, (padding,) * 4)
206 |
207 | if self.patch_crop_method == 'unfold':
208 | # Use unfold. Best performance for PyTorch and TorchScript.
209 | return x.permute(0, 2, 3, 1) \
210 | .unfold(1, size + 2 * padding, size) \
211 | .unfold(2, size + 2 * padding, size)[idx[0], idx[1], idx[2]]
212 | elif self.patch_crop_method == 'roi_align':
213 | # Use roi_align. Best compatibility for ONNX.
214 | idx = idx[0].type_as(x), idx[1].type_as(x), idx[2].type_as(x)
215 | b = idx[0]
216 | x1 = idx[2] * size - 0.5
217 | y1 = idx[1] * size - 0.5
218 | x2 = idx[2] * size + size + 2 * padding - 0.5
219 | y2 = idx[1] * size + size + 2 * padding - 0.5
220 | boxes = torch.stack([b, x1, y1, x2, y2], dim=1)
221 | return torchvision.ops.roi_align(x, boxes, size + 2 * padding, sampling_ratio=1)
222 | else:
223 | # Use gather. Crops out patches pixel by pixel.
224 | idx_pix = self.compute_pixel_indices(x, idx, size, padding)
225 | pat = torch.gather(x.view(-1), 0, idx_pix.view(-1))
226 | pat = pat.view(-1, x.size(1), size + 2 * padding, size + 2 * padding)
227 | return pat
228 |
229 | def replace_patch(self,
230 | x: torch.Tensor,
231 | y: torch.Tensor,
232 | idx: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]):
233 | """
234 | Replaces patches back into image given index.
235 |
236 | Inputs:
237 | x: image (B, C, H, W)
238 | y: patches (P, C, h, w)
239 | idx: selection indices Tuple[(P,), (P,), (P,)] where the 3 values are (B, H, W) index.
240 |
241 | Output:
242 | image: (B, C, H, W), where patches at idx locations are replaced with y.
243 | """
244 | xB, xC, xH, xW = x.shape
245 | yB, yC, yH, yW = y.shape
246 | if self.patch_replace_method == 'scatter_nd':
247 | # Use scatter_nd. Best performance for PyTorch and TorchScript. Replacing patch by patch.
248 | x = x.view(xB, xC, xH // yH, yH, xW // yW, yW).permute(0, 2, 4, 1, 3, 5)
249 | x[idx[0], idx[1], idx[2]] = y
250 | x = x.permute(0, 3, 1, 4, 2, 5).view(xB, xC, xH, xW)
251 | return x
252 | else:
253 | # Use scatter_element. Best compatibility for ONNX. Replacing pixel by pixel.
254 | idx_pix = self.compute_pixel_indices(x, idx, size=4, padding=0)
255 | return x.view(-1).scatter_(0, idx_pix.view(-1), y.view(-1)).view(x.shape)
256 |
257 | def compute_pixel_indices(self,
258 | x: torch.Tensor,
259 | idx: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
260 | size: int,
261 | padding: int):
262 | """
263 | Compute selected pixel indices in the tensor.
264 | Used for crop_method == 'gather' and replace_method == 'scatter_element', which crop and replace pixel by pixel.
265 | Input:
266 | x: image: (B, C, H, W)
267 | idx: selection indices Tuple[(P,), (P,), (P,),], where the 3 values are (B, H, W) index.
268 | size: center size of the patch, also stride of the crop.
269 | padding: expansion size of the patch.
270 | Output:
271 | idx: (P, C, O, O) long tensor where O is the output size: size + 2 * padding, P is number of patches.
272 | the element are indices pointing to the input x.view(-1).
273 | """
274 | B, C, H, W = x.shape
275 | S, P = size, padding
276 | O = S + 2 * P
277 | b, y, x = idx
278 | n = b.size(0)
279 | c = torch.arange(C)
280 | o = torch.arange(O)
281 | idx_pat = (c * H * W).view(C, 1, 1).expand([C, O, O]) + (o * W).view(1, O, 1).expand([C, O, O]) + o.view(1, 1, O).expand([C, O, O])
282 | idx_loc = b * W * H + y * W * S + x * S
283 | idx_pix = idx_loc.view(-1, 1, 1, 1).expand([n, C, O, O]) + idx_pat.view(1, C, O, O).expand([n, C, O, O])
284 | return idx_pix
285 |
--------------------------------------------------------------------------------
/AOC-Net/complete_project/AOCNet/networks/layers/shannon_entropy.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import matplotlib.pyplot as plt
3 | import numpy as np
4 |
5 | def normalize(image, MIN_BOUND, MAX_BOUND):
6 | image = (image - MIN_BOUND) / (MAX_BOUND - MIN_BOUND)
7 | reverse_image = 1 - image
8 | return reverse_image
9 |
10 | def cal_shannon_entropy(preds): # (batch, obj_num, 128, 128)
11 | uncertainty = -1.0 * torch.sum(preds * torch.log(preds + 1e-6), dim=1, keepdim=True) # (batch, 1, 128, 128)
12 | uncertainty_norm = normalize(uncertainty, 0, np.log(2)) * 7
13 | return uncertainty,uncertainty_norm
14 |
15 |
16 | def normalize_train(image, MIN_BOUND, MAX_BOUND):
17 | image = (image - MIN_BOUND) / (MAX_BOUND - MIN_BOUND)
18 | #reverse_image = 1 - image
19 | return image #reverse_image
20 |
21 | def cal_shannon_entropy_train(preds): # (batch, obj_num, 128, 128)
22 | uncertainty = -1.0 * torch.sum(preds * torch.log(preds + 1e-6), dim=1, keepdim=True) # (batch, 1, 128, 128)
23 | uncertainty_norm = normalize_train(uncertainty, 0, np.log(2)) * 7
24 | return uncertainty,uncertainty_norm
25 |
26 |
27 |
28 | def show_shannon_entropy(uncertainty,uncertainty_norm,unc_rate=0.5): #torch.tensor, torch.tensor
29 | uncertainty = uncertainty.cpu().numpy().squeeze().astype('float32')
30 | uncertainty = uncertainty * (uncertainty > unc_rate)
31 | uncertainty_norm = uncertainty_norm.cpu().numpy().squeeze().astype('float32')
32 | uncertainty_norm = uncertainty_norm
33 |
34 | plt.figure()
35 | plt.subplot(1, 6, 1)
36 | plt.imshow(save_pre_den)
37 | plt.subplot(1, 6, 2)
38 | plt.imshow(density_gt)
39 | plt.subplot(1, 6, 3)
40 | plt.imshow(save_pre_dmp_to_att)
41 | plt.subplot(1, 6, 4)
42 | plt.imshow(save_pre_att_2)
43 | plt.subplot(1, 6, 5)
44 | plt.imshow(uncertainty, cmap='inferno')
45 | plt.subplot(1, 6, 6)
46 | plt.imshow(uncertainty_norm, cmap='inferno')
47 | plt.show()
48 |
49 | def save_shannon_entropy(uncertainty,uncertainty_norm,save_path,unc_rate=0.5): #torch.tensor, torch.tensor
50 | uncertainty = uncertainty.cpu().numpy().squeeze().astype('float32')
51 | uncertainty = uncertainty * (uncertainty > unc_rate)
52 | uncertainty_norm = uncertainty_norm.cpu().numpy().squeeze().astype('float32')
53 | uncertainty_norm = uncertainty_norm
54 |
55 | plt.figure()
56 | plt.subplot(1, 2, 1)
57 | plt.imshow(uncertainty, cmap='inferno')
58 | plt.subplot(1, 2, 2)
59 | plt.imshow(uncertainty_norm, cmap='inferno')
60 | #plt.show()
61 | plt.savefig(save_path)
62 | plt.close()
63 |
64 | def save_shannon_entropy_calculated(uncertainty,save_path,unc_rate=1): #torch.tensor, torch.tensor
65 | uncertainty = uncertainty.cpu().numpy().squeeze().astype('float32')
66 |
67 |
68 |
69 | plt.figure()
70 | unc_rate=0
71 | uncertainty_org1 = uncertainty
72 | uncertainty_thresh1 = (uncertainty > unc_rate)
73 | uncertainty1 = uncertainty * (uncertainty > unc_rate)
74 | plt.subplot(3,3, 1)
75 | plt.imshow(uncertainty_org1, cmap='inferno')
76 | plt.subplot(3,3, 2)
77 | plt.imshow(uncertainty_thresh1, cmap='inferno')
78 | plt.subplot(3,3, 3)
79 | plt.imshow(uncertainty1, cmap='inferno')
80 |
81 | unc_rate=1
82 | uncertainty_org2 = uncertainty
83 | uncertainty_thresh2 = (uncertainty > unc_rate)
84 | uncertainty2 = uncertainty * (uncertainty > unc_rate)
85 | plt.subplot(3,3, 4)
86 | plt.imshow(uncertainty_org2, cmap='inferno')
87 | plt.subplot(3,3, 5)
88 | plt.imshow(uncertainty_thresh2, cmap='inferno')
89 | plt.subplot(3,3, 6)
90 | plt.imshow(uncertainty2, cmap='inferno')
91 |
92 | unc_rate=10
93 | uncertainty_org3 = uncertainty
94 | uncertainty_thresh3 = (uncertainty > unc_rate)
95 | uncertainty3 = uncertainty * (uncertainty > unc_rate)
96 | plt.subplot(3,3, 7)
97 | plt.imshow(uncertainty_org2, cmap='inferno')
98 | plt.subplot(3,3, 8)
99 | plt.imshow(uncertainty_thresh2, cmap='inferno')
100 | plt.subplot(3,3, 9)
101 | plt.imshow(uncertainty2, cmap='inferno')
102 |
103 |
104 |
105 | #plt.show()
106 | plt.savefig(save_path)
107 | plt.close()
--------------------------------------------------------------------------------
/AOC-Net/complete_project/AOCNet/networks/layers/shanoon_entropy.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import matplotlib.pyplot as plt
3 | import numpy as np
4 |
5 | def normalize(image, MIN_BOUND, MAX_BOUND):
6 | image = (image - MIN_BOUND) / (MAX_BOUND - MIN_BOUND)
7 | reverse_image = 1 - image
8 | return reverse_image
9 |
10 | def cal_shanoon_entropy(preds): # (batch, obj_num, 128, 128)
11 | uncertainty = -1.0 * torch.sum(preds * torch.log(preds + 1e-6), dim=1, keepdim=True) # (batch, 1, 128, 128)
12 | uncertainty_norm = normalize(uncertainty, 0, np.log(2)) * 7
13 | return uncertainty,uncertainty_norm
14 |
15 |
16 | def normalize_train(image, MIN_BOUND, MAX_BOUND):
17 | image = (image - MIN_BOUND) / (MAX_BOUND - MIN_BOUND)
18 | #reverse_image = 1 - image
19 | return image #reverse_image
20 |
21 | def cal_shanoon_entropy_train(preds): # (batch, obj_num, 128, 128)
22 | uncertainty = -1.0 * torch.sum(preds * torch.log(preds + 1e-6), dim=1, keepdim=True) # (batch, 1, 128, 128)
23 | uncertainty_norm = normalize_train(uncertainty, 0, np.log(2)) * 7
24 | return uncertainty,uncertainty_norm
25 |
26 |
27 |
28 | def show_shanoon_entropy(uncertainty,uncertainty_norm,unc_rate=0.5): #torch.tensor, torch.tensor
29 | uncertainty = uncertainty.cpu().numpy().squeeze().astype('float32')
30 | uncertainty = uncertainty * (uncertainty > unc_rate)
31 | uncertainty_norm = uncertainty_norm.cpu().numpy().squeeze().astype('float32')
32 | uncertainty_norm = uncertainty_norm
33 |
34 | plt.figure()
35 | plt.subplot(1, 6, 1)
36 | plt.imshow(save_pre_den)
37 | plt.subplot(1, 6, 2)
38 | plt.imshow(density_gt)
39 | plt.subplot(1, 6, 3)
40 | plt.imshow(save_pre_dmp_to_att)
41 | plt.subplot(1, 6, 4)
42 | plt.imshow(save_pre_att_2)
43 | plt.subplot(1, 6, 5)
44 | plt.imshow(uncertainty, cmap='inferno')
45 | plt.subplot(1, 6, 6)
46 | plt.imshow(uncertainty_norm, cmap='inferno')
47 | plt.show()
48 |
49 | def save_shanoon_entropy(uncertainty,uncertainty_norm,save_path,unc_rate=0.5): #torch.tensor, torch.tensor
50 | uncertainty = uncertainty.cpu().numpy().squeeze().astype('float32')
51 | uncertainty = uncertainty * (uncertainty > unc_rate)
52 | uncertainty_norm = uncertainty_norm.cpu().numpy().squeeze().astype('float32')
53 | uncertainty_norm = uncertainty_norm
54 |
55 | plt.figure()
56 | plt.subplot(1, 2, 1)
57 | plt.imshow(uncertainty, cmap='inferno')
58 | plt.subplot(1, 2, 2)
59 | plt.imshow(uncertainty_norm, cmap='inferno')
60 | #plt.show()
61 | plt.savefig(save_path)
62 | plt.close()
63 |
64 | def save_shanoon_entropy_calculated(uncertainty,save_path,unc_rate=1): #torch.tensor, torch.tensor
65 | uncertainty = uncertainty.cpu().numpy().squeeze().astype('float32')
66 |
67 |
68 |
69 | plt.figure()
70 | unc_rate=0
71 | uncertainty_org1 = uncertainty
72 | uncertainty_thresh1 = (uncertainty > unc_rate)
73 | uncertainty1 = uncertainty * (uncertainty > unc_rate)
74 | plt.subplot(3,3, 1)
75 | plt.imshow(uncertainty_org1, cmap='inferno')
76 | plt.subplot(3,3, 2)
77 | plt.imshow(uncertainty_thresh1, cmap='inferno')
78 | plt.subplot(3,3, 3)
79 | plt.imshow(uncertainty1, cmap='inferno')
80 |
81 | unc_rate=1
82 | uncertainty_org2 = uncertainty
83 | uncertainty_thresh2 = (uncertainty > unc_rate)
84 | uncertainty2 = uncertainty * (uncertainty > unc_rate)
85 | plt.subplot(3,3, 4)
86 | plt.imshow(uncertainty_org2, cmap='inferno')
87 | plt.subplot(3,3, 5)
88 | plt.imshow(uncertainty_thresh2, cmap='inferno')
89 | plt.subplot(3,3, 6)
90 | plt.imshow(uncertainty2, cmap='inferno')
91 |
92 | unc_rate=10
93 | uncertainty_org3 = uncertainty
94 | uncertainty_thresh3 = (uncertainty > unc_rate)
95 | uncertainty3 = uncertainty * (uncertainty > unc_rate)
96 | plt.subplot(3,3, 7)
97 | plt.imshow(uncertainty_org2, cmap='inferno')
98 | plt.subplot(3,3, 8)
99 | plt.imshow(uncertainty_thresh2, cmap='inferno')
100 | plt.subplot(3,3, 9)
101 | plt.imshow(uncertainty2, cmap='inferno')
102 |
103 |
104 |
105 | #plt.show()
106 | plt.savefig(save_path)
107 | plt.close()
108 |
--------------------------------------------------------------------------------
/AOC-Net/complete_project/AOCNet/scripts/eval.sh:
--------------------------------------------------------------------------------
1 | datasets="youtubevos"
2 | config="configs.resnet101_aocnet_2"
3 | python ../tools/eval_net.py --config ${config} --dataset ${datasets} --ckpt_step 400000 --global_chunks 16 --gpu_id 0 --mem_every 5
4 |
--------------------------------------------------------------------------------
/AOC-Net/complete_project/AOCNet/scripts/train.sh:
--------------------------------------------------------------------------------
1 | datasets="youtubevos"
2 | config="configs.resnet101_aoc"
3 | # training for 200k with lr=0.2
4 | python ../tools/train_net_mm.py --config ${config} --datasets ${datasets} --global_chunks 1
5 |
6 | # go on training for 200k with lr=0.1
7 | config="configs.resnet101_aoc_2"
8 | python ../tools/train_net_mm.py --config ${config} --datasets ${datasets} --global_chunks 1
9 |
--------------------------------------------------------------------------------
/AOC-Net/complete_project/AOCNet/tools/eval_net_mm_rpa.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('.')
3 | sys.path.append('..')
4 | from networks.engine.eval_manager_mm_rpa import Evaluator
5 | import importlib
6 |
7 | def main():
8 | import argparse
9 | parser = argparse.ArgumentParser(description="Eval")
10 | parser.add_argument('--exp_name', type=str, default='')
11 |
12 | parser.add_argument('--config', type=str, default='configs.resnet101')
13 |
14 | parser.add_argument('--gpu_id', type=int, default=0)
15 |
16 | parser.add_argument('--ckpt_path', type=str, default='')
17 | parser.add_argument('--ckpt_step', type=int, default=-1)
18 |
19 | parser.add_argument('--dataset', type=str, default='')
20 |
21 | parser.add_argument('--flip', action='store_true')
22 | parser.set_defaults(flip=False)
23 | parser.add_argument('--ms', nargs='+', type=float, default=[1.])
24 | parser.add_argument('--max_long_edge', type=int, default=-1)
25 | parser.add_argument('--mem_every', type=int, default=-1)
26 | parser.add_argument('--ucr', type=float, default=1.0)
27 | parser.add_argument('--float16', action='store_true')
28 | parser.add_argument('--vis', action='store_true')
29 | parser.set_defaults(float16=False)
30 | parser.add_argument('--global_atrous_rate', type=int, default=1)
31 | parser.add_argument('--global_chunks', type=int, default=4)
32 | parser.add_argument('--min_matching_pixels', type=int, default=0)
33 | parser.add_argument('--no_local_parallel', dest='local_parallel', action='store_false')
34 | parser.set_defaults(local_parallel=True)
35 | args = parser.parse_args()
36 |
37 | config = importlib.import_module(args.config)
38 | cfg = config.cfg
39 |
40 | cfg.TEST_GPU_ID = args.gpu_id
41 | if args.exp_name != '':
42 | cfg.EXP_NAME = args.exp_name
43 | if args.mem_every != '':
44 | cfg.MEM_EVERY = args.mem_every
45 | if args.ucr != '':
46 | cfg.UNC_RATIO = args.ucr
47 | if args.ckpt_path != '':
48 | cfg.TEST_CKPT_PATH = args.ckpt_path
49 | if args.ckpt_step > 0:
50 | cfg.TEST_CKPT_STEP = args.ckpt_step
51 | if args.dataset != '':
52 | cfg.TEST_DATASET = args.dataset
53 |
54 | cfg.UNC_VIS = args.vis
55 |
56 | cfg.TEST_FLIP = args.flip
57 | cfg.TEST_MULTISCALE = args.ms
58 | if args.max_long_edge > 0:
59 | cfg.TEST_MAX_SIZE = args.max_long_edge
60 | else:
61 | cfg.TEST_MAX_SIZE = 800 * 1.3 if cfg.TEST_MULTISCALE == [1.] else 800
62 |
63 | cfg.MODEL_FLOAT16_MATCHING = args.float16
64 | if 'cfbip' in cfg.MODEL_MODULE:
65 | cfg.TEST_GLOBAL_ATROUS_RATE = [args.global_atrous_rate, 1, 1]
66 | else:
67 | cfg.TEST_GLOBAL_ATROUS_RATE = args.global_atrous_rate
68 | cfg.TEST_GLOBAL_CHUNKS = args.global_chunks
69 | cfg.TEST_LOCAL_PARALLEL = args.local_parallel
70 |
71 | if args.min_matching_pixels > 0:
72 | cfg.TEST_GLOBAL_MATCHING_MIN_PIXEL = args.min_matching_pixels
73 |
74 | evaluator = Evaluator(cfg=cfg)
75 | evaluator.evaluating()
76 |
77 | if __name__ == '__main__':
78 | main()
79 |
80 |
--------------------------------------------------------------------------------
/AOC-Net/complete_project/AOCNet/tools/train_net_mm.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('.')
3 | sys.path.append('..')
4 | from networks.engine.train_manager_mm import Trainer
5 | import torch.multiprocessing as mp
6 | import importlib
7 |
8 | def main_worker(gpu, cfg):
9 | # Initiate a training manager
10 | trainer = Trainer(rank=gpu, cfg=cfg)
11 | # Start Training
12 | trainer.sequential_training()
13 |
14 | def main():
15 | import argparse
16 | parser = argparse.ArgumentParser(description="Train")
17 | parser.add_argument('--exp_name', type=str, default='')
18 | parser.add_argument('--config', type=str, default='configs.resnet101')
19 |
20 | parser.add_argument('--start_gpu', type=int, default=0)
21 | parser.add_argument('--gpu_num', type=int, default=-1)
22 | parser.add_argument('--batch_size', type=int, default=-1)
23 |
24 | parser.add_argument('--pretrained_path', type=str, default='')
25 |
26 | parser.add_argument('--datasets', nargs='+', type=str, default=['youtubevos'])
27 | parser.add_argument('--lr', type=float, default=-1.)
28 | parser.add_argument('--total_step', type=int, default=-1.)
29 | parser.add_argument('--start_step', type=int, default=-1.)
30 |
31 | parser.add_argument('--float16', action='store_true')
32 | parser.set_defaults(float16=False)
33 | parser.add_argument('--global_atrous_rate', type=int, default=1)
34 | parser.add_argument('--global_chunks', type=int, default=20)
35 | parser.add_argument('--no_local_parallel', dest='local_parallel', action='store_false')
36 | parser.set_defaults(local_parallel=True)
37 | args = parser.parse_args()
38 |
39 | config = importlib.import_module(args.config)
40 | cfg = config.cfg
41 |
42 | if args.exp_name != '':
43 | cfg.EXP_NAME = args.exp_name
44 |
45 | cfg.DIST_START_GPU = args.start_gpu
46 | if args.gpu_num > 0:
47 | cfg.TRAIN_GPUS = args.gpu_num
48 | if args.batch_size > 0:
49 | cfg.TRAIN_BATCH_SIZE = args.batch_size
50 |
51 | if args.pretrained_path != '':
52 | cfg.PRETRAIN_MODEL = args.pretrained_path
53 |
54 | if args.lr > 0:
55 | cfg.TRAIN_LR = args.lr
56 | if args.total_step > 0:
57 | cfg.TRAIN_TOTAL_STEPS = args.total_step
58 | cfg.TRAIN_START_SEQ_TRAINING_STEPS = int(args.total_step / 2)
59 | cfg.TRAIN_HARD_MINING_STEP = int(args.total_step / 2)
60 | if args.start_step > 0:
61 | cfg.TRAIN_START_STEP = args.start_step
62 |
63 | cfg.MODEL_FLOAT16_MATCHING = args.float16
64 | if 'cfbip' in cfg.MODEL_MODULE:
65 | cfg.TRAIN_GLOBAL_ATROUS_RATE = [args.global_atrous_rate, 1, 1]
66 | else:
67 | cfg.TRAIN_GLOBAL_ATROUS_RATE = args.global_atrous_rate
68 | cfg.TRAIN_GLOBAL_CHUNKS = args.global_chunks
69 | cfg.TRAIN_LOCAL_PARALLEL = args.local_parallel
70 |
71 | # Use torch.multiprocessing.spawn to launch distributed processes
72 | mp.spawn(main_worker, nprocs=cfg.TRAIN_GPUS, args=(cfg,))
73 |
74 | if __name__ == '__main__':
75 | main()
76 |
77 |
--------------------------------------------------------------------------------
/AOC-Net/complete_project/AOCNet/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JerryX1110/Robust-Video-Object-Segmentation/f562afe70ac0f960d27980d3e82b237f8dea45ec/AOC-Net/complete_project/AOCNet/utils/__init__.py
--------------------------------------------------------------------------------
/AOC-Net/complete_project/AOCNet/utils/checkpoint.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import numpy as np
4 |
5 | def load_network_and_optimizer(net, opt, pretrained_dir, gpu):
6 | pretrained = torch.load(
7 | pretrained_dir,
8 | map_location=torch.device("cuda:"+str(gpu)))
9 | pretrained_dict = pretrained['state_dict']
10 | model_dict = net.state_dict()
11 | pretrained_dict_update = {}
12 | pretrained_dict_remove = []
13 | for k, v in pretrained_dict.items():
14 | if k in model_dict:
15 | pretrained_dict_update[k] = v
16 | elif k[:7] == 'module.':
17 | if k[7:] in model_dict:
18 | pretrained_dict_update[k[7:]] = v
19 | else:
20 | pretrained_dict_remove.append(k)
21 | model_dict.update(pretrained_dict_update)
22 | net.load_state_dict(model_dict)
23 | opt.load_state_dict(pretrained['optimizer'])
24 | del(pretrained)
25 | return net.cuda(gpu), opt, pretrained_dict_remove
26 |
27 | def load_network_and_not_optimizer(net, opt, pretrained_dir, gpu):
28 | pretrained = torch.load(
29 | pretrained_dir,
30 | map_location=torch.device("cuda:"+str(gpu)))
31 | pretrained_dict = pretrained['state_dict']
32 | model_dict = net.state_dict()
33 | pretrained_dict_update = {}
34 | pretrained_dict_remove = []
35 | for k, v in pretrained_dict.items():
36 | if k in model_dict:
37 | pretrained_dict_update[k] = v
38 | elif k[:7] == 'module.':
39 | if k[7:] in model_dict:
40 | pretrained_dict_update[k[7:]] = v
41 | else:
42 | pretrained_dict_remove.append(k)
43 | model_dict.update(pretrained_dict_update)
44 | net.load_state_dict(model_dict)
45 | #opt.load_state_dict(pretrained['optimizer'])
46 | del(pretrained)
47 | return net.cuda(gpu), opt, pretrained_dict_remove
48 |
49 | def load_network(net, pretrained_dir, gpu):
50 | pretrained = torch.load(
51 | pretrained_dir,
52 | map_location=torch.device("cuda:"+str(gpu)))
53 | pretrained_dict = pretrained['state_dict']
54 | model_dict = net.state_dict()
55 | pretrained_dict_update = {}
56 | pretrained_dict_remove = []
57 | for k, v in pretrained_dict.items():
58 | if k in model_dict and model_dict[k].size()==pretrained_dict[k].size():
59 | pretrained_dict_update[k] = v
60 | #if model_dict[k].size()!=pretrained_dict[k].size():
61 | # print("ERROR->",k)
62 | elif k[:7] == 'module.':
63 | if k[7:] in model_dict:
64 | pretrained_dict_update[k[7:]] = v
65 | else:
66 | pretrained_dict_remove.append(k)
67 | model_dict.update(pretrained_dict_update)
68 | net.load_state_dict(model_dict)
69 | del(pretrained)
70 | return net.cuda(gpu), pretrained_dict_remove
71 |
72 |
73 | def load_network_P2T(net, pretrained_dir, gpu):
74 | pretrained = torch.load(
75 | pretrained_dir,
76 | map_location=torch.device("cuda:"+str(gpu)))
77 | pretrained_dict = pretrained['state_dict']
78 | model_dict = net.state_dict()
79 | pretrained_dict_update = {}
80 | pretrained_dict_remove = []
81 | for k, v in pretrained_dict.items():
82 | print("pretrained_dict_k",k)
83 | if k in model_dict:
84 | print("k in new model")
85 | pretrained_dict_update[k] = v
86 | #for para in pretrained_dict_update[k].parameters():
87 | # para.requires_grad = False
88 | v.requires_grad = False
89 | elif k[:7] == 'module.':
90 | print("k in new model (module)")
91 | if k[7:] in model_dict:
92 | pretrained_dict_update[k[7:]] = v
93 | #for para in pretrained_dict_update[k].parameters():
94 | v.requires_grad = False
95 | else:
96 | print("k not in new model")
97 | pretrained_dict_remove.append(k)
98 |
99 | model_dict.update(pretrained_dict_update)
100 | net.load_state_dict(model_dict)
101 | del(pretrained)
102 | return net.cuda(gpu), pretrained_dict_remove
103 |
104 |
105 | def save_network(net, opt, step, save_path, max_keep=8):
106 | try:
107 | if not os.path.exists(save_path):
108 | os.makedirs(save_path)
109 | save_file = 'save_step_%s.pth' % (step)
110 | save_dir = os.path.join(save_path, save_file)
111 | torch.save({'state_dict': net.state_dict(), 'optimizer': opt.state_dict()}, save_dir)
112 | except:
113 | save_path = './saved_models'
114 | if not os.path.exists(save_path):
115 | os.makedirs(save_path)
116 | save_file = 'save_step_%s.pth' % (step)
117 | save_dir = os.path.join(save_path, save_file)
118 | torch.save({'state_dict': net.state_dict(), 'optimizer': opt.state_dict()}, save_dir)
119 |
120 | all_ckpt = os.listdir(save_path)
121 | if len(all_ckpt) > max_keep:
122 | all_step = []
123 | for ckpt_name in all_ckpt:
124 | step = int(ckpt_name.split('_')[-1].split('.')[0])
125 | all_step.append(step)
126 | all_step = list(np.sort(all_step))[:-max_keep]
127 | for step in all_step:
128 | ckpt_path = os.path.join(save_path, 'save_step_%s.pth' % (step))
129 | os.system('rm {}'.format(ckpt_path))
130 |
--------------------------------------------------------------------------------
/AOC-Net/complete_project/AOCNet/utils/eval.py:
--------------------------------------------------------------------------------
1 | import shutil
2 | import zipfile
3 | import os
4 |
5 | def zip_folder(source_folder, zip_dir):
6 | f = zipfile.ZipFile(zip_dir, 'w', zipfile.ZIP_DEFLATED)
7 | pre_len = len(os.path.dirname(source_folder))
8 | for dirpath, dirnames, filenames in os.walk(source_folder):
9 | for filename in filenames:
10 | pathfile = os.path.join(dirpath, filename)
11 | arcname = pathfile[pre_len:].strip(os.path.sep)
12 | f.write(pathfile, arcname)
13 | f.close()
--------------------------------------------------------------------------------
/AOC-Net/complete_project/AOCNet/utils/image.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | import torch
4 |
5 | ## for visulization
6 | import cv2
7 | import time
8 | import gc
9 | import matplotlib.pyplot as plt
10 | import seaborn as sns
11 |
12 | plt.rcParams['font.sans-serif']=['SimHei']
13 | plt.rcParams['axes.unicode_minus'] = False
14 | ##
15 |
16 | _palette = [0, 0, 0, 128, 0, 0, 0, 128, 0, 128, 128, 0, 0, 0, 128, 128, 0, 128, 0, 128, 128, 128, 128, 128, 64, 0, 0, 191, 0, 0, 64, 128, 0, 191, 128, 0, 64, 0, 128, 191, 0, 128, 64, 128, 128, 191, 128, 128, 0, 64, 0, 128, 64, 0, 0, 191, 0, 128, 191, 0, 0, 64, 128, 128, 64, 128, 22, 22, 22, 23, 23, 23, 24, 24, 24, 25, 25, 25, 26, 26, 26, 27, 27, 27, 28, 28, 28, 29, 29, 29, 30, 30, 30, 31, 31, 31, 32, 32, 32, 33, 33, 33, 34, 34, 34, 35, 35, 35, 36, 36, 36, 37, 37, 37, 38, 38, 38, 39, 39, 39, 40, 40, 40, 41, 41, 41, 42, 42, 42, 43, 43, 43, 44, 44, 44, 45, 45, 45, 46, 46, 46, 47, 47, 47, 48, 48, 48, 49, 49, 49, 50, 50, 50, 51, 51, 51, 52, 52, 52, 53, 53, 53, 54, 54, 54, 55, 55, 55, 56, 56, 56, 57, 57, 57, 58, 58, 58, 59, 59, 59, 60, 60, 60, 61, 61, 61, 62, 62, 62, 63, 63, 63, 64, 64, 64, 65, 65, 65, 66, 66, 66, 67, 67, 67, 68, 68, 68, 69, 69, 69, 70, 70, 70, 71, 71, 71, 72, 72, 72, 73, 73, 73, 74, 74, 74, 75, 75, 75, 76, 76, 76, 77, 77, 77, 78, 78, 78, 79, 79, 79, 80, 80, 80, 81, 81, 81, 82, 82, 82, 83, 83, 83, 84, 84, 84, 85, 85, 85, 86, 86, 86, 87, 87, 87, 88, 88, 88, 89, 89, 89, 90, 90, 90, 91, 91, 91, 92, 92, 92, 93, 93, 93, 94, 94, 94, 95, 95, 95, 96, 96, 96, 97, 97, 97, 98, 98, 98, 99, 99, 99, 100, 100, 100, 101, 101, 101, 102, 102, 102, 103, 103, 103, 104, 104, 104, 105, 105, 105, 106, 106, 106, 107, 107, 107, 108, 108, 108, 109, 109, 109, 110, 110, 110, 111, 111, 111, 112, 112, 112, 113, 113, 113, 114, 114, 114, 115, 115, 115, 116, 116, 116, 117, 117, 117, 118, 118, 118, 119, 119, 119, 120, 120, 120, 121, 121, 121, 122, 122, 122, 123, 123, 123, 124, 124, 124, 125, 125, 125, 126, 126, 126, 127, 127, 127, 128, 128, 128, 129, 129, 129, 130, 130, 130, 131, 131, 131, 132, 132, 132, 133, 133, 133, 134, 134, 134, 135, 135, 135, 136, 136, 136, 137, 137, 137, 138, 138, 138, 139, 139, 139, 140, 140, 140, 141, 141, 141, 142, 142, 142, 143, 143, 143, 144, 144, 144, 145, 145, 145, 146, 146, 146, 147, 147, 147, 148, 148, 148, 149, 149, 149, 150, 150, 150, 151, 151, 151, 152, 152, 152, 153, 153, 153, 154, 154, 154, 155, 155, 155, 156, 156, 156, 157, 157, 157, 158, 158, 158, 159, 159, 159, 160, 160, 160, 161, 161, 161, 162, 162, 162, 163, 163, 163, 164, 164, 164, 165, 165, 165, 166, 166, 166, 167, 167, 167, 168, 168, 168, 169, 169, 169, 170, 170, 170, 171, 171, 171, 172, 172, 172, 173, 173, 173, 174, 174, 174, 175, 175, 175, 176, 176, 176, 177, 177, 177, 178, 178, 178, 179, 179, 179, 180, 180, 180, 181, 181, 181, 182, 182, 182, 183, 183, 183, 184, 184, 184, 185, 185, 185, 186, 186, 186, 187, 187, 187, 188, 188, 188, 189, 189, 189, 190, 190, 190, 191, 191, 191, 192, 192, 192, 193, 193, 193, 194, 194, 194, 195, 195, 195, 196, 196, 196, 197, 197, 197, 198, 198, 198, 199, 199, 199, 200, 200, 200, 201, 201, 201, 202, 202, 202, 203, 203, 203, 204, 204, 204, 205, 205, 205, 206, 206, 206, 207, 207, 207, 208, 208, 208, 209, 209, 209, 210, 210, 210, 211, 211, 211, 212, 212, 212, 213, 213, 213, 214, 214, 214, 215, 215, 215, 216, 216, 216, 217, 217, 217, 218, 218, 218, 219, 219, 219, 220, 220, 220, 221, 221, 221, 222, 222, 222, 223, 223, 223, 224, 224, 224, 225, 225, 225, 226, 226, 226, 227, 227, 227, 228, 228, 228, 229, 229, 229, 230, 230, 230, 231, 231, 231, 232, 232, 232, 233, 233, 233, 234, 234, 234, 235, 235, 235, 236, 236, 236, 237, 237, 237, 238, 238, 238, 239, 239, 239, 240, 240, 240, 241, 241, 241, 242, 242, 242, 243, 243, 243, 244, 244, 244, 245, 245, 245, 246, 246, 246, 247, 247, 247, 248, 248, 248, 249, 249, 249, 250, 250, 250, 251, 251, 251, 252, 252, 252, 253, 253, 253, 254, 254, 254, 255, 255, 255]
17 |
18 | cluster_map_palette = [0, 0, 0, 200, 200, 200, 0, 128, 0, 128, 128, 0, 0, 0, 128, 128, 0, 128, 0, 128, 128, 128, 128, 128, 64, 0, 0, 191, 0, 0, 64, 128, 0, 191, 128, 0, 64, 0, 128, 191, 0, 128, 64, 128, 128, 191, 128, 128, 0, 64, 0, 128, 64, 0, 0, 191, 0, 128, 191, 0, 0, 64, 128, 128, 64, 128, 22, 22, 22, 23, 23, 23, 24, 24, 24, 25, 25, 25, 26, 26, 26, 27, 27, 27, 28, 28, 28, 29, 29, 29, 30, 30, 30, 31, 31, 31, 32, 32, 32, 33, 33, 33, 34, 34, 34, 35, 35, 35, 36, 36, 36, 37, 37, 37, 38, 38, 38, 39, 39, 39, 40, 40, 40, 41, 41, 41, 42, 42, 42, 43, 43, 43, 44, 44, 44, 45, 45, 45, 46, 46, 46, 47, 47, 47, 48, 48, 48, 49, 49, 49, 50, 50, 50, 51, 51, 51, 52, 52, 52, 53, 53, 53, 54, 54, 54, 55, 55, 55, 56, 56, 56, 57, 57, 57, 58, 58, 58, 59, 59, 59, 60, 60, 60, 61, 61, 61, 62, 62, 62, 63, 63, 63, 64, 64, 64, 65, 65, 65, 66, 66, 66, 67, 67, 67, 68, 68, 68, 69, 69, 69, 70, 70, 70, 71, 71, 71, 72, 72, 72, 73, 73, 73, 74, 74, 74, 75, 75, 75, 76, 76, 76, 77, 77, 77, 78, 78, 78, 79, 79, 79, 80, 80, 80, 81, 81, 81, 82, 82, 82, 83, 83, 83, 84, 84, 84, 85, 85, 85, 86, 86, 86, 87, 87, 87, 88, 88, 88, 89, 89, 89, 90, 90, 90, 91, 91, 91, 92, 92, 92, 93, 93, 93, 94, 94, 94, 95, 95, 95, 96, 96, 96, 97, 97, 97, 98, 98, 98, 99, 99, 99, 100, 100, 100, 101, 101, 101, 102, 102, 102, 103, 103, 103, 104, 104, 104, 105, 105, 105, 106, 106, 106, 107, 107, 107, 108, 108, 108, 109, 109, 109, 110, 110, 110, 111, 111, 111, 112, 112, 112, 113, 113, 113, 114, 114, 114, 115, 115, 115, 116, 116, 116, 117, 117, 117, 118, 118, 118, 119, 119, 119, 120, 120, 120, 121, 121, 121, 122, 122, 122, 123, 123, 123, 124, 124, 124, 125, 125, 125, 126, 126, 126, 127, 127, 127, 128, 128, 128, 129, 129, 129, 130, 130, 130, 131, 131, 131, 132, 132, 132, 133, 133, 133, 134, 134, 134, 135, 135, 135, 136, 136, 136, 137, 137, 137, 138, 138, 138, 139, 139, 139, 140, 140, 140, 141, 141, 141, 142, 142, 142, 143, 143, 143, 144, 144, 144, 145, 145, 145, 146, 146, 146, 147, 147, 147, 148, 148, 148, 149, 149, 149, 150, 150, 150, 151, 151, 151, 152, 152, 152, 153, 153, 153, 154, 154, 154, 155, 155, 155, 156, 156, 156, 157, 157, 157, 158, 158, 158, 159, 159, 159, 160, 160, 160, 161, 161, 161, 162, 162, 162, 163, 163, 163, 164, 164, 164, 165, 165, 165, 166, 166, 166, 167, 167, 167, 168, 168, 168, 169, 169, 169, 170, 170, 170, 171, 171, 171, 172, 172, 172, 173, 173, 173, 174, 174, 174, 175, 175, 175, 176, 176, 176, 177, 177, 177, 178, 178, 178, 179, 179, 179, 180, 180, 180, 181, 181, 181, 182, 182, 182, 183, 183, 183, 184, 184, 184, 185, 185, 185, 186, 186, 186, 187, 187, 187, 188, 188, 188, 189, 189, 189, 190, 190, 190, 191, 191, 191, 192, 192, 192, 193, 193, 193, 194, 194, 194, 195, 195, 195, 196, 196, 196, 197, 197, 197, 198, 198, 198, 199, 199, 199, 200, 200, 200, 201, 201, 201, 202, 202, 202, 203, 203, 203, 204, 204, 204, 205, 205, 205, 206, 206, 206, 207, 207, 207, 208, 208, 208, 209, 209, 209, 210, 210, 210, 211, 211, 211, 212, 212, 212, 213, 213, 213, 214, 214, 214, 215, 215, 215, 216, 216, 216, 217, 217, 217, 218, 218, 218, 219, 219, 219, 220, 220, 220, 221, 221, 221, 222, 222, 222, 223, 223, 223, 224, 224, 224, 225, 225, 225, 226, 226, 226, 227, 227, 227, 228, 228, 228, 229, 229, 229, 230, 230, 230, 231, 231, 231, 232, 232, 232, 233, 233, 233, 234, 234, 234, 235, 235, 235, 236, 236, 236, 237, 237, 237, 238, 238, 238, 239, 239, 239, 240, 240, 240, 241, 241, 241, 242, 242, 242, 243, 243, 243, 244, 244, 244, 245, 245, 245, 246, 246, 246, 247, 247, 247, 248, 248, 248, 249, 249, 249, 250, 250, 250, 251, 251, 251, 252, 252, 252, 253, 253, 253, 254, 254, 254, 255, 255, 255]
19 |
20 | def label2colormap(label):
21 |
22 | m = label.astype(np.uint8)
23 | r,c = m.shape
24 | cmap = np.zeros((r,c,3), dtype=np.uint8)
25 | cmap[:,:,0] = (m&1)<<7 | (m&8)<<3 | (m&64)>>1
26 | cmap[:,:,1] = (m&2)<<6 | (m&16)<<2 | (m&128)>>2
27 | cmap[:,:,2] = (m&4)<<5 | (m&32)<<1
28 | return cmap
29 |
30 | def masked_image(image, colored_mask, mask, alpha = 0.7):
31 | mask = np.expand_dims(mask > 0, axis=0)
32 | mask = np.repeat(mask, 3, axis=0)
33 | show_img = (image * alpha + colored_mask * (1 - alpha)) * mask + image * (1 - mask)
34 | return show_img
35 |
36 | def save_image(image, path):
37 | im = Image.fromarray(np.uint8(image * 255.).transpose((1, 2, 0)))
38 | im.save(path)
39 |
40 | def save_mask(mask_tensor, path):
41 | mask = mask_tensor.cpu().numpy().astype('uint8')
42 | mask = Image.fromarray(mask).convert('P')
43 | mask.putpalette(_palette)
44 | mask.save(path)
45 |
46 | def save_cluster_map(mask_tensor, path):
47 | mask = mask_tensor.cpu().numpy().astype('uint8')
48 | mask = Image.fromarray(mask).convert('P')
49 | mask.putpalette(cluster_map_palette)
50 | mask.save(path)
51 |
52 | def flip_tensor(tensor, dim=0):
53 | inv_idx = torch.arange(tensor.size(dim) - 1, -1, -1, device=tensor.device).long()
54 | tensor = tensor.index_select(dim, inv_idx)
55 | return tensor
56 |
57 | def save_matching_result(img,path):
58 |
59 | cv2.imwrite(path,img)
60 |
61 | def save_uncertain_mask(mask, path):
62 | #mask = mask_tensor.cpu().numpy().astype('uint8')
63 | mask = Image.fromarray(mask).convert('P')
64 | mask.putpalette(_palette)
65 | mask.save(path)
--------------------------------------------------------------------------------
/AOC-Net/complete_project/AOCNet/utils/learning.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 |
4 | def adjust_learning_rate(optimizer, base_lr, p, itr, max_itr, warm_up_steps=1000, is_cosine_decay=False, min_lr=1e-5):
5 |
6 | if itr < warm_up_steps:
7 | now_lr = base_lr * itr / warm_up_steps
8 | else:
9 | itr = itr - warm_up_steps
10 | max_itr = max_itr - warm_up_steps
11 | if is_cosine_decay:
12 | now_lr = base_lr * (math.cos(math.pi * itr / (max_itr + 1)) + 1.) * 0.5
13 | else:
14 | now_lr = base_lr * (1 - itr / (max_itr + 1)) ** p
15 |
16 | if now_lr < min_lr:
17 | now_lr = min_lr
18 |
19 | for param_group in optimizer.param_groups:
20 | param_group['lr'] = now_lr
21 | return now_lr
22 |
23 |
24 | def get_trainable_params(model, base_lr, weight_decay, beta_wd=True):
25 | params = []
26 | for key, value in model.named_parameters():
27 | if not value.requires_grad:
28 | continue
29 | wd = weight_decay
30 | if 'beta' in key:
31 | if not beta_wd:
32 | wd = 0.
33 | params += [{"params": [value], "lr": base_lr, "weight_decay": wd}]
34 | return params
--------------------------------------------------------------------------------
/AOC-Net/complete_project/AOCNet/utils/meters.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 |
4 | class AverageMeter(object):
5 | """Computes and stores the average and current value"""
6 |
7 | def __init__(self):
8 | self.val = 0
9 | self.avg = 0
10 | self.sum = 0
11 | self.count = 0
12 |
13 | def reset(self):
14 | self.val = 0
15 | self.avg = 0
16 | self.sum = 0
17 | self.count = 0
18 |
19 | def update(self, val, n=1):
20 | self.val = val
21 | self.sum += val * n
22 | self.count += n
23 | self.avg = self.sum / self.count
24 |
--------------------------------------------------------------------------------
/AOC-Net/complete_project/AOCNet/utils/metric.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def pytorch_iou(pred, target, obj_num, epsilon=1e-6):
4 | '''
5 | pred: [bs, h, w]
6 | target: [bs, h, w]
7 | obj_num: [bs]
8 | '''
9 | bs = obj_num.size(0)
10 | all_iou = []
11 | for idx in range(bs):
12 | now_pred = pred[idx].unsqueeze(0)
13 | now_target = target[idx].unsqueeze(0)
14 | now_obj_num = obj_num[idx]
15 |
16 | obj_ids = torch.arange(0, now_obj_num + 1, device=now_pred.device).int().view(-1, 1, 1)
17 | if obj_ids.size(0) == 1: # only contain background
18 | continue
19 | else:
20 | obj_ids = obj_ids[1:]
21 | now_pred = (now_pred == obj_ids).float()
22 | now_target = (now_target == obj_ids).float()
23 |
24 | intersection = (now_pred * now_target).sum((1, 2))
25 | union = ((now_pred + now_target) > 0).float().sum((1, 2))
26 |
27 | now_iou = (intersection + epsilon) / (union + epsilon)
28 |
29 | all_iou.append(now_iou.mean())
30 | if len(all_iou) > 0:
31 | all_iou = torch.stack(all_iou).mean()
32 | else:
33 | all_iou = torch.ones((1), device=pred.device)
34 | return all_iou
35 |
--------------------------------------------------------------------------------
/AOC-Net/complete_project/README.md:
--------------------------------------------------------------------------------
1 | After I cleaned the codes to reduce redundancy, as I do not have GPU resources for testing currently, it may potentially involve bugs.
2 |
--------------------------------------------------------------------------------
/AOC-Net/conditioning_layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class conditioning_layer(nn.Module):
7 | def __init__(self,
8 | in_dim=256,
9 | beta_percentage=0.3):
10 | super(conditioning_layer,self).__init__()
11 |
12 | self.beta_percentage = beta_percentage
13 |
14 | kernel_size = 1
15 | self.phi_layer = nn.Conv2d(in_dim,1,kernel_size=kernel_size,stride=1,padding=int((kernel_size-1)/2))
16 | self.mlp_layer = nn.Linear(in_dim, in_dim)
17 |
18 | nn.init.kaiming_normal_(self.phi_layer.weight,mode='fan_out',nonlinearity='relu')
19 |
20 |
21 | def forward(self, z_in):
22 |
23 | # Step 1: phi(z_in)
24 | x = self.phi_layer(z_in)
25 |
26 | # Step 2: beta
27 | x = x.reshape(x.size()[0],x.size()[1],-1)
28 | z_in_reshape = z_in.reshape(z_in.size()[0],z_in.size()[1],-1)
29 | beta_rank = int(self.beta_percentage*z_in.size()[-1]*z_in.size()[-2])
30 | beta_val, _ = torch.topk(x, k=beta_rank, dim=-1, sorted=True)
31 |
32 | # Step 3: pi_beta(phi(z_in))
33 | x = x > beta_val[...,-1,None]
34 |
35 | # Step 4: z_in \odot pi_beta(phi(z_in))
36 | z_in_masked = z_in_reshape * x
37 |
38 | # Step 5: GAP (z_in \odot pi_beta(phi(z_in)))
39 | z_in_masked_gap = torch.nn.functional.avg_pool1d(z_in_masked,
40 | kernel_size=z_in_masked.size()[-1]).squeeze(-1)
41 |
42 | # Step 6: MLP ( GAP (z_in \odot pi_beta(phi(z_in))) )
43 | out = mlp_layer(z_in_masked_gap)
44 |
45 | return out
46 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | # main image
2 | FROM nvidia/cuda:10.1-cudnn7-devel-ubuntu18.04
3 |
4 | # tweaked azureml pytorch image
5 | # as of Aug 31, 2020 pt 1.6 doesn't seem to work with horovod on native mixed precision
6 |
7 | LABEL maintainer="Albert"
8 | LABEL maintainer_email="alsadovn@microsoft.com"
9 | LABEL version="0.1"
10 |
11 | USER root:root
12 |
13 | ENV com.nvidia.cuda.version $CUDA_VERSION
14 | ENV com.nvidia.volumes.needed nvidia_driver
15 | ENV LANG=C.UTF-8 LC_ALL=C.UTF-8
16 | ENV DEBIAN_FRONTEND noninteractive
17 | ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64
18 | ENV NCCL_DEBUG=INFO
19 | ENV HOROVOD_GPU_ALLREDUCE=NCCL
20 |
21 | # Install Common Dependencies
22 | RUN apt-get update && \
23 | apt-get install -y --no-install-recommends \
24 | # SSH and RDMA
25 | libmlx4-1 \
26 | libmlx5-1 \
27 | librdmacm1 \
28 | libibverbs1 \
29 | libmthca1 \
30 | libdapl2 \
31 | dapl2-utils \
32 | openssh-client \
33 | openssh-server \
34 | iproute2 && \
35 | # Others
36 | apt-get install -y --no-install-recommends \
37 | build-essential \
38 | bzip2=1.0.6-8.1ubuntu0.2 \
39 | libbz2-1.0=1.0.6-8.1ubuntu0.2 \
40 | systemd \
41 | git=1:2.17.1-1ubuntu0.7 \
42 | wget \
43 | cpio \
44 | libsm6 \
45 | libxext6 \
46 | libxrender-dev \
47 | fuse && \
48 | apt-get clean -y && \
49 | rm -rf /var/lib/apt/lists/*
50 |
51 | # Conda Environment
52 | ENV MINICONDA_VERSION 4.7.12.1
53 | ENV PATH /opt/miniconda/bin:$PATH
54 | RUN wget -qO /tmp/miniconda.sh https://repo.continuum.io/miniconda/Miniconda3-${MINICONDA_VERSION}-Linux-x86_64.sh && \
55 | bash /tmp/miniconda.sh -bf -p /opt/miniconda && \
56 | conda clean -ay && \
57 | rm -rf /opt/miniconda/pkgs && \
58 | rm /tmp/miniconda.sh && \
59 | find / -type d -name __pycache__ | xargs rm -rf
60 |
61 | # To resolve horovod hangs due to a known NCCL issue in version 2.4.
62 | # Can remove it once we upgrade NCCL to 2.5+.
63 | # https://github.com/horovod/horovod/issues/893
64 | # ENV NCCL_TREE_THRESHOLD=0
65 | ENV PIP="pip install --no-cache-dir"
66 |
67 | RUN conda install -y conda=4.8.5 python=3.6.2 && conda clean -ay && \
68 | conda install -y mkl=2020.1 && \
69 | conda install -y numpy scipy scikit-learn scikit-image imageio protobuf && \
70 | conda install -y ruamel.yaml==0.16.10 && \
71 | # ruamel_yaml is a copy of ruamel.yaml package
72 | # conda installs version ruamel_yaml v0.15.87 which is vulnerable
73 | # force uninstall it leaving other packages intact
74 | conda remove --force -y ruamel_yaml && \
75 | conda clean -ay && \
76 | # Install AzureML SDK
77 | ${PIP} azureml-defaults && \
78 | # Install PyTorch
79 | ${PIP} torch==1.4.0 && \
80 | ${PIP} torchvision==0.2.1 && \
81 | ${PIP} wandb && \
82 | # # Install Horovod
83 | # HOROVOD_WITH_PYTORCH=1 ${PIP} horovod[pytorch]==0.19.5 && \
84 | # ldconfig && \
85 | ${PIP} tensorboard==1.15.0 && \
86 | ${PIP} future==0.17.1 && \
87 | ${PIP} onnxruntime==1.4.0 && \
88 | ${PIP} pytorch-lightning && \
89 | ${PIP} opencv-python-headless~=4.4.0 && \
90 | ${PIP} imgaug==0.4.0 --no-deps && \
91 | # hydra
92 | ${PIP} hydra-core --upgrade && \
93 | ${PIP} lmdb pyarrow
94 |
95 | RUN pip3 install --upgrade pip
96 | RUN pip3 install pipreqs
97 |
98 | RUN apt-get update
99 | RUN apt-get install -y --no-install-recommends libglib2.0-dev
100 | RUN apt-get install -y --no-install-recommends vim
101 |
102 | WORKDIR /
103 | RUN apt-get install -y --no-install-recommends libunwind8
104 | RUN apt-get install -y --no-install-recommends libicu-dev
105 | RUN apt-get install -y --no-install-recommends htop
106 | RUN apt-get install -y --no-install-recommends net-tools
107 | RUN apt-get install -y --no-install-recommends rsync
108 | RUN apt-get install -y --no-install-recommends tree
109 |
110 | # put the requirements file for your own repo under /app for pip-based installation!!!
111 | WORKDIR /app
112 | RUN pip3 install -r requirements.txt
113 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Towards Robust Video Object Segmentation with Adaptive Object Calibration (ACM Multimedia 2022)
2 |
3 |
4 | Preview version **paper** of this work is available at [Arxiv](https://arxiv.org/abs/2207.00887).
5 |
6 | The conference **poster** is available at [this github repo](https://github.com/JerryX1110/Robust-Video-Object-Segmentation/blob/main/figs/mm22_345_poster_a0.pdf).
7 |
8 | Long paper **presentation video** is available at [GoogleDrive](https://drive.google.com/drive/folders/18q0eWTB5lXSkZUHOuBfyQO2LbIdTJof5?usp=sharing) and [YouTube](https://youtu.be/tFMfZ4NATrw).
9 |
10 | **Qualitative results and comparisons** with previous SOTAs are available at [YouTube](https://www.youtube.com/watch?v=3F6n7tcwWkA).
11 |
12 | Welcome to starts ⭐ & comments 💹 & collaboration 😀 !!**
13 |
14 | ```diff
15 | - 2022.11.16: All the codes are cleaned and released ~
16 | - 2022.10.21: Add the robustness evaluation dataloader for other models, e.g., AOT~
17 | - 2022.10.1:Add the code of key implementations of this work~
18 | - 2022.9.25:Add the poster of this work~
19 | - 2022.8.27: Add presentation video and PPT for this work~
20 | - 2022.7.10: Add future works towards robust VOS!
21 | - 2022.7.5: Our ArXiv-version paper is available.
22 | - 2022.7.1: Repo init. Please stay tuned~
23 | ```
24 | ---
25 |
26 | ## Motivation for Robust Video Object Segmentation
27 |
28 |
29 |
30 |
31 | ## Pipeline
32 |
33 |
34 |
35 | ### Adaptive Object Proxy Representation (Component1)
36 |
37 |
38 |
39 |
40 | ### Object Mask Calibration (Component2)
41 |
42 |
43 |
44 |
45 | ## Abstract
46 | In the booming video era, video segmentation attracts increasing research attention in the multimedia community.
47 |
48 | Semi-supervised video object segmentation (VOS) aims at segmenting objects in all target frames of a video, given annotated object masks of reference frames. **Most existing methods build pixel-wise reference-target correlations and then perform pixel-wise tracking to obtain target masks. Due to neglecting object-level cues, pixel-level approaches make the tracking vulnerable to perturbations, and even indiscriminate among similar objects.**
49 |
50 | Towards **robust VOS**, the key insight is to calibrate the representation and mask of each specific object to be expressive and discriminative. Accordingly, we propose a new deep network, which can adaptively construct object representations and calibrate object masks to achieve stronger robustness.
51 |
52 | First, we construct the object representations by applying an **adaptive object proxy (AOP) aggregation** method, where the proxies represent arbitrary-shaped segments **via clustering** at multi-levels for reference.
53 |
54 | Then, prototype masks are initially generated from the reference-target correlations based on AOP.
55 | Afterwards, such proto-masks are further calibrated through network modulation, conditioning on the object proxy representations.
56 | We consolidate this **conditional mask calibration** process in a progressive manner, where the object representations and proto-masks evolve to be discriminative iteratively.
57 |
58 | Extensive experiments are conducted on the standard VOS benchmarks, YouTube-VOS-18/19 and DAVIS-17.
59 | Our model achieves the state-of-the-art performance among existing published works, and also exhibits significantly superior robustness against perturbations.
60 |
61 | ## Requirements
62 | * Python3
63 | * pytorch >= 1.4.0
64 | * torchvision
65 | * opencv-python
66 | * Pillow
67 |
68 | You can also use the docker image below to set up your env directly. However, this docker image may contain some redundent packages.
69 |
70 | ```latex
71 | docker image: xxiaoh/vos:10.1-cudnn7-torch1.4_v3
72 | ```
73 |
74 | A more light-weight version can be created by modified the [Dockerfile](https://github.com/JerryX1110/RPCMVOS/blob/main/Dockerfile) provided.
75 |
76 | ## Preparation
77 | * Datasets
78 |
79 | * **YouTube-VOS**
80 |
81 | A commonly-used large-scale VOS dataset.
82 |
83 | [datasets/YTB/2019](datasets/YTB/2019): version 2019, download [link](https://drive.google.com/drive/folders/1BWzrCWyPEmBEKm0lOHe5KLuBuQxUSwqz?usp=sharing). `train` is required for training. `valid` (6fps) and `valid_all_frames` (30fps, optional) are used for evaluation.
84 |
85 | [datasets/YTB/2018](datasets/YTB/2018): version 2018, download [link](https://drive.google.com/drive/folders/1bI5J1H3mxsIGo7Kp-pPZU8i6rnykOw7f?usp=sharing). Only `valid` (6fps) and `valid_all_frames` (30fps, optional) are required for this project and used for evaluation.
86 |
87 | * **DAVIS**
88 |
89 | A commonly-used small-scale VOS dataset.
90 |
91 | [datasets/DAVIS](datasets/DAVIS): [TrainVal](https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017-trainval-480p.zip) (480p) contains both the training and validation split. [Test-Dev](https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017-test-dev-480p.zip) (480p) contains the Test-dev split. The [full-resolution version](https://davischallenge.org/davis2017/code.html) is also supported for training and evaluation but not required.
92 |
93 | * pretrained weights for the backbone
94 |
95 | [resnet101-deeplabv3p](https://drive.google.com/file/d/101jYpeGGG58Kywk03331PKKzv1wN5DL0/view?usp=sharing)
96 |
97 |
98 | ## Implementation
99 |
100 | The key implementation of matching with adaptive-proxy-based representation is provided in [THIS FILE](https://github.com/JerryX1110/Robust-Video-Object-Segmentation/blob/main/AOC-Net/adaptive_embedding_for_matching.py). Other implementation and training/evaluation details can refer to [PRCMVOS](https://github.com/JerryX1110/RPCMVOS) or [CFBI](https://github.com/z-x-yang/CFBI).
101 |
102 | The key implementation of the preliminary robust VOS benchmark evaluation is provided in [THIS FILE](https://github.com/JerryX1110/Robust-Video-Object-Segmentation/tree/main/Robust-VOS-Benchmark).
103 |
104 | The whole project code is provided in [THIS FOLDER](https://github.com/JerryX1110/Robust-Video-Object-Segmentation/tree/main/AOC-Net/complete_project).
105 |
106 | Feel free to contact me if you have any problems with the implementation~
107 |
108 |
109 |
110 | * For evaluation, please use official YouTube-VOS servers ([2018 server](https://competitions.codalab.org/competitions/19544) and [2019 server](https://competitions.codalab.org/competitions/20127)), official [DAVIS toolkit](https://github.com/davisvideochallenge/davis-2017) (for Val), and official [DAVIS server](https://competitions.codalab.org/competitions/20516#learn_the_details) (for Test-dev).
111 |
112 |
113 |
114 | ## Limitation & Directions for further exploration towards Robust VOS!
115 | * Extension of the proposed clustering-based adaptive proxy representation to other dense-tracking tasks in a more efficient and robust way
116 | * Leverage the robust layered representation, i.e., intermediate masks, for robust mask calibration in other segmentation tasks
117 | * More diverse perturbation/corruption types can be studied for video segmentation tasks like VOS and VIS
118 | * Adversial attack and defence for VOS models is still an open question for further exploration
119 | * VOS model robustness verification and theoretical analysis
120 | * Model enhancement from the perspective of data management
121 |
122 | (to be continued...)
123 |
124 | ## Citation
125 | If you find this work is useful for your research, please consider citing:
126 |
127 | ```latex
128 | @inproceedings{xu2022towards,
129 | title={Towards Robust Video Object Segmentation with Adaptive Object Calibration},
130 | author={Xu, Xiaohao and Wang, Jinglu and Ming, Xiang and Lu, Yan},
131 | booktitle={Proceedings of the 30th ACM International Conference on Multimedia},
132 | pages={2709--2718},
133 | year={2022}
134 | }
135 | ```
136 |
137 |
138 | ## Credit
139 |
140 | **CFBI**:
141 |
142 | **Deeplab**:
143 |
144 | **GCT**:
145 |
146 | ## Related Works in VOS
147 | **Semisupervised video object segmentation repo/paper link:**
148 |
149 | **ARKitTrack [CVPR 2023]**:
150 |
151 | **TarVis [Arxiv 2023]**:
152 |
153 | **LBLVOS [AAAI 2023]**:
154 |
155 | **DeAOT [NeurIPS 2022]**:
156 |
157 | **BATMAN [ECCV 2022 Oral]**:
158 |
159 | **XMEM [ECCV 2022]**:
160 |
161 | **TBD [ECCV 2022]**:
162 |
163 | **QDMN [ECCV 2022]**:
164 |
165 | **GSFM [ECCV 2022]**:
166 |
167 | **SWEM [CVPR 2022]**:
168 |
169 | **RDE [CVPR 2022]**:
170 |
171 | **COVOS [CVPR 2022]** :
172 |
173 | **RPCM [AAAI 2022 Oral]** :
174 |
175 | **AOT [NeurIPS 2021]**:
176 |
177 | **STCN [NeurIPS 2021]**:
178 |
179 | **JOINT [ICCV 2021]**:
180 |
181 | **HMMN [ICCV 2021]**:
182 |
183 | **DMN-AOA [ICCV 2021]**:
184 |
185 | **MiVOS [CVPR 2021]**:
186 |
187 | **SSTVOS [CVPR 2021 Oral]**:
188 |
189 | **GraphMemVOS [ECCV 2020]**:
190 |
191 | **AFB-URR [NeurIPS 2020]**:
192 |
193 | **CFBI [ECCV 2020]**:
194 |
195 | **FRTM-VOS [CVPR 2020]**:
196 |
197 | **STM [ICCV 2019]**:
198 |
199 | **FEELVOS [CVPR 2019]**:
200 |
201 | (The list may be incomplete, feel free to contact me by pulling a issue and I'll add them on!)
202 |
203 | ## Useful websites for VOS
204 | **The 1st Large-scale Video Object Segmentation Challenge**:
205 |
206 | **The 2nd Large-scale Video Object Segmentation Challenge - Track 1: Video Object Segmentation**:
207 |
208 | **The Semi-Supervised DAVIS Challenge on Video Object Segmentation @ CVPR 2020**:
209 |
210 | **DAVIS**:
211 |
212 | **YouTube-VOS**:
213 |
214 | **Papers with code for Semi-VOS**:
215 |
216 | ## Acknowledgement ❤️
217 | This work is heavily built upon CFBI and RPCMVOS. Thanks to the author of CFBI to release such a wonderful code repo for further work to build upon!
218 |
219 | ## Welcome to comments and discussions!!
220 | Xiaohao Xu:
221 |
222 | ## License
223 | This project is released under the Mit license. See [LICENSE](LICENSE) for additional details.
224 |
--------------------------------------------------------------------------------
/Robust-VOS-Benchmark/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/figs/FIGS.MD:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/figs/mm22_345_poster_a0.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JerryX1110/Robust-Video-Object-Segmentation/f562afe70ac0f960d27980d3e82b237f8dea45ec/figs/mm22_345_poster_a0.pdf
--------------------------------------------------------------------------------
/figs/mm22_345_poster_a0.pptx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JerryX1110/Robust-Video-Object-Segmentation/f562afe70ac0f960d27980d3e82b237f8dea45ec/figs/mm22_345_poster_a0.pptx
--------------------------------------------------------------------------------