├── AOC-Net ├── __init__.py ├── adaptive_embedding_for_matching.py ├── complete_project │ ├── AOCNet │ │ ├── configs │ │ │ ├── resnet101_aocnet.py │ │ │ └── resnet101_aocnet_2.py │ │ ├── dataloaders │ │ │ ├── __init__.py │ │ │ ├── custom_transforms.py │ │ │ └── datasets_m.py │ │ ├── networks │ │ │ ├── __init__.py │ │ │ ├── aoc │ │ │ │ ├── __init__.py │ │ │ │ ├── aocnet.py │ │ │ │ ├── conditioning_layer.py │ │ │ │ └── decoding_module.py │ │ │ ├── deeplab │ │ │ │ ├── __init__.py │ │ │ │ ├── aspp.py │ │ │ │ ├── backbone │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── __pycache__ │ │ │ │ │ │ ├── __init__.cpython-36.pyc │ │ │ │ │ │ ├── mobilenet.cpython-36.pyc │ │ │ │ │ │ └── resnet.cpython-36.pyc │ │ │ │ │ ├── mobilenet.py │ │ │ │ │ └── resnet.py │ │ │ │ ├── decoder.py │ │ │ │ └── deeplab.py │ │ │ ├── engine │ │ │ │ ├── __init__.py │ │ │ │ ├── eval_manager_mm.py │ │ │ │ └── train_manager_mm.py │ │ │ └── layers │ │ │ │ ├── __init__.py │ │ │ │ ├── aspp.py │ │ │ │ ├── attention.py │ │ │ │ ├── conv_gru.py │ │ │ │ ├── gct.py │ │ │ │ ├── gru_conv.py │ │ │ │ ├── loss.py │ │ │ │ ├── matching.py │ │ │ │ ├── normalization.py │ │ │ │ ├── refiner.py │ │ │ │ ├── shannon_entropy.py │ │ │ │ └── shanoon_entropy.py │ │ ├── scripts │ │ │ ├── eval.sh │ │ │ └── train.sh │ │ ├── tools │ │ │ ├── eval_net_mm_rpa.py │ │ │ └── train_net_mm.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── checkpoint.py │ │ │ ├── eval.py │ │ │ ├── image.py │ │ │ ├── learning.py │ │ │ ├── meters.py │ │ │ └── metric.py │ └── README.md └── conditioning_layer.py ├── Dockerfile ├── README.md ├── Robust-VOS-Benchmark ├── AOT │ └── eval_datasets.py ├── CFBI&AOC(ours) │ └── datasets_robustness.py └── __init__.py └── figs ├── FIGS.MD ├── mm22_345_poster_a0.pdf └── mm22_345_poster_a0.pptx /AOC-Net/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /AOC-Net/complete_project/AOCNet/configs/resnet101_aocnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch 3 | import argparse 4 | import os 5 | import sys 6 | import cv2 7 | import time 8 | import socket 9 | import random 10 | 11 | class Configuration(): 12 | def __init__(self): 13 | self.EXP_NAME = 'aoc_stage_1' 14 | 15 | self.EVAL_AUTO_RESUME = False 16 | self.UNC_RATIO=1.0 17 | self.MEM_EVERY=5 18 | self.USE_SF=False 19 | self.PAST_FRAME_NUM=4 20 | self.COMPRESSION_RATE=4 21 | self.BLOCK_NUM=2 22 | 23 | self.DIR_ROOT = '/datasets/' 24 | self.DIR_DATA = '/datasets/datasets/' 25 | self.DIR_TEMP_RESULT = '/datasets/result_evaluation/' 26 | self.DIR_DAVIS = os.path.join(self.DIR_DATA, 'DAVIS') 27 | self.DIR_PORTRAIT_DATA ='/datasets/Portrait_dataset/dataset_post/' 28 | self.DIR_YTB = os.path.join(self.DIR_DATA, 'train_all_frames/train_all_frames/') 29 | self.DIR_YTB_EVAL = os.path.join(self.DIR_DATA, 'valid_all_frames/valid_all_frames/') 30 | self.DIR_YTB_EVAL18 = os.path.join(self.DIR_ROOT, 'MODELS/STCN-ALL/YouTube2018/valid/') 31 | self.DIR_YTB_EVAL19 = os.path.join(self.DIR_ROOT,'MODELS/STCN-ALL/YouTube/valid/') 32 | self.DIR_RESULT = os.path.join(self.DIR_ROOT, 'result', self.EXP_NAME) 33 | 34 | 35 | 36 | self.DIR_CKPT = os.path.join(self.DIR_RESULT, 'ckpt') 37 | self.DIR_LOG = os.path.join(self.DIR_RESULT, 'log') 38 | self.DIR_IMG_LOG = os.path.join(self.DIR_RESULT, 'log', 'img') 39 | self.DIR_TB_LOG = os.path.join(self.DIR_RESULT, 'log', 'tensorboard') 40 | self.DIR_EVALUATION = os.path.join(self.DIR_RESULT, 'eval') 41 | 42 | self.DATASETS = ['youtubevos'] 43 | self.DATA_WORKERS = 4 44 | self.DATA_RANDOMCROP = (465, 465) 45 | self.DATA_RANDOMFLIP = 0.5 46 | self.DATA_MAX_CROP_STEPS = 5 47 | self.DATA_MIN_SCALE_FACTOR = 1. 48 | self.DATA_MAX_SCALE_FACTOR = 1.3 49 | self.DATA_SHORT_EDGE_LEN = 480 50 | self.DATA_RANDOM_REVERSE_SEQ = True 51 | self.DATA_DAVIS_REPEAT = 30 52 | self.DATA_CURR_SEQ_LEN = 5 53 | self.DATA_RANDOM_GAP_DAVIS = 3 54 | self.DATA_RANDOM_GAP_YTB = 3 55 | 56 | 57 | self.PRETRAIN = True 58 | self.PRETRAIN_FULL = True 59 | self.PRETRAIN_MODEL = '/datasets/result/resnet101_cfbi_p2t_lr02_8C_pt_pxmc_pxm/ckpt/save_step_400000.pth' 60 | self.MODEL_BACKBONE = 'resnet' 61 | self.MODEL_MODULE = 'networks.aoc.aocnet' 62 | self.MODEL_OUTPUT_STRIDE = 16 63 | self.MODEL_ASPP_OUTDIM = 256 64 | self.MODEL_SHORTCUT_DIM = 48 65 | self.MODEL_SEMANTIC_EMBEDDING_DIM = 100 66 | self.MODEL_HEAD_EMBEDDING_DIM = 256 67 | self.MODEL_PRE_HEAD_EMBEDDING_DIM = 64 68 | self.MODEL_GN_GROUPS = 32 69 | self.MODEL_GN_EMB_GROUPS = 25 70 | self.MODEL_MULTI_LOCAL_DISTANCE = [2, 4, 6, 8, 10, 12] 71 | self.MODEL_LOCAL_DOWNSAMPLE = True 72 | self.MODEL_REFINE_CHANNELS = 64 # n * 32 73 | self.MODEL_LOW_LEVEL_INPLANES = 256 if self.MODEL_BACKBONE == 'resnet' else 24 74 | self.MODEL_RELATED_CHANNELS = 64 75 | self.MODEL_EPSILON = 1e-5 76 | self.MODEL_MATCHING_BACKGROUND = True 77 | self.MODEL_GCT_BETA_WD = True 78 | self.MODEL_FLOAT16_MATCHING = False 79 | self.MODEL_FREEZE_BN = True 80 | self.MODEL_FREEZE_BACKBONE = False 81 | 82 | self.TRAIN_TOTAL_STEPS = 50000 83 | self.TRAIN_START_STEP = 0 84 | self.TRAIN_LR = 0.01 85 | self.TRAIN_MOMENTUM = 0.9 86 | self.TRAIN_COSINE_DECAY = False 87 | self.TRAIN_WARM_UP_STEPS = 1000 88 | self.TRAIN_WEIGHT_DECAY = 15e-5 89 | self.TRAIN_POWER = 0.9 90 | self.TRAIN_GPUS = 8 91 | self.TRAIN_BATCH_SIZE = 8 92 | self.TRAIN_START_SEQ_TRAINING_STEPS = self.TRAIN_TOTAL_STEPS / 2 93 | self.TRAIN_TBLOG = False 94 | self.TRAIN_TBLOG_STEP = 60 95 | self.TRAIN_LOG_STEP = 20 96 | self.TRAIN_IMG_LOG = False 97 | self.TRAIN_TOP_K_PERCENT_PIXELS = 0.15 98 | self.TRAIN_HARD_MINING_STEP = self.TRAIN_TOTAL_STEPS / 2 99 | self.TRAIN_CLIP_GRAD_NORM = 5. 100 | self.TRAIN_SAVE_STEP = 2000 101 | self.TRAIN_MAX_KEEP_CKPT = 8000 102 | self.TRAIN_RESUME = False 103 | self.TRAIN_RESUME_CKPT = None 104 | self.TRAIN_RESUME_STEP = 0 105 | self.TRAIN_AUTO_RESUME = True 106 | self.TRAIN_GLOBAL_ATROUS_RATE = 1 107 | self.TRAIN_LOCAL_ATROUS_RATE = 1 108 | self.TRAIN_LOCAL_PARALLEL = True 109 | self.TRAIN_GLOBAL_CHUNKS = 1 110 | self.TRAIN_DATASET_FULL_RESOLUTION = True 111 | 112 | 113 | self.TEST_GPU_ID = 0 114 | self.TEST_DATASET = 'youtubevos' 115 | self.TEST_DATASET_FULL_RESOLUTION = True 116 | self.TEST_DATASET_SPLIT = ['val'] 117 | self.TEST_CKPT_PATH = None 118 | self.TEST_CKPT_STEP = None # if "None", evaluate the latest checkpoint. 119 | self.TEST_FLIP = False 120 | self.TEST_MULTISCALE = [1] 121 | self.TEST_MIN_SIZE = None 122 | self.TEST_MAX_SIZE = 800 * 1.3 if self.TEST_MULTISCALE == [1.] else 800 123 | self.TEST_WORKERS = 4 124 | self.TEST_GLOBAL_CHUNKS = 4 125 | self.TEST_GLOBAL_ATROUS_RATE = 1 126 | self.TEST_LOCAL_ATROUS_RATE = 1 127 | self.TEST_LOCAL_PARALLEL = True 128 | 129 | # dist 130 | self.DIST_ENABLE = True 131 | self.DIST_BACKEND = "nccl" 132 | 133 | myname = socket.getfqdn(socket.gethostname( )) 134 | myaddr = socket.gethostbyname(myname) 135 | 136 | self.DIST_URL = "tcp://"+myaddr+":"+str(random.randint(30000,50000)) 137 | self.DIST_START_GPU = 0 138 | 139 | self.__check() 140 | 141 | def __check(self): 142 | if not torch.cuda.is_available(): 143 | raise ValueError('config.py: cuda is not avalable') 144 | if self.TRAIN_GPUS == 0: 145 | raise ValueError('config.py: the number of GPU is 0') 146 | for path in [self.DIR_RESULT, self.DIR_CKPT, self.DIR_LOG, self.DIR_EVALUATION, self.DIR_IMG_LOG, self.DIR_TB_LOG]: 147 | if not os.path.isdir(path): 148 | os.makedirs(path) 149 | 150 | 151 | 152 | cfg = Configuration() 153 | -------------------------------------------------------------------------------- /AOC-Net/complete_project/AOCNet/configs/resnet101_aocnet_2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import os 4 | import sys 5 | import cv2 6 | import time 7 | import socket 8 | import random 9 | 10 | class Configuration(): 11 | def __init__(self): 12 | self.EXP_NAME = 'aocnet_stage_2' 13 | 14 | self.EVAL_AUTO_RESUME = False 15 | self.UNC_RATIO=1.0 16 | self.MEM_EVERY=5 17 | self.USE_SF=False 18 | self.PAST_FRAME_NUM=4 19 | self.COMPRESSION_RATE=4 20 | self.BLOCK_NUM=2 21 | 22 | self.DIR_ROOT = '/datasets/' 23 | self.DIR_DATA = '/datasets/datasets/' 24 | self.DIR_TEMP_RESULT = '/datasets/result_evaluation/' 25 | self.DIR_DAVIS = os.path.join(self.DIR_DATA, 'DAVIS') 26 | self.DIR_PORTRAIT_DATA ='/datasets/Portrait_dataset/dataset_post/' 27 | self.DIR_YTB = os.path.join(self.DIR_DATA, 'train_all_frames/train_all_frames/') 28 | self.DIR_YTB_EVAL = os.path.join(self.DIR_DATA, 'valid_all_frames/valid_all_frames/') 29 | self.DIR_YTB_EVAL18 = os.path.join(self.DIR_ROOT, 'MODELS/STCN-ALL/YouTube2018/valid/') 30 | self.DIR_YTB_EVAL19 = os.path.join(self.DIR_ROOT,'MODELS/STCN-ALL/YouTube/valid/') 31 | self.DIR_RESULT = os.path.join(self.DIR_ROOT, 'result', self.EXP_NAME) 32 | 33 | 34 | 35 | self.DIR_CKPT = os.path.join(self.DIR_RESULT, 'ckpt') 36 | self.DIR_LOG = os.path.join(self.DIR_RESULT, 'log') 37 | self.DIR_IMG_LOG = os.path.join(self.DIR_RESULT, 'log', 'img') 38 | self.DIR_TB_LOG = os.path.join(self.DIR_RESULT, 'log', 'tensorboard') 39 | self.DIR_EVALUATION = os.path.join(self.DIR_RESULT, 'eval') 40 | 41 | self.DATASETS = ['youtubevos'] 42 | self.DATA_WORKERS = 4 43 | self.DATA_RANDOMCROP = (465, 465) 44 | self.DATA_RANDOMFLIP = 0.5 45 | self.DATA_MAX_CROP_STEPS = 5 46 | self.DATA_MIN_SCALE_FACTOR = 1. 47 | self.DATA_MAX_SCALE_FACTOR = 1.3 48 | self.DATA_SHORT_EDGE_LEN = 480 49 | self.DATA_RANDOM_REVERSE_SEQ = True 50 | self.DATA_DAVIS_REPEAT = 30 51 | self.DATA_CURR_SEQ_LEN = 5 52 | self.DATA_RANDOM_GAP_DAVIS = 3 53 | self.DATA_RANDOM_GAP_YTB = 3 54 | 55 | 56 | self.PRETRAIN = True 57 | self.PRETRAIN_FULL = False 58 | self.PRETRAIN_MODEL = '/datasets/MODELS/CFBI/resnet101-deeplabv3p.pth.tar' 59 | 60 | self.MODEL_BACKBONE = 'resnet' 61 | self.MODEL_MODULE = 'networks.aoc.aocnet' 62 | self.MODEL_OUTPUT_STRIDE = 16 63 | self.MODEL_ASPP_OUTDIM = 256 64 | self.MODEL_SHORTCUT_DIM = 48 65 | self.MODEL_SEMANTIC_EMBEDDING_DIM = 100 66 | self.MODEL_HEAD_EMBEDDING_DIM = 256 67 | self.MODEL_PRE_HEAD_EMBEDDING_DIM = 64 68 | self.MODEL_GN_GROUPS = 32 69 | self.MODEL_GN_EMB_GROUPS = 25 70 | self.MODEL_MULTI_LOCAL_DISTANCE = [2, 4, 6, 8, 10, 12] 71 | self.MODEL_LOCAL_DOWNSAMPLE = True 72 | self.MODEL_REFINE_CHANNELS = 64 # n * 32 73 | self.MODEL_LOW_LEVEL_INPLANES = 256 if self.MODEL_BACKBONE == 'resnet' else 24 74 | self.MODEL_RELATED_CHANNELS = 64 75 | self.MODEL_EPSILON = 1e-5 76 | self.MODEL_MATCHING_BACKGROUND = True 77 | self.MODEL_GCT_BETA_WD = True 78 | self.MODEL_FLOAT16_MATCHING = False 79 | self.MODEL_FREEZE_BN = True 80 | self.MODEL_FREEZE_BACKBONE = False 81 | 82 | self.TRAIN_TOTAL_STEPS = 400000 83 | self.TRAIN_START_STEP = 0 84 | self.TRAIN_LR = 0.01 85 | self.TRAIN_MOMENTUM = 0.9 86 | self.TRAIN_COSINE_DECAY = False 87 | self.TRAIN_WARM_UP_STEPS = 1000 88 | self.TRAIN_WEIGHT_DECAY = 15e-5 89 | self.TRAIN_POWER = 0.9 90 | self.TRAIN_GPUS = 8 91 | self.TRAIN_BATCH_SIZE = 8 92 | self.TRAIN_START_SEQ_TRAINING_STEPS = self.TRAIN_TOTAL_STEPS / 2 93 | self.TRAIN_TBLOG = False 94 | self.TRAIN_TBLOG_STEP = 60 95 | self.TRAIN_LOG_STEP = 20 96 | self.TRAIN_IMG_LOG = False 97 | self.TRAIN_TOP_K_PERCENT_PIXELS = 0.15 98 | self.TRAIN_HARD_MINING_STEP = self.TRAIN_TOTAL_STEPS / 2 99 | self.TRAIN_CLIP_GRAD_NORM = 5. 100 | self.TRAIN_SAVE_STEP = 2000 101 | self.TRAIN_MAX_KEEP_CKPT = 8000 102 | self.TRAIN_RESUME = False 103 | self.TRAIN_RESUME_CKPT = None 104 | self.TRAIN_RESUME_STEP = 0 105 | self.TRAIN_AUTO_RESUME = True 106 | self.TRAIN_GLOBAL_ATROUS_RATE = 1 107 | self.TRAIN_LOCAL_ATROUS_RATE = 1 108 | self.TRAIN_LOCAL_PARALLEL = True 109 | self.TRAIN_GLOBAL_CHUNKS = 1 110 | self.TRAIN_DATASET_FULL_RESOLUTION = True 111 | 112 | 113 | self.TEST_GPU_ID = 0 114 | self.TEST_DATASET = 'youtubevos' 115 | self.TEST_DATASET_FULL_RESOLUTION = False 116 | self.TEST_DATASET_SPLIT = ['val'] 117 | self.TEST_CKPT_PATH = None 118 | self.TEST_CKPT_STEP = None # if "None", evaluate the latest checkpoint. 119 | self.TEST_FLIP = False 120 | self.TEST_MULTISCALE = [1] 121 | self.TEST_MIN_SIZE = None 122 | self.TEST_MAX_SIZE = 800 * 1.3 if self.TEST_MULTISCALE == [1.] else 800 123 | self.TEST_WORKERS = 4 124 | self.TEST_GLOBAL_CHUNKS = 4 125 | self.TEST_GLOBAL_ATROUS_RATE = 1 126 | self.TEST_LOCAL_ATROUS_RATE = 1 127 | self.TEST_LOCAL_PARALLEL = True 128 | 129 | # dist 130 | self.DIST_ENABLE = True 131 | self.DIST_BACKEND = "nccl" 132 | 133 | myname = socket.getfqdn(socket.gethostname( )) 134 | myaddr = socket.gethostbyname(myname) 135 | 136 | self.DIST_URL = "tcp://"+myaddr+":"+str(random.randint(30000,50000)) 137 | self.DIST_START_GPU = 0 138 | 139 | self.__check() 140 | 141 | def __check(self): 142 | if not torch.cuda.is_available(): 143 | raise ValueError('config.py: cuda is not avalable') 144 | if self.TRAIN_GPUS == 0: 145 | raise ValueError('config.py: the number of GPU is 0') 146 | for path in [self.DIR_RESULT, self.DIR_CKPT, self.DIR_LOG, self.DIR_EVALUATION, self.DIR_IMG_LOG, self.DIR_TB_LOG]: 147 | if not os.path.isdir(path): 148 | os.makedirs(path) 149 | 150 | 151 | 152 | cfg = Configuration() 153 | -------------------------------------------------------------------------------- /AOC-Net/complete_project/AOCNet/dataloaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/Robust-Video-Object-Segmentation/f562afe70ac0f960d27980d3e82b237f8dea45ec/AOC-Net/complete_project/AOCNet/dataloaders/__init__.py -------------------------------------------------------------------------------- /AOC-Net/complete_project/AOCNet/dataloaders/custom_transforms.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import cv2 4 | import numpy as np 5 | import torch 6 | 7 | cv2.setNumThreads(0) 8 | 9 | class Resize(object): 10 | """Rescale the image in a sample to a given size. 11 | 12 | Args: 13 | output_size (tuple or int): Desired output size. If tuple, output is 14 | matched to output_size. If int, smaller of image edges is matched 15 | to output_size keeping aspect ratio the same. 16 | """ 17 | 18 | def __init__(self, output_size): 19 | assert isinstance(output_size, (int, tuple)) 20 | if isinstance(output_size, int): 21 | self.output_size = (output_size, output_size) 22 | else: 23 | self.output_size = output_size 24 | 25 | def __call__(self, sample): 26 | prev_img = sample['prev_img'] 27 | h, w = prev_img.shape[:2] 28 | if self.output_size == (h, w): 29 | return sample 30 | else: 31 | new_h, new_w = self.output_size 32 | 33 | for elem in sample.keys(): 34 | if 'meta' in elem: 35 | continue 36 | tmp = sample[elem] 37 | 38 | if elem == 'prev_img' or elem == 'curr_img' or elem == 'ref_img': 39 | flagval = cv2.INTER_CUBIC 40 | else: 41 | flagval = cv2.INTER_NEAREST 42 | 43 | if elem == 'curr_img' or elem == 'curr_label': 44 | new_tmp = [] 45 | all_tmp = tmp 46 | for tmp in all_tmp: 47 | tmp = cv2.resize(tmp, dsize=(new_w, new_h), 48 | interpolation=flagval) 49 | new_tmp.append(tmp) 50 | tmp = new_tmp 51 | else: 52 | tmp = cv2.resize(tmp, dsize=(new_w, new_h), 53 | interpolation=flagval) 54 | 55 | sample[elem] = tmp 56 | 57 | return sample 58 | 59 | class BalancedRandomCrop(object): 60 | """Crop randomly the image in a sample. 61 | 62 | Args: 63 | output_size (tuple or int): Desired output size. If int, square crop 64 | is made. 65 | """ 66 | 67 | def __init__(self, output_size, max_step=5, max_obj_num=5, min_obj_pixel_num=100): 68 | assert isinstance(output_size, (int, tuple)) 69 | if isinstance(output_size, int): 70 | self.output_size = (output_size, output_size) 71 | else: 72 | assert len(output_size) == 2 73 | self.output_size = output_size 74 | self.max_step = max_step 75 | self.max_obj_num = max_obj_num 76 | self.min_obj_pixel_num = min_obj_pixel_num 77 | 78 | def __call__(self, sample): 79 | 80 | image = sample['prev_img'] 81 | h, w = image.shape[:2] 82 | new_h, new_w = self.output_size 83 | new_h = h if new_h >= h else new_h 84 | new_w = w if new_w >= w else new_w 85 | ref_label = sample["ref_label"] 86 | prev_label = sample["prev_label"] 87 | curr_label = sample["curr_label"] 88 | 89 | is_contain_obj = False 90 | step = 0 91 | while (not is_contain_obj) and (step < self.max_step): 92 | step += 1 93 | top = np.random.randint(0, h - new_h + 1) 94 | left = np.random.randint(0, w - new_w + 1) 95 | after_crop = [] 96 | contains = [] 97 | for elem in ([ref_label, prev_label] + curr_label): 98 | tmp = elem[top: top + new_h, left:left + new_w] 99 | contains.append(np.unique(tmp)) 100 | after_crop.append(tmp) 101 | 102 | 103 | all_obj = list(np.sort(contains[0])) 104 | 105 | if all_obj[-1] == 0: 106 | continue 107 | 108 | # remove background 109 | if all_obj[0] == 0: 110 | all_obj = all_obj[1:] 111 | # remove small obj 112 | new_all_obj = [] 113 | for obj_id in all_obj: 114 | after_crop_pixels = np.sum(after_crop[0] == obj_id) 115 | if after_crop_pixels > self.min_obj_pixel_num: 116 | new_all_obj.append(obj_id) 117 | 118 | if len(new_all_obj) == 0: 119 | is_contain_obj = False 120 | else: 121 | is_contain_obj = True 122 | 123 | if len(new_all_obj) > self.max_obj_num: 124 | random.shuffle(new_all_obj) 125 | new_all_obj = new_all_obj[:self.max_obj_num] 126 | 127 | all_obj = [0] + new_all_obj 128 | 129 | 130 | post_process = [] 131 | for elem in after_crop: 132 | new_elem = elem * 0 133 | for idx in range(len(all_obj)): 134 | obj_id = all_obj[idx] 135 | if obj_id == 0: 136 | continue 137 | mask = elem == obj_id 138 | 139 | new_elem += (mask * idx).astype(np.uint8) 140 | post_process.append(new_elem.astype(np.uint8)) 141 | 142 | sample["ref_label"] = post_process[0] 143 | sample["prev_label"] = post_process[1] 144 | curr_len = len(sample["curr_img"]) 145 | sample["curr_label"] = [] 146 | for idx in range(curr_len): 147 | sample["curr_label"].append(post_process[idx + 2]) 148 | 149 | for elem in sample.keys(): 150 | if 'meta' in elem or 'label' in elem: 151 | continue 152 | if elem == 'curr_img': 153 | new_tmp = [] 154 | for tmp_ in sample[elem]: 155 | tmp_ = tmp_[top: top + new_h, left:left + new_w] 156 | new_tmp.append(tmp_) 157 | sample[elem] = new_tmp 158 | else: 159 | tmp = sample[elem] 160 | tmp = tmp[top: top + new_h, left:left + new_w] 161 | sample[elem] = tmp 162 | 163 | obj_num = len(all_obj) - 1 164 | 165 | sample['meta']['obj_num'] = obj_num 166 | 167 | return sample 168 | 169 | 170 | class RandomScale(object): 171 | """Randomly resize the image and the ground truth to specified scales. 172 | Args: 173 | scales (list): the list of scales 174 | """ 175 | 176 | def __init__(self, min_scale=1., max_scale=1.3, short_edge=None): 177 | self.min_scale = min_scale 178 | self.max_scale = max_scale 179 | self.short_edge = short_edge 180 | 181 | def __call__(self, sample): 182 | # Fixed range of scales 183 | sc = np.random.uniform(self.min_scale, self.max_scale) 184 | # Align short edge 185 | if not (self.short_edge is None): 186 | image = sample['prev_img'] 187 | h, w = image.shape[:2] 188 | if h > w: 189 | sc *= float(self.short_edge) / w 190 | else: 191 | sc *= float(self.short_edge) / h 192 | 193 | 194 | for elem in sample.keys(): 195 | if 'meta' in elem: 196 | continue 197 | tmp = sample[elem] 198 | 199 | if elem == 'prev_img' or elem == 'curr_img' or elem == 'ref_img': 200 | flagval = cv2.INTER_CUBIC 201 | else: 202 | flagval = cv2.INTER_NEAREST 203 | 204 | if elem == 'curr_img' or elem == 'curr_label': 205 | new_tmp = [] 206 | for tmp_ in tmp: 207 | tmp_ = cv2.resize(tmp_, None, fx=sc, fy=sc, interpolation=flagval) 208 | new_tmp.append(tmp_) 209 | tmp = new_tmp 210 | else: 211 | tmp = cv2.resize(tmp, None, fx=sc, fy=sc, interpolation=flagval) 212 | 213 | sample[elem] = tmp 214 | 215 | return sample 216 | 217 | class RestrictSize(object): 218 | """Randomly resize the image and the ground truth to specified scales. 219 | Args: 220 | scales (list): the list of scales 221 | """ 222 | 223 | def __init__(self, min_size=None, max_size=800*1.3): 224 | self.min_size = min_size 225 | self.max_size = max_size 226 | assert ((min_size is None)) or ((max_size is None)) 227 | 228 | def __call__(self, sample): 229 | 230 | # Fixed range of scales 231 | sc = None 232 | image = sample['ref_img'] 233 | h, w = image.shape[:2] 234 | # Align short edge 235 | if not (self.min_size is None): 236 | if h > w: 237 | short_edge = w 238 | else: 239 | short_edge = h 240 | if short_edge < self.min_size: 241 | sc = float(self.min_size) / short_edge 242 | else: 243 | if h > w: 244 | long_edge = h 245 | else: 246 | long_edge = w 247 | if long_edge > self.max_size: 248 | sc = float(self.max_size) / long_edge 249 | 250 | if sc is None: 251 | new_h = h 252 | new_w = w 253 | else: 254 | new_h = int(sc * h) 255 | new_w = int(sc * w) 256 | new_h = new_h - (new_h - 1) % 4 257 | new_w = new_w - (new_w - 1) % 4 258 | if new_h == h and new_w == w: 259 | return sample 260 | 261 | 262 | for elem in sample.keys(): 263 | if 'meta' in elem: 264 | continue 265 | tmp = sample[elem] 266 | 267 | if 'label' in elem: 268 | flagval = cv2.INTER_NEAREST 269 | else: 270 | flagval = cv2.INTER_CUBIC 271 | 272 | tmp = cv2.resize(tmp, dsize=(new_w, new_h), interpolation=flagval) 273 | 274 | sample[elem] = tmp 275 | 276 | return sample 277 | 278 | 279 | class RandomHorizontalFlip(object): 280 | """Horizontally flip the given image and ground truth randomly with a probability of 0.5.""" 281 | 282 | def __init__(self, prob): 283 | self.p = prob 284 | 285 | def __call__(self, sample): 286 | 287 | if random.random() < self.p: 288 | for elem in sample.keys(): 289 | if 'meta' in elem: 290 | continue 291 | if elem == 'curr_img' or elem == 'curr_label': 292 | new_tmp = [] 293 | for tmp_ in sample[elem]: 294 | tmp_ = cv2.flip(tmp_, flipCode=1) 295 | new_tmp.append(tmp_) 296 | sample[elem] = new_tmp 297 | else: 298 | tmp = sample[elem] 299 | tmp = cv2.flip(tmp, flipCode=1) 300 | sample[elem] = tmp 301 | 302 | return sample 303 | 304 | 305 | class RandomGaussianBlur(object): 306 | 307 | def __init__(self, prob=0.2): 308 | self.p = prob 309 | 310 | def __call__(self, sample): 311 | 312 | 313 | for elem in sample.keys(): 314 | if 'meta' in elem or 'label' in elem: 315 | continue 316 | 317 | if elem == 'curr_img': 318 | new_tmp = [] 319 | for tmp_ in sample[elem]: 320 | if random.random() < self.p: 321 | std = random.random() * 1.9 + 0.1 # [0.1, 2] 322 | tmp_ = cv2.GaussianBlur(tmp_, (9, 9), sigmaX=std, sigmaY=std) 323 | new_tmp.append(tmp_) 324 | sample[elem] = new_tmp 325 | else: 326 | tmp = sample[elem] 327 | if random.random() < self.p: 328 | std = random.random() * 1.9 + 0.1 # [0.1, 2] 329 | tmp = cv2.GaussianBlur(tmp, (9, 9), sigmaX=std, sigmaY=std) 330 | sample[elem] = tmp 331 | 332 | return sample 333 | 334 | class SubtractMeanImage(object): 335 | def __init__(self, mean, change_channels=False): 336 | self.mean = mean 337 | self.change_channels = change_channels 338 | 339 | def __call__(self, sample): 340 | for elem in sample.keys(): 341 | if 'image' in elem: 342 | if self.change_channels: 343 | sample[elem] = sample[elem][:, :, [2, 1, 0]] 344 | sample[elem] = np.subtract( 345 | sample[elem], np.array(self.mean, dtype=np.float32)) 346 | return sample 347 | 348 | def __str__(self): 349 | return 'SubtractMeanImage' + str(self.mean) 350 | 351 | 352 | class ToTensor(object): 353 | """Convert ndarrays in sample to Tensors.""" 354 | 355 | def __call__(self, sample): 356 | 357 | for elem in sample.keys(): 358 | if 'meta' in elem: 359 | continue 360 | tmp = sample[elem] 361 | 362 | if elem == 'curr_img' or elem == 'curr_label': 363 | new_tmp = [] 364 | for tmp_ in tmp: 365 | if tmp_.ndim == 2: 366 | tmp_ = tmp_[:, :, np.newaxis] 367 | else: 368 | tmp_ = tmp_ / 255. 369 | tmp_ -= (0.485, 0.456, 0.406) 370 | tmp_ /= (0.229, 0.224, 0.225) 371 | tmp_ = tmp_.transpose((2, 0, 1)) 372 | new_tmp.append(torch.from_numpy(tmp_)) 373 | tmp = new_tmp 374 | else: 375 | if tmp.ndim == 2: 376 | tmp = tmp[:, :, np.newaxis] 377 | else: 378 | tmp = tmp / 255. 379 | tmp -= (0.485, 0.456, 0.406) 380 | tmp /= (0.229, 0.224, 0.225) 381 | tmp = tmp.transpose((2, 0, 1)) 382 | tmp = torch.from_numpy(tmp) 383 | sample[elem] = tmp 384 | 385 | return sample 386 | 387 | class MultiRestrictSize(object): 388 | def __init__(self, min_size=None, max_size=800, flip=False, multi_scale=[1.3]): 389 | self.min_size = min_size 390 | self.max_size = max_size 391 | self.multi_scale = multi_scale 392 | self.flip = flip 393 | assert ((min_size is None)) or ((max_size is None)) 394 | 395 | def __call__(self, sample): 396 | samples = [] 397 | image = sample['current_img'] 398 | h, w = image.shape[:2] 399 | for scale in self.multi_scale: 400 | # Fixed range of scales 401 | sc = None 402 | # Align short edge 403 | if not (self.min_size is None): 404 | if h > w: 405 | short_edge = w 406 | else: 407 | short_edge = h 408 | if short_edge > self.min_size: 409 | sc = float(self.min_size) / short_edge 410 | else: 411 | if h > w: 412 | long_edge = h 413 | else: 414 | long_edge = w 415 | if long_edge > self.max_size: 416 | sc = float(self.max_size) / long_edge 417 | 418 | if sc is None: 419 | new_h = h 420 | new_w = w 421 | else: 422 | new_h = sc * h 423 | new_w = sc * w 424 | new_h = int(new_h * scale) 425 | new_w = int(new_w * scale) 426 | 427 | if (new_h - 1) % 16 != 0: 428 | new_h = int(np.around((new_h - 1) / 16.) * 16 + 1) 429 | if (new_w - 1) % 16 != 0: 430 | new_w = int(np.around((new_w - 1) / 16.) * 16 + 1) 431 | 432 | if new_h == h and new_w == w: 433 | samples.append(sample) 434 | else: 435 | new_sample = {} 436 | for elem in sample.keys(): 437 | if 'meta' in elem: 438 | new_sample[elem] = sample[elem] 439 | continue 440 | tmp = sample[elem] 441 | if 'label' in elem: 442 | new_sample[elem] = sample[elem] 443 | continue 444 | else: 445 | flagval = cv2.INTER_CUBIC 446 | tmp = cv2.resize(tmp, dsize=(new_w, new_h), interpolation=flagval) 447 | new_sample[elem] = tmp 448 | samples.append(new_sample) 449 | 450 | if self.flip: 451 | now_sample = samples[-1] 452 | new_sample = {} 453 | for elem in now_sample.keys(): 454 | if 'meta' in elem: 455 | new_sample[elem] = now_sample[elem].copy() 456 | new_sample[elem]['flip'] = True 457 | continue 458 | tmp = now_sample[elem] 459 | tmp = tmp[:, ::-1].copy() 460 | new_sample[elem] = tmp 461 | samples.append(new_sample) 462 | 463 | return samples 464 | 465 | class MultiToTensor(object): 466 | def __call__(self, samples): 467 | for idx in range(len(samples)): 468 | sample = samples[idx] 469 | for elem in sample.keys(): 470 | if 'meta' in elem: 471 | continue 472 | tmp = sample[elem] 473 | if tmp is None: 474 | continue 475 | 476 | if tmp.ndim == 2: 477 | tmp = tmp[:, :, np.newaxis] 478 | else: 479 | tmp = tmp / 255. 480 | tmp -= (0.485, 0.456, 0.406) 481 | tmp /= (0.229, 0.224, 0.225) 482 | 483 | tmp = tmp.transpose((2, 0, 1)) 484 | samples[idx][elem] = torch.from_numpy(tmp) 485 | 486 | return samples 487 | 488 | -------------------------------------------------------------------------------- /AOC-Net/complete_project/AOCNet/networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/Robust-Video-Object-Segmentation/f562afe70ac0f960d27980d3e82b237f8dea45ec/AOC-Net/complete_project/AOCNet/networks/__init__.py -------------------------------------------------------------------------------- /AOC-Net/complete_project/AOCNet/networks/aoc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/Robust-Video-Object-Segmentation/f562afe70ac0f960d27980d3e82b237f8dea45ec/AOC-Net/complete_project/AOCNet/networks/aoc/__init__.py -------------------------------------------------------------------------------- /AOC-Net/complete_project/AOCNet/networks/aoc/aocnet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from networks.layers.loss import Concat_CrossEntropyLoss 6 | from networks.layers.matching import global_matching, global_matching_for_eval, local_matching, foreground2background, global_matching_proxy,local_matching_proxy,global_matching_cluster2, global_matching_cluster,global_matching_for_eval_cluster, global_matching_for_eval_proxy 7 | from networks.layers.attention import calculate_attention_head, calculate_attention_head_for_eval, calculate_attention_head_p_m, calculate_attention_head_for_eval_p_m 8 | from networks.p2t.decoding_module import CalibrationDecoding,DynamicPreHead 9 | 10 | 11 | class AOCNet(nn.Module): 12 | def __init__(self, cfg, feature_extracter): 13 | super(AOCNet, self).__init__() 14 | self.cfg = cfg 15 | self.epsilon = cfg.MODEL_EPSILON 16 | 17 | self.feature_extracter=feature_extracter 18 | 19 | self.seperate_conv = nn.Conv2d(cfg.MODEL_ASPP_OUTDIM, cfg.MODEL_ASPP_OUTDIM, kernel_size=3, stride=1, padding=1, groups=cfg.MODEL_ASPP_OUTDIM) 20 | self.bn1 = nn.GroupNorm(cfg.MODEL_GN_GROUPS, cfg.MODEL_ASPP_OUTDIM) 21 | self.relu1 = nn.ReLU(True) 22 | self.embedding_conv = nn.Conv2d(cfg.MODEL_ASPP_OUTDIM, cfg.MODEL_SEMANTIC_EMBEDDING_DIM, 1, 1) 23 | self.bn2 = nn.GroupNorm(cfg.MODEL_GN_EMB_GROUPS, cfg.MODEL_SEMANTIC_EMBEDDING_DIM) 24 | self.relu2 = nn.ReLU(True) 25 | self.semantic_embedding=nn.Sequential(*[self.seperate_conv, self.bn1, self.relu1, self.embedding_conv, self.bn2, self.relu2]) 26 | 27 | self.bg_bias = nn.Parameter(torch.zeros(1, 1, 1, 1)) 28 | self.fg_bias = nn.Parameter(torch.zeros(1, 1, 1, 1)) 29 | 30 | self.criterion = Concat_CrossEntropyLoss(cfg.TRAIN_TOP_K_PERCENT_PIXELS, cfg.TRAIN_HARD_MINING_STEP) 31 | 32 | for m in self.semantic_embedding: 33 | if isinstance(m, nn.Conv2d): 34 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 35 | 36 | self.dynamic_seghead = CalibrationDecoding( 37 | in_dim=cfg.MODEL_SEMANTIC_EMBEDDING_DIM + cfg.MODEL_PRE_HEAD_EMBEDDING_DIM, 38 | attention_dim=cfg.MODEL_SEMANTIC_EMBEDDING_DIM * 4, 39 | embed_dim=cfg.MODEL_HEAD_EMBEDDING_DIM, 40 | refine_dim=cfg.MODEL_REFINE_CHANNELS, 41 | low_level_dim=cfg.MODEL_LOW_LEVEL_INPLANES) 42 | 43 | in_dim = 2*(2 + len(cfg.MODEL_MULTI_LOCAL_DISTANCE))-1+2 44 | 45 | if cfg.MODEL_MATCHING_BACKGROUND: 46 | in_dim += 1 +len(cfg.MODEL_MULTI_LOCAL_DISTANCE) 47 | 48 | 49 | self.dynamic_prehead = DynamicPreHead( 50 | in_dim=in_dim, 51 | embed_dim=cfg.MODEL_PRE_HEAD_EMBEDDING_DIM) 52 | 53 | 54 | def forward(self, input,memory_prev_list, ref_frame_label, previous_frame_mask, current_frame_mask, 55 | gt_ids, step=0, tf_board=False): 56 | x, low_level = self.extract_feature(input) 57 | ref_frame_embedding, previous_frame_embedding, current_frame_embedding = torch.split(x, split_size_or_sections=int(x.size(0)/3), dim=0) 58 | _, _, current_low_level = torch.split(low_level, split_size_or_sections=int(x.size(0)/3), dim=0) 59 | bs, c, h, w = current_frame_embedding.size() 60 | tmp_dic, boards, memory_cur_list= self.before_seghead_process( 61 | memory_prev_list, 62 | ref_frame_embedding, 63 | previous_frame_embedding, 64 | current_frame_embedding, 65 | ref_frame_label, 66 | previous_frame_mask, 67 | gt_ids, 68 | current_low_level=current_low_level,tf_board=tf_board) 69 | label_dic=[] 70 | all_pred = [] 71 | for i in range(bs): 72 | tmp_pred_logits = tmp_dic[i] 73 | tmp_pred_logits = nn.functional.interpolate(tmp_pred_logits, size=(input.shape[2],input.shape[3]), mode='bilinear', align_corners=True) 74 | tmp_dic[i] = tmp_pred_logits 75 | label_tmp, obj_num = current_frame_mask[i], gt_ids[i] 76 | label_dic.append(label_tmp.long()) 77 | pred = tmp_pred_logits 78 | preds_s = torch.argmax(pred,dim=1) 79 | all_pred.append(preds_s) 80 | all_pred = torch.cat(all_pred, dim=0) 81 | 82 | return self.criterion(tmp_dic, label_dic, step), all_pred, boards, memory_cur_list 83 | 84 | def forward_for_eval(self,memory_prev_list, ref_embeddings, ref_masks, prev_embedding, prev_mask, current_frame, pred_size, gt_ids): 85 | current_frame_embedding, current_low_level = self.extract_feature(current_frame) 86 | if prev_embedding is None: 87 | return None, current_frame_embedding,memory_prev_list 88 | else: 89 | bs,c,h,w = current_frame_embedding.size() 90 | tmp_dic, _ ,memory_cur_list= self.before_seghead_process( 91 | memory_prev_list, 92 | ref_embeddings, 93 | prev_embedding, 94 | current_frame_embedding, 95 | ref_masks, 96 | prev_mask, 97 | gt_ids, 98 | current_low_level=current_low_level, 99 | tf_board=False) 100 | all_pred = [] 101 | for i in range(bs): 102 | pred = tmp_dic[i] 103 | pred = nn.functional.interpolate(pred, size=(pred_size[0],pred_size[1]), mode='bilinear',align_corners=True) 104 | all_pred.append(pred) 105 | all_pred = torch.cat(all_pred, dim=0) 106 | all_pred = torch.softmax(all_pred, dim=1) 107 | return all_pred, current_frame_embedding, memory_cur_list 108 | 109 | def extract_feature(self, x): 110 | x, low_level=self.feature_extracter(x) 111 | x = self.semantic_embedding(x) 112 | return x, low_level 113 | 114 | def before_seghead_process(self, memory_prev_list=None, 115 | ref_frame_embedding=None, previous_frame_embedding=None, current_frame_embedding=None, 116 | ref_frame_label=None, previous_frame_mask=None, 117 | gt_ids=None, current_low_level=None, tf_board=False): 118 | 119 | cfg = self.cfg 120 | 121 | dic_tmp=[] 122 | bs,c,h,w = current_frame_embedding.size() 123 | 124 | if self.training: 125 | scale_ref_frame_label = torch.nn.functional.interpolate(ref_frame_label.float(),size=(h,w),mode='nearest') 126 | scale_ref_frame_label = scale_ref_frame_label.int() 127 | else: 128 | scale_ref_frame_labels = [] 129 | for each_ref_frame_label in ref_frame_label: 130 | each_scale_ref_frame_label = torch.nn.functional.interpolate(each_ref_frame_label.float(),size=(h,w),mode='nearest') 131 | each_scale_ref_frame_label = each_scale_ref_frame_label.int() 132 | scale_ref_frame_labels.append(each_scale_ref_frame_label) 133 | 134 | scale_previous_frame_label=torch.nn.functional.interpolate(previous_frame_mask.float(),size=(h,w),mode='nearest') 135 | scale_previous_frame_label=scale_previous_frame_label.int() 136 | 137 | boards = {'image': {}, 'scalar': {}} 138 | memory_cur_list = [] 139 | 140 | for n in range(bs): 141 | ref_obj_ids = torch.arange(0, gt_ids[n] + 1, device=current_frame_embedding.device).int().view(-1, 1, 1, 1) 142 | obj_num = ref_obj_ids.size(0) 143 | if gt_ids[n] > 0: 144 | dis_bias = torch.cat([self.bg_bias, self.fg_bias.expand(gt_ids[n], -1, -1, -1)], dim=0) 145 | else: 146 | dis_bias = self.bg_bias 147 | 148 | seq_current_frame_embedding = current_frame_embedding[n] 149 | seq_current_frame_embedding = seq_current_frame_embedding.permute(1,2,0) 150 | 151 | 152 | seq_prev_frame_embedding = previous_frame_embedding[n] 153 | seq_prev_frame_embedding = seq_prev_frame_embedding.permute(1,2,0) 154 | seq_previous_frame_label = (scale_previous_frame_label[n].int() == ref_obj_ids).float() 155 | to_cat_previous_frame = seq_previous_frame_label 156 | seq_previous_frame_label = seq_previous_frame_label.squeeze(1).permute(1,2,0) 157 | to_cat_current_frame_embedding = current_frame_embedding[n].unsqueeze(0).expand((obj_num,-1,-1,-1)) 158 | to_cat_prev_frame_embedding = previous_frame_embedding[n].unsqueeze(0).expand((obj_num,-1,-1,-1)) 159 | 160 | 161 | ## start global matching 162 | if self.training: 163 | seq_ref_frame_embedding = ref_frame_embedding[n] 164 | seq_ref_frame_embedding = seq_ref_frame_embedding.permute(1,2,0) 165 | 166 | seq_ref_frame_label = (scale_ref_frame_label[n].int() == ref_obj_ids).float() 167 | to_cat_ref_frame = seq_ref_frame_label 168 | seq_ref_frame_label = seq_ref_frame_label.squeeze(1).permute(1,2,0) 169 | 170 | global_matching_fg = global_matching( 171 | reference_embeddings=seq_ref_frame_embedding, 172 | query_embeddings=seq_current_frame_embedding, 173 | reference_labels=seq_ref_frame_label, 174 | n_chunks=cfg.TRAIN_GLOBAL_CHUNKS, 175 | dis_bias=dis_bias, 176 | atrous_rate=cfg.TRAIN_GLOBAL_ATROUS_RATE, 177 | use_float16=cfg.MODEL_FLOAT16_MATCHING) 178 | 179 | else: 180 | all_reference_embeddings = [] 181 | all_reference_labels = [] 182 | seq_ref_frame_labels = [] 183 | for idx in range(len(scale_ref_frame_labels)): 184 | each_ref_frame_embedding = ref_frame_embedding[idx] 185 | scale_ref_frame_label = scale_ref_frame_labels[idx] 186 | 187 | seq_ref_frame_embedding = each_ref_frame_embedding[n] 188 | seq_ref_frame_embedding = seq_ref_frame_embedding.permute(1,2,0) 189 | all_reference_embeddings.append(seq_ref_frame_embedding) 190 | 191 | seq_ref_frame_label = (scale_ref_frame_label[n].int() == ref_obj_ids).float() 192 | seq_ref_frame_labels.append(seq_ref_frame_label) 193 | seq_ref_frame_label = seq_ref_frame_label.squeeze(1).permute(1,2,0) 194 | all_reference_labels.append(seq_ref_frame_label) 195 | 196 | global_matching_fg = global_matching_for_eval( 197 | all_reference_embeddings=all_reference_embeddings, 198 | query_embeddings=seq_current_frame_embedding, 199 | all_reference_labels=all_reference_labels, 200 | n_chunks=cfg.TEST_GLOBAL_CHUNKS, 201 | dis_bias=dis_bias, 202 | atrous_rate=cfg.TEST_GLOBAL_ATROUS_RATE, 203 | use_float16=cfg.MODEL_FLOAT16_MATCHING) 204 | 205 | ## end global matching 206 | 207 | ## start global matching cluster (defualt num: 16) 208 | if self.training: 209 | seq_ref_frame_embedding = ref_frame_embedding[n] 210 | seq_ref_frame_embedding = seq_ref_frame_embedding.permute(1,2,0) 211 | 212 | seq_ref_frame_label = (scale_ref_frame_label[n].int() == ref_obj_ids).float() 213 | to_cat_ref_frame = seq_ref_frame_label 214 | seq_ref_frame_label = seq_ref_frame_label.squeeze(1).permute(1,2,0) 215 | 216 | global_matching_fg_cluster = global_matching_cluster2( 217 | reference_embeddings=seq_ref_frame_embedding, 218 | query_embeddings=seq_current_frame_embedding, 219 | reference_labels=seq_ref_frame_label, 220 | n_chunks=cfg.TRAIN_GLOBAL_CHUNKS, 221 | dis_bias=dis_bias, 222 | atrous_rate=cfg.TRAIN_GLOBAL_ATROUS_RATE, 223 | use_float16=cfg.MODEL_FLOAT16_MATCHING) 224 | 225 | else: 226 | all_reference_embeddings = [] 227 | all_reference_labels = [] 228 | seq_ref_frame_labels = [] 229 | for idx in range(len(scale_ref_frame_labels)): 230 | each_ref_frame_embedding = ref_frame_embedding[idx] 231 | scale_ref_frame_label = scale_ref_frame_labels[idx] 232 | 233 | seq_ref_frame_embedding = each_ref_frame_embedding[n] 234 | seq_ref_frame_embedding = seq_ref_frame_embedding.permute(1,2,0) 235 | all_reference_embeddings.append(seq_ref_frame_embedding) 236 | 237 | seq_ref_frame_label = (scale_ref_frame_label[n].int() == ref_obj_ids).float() 238 | seq_ref_frame_labels.append(seq_ref_frame_label) 239 | seq_ref_frame_label = seq_ref_frame_label.squeeze(1).permute(1,2,0) 240 | all_reference_labels.append(seq_ref_frame_label) 241 | 242 | global_matching_fg_cluster = global_matching_for_eval_cluster( 243 | all_reference_embeddings=all_reference_embeddings, 244 | query_embeddings=seq_current_frame_embedding, 245 | all_reference_labels=all_reference_labels, 246 | n_chunks=cfg.TEST_GLOBAL_CHUNKS, 247 | dis_bias=dis_bias, 248 | atrous_rate=cfg.TEST_GLOBAL_ATROUS_RATE, 249 | use_float16=cfg.MODEL_FLOAT16_MATCHING) 250 | 251 | ## end global matching cluster (defualt num: 16) 252 | 253 | ## start local matching 254 | 255 | local_matching_fg = local_matching( 256 | prev_frame_embedding=seq_prev_frame_embedding, 257 | query_embedding=seq_current_frame_embedding, 258 | prev_frame_labels=seq_previous_frame_label, 259 | multi_local_distance=cfg.MODEL_MULTI_LOCAL_DISTANCE, 260 | dis_bias=dis_bias, 261 | use_float16=cfg.MODEL_FLOAT16_MATCHING, 262 | atrous_rate=cfg.TRAIN_LOCAL_ATROUS_RATE if self.training else cfg.TEST_LOCAL_ATROUS_RATE, 263 | allow_downsample=cfg.MODEL_LOCAL_DOWNSAMPLE, 264 | allow_parallel=cfg.TRAIN_LOCAL_PARALLEL if self.training else cfg.TEST_LOCAL_PARALLEL) 265 | 266 | ## start end matching 267 | 268 | 269 | 270 | ## start refereced instance-level object representation calculation (proxy indicates cluster_num = 1) 271 | if self.training: 272 | attention_head = calculate_attention_head( 273 | ref_frame_embedding[n].unsqueeze(0).expand((obj_num,-1,-1,-1)), 274 | to_cat_ref_frame, 275 | previous_frame_embedding[n].unsqueeze(0).expand((obj_num,-1,-1,-1)), 276 | to_cat_previous_frame, 277 | epsilon=self.epsilon) 278 | else: 279 | attention_head = calculate_attention_head_for_eval( 280 | ref_frame_embedding, 281 | seq_ref_frame_labels, 282 | previous_frame_embedding[n].unsqueeze(0).expand((obj_num,-1,-1,-1)), 283 | to_cat_previous_frame, 284 | epsilon=self.epsilon) 285 | ## end refereced instance-level object representation calculation (proxy indicates cluster_num = 1) 286 | 287 | 288 | ## start proxy-based matching 289 | if self.training: 290 | attention_head,ref_head_pos, ref_head_neg, prev_head_pos, prev_head_neg= calculate_attention_head_p_m( 291 | ref_frame_embedding[n].unsqueeze(0).expand((obj_num,-1,-1,-1)), 292 | to_cat_ref_frame, 293 | previous_frame_embedding[n].unsqueeze(0).expand((obj_num,-1,-1,-1)), 294 | to_cat_previous_frame, 295 | epsilon=self.epsilon) 296 | else: 297 | attention_head,ref_head_pos, ref_head_neg, prev_head_pos, prev_head_neg = calculate_attention_head_for_eval_p_m( 298 | ref_frame_embedding, 299 | seq_ref_frame_labels, 300 | previous_frame_embedding[n].unsqueeze(0).expand((obj_num,-1,-1,-1)), 301 | to_cat_previous_frame, 302 | epsilon=self.epsilon) 303 | 304 | if self.training: 305 | global_matching_fg_proxy = global_matching_proxy( 306 | reference_embeddings=ref_head_pos, 307 | query_embeddings=seq_current_frame_embedding, 308 | reference_labels=seq_ref_frame_label, 309 | n_chunks=cfg.TRAIN_GLOBAL_CHUNKS, 310 | dis_bias=dis_bias, 311 | atrous_rate=cfg.TRAIN_GLOBAL_ATROUS_RATE, 312 | use_float16=cfg.MODEL_FLOAT16_MATCHING) 313 | else: 314 | global_matching_fg_proxy = global_matching_for_eval_proxy( 315 | all_reference_embeddings=ref_head_pos, 316 | query_embeddings=seq_current_frame_embedding, 317 | all_reference_labels=all_reference_labels, 318 | n_chunks=cfg.TEST_GLOBAL_CHUNKS, 319 | dis_bias=dis_bias, 320 | atrous_rate=cfg.TEST_GLOBAL_ATROUS_RATE, 321 | use_float16=cfg.MODEL_FLOAT16_MATCHING) 322 | 323 | 324 | 325 | seq_prev_frame_embedding_inst = torch.matmul(seq_previous_frame_label, prev_head_pos) 326 | 327 | 328 | local_matching_fg_proxy = local_matching_proxy( 329 | prev_frame_embedding=seq_prev_frame_embedding_inst, 330 | query_embedding=seq_current_frame_embedding, 331 | prev_frame_labels=seq_previous_frame_label, 332 | multi_local_distance=cfg.MODEL_MULTI_LOCAL_DISTANCE, 333 | dis_bias=dis_bias, 334 | use_float16=cfg.MODEL_FLOAT16_MATCHING, 335 | atrous_rate=cfg.TRAIN_LOCAL_ATROUS_RATE if self.training else cfg.TEST_LOCAL_ATROUS_RATE, 336 | allow_downsample=cfg.MODEL_LOCAL_DOWNSAMPLE, 337 | allow_parallel=cfg.TRAIN_LOCAL_PARALLEL if self.training else cfg.TEST_LOCAL_PARALLEL) 338 | 339 | 340 | 341 | to_cat_global_matching_fg_proxy = global_matching_fg_proxy.squeeze(0).permute(2,3,0,1) 342 | to_cat_global_matching_fg_cluster = global_matching_fg_cluster.squeeze(0).permute(2,3,0,1) 343 | to_cat_global_matching_fg = global_matching_fg.squeeze(0).permute(2,3,0,1) 344 | to_cat_local_matching_fg_proxy = local_matching_fg_proxy.squeeze(0).permute(2,3,0,1) 345 | to_cat_local_matching_fg = local_matching_fg.squeeze(0).permute(2,3,0,1) 346 | 347 | 348 | 349 | if cfg.MODEL_MATCHING_BACKGROUND: 350 | to_cat_global_matching_bg = foreground2background(to_cat_global_matching_fg, gt_ids[n] + 1) 351 | reshaped_prev_nn_feature_n = to_cat_local_matching_fg.permute(0, 2, 3, 1).unsqueeze(1) 352 | to_cat_local_matching_bg = foreground2background(reshaped_prev_nn_feature_n, gt_ids[n] + 1) 353 | to_cat_local_matching_bg = to_cat_local_matching_bg.permute(0, 4, 2, 3, 1).squeeze(-1) 354 | 355 | pre_to_cat = torch.cat((to_cat_global_matching_fg, to_cat_global_matching_fg_cluster, to_cat_global_matching_fg_proxy,to_cat_local_matching_fg, to_cat_local_matching_fg_proxy, to_cat_previous_frame), 1) 356 | 357 | if cfg.MODEL_MATCHING_BACKGROUND: 358 | pre_to_cat = torch.cat([pre_to_cat, to_cat_local_matching_bg, to_cat_global_matching_bg], 1) 359 | 360 | pre_to_cat = self.dynamic_prehead(pre_to_cat) 361 | 362 | to_cat = torch.cat((to_cat_current_frame_embedding, pre_to_cat),1) 363 | 364 | 365 | low_level_feat = current_low_level[n].unsqueeze(0) 366 | 367 | pred,memory_tmp_list = self.dynamic_seghead(to_cat, attention_head, memory_prev_list[n], low_level_feat,to_cat_previous_frame) 368 | memory_cur_list.append(memory_tmp_list) 369 | dic_tmp.append(pred) 370 | 371 | 372 | return dic_tmp, boards, memory_cur_list 373 | 374 | def get_module(): 375 | return AOCNet 376 | -------------------------------------------------------------------------------- /AOC-Net/complete_project/AOCNet/networks/aoc/conditioning_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class conditioning_layer(nn.Module): 7 | 8 | # Equation (7) of the main paper 9 | 10 | def __init__(self, 11 | in_dim=256, 12 | beta_percentage=0.3): 13 | super(conditioning_layer,self).__init__() 14 | 15 | self.beta_percentage = beta_percentage 16 | 17 | kernel_size = 1 18 | self.phi_layer = nn.Conv2d(in_dim,1,kernel_size=kernel_size,stride=1,padding=int((kernel_size-1)/2)) 19 | self.mlp_layer = nn.Linear(in_dim, in_dim) 20 | 21 | nn.init.kaiming_normal_(self.phi_layer.weight,mode='fan_out',nonlinearity='relu') 22 | 23 | 24 | def forward(self, z_in): 25 | 26 | # Step 1: phi(z_in) 27 | x = self.phi_layer(z_in) 28 | 29 | # Step 2: beta 30 | x = x.reshape(x.size()[0],x.size()[1],-1) 31 | z_in_reshape = z_in.reshape(z_in.size()[0],z_in.size()[1],-1) 32 | beta_rank = int(self.beta_percentage*z_in.size()[-1]*z_in.size()[-2]) 33 | beta_val, _ = torch.topk(x, k=beta_rank, dim=-1, sorted=True) 34 | 35 | # Step 3: pi_beta(phi(z_in)) 36 | x = x > beta_val[...,-1,None] 37 | 38 | # Step 4: z_in \odot pi_beta(phi(z_in)) 39 | z_in_masked = z_in_reshape * x 40 | 41 | # Step 5: GAP (z_in \odot pi_beta(phi(z_in))) 42 | z_in_masked_gap = torch.nn.functional.avg_pool1d(z_in_masked, 43 | kernel_size=z_in_masked.size()[-1]).squeeze(-1) 44 | 45 | # Step 6: MLP ( GAP (z_in \odot pi_beta(phi(z_in))) ) 46 | out = mlp_layer(z_in_masked_gap) 47 | 48 | return out 49 | 50 | class conditioning_block(nn.Module): 51 | 52 | # Equation (5) of the main paper 53 | 54 | def __init__(self, 55 | in_dim=256, 56 | proxy_dim = 400, 57 | beta_percentage=0.3): 58 | super(conditioning_block,self).__init__() 59 | 60 | self.CL_1 = conditioning_layer(in_dim, beta_percentage) 61 | self.CL_2 = conditioning_layer(in_dim, beta_percentage) 62 | self.CL_3 = conditioning_layer(proxy_dim, 1) 63 | 64 | self.mlp_layer = nn.Linear(in_dim * 2 + proxy_dim, in_dim) 65 | 66 | def forward(self, x, proxy_IA_head): 67 | 68 | px1 = torch.nn.functional.avg_pool2d(x,kernel_size=(x.size()[-2],x.size()[-1]),padding = 0) 69 | x_delta = (torch.sum(px1,dim=0,keepdim=True)-px1).squeeze(-1).squeeze(-1) 70 | 71 | # Step 1: cal intra-object conditioning code 72 | cl_out_1 = CL_1(x) 73 | 74 | # Step 2: cal inter-object conditioning code 75 | cl_out_2 = CL_2(x_delta) 76 | 77 | # Step 3: cal conditioning code with poxies 78 | cl_out_3 = CL_3(proxy_IA_head) 79 | 80 | # Step 4: conduct calibration 81 | a = self.mlp_layer(torch.cat([cl_out_1,cl_out_2,cl_out_3],dim=1)) 82 | a = 1. + torch.tanh(a) 83 | a = a.unsqueeze(-1).unsqueeze(-1) 84 | x = a * x 85 | 86 | return x 87 | -------------------------------------------------------------------------------- /AOC-Net/complete_project/AOCNet/networks/aoc/decoding_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from networks.layers.attention import IA_gate 5 | from networks.layers.gct import Bottleneck, GCT 6 | from networks.layers.aspp import ASPP 7 | from networks.p2t.conditioning_layer import conditioning_layer, conditioning_block 8 | 9 | 10 | class CalibrationDecoding(nn.Module): 11 | def __init__(self, 12 | in_dim=256, 13 | attention_dim=400, 14 | embed_dim=100, 15 | refine_dim=48, 16 | low_level_dim=256, 17 | beta_percentage=0.3): 18 | super(CalibrationDecoding,self).__init__() 19 | self.embed_dim = embed_dim 20 | IA_in_dim = attention_dim 21 | self.unc_topk_ratio = unc_topk_ratio 22 | self.IA1 = IA_gate(IA_in_dim, in_dim) 23 | self.layer1 = Bottleneck(in_dim, embed_dim) 24 | 25 | 26 | self.layer2 = Bottleneck(embed_dim, embed_dim, 1, 2) 27 | self.CLB2 = conditioning_block( 28 | in_dim=embed_dim, 29 | attention_dim=IA_in_dim, 30 | beta_percentage=self.beta_percentage) 31 | 32 | 33 | self.layer3 = Bottleneck(embed_dim, embed_dim * 2, 2) 34 | self.CLB3 = conditioning_block( 35 | in_dim=embed_dim, 36 | attention_dim=IA_in_dim, 37 | beta_percentage=self.beta_percentage) 38 | 39 | 40 | self.layer4 = Bottleneck(embed_dim * 2, embed_dim * 2, 1, 2) 41 | self.CLB4 = conditioning_block( 42 | in_dim=embed_dim*2, 43 | attention_dim=IA_in_dim, 44 | beta_percentage=self.beta_percentage) 45 | 46 | self.layer5 = Bottleneck(embed_dim * 2, embed_dim * 2, 1, 4) 47 | self.CLB5 = conditioning_block( 48 | in_dim=embed_dim*2, 49 | attention_dim=IA_in_dim, 50 | beta_percentage=self.beta_percentage) 51 | 52 | self.IA9 = IA_gate(IA_in_dim+embed_dim * 2, embed_dim * 2) 53 | self.ASPP = ASPP() 54 | 55 | self.M1_Reweight_Layer_1 = IA_gate(IA_in_dim, embed_dim * 2) 56 | self.M1_Bottleneck_1 = Bottleneck(embed_dim*2, embed_dim * 2, 1) 57 | 58 | self.M1_Reweight_Layer_2 = IA_gate(IA_in_dim, embed_dim * 2) 59 | self.M1_Bottleneck_2 = Bottleneck(embed_dim*2, embed_dim * 1, 1) 60 | 61 | self.M1_Reweight_Layer_3 = IA_gate(IA_in_dim, embed_dim * 1) 62 | self.M1_Bottleneck_3 = Bottleneck(embed_dim*1, embed_dim * 1, 1) 63 | 64 | self.M2_Reweight_Layer_1 = IA_gate(IA_in_dim, embed_dim * 2) 65 | self.M2_Bottleneck_1 = Bottleneck(embed_dim*2, embed_dim * 2, 1) 66 | 67 | self.M2_Reweight_Layer_2 = IA_gate(IA_in_dim, embed_dim * 2) 68 | self.M2_Bottleneck_2 = Bottleneck(embed_dim*2, embed_dim * 1, 1) 69 | 70 | self.M2_Reweight_Layer_3 = IA_gate(IA_in_dim, embed_dim *1) 71 | self.M2_Bottleneck_3 = Bottleneck(embed_dim*1, embed_dim * 1, 1) 72 | 73 | 74 | self.GCT_sc = GCT(low_level_dim + embed_dim) 75 | self.conv_sc = nn.Conv2d(low_level_dim + embed_dim, refine_dim, 1, bias=False) 76 | self.bn_sc = nn.GroupNorm(int(refine_dim / 4), refine_dim) 77 | self.relu = nn.ReLU(inplace=True) 78 | 79 | self.IA10 = IA_gate(IA_in_dim+embed_dim + refine_dim, embed_dim + refine_dim) 80 | self.conv1 = nn.Conv2d(embed_dim + refine_dim, int(embed_dim / 2), kernel_size=3, padding=1, bias=False) 81 | self.bn1 = nn.GroupNorm(32, int(embed_dim / 2)) 82 | 83 | 84 | self.IA11 = IA_gate(IA_in_dim+int(embed_dim / 2), int(embed_dim / 2)) 85 | self.conv2 = nn.Conv2d(int(embed_dim / 2), int(embed_dim / 2), kernel_size=3, padding=1, bias=False) 86 | self.bn2 = nn.GroupNorm(32, int(embed_dim / 2)) 87 | 88 | self.IA_final_fg = nn.Linear(IA_in_dim, int(embed_dim / 2) + 1) 89 | self.IA_final_bg = nn.Linear(IA_in_dim, int(embed_dim / 2) + 1) 90 | 91 | nn.init.kaiming_normal_(self.conv_sc.weight,mode='fan_out', nonlinearity='relu') 92 | nn.init.kaiming_normal_(self.conv1.weight,mode='fan_out', nonlinearity='relu') 93 | nn.init.kaiming_normal_(self.conv2.weight,mode='fan_out', nonlinearity='relu') 94 | 95 | 96 | def forward(self, x, IA_head=None,memory_list=None,low_level_feat=None,to_cat_previous_frame=None): 97 | 98 | 99 | x = self.IA1(x, IA_head) 100 | x = self.layer1(x) 101 | 102 | # start: Calibration #1 103 | x = self.CLB2(x,IA_head) 104 | # end: Calibration #1 105 | 106 | x = self.layer2(x) 107 | 108 | # start: Calibration #2 109 | x = self.CLB3(x,IA_head) 110 | # end: Calibration #2 111 | 112 | x = self.layer3(x) 113 | 114 | # start: Calibration #3 115 | x = self.CLB4(x,IA_head) 116 | # end: Calibration #3 117 | 118 | x = self.layer4(x) 119 | 120 | # start: Calibration #4) 121 | x = self.CLB5(x,IA_head) 122 | # end: Calibration #4 123 | 124 | x = self.layer5(x) 125 | 126 | px1 = torch.nn.functional.avg_pool2d(x,kernel_size=(x.size()[-2],x.size()[-1]),padding = 0) 127 | px1_sum = torch.sum(px1,dim=0,keepdim=True) 128 | px1_delta = (px1_sum-px1).squeeze(-1).squeeze(-1) 129 | 130 | x = self.IA9(x, torch.cat([IA_head,px1_delta],dim=1)) 131 | x = self.ASPP(x) 132 | 133 | x_emb_cur_1 = x.detach() 134 | if memory_list[0]==None or x_emb_cur_1.size()!=memory_list[0].size(): 135 | memory_list[0] = x_emb_cur_1 136 | x = self.Modulator_1(x, memory_list[0].cuda(x.device),IA_head) 137 | x_emb_cur_2 = x.detach() 138 | if memory_list[1]==None or x_emb_cur_2.size()!=memory_list[1].size(): 139 | memory_list[1] = x_emb_cur_2 140 | x = self.Modulator_2(x, memory_list[1].cuda(x.device),IA_head) 141 | 142 | x = self.decoder_final(x, low_level_feat, IA_head) 143 | 144 | fg_logit = self.IA_logit(x, IA_head, self.IA_final_fg) 145 | bg_logit = self.IA_logit(x, IA_head, self.IA_final_bg) 146 | 147 | pred = self.augment_background_logit(fg_logit, bg_logit) 148 | memory_list =[x_emb_cur_1.cpu(),memory_list[1].cpu()] 149 | return pred,memory_list 150 | 151 | def IA_logit(self, x, IA_head, IA_final): 152 | n, c, h, w = x.size() 153 | x = x.view(1, n * c, h, w) 154 | IA_output = IA_final(IA_head) 155 | IA_weight = IA_output[:, :c] 156 | IA_bias = IA_output[:, -1] 157 | IA_weight = IA_weight.view(n, c, 1, 1) 158 | IA_bias = IA_bias.view(-1) 159 | logit = F.conv2d(x, weight=IA_weight, bias=IA_bias, groups=n).view(n, 1, h, w) 160 | return logit 161 | 162 | def decoder_final(self, x, low_level_feat, IA_head): 163 | x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bicubic', align_corners=True) 164 | 165 | low_level_feat = self.GCT_sc(low_level_feat) 166 | low_level_feat = self.conv_sc(low_level_feat) 167 | low_level_feat = self.bn_sc(low_level_feat) 168 | low_level_feat = self.relu(low_level_feat) 169 | 170 | x = torch.cat([x, low_level_feat], dim=1) 171 | 172 | px1 = torch.nn.functional.avg_pool2d(x,kernel_size=(x.size()[-2],x.size()[-1]),padding = 0) 173 | px1_sum = torch.sum(px1,dim=0,keepdim=True) 174 | px1_delta = (px1_sum-px1).squeeze(-1).squeeze(-1) 175 | 176 | x = self.IA10(x, torch.cat([IA_head,px1_delta],dim=1)) 177 | x = self.conv1(x) 178 | x = self.bn1(x) 179 | x = self.relu(x) 180 | 181 | px1 = torch.nn.functional.avg_pool2d(x,kernel_size=(x.size()[-2],x.size()[-1]),padding = 0) 182 | px1_sum = torch.sum(px1,dim=0,keepdim=True) 183 | px1_delta = (px1_sum-px1).squeeze(-1).squeeze(-1) 184 | 185 | x = self.IA11(x, torch.cat([IA_head,px1_delta],dim=1)) 186 | x = self.conv2(x) 187 | x = self.bn2(x) 188 | x = self.relu(x) 189 | 190 | return x 191 | 192 | def Modulator_1(self, x, x_memory,IA_head): 193 | x = torch.cat([x, x_memory], dim=1) 194 | x = self.M1_Reweight_Layer_1(x, IA_head) 195 | x = self.M1_Bottleneck_1(x) 196 | x = self.M1_Reweight_Layer_2(x, IA_head) 197 | x = self.M1_Bottleneck_2(x) 198 | x = self.M1_Reweight_Layer_3(x, IA_head) 199 | x = self.M1_Bottleneck_3(x) 200 | return x 201 | 202 | def Modulator_2(self, x, x_memory,IA_head): 203 | x = torch.cat([x, x_memory], dim=1) 204 | x = self.M2_Reweight_Layer_1(x, IA_head) 205 | x = self.M2_Bottleneck_1(x) 206 | x = self.M2_Reweight_Layer_2(x, IA_head) 207 | x = self.M2_Bottleneck_2(x) 208 | x = self.M2_Reweight_Layer_3(x, IA_head) 209 | x = self.M2_Bottleneck_3(x) 210 | return x 211 | 212 | 213 | def augment_background_logit(self, fg_logit, bg_logit): 214 | # Augment the logit of absolute background by using the relative background logit of all the 215 | # foreground objects. 216 | obj_num = fg_logit.size(0) 217 | pred = fg_logit 218 | if obj_num > 1: 219 | bg_logit = bg_logit[1:obj_num, :, :, :] 220 | aug_bg_logit, _ = torch.min(bg_logit, dim=0, keepdim=True) 221 | pad = torch.zeros(aug_bg_logit.size(), device=aug_bg_logit.device).expand(obj_num - 1, -1, -1, -1) 222 | aug_bg_logit = torch.cat([aug_bg_logit, pad], dim=0) 223 | pred = pred + aug_bg_logit 224 | pred = pred.permute(1,0,2,3) 225 | return pred 226 | 227 | 228 | class DynamicPreHead(nn.Module): 229 | def __init__(self, in_dim=3, embed_dim=100, kernel_size=1): 230 | super(DynamicPreHead,self).__init__() 231 | self.conv=nn.Conv2d(in_dim,embed_dim,kernel_size=kernel_size,stride=1,padding=int((kernel_size-1)/2)) 232 | self.bn = nn.GroupNorm(int(embed_dim / 4), embed_dim) 233 | self.relu = nn.ReLU(True) 234 | nn.init.kaiming_normal_(self.conv.weight,mode='fan_out',nonlinearity='relu') 235 | 236 | def forward(self, x): 237 | x = self.conv(x) 238 | x = self.bn(x) 239 | x = self.relu(x) 240 | return x 241 | -------------------------------------------------------------------------------- /AOC-Net/complete_project/AOCNet/networks/deeplab/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/Robust-Video-Object-Segmentation/f562afe70ac0f960d27980d3e82b237f8dea45ec/AOC-Net/complete_project/AOCNet/networks/deeplab/__init__.py -------------------------------------------------------------------------------- /AOC-Net/complete_project/AOCNet/networks/deeplab/aspp.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class _ASPPModule(nn.Module): 7 | def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm): 8 | super(_ASPPModule, self).__init__() 9 | self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, 10 | stride=1, padding=padding, dilation=dilation, bias=False) 11 | self.bn = BatchNorm(planes) 12 | self.relu = nn.ReLU(inplace=True) 13 | 14 | self._init_weight() 15 | 16 | def forward(self, x): 17 | x = self.atrous_conv(x) 18 | x = self.bn(x) 19 | 20 | return self.relu(x) 21 | 22 | def _init_weight(self): 23 | for m in self.modules(): 24 | if isinstance(m, nn.Conv2d): 25 | torch.nn.init.kaiming_normal_(m.weight) 26 | elif isinstance(m, nn.BatchNorm2d): 27 | m.weight.data.fill_(1) 28 | m.bias.data.zero_() 29 | 30 | class ASPP(nn.Module): 31 | def __init__(self, backbone, output_stride, BatchNorm): 32 | super(ASPP, self).__init__() 33 | if backbone == 'drn': 34 | inplanes = 512 35 | elif backbone == 'mobilenet': 36 | inplanes = 320 37 | else: 38 | inplanes = 2048 39 | if output_stride == 16: 40 | dilations = [1, 6, 12, 18] 41 | elif output_stride == 8: 42 | dilations = [1, 12, 24, 36] 43 | else: 44 | raise NotImplementedError 45 | 46 | self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm) 47 | self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm) 48 | self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm) 49 | self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm) 50 | 51 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 52 | nn.Conv2d(inplanes, 256, 1, stride=1, bias=False), 53 | BatchNorm(256), 54 | nn.ReLU(inplace=True)) 55 | self.conv1 = nn.Conv2d(1280, 256, 1, bias=False) 56 | self.bn1 = BatchNorm(256) 57 | self.relu = nn.ReLU(inplace=True) 58 | self.dropout = nn.Dropout(0.1) 59 | self._init_weight() 60 | 61 | def forward(self, x): 62 | x1 = self.aspp1(x) 63 | x2 = self.aspp2(x) 64 | x3 = self.aspp3(x) 65 | x4 = self.aspp4(x) 66 | x5 = self.global_avg_pool(x) 67 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) 68 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 69 | 70 | x = self.conv1(x) 71 | x = self.bn1(x) 72 | x = self.relu(x) 73 | 74 | return self.dropout(x) 75 | 76 | def _init_weight(self): 77 | for m in self.modules(): 78 | if isinstance(m, nn.Conv2d): 79 | torch.nn.init.kaiming_normal_(m.weight) 80 | elif isinstance(m, nn.BatchNorm2d): 81 | m.weight.data.fill_(1) 82 | m.bias.data.zero_() 83 | 84 | 85 | def build_aspp(backbone, output_stride, BatchNorm): 86 | return ASPP(backbone, output_stride, BatchNorm) 87 | -------------------------------------------------------------------------------- /AOC-Net/complete_project/AOCNet/networks/deeplab/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from networks.deeplab.backbone import resnet, mobilenet 2 | 3 | def build_backbone(backbone, output_stride, BatchNorm): 4 | if backbone == 'resnet': 5 | return resnet.ResNet101(output_stride, BatchNorm) 6 | elif backbone == 'mobilenet': 7 | return mobilenet.MobileNetV2(output_stride, BatchNorm) 8 | else: 9 | raise NotImplementedError 10 | -------------------------------------------------------------------------------- /AOC-Net/complete_project/AOCNet/networks/deeplab/backbone/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/Robust-Video-Object-Segmentation/f562afe70ac0f960d27980d3e82b237f8dea45ec/AOC-Net/complete_project/AOCNet/networks/deeplab/backbone/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /AOC-Net/complete_project/AOCNet/networks/deeplab/backbone/__pycache__/mobilenet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/Robust-Video-Object-Segmentation/f562afe70ac0f960d27980d3e82b237f8dea45ec/AOC-Net/complete_project/AOCNet/networks/deeplab/backbone/__pycache__/mobilenet.cpython-36.pyc -------------------------------------------------------------------------------- /AOC-Net/complete_project/AOCNet/networks/deeplab/backbone/__pycache__/resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/Robust-Video-Object-Segmentation/f562afe70ac0f960d27980d3e82b237f8dea45ec/AOC-Net/complete_project/AOCNet/networks/deeplab/backbone/__pycache__/resnet.cpython-36.pyc -------------------------------------------------------------------------------- /AOC-Net/complete_project/AOCNet/networks/deeplab/backbone/mobilenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import math 5 | import torch.utils.model_zoo as model_zoo 6 | 7 | def conv_bn(inp, oup, stride, BatchNorm): 8 | return nn.Sequential( 9 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 10 | BatchNorm(oup), 11 | nn.ReLU6(inplace=True) 12 | ) 13 | 14 | 15 | def fixed_padding(inputs, kernel_size, dilation): 16 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 17 | pad_total = kernel_size_effective - 1 18 | pad_beg = pad_total // 2 19 | pad_end = pad_total - pad_beg 20 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) 21 | return padded_inputs 22 | 23 | 24 | class InvertedResidual(nn.Module): 25 | def __init__(self, inp, oup, stride, dilation, expand_ratio, BatchNorm): 26 | super(InvertedResidual, self).__init__() 27 | self.stride = stride 28 | assert stride in [1, 2] 29 | 30 | hidden_dim = round(inp * expand_ratio) 31 | self.use_res_connect = self.stride == 1 and inp == oup 32 | self.kernel_size = 3 33 | self.dilation = dilation 34 | 35 | if expand_ratio == 1: 36 | self.conv = nn.Sequential( 37 | # dw 38 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False), 39 | BatchNorm(hidden_dim), 40 | nn.ReLU6(inplace=True), 41 | # pw-linear 42 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, 1, bias=False), 43 | BatchNorm(oup), 44 | ) 45 | else: 46 | self.conv = nn.Sequential( 47 | # pw 48 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, 1, bias=False), 49 | BatchNorm(hidden_dim), 50 | nn.ReLU6(inplace=True), 51 | # dw 52 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False), 53 | BatchNorm(hidden_dim), 54 | nn.ReLU6(inplace=True), 55 | # pw-linear 56 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, bias=False), 57 | BatchNorm(oup), 58 | ) 59 | 60 | def forward(self, x): 61 | x_pad = fixed_padding(x, self.kernel_size, dilation=self.dilation) 62 | if self.use_res_connect: 63 | x = x + self.conv(x_pad) 64 | else: 65 | x = self.conv(x_pad) 66 | return x 67 | 68 | 69 | class MobileNetV2(nn.Module): 70 | def __init__(self, output_stride=8, BatchNorm=None, width_mult=1., pretrained=False): 71 | super(MobileNetV2, self).__init__() 72 | block = InvertedResidual 73 | input_channel = 32 74 | current_stride = 1 75 | rate = 1 76 | interverted_residual_setting = [ 77 | # t, c, n, s 78 | [1, 16, 1, 1], 79 | [6, 24, 2, 2], 80 | [6, 32, 3, 2], 81 | [6, 64, 4, 2], 82 | [6, 96, 3, 1], 83 | [6, 160, 3, 2], 84 | [6, 320, 1, 1], 85 | ] 86 | 87 | # building first layer 88 | input_channel = int(input_channel * width_mult) 89 | self.features = [conv_bn(3, input_channel, 2, BatchNorm)] 90 | current_stride *= 2 91 | # building inverted residual blocks 92 | for t, c, n, s in interverted_residual_setting: 93 | if current_stride == output_stride: 94 | stride = 1 95 | dilation = rate 96 | rate *= s 97 | else: 98 | stride = s 99 | dilation = 1 100 | current_stride *= s 101 | output_channel = int(c * width_mult) 102 | for i in range(n): 103 | if i == 0: 104 | self.features.append(block(input_channel, output_channel, stride, dilation, t, BatchNorm)) 105 | else: 106 | self.features.append(block(input_channel, output_channel, 1, rate, t, BatchNorm)) 107 | input_channel = output_channel 108 | self.features = nn.Sequential(*self.features) 109 | self._initialize_weights() 110 | 111 | if pretrained: 112 | self._load_pretrained_model() 113 | 114 | self.low_level_features = self.features[0:4] 115 | self.high_level_features = self.features[4:] 116 | 117 | self.feautre_8x = self.features[4:7] 118 | self.feature_16x = self.features[7:14] 119 | self.feature_32x = self.features[14:] 120 | 121 | def forward(self, x, return_mid_level=False): 122 | if return_mid_level: 123 | low_level_feat = self.low_level_features(x) 124 | mid_level_feat = self.feautre_8x(low_level_feat) 125 | x = self.feature_16x(mid_level_feat) 126 | x = self.feature_32x(x) 127 | return x, low_level_feat, mid_level_feat 128 | else: 129 | low_level_feat = self.low_level_features(x) 130 | x = self.high_level_features(low_level_feat) 131 | return x, low_level_feat 132 | 133 | def _load_pretrained_model(self): 134 | pretrain_dict = model_zoo.load_url('http://jeff95.me/models/mobilenet_v2-6a65762b.pth') 135 | model_dict = {} 136 | state_dict = self.state_dict() 137 | for k, v in pretrain_dict.items(): 138 | if k in state_dict: 139 | model_dict[k] = v 140 | state_dict.update(model_dict) 141 | self.load_state_dict(state_dict) 142 | 143 | def _initialize_weights(self): 144 | for m in self.modules(): 145 | if isinstance(m, nn.Conv2d): 146 | torch.nn.init.kaiming_normal_(m.weight) 147 | elif isinstance(m, nn.BatchNorm2d): 148 | m.weight.data.fill_(1) 149 | m.bias.data.zero_() 150 | 151 | if __name__ == "__main__": 152 | input = torch.rand(1, 3, 512, 512) 153 | model = MobileNetV2(output_stride=16, BatchNorm=nn.BatchNorm2d) 154 | output, low_level_feat = model(input) 155 | print(output.size()) 156 | print(low_level_feat.size()) 157 | -------------------------------------------------------------------------------- /AOC-Net/complete_project/AOCNet/networks/deeplab/backbone/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | class Bottleneck(nn.Module): 6 | expansion = 4 7 | 8 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None): 9 | super(Bottleneck, self).__init__() 10 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 11 | self.bn1 = BatchNorm(planes) 12 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 13 | dilation=dilation, padding=dilation, bias=False) 14 | self.bn2 = BatchNorm(planes) 15 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 16 | self.bn3 = BatchNorm(planes * 4) 17 | self.relu = nn.ReLU(inplace=True) 18 | self.downsample = downsample 19 | self.stride = stride 20 | self.dilation = dilation 21 | 22 | def forward(self, x): 23 | residual = x 24 | 25 | out = self.conv1(x) 26 | out = self.bn1(out) 27 | out = self.relu(out) 28 | 29 | out = self.conv2(out) 30 | out = self.bn2(out) 31 | out = self.relu(out) 32 | 33 | out = self.conv3(out) 34 | out = self.bn3(out) 35 | 36 | if self.downsample is not None: 37 | residual = self.downsample(x) 38 | 39 | out += residual 40 | out = self.relu(out) 41 | 42 | return out 43 | 44 | class ResNet(nn.Module): 45 | 46 | def __init__(self, block, layers, output_stride, BatchNorm, pretrained=False): 47 | self.inplanes = 64 48 | super(ResNet, self).__init__() 49 | blocks = [1, 2, 4] 50 | if output_stride == 16: 51 | strides = [1, 2, 2, 1] 52 | dilations = [1, 1, 1, 2] 53 | elif output_stride == 8: 54 | strides = [1, 2, 1, 1] 55 | dilations = [1, 1, 2, 4] 56 | else: 57 | raise NotImplementedError 58 | 59 | # Modules 60 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 61 | bias=False) 62 | self.bn1 = BatchNorm(64) 63 | self.relu = nn.ReLU(inplace=True) 64 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 65 | 66 | self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm) 67 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm) 68 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm) 69 | self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) 70 | self._init_weight() 71 | 72 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 73 | downsample = None 74 | if stride != 1 or self.inplanes != planes * block.expansion: 75 | downsample = nn.Sequential( 76 | nn.Conv2d(self.inplanes, planes * block.expansion, 77 | kernel_size=1, stride=stride, bias=False), 78 | BatchNorm(planes * block.expansion), 79 | ) 80 | 81 | layers = [] 82 | layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm)) 83 | self.inplanes = planes * block.expansion 84 | for i in range(1, blocks): 85 | layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm)) 86 | 87 | return nn.Sequential(*layers) 88 | 89 | def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 90 | downsample = None 91 | if stride != 1 or self.inplanes != planes * block.expansion: 92 | downsample = nn.Sequential( 93 | nn.Conv2d(self.inplanes, planes * block.expansion, 94 | kernel_size=1, stride=stride, bias=False), 95 | BatchNorm(planes * block.expansion), 96 | ) 97 | 98 | layers = [] 99 | layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation, 100 | downsample=downsample, BatchNorm=BatchNorm)) 101 | self.inplanes = planes * block.expansion 102 | for i in range(1, len(blocks)): 103 | layers.append(block(self.inplanes, planes, stride=1, 104 | dilation=blocks[i]*dilation, BatchNorm=BatchNorm)) 105 | 106 | return nn.Sequential(*layers) 107 | 108 | def forward(self, input, return_mid_level=False): 109 | x = self.conv1(input) 110 | x = self.bn1(x) 111 | x = self.relu(x) 112 | x = self.maxpool(x) 113 | 114 | x = self.layer1(x) 115 | low_level_feat = x 116 | x = self.layer2(x) 117 | mid_level_feat = x 118 | x = self.layer3(x) 119 | x = self.layer4(x) 120 | if return_mid_level: 121 | return x, low_level_feat, mid_level_feat 122 | else: 123 | return x, low_level_feat 124 | def _init_weight(self): 125 | for m in self.modules(): 126 | if isinstance(m, nn.Conv2d): 127 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 128 | m.weight.data.normal_(0, math.sqrt(2. / n)) 129 | elif isinstance(m, nn.BatchNorm2d): 130 | m.weight.data.fill_(1) 131 | m.bias.data.zero_() 132 | 133 | def _load_pretrained_model(self): 134 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth') 135 | model_dict = {} 136 | state_dict = self.state_dict() 137 | for k, v in pretrain_dict.items(): 138 | if k in state_dict: 139 | model_dict[k] = v 140 | state_dict.update(model_dict) 141 | self.load_state_dict(state_dict) 142 | 143 | def ResNet101(output_stride, BatchNorm, pretrained=True): 144 | """Constructs a ResNet-101 model. 145 | Args: 146 | pretrained (bool): If True, returns a model pre-trained on ImageNet 147 | """ 148 | model = ResNet(Bottleneck, [3, 4, 23, 3], output_stride, BatchNorm, pretrained=pretrained) 149 | return model 150 | 151 | if __name__ == "__main__": 152 | import torch 153 | model = ResNet101(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=8) 154 | input = torch.rand(1, 3, 512, 512) 155 | output, low_level_feat = model(input) 156 | print(output.size()) 157 | print(low_level_feat.size()) 158 | -------------------------------------------------------------------------------- /AOC-Net/complete_project/AOCNet/networks/deeplab/decoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class Decoder(nn.Module): 7 | def __init__(self, backbone, BatchNorm): 8 | super(Decoder, self).__init__() 9 | if backbone == 'resnet': 10 | low_level_inplanes = 256 11 | elif backbone == 'mobilenet': 12 | low_level_inplanes = 24 13 | else: 14 | raise NotImplementedError 15 | 16 | self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False) 17 | self.bn1 = BatchNorm(48) 18 | self.relu = nn.ReLU(inplace=True) 19 | 20 | self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False), 21 | BatchNorm(256), 22 | nn.ReLU(inplace=True), 23 | nn.Sequential(), 24 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), 25 | BatchNorm(256), 26 | nn.ReLU(inplace=True), 27 | nn.Sequential()) 28 | 29 | self._init_weight() 30 | 31 | 32 | def forward(self, x, low_level_feat): 33 | low_level_feat = self.conv1(low_level_feat) 34 | low_level_feat = self.bn1(low_level_feat) 35 | low_level_feat = self.relu(low_level_feat) 36 | 37 | x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True) 38 | x = torch.cat((x, low_level_feat), dim=1) 39 | x = self.last_conv(x) 40 | 41 | return x 42 | 43 | def _init_weight(self): 44 | for m in self.modules(): 45 | if isinstance(m, nn.Conv2d): 46 | torch.nn.init.kaiming_normal_(m.weight) 47 | elif isinstance(m, nn.BatchNorm2d): 48 | m.weight.data.fill_(1) 49 | m.bias.data.zero_() 50 | 51 | def build_decoder(backbone, BatchNorm): 52 | return Decoder(backbone, BatchNorm) 53 | -------------------------------------------------------------------------------- /AOC-Net/complete_project/AOCNet/networks/deeplab/deeplab.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from networks.deeplab.aspp import build_aspp 5 | from networks.deeplab.decoder import build_decoder 6 | from networks.deeplab.backbone import build_backbone 7 | from networks.layers.normalization import FrozenBatchNorm2d 8 | 9 | class DeepLab(nn.Module): 10 | def __init__(self, 11 | backbone='resnet', 12 | output_stride=16, 13 | freeze_bn=True): 14 | super(DeepLab, self).__init__() 15 | 16 | if freeze_bn == True: 17 | print("Use frozen BN in DeepLab!") 18 | BatchNorm = FrozenBatchNorm2d 19 | else: 20 | BatchNorm = nn.BatchNorm2d 21 | 22 | self.backbone = build_backbone(backbone, output_stride, BatchNorm) 23 | self.aspp = build_aspp(backbone, output_stride, BatchNorm) 24 | self.decoder = build_decoder(backbone, BatchNorm) 25 | 26 | 27 | def forward(self, input, return_aspp=False): 28 | if return_aspp: 29 | x, low_level_feat, mid_level_feat = self.backbone(input, True) 30 | else: 31 | x, low_level_feat = self.backbone(input) 32 | aspp_x = self.aspp(x) 33 | x = self.decoder(aspp_x, low_level_feat) 34 | 35 | if return_aspp: 36 | return x, aspp_x, low_level_feat, mid_level_feat 37 | else: 38 | return x, low_level_feat 39 | 40 | 41 | def get_1x_lr_params(self): 42 | modules = [self.backbone] 43 | for i in range(len(modules)): 44 | for m in modules[i].named_modules(): 45 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], nn.BatchNorm2d): 46 | for p in m[1].parameters(): 47 | if p.requires_grad: 48 | yield p 49 | 50 | def get_10x_lr_params(self): 51 | modules = [self.aspp, self.decoder] 52 | for i in range(len(modules)): 53 | for m in modules[i].named_modules(): 54 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], nn.BatchNorm2d): 55 | for p in m[1].parameters(): 56 | if p.requires_grad: 57 | yield p 58 | 59 | 60 | if __name__ == "__main__": 61 | model = DeepLab(backbone='resnet', output_stride=16) 62 | model.eval() 63 | input = torch.rand(2, 3, 513, 513) 64 | output = model(input) 65 | print(output.size()) 66 | 67 | 68 | -------------------------------------------------------------------------------- /AOC-Net/complete_project/AOCNet/networks/engine/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/Robust-Video-Object-Segmentation/f562afe70ac0f960d27980d3e82b237f8dea45ec/AOC-Net/complete_project/AOCNet/networks/engine/__init__.py -------------------------------------------------------------------------------- /AOC-Net/complete_project/AOCNet/networks/engine/eval_manager_mm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | import time 4 | import datetime as datetime 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | from torch.utils.data import DataLoader 9 | from torchvision import transforms 10 | import numpy as np 11 | from dataloaders.datasets_m import YOUTUBE_VOS_Test,DAVIS_Test 12 | import dataloaders.custom_transforms as tr 13 | from networks.deeplab.deeplab import DeepLab 14 | from utils.meters import AverageMeter 15 | from utils.image import flip_tensor, save_mask, save_matching_result 16 | from utils.checkpoint import load_network 17 | from utils.eval import zip_folder 18 | from networks.layers.shannon_entropy import cal_shannon_entropy 19 | import math 20 | 21 | class Evaluator(object): 22 | def __init__(self, cfg): 23 | 24 | self.mem_every = cfg.MEM_EVERY 25 | self.unc_ratio = cfg.UNC_RATIO 26 | 27 | self.gpu = cfg.TEST_GPU_ID 28 | self.cfg = cfg 29 | self.print_log(cfg.__dict__) 30 | print("Use GPU {} for evaluating".format(self.gpu)) 31 | torch.cuda.set_device(self.gpu) 32 | 33 | self.print_log('Build backbone.') 34 | self.feature_extracter = DeepLab( 35 | backbone=cfg.MODEL_BACKBONE, 36 | freeze_bn=cfg.MODEL_FREEZE_BN).cuda(self.gpu) 37 | 38 | self.print_log('Build VOS model.') 39 | CFBI = importlib.import_module(cfg.MODEL_MODULE) 40 | self.model = CFBI.get_module()( 41 | cfg, 42 | self.feature_extracter).cuda(self.gpu) 43 | 44 | self.process_pretrained_model() 45 | 46 | self.prepare_dataset() 47 | 48 | def process_pretrained_model(self): 49 | cfg = self.cfg 50 | if cfg.TEST_CKPT_PATH == 'test': 51 | self.ckpt = 'test' 52 | self.print_log('Test evaluation.') 53 | return 54 | if cfg.TEST_CKPT_PATH is None: 55 | if cfg.TEST_CKPT_STEP is not None: 56 | ckpt = str(cfg.TEST_CKPT_STEP) 57 | else: 58 | ckpts = os.listdir(cfg.DIR_CKPT) 59 | if len(ckpts) > 0: 60 | ckpts = list(map(lambda x: int(x.split('_')[-1].split('.')[0]), ckpts)) 61 | ckpt = np.sort(ckpts)[-1] 62 | else: 63 | self.print_log('No checkpoint in {}.'.format(cfg.DIR_CKPT)) 64 | exit() 65 | self.ckpt = ckpt 66 | cfg.TEST_CKPT_PATH = os.path.join(cfg.DIR_CKPT, 'save_step_%s.pth' % ckpt) 67 | self.model, removed_dict = load_network(self.model, cfg.TEST_CKPT_PATH, self.gpu) 68 | if len(removed_dict) > 0: 69 | self.print_log('Remove {} from pretrained model.'.format(removed_dict)) 70 | self.print_log('Load latest checkpoint from {}'.format(cfg.TEST_CKPT_PATH)) 71 | else: 72 | self.ckpt = 'unknown' 73 | self.model, removed_dict = load_network(self.model, cfg.TEST_CKPT_PATH, self.gpu) 74 | if len(removed_dict) > 0: 75 | self.print_log('Remove {} from pretrained model.'.format(removed_dict)) 76 | self.print_log('Load checkpoint from {}'.format(cfg.TEST_CKPT_PATH)) 77 | 78 | def prepare_dataset(self): 79 | cfg = self.cfg 80 | self.print_log('Process dataset...') 81 | eval_transforms = transforms.Compose([ 82 | tr.MultiRestrictSize(cfg.TEST_MIN_SIZE, cfg.TEST_MAX_SIZE, cfg.TEST_FLIP, cfg.TEST_MULTISCALE), 83 | tr.MultiToTensor()]) 84 | 85 | eval_name = '{}_{}_ckpt_{}'.format(cfg.TEST_DATASET, cfg.EXP_NAME, self.ckpt) 86 | if cfg.TEST_FLIP: 87 | eval_name += '_flip' 88 | if len(cfg.TEST_MULTISCALE) > 1: 89 | eval_name += '_ms' 90 | for ss in cfg.TEST_MULTISCALE: 91 | eval_name +="_" 92 | eval_name +=str(ss) 93 | 94 | eval_name+="_m_"+str(self.mem_every)+"_u_"+str(self.unc_ratio)+"_r_"+str(cfg.TEST_MAX_SIZE)+"_RPA" 95 | 96 | if cfg.TEST_DATASET == 'youtubevos19': 97 | self.result_root = os.path.join(cfg.DIR_EVALUATION, cfg.TEST_DATASET, eval_name, 'Annotations') 98 | self.dataset = YOUTUBE_VOS_Test( 99 | root=cfg.DIR_YTB_EVAL, 100 | transform=eval_transforms, 101 | result_root=self.result_root) 102 | 103 | elif cfg.TEST_DATASET == 'youtubevos18': 104 | self.result_root = os.path.join(cfg.DIR_EVALUATION, cfg.TEST_DATASET, eval_name, 'Annotations') 105 | self.dataset = YOUTUBE_VOS_Test( 106 | root=cfg.DIR_YTB_EVAL18, 107 | transform=eval_transforms, 108 | result_root=self.result_root) 109 | elif cfg.TEST_DATASET == 'youtubevos': ## predict all frame for youtubevos 2019 110 | self.result_root = os.path.join(cfg.DIR_EVALUATION, cfg.TEST_DATASET, eval_name, 'Annotations') 111 | self.dataset = YOUTUBE_VOS_Test( 112 | root=cfg.DIR_YTB_EVAL, 113 | transform=eval_transforms, 114 | result_root=self.result_root, 115 | use_all=True) 116 | elif cfg.TEST_DATASET == 'davis2017': 117 | resolution = 'Full-Resolution' if cfg.TEST_DATASET_FULL_RESOLUTION else '480p' 118 | self.result_root = os.path.join(cfg.DIR_EVALUATION, cfg.TEST_DATASET, eval_name, 'Annotations', resolution) 119 | self.dataset = DAVIS_Test( 120 | split=cfg.TEST_DATASET_SPLIT, 121 | root=cfg.DIR_DAVIS, 122 | year=2017, 123 | transform=eval_transforms, 124 | full_resolution=cfg.TEST_DATASET_FULL_RESOLUTION, 125 | result_root=self.result_root) 126 | elif cfg.TEST_DATASET == 'davis2017-label': 127 | resolution = 'Full-Resolution' if cfg.TEST_DATASET_FULL_RESOLUTION else '480p' 128 | self.result_root = os.path.join(cfg.DIR_EVALUATION, cfg.TEST_DATASET, eval_name, 'Annotations', resolution) 129 | self.dataset = DAVIS_Test_w_label( 130 | split=cfg.TEST_DATASET_SPLIT, 131 | root=cfg.DIR_DAVIS, 132 | year=2017, 133 | transform=eval_transforms, 134 | full_resolution=cfg.TEST_DATASET_FULL_RESOLUTION, 135 | result_root=self.result_root) 136 | elif cfg.TEST_DATASET == 'davis2016': 137 | resolution = 'Full-Resolution' if cfg.TEST_DATASET_FULL_RESOLUTION else '480p' 138 | self.result_root = os.path.join(cfg.DIR_EVALUATION, cfg.TEST_DATASET, eval_name, 'Annotations', resolution) 139 | self.dataset = DAVIS_Test( 140 | split=cfg.TEST_DATASET_SPLIT, 141 | root=cfg.DIR_DAVIS, 142 | year=2016, 143 | transform=eval_transforms, 144 | full_resolution=cfg.TEST_DATASET_FULL_RESOLUTION, 145 | result_root=self.result_root) 146 | elif cfg.TEST_DATASET == 'test': 147 | self.result_root = os.path.join(cfg.DIR_EVALUATION, cfg.TEST_DATASET, eval_name, 'Annotations') 148 | self.dataset = EVAL_TEST(eval_transforms, self.result_root) 149 | else: 150 | print('Unknown dataset!') 151 | exit() 152 | 153 | print('Eval {} on {}:'.format(cfg.EXP_NAME, cfg.TEST_DATASET)) 154 | self.source_folder = os.path.join(cfg.DIR_EVALUATION, cfg.TEST_DATASET, eval_name, 'Annotations') 155 | self.zip_dir = os.path.join(cfg.DIR_EVALUATION, cfg.TEST_DATASET, '{}.zip'.format(eval_name)) 156 | if not os.path.exists(self.result_root): 157 | os.makedirs(self.result_root) 158 | self.print_log('Done!') 159 | 160 | def evaluating(self): 161 | cfg = self.cfg 162 | self.model.eval() 163 | video_num = 0 164 | total_time = 0 165 | total_frame = 0 166 | total_sfps = 0 167 | total_video_num = len(self.dataset) 168 | PlaceHolder=[] 169 | for i in range(cfg.BLOCK_NUM): 170 | PlaceHolder.append(None) 171 | 172 | for seq_idx, seq_dataset in enumerate(self.dataset): 173 | video_num += 1 174 | seq_name = seq_dataset.seq_name 175 | 176 | print('Prcessing Seq {} [{}/{}]:'.format(seq_name, video_num, total_video_num)) 177 | 178 | torch.cuda.empty_cache() 179 | 180 | seq_dataloader=DataLoader(seq_dataset, batch_size=1, shuffle=False, num_workers=cfg.TEST_WORKERS, pin_memory=True) 181 | 182 | seq_total_time = 0 183 | seq_total_frame = 0 184 | ref_embeddings = [] 185 | ref_masks = [] 186 | prev_embedding = [] 187 | prev_mask = [] 188 | ref_mask_confident = [] 189 | memory_prev_all_list=[] 190 | memory_cur_all_list=[] 191 | memory_prev_list=[] 192 | memory_cur_list=[] 193 | label_all_list=[] 194 | 195 | with torch.no_grad(): 196 | for frame_idx, samples in enumerate(seq_dataloader): 197 | 198 | time_start = time.time() 199 | all_preds = [] 200 | 201 | join_label = None 202 | UPDATE=False 203 | 204 | 205 | if frame_idx==0: 206 | for aug_idx in range(len(samples)): 207 | memory_prev_all_list.append([PlaceHolder]) 208 | else: 209 | memory_prev_all_list=memory_cur_all_list 210 | 211 | memory_cur_all_list=[] 212 | for aug_idx in range(len(samples)): 213 | if len(ref_embeddings) <= aug_idx: 214 | ref_embeddings.append([]) 215 | ref_masks.append([]) 216 | prev_embedding.append(None) 217 | prev_mask.append(None) 218 | ref_mask_confident.append([]) 219 | 220 | sample = samples[aug_idx] 221 | ref_emb = ref_embeddings[aug_idx] 222 | #ref_m = ref_masks[aug_idx] 223 | 224 | ## use confident mask for correlation 225 | ref_m = ref_mask_confident[aug_idx] 226 | 227 | prev_emb = prev_embedding[aug_idx] 228 | prev_m = prev_mask[aug_idx] 229 | 230 | 231 | current_img = sample['current_img'] 232 | if 'current_label' in sample.keys(): 233 | current_label = sample['current_label'].cuda(self.gpu) 234 | else: 235 | current_label = None 236 | 237 | obj_list = sample['meta']['obj_list'] 238 | obj_num = sample['meta']['obj_num'] 239 | imgname = sample['meta']['current_name'] 240 | ori_height = sample['meta']['height'] 241 | ori_width = sample['meta']['width'] 242 | current_img = current_img.cuda(self.gpu) 243 | obj_num = obj_num.cuda(self.gpu) 244 | bs, _, h, w = current_img.size() 245 | 246 | all_pred, current_embedding,memory_cur_list = self.model.forward_for_eval(memory_prev_all_list[aug_idx], ref_emb, 247 | ref_m, prev_emb, prev_m, 248 | current_img, gt_ids=obj_num, 249 | pred_size=[ori_height,ori_width]) 250 | memory_cur_all_list.append(memory_cur_list) 251 | 252 | # delete the label that hasn't existed in the GT label 253 | all_pred_remake = [] 254 | all_pred_exist = [] 255 | if all_pred!=None: 256 | all_pred_split = all_pred.split(all_pred.size()[1],dim=1)[0] 257 | 258 | for i in range(all_pred.size()[1]): 259 | if i not in label_all_list: 260 | all_pred_remake.append(torch.zeros_like(all_pred_split[0][i]).unsqueeze(0)) 261 | else: 262 | all_pred_remake.append(all_pred_split[0][i].unsqueeze(0)) 263 | all_pred_exist.append(all_pred_split[0][i].unsqueeze(0)) 264 | all_pred = torch.cat(all_pred_remake,dim=0).unsqueeze(0) 265 | all_pred_exist = torch.cat(all_pred_exist,dim=0).unsqueeze(0) 266 | 267 | 268 | if 'current_label' in sample.keys(): 269 | label_cur_list = np.unique(sample['current_label'].cpu().detach().numpy()).tolist() 270 | for i in label_cur_list: 271 | if i not in label_all_list: 272 | label_all_list.append(i) 273 | 274 | if frame_idx == 0: 275 | if current_label is None: 276 | print("No first frame label in Seq {}.".format(seq_name)) 277 | ref_embeddings[aug_idx].append(current_embedding) 278 | ref_masks[aug_idx].append(current_label) 279 | ref_mask_confident[aug_idx].append(current_label) 280 | 281 | prev_embedding[aug_idx] = current_embedding 282 | prev_mask[aug_idx] = current_label 283 | 284 | else: 285 | if sample['meta']['flip']: 286 | all_pred = flip_tensor(all_pred, 3) 287 | # In YouTube-VOS, not all the objects appear in the first frame for the first time. Thus, we 288 | # have to introduce new labels for new objects, if necessary. 289 | if not sample['meta']['flip'] and not(current_label is None) and join_label is None: # gt exists here 290 | join_label = current_label 291 | all_preds.append(all_pred) 292 | 293 | all_pred_org = all_pred 294 | current_label_0 = None 295 | 296 | if current_label is not None: 297 | ref_embeddings[aug_idx].append(current_embedding) 298 | 299 | else: 300 | all_preds_0 = torch.cat(all_preds, dim=0) 301 | all_preds_0 = torch.mean(all_preds_0, dim=0) 302 | pred_label_0 = torch.argmax(all_preds_0, dim=0) 303 | current_label_0 = pred_label_0.view(1, 1, ori_height, ori_width) 304 | 305 | # uncertainty region filter 306 | uncertainty_org,uncertainty_norm = cal_shannon_entropy(all_pred_exist) 307 | 308 | # we set mem_every==-1 to indicate we don't use extra confident candidate pool 309 | if self.mem_every>-1 and frame_idx%self.mem_every==0 and frame_idx!=0 and current_embedding!=None and current_label_0!=None: 310 | ref_embeddings[aug_idx].append(current_embedding) 311 | ref_masks[aug_idx].append(current_label_0) 312 | UPDATE=True 313 | 314 | 315 | prev_embedding[aug_idx] = current_embedding 316 | 317 | if frame_idx > 0: 318 | all_preds = torch.cat(all_preds, dim=0) 319 | all_preds = torch.mean(all_preds, dim=0) 320 | pred_label = torch.argmax(all_preds, dim=0) 321 | if join_label is not None: 322 | join_label = join_label.squeeze(0).squeeze(0) 323 | keep = (join_label == 0).long() 324 | pred_label = pred_label * keep + join_label * (1 - keep) 325 | pred_label = pred_label 326 | current_label = pred_label.view(1, 1, ori_height, ori_width) 327 | if samples[aug_idx]['meta']['flip']: 328 | flip_pred_label = flip_tensor(pred_label, 1) 329 | flip_current_label = flip_pred_label.view(1, 1, ori_height, ori_width) 330 | 331 | for aug_idx in range(len(samples)): 332 | if join_label is not None: 333 | if samples[aug_idx]['meta']['flip']: 334 | ref_masks[aug_idx].append(flip_current_label) 335 | ref_mask_confident[aug_idx].append(flip_current_label) 336 | else: 337 | ref_masks[aug_idx].append(current_label) 338 | 339 | uncertainty_org,uncertainty_norm = cal_shannon_entropy(all_pred_exist) 340 | join_label = join_label.squeeze(0).squeeze(0) 341 | keep = (join_label == 0).long() 342 | join_uncertainty_map = (join_label <0).long() 343 | uncertainty_org = uncertainty_org * keep + join_uncertainty_map * (1 - keep) 344 | 345 | uncertainty_region = (uncertainty_org>self.unc_ratio ).long() 346 | pred_label_c = pred_label* (1 - uncertainty_region) + (125)* uncertainty_region 347 | pred_label_c = pred_label_c.view(1, 1, ori_height, ori_width) 348 | 349 | ref_mask_confident[aug_idx].append(pred_label_c) 350 | 351 | if samples[aug_idx]['meta']['flip']: 352 | prev_mask[aug_idx] = flip_current_label 353 | else: 354 | prev_mask[aug_idx] = current_label 355 | 356 | if UPDATE: 357 | if self.mem_every>-1 and frame_idx%self.mem_every==0 and frame_idx!=0 and current_embedding!=None and current_label_0!=None : 358 | uncertainty_region = (uncertainty_org>self.unc_ratio ).long() 359 | pred_label_c = pred_label* (1 - uncertainty_region) + (125)* uncertainty_region 360 | pred_label_c = pred_label_c.view(1, 1, ori_height, ori_width) 361 | ref_mask_confident[aug_idx].append(pred_label_c) 362 | 363 | one_frametime = time.time() - time_start 364 | seq_total_time += one_frametime 365 | seq_total_frame += 1 366 | obj_num = obj_num[0].item() 367 | print('Frame: {}, Obj Num: {}, Time: {}'.format(imgname[0], obj_num, one_frametime)) 368 | # Save result 369 | save_mask(pred_label, os.path.join(self.result_root, seq_name, imgname[0].split('.')[0]+'.png')) 370 | 371 | else: 372 | one_frametime = time.time() - time_start 373 | seq_total_time += one_frametime 374 | print('Ref Frame: {}, Time: {}'.format(imgname[0], one_frametime)) 375 | 376 | del(ref_embeddings) 377 | del(ref_masks) 378 | del(prev_embedding) 379 | del(prev_mask) 380 | del(seq_dataset) 381 | del(seq_dataloader) 382 | del(memory_cur_all_list) 383 | 384 | 385 | seq_avg_time_per_frame = seq_total_time / seq_total_frame 386 | total_time += seq_total_time 387 | total_frame += seq_total_frame 388 | total_avg_time_per_frame = total_time / total_frame 389 | total_sfps += seq_avg_time_per_frame 390 | avg_sfps = total_sfps / (seq_idx + 1) 391 | print("Seq {} FPS: {}, Total FPS: {}, FPS per Seq: {}".format(seq_name, 1./seq_avg_time_per_frame, 1./total_avg_time_per_frame, 1./avg_sfps)) 392 | 393 | zip_folder(self.source_folder, self.zip_dir) 394 | self.print_log('Save result to {}.'.format(self.zip_dir)) 395 | 396 | 397 | def print_log(self, string): 398 | print(string) 399 | 400 | 401 | 402 | 403 | 404 | -------------------------------------------------------------------------------- /AOC-Net/complete_project/AOCNet/networks/engine/train_manager_mm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | import time 4 | import datetime as datetime 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import torch.distributed as dist 9 | from torch.utils.data import DataLoader 10 | from torchvision import transforms 11 | import numpy as np 12 | from dataloaders.datasets_m import DAVIS2017_Train, YOUTUBE_VOS_Train, TEST 13 | import dataloaders.custom_transforms as tr 14 | from networks.deeplab.deeplab import DeepLab 15 | from utils.meters import AverageMeter 16 | from utils.image import label2colormap, masked_image, save_image 17 | from utils.checkpoint import load_network_and_optimizer, load_network, save_network 18 | from utils.learning import adjust_learning_rate, get_trainable_params 19 | from utils.metric import pytorch_iou 20 | #torch.backends.cudnn.enabled = True 21 | #torch.backends.cudnn.benchmark = True 22 | class Trainer(object): 23 | def __init__(self , rank, cfg): 24 | self.gpu = rank + cfg.DIST_START_GPU 25 | self.rank = rank 26 | self.cfg = cfg 27 | self.print_log(cfg.__dict__) 28 | print("Use GPU {} for training".format(self.gpu)) 29 | torch.cuda.set_device(self.gpu) 30 | 31 | self.print_log('Build backbone.') 32 | self.feature_extracter = DeepLab( 33 | backbone=cfg.MODEL_BACKBONE, 34 | freeze_bn=cfg.MODEL_FREEZE_BN).cuda(self.gpu) 35 | 36 | if cfg.MODEL_FREEZE_BACKBONE: 37 | for param in self.feature_extracter.parameters(): 38 | param.requires_grad = False 39 | 40 | self.print_log('Build VOS model.') 41 | CFBI = importlib.import_module(cfg.MODEL_MODULE) 42 | 43 | self.model = CFBI.get_module()( 44 | cfg, 45 | self.feature_extracter).cuda(self.gpu) 46 | 47 | if cfg.DIST_ENABLE: 48 | dist.init_process_group( 49 | backend=cfg.DIST_BACKEND, 50 | init_method=cfg.DIST_URL, 51 | world_size=cfg.TRAIN_GPUS, 52 | rank=rank, 53 | timeout=datetime.timedelta(seconds=300)) 54 | self.dist_model = torch.nn.parallel.DistributedDataParallel( 55 | self.model, 56 | device_ids=[self.gpu], 57 | find_unused_parameters=True) 58 | else: 59 | self.dist_model = self.model 60 | 61 | self.print_log('Build optimizer.') 62 | trainable_params = get_trainable_params( 63 | model=self.dist_model, 64 | base_lr=cfg.TRAIN_LR, 65 | weight_decay=cfg.TRAIN_WEIGHT_DECAY, 66 | beta_wd=cfg.MODEL_GCT_BETA_WD) 67 | 68 | self.optimizer = optim.SGD( 69 | trainable_params, 70 | lr=cfg.TRAIN_LR, 71 | momentum=cfg.TRAIN_MOMENTUM, 72 | nesterov=True) 73 | 74 | self.prepare_dataset() 75 | self.process_pretrained_model() 76 | 77 | if cfg.TRAIN_TBLOG and self.rank == 0: 78 | from tensorboardX import SummaryWriter 79 | self.tblogger = SummaryWriter(cfg.DIR_TB_LOG) 80 | 81 | def process_pretrained_model(self): 82 | cfg = self.cfg 83 | 84 | self.step = cfg.TRAIN_START_STEP 85 | self.epoch = 0 86 | 87 | if cfg.TRAIN_AUTO_RESUME: 88 | ckpts = os.listdir(cfg.DIR_CKPT) 89 | if len(ckpts) > 0: 90 | ckpts = list(map(lambda x: int(x.split('_')[-1].split('.')[0]), ckpts)) 91 | ckpt = np.sort(ckpts)[-1] 92 | cfg.TRAIN_RESUME = True 93 | cfg.TRAIN_RESUME_CKPT = ckpt 94 | cfg.TRAIN_RESUME_STEP = ckpt + 1 95 | else: 96 | cfg.TRAIN_RESUME = False 97 | 98 | if cfg.TRAIN_RESUME: 99 | resume_ckpt = os.path.join(cfg.DIR_CKPT, 'save_step_%s.pth' % (cfg.TRAIN_RESUME_CKPT)) 100 | 101 | self.model, self.optimizer, removed_dict = load_network_and_optimizer(self.model, self.optimizer, resume_ckpt, self.gpu) 102 | 103 | if len(removed_dict) > 0: 104 | self.print_log('Remove {} from checkpoint.'.format(removed_dict)) 105 | 106 | self.step = cfg.TRAIN_RESUME_STEP 107 | if cfg.TRAIN_TOTAL_STEPS <= self.step: 108 | self.print_log("Your training has finished!") 109 | exit() 110 | self.epoch = int(np.ceil(self.step / len(self.trainloader))) 111 | 112 | self.print_log('Resume from step {}'.format(self.step)) 113 | 114 | elif cfg.PRETRAIN: 115 | if cfg.PRETRAIN_FULL: 116 | self.model, removed_dict = load_network(self.model, cfg.PRETRAIN_MODEL, self.gpu) 117 | if len(removed_dict) > 0: 118 | self.print_log('Remove {} from pretrained model.'.format(removed_dict)) 119 | self.print_log('Load pretrained VOS model from {}.'.format(cfg.PRETRAIN_MODEL)) 120 | else: 121 | feature_extracter, removed_dict = load_network(self.feature_extracter, cfg.PRETRAIN_MODEL, self.gpu) 122 | if len(removed_dict) > 0: 123 | self.print_log('Remove {} from pretrained model.'.format(removed_dict)) 124 | self.print_log('Load pretrained backbone model from {}.'.format(cfg.PRETRAIN_MODEL)) 125 | 126 | def prepare_dataset(self): 127 | cfg = self.cfg 128 | self.print_log('Process dataset...') 129 | composed_transforms = transforms.Compose([ 130 | tr.RandomScale(cfg.DATA_MIN_SCALE_FACTOR, cfg.DATA_MAX_SCALE_FACTOR, cfg.DATA_SHORT_EDGE_LEN), 131 | tr.BalancedRandomCrop(cfg.DATA_RANDOMCROP), 132 | tr.RandomHorizontalFlip(cfg.DATA_RANDOMFLIP), 133 | tr.Resize(cfg.DATA_RANDOMCROP), 134 | tr.ToTensor()]) 135 | 136 | train_datasets = [] 137 | if 'davis2017' in cfg.DATASETS: 138 | train_davis_dataset = DAVIS2017_Train( 139 | root=cfg.DIR_DAVIS, 140 | full_resolution=cfg.TRAIN_DATASET_FULL_RESOLUTION, 141 | transform=composed_transforms, 142 | repeat_time=cfg.DATA_DAVIS_REPEAT, 143 | curr_len=cfg.DATA_CURR_SEQ_LEN, 144 | rand_gap=cfg.DATA_RANDOM_GAP_DAVIS, 145 | rand_reverse=cfg.DATA_RANDOM_REVERSE_SEQ) 146 | train_datasets.append(train_davis_dataset) 147 | 148 | if 'youtubevos' in cfg.DATASETS: 149 | train_ytb_dataset = YOUTUBE_VOS_Train( 150 | root=cfg.DIR_YTB, 151 | transform=composed_transforms, 152 | curr_len=cfg.DATA_CURR_SEQ_LEN, 153 | rand_gap=cfg.DATA_RANDOM_GAP_YTB, 154 | rand_reverse=cfg.DATA_RANDOM_REVERSE_SEQ) 155 | train_datasets.append(train_ytb_dataset) 156 | 157 | if 'test' in cfg.DATASETS: 158 | test_dataset = TEST( 159 | transform=composed_transforms, 160 | curr_len=cfg.DATA_CURR_SEQ_LEN) 161 | train_datasets.append(test_dataset) 162 | 163 | if len(train_datasets) > 1: 164 | train_dataset = torch.utils.data.ConcatDataset(train_datasets) 165 | elif len(train_datasets) == 1: 166 | train_dataset = train_datasets[0] 167 | else: 168 | self.print_log('No dataset!') 169 | exit(0) 170 | 171 | self.train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 172 | self.trainloader = DataLoader( 173 | train_dataset, 174 | batch_size=int(cfg.TRAIN_BATCH_SIZE / cfg.TRAIN_GPUS), 175 | shuffle=False, 176 | num_workers=cfg.DATA_WORKERS, 177 | pin_memory=True, 178 | sampler=self.train_sampler) 179 | 180 | self.print_log('Done!') 181 | 182 | def sequential_training(self): 183 | 184 | cfg = self.cfg 185 | 186 | running_losses = [] 187 | running_ious = [] 188 | for _ in range(cfg.DATA_CURR_SEQ_LEN): 189 | running_losses.append(AverageMeter()) 190 | running_ious.append(AverageMeter()) 191 | batch_time = AverageMeter() 192 | avg_obj = AverageMeter() 193 | 194 | optimizer = self.optimizer 195 | model = self.dist_model 196 | train_sampler = self.train_sampler 197 | trainloader = self.trainloader 198 | step = self.step 199 | epoch = self.epoch 200 | max_itr = cfg.TRAIN_TOTAL_STEPS 201 | 202 | PlaceHolder=[] 203 | for i in range(cfg.BLOCK_NUM): 204 | PlaceHolder.append(None) 205 | 206 | self.print_log('Start training.') 207 | model.train() 208 | while step < cfg.TRAIN_TOTAL_STEPS: 209 | train_sampler.set_epoch(epoch) 210 | epoch += 1 211 | last_time = time.time() 212 | for frame_idx, sample in enumerate(trainloader): 213 | now_lr = adjust_learning_rate( 214 | optimizer=optimizer, 215 | base_lr=cfg.TRAIN_LR, 216 | p=cfg.TRAIN_POWER, 217 | itr=step, 218 | max_itr=max_itr, 219 | warm_up_steps=cfg.TRAIN_WARM_UP_STEPS, 220 | is_cosine_decay=cfg.TRAIN_COSINE_DECAY) 221 | 222 | ref_imgs = sample['ref_img'] # batch_size * 3 * h * w 223 | prev_imgs = sample['prev_img'] 224 | curr_imgs = sample['curr_img'][0] 225 | ref_labels = sample['ref_label'] # batch_size * 1 * h * w 226 | prev_labels = sample['prev_label'] 227 | curr_labels = sample['curr_label'][0] 228 | obj_nums = sample['meta']['obj_num'] 229 | bs, _, h, w = curr_imgs.size() 230 | 231 | ref_labels = ref_labels.cuda(self.gpu) 232 | prev_labels = prev_labels.cuda(self.gpu) 233 | curr_labels = curr_labels.cuda(self.gpu) 234 | obj_nums = obj_nums.cuda(self.gpu) 235 | 236 | if step % cfg.TRAIN_TBLOG_STEP == 0 and self.rank == 0 and cfg.TRAIN_TBLOG: 237 | tf_board = True 238 | else: 239 | tf_board = False 240 | 241 | # Sequential training 242 | all_boards = [] 243 | curr_imgs = prev_imgs 244 | curr_labels = prev_labels 245 | all_pred = prev_labels.squeeze(1) 246 | optimizer.zero_grad() 247 | memory_cur_list=[] 248 | memory_prev_list=[] 249 | for iii in range(int(cfg.TRAIN_BATCH_SIZE//cfg.TRAIN_GPUS)): 250 | memory_cur_list.append(PlaceHolder) 251 | memory_prev_list.append(PlaceHolder) 252 | 253 | for idx in range(cfg.DATA_CURR_SEQ_LEN): 254 | prev_imgs = curr_imgs 255 | curr_imgs = sample['curr_img'][idx] 256 | inputs = torch.cat((ref_imgs, prev_imgs, curr_imgs), 0).cuda(self.gpu) 257 | if step > cfg.TRAIN_START_SEQ_TRAINING_STEPS: 258 | # Use previous prediction instead of ground-truth mask 259 | prev_labels = all_pred.unsqueeze(1) 260 | else: 261 | # Use previous ground-truth mask 262 | prev_labels = curr_labels 263 | curr_labels = sample['curr_label'][idx].cuda(self.gpu) 264 | 265 | loss, all_pred, boards,memory_cur_list = model( 266 | inputs, 267 | memory_prev_list, 268 | ref_labels, 269 | prev_labels, 270 | curr_labels, 271 | gt_ids=obj_nums, 272 | step=step, 273 | tf_board=tf_board) 274 | 275 | memory_prev_list = memory_cur_list 276 | 277 | iou = pytorch_iou(all_pred.unsqueeze(1), curr_labels, obj_nums) 278 | loss = torch.mean(loss) / cfg.DATA_CURR_SEQ_LEN 279 | loss.backward() 280 | all_boards.append(boards) 281 | running_losses[idx].update(loss.item() * cfg.DATA_CURR_SEQ_LEN) 282 | running_ious[idx].update(iou.item()) 283 | torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.TRAIN_CLIP_GRAD_NORM) 284 | optimizer.step() 285 | batch_time.update(time.time() - last_time) 286 | avg_obj.update(obj_nums.float().mean().item()) 287 | last_time = time.time() 288 | 289 | if step % cfg.TRAIN_TBLOG_STEP == 0 and self.rank == 0: 290 | self.process_log( 291 | ref_imgs, prev_imgs, curr_imgs, 292 | ref_labels, prev_labels, curr_labels, 293 | all_pred, all_boards, running_losses, running_ious, now_lr, step) 294 | 295 | if step % cfg.TRAIN_LOG_STEP == 0 and self.rank == 0: 296 | strs = 'Itr:{}, LR:{:.7f}, Time:{:.3f}, Obj:{:.1f}'.format(step, now_lr, batch_time.avg, avg_obj.avg) 297 | batch_time.reset() 298 | avg_obj.reset() 299 | for idx in range(cfg.DATA_CURR_SEQ_LEN): 300 | strs += ', S{}: L {:.3f}({:.3f}) IoU {:.3f}({:.3f})'.format(idx, running_losses[idx].val, running_losses[idx].avg, 301 | running_ious[idx].val, running_ious[idx].avg) 302 | running_losses[idx].reset() 303 | running_ious[idx].reset() 304 | 305 | self.print_log(strs) 306 | 307 | if step % cfg.TRAIN_SAVE_STEP == 0 and step != 0 and self.rank == 0: 308 | self.print_log('Save CKPT (Step {}).'.format(step)) 309 | save_network(self.model, optimizer, step, cfg.DIR_CKPT, cfg.TRAIN_MAX_KEEP_CKPT) 310 | 311 | step += 1 312 | if step > cfg.TRAIN_TOTAL_STEPS: 313 | break 314 | 315 | if self.rank == 0: 316 | self.print_log('Save final CKPT (Step {}).'.format(step - 1)) 317 | save_network(self.model, optimizer, step - 1, cfg.DIR_CKPT, cfg.TRAIN_MAX_KEEP_CKPT) 318 | 319 | def print_log(self, string): 320 | if self.rank == 0: 321 | print(string) 322 | 323 | 324 | def process_log(self, 325 | ref_imgs, prev_imgs, curr_imgs, 326 | ref_labels, prev_labels, curr_labels, 327 | curr_pred, all_boards, running_losses, running_ious, now_lr, step): 328 | cfg = self.cfg 329 | 330 | mean = np.array([[[0.485]], [[0.456]], [[0.406]]]) 331 | sigma = np.array([[[0.229]], [[0.224]], [[0.225]]]) 332 | 333 | show_ref_img, show_prev_img, show_curr_img = [img.cpu().numpy()[0] * sigma + mean for img in [ref_imgs, prev_imgs, curr_imgs]] 334 | 335 | show_gt, show_prev_gt, show_ref_gt, show_preds_s = [label.cpu()[0].squeeze(0).numpy() for label in [curr_labels, prev_labels, ref_labels, curr_pred]] 336 | 337 | show_gtf, show_prev_gtf, show_ref_gtf, show_preds_sf = [label2colormap(label).transpose((2,0,1)) for label in [show_gt, show_prev_gt, show_ref_gt, show_preds_s]] 338 | 339 | if cfg.TRAIN_IMG_LOG or cfg.TRAIN_TBLOG: 340 | 341 | show_ref_img = masked_image(show_ref_img, show_ref_gtf, show_ref_gt) 342 | if cfg.TRAIN_IMG_LOG: 343 | save_image(show_ref_img, os.path.join(cfg.DIR_IMG_LOG, '%06d_ref_img.jpeg' % (step))) 344 | 345 | show_prev_img = masked_image(show_prev_img, show_prev_gtf, show_prev_gt) 346 | if cfg.TRAIN_IMG_LOG: 347 | save_image(show_prev_img, os.path.join(cfg.DIR_IMG_LOG, '%06d_prev_img.jpeg' % (step))) 348 | 349 | show_img_pred = masked_image(show_curr_img, show_preds_sf, show_preds_s) 350 | if cfg.TRAIN_IMG_LOG: 351 | save_image(show_img_pred, os.path.join(cfg.DIR_IMG_LOG, '%06d_prediction.jpeg' % (step))) 352 | 353 | show_curr_img = masked_image(show_curr_img, show_gtf, show_gt) 354 | if cfg.TRAIN_IMG_LOG: 355 | save_image(show_curr_img, os.path.join(cfg.DIR_IMG_LOG, '%06d_groundtruth.jpeg' % (step))) 356 | 357 | if cfg.TRAIN_TBLOG: 358 | for seq_step, running_loss, running_iou in zip(range(len(running_losses)), running_losses, running_ious): 359 | self.tblogger.add_scalar('S{}/Loss'.format(seq_step), running_loss.avg, step) 360 | self.tblogger.add_scalar('S{}/IoU'.format(seq_step), running_iou.avg, step) 361 | 362 | self.tblogger.add_scalar('LR', now_lr, step) 363 | self.tblogger.add_image('Ref/Image', show_ref_img, step) 364 | self.tblogger.add_image('Ref/GT', show_ref_gtf, step) 365 | 366 | self.tblogger.add_image('Prev/Image', show_prev_img, step) 367 | self.tblogger.add_image('Prev/GT', show_prev_gtf, step) 368 | 369 | self.tblogger.add_image('Curr/Image_GT', show_curr_img, step) 370 | self.tblogger.add_image('Curr/Image_Pred', show_img_pred, step) 371 | 372 | self.tblogger.add_image('Curr/Mask_GT', show_gtf, step) 373 | self.tblogger.add_image('Curr/Mask_Pred', show_preds_sf, step) 374 | 375 | for seq_step, boards in enumerate(all_boards): 376 | for key in boards['image'].keys(): 377 | tmp = boards['image'][key].cpu().numpy() 378 | self.tblogger.add_image('S{}/' + key, tmp, step) 379 | for key in boards['scalar'].keys(): 380 | tmp = boards['scalar'][key].cpu().numpy() 381 | self.tblogger.add_scalar('S{}/' + key, tmp, step) 382 | 383 | self.tblogger.flush() 384 | 385 | del(all_boards) 386 | 387 | 388 | -------------------------------------------------------------------------------- /AOC-Net/complete_project/AOCNet/networks/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/Robust-Video-Object-Segmentation/f562afe70ac0f960d27980d3e82b237f8dea45ec/AOC-Net/complete_project/AOCNet/networks/layers/__init__.py -------------------------------------------------------------------------------- /AOC-Net/complete_project/AOCNet/networks/layers/aspp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import math 4 | from torch import nn 5 | from networks.layers.gct import GCT 6 | 7 | class _ASPPModule(nn.Module): 8 | def __init__(self, inplanes, planes, kernel_size, padding, dilation): 9 | super(_ASPPModule, self).__init__() 10 | self.GCT = GCT(inplanes) 11 | self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, 12 | stride=1, padding=padding, dilation=dilation, bias=False) 13 | self.bn = nn.GroupNorm(int(planes / 4), planes) 14 | self.relu = nn.ReLU(inplace=True) 15 | 16 | self._init_weight() 17 | 18 | def forward(self, x): 19 | x = self.GCT(x) 20 | x = self.atrous_conv(x) 21 | x = self.bn(x) 22 | 23 | return self.relu(x) 24 | 25 | def _init_weight(self): 26 | for m in self.modules(): 27 | if isinstance(m, nn.Conv2d): 28 | torch.nn.init.kaiming_normal_(m.weight) 29 | elif isinstance(m, nn.BatchNorm2d): 30 | m.weight.data.fill_(1) 31 | m.bias.data.zero_() 32 | 33 | class ASPP(nn.Module): 34 | def __init__(self): 35 | super(ASPP, self).__init__() 36 | 37 | inplanes = 512 38 | dilations = [1, 6, 12, 18] 39 | 40 | 41 | self.aspp1 = _ASPPModule(inplanes, 128, 1, padding=0, dilation=dilations[0]) 42 | self.aspp2 = _ASPPModule(inplanes, 128, 3, padding=dilations[1], dilation=dilations[1]) 43 | self.aspp3 = _ASPPModule(inplanes, 128, 3, padding=dilations[2], dilation=dilations[2]) 44 | self.aspp4 = _ASPPModule(inplanes, 128, 3, padding=dilations[3], dilation=dilations[3]) 45 | 46 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 47 | nn.Conv2d(inplanes, 128, 1, stride=1, bias=False), 48 | nn.ReLU(inplace=True)) 49 | 50 | self.GCT = GCT(640) 51 | self.conv1 = nn.Conv2d(640, 256, 1, bias=False) 52 | self.bn1 = nn.GroupNorm(32, 256) 53 | self.relu = nn.ReLU(inplace=True) 54 | self._init_weight() 55 | 56 | def forward(self, x): 57 | x1 = self.aspp1(x) 58 | x2 = self.aspp2(x) 59 | x3 = self.aspp3(x) 60 | x4 = self.aspp4(x) 61 | x5 = self.global_avg_pool(x) 62 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) 63 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 64 | 65 | x = self.GCT(x) 66 | x = self.conv1(x) 67 | x = self.bn1(x) 68 | x = self.relu(x) 69 | 70 | return x 71 | 72 | def _init_weight(self): 73 | for m in self.modules(): 74 | if isinstance(m, nn.Conv2d): 75 | torch.nn.init.kaiming_normal_(m.weight) 76 | elif isinstance(m, nn.BatchNorm2d): 77 | m.weight.data.fill_(1) 78 | m.bias.data.zero_() -------------------------------------------------------------------------------- /AOC-Net/complete_project/AOCNet/networks/layers/conv_gru.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from torch.nn import init 6 | 7 | 8 | class ConvGRUCell(nn.Module): 9 | """ 10 | Generate a convolutional GRU cell 11 | """ 12 | 13 | def __init__(self, input_size, hidden_size, kernel_size): 14 | super().__init__() 15 | padding = kernel_size // 2 16 | self.input_size = input_size 17 | self.hidden_size = hidden_size 18 | self.reset_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding) 19 | self.update_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding) 20 | self.out_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding) 21 | 22 | init.orthogonal(self.reset_gate.weight) 23 | init.orthogonal(self.update_gate.weight) 24 | init.orthogonal(self.out_gate.weight) 25 | init.constant(self.reset_gate.bias, 0.) 26 | init.constant(self.update_gate.bias, 0.) 27 | init.constant(self.out_gate.bias, 0.) 28 | 29 | 30 | def forward(self, input_, prev_state): 31 | 32 | # get batch and spatial sizes 33 | batch_size = input_.data.size()[0] 34 | spatial_size = input_.data.size()[2:] 35 | 36 | # generate empty prev_state, if None is provided 37 | if prev_state is None: 38 | state_size = [batch_size, self.hidden_size] + list(spatial_size) 39 | if torch.cuda.is_available(): 40 | prev_state = torch.zeros(state_size).cuda() 41 | else: 42 | prev_state = torch.zeros(state_size) 43 | print("input_.size()",input_.size()) 44 | print("prev_state.size()",prev_state.size()) 45 | if input_.size()[0]!=prev_state.size()[0]: 46 | 47 | state_size = [input_.size()[0]-prev_state.size()[0], self.hidden_size] + list(spatial_size) 48 | if torch.cuda.is_available(): 49 | prev_state_tmp = torch.zeros(state_size).cuda() 50 | else: 51 | prev_state_tmp = torch.zeros(state_size) 52 | 53 | prev_state = torch.cat([prev_state, prev_state_tmp],dim=0) 54 | print("add_on prev_state.size()", prev_state.size()) 55 | # data size is [batch, channel, height, width] 56 | stacked_inputs = torch.cat([input_, prev_state], dim=1) 57 | update = F.sigmoid(self.update_gate(stacked_inputs)) 58 | reset = F.sigmoid(self.reset_gate(stacked_inputs)) 59 | out_inputs = F.tanh(self.out_gate(torch.cat([input_, prev_state * reset], dim=1))) 60 | new_state = prev_state * (1 - update) + out_inputs * update 61 | 62 | return new_state 63 | 64 | 65 | class ConvGRU(nn.Module): 66 | 67 | def __init__(self, input_size, hidden_sizes, kernel_sizes, n_layers): 68 | ''' 69 | Generates a multi-layer convolutional GRU. 70 | Preserves spatial dimensions across cells, only altering depth. 71 | 72 | Parameters 73 | ---------- 74 | input_size : integer. depth dimension of input tensors. 75 | hidden_sizes : integer or list. depth dimensions of hidden state. 76 | if integer, the same hidden size is used for all cells. 77 | kernel_sizes : integer or list. sizes of Conv2d gate kernels. 78 | if integer, the same kernel size is used for all cells. 79 | n_layers : integer. number of chained `ConvGRUCell`. 80 | ''' 81 | 82 | super(ConvGRU, self).__init__() 83 | 84 | self.input_size = input_size 85 | 86 | if type(hidden_sizes) != list: 87 | self.hidden_sizes = [hidden_sizes]*n_layers 88 | else: 89 | assert len(hidden_sizes) == n_layers, '`hidden_sizes` must have the same length as n_layers' 90 | self.hidden_sizes = hidden_sizes 91 | if type(kernel_sizes) != list: 92 | self.kernel_sizes = [kernel_sizes]*n_layers 93 | else: 94 | assert len(kernel_sizes) == n_layers, '`kernel_sizes` must have the same length as n_layers' 95 | self.kernel_sizes = kernel_sizes 96 | 97 | self.n_layers = n_layers 98 | 99 | cells = [] 100 | for i in range(self.n_layers): 101 | if i == 0: 102 | input_dim = self.input_size 103 | else: 104 | input_dim = self.hidden_sizes[i-1] 105 | 106 | cell = ConvGRUCell(input_dim, self.hidden_sizes[i], self.kernel_sizes[i]) 107 | name = 'ConvGRUCell_' + str(i).zfill(2) 108 | 109 | setattr(self, name, cell) 110 | cells.append(getattr(self, name)) 111 | 112 | self.cells = cells 113 | 114 | 115 | def forward(self, x, hidden=None): 116 | ''' 117 | Parameters 118 | ---------- 119 | x : 4D input tensor. (batch, channels, height, width). 120 | hidden : list of 4D hidden state representations. (batch, channels, height, width). 121 | 122 | Returns 123 | ------- 124 | upd_hidden : 5D hidden representation. (layer, batch, channels, height, width). 125 | ''' 126 | if not hidden: 127 | hidden = [None]*self.n_layers 128 | 129 | input_ = x 130 | 131 | upd_hidden = [] 132 | 133 | for layer_idx in range(self.n_layers): 134 | cell = self.cells[layer_idx] 135 | cell_hidden = hidden[layer_idx] 136 | 137 | # pass through layer 138 | upd_cell_hidden = cell(input_, cell_hidden) 139 | upd_hidden.append(upd_cell_hidden.detach()) 140 | # update input_ to the last updated hidden layer for next pass 141 | input_ = upd_cell_hidden 142 | 143 | # retain tensors in list to allow different hidden sizes 144 | return input_, upd_hidden -------------------------------------------------------------------------------- /AOC-Net/complete_project/AOCNet/networks/layers/gct.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import math 4 | from torch import nn 5 | from networks.p2t.center_module import SpatialProp 6 | 7 | class GCT(nn.Module): 8 | def __init__(self, num_channels, epsilon=1e-5, mode='l2', after_relu=False): 9 | super(GCT, self).__init__() 10 | self.alpha = nn.Parameter(torch.ones(1, num_channels, 1, 1)) 11 | self.gamma = nn.Parameter(torch.zeros(1, num_channels, 1, 1)) 12 | self.beta = nn.Parameter(torch.zeros(1, num_channels, 1, 1)) 13 | self.epsilon = epsilon 14 | self.mode = mode 15 | self.after_relu = after_relu 16 | 17 | def forward(self, x): 18 | 19 | if self.mode == 'l2': 20 | embedding = (x.pow(2).sum((2,3), keepdim=True) + self.epsilon).pow(0.5) * self.alpha 21 | norm = self.gamma / (embedding.pow(2).mean(dim=1, keepdim=True) + self.epsilon).pow(0.5) 22 | 23 | elif self.mode == 'l1': 24 | if not self.after_relu: 25 | _x = torch.abs(x) 26 | else: 27 | _x = x 28 | embedding = _x.sum((2,3), keepdim=True) * self.alpha 29 | norm = self.gamma / (torch.abs(embedding).mean(dim=1, keepdim=True) + self.epsilon) 30 | else: 31 | print('Unknown mode!') 32 | exit() 33 | 34 | gate = 1. + torch.tanh(embedding * norm + self.beta) 35 | 36 | return x * gate 37 | 38 | class Bottleneck(nn.Module): 39 | def __init__(self, inplanes, outplanes, stride=1, dilation=1): 40 | super(Bottleneck, self).__init__() 41 | expansion = 4 42 | planes = int(outplanes / expansion) 43 | self.GCT1 = GCT(inplanes) 44 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 45 | self.bn1 = nn.GroupNorm(32, planes) 46 | 47 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 48 | dilation=dilation, padding=dilation, bias=False) 49 | self.bn2 = nn.GroupNorm(32, planes) 50 | 51 | self.conv3 = nn.Conv2d(planes, planes * expansion, kernel_size=1, bias=False) 52 | self.bn3 = nn.GroupNorm(32, planes * expansion) 53 | self.relu = nn.ReLU(inplace=True) 54 | if stride != 1 or inplanes != planes * expansion: 55 | downsample = nn.Sequential( 56 | nn.Conv2d(inplanes, planes * expansion, 57 | kernel_size=1, stride=stride, bias=False), 58 | nn.GroupNorm(32, planes * expansion), 59 | ) 60 | else: 61 | downsample = None 62 | self.downsample = downsample 63 | self.stride = stride 64 | self.dilation = dilation 65 | 66 | for m in self.modules(): 67 | if isinstance(m, nn.Conv2d): 68 | nn.init.kaiming_normal_(m.weight,mode='fan_out', nonlinearity='relu') 69 | 70 | def forward(self, x): 71 | residual = x 72 | 73 | out = self.GCT1(x) 74 | out = self.conv1(out) 75 | out = self.bn1(out) 76 | out = self.relu(out) 77 | 78 | out = self.conv2(out) 79 | out = self.bn2(out) 80 | out = self.relu(out) 81 | 82 | out = self.conv3(out) 83 | out = self.bn3(out) 84 | 85 | if self.downsample is not None: 86 | residual = self.downsample(x) 87 | 88 | out += residual 89 | out = self.relu(out) 90 | 91 | return out 92 | 93 | class GCT_Spatial(nn.Module): 94 | def __init__(self, num_channels, epsilon=1e-5, mode='l2', after_relu=False): 95 | super(GCT_Spatial, self).__init__() 96 | 97 | self.alpha = nn.Parameter(torch.ones(1, num_channels, 1, 1)) 98 | self.gamma = nn.Parameter(torch.zeros(1, num_channels, 1, 1)) 99 | self.beta = nn.Parameter(torch.zeros(1, num_channels, 1, 1)) 100 | self.epsilon = epsilon 101 | self.mode = mode 102 | self.after_relu = after_relu 103 | 104 | def forward(self, x,spatial_embedding): 105 | 106 | if self.mode == 'l2': 107 | embedding = (x.pow(2).sum((2,3), keepdim=True) + self.epsilon).pow(0.5) * self.alpha 108 | norm = self.gamma / (embedding.pow(2).mean(dim=1, keepdim=True) + self.epsilon).pow(0.5) 109 | 110 | elif self.mode == 'l1': 111 | if not self.after_relu: 112 | _x = torch.abs(x) 113 | else: 114 | _x = x 115 | embedding = _x.sum((2,3), keepdim=True) * self.alpha 116 | norm = self.gamma / (torch.abs(embedding).mean(dim=1, keepdim=True) + self.epsilon) 117 | else: 118 | print('Unknown mode!') 119 | exit() 120 | 121 | gate = 1. + torch.tanh(embedding * norm + self.beta) 122 | 123 | return x * gate -------------------------------------------------------------------------------- /AOC-Net/complete_project/AOCNet/networks/layers/gru_conv.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch import nn 4 | from torch.autograd import Variable 5 | 6 | 7 | class ConvGRUCell(nn.Module): 8 | def __init__(self, input_size, input_dim, hidden_dim, kernel_size, bias, dtype): 9 | """ 10 | Initialize the ConvLSTM cell 11 | :param input_size: (int, int) 12 | Height and width of input tensor as (height, width). 13 | :param input_dim: int 14 | Number of channels of input tensor. 15 | :param hidden_dim: int 16 | Number of channels of hidden state. 17 | :param kernel_size: (int, int) 18 | Size of the convolutional kernel. 19 | :param bias: bool 20 | Whether or not to add the bias. 21 | :param dtype: torch.cuda.FloatTensor or torch.FloatTensor 22 | Whether or not to use cuda. 23 | """ 24 | super(ConvGRUCell, self).__init__() 25 | self.height, self.width = input_size 26 | self.padding = kernel_size[0] // 2, kernel_size[1] // 2 27 | self.hidden_dim = hidden_dim 28 | self.bias = bias 29 | self.dtype = dtype 30 | 31 | self.conv_gates = nn.Conv2d(in_channels=input_dim + hidden_dim, 32 | out_channels=2*self.hidden_dim, # for update_gate,reset_gate respectively 33 | kernel_size=kernel_size, 34 | padding=self.padding, 35 | bias=self.bias) 36 | 37 | self.conv_can = nn.Conv2d(in_channels=input_dim+hidden_dim, 38 | out_channels=self.hidden_dim, # for candidate neural memory 39 | kernel_size=kernel_size, 40 | padding=self.padding, 41 | bias=self.bias) 42 | 43 | def init_hidden(self, batch_size): 44 | return (Variable(torch.zeros(batch_size, self.hidden_dim, self.height, self.width)).type(self.dtype)) 45 | 46 | def forward(self, input_tensor, h_cur): 47 | """ 48 | 49 | :param self: 50 | :param input_tensor: (b, c, h, w) 51 | input is actually the target_model 52 | :param h_cur: (b, c_hidden, h, w) 53 | current hidden and cell states respectively 54 | :return: h_next, 55 | next hidden state 56 | """ 57 | combined = torch.cat([input_tensor, h_cur], dim=1) 58 | combined_conv = self.conv_gates(combined) 59 | 60 | gamma, beta = torch.split(combined_conv, self.hidden_dim, dim=1) 61 | reset_gate = torch.sigmoid(gamma) 62 | update_gate = torch.sigmoid(beta) 63 | 64 | combined = torch.cat([input_tensor, reset_gate*h_cur], dim=1) 65 | cc_cnm = self.conv_can(combined) 66 | cnm = torch.tanh(cc_cnm) 67 | 68 | h_next = (1 - update_gate) * h_cur + update_gate * cnm 69 | return h_next 70 | 71 | 72 | class ConvGRU(nn.Module): 73 | def __init__(self, input_size, input_dim, hidden_dim, kernel_size, num_layers, 74 | dtype, batch_first=False, bias=True, return_all_layers=False): 75 | """ 76 | 77 | :param input_size: (int, int) 78 | Height and width of input tensor as (height, width). 79 | :param input_dim: int e.g. 256 80 | Number of channels of input tensor. 81 | :param hidden_dim: int e.g. 1024 82 | Number of channels of hidden state. 83 | :param kernel_size: (int, int) 84 | Size of the convolutional kernel. 85 | :param num_layers: int 86 | Number of ConvLSTM layers 87 | :param dtype: torch.cuda.FloatTensor or torch.FloatTensor 88 | Whether or not to use cuda. 89 | :param alexnet_path: str 90 | pretrained alexnet parameters 91 | :param batch_first: bool 92 | if the first position of array is batch or not 93 | :param bias: bool 94 | Whether or not to add the bias. 95 | :param return_all_layers: bool 96 | if return hidden and cell states for all layers 97 | """ 98 | super(ConvGRU, self).__init__() 99 | 100 | # Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers 101 | kernel_size = self._extend_for_multilayer(kernel_size, num_layers) 102 | hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers) 103 | if not len(kernel_size) == len(hidden_dim) == num_layers: 104 | raise ValueError('Inconsistent list length.') 105 | 106 | self.height, self.width = input_size 107 | self.input_dim = input_dim 108 | self.hidden_dim = hidden_dim 109 | self.kernel_size = kernel_size 110 | self.dtype = dtype 111 | self.num_layers = num_layers 112 | self.batch_first = batch_first 113 | self.bias = bias 114 | self.return_all_layers = return_all_layers 115 | 116 | cell_list = [] 117 | for i in range(0, self.num_layers): 118 | cur_input_dim = input_dim if i == 0 else hidden_dim[i - 1] 119 | cell_list.append(ConvGRUCell(input_size=(self.height, self.width), 120 | input_dim=cur_input_dim, 121 | hidden_dim=self.hidden_dim[i], 122 | kernel_size=self.kernel_size[i], 123 | bias=self.bias, 124 | dtype=self.dtype)) 125 | 126 | # convert python list to pytorch module 127 | self.cell_list = nn.ModuleList(cell_list) 128 | 129 | def forward(self, input_tensor, hidden_state=None): 130 | """ 131 | 132 | :param input_tensor: (b, t, c, h, w) or (t,b,c,h,w) depends on if batch first or not 133 | extracted features from alexnet 134 | :param hidden_state: 135 | :return: layer_output_list, last_state_list 136 | """ 137 | print("input_tensor.device",input_tensor.device) 138 | if not self.batch_first: 139 | # (t, b, c, h, w) -> (b, t, c, h, w) 140 | input_tensor = input_tensor.permute(1, 0, 2, 3, 4) 141 | 142 | # Implement stateful ConvLSTM 143 | if hidden_state is not None: 144 | raise NotImplementedError() 145 | else: 146 | hidden_state = self._init_hidden(batch_size=input_tensor.size(0)) 147 | 148 | layer_output_list = [] 149 | last_state_list = [] 150 | 151 | seq_len = input_tensor.size(1) 152 | cur_layer_input = input_tensor 153 | 154 | for layer_idx in range(self.num_layers): 155 | h = hidden_state[layer_idx] 156 | output_inner = [] 157 | for t in range(seq_len): 158 | # input current hidden and cell state then compute the next hidden and cell state through ConvLSTMCell forward function 159 | h = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :], # (b,t,c,h,w) 160 | h_cur=h) 161 | output_inner.append(h) 162 | 163 | layer_output = torch.stack(output_inner, dim=1) 164 | cur_layer_input = layer_output 165 | 166 | layer_output_list.append(layer_output) 167 | last_state_list.append([h]) 168 | 169 | if not self.return_all_layers: 170 | layer_output_list = layer_output_list[-1:] 171 | last_state_list = last_state_list[-1:] 172 | 173 | return layer_output_list, last_state_list 174 | 175 | def _init_hidden(self, batch_size): 176 | init_states = [] 177 | for i in range(self.num_layers): 178 | init_states.append(self.cell_list[i].init_hidden(batch_size)) 179 | return init_states 180 | 181 | @staticmethod 182 | def _check_kernel_size_consistency(kernel_size): 183 | if not (isinstance(kernel_size, tuple) or 184 | (isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))): 185 | raise ValueError('`kernel_size` must be tuple or list of tuples') 186 | 187 | @staticmethod 188 | def _extend_for_multilayer(param, num_layers): 189 | if not isinstance(param, list): 190 | param = [param] * num_layers 191 | return param 192 | 193 | 194 | if __name__ == '__main__': 195 | # set CUDA device 196 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 197 | 198 | # detect if CUDA is available or not 199 | use_gpu = torch.cuda.is_available() 200 | if use_gpu: 201 | dtype = torch.cuda.FloatTensor # computation in GPU 202 | else: 203 | dtype = torch.FloatTensor 204 | dtype = torch.FloatTensor 205 | height = width = 6 206 | channels = 256 207 | hidden_dim = [32, 64] 208 | kernel_size = (3,3) # kernel size for two stacked hidden layer 209 | num_layers = 2 # number of stacked hidden layer 210 | model = ConvGRU(input_size=(height, width), 211 | input_dim=channels, 212 | hidden_dim=hidden_dim, 213 | kernel_size=kernel_size, 214 | num_layers=num_layers, 215 | dtype=dtype, 216 | batch_first=True, 217 | bias = True, 218 | return_all_layers = False) 219 | 220 | batch_size = 1 221 | time_steps = 1 222 | input_tensor = torch.rand(batch_size, time_steps, channels, height, width) # (b,t,c,h,w) 223 | layer_output_list, last_state_list = model(input_tensor) 224 | for i in layer_output_list: 225 | print("i.size()",i.size()) 226 | for i in last_state_list: 227 | for j in i: 228 | print("j.size()",j.size()) 229 | layer_output_list, last_state_list2 = model(input_tensor,last_state_list) -------------------------------------------------------------------------------- /AOC-Net/complete_project/AOCNet/networks/layers/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | 5 | class Concat_BCEWithLogitsLoss(nn.Module): 6 | def __init__(self, top_k_percent_pixels=None, 7 | hard_example_mining_step=100000): 8 | super(Concat_BCEWithLogitsLoss, self).__init__() 9 | self.top_k_percent_pixels = top_k_percent_pixels 10 | if top_k_percent_pixels is not None: 11 | assert(top_k_percent_pixels > 0 and top_k_percent_pixels < 1) 12 | self.hard_example_mining_step = hard_example_mining_step 13 | if self.top_k_percent_pixels == None: 14 | self.bceloss = nn.BCEWithLogitsLoss(reduction='mean') 15 | else: 16 | self.bceloss = nn.BCEWithLogitsLoss(reduction='none') 17 | 18 | def forward(self, dic_tmp, y, step): 19 | total_loss = [] 20 | for i in range(len(dic_tmp)): 21 | pred_logits = dic_tmp[i] 22 | gts = y[i] 23 | if self.top_k_percent_pixels == None: 24 | final_loss = self.bceloss(pred_logits, gts) 25 | else: 26 | # Only compute the loss for top k percent pixels. 27 | # First, compute the loss for all pixels. Note we do not put the loss 28 | # to loss_collection and set reduction = None to keep the shape. 29 | num_pixels = float(pred_logits.size(2) * pred_logits.size(3)) 30 | pred_logits = pred_logits.view(-1, pred_logits.size( 31 | 1), pred_logits.size(2) * pred_logits.size(3)) 32 | gts = gts.view(-1, gts.size(1), gts.size(2) * gts.size(3)) 33 | pixel_losses = self.bceloss(pred_logits, gts) 34 | if self.hard_example_mining_step == 0: 35 | top_k_pixels = int(self.top_k_percent_pixels * num_pixels) 36 | else: 37 | ratio = min( 38 | 1.0, step / float(self.hard_example_mining_step)) 39 | top_k_pixels = int( 40 | (ratio * self.top_k_percent_pixels + (1.0 - ratio)) * num_pixels) 41 | _, top_k_indices = torch.topk( 42 | pixel_losses, k=top_k_pixels, dim=2) 43 | 44 | final_loss = nn.BCEWithLogitsLoss( 45 | weight=top_k_indices, reduction='mean')(pred_logits, gts) 46 | final_loss = final_loss.unsqueeze(0) 47 | total_loss.append(final_loss) 48 | total_loss = torch.cat(total_loss, dim=0) 49 | return total_loss 50 | 51 | 52 | class Concat_CrossEntropyLoss(nn.Module): 53 | def __init__(self, top_k_percent_pixels=None, 54 | hard_example_mining_step=100000): 55 | super(Concat_CrossEntropyLoss, self).__init__() 56 | self.top_k_percent_pixels = top_k_percent_pixels 57 | if top_k_percent_pixels is not None: 58 | assert(top_k_percent_pixels > 0 and top_k_percent_pixels < 1) 59 | self.hard_example_mining_step = hard_example_mining_step 60 | if self.top_k_percent_pixels == None: 61 | self.celoss = nn.CrossEntropyLoss( 62 | ignore_index=255, reduction='mean') 63 | else: 64 | self.celoss = nn.CrossEntropyLoss( 65 | ignore_index=255, reduction='none') 66 | 67 | def forward(self, dic_tmp, y, step): 68 | total_loss = [] 69 | for i in range(len(dic_tmp)): 70 | pred_logits = dic_tmp[i] 71 | gts = y[i] 72 | if self.top_k_percent_pixels == None: 73 | final_loss = self.celoss(pred_logits, gts) 74 | else: 75 | # Only compute the loss for top k percent pixels. 76 | # First, compute the loss for all pixels. Note we do not put the loss 77 | # to loss_collection and set reduction = None to keep the shape. 78 | num_pixels = float(pred_logits.size(2) * pred_logits.size(3)) 79 | pred_logits = pred_logits.view(-1, pred_logits.size( 80 | 1), pred_logits.size(2) * pred_logits.size(3)) 81 | gts = gts.view(-1, gts.size(1) * gts.size(2)) 82 | pixel_losses = self.celoss(pred_logits, gts) 83 | if self.hard_example_mining_step == 0: 84 | top_k_pixels = int(self.top_k_percent_pixels * num_pixels) 85 | else: 86 | ratio = min( 87 | 1.0, step / float(self.hard_example_mining_step)) 88 | top_k_pixels = int( 89 | (ratio * self.top_k_percent_pixels + (1.0 - ratio)) * num_pixels) 90 | top_k_loss, top_k_indices = torch.topk( 91 | pixel_losses, k=top_k_pixels, dim=1) 92 | 93 | final_loss = torch.mean(top_k_loss) 94 | final_loss = final_loss.unsqueeze(0) 95 | total_loss.append(final_loss) 96 | total_loss = torch.cat(total_loss, dim=0) 97 | return total_loss 98 | -------------------------------------------------------------------------------- /AOC-Net/complete_project/AOCNet/networks/layers/normalization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FrozenBatchNorm2d(nn.Module): 7 | """ 8 | BatchNorm2d where the batch statistics and the affine parameters 9 | are fixed 10 | """ 11 | def __init__(self, n, epsilon=1e-5): 12 | super(FrozenBatchNorm2d, self).__init__() 13 | self.register_buffer("weight", torch.ones(n)) 14 | self.register_buffer("bias", torch.zeros(n)) 15 | self.register_buffer("running_mean", torch.zeros(n)) 16 | self.register_buffer("running_var", torch.ones(n)) 17 | self.epsilon = epsilon 18 | 19 | def forward(self, x): 20 | scale = self.weight * (self.running_var + self.epsilon).rsqrt() 21 | bias = self.bias - self.running_mean * scale 22 | scale = scale.reshape(1, -1, 1, 1) 23 | bias = bias.reshape(1, -1, 1, 1) 24 | return x * scale + bias -------------------------------------------------------------------------------- /AOC-Net/complete_project/AOCNet/networks/layers/refiner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from typing import Tuple 6 | 7 | 8 | class Refiner(nn.Module): 9 | """ 10 | Refiner refines the coarse output to full resolution. 11 | 12 | Args: 13 | mode: area selection mode. Options: 14 | "full" - No area selection, refine everywhere using regular Conv2d. 15 | "sampling" - Refine fixed amount of pixels ranked by the top most errors. 16 | "thresholding" - Refine varying amount of pixels that have greater error than the threshold. 17 | sample_pixels: number of pixels to refine. Only used when mode == "sampling". 18 | threshold: error threshold ranged from 0 ~ 1. Refine where err > threshold. Only used when mode == "thresholding". 19 | kernel_size: The convolution kernel_size. Options: [1, 3] 20 | prevent_oversampling: True for regular cases, False for speedtest. 21 | 22 | Compatibility Args: 23 | patch_crop_method: the method for cropping patches. Options: 24 | "unfold" - Best performance for PyTorch and TorchScript. 25 | "roi_align" - Another way for croping patches. 26 | "gather" - Another way for croping patches. 27 | patch_replace_method: the method for replacing patches. Options: 28 | "scatter_nd" - Best performance for PyTorch and TorchScript. 29 | "scatter_element" - Another way for replacing patches. 30 | 31 | Input: 32 | src: (B, 3, H, W) full resolution source image. 33 | bgr: (B, 3, H, W) full resolution background image. 34 | pha: (B, 1, Hc, Wc) coarse alpha prediction. 35 | fgr: (B, 3, Hc, Wc) coarse foreground residual prediction. 36 | err: (B, 1, Hc, Hc) coarse error prediction. 37 | hid: (B, 32, Hc, Hc) coarse hidden encoding. 38 | 39 | Output: 40 | pha: (B, 1, H, W) full resolution alpha prediction. 41 | fgr: (B, 3, H, W) full resolution foreground residual prediction. 42 | ref: (B, 1, H/4, W/4) quarter resolution refinement selection map. 1 indicates refined 4x4 patch locations. 43 | """ 44 | 45 | # For TorchScript export optimization. 46 | __constants__ = ['kernel_size', 'patch_crop_method', 'patch_replace_method'] 47 | 48 | def __init__(self, 49 | mode: str, 50 | sample_pixels: int, 51 | threshold: float, 52 | kernel_size: int = 3, 53 | prevent_oversampling: bool = True, 54 | patch_crop_method: str = 'unfold', 55 | patch_replace_method: str = 'scatter_nd'): 56 | super().__init__() 57 | assert mode in ['full', 'sampling', 'thresholding'] 58 | assert kernel_size in [1, 3] 59 | assert patch_crop_method in ['unfold', 'roi_align', 'gather'] 60 | assert patch_replace_method in ['scatter_nd', 'scatter_element'] 61 | 62 | self.mode = mode 63 | self.sample_pixels = sample_pixels 64 | self.threshold = threshold 65 | self.kernel_size = kernel_size 66 | self.prevent_oversampling = prevent_oversampling 67 | self.patch_crop_method = patch_crop_method 68 | self.patch_replace_method = patch_replace_method 69 | 70 | channels = [128, 64, 32, 16, 1] #4->1 71 | self.conv1 = nn.Conv2d(channels[0] + 4, channels[1], kernel_size, bias=False) # 6->3 72 | self.bn1 = nn.BatchNorm2d(channels[1]) 73 | self.conv2 = nn.Conv2d(channels[1], channels[2], kernel_size, bias=False) 74 | self.bn2 = nn.BatchNorm2d(channels[2]) 75 | self.conv3 = nn.Conv2d(channels[2] + 6, channels[3], kernel_size, bias=False) 76 | self.bn3 = nn.BatchNorm2d(channels[3]) 77 | self.conv4 = nn.Conv2d(channels[3], channels[4], kernel_size, bias=True) 78 | self.relu = nn.ReLU(True) 79 | 80 | def forward(self, 81 | src: torch.Tensor, #image 82 | #bgr: torch.Tensor, 83 | pha: torch.Tensor, #probability_map (B, 1, H, W) 84 | #fgr: torch.Tensor, #warp_result 85 | err: torch.Tensor, #uncertainty map (B, 1, H, W) 86 | hid: torch.Tensor): #feature 87 | 88 | # 1.get size 89 | H_full, W_full = src.shape[2:] 90 | H_half, W_half = H_full // 2, W_full // 2 91 | H_quat, W_quat = H_full // 4, W_full // 4 92 | 93 | src_bgr = src #torch.cat([src, bgr], dim=1) 94 | 95 | if self.mode != 'full': 96 | err = F.interpolate(err, (H_quat, W_quat), mode='bilinear', align_corners=False) # downsample error map to 1/4 97 | ref = self.select_refinement_regions(err) #select top error regions 98 | idx = torch.nonzero(ref.squeeze(1)) 99 | idx = idx[:, 0], idx[:, 1], idx[:, 2] 100 | 101 | if idx[0].size(0) > 0: 102 | x = torch.cat([hid, pha], dim=1) #32 + 1 103 | x = F.interpolate(x, (H_half, W_half), mode='bilinear', align_corners=False) 104 | x = self.crop_patch(x, idx, 2, 3 if self.kernel_size == 3 else 0) 105 | 106 | y = F.interpolate(src_bgr, (H_half, W_half), mode='bilinear', align_corners=False) # 6->3 107 | y = self.crop_patch(y, idx, 2, 3 if self.kernel_size == 3 else 0) 108 | 109 | x = self.conv1(torch.cat([x, y], dim=1)) 110 | x = self.bn1(x) 111 | x = self.relu(x) 112 | x = self.conv2(x) 113 | x = self.bn2(x) 114 | x = self.relu(x) 115 | 116 | x = F.interpolate(x, 8 if self.kernel_size == 3 else 4, mode='nearest') 117 | y = self.crop_patch(src_bgr, idx, 4, 2 if self.kernel_size == 3 else 0) 118 | 119 | x = self.conv3(torch.cat([x, y], dim=1)) 120 | x = self.bn3(x) 121 | x = self.relu(x) 122 | x = self.conv4(x) 123 | 124 | out = pha # torch.cat([pha, fgr], dim=1) 125 | out = F.interpolate(out, (H_full, W_full), mode='bilinear', align_corners=False) 126 | out = self.replace_patch(out, x, idx) 127 | pha = out #[:, :1] 128 | #fgr = out[:, 1:] 129 | else: 130 | pha = F.interpolate(pha, (H_full, W_full), mode='bilinear', align_corners=False) 131 | #fgr = F.interpolate(fgr, (H_full, W_full), mode='bilinear', align_corners=False) 132 | else: 133 | x = torch.cat([hid, pha], dim=1) 134 | x = F.interpolate(x, (H_half, W_half), mode='bilinear', align_corners=False) 135 | y = F.interpolate(src_bgr, (H_half, W_half), mode='bilinear', align_corners=False) 136 | if self.kernel_size == 3: 137 | x = F.pad(x, (3, 3, 3, 3)) 138 | y = F.pad(y, (3, 3, 3, 3)) 139 | 140 | x = self.conv1(torch.cat([x, y], dim=1)) 141 | x = self.bn1(x) 142 | x = self.relu(x) 143 | x = self.conv2(x) 144 | x = self.bn2(x) 145 | x = self.relu(x) 146 | 147 | if self.kernel_size == 3: 148 | x = F.interpolate(x, (H_full + 4, W_full + 4)) 149 | y = F.pad(src_bgr, (2, 2, 2, 2)) 150 | else: 151 | x = F.interpolate(x, (H_full, W_full), mode='nearest') 152 | y = src_bgr 153 | 154 | x = self.conv3(torch.cat([x, y], dim=1)) 155 | x = self.bn3(x) 156 | x = self.relu(x) 157 | x = self.conv4(x) 158 | 159 | pha = x[:, :1] 160 | #fgr = x[:, 1:] 161 | ref = torch.ones((src.size(0), 1, H_quat, W_quat), device=src.device, dtype=src.dtype) 162 | 163 | return pha, ref #fgr, ref 164 | 165 | def select_refinement_regions(self, err: torch.Tensor): 166 | """ 167 | Select refinement regions. 168 | Input: 169 | err: error map (B, 1, H, W) 170 | Output: 171 | ref: refinement regions (B, 1, H, W). FloatTensor. 1 is selected, 0 is not. 172 | """ 173 | if self.mode == 'sampling': 174 | # Sampling mode. 175 | b, _, h, w = err.shape 176 | err = err.view(b, -1) 177 | idx = err.topk(self.sample_pixels // 16, dim=1, sorted=False).indices 178 | ref = torch.zeros_like(err) 179 | ref.scatter_(1, idx, 1.) 180 | if self.prevent_oversampling: 181 | ref.mul_(err.gt(0).float()) 182 | ref = ref.view(b, 1, h, w) 183 | else: 184 | # Thresholding mode. 185 | ref = err.gt(self.threshold).float() 186 | return ref 187 | 188 | def crop_patch(self, 189 | x: torch.Tensor, 190 | idx: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], 191 | size: int, 192 | padding: int): 193 | """ 194 | Crops selected patches from image given indices. 195 | 196 | Inputs: 197 | x: image (B, C, H, W). 198 | idx: selection indices Tuple[(P,), (P,), (P,),], where the 3 values are (B, H, W) index. 199 | size: center size of the patch, also stride of the crop. 200 | padding: expansion size of the patch. 201 | Output: 202 | patch: (P, C, h, w), where h = w = size + 2 * padding. 203 | """ 204 | if padding != 0: 205 | x = F.pad(x, (padding,) * 4) 206 | 207 | if self.patch_crop_method == 'unfold': 208 | # Use unfold. Best performance for PyTorch and TorchScript. 209 | return x.permute(0, 2, 3, 1) \ 210 | .unfold(1, size + 2 * padding, size) \ 211 | .unfold(2, size + 2 * padding, size)[idx[0], idx[1], idx[2]] 212 | elif self.patch_crop_method == 'roi_align': 213 | # Use roi_align. Best compatibility for ONNX. 214 | idx = idx[0].type_as(x), idx[1].type_as(x), idx[2].type_as(x) 215 | b = idx[0] 216 | x1 = idx[2] * size - 0.5 217 | y1 = idx[1] * size - 0.5 218 | x2 = idx[2] * size + size + 2 * padding - 0.5 219 | y2 = idx[1] * size + size + 2 * padding - 0.5 220 | boxes = torch.stack([b, x1, y1, x2, y2], dim=1) 221 | return torchvision.ops.roi_align(x, boxes, size + 2 * padding, sampling_ratio=1) 222 | else: 223 | # Use gather. Crops out patches pixel by pixel. 224 | idx_pix = self.compute_pixel_indices(x, idx, size, padding) 225 | pat = torch.gather(x.view(-1), 0, idx_pix.view(-1)) 226 | pat = pat.view(-1, x.size(1), size + 2 * padding, size + 2 * padding) 227 | return pat 228 | 229 | def replace_patch(self, 230 | x: torch.Tensor, 231 | y: torch.Tensor, 232 | idx: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): 233 | """ 234 | Replaces patches back into image given index. 235 | 236 | Inputs: 237 | x: image (B, C, H, W) 238 | y: patches (P, C, h, w) 239 | idx: selection indices Tuple[(P,), (P,), (P,)] where the 3 values are (B, H, W) index. 240 | 241 | Output: 242 | image: (B, C, H, W), where patches at idx locations are replaced with y. 243 | """ 244 | xB, xC, xH, xW = x.shape 245 | yB, yC, yH, yW = y.shape 246 | if self.patch_replace_method == 'scatter_nd': 247 | # Use scatter_nd. Best performance for PyTorch and TorchScript. Replacing patch by patch. 248 | x = x.view(xB, xC, xH // yH, yH, xW // yW, yW).permute(0, 2, 4, 1, 3, 5) 249 | x[idx[0], idx[1], idx[2]] = y 250 | x = x.permute(0, 3, 1, 4, 2, 5).view(xB, xC, xH, xW) 251 | return x 252 | else: 253 | # Use scatter_element. Best compatibility for ONNX. Replacing pixel by pixel. 254 | idx_pix = self.compute_pixel_indices(x, idx, size=4, padding=0) 255 | return x.view(-1).scatter_(0, idx_pix.view(-1), y.view(-1)).view(x.shape) 256 | 257 | def compute_pixel_indices(self, 258 | x: torch.Tensor, 259 | idx: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], 260 | size: int, 261 | padding: int): 262 | """ 263 | Compute selected pixel indices in the tensor. 264 | Used for crop_method == 'gather' and replace_method == 'scatter_element', which crop and replace pixel by pixel. 265 | Input: 266 | x: image: (B, C, H, W) 267 | idx: selection indices Tuple[(P,), (P,), (P,),], where the 3 values are (B, H, W) index. 268 | size: center size of the patch, also stride of the crop. 269 | padding: expansion size of the patch. 270 | Output: 271 | idx: (P, C, O, O) long tensor where O is the output size: size + 2 * padding, P is number of patches. 272 | the element are indices pointing to the input x.view(-1). 273 | """ 274 | B, C, H, W = x.shape 275 | S, P = size, padding 276 | O = S + 2 * P 277 | b, y, x = idx 278 | n = b.size(0) 279 | c = torch.arange(C) 280 | o = torch.arange(O) 281 | idx_pat = (c * H * W).view(C, 1, 1).expand([C, O, O]) + (o * W).view(1, O, 1).expand([C, O, O]) + o.view(1, 1, O).expand([C, O, O]) 282 | idx_loc = b * W * H + y * W * S + x * S 283 | idx_pix = idx_loc.view(-1, 1, 1, 1).expand([n, C, O, O]) + idx_pat.view(1, C, O, O).expand([n, C, O, O]) 284 | return idx_pix 285 | -------------------------------------------------------------------------------- /AOC-Net/complete_project/AOCNet/networks/layers/shannon_entropy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | 5 | def normalize(image, MIN_BOUND, MAX_BOUND): 6 | image = (image - MIN_BOUND) / (MAX_BOUND - MIN_BOUND) 7 | reverse_image = 1 - image 8 | return reverse_image 9 | 10 | def cal_shannon_entropy(preds): # (batch, obj_num, 128, 128) 11 | uncertainty = -1.0 * torch.sum(preds * torch.log(preds + 1e-6), dim=1, keepdim=True) # (batch, 1, 128, 128) 12 | uncertainty_norm = normalize(uncertainty, 0, np.log(2)) * 7 13 | return uncertainty,uncertainty_norm 14 | 15 | 16 | def normalize_train(image, MIN_BOUND, MAX_BOUND): 17 | image = (image - MIN_BOUND) / (MAX_BOUND - MIN_BOUND) 18 | #reverse_image = 1 - image 19 | return image #reverse_image 20 | 21 | def cal_shannon_entropy_train(preds): # (batch, obj_num, 128, 128) 22 | uncertainty = -1.0 * torch.sum(preds * torch.log(preds + 1e-6), dim=1, keepdim=True) # (batch, 1, 128, 128) 23 | uncertainty_norm = normalize_train(uncertainty, 0, np.log(2)) * 7 24 | return uncertainty,uncertainty_norm 25 | 26 | 27 | 28 | def show_shannon_entropy(uncertainty,uncertainty_norm,unc_rate=0.5): #torch.tensor, torch.tensor 29 | uncertainty = uncertainty.cpu().numpy().squeeze().astype('float32') 30 | uncertainty = uncertainty * (uncertainty > unc_rate) 31 | uncertainty_norm = uncertainty_norm.cpu().numpy().squeeze().astype('float32') 32 | uncertainty_norm = uncertainty_norm 33 | 34 | plt.figure() 35 | plt.subplot(1, 6, 1) 36 | plt.imshow(save_pre_den) 37 | plt.subplot(1, 6, 2) 38 | plt.imshow(density_gt) 39 | plt.subplot(1, 6, 3) 40 | plt.imshow(save_pre_dmp_to_att) 41 | plt.subplot(1, 6, 4) 42 | plt.imshow(save_pre_att_2) 43 | plt.subplot(1, 6, 5) 44 | plt.imshow(uncertainty, cmap='inferno') 45 | plt.subplot(1, 6, 6) 46 | plt.imshow(uncertainty_norm, cmap='inferno') 47 | plt.show() 48 | 49 | def save_shannon_entropy(uncertainty,uncertainty_norm,save_path,unc_rate=0.5): #torch.tensor, torch.tensor 50 | uncertainty = uncertainty.cpu().numpy().squeeze().astype('float32') 51 | uncertainty = uncertainty * (uncertainty > unc_rate) 52 | uncertainty_norm = uncertainty_norm.cpu().numpy().squeeze().astype('float32') 53 | uncertainty_norm = uncertainty_norm 54 | 55 | plt.figure() 56 | plt.subplot(1, 2, 1) 57 | plt.imshow(uncertainty, cmap='inferno') 58 | plt.subplot(1, 2, 2) 59 | plt.imshow(uncertainty_norm, cmap='inferno') 60 | #plt.show() 61 | plt.savefig(save_path) 62 | plt.close() 63 | 64 | def save_shannon_entropy_calculated(uncertainty,save_path,unc_rate=1): #torch.tensor, torch.tensor 65 | uncertainty = uncertainty.cpu().numpy().squeeze().astype('float32') 66 | 67 | 68 | 69 | plt.figure() 70 | unc_rate=0 71 | uncertainty_org1 = uncertainty 72 | uncertainty_thresh1 = (uncertainty > unc_rate) 73 | uncertainty1 = uncertainty * (uncertainty > unc_rate) 74 | plt.subplot(3,3, 1) 75 | plt.imshow(uncertainty_org1, cmap='inferno') 76 | plt.subplot(3,3, 2) 77 | plt.imshow(uncertainty_thresh1, cmap='inferno') 78 | plt.subplot(3,3, 3) 79 | plt.imshow(uncertainty1, cmap='inferno') 80 | 81 | unc_rate=1 82 | uncertainty_org2 = uncertainty 83 | uncertainty_thresh2 = (uncertainty > unc_rate) 84 | uncertainty2 = uncertainty * (uncertainty > unc_rate) 85 | plt.subplot(3,3, 4) 86 | plt.imshow(uncertainty_org2, cmap='inferno') 87 | plt.subplot(3,3, 5) 88 | plt.imshow(uncertainty_thresh2, cmap='inferno') 89 | plt.subplot(3,3, 6) 90 | plt.imshow(uncertainty2, cmap='inferno') 91 | 92 | unc_rate=10 93 | uncertainty_org3 = uncertainty 94 | uncertainty_thresh3 = (uncertainty > unc_rate) 95 | uncertainty3 = uncertainty * (uncertainty > unc_rate) 96 | plt.subplot(3,3, 7) 97 | plt.imshow(uncertainty_org2, cmap='inferno') 98 | plt.subplot(3,3, 8) 99 | plt.imshow(uncertainty_thresh2, cmap='inferno') 100 | plt.subplot(3,3, 9) 101 | plt.imshow(uncertainty2, cmap='inferno') 102 | 103 | 104 | 105 | #plt.show() 106 | plt.savefig(save_path) 107 | plt.close() -------------------------------------------------------------------------------- /AOC-Net/complete_project/AOCNet/networks/layers/shanoon_entropy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | 5 | def normalize(image, MIN_BOUND, MAX_BOUND): 6 | image = (image - MIN_BOUND) / (MAX_BOUND - MIN_BOUND) 7 | reverse_image = 1 - image 8 | return reverse_image 9 | 10 | def cal_shanoon_entropy(preds): # (batch, obj_num, 128, 128) 11 | uncertainty = -1.0 * torch.sum(preds * torch.log(preds + 1e-6), dim=1, keepdim=True) # (batch, 1, 128, 128) 12 | uncertainty_norm = normalize(uncertainty, 0, np.log(2)) * 7 13 | return uncertainty,uncertainty_norm 14 | 15 | 16 | def normalize_train(image, MIN_BOUND, MAX_BOUND): 17 | image = (image - MIN_BOUND) / (MAX_BOUND - MIN_BOUND) 18 | #reverse_image = 1 - image 19 | return image #reverse_image 20 | 21 | def cal_shanoon_entropy_train(preds): # (batch, obj_num, 128, 128) 22 | uncertainty = -1.0 * torch.sum(preds * torch.log(preds + 1e-6), dim=1, keepdim=True) # (batch, 1, 128, 128) 23 | uncertainty_norm = normalize_train(uncertainty, 0, np.log(2)) * 7 24 | return uncertainty,uncertainty_norm 25 | 26 | 27 | 28 | def show_shanoon_entropy(uncertainty,uncertainty_norm,unc_rate=0.5): #torch.tensor, torch.tensor 29 | uncertainty = uncertainty.cpu().numpy().squeeze().astype('float32') 30 | uncertainty = uncertainty * (uncertainty > unc_rate) 31 | uncertainty_norm = uncertainty_norm.cpu().numpy().squeeze().astype('float32') 32 | uncertainty_norm = uncertainty_norm 33 | 34 | plt.figure() 35 | plt.subplot(1, 6, 1) 36 | plt.imshow(save_pre_den) 37 | plt.subplot(1, 6, 2) 38 | plt.imshow(density_gt) 39 | plt.subplot(1, 6, 3) 40 | plt.imshow(save_pre_dmp_to_att) 41 | plt.subplot(1, 6, 4) 42 | plt.imshow(save_pre_att_2) 43 | plt.subplot(1, 6, 5) 44 | plt.imshow(uncertainty, cmap='inferno') 45 | plt.subplot(1, 6, 6) 46 | plt.imshow(uncertainty_norm, cmap='inferno') 47 | plt.show() 48 | 49 | def save_shanoon_entropy(uncertainty,uncertainty_norm,save_path,unc_rate=0.5): #torch.tensor, torch.tensor 50 | uncertainty = uncertainty.cpu().numpy().squeeze().astype('float32') 51 | uncertainty = uncertainty * (uncertainty > unc_rate) 52 | uncertainty_norm = uncertainty_norm.cpu().numpy().squeeze().astype('float32') 53 | uncertainty_norm = uncertainty_norm 54 | 55 | plt.figure() 56 | plt.subplot(1, 2, 1) 57 | plt.imshow(uncertainty, cmap='inferno') 58 | plt.subplot(1, 2, 2) 59 | plt.imshow(uncertainty_norm, cmap='inferno') 60 | #plt.show() 61 | plt.savefig(save_path) 62 | plt.close() 63 | 64 | def save_shanoon_entropy_calculated(uncertainty,save_path,unc_rate=1): #torch.tensor, torch.tensor 65 | uncertainty = uncertainty.cpu().numpy().squeeze().astype('float32') 66 | 67 | 68 | 69 | plt.figure() 70 | unc_rate=0 71 | uncertainty_org1 = uncertainty 72 | uncertainty_thresh1 = (uncertainty > unc_rate) 73 | uncertainty1 = uncertainty * (uncertainty > unc_rate) 74 | plt.subplot(3,3, 1) 75 | plt.imshow(uncertainty_org1, cmap='inferno') 76 | plt.subplot(3,3, 2) 77 | plt.imshow(uncertainty_thresh1, cmap='inferno') 78 | plt.subplot(3,3, 3) 79 | plt.imshow(uncertainty1, cmap='inferno') 80 | 81 | unc_rate=1 82 | uncertainty_org2 = uncertainty 83 | uncertainty_thresh2 = (uncertainty > unc_rate) 84 | uncertainty2 = uncertainty * (uncertainty > unc_rate) 85 | plt.subplot(3,3, 4) 86 | plt.imshow(uncertainty_org2, cmap='inferno') 87 | plt.subplot(3,3, 5) 88 | plt.imshow(uncertainty_thresh2, cmap='inferno') 89 | plt.subplot(3,3, 6) 90 | plt.imshow(uncertainty2, cmap='inferno') 91 | 92 | unc_rate=10 93 | uncertainty_org3 = uncertainty 94 | uncertainty_thresh3 = (uncertainty > unc_rate) 95 | uncertainty3 = uncertainty * (uncertainty > unc_rate) 96 | plt.subplot(3,3, 7) 97 | plt.imshow(uncertainty_org2, cmap='inferno') 98 | plt.subplot(3,3, 8) 99 | plt.imshow(uncertainty_thresh2, cmap='inferno') 100 | plt.subplot(3,3, 9) 101 | plt.imshow(uncertainty2, cmap='inferno') 102 | 103 | 104 | 105 | #plt.show() 106 | plt.savefig(save_path) 107 | plt.close() 108 | -------------------------------------------------------------------------------- /AOC-Net/complete_project/AOCNet/scripts/eval.sh: -------------------------------------------------------------------------------- 1 | datasets="youtubevos" 2 | config="configs.resnet101_aocnet_2" 3 | python ../tools/eval_net.py --config ${config} --dataset ${datasets} --ckpt_step 400000 --global_chunks 16 --gpu_id 0 --mem_every 5 4 | -------------------------------------------------------------------------------- /AOC-Net/complete_project/AOCNet/scripts/train.sh: -------------------------------------------------------------------------------- 1 | datasets="youtubevos" 2 | config="configs.resnet101_aoc" 3 | # training for 200k with lr=0.2 4 | python ../tools/train_net_mm.py --config ${config} --datasets ${datasets} --global_chunks 1 5 | 6 | # go on training for 200k with lr=0.1 7 | config="configs.resnet101_aoc_2" 8 | python ../tools/train_net_mm.py --config ${config} --datasets ${datasets} --global_chunks 1 9 | -------------------------------------------------------------------------------- /AOC-Net/complete_project/AOCNet/tools/eval_net_mm_rpa.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('.') 3 | sys.path.append('..') 4 | from networks.engine.eval_manager_mm_rpa import Evaluator 5 | import importlib 6 | 7 | def main(): 8 | import argparse 9 | parser = argparse.ArgumentParser(description="Eval") 10 | parser.add_argument('--exp_name', type=str, default='') 11 | 12 | parser.add_argument('--config', type=str, default='configs.resnet101') 13 | 14 | parser.add_argument('--gpu_id', type=int, default=0) 15 | 16 | parser.add_argument('--ckpt_path', type=str, default='') 17 | parser.add_argument('--ckpt_step', type=int, default=-1) 18 | 19 | parser.add_argument('--dataset', type=str, default='') 20 | 21 | parser.add_argument('--flip', action='store_true') 22 | parser.set_defaults(flip=False) 23 | parser.add_argument('--ms', nargs='+', type=float, default=[1.]) 24 | parser.add_argument('--max_long_edge', type=int, default=-1) 25 | parser.add_argument('--mem_every', type=int, default=-1) 26 | parser.add_argument('--ucr', type=float, default=1.0) 27 | parser.add_argument('--float16', action='store_true') 28 | parser.add_argument('--vis', action='store_true') 29 | parser.set_defaults(float16=False) 30 | parser.add_argument('--global_atrous_rate', type=int, default=1) 31 | parser.add_argument('--global_chunks', type=int, default=4) 32 | parser.add_argument('--min_matching_pixels', type=int, default=0) 33 | parser.add_argument('--no_local_parallel', dest='local_parallel', action='store_false') 34 | parser.set_defaults(local_parallel=True) 35 | args = parser.parse_args() 36 | 37 | config = importlib.import_module(args.config) 38 | cfg = config.cfg 39 | 40 | cfg.TEST_GPU_ID = args.gpu_id 41 | if args.exp_name != '': 42 | cfg.EXP_NAME = args.exp_name 43 | if args.mem_every != '': 44 | cfg.MEM_EVERY = args.mem_every 45 | if args.ucr != '': 46 | cfg.UNC_RATIO = args.ucr 47 | if args.ckpt_path != '': 48 | cfg.TEST_CKPT_PATH = args.ckpt_path 49 | if args.ckpt_step > 0: 50 | cfg.TEST_CKPT_STEP = args.ckpt_step 51 | if args.dataset != '': 52 | cfg.TEST_DATASET = args.dataset 53 | 54 | cfg.UNC_VIS = args.vis 55 | 56 | cfg.TEST_FLIP = args.flip 57 | cfg.TEST_MULTISCALE = args.ms 58 | if args.max_long_edge > 0: 59 | cfg.TEST_MAX_SIZE = args.max_long_edge 60 | else: 61 | cfg.TEST_MAX_SIZE = 800 * 1.3 if cfg.TEST_MULTISCALE == [1.] else 800 62 | 63 | cfg.MODEL_FLOAT16_MATCHING = args.float16 64 | if 'cfbip' in cfg.MODEL_MODULE: 65 | cfg.TEST_GLOBAL_ATROUS_RATE = [args.global_atrous_rate, 1, 1] 66 | else: 67 | cfg.TEST_GLOBAL_ATROUS_RATE = args.global_atrous_rate 68 | cfg.TEST_GLOBAL_CHUNKS = args.global_chunks 69 | cfg.TEST_LOCAL_PARALLEL = args.local_parallel 70 | 71 | if args.min_matching_pixels > 0: 72 | cfg.TEST_GLOBAL_MATCHING_MIN_PIXEL = args.min_matching_pixels 73 | 74 | evaluator = Evaluator(cfg=cfg) 75 | evaluator.evaluating() 76 | 77 | if __name__ == '__main__': 78 | main() 79 | 80 | -------------------------------------------------------------------------------- /AOC-Net/complete_project/AOCNet/tools/train_net_mm.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('.') 3 | sys.path.append('..') 4 | from networks.engine.train_manager_mm import Trainer 5 | import torch.multiprocessing as mp 6 | import importlib 7 | 8 | def main_worker(gpu, cfg): 9 | # Initiate a training manager 10 | trainer = Trainer(rank=gpu, cfg=cfg) 11 | # Start Training 12 | trainer.sequential_training() 13 | 14 | def main(): 15 | import argparse 16 | parser = argparse.ArgumentParser(description="Train") 17 | parser.add_argument('--exp_name', type=str, default='') 18 | parser.add_argument('--config', type=str, default='configs.resnet101') 19 | 20 | parser.add_argument('--start_gpu', type=int, default=0) 21 | parser.add_argument('--gpu_num', type=int, default=-1) 22 | parser.add_argument('--batch_size', type=int, default=-1) 23 | 24 | parser.add_argument('--pretrained_path', type=str, default='') 25 | 26 | parser.add_argument('--datasets', nargs='+', type=str, default=['youtubevos']) 27 | parser.add_argument('--lr', type=float, default=-1.) 28 | parser.add_argument('--total_step', type=int, default=-1.) 29 | parser.add_argument('--start_step', type=int, default=-1.) 30 | 31 | parser.add_argument('--float16', action='store_true') 32 | parser.set_defaults(float16=False) 33 | parser.add_argument('--global_atrous_rate', type=int, default=1) 34 | parser.add_argument('--global_chunks', type=int, default=20) 35 | parser.add_argument('--no_local_parallel', dest='local_parallel', action='store_false') 36 | parser.set_defaults(local_parallel=True) 37 | args = parser.parse_args() 38 | 39 | config = importlib.import_module(args.config) 40 | cfg = config.cfg 41 | 42 | if args.exp_name != '': 43 | cfg.EXP_NAME = args.exp_name 44 | 45 | cfg.DIST_START_GPU = args.start_gpu 46 | if args.gpu_num > 0: 47 | cfg.TRAIN_GPUS = args.gpu_num 48 | if args.batch_size > 0: 49 | cfg.TRAIN_BATCH_SIZE = args.batch_size 50 | 51 | if args.pretrained_path != '': 52 | cfg.PRETRAIN_MODEL = args.pretrained_path 53 | 54 | if args.lr > 0: 55 | cfg.TRAIN_LR = args.lr 56 | if args.total_step > 0: 57 | cfg.TRAIN_TOTAL_STEPS = args.total_step 58 | cfg.TRAIN_START_SEQ_TRAINING_STEPS = int(args.total_step / 2) 59 | cfg.TRAIN_HARD_MINING_STEP = int(args.total_step / 2) 60 | if args.start_step > 0: 61 | cfg.TRAIN_START_STEP = args.start_step 62 | 63 | cfg.MODEL_FLOAT16_MATCHING = args.float16 64 | if 'cfbip' in cfg.MODEL_MODULE: 65 | cfg.TRAIN_GLOBAL_ATROUS_RATE = [args.global_atrous_rate, 1, 1] 66 | else: 67 | cfg.TRAIN_GLOBAL_ATROUS_RATE = args.global_atrous_rate 68 | cfg.TRAIN_GLOBAL_CHUNKS = args.global_chunks 69 | cfg.TRAIN_LOCAL_PARALLEL = args.local_parallel 70 | 71 | # Use torch.multiprocessing.spawn to launch distributed processes 72 | mp.spawn(main_worker, nprocs=cfg.TRAIN_GPUS, args=(cfg,)) 73 | 74 | if __name__ == '__main__': 75 | main() 76 | 77 | -------------------------------------------------------------------------------- /AOC-Net/complete_project/AOCNet/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/Robust-Video-Object-Segmentation/f562afe70ac0f960d27980d3e82b237f8dea45ec/AOC-Net/complete_project/AOCNet/utils/__init__.py -------------------------------------------------------------------------------- /AOC-Net/complete_project/AOCNet/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | 5 | def load_network_and_optimizer(net, opt, pretrained_dir, gpu): 6 | pretrained = torch.load( 7 | pretrained_dir, 8 | map_location=torch.device("cuda:"+str(gpu))) 9 | pretrained_dict = pretrained['state_dict'] 10 | model_dict = net.state_dict() 11 | pretrained_dict_update = {} 12 | pretrained_dict_remove = [] 13 | for k, v in pretrained_dict.items(): 14 | if k in model_dict: 15 | pretrained_dict_update[k] = v 16 | elif k[:7] == 'module.': 17 | if k[7:] in model_dict: 18 | pretrained_dict_update[k[7:]] = v 19 | else: 20 | pretrained_dict_remove.append(k) 21 | model_dict.update(pretrained_dict_update) 22 | net.load_state_dict(model_dict) 23 | opt.load_state_dict(pretrained['optimizer']) 24 | del(pretrained) 25 | return net.cuda(gpu), opt, pretrained_dict_remove 26 | 27 | def load_network_and_not_optimizer(net, opt, pretrained_dir, gpu): 28 | pretrained = torch.load( 29 | pretrained_dir, 30 | map_location=torch.device("cuda:"+str(gpu))) 31 | pretrained_dict = pretrained['state_dict'] 32 | model_dict = net.state_dict() 33 | pretrained_dict_update = {} 34 | pretrained_dict_remove = [] 35 | for k, v in pretrained_dict.items(): 36 | if k in model_dict: 37 | pretrained_dict_update[k] = v 38 | elif k[:7] == 'module.': 39 | if k[7:] in model_dict: 40 | pretrained_dict_update[k[7:]] = v 41 | else: 42 | pretrained_dict_remove.append(k) 43 | model_dict.update(pretrained_dict_update) 44 | net.load_state_dict(model_dict) 45 | #opt.load_state_dict(pretrained['optimizer']) 46 | del(pretrained) 47 | return net.cuda(gpu), opt, pretrained_dict_remove 48 | 49 | def load_network(net, pretrained_dir, gpu): 50 | pretrained = torch.load( 51 | pretrained_dir, 52 | map_location=torch.device("cuda:"+str(gpu))) 53 | pretrained_dict = pretrained['state_dict'] 54 | model_dict = net.state_dict() 55 | pretrained_dict_update = {} 56 | pretrained_dict_remove = [] 57 | for k, v in pretrained_dict.items(): 58 | if k in model_dict and model_dict[k].size()==pretrained_dict[k].size(): 59 | pretrained_dict_update[k] = v 60 | #if model_dict[k].size()!=pretrained_dict[k].size(): 61 | # print("ERROR->",k) 62 | elif k[:7] == 'module.': 63 | if k[7:] in model_dict: 64 | pretrained_dict_update[k[7:]] = v 65 | else: 66 | pretrained_dict_remove.append(k) 67 | model_dict.update(pretrained_dict_update) 68 | net.load_state_dict(model_dict) 69 | del(pretrained) 70 | return net.cuda(gpu), pretrained_dict_remove 71 | 72 | 73 | def load_network_P2T(net, pretrained_dir, gpu): 74 | pretrained = torch.load( 75 | pretrained_dir, 76 | map_location=torch.device("cuda:"+str(gpu))) 77 | pretrained_dict = pretrained['state_dict'] 78 | model_dict = net.state_dict() 79 | pretrained_dict_update = {} 80 | pretrained_dict_remove = [] 81 | for k, v in pretrained_dict.items(): 82 | print("pretrained_dict_k",k) 83 | if k in model_dict: 84 | print("k in new model") 85 | pretrained_dict_update[k] = v 86 | #for para in pretrained_dict_update[k].parameters(): 87 | # para.requires_grad = False 88 | v.requires_grad = False 89 | elif k[:7] == 'module.': 90 | print("k in new model (module)") 91 | if k[7:] in model_dict: 92 | pretrained_dict_update[k[7:]] = v 93 | #for para in pretrained_dict_update[k].parameters(): 94 | v.requires_grad = False 95 | else: 96 | print("k not in new model") 97 | pretrained_dict_remove.append(k) 98 | 99 | model_dict.update(pretrained_dict_update) 100 | net.load_state_dict(model_dict) 101 | del(pretrained) 102 | return net.cuda(gpu), pretrained_dict_remove 103 | 104 | 105 | def save_network(net, opt, step, save_path, max_keep=8): 106 | try: 107 | if not os.path.exists(save_path): 108 | os.makedirs(save_path) 109 | save_file = 'save_step_%s.pth' % (step) 110 | save_dir = os.path.join(save_path, save_file) 111 | torch.save({'state_dict': net.state_dict(), 'optimizer': opt.state_dict()}, save_dir) 112 | except: 113 | save_path = './saved_models' 114 | if not os.path.exists(save_path): 115 | os.makedirs(save_path) 116 | save_file = 'save_step_%s.pth' % (step) 117 | save_dir = os.path.join(save_path, save_file) 118 | torch.save({'state_dict': net.state_dict(), 'optimizer': opt.state_dict()}, save_dir) 119 | 120 | all_ckpt = os.listdir(save_path) 121 | if len(all_ckpt) > max_keep: 122 | all_step = [] 123 | for ckpt_name in all_ckpt: 124 | step = int(ckpt_name.split('_')[-1].split('.')[0]) 125 | all_step.append(step) 126 | all_step = list(np.sort(all_step))[:-max_keep] 127 | for step in all_step: 128 | ckpt_path = os.path.join(save_path, 'save_step_%s.pth' % (step)) 129 | os.system('rm {}'.format(ckpt_path)) 130 | -------------------------------------------------------------------------------- /AOC-Net/complete_project/AOCNet/utils/eval.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import zipfile 3 | import os 4 | 5 | def zip_folder(source_folder, zip_dir): 6 | f = zipfile.ZipFile(zip_dir, 'w', zipfile.ZIP_DEFLATED) 7 | pre_len = len(os.path.dirname(source_folder)) 8 | for dirpath, dirnames, filenames in os.walk(source_folder): 9 | for filename in filenames: 10 | pathfile = os.path.join(dirpath, filename) 11 | arcname = pathfile[pre_len:].strip(os.path.sep) 12 | f.write(pathfile, arcname) 13 | f.close() -------------------------------------------------------------------------------- /AOC-Net/complete_project/AOCNet/utils/image.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import torch 4 | 5 | ## for visulization 6 | import cv2 7 | import time 8 | import gc 9 | import matplotlib.pyplot as plt 10 | import seaborn as sns 11 | 12 | plt.rcParams['font.sans-serif']=['SimHei'] 13 | plt.rcParams['axes.unicode_minus'] = False 14 | ## 15 | 16 | _palette = [0, 0, 0, 128, 0, 0, 0, 128, 0, 128, 128, 0, 0, 0, 128, 128, 0, 128, 0, 128, 128, 128, 128, 128, 64, 0, 0, 191, 0, 0, 64, 128, 0, 191, 128, 0, 64, 0, 128, 191, 0, 128, 64, 128, 128, 191, 128, 128, 0, 64, 0, 128, 64, 0, 0, 191, 0, 128, 191, 0, 0, 64, 128, 128, 64, 128, 22, 22, 22, 23, 23, 23, 24, 24, 24, 25, 25, 25, 26, 26, 26, 27, 27, 27, 28, 28, 28, 29, 29, 29, 30, 30, 30, 31, 31, 31, 32, 32, 32, 33, 33, 33, 34, 34, 34, 35, 35, 35, 36, 36, 36, 37, 37, 37, 38, 38, 38, 39, 39, 39, 40, 40, 40, 41, 41, 41, 42, 42, 42, 43, 43, 43, 44, 44, 44, 45, 45, 45, 46, 46, 46, 47, 47, 47, 48, 48, 48, 49, 49, 49, 50, 50, 50, 51, 51, 51, 52, 52, 52, 53, 53, 53, 54, 54, 54, 55, 55, 55, 56, 56, 56, 57, 57, 57, 58, 58, 58, 59, 59, 59, 60, 60, 60, 61, 61, 61, 62, 62, 62, 63, 63, 63, 64, 64, 64, 65, 65, 65, 66, 66, 66, 67, 67, 67, 68, 68, 68, 69, 69, 69, 70, 70, 70, 71, 71, 71, 72, 72, 72, 73, 73, 73, 74, 74, 74, 75, 75, 75, 76, 76, 76, 77, 77, 77, 78, 78, 78, 79, 79, 79, 80, 80, 80, 81, 81, 81, 82, 82, 82, 83, 83, 83, 84, 84, 84, 85, 85, 85, 86, 86, 86, 87, 87, 87, 88, 88, 88, 89, 89, 89, 90, 90, 90, 91, 91, 91, 92, 92, 92, 93, 93, 93, 94, 94, 94, 95, 95, 95, 96, 96, 96, 97, 97, 97, 98, 98, 98, 99, 99, 99, 100, 100, 100, 101, 101, 101, 102, 102, 102, 103, 103, 103, 104, 104, 104, 105, 105, 105, 106, 106, 106, 107, 107, 107, 108, 108, 108, 109, 109, 109, 110, 110, 110, 111, 111, 111, 112, 112, 112, 113, 113, 113, 114, 114, 114, 115, 115, 115, 116, 116, 116, 117, 117, 117, 118, 118, 118, 119, 119, 119, 120, 120, 120, 121, 121, 121, 122, 122, 122, 123, 123, 123, 124, 124, 124, 125, 125, 125, 126, 126, 126, 127, 127, 127, 128, 128, 128, 129, 129, 129, 130, 130, 130, 131, 131, 131, 132, 132, 132, 133, 133, 133, 134, 134, 134, 135, 135, 135, 136, 136, 136, 137, 137, 137, 138, 138, 138, 139, 139, 139, 140, 140, 140, 141, 141, 141, 142, 142, 142, 143, 143, 143, 144, 144, 144, 145, 145, 145, 146, 146, 146, 147, 147, 147, 148, 148, 148, 149, 149, 149, 150, 150, 150, 151, 151, 151, 152, 152, 152, 153, 153, 153, 154, 154, 154, 155, 155, 155, 156, 156, 156, 157, 157, 157, 158, 158, 158, 159, 159, 159, 160, 160, 160, 161, 161, 161, 162, 162, 162, 163, 163, 163, 164, 164, 164, 165, 165, 165, 166, 166, 166, 167, 167, 167, 168, 168, 168, 169, 169, 169, 170, 170, 170, 171, 171, 171, 172, 172, 172, 173, 173, 173, 174, 174, 174, 175, 175, 175, 176, 176, 176, 177, 177, 177, 178, 178, 178, 179, 179, 179, 180, 180, 180, 181, 181, 181, 182, 182, 182, 183, 183, 183, 184, 184, 184, 185, 185, 185, 186, 186, 186, 187, 187, 187, 188, 188, 188, 189, 189, 189, 190, 190, 190, 191, 191, 191, 192, 192, 192, 193, 193, 193, 194, 194, 194, 195, 195, 195, 196, 196, 196, 197, 197, 197, 198, 198, 198, 199, 199, 199, 200, 200, 200, 201, 201, 201, 202, 202, 202, 203, 203, 203, 204, 204, 204, 205, 205, 205, 206, 206, 206, 207, 207, 207, 208, 208, 208, 209, 209, 209, 210, 210, 210, 211, 211, 211, 212, 212, 212, 213, 213, 213, 214, 214, 214, 215, 215, 215, 216, 216, 216, 217, 217, 217, 218, 218, 218, 219, 219, 219, 220, 220, 220, 221, 221, 221, 222, 222, 222, 223, 223, 223, 224, 224, 224, 225, 225, 225, 226, 226, 226, 227, 227, 227, 228, 228, 228, 229, 229, 229, 230, 230, 230, 231, 231, 231, 232, 232, 232, 233, 233, 233, 234, 234, 234, 235, 235, 235, 236, 236, 236, 237, 237, 237, 238, 238, 238, 239, 239, 239, 240, 240, 240, 241, 241, 241, 242, 242, 242, 243, 243, 243, 244, 244, 244, 245, 245, 245, 246, 246, 246, 247, 247, 247, 248, 248, 248, 249, 249, 249, 250, 250, 250, 251, 251, 251, 252, 252, 252, 253, 253, 253, 254, 254, 254, 255, 255, 255] 17 | 18 | cluster_map_palette = [0, 0, 0, 200, 200, 200, 0, 128, 0, 128, 128, 0, 0, 0, 128, 128, 0, 128, 0, 128, 128, 128, 128, 128, 64, 0, 0, 191, 0, 0, 64, 128, 0, 191, 128, 0, 64, 0, 128, 191, 0, 128, 64, 128, 128, 191, 128, 128, 0, 64, 0, 128, 64, 0, 0, 191, 0, 128, 191, 0, 0, 64, 128, 128, 64, 128, 22, 22, 22, 23, 23, 23, 24, 24, 24, 25, 25, 25, 26, 26, 26, 27, 27, 27, 28, 28, 28, 29, 29, 29, 30, 30, 30, 31, 31, 31, 32, 32, 32, 33, 33, 33, 34, 34, 34, 35, 35, 35, 36, 36, 36, 37, 37, 37, 38, 38, 38, 39, 39, 39, 40, 40, 40, 41, 41, 41, 42, 42, 42, 43, 43, 43, 44, 44, 44, 45, 45, 45, 46, 46, 46, 47, 47, 47, 48, 48, 48, 49, 49, 49, 50, 50, 50, 51, 51, 51, 52, 52, 52, 53, 53, 53, 54, 54, 54, 55, 55, 55, 56, 56, 56, 57, 57, 57, 58, 58, 58, 59, 59, 59, 60, 60, 60, 61, 61, 61, 62, 62, 62, 63, 63, 63, 64, 64, 64, 65, 65, 65, 66, 66, 66, 67, 67, 67, 68, 68, 68, 69, 69, 69, 70, 70, 70, 71, 71, 71, 72, 72, 72, 73, 73, 73, 74, 74, 74, 75, 75, 75, 76, 76, 76, 77, 77, 77, 78, 78, 78, 79, 79, 79, 80, 80, 80, 81, 81, 81, 82, 82, 82, 83, 83, 83, 84, 84, 84, 85, 85, 85, 86, 86, 86, 87, 87, 87, 88, 88, 88, 89, 89, 89, 90, 90, 90, 91, 91, 91, 92, 92, 92, 93, 93, 93, 94, 94, 94, 95, 95, 95, 96, 96, 96, 97, 97, 97, 98, 98, 98, 99, 99, 99, 100, 100, 100, 101, 101, 101, 102, 102, 102, 103, 103, 103, 104, 104, 104, 105, 105, 105, 106, 106, 106, 107, 107, 107, 108, 108, 108, 109, 109, 109, 110, 110, 110, 111, 111, 111, 112, 112, 112, 113, 113, 113, 114, 114, 114, 115, 115, 115, 116, 116, 116, 117, 117, 117, 118, 118, 118, 119, 119, 119, 120, 120, 120, 121, 121, 121, 122, 122, 122, 123, 123, 123, 124, 124, 124, 125, 125, 125, 126, 126, 126, 127, 127, 127, 128, 128, 128, 129, 129, 129, 130, 130, 130, 131, 131, 131, 132, 132, 132, 133, 133, 133, 134, 134, 134, 135, 135, 135, 136, 136, 136, 137, 137, 137, 138, 138, 138, 139, 139, 139, 140, 140, 140, 141, 141, 141, 142, 142, 142, 143, 143, 143, 144, 144, 144, 145, 145, 145, 146, 146, 146, 147, 147, 147, 148, 148, 148, 149, 149, 149, 150, 150, 150, 151, 151, 151, 152, 152, 152, 153, 153, 153, 154, 154, 154, 155, 155, 155, 156, 156, 156, 157, 157, 157, 158, 158, 158, 159, 159, 159, 160, 160, 160, 161, 161, 161, 162, 162, 162, 163, 163, 163, 164, 164, 164, 165, 165, 165, 166, 166, 166, 167, 167, 167, 168, 168, 168, 169, 169, 169, 170, 170, 170, 171, 171, 171, 172, 172, 172, 173, 173, 173, 174, 174, 174, 175, 175, 175, 176, 176, 176, 177, 177, 177, 178, 178, 178, 179, 179, 179, 180, 180, 180, 181, 181, 181, 182, 182, 182, 183, 183, 183, 184, 184, 184, 185, 185, 185, 186, 186, 186, 187, 187, 187, 188, 188, 188, 189, 189, 189, 190, 190, 190, 191, 191, 191, 192, 192, 192, 193, 193, 193, 194, 194, 194, 195, 195, 195, 196, 196, 196, 197, 197, 197, 198, 198, 198, 199, 199, 199, 200, 200, 200, 201, 201, 201, 202, 202, 202, 203, 203, 203, 204, 204, 204, 205, 205, 205, 206, 206, 206, 207, 207, 207, 208, 208, 208, 209, 209, 209, 210, 210, 210, 211, 211, 211, 212, 212, 212, 213, 213, 213, 214, 214, 214, 215, 215, 215, 216, 216, 216, 217, 217, 217, 218, 218, 218, 219, 219, 219, 220, 220, 220, 221, 221, 221, 222, 222, 222, 223, 223, 223, 224, 224, 224, 225, 225, 225, 226, 226, 226, 227, 227, 227, 228, 228, 228, 229, 229, 229, 230, 230, 230, 231, 231, 231, 232, 232, 232, 233, 233, 233, 234, 234, 234, 235, 235, 235, 236, 236, 236, 237, 237, 237, 238, 238, 238, 239, 239, 239, 240, 240, 240, 241, 241, 241, 242, 242, 242, 243, 243, 243, 244, 244, 244, 245, 245, 245, 246, 246, 246, 247, 247, 247, 248, 248, 248, 249, 249, 249, 250, 250, 250, 251, 251, 251, 252, 252, 252, 253, 253, 253, 254, 254, 254, 255, 255, 255] 19 | 20 | def label2colormap(label): 21 | 22 | m = label.astype(np.uint8) 23 | r,c = m.shape 24 | cmap = np.zeros((r,c,3), dtype=np.uint8) 25 | cmap[:,:,0] = (m&1)<<7 | (m&8)<<3 | (m&64)>>1 26 | cmap[:,:,1] = (m&2)<<6 | (m&16)<<2 | (m&128)>>2 27 | cmap[:,:,2] = (m&4)<<5 | (m&32)<<1 28 | return cmap 29 | 30 | def masked_image(image, colored_mask, mask, alpha = 0.7): 31 | mask = np.expand_dims(mask > 0, axis=0) 32 | mask = np.repeat(mask, 3, axis=0) 33 | show_img = (image * alpha + colored_mask * (1 - alpha)) * mask + image * (1 - mask) 34 | return show_img 35 | 36 | def save_image(image, path): 37 | im = Image.fromarray(np.uint8(image * 255.).transpose((1, 2, 0))) 38 | im.save(path) 39 | 40 | def save_mask(mask_tensor, path): 41 | mask = mask_tensor.cpu().numpy().astype('uint8') 42 | mask = Image.fromarray(mask).convert('P') 43 | mask.putpalette(_palette) 44 | mask.save(path) 45 | 46 | def save_cluster_map(mask_tensor, path): 47 | mask = mask_tensor.cpu().numpy().astype('uint8') 48 | mask = Image.fromarray(mask).convert('P') 49 | mask.putpalette(cluster_map_palette) 50 | mask.save(path) 51 | 52 | def flip_tensor(tensor, dim=0): 53 | inv_idx = torch.arange(tensor.size(dim) - 1, -1, -1, device=tensor.device).long() 54 | tensor = tensor.index_select(dim, inv_idx) 55 | return tensor 56 | 57 | def save_matching_result(img,path): 58 | 59 | cv2.imwrite(path,img) 60 | 61 | def save_uncertain_mask(mask, path): 62 | #mask = mask_tensor.cpu().numpy().astype('uint8') 63 | mask = Image.fromarray(mask).convert('P') 64 | mask.putpalette(_palette) 65 | mask.save(path) -------------------------------------------------------------------------------- /AOC-Net/complete_project/AOCNet/utils/learning.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | def adjust_learning_rate(optimizer, base_lr, p, itr, max_itr, warm_up_steps=1000, is_cosine_decay=False, min_lr=1e-5): 5 | 6 | if itr < warm_up_steps: 7 | now_lr = base_lr * itr / warm_up_steps 8 | else: 9 | itr = itr - warm_up_steps 10 | max_itr = max_itr - warm_up_steps 11 | if is_cosine_decay: 12 | now_lr = base_lr * (math.cos(math.pi * itr / (max_itr + 1)) + 1.) * 0.5 13 | else: 14 | now_lr = base_lr * (1 - itr / (max_itr + 1)) ** p 15 | 16 | if now_lr < min_lr: 17 | now_lr = min_lr 18 | 19 | for param_group in optimizer.param_groups: 20 | param_group['lr'] = now_lr 21 | return now_lr 22 | 23 | 24 | def get_trainable_params(model, base_lr, weight_decay, beta_wd=True): 25 | params = [] 26 | for key, value in model.named_parameters(): 27 | if not value.requires_grad: 28 | continue 29 | wd = weight_decay 30 | if 'beta' in key: 31 | if not beta_wd: 32 | wd = 0. 33 | params += [{"params": [value], "lr": base_lr, "weight_decay": wd}] 34 | return params -------------------------------------------------------------------------------- /AOC-Net/complete_project/AOCNet/utils/meters.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | -------------------------------------------------------------------------------- /AOC-Net/complete_project/AOCNet/utils/metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def pytorch_iou(pred, target, obj_num, epsilon=1e-6): 4 | ''' 5 | pred: [bs, h, w] 6 | target: [bs, h, w] 7 | obj_num: [bs] 8 | ''' 9 | bs = obj_num.size(0) 10 | all_iou = [] 11 | for idx in range(bs): 12 | now_pred = pred[idx].unsqueeze(0) 13 | now_target = target[idx].unsqueeze(0) 14 | now_obj_num = obj_num[idx] 15 | 16 | obj_ids = torch.arange(0, now_obj_num + 1, device=now_pred.device).int().view(-1, 1, 1) 17 | if obj_ids.size(0) == 1: # only contain background 18 | continue 19 | else: 20 | obj_ids = obj_ids[1:] 21 | now_pred = (now_pred == obj_ids).float() 22 | now_target = (now_target == obj_ids).float() 23 | 24 | intersection = (now_pred * now_target).sum((1, 2)) 25 | union = ((now_pred + now_target) > 0).float().sum((1, 2)) 26 | 27 | now_iou = (intersection + epsilon) / (union + epsilon) 28 | 29 | all_iou.append(now_iou.mean()) 30 | if len(all_iou) > 0: 31 | all_iou = torch.stack(all_iou).mean() 32 | else: 33 | all_iou = torch.ones((1), device=pred.device) 34 | return all_iou 35 | -------------------------------------------------------------------------------- /AOC-Net/complete_project/README.md: -------------------------------------------------------------------------------- 1 | After I cleaned the codes to reduce redundancy, as I do not have GPU resources for testing currently, it may potentially involve bugs. 2 | -------------------------------------------------------------------------------- /AOC-Net/conditioning_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class conditioning_layer(nn.Module): 7 | def __init__(self, 8 | in_dim=256, 9 | beta_percentage=0.3): 10 | super(conditioning_layer,self).__init__() 11 | 12 | self.beta_percentage = beta_percentage 13 | 14 | kernel_size = 1 15 | self.phi_layer = nn.Conv2d(in_dim,1,kernel_size=kernel_size,stride=1,padding=int((kernel_size-1)/2)) 16 | self.mlp_layer = nn.Linear(in_dim, in_dim) 17 | 18 | nn.init.kaiming_normal_(self.phi_layer.weight,mode='fan_out',nonlinearity='relu') 19 | 20 | 21 | def forward(self, z_in): 22 | 23 | # Step 1: phi(z_in) 24 | x = self.phi_layer(z_in) 25 | 26 | # Step 2: beta 27 | x = x.reshape(x.size()[0],x.size()[1],-1) 28 | z_in_reshape = z_in.reshape(z_in.size()[0],z_in.size()[1],-1) 29 | beta_rank = int(self.beta_percentage*z_in.size()[-1]*z_in.size()[-2]) 30 | beta_val, _ = torch.topk(x, k=beta_rank, dim=-1, sorted=True) 31 | 32 | # Step 3: pi_beta(phi(z_in)) 33 | x = x > beta_val[...,-1,None] 34 | 35 | # Step 4: z_in \odot pi_beta(phi(z_in)) 36 | z_in_masked = z_in_reshape * x 37 | 38 | # Step 5: GAP (z_in \odot pi_beta(phi(z_in))) 39 | z_in_masked_gap = torch.nn.functional.avg_pool1d(z_in_masked, 40 | kernel_size=z_in_masked.size()[-1]).squeeze(-1) 41 | 42 | # Step 6: MLP ( GAP (z_in \odot pi_beta(phi(z_in))) ) 43 | out = mlp_layer(z_in_masked_gap) 44 | 45 | return out 46 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # main image 2 | FROM nvidia/cuda:10.1-cudnn7-devel-ubuntu18.04 3 | 4 | # tweaked azureml pytorch image 5 | # as of Aug 31, 2020 pt 1.6 doesn't seem to work with horovod on native mixed precision 6 | 7 | LABEL maintainer="Albert" 8 | LABEL maintainer_email="alsadovn@microsoft.com" 9 | LABEL version="0.1" 10 | 11 | USER root:root 12 | 13 | ENV com.nvidia.cuda.version $CUDA_VERSION 14 | ENV com.nvidia.volumes.needed nvidia_driver 15 | ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 16 | ENV DEBIAN_FRONTEND noninteractive 17 | ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64 18 | ENV NCCL_DEBUG=INFO 19 | ENV HOROVOD_GPU_ALLREDUCE=NCCL 20 | 21 | # Install Common Dependencies 22 | RUN apt-get update && \ 23 | apt-get install -y --no-install-recommends \ 24 | # SSH and RDMA 25 | libmlx4-1 \ 26 | libmlx5-1 \ 27 | librdmacm1 \ 28 | libibverbs1 \ 29 | libmthca1 \ 30 | libdapl2 \ 31 | dapl2-utils \ 32 | openssh-client \ 33 | openssh-server \ 34 | iproute2 && \ 35 | # Others 36 | apt-get install -y --no-install-recommends \ 37 | build-essential \ 38 | bzip2=1.0.6-8.1ubuntu0.2 \ 39 | libbz2-1.0=1.0.6-8.1ubuntu0.2 \ 40 | systemd \ 41 | git=1:2.17.1-1ubuntu0.7 \ 42 | wget \ 43 | cpio \ 44 | libsm6 \ 45 | libxext6 \ 46 | libxrender-dev \ 47 | fuse && \ 48 | apt-get clean -y && \ 49 | rm -rf /var/lib/apt/lists/* 50 | 51 | # Conda Environment 52 | ENV MINICONDA_VERSION 4.7.12.1 53 | ENV PATH /opt/miniconda/bin:$PATH 54 | RUN wget -qO /tmp/miniconda.sh https://repo.continuum.io/miniconda/Miniconda3-${MINICONDA_VERSION}-Linux-x86_64.sh && \ 55 | bash /tmp/miniconda.sh -bf -p /opt/miniconda && \ 56 | conda clean -ay && \ 57 | rm -rf /opt/miniconda/pkgs && \ 58 | rm /tmp/miniconda.sh && \ 59 | find / -type d -name __pycache__ | xargs rm -rf 60 | 61 | # To resolve horovod hangs due to a known NCCL issue in version 2.4. 62 | # Can remove it once we upgrade NCCL to 2.5+. 63 | # https://github.com/horovod/horovod/issues/893 64 | # ENV NCCL_TREE_THRESHOLD=0 65 | ENV PIP="pip install --no-cache-dir" 66 | 67 | RUN conda install -y conda=4.8.5 python=3.6.2 && conda clean -ay && \ 68 | conda install -y mkl=2020.1 && \ 69 | conda install -y numpy scipy scikit-learn scikit-image imageio protobuf && \ 70 | conda install -y ruamel.yaml==0.16.10 && \ 71 | # ruamel_yaml is a copy of ruamel.yaml package 72 | # conda installs version ruamel_yaml v0.15.87 which is vulnerable 73 | # force uninstall it leaving other packages intact 74 | conda remove --force -y ruamel_yaml && \ 75 | conda clean -ay && \ 76 | # Install AzureML SDK 77 | ${PIP} azureml-defaults && \ 78 | # Install PyTorch 79 | ${PIP} torch==1.4.0 && \ 80 | ${PIP} torchvision==0.2.1 && \ 81 | ${PIP} wandb && \ 82 | # # Install Horovod 83 | # HOROVOD_WITH_PYTORCH=1 ${PIP} horovod[pytorch]==0.19.5 && \ 84 | # ldconfig && \ 85 | ${PIP} tensorboard==1.15.0 && \ 86 | ${PIP} future==0.17.1 && \ 87 | ${PIP} onnxruntime==1.4.0 && \ 88 | ${PIP} pytorch-lightning && \ 89 | ${PIP} opencv-python-headless~=4.4.0 && \ 90 | ${PIP} imgaug==0.4.0 --no-deps && \ 91 | # hydra 92 | ${PIP} hydra-core --upgrade && \ 93 | ${PIP} lmdb pyarrow 94 | 95 | RUN pip3 install --upgrade pip 96 | RUN pip3 install pipreqs 97 | 98 | RUN apt-get update 99 | RUN apt-get install -y --no-install-recommends libglib2.0-dev 100 | RUN apt-get install -y --no-install-recommends vim 101 | 102 | WORKDIR / 103 | RUN apt-get install -y --no-install-recommends libunwind8 104 | RUN apt-get install -y --no-install-recommends libicu-dev 105 | RUN apt-get install -y --no-install-recommends htop 106 | RUN apt-get install -y --no-install-recommends net-tools 107 | RUN apt-get install -y --no-install-recommends rsync 108 | RUN apt-get install -y --no-install-recommends tree 109 | 110 | # put the requirements file for your own repo under /app for pip-based installation!!! 111 | WORKDIR /app 112 | RUN pip3 install -r requirements.txt 113 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Towards Robust Video Object Segmentation with Adaptive Object Calibration (ACM Multimedia 2022) 2 | 3 | 4 | Preview version **paper** of this work is available at [Arxiv](https://arxiv.org/abs/2207.00887). 5 | 6 | The conference **poster** is available at [this github repo](https://github.com/JerryX1110/Robust-Video-Object-Segmentation/blob/main/figs/mm22_345_poster_a0.pdf). 7 | 8 | Long paper **presentation video** is available at [GoogleDrive](https://drive.google.com/drive/folders/18q0eWTB5lXSkZUHOuBfyQO2LbIdTJof5?usp=sharing) and [YouTube](https://youtu.be/tFMfZ4NATrw). 9 | 10 | **Qualitative results and comparisons** with previous SOTAs are available at [YouTube](https://www.youtube.com/watch?v=3F6n7tcwWkA). 11 | 12 | Welcome to starts ⭐ & comments 💹 & collaboration 😀 !!** 13 | 14 | ```diff 15 | - 2022.11.16: All the codes are cleaned and released ~ 16 | - 2022.10.21: Add the robustness evaluation dataloader for other models, e.g., AOT~ 17 | - 2022.10.1:Add the code of key implementations of this work~ 18 | - 2022.9.25:Add the poster of this work~ 19 | - 2022.8.27: Add presentation video and PPT for this work~ 20 | - 2022.7.10: Add future works towards robust VOS! 21 | - 2022.7.5: Our ArXiv-version paper is available. 22 | - 2022.7.1: Repo init. Please stay tuned~ 23 | ``` 24 | --- 25 | 26 | ## Motivation for Robust Video Object Segmentation 27 | 28 | 截屏2022-07-11 13 01 56 29 | 30 | 31 | ## Pipeline 32 | 33 | 截屏2022-07-11 13 00 17 34 | 35 | ### Adaptive Object Proxy Representation (Component1) 36 | 37 | 38 | 截屏2022-07-11 13 05 05 39 | 40 | ### Object Mask Calibration (Component2) 41 | 42 | 43 | 截屏2022-07-11 13 04 23 44 | 45 | ## Abstract 46 | In the booming video era, video segmentation attracts increasing research attention in the multimedia community. 47 | 48 | Semi-supervised video object segmentation (VOS) aims at segmenting objects in all target frames of a video, given annotated object masks of reference frames. **Most existing methods build pixel-wise reference-target correlations and then perform pixel-wise tracking to obtain target masks. Due to neglecting object-level cues, pixel-level approaches make the tracking vulnerable to perturbations, and even indiscriminate among similar objects.** 49 | 50 | Towards **robust VOS**, the key insight is to calibrate the representation and mask of each specific object to be expressive and discriminative. Accordingly, we propose a new deep network, which can adaptively construct object representations and calibrate object masks to achieve stronger robustness. 51 | 52 | First, we construct the object representations by applying an **adaptive object proxy (AOP) aggregation** method, where the proxies represent arbitrary-shaped segments **via clustering** at multi-levels for reference. 53 | 54 | Then, prototype masks are initially generated from the reference-target correlations based on AOP. 55 | Afterwards, such proto-masks are further calibrated through network modulation, conditioning on the object proxy representations. 56 | We consolidate this **conditional mask calibration** process in a progressive manner, where the object representations and proto-masks evolve to be discriminative iteratively. 57 | 58 | Extensive experiments are conducted on the standard VOS benchmarks, YouTube-VOS-18/19 and DAVIS-17. 59 | Our model achieves the state-of-the-art performance among existing published works, and also exhibits significantly superior robustness against perturbations. 60 | 61 | ## Requirements 62 | * Python3 63 | * pytorch >= 1.4.0 64 | * torchvision 65 | * opencv-python 66 | * Pillow 67 | 68 | You can also use the docker image below to set up your env directly. However, this docker image may contain some redundent packages. 69 | 70 | ```latex 71 | docker image: xxiaoh/vos:10.1-cudnn7-torch1.4_v3 72 | ``` 73 | 74 | A more light-weight version can be created by modified the [Dockerfile](https://github.com/JerryX1110/RPCMVOS/blob/main/Dockerfile) provided. 75 | 76 | ## Preparation 77 | * Datasets 78 | 79 | * **YouTube-VOS** 80 | 81 | A commonly-used large-scale VOS dataset. 82 | 83 | [datasets/YTB/2019](datasets/YTB/2019): version 2019, download [link](https://drive.google.com/drive/folders/1BWzrCWyPEmBEKm0lOHe5KLuBuQxUSwqz?usp=sharing). `train` is required for training. `valid` (6fps) and `valid_all_frames` (30fps, optional) are used for evaluation. 84 | 85 | [datasets/YTB/2018](datasets/YTB/2018): version 2018, download [link](https://drive.google.com/drive/folders/1bI5J1H3mxsIGo7Kp-pPZU8i6rnykOw7f?usp=sharing). Only `valid` (6fps) and `valid_all_frames` (30fps, optional) are required for this project and used for evaluation. 86 | 87 | * **DAVIS** 88 | 89 | A commonly-used small-scale VOS dataset. 90 | 91 | [datasets/DAVIS](datasets/DAVIS): [TrainVal](https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017-trainval-480p.zip) (480p) contains both the training and validation split. [Test-Dev](https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017-test-dev-480p.zip) (480p) contains the Test-dev split. The [full-resolution version](https://davischallenge.org/davis2017/code.html) is also supported for training and evaluation but not required. 92 | 93 | * pretrained weights for the backbone 94 | 95 | [resnet101-deeplabv3p](https://drive.google.com/file/d/101jYpeGGG58Kywk03331PKKzv1wN5DL0/view?usp=sharing) 96 | 97 | 98 | ## Implementation 99 | 100 | The key implementation of matching with adaptive-proxy-based representation is provided in [THIS FILE](https://github.com/JerryX1110/Robust-Video-Object-Segmentation/blob/main/AOC-Net/adaptive_embedding_for_matching.py). Other implementation and training/evaluation details can refer to [PRCMVOS](https://github.com/JerryX1110/RPCMVOS) or [CFBI](https://github.com/z-x-yang/CFBI). 101 | 102 | The key implementation of the preliminary robust VOS benchmark evaluation is provided in [THIS FILE](https://github.com/JerryX1110/Robust-Video-Object-Segmentation/tree/main/Robust-VOS-Benchmark). 103 | 104 | The whole project code is provided in [THIS FOLDER](https://github.com/JerryX1110/Robust-Video-Object-Segmentation/tree/main/AOC-Net/complete_project). 105 | 106 | Feel free to contact me if you have any problems with the implementation~ 107 | 108 | 109 | 110 | * For evaluation, please use official YouTube-VOS servers ([2018 server](https://competitions.codalab.org/competitions/19544) and [2019 server](https://competitions.codalab.org/competitions/20127)), official [DAVIS toolkit](https://github.com/davisvideochallenge/davis-2017) (for Val), and official [DAVIS server](https://competitions.codalab.org/competitions/20516#learn_the_details) (for Test-dev). 111 | 112 | 113 | 114 | ## Limitation & Directions for further exploration towards Robust VOS! 115 | * Extension of the proposed clustering-based adaptive proxy representation to other dense-tracking tasks in a more efficient and robust way 116 | * Leverage the robust layered representation, i.e., intermediate masks, for robust mask calibration in other segmentation tasks 117 | * More diverse perturbation/corruption types can be studied for video segmentation tasks like VOS and VIS 118 | * Adversial attack and defence for VOS models is still an open question for further exploration 119 | * VOS model robustness verification and theoretical analysis 120 | * Model enhancement from the perspective of data management 121 | 122 | (to be continued...) 123 | 124 | ## Citation 125 | If you find this work is useful for your research, please consider citing: 126 | 127 | ```latex 128 | @inproceedings{xu2022towards, 129 | title={Towards Robust Video Object Segmentation with Adaptive Object Calibration}, 130 | author={Xu, Xiaohao and Wang, Jinglu and Ming, Xiang and Lu, Yan}, 131 | booktitle={Proceedings of the 30th ACM International Conference on Multimedia}, 132 | pages={2709--2718}, 133 | year={2022} 134 | } 135 | ``` 136 | 137 | 138 | ## Credit 139 | 140 | **CFBI**: 141 | 142 | **Deeplab**: 143 | 144 | **GCT**: 145 | 146 | ## Related Works in VOS 147 | **Semisupervised video object segmentation repo/paper link:** 148 | 149 | **ARKitTrack [CVPR 2023]**: 150 | 151 | **TarVis [Arxiv 2023]**: 152 | 153 | **LBLVOS [AAAI 2023]**: 154 | 155 | **DeAOT [NeurIPS 2022]**: 156 | 157 | **BATMAN [ECCV 2022 Oral]**: 158 | 159 | **XMEM [ECCV 2022]**: 160 | 161 | **TBD [ECCV 2022]**: 162 | 163 | **QDMN [ECCV 2022]**: 164 | 165 | **GSFM [ECCV 2022]**: 166 | 167 | **SWEM [CVPR 2022]**: 168 | 169 | **RDE [CVPR 2022]**: 170 | 171 | **COVOS [CVPR 2022]** : 172 | 173 | **RPCM [AAAI 2022 Oral]** : 174 | 175 | **AOT [NeurIPS 2021]**: 176 | 177 | **STCN [NeurIPS 2021]**: 178 | 179 | **JOINT [ICCV 2021]**: 180 | 181 | **HMMN [ICCV 2021]**: 182 | 183 | **DMN-AOA [ICCV 2021]**: 184 | 185 | **MiVOS [CVPR 2021]**: 186 | 187 | **SSTVOS [CVPR 2021 Oral]**: 188 | 189 | **GraphMemVOS [ECCV 2020]**: 190 | 191 | **AFB-URR [NeurIPS 2020]**: 192 | 193 | **CFBI [ECCV 2020]**: 194 | 195 | **FRTM-VOS [CVPR 2020]**: 196 | 197 | **STM [ICCV 2019]**: 198 | 199 | **FEELVOS [CVPR 2019]**: 200 | 201 | (The list may be incomplete, feel free to contact me by pulling a issue and I'll add them on!) 202 | 203 | ## Useful websites for VOS 204 | **The 1st Large-scale Video Object Segmentation Challenge**: 205 | 206 | **The 2nd Large-scale Video Object Segmentation Challenge - Track 1: Video Object Segmentation**: 207 | 208 | **The Semi-Supervised DAVIS Challenge on Video Object Segmentation @ CVPR 2020**: 209 | 210 | **DAVIS**: 211 | 212 | **YouTube-VOS**: 213 | 214 | **Papers with code for Semi-VOS**: 215 | 216 | ## Acknowledgement ❤️ 217 | This work is heavily built upon CFBI and RPCMVOS. Thanks to the author of CFBI to release such a wonderful code repo for further work to build upon! 218 | 219 | ## Welcome to comments and discussions!! 220 | Xiaohao Xu: 221 | 222 | ## License 223 | This project is released under the Mit license. See [LICENSE](LICENSE) for additional details. 224 | -------------------------------------------------------------------------------- /Robust-VOS-Benchmark/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /figs/FIGS.MD: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /figs/mm22_345_poster_a0.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/Robust-Video-Object-Segmentation/f562afe70ac0f960d27980d3e82b237f8dea45ec/figs/mm22_345_poster_a0.pdf -------------------------------------------------------------------------------- /figs/mm22_345_poster_a0.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/Robust-Video-Object-Segmentation/f562afe70ac0f960d27980d3e82b237f8dea45ec/figs/mm22_345_poster_a0.pptx --------------------------------------------------------------------------------