├── README.md ├── data ├── LoadDataSeg.py ├── coco_train.py ├── coco_val.py ├── transforms │ ├── functional.py │ ├── transform1.py │ └── transforms.py ├── voc_train.py └── voc_val.py ├── data_list ├── train │ ├── split0_train.txt │ ├── split1_train.txt │ ├── split2_train.txt │ └── split3_train.txt ├── train_list │ ├── split0_train.txt │ ├── split1_train.txt │ ├── split2_train.txt │ └── split3_train.txt ├── val │ ├── split0_val.txt │ ├── split1_val.txt │ ├── split2_val.txt │ └── split3_val.txt └── val_list │ ├── split0_val.txt │ ├── split1_val.txt │ ├── split2_val.txt │ └── split3_val.txt ├── data_parallel.py ├── img ├── chain10.png ├── graph4.png ├── prior1.png └── result7.png ├── models ├── PMMs.py ├── PMMs_single.py └── backbone │ ├── AlexNet.py │ ├── NetworkInNetwork.py │ ├── __pycache__ │ ├── NetworkInNetwork.cpython-36.pyc │ ├── NetworkInNetwork.cpython-37.pyc │ ├── resnet.cpython-36.pyc │ ├── resnet.cpython-37.pyc │ ├── resnet_dialated.cpython-36.pyc │ ├── resnet_dialated.cpython-37.pyc │ ├── resnet_dialated4.cpython-36.pyc │ ├── resnet_dialated4.cpython-37.pyc │ ├── resnet_dialated_fuse.cpython-37.pyc │ └── vgg.cpython-37.pyc │ ├── resnet.py │ └── vgg.py ├── networks ├── FPMMs.py ├── FRPMMs.py ├── VGG16based.py ├── __init__.py ├── __pycache__ │ ├── FPMMs.cpython-36.pyc │ ├── FPMMs.cpython-37.pyc │ ├── FRPMMs.cpython-36.pyc │ ├── FRPMMs.cpython-37.pyc │ ├── VGG16based.cpython-36.pyc │ ├── VGG16based.cpython-37.pyc │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── resnet50_34.cpython-36.pyc │ ├── resnet50_34.cpython-37.pyc │ ├── resnet50_34_all.cpython-36.pyc │ ├── resnet50_34_all.cpython-37.pyc │ ├── resnet50_34_s.cpython-36.pyc │ ├── resnet50_34_s.cpython-37.pyc │ ├── resnet50based.cpython-36.pyc │ ├── resnet50based.cpython-37.pyc │ ├── resnet50based4.cpython-36.pyc │ ├── resnet50based4.cpython-37.pyc │ ├── resnet50basedf.cpython-36.pyc │ ├── resnet50basedf.cpython-37.pyc │ └── resnet50basedfused.cpython-37.pyc └── resnet50_34.py ├── test.py ├── test_5shot.py ├── test_all_frame.py ├── test_frame.py └── utils ├── NoteEvaluation.py ├── NoteLoss.py ├── Restore.py ├── Visualize.py ├── __pycache__ ├── NoteEvaluation.cpython-36.pyc ├── NoteEvaluation.cpython-37.pyc ├── NoteLoss.cpython-37.pyc ├── Restore.cpython-36.pyc ├── Restore.cpython-37.pyc ├── Visualize.cpython-36.pyc ├── Visualize.cpython-37.pyc ├── my_optim.cpython-36.pyc └── my_optim.cpython-37.pyc ├── my_optim.py └── test.py /README.md: -------------------------------------------------------------------------------- 1 | # Progressively-Dual-Prior-Guided-Few-shot-Semantic-Segmentation 2 | Codes for progressive dual prior guided few-shot semantic segmentation. 3 | 4 | the overall network: 5 |

6 | the overall network 7 |

8 | some visualization results: 9 | the overall network: 10 |

11 | the results 12 |

13 | 14 | 15 | 16 | 17 | ### Datasets and Data Preparation 18 | 19 | We follow the dataset setting in **PMMs**: https://github.com/Yang-Bob/PMMs and utilize the Dependencies in **PMMs** to run the code. 20 | 21 | The pretrained [**model**](https://pan.baidu.com/s/1qn_AhDbV5Q5XM-PpuKrNqQ) in **VOC_split0_1-shot** is provided as a sample. (Password:yd11) 22 | -------------------------------------------------------------------------------- /data/LoadDataSeg.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | 4 | from config import settings 5 | from data.transforms import transforms 6 | from data.transforms import transform1 7 | from torch.utils.data import DataLoader 8 | from data.voc_train import voc_train 9 | from data.voc_val import voc_val 10 | from data.coco_train import coco_train 11 | from data.coco_val import coco_val 12 | 13 | 14 | def data_loader(args,k_shot=1): 15 | 16 | batch = args.batch_size 17 | mean_vals = settings.mean_vals 18 | std_vals = settings.std_vals 19 | size = settings.size 20 | 21 | tsfm_train = transforms.Compose([transforms.ToPILImage(), 22 | transforms.RandomHorizontalFlip(), 23 | transforms.RandomResizedCrop(size=size, scale=(0.5, 1.0)), 24 | transforms.ToTensor(), 25 | transforms.Normalize(mean_vals, std_vals) 26 | ]) 27 | # value_scale = 255 28 | # mean = [0.485, 0.456, 0.406] 29 | # mean = [item * value_scale for item in mean] 30 | # std = [0.229, 0.224, 0.225] 31 | # std = [item * value_scale for item in std] 32 | # tsfm_train = transform1.Compose([transform1.RandScale([0.9, 1.1]), 33 | # transform1.RandRotate([-10, 10], padding=mean, ignore_label=255), 34 | # transform1.RandomGaussianBlur(), 35 | # transform1.RandomHorizontalFlip(), 36 | # transform1.Crop([size, size], crop_type='rand', padding=mean, ignore_label=255), 37 | # transform1.ToTensor(), 38 | # transform1.Normalize(mean=mean, std=std) 39 | # ]) 40 | 41 | if args.dataset == 'coco': 42 | img_train = coco_train(args, transform=tsfm_train,k_shot=k_shot) 43 | if args.dataset == 'voc': 44 | img_train = voc_train(args, transform=tsfm_train,k_shot=k_shot) 45 | 46 | train_loader = DataLoader(img_train, batch_size=batch, shuffle=True, num_workers=1) 47 | 48 | return train_loader 49 | 50 | def val_loader(args, k_shot=1): 51 | mean_vals = settings.mean_vals 52 | std_vals = settings.std_vals 53 | 54 | tsfm_val = transforms.Compose([transforms.ToPILImage(), 55 | transforms.Resize(size=(321,321)), 56 | transforms.ToTensor(), 57 | transforms.Normalize(mean_vals, std_vals) 58 | ]) 59 | 60 | # value_scale = 255 61 | # mean = [0.485, 0.456, 0.406] 62 | # mean = [item * value_scale for item in mean] 63 | # std = [0.229, 0.224, 0.225] 64 | # std = [item * value_scale for item in std] 65 | # tsfm_val = transform1.Compose([ 66 | # transform1.Resize(size=321), 67 | # transform1.ToTensor(), 68 | # transform1.Normalize(mean=mean, std=std) 69 | # ]) 70 | 71 | if args.dataset == 'coco': 72 | img_val = coco_val(args, transform=tsfm_val, k_shot=k_shot) 73 | if args.dataset == 'voc': 74 | img_val = voc_val(args, transform=tsfm_val, k_shot=k_shot) 75 | 76 | 77 | val_loader = DataLoader(img_val, batch_size=args.batch_size, shuffle=False, num_workers=1) 78 | 79 | return val_loader 80 | -------------------------------------------------------------------------------- /data/coco_train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | 4 | # random.seed(1234) 5 | # from .transforms import functional 6 | import os 7 | import cv2 8 | import random 9 | import PIL.Image as Image 10 | import numpy as np 11 | from config import settings 12 | from pycocotools.coco import COCO 13 | #random.seed(1385) 14 | import torch 15 | class coco_train(): 16 | 17 | """Face Landmarks dataset.""" 18 | 19 | def __init__(self, args, transform=None,k_shot=1): 20 | """ 21 | Args: 22 | csv_file (string): Path to the csv file with annotations. 23 | root_dir (string): Directory with all the images. 24 | transform (callable, optional): Optional transform to be applied 25 | on a sample. 26 | """ 27 | self.num_classes = 80 28 | self.group = args.group 29 | self.num_folds = args.num_folds 30 | 31 | self.dataDir='/disk2/caoqinglong/coco' 32 | self.dataType='train2017' # train2017 val2017 33 | self.annFile='{}/annotations/instances_{}.json'.format(self.dataDir, self.dataType) 34 | self.coco=COCO(self.annFile) 35 | 36 | #self.nms = self.get_nms() 37 | self.train_id_list = self.get_train_id_list() 38 | self.coco_all_id = self.coco.getCatIds() 39 | self.train_coco_id_list = self.get_train_coco_id_list() 40 | self.list_splite = self.get_total_list() 41 | self.list_splite_len = len(self.list_splite) 42 | self.list_class = self.get_class_list() 43 | 44 | self.transform = transform 45 | self.count = 0 46 | self.random_generator = random.Random() 47 | self.len = args.max_steps *args.batch_size *2 48 | #self.random_generator.shuffle(self.list_splite) 49 | self.k_shot = k_shot 50 | 51 | #self.random_generator.seed(1385) 52 | #self.split = args.split 53 | 54 | def get_nms(self): 55 | cats = self.coco.loadCats(self.coco.getCatIds()) 56 | nms = [cat['name'] for cat in cats] 57 | return nms 58 | 59 | def get_train_coco_id_list(self): 60 | train_coco_id_list = [] 61 | for i in self.train_id_list: 62 | cls = self.coco_all_id[i] 63 | train_coco_id_list.append(cls) 64 | 65 | return train_coco_id_list 66 | 67 | def get_train_id_list(self): 68 | num = int(self.num_classes/ self.num_folds) 69 | #val_set = [self.group * num + v for v in range(num)] 70 | val_set = [self.group + self.num_folds * v for v in range(num)] 71 | 72 | train_set = [x for x in range(self.num_classes) if x not in val_set] 73 | 74 | return train_set 75 | 76 | def get_category(self, annotations): 77 | category_id_list = [] 78 | for ann in annotations: 79 | category_id_list.append(ann['category_id']) 80 | category = np.array(category_id_list) 81 | category = np.unique(category) 82 | return category 83 | 84 | def get_total_list(self): 85 | new_exist_class_list = [] 86 | for coco_id in self.train_coco_id_list: 87 | imgIds = self.coco.getImgIds(catIds=coco_id); 88 | for i in range(len(imgIds)): 89 | img = self.coco.loadImgs(imgIds[i])[0] 90 | annIds = self.coco.getAnnIds(imgIds=img['id'], iscrowd=None) # catIds=catIds, 91 | anns = self.coco.loadAnns(annIds) 92 | label = self.get_category(anns) 93 | ##filt the img not in train set 94 | #if set(label.tolist()).issubset(self.train_coco_id_list): 95 | new_exist_class_list.append(img['id']) 96 | 97 | new_exist_class_list_unique = list(set(new_exist_class_list)) 98 | print("Total images after filted are : ", len(new_exist_class_list_unique)) 99 | return new_exist_class_list_unique 100 | 101 | 102 | def get_class_list(self): 103 | list_class = {} 104 | for i in range(self.num_classes): 105 | list_class[i] = [] 106 | for name in self.list_splite: 107 | annIds = self.coco.getAnnIds(imgIds=name, iscrowd=None) # catIds=catIds, 108 | anns = self.coco.loadAnns(annIds) 109 | labels = self.get_category(anns) 110 | for class_ in labels: 111 | if class_ in self.train_coco_id_list: 112 | # decode coco label to our label 113 | class_us = self.coco_all_id.index(class_) 114 | list_class[class_us].append(name) 115 | 116 | return list_class 117 | 118 | def read_img(self, name): 119 | img = self.coco.loadImgs(name)[0] 120 | path = self.dataDir + '/train2017/' + img['file_name'] 121 | img = Image.open(path) 122 | 123 | return img 124 | 125 | def read_mask(self, name, category): 126 | 127 | img = self.coco.loadImgs(name)[0] 128 | 129 | annIds = self.coco.getAnnIds(imgIds=name, catIds=category, iscrowd=None) # catIds=catIds, 130 | anns = self.coco.loadAnns(annIds) 131 | 132 | mask = self.get_mask(img, anns, category) 133 | 134 | return mask.astype(np.float32) 135 | 136 | def polygons_to_mask2(self, img_shape, polygons): 137 | 138 | mask = np.zeros(img_shape, dtype=np.uint8) 139 | polygons = np.asarray([polygons], np.int32) # 这里必须是int32,其他类型使用fillPoly会报错 140 | # cv2.fillPoly(mask, polygons, 1) # 非int32 会报错 141 | cv2.fillConvexPoly(mask, polygons, 1) # 非int32 会报错 142 | return mask 143 | 144 | def get_mask(self, img, annotations, category_id): 145 | len_ann = len(annotations) 146 | 147 | half_mask = [] 148 | final_mask = [] 149 | 150 | for ann in annotations: 151 | if ann['category_id'] == category_id: 152 | if ann['iscrowd'] == 1: 153 | continue 154 | seg1 = ann['segmentation'] 155 | seg = seg1[0] 156 | for j in range(0, len(seg), 2): 157 | x = seg[j] 158 | y = seg[j + 1] 159 | mas = [x, y] 160 | half_mask.append(mas) 161 | final_mask.append(half_mask) 162 | half_mask = [] 163 | 164 | mask0 = self.polygons_to_mask2([img['height'],img['width']], final_mask[0]) 165 | for i in range(1, len(final_mask)): 166 | maskany = self.polygons_to_mask2([img['height'],img['width']], final_mask[i]) 167 | mask0 += maskany 168 | 169 | mask0[mask0 > 1] = 1 170 | 171 | return mask0 172 | 173 | def load_frame(self, support_name, query_name, class_): 174 | support_img = self.read_img(support_name) 175 | query_img = self.read_img(query_name) 176 | class_coco = self.coco_all_id[class_] 177 | support_mask = self.read_mask(support_name, class_coco) 178 | query_mask = self.read_mask(query_name, class_coco) 179 | 180 | #support_mask = self.read_binary_mask(support_name, class_) 181 | #query_mask = self.read_binary_mask(query_name, class_) 182 | 183 | return query_img.convert('RGB'), query_mask, support_img.convert('RGB'), support_mask, class_ 184 | 185 | def load_frame_k_shot(self, support_name_list, query_name, class_): 186 | class_coco = self.coco_all_id[class_] 187 | 188 | query_img = self.read_img(query_name) 189 | query_mask = self.read_mask(query_name, class_coco) 190 | 191 | support_img_list = [] 192 | support_mask_list = [] 193 | 194 | for support_name in support_name_list: 195 | support_img = self.read_img(support_name) 196 | support_mask = self.read_mask(support_name, class_coco) 197 | support_img_list.append(support_img.convert('RGB')) 198 | support_mask_list.append(support_mask) 199 | 200 | return query_img.convert('RGB'), query_mask, support_img_list, support_mask_list 201 | 202 | def random_choose(self): 203 | class_ = np.random.choice(self.train_id_list, 1, replace=False)[0] 204 | cat_list = self.list_class[class_] 205 | sample_img_ids_1 = np.random.choice(len(cat_list), 2, replace=False) 206 | 207 | query_name = cat_list[sample_img_ids_1[0]] 208 | support_name = cat_list[sample_img_ids_1[1]] 209 | 210 | return support_name, query_name, class_ 211 | 212 | def random_choose_k(self): 213 | class_ = np.random.choice(self.train_id_list, 1, replace=False)[0] 214 | cat_list = self.list_class[class_] 215 | sample_img_ids_1 = np.random.choice(len(cat_list), self.k_shot+1, replace=False) 216 | 217 | query_name = cat_list[sample_img_ids_1[0]] 218 | 219 | support_name_list = [] 220 | for i in range(self.k_shot): 221 | support_name = cat_list[sample_img_ids_1[i+1]] 222 | support_name_list.append(support_name) 223 | 224 | return support_name_list, query_name, class_ 225 | 226 | def get_1_shot(self, idx): 227 | support_name, query_name, class_ = self.random_choose() 228 | 229 | img = self.coco.loadImgs(query_name)[0] 230 | size = [img['height'], img['width']] 231 | 232 | while True: 233 | # support_name, query_name, class_ = self.random_choose() 234 | query_img, query_mask, support_img, support_mask, class_ = self.load_frame(support_name, query_name, class_) 235 | sum1 = query_mask.sum() 236 | sum2 = support_mask.sum() 237 | if sum1 >= 2 * 32 * 32 and sum2 >= 2 * 32 * 32: 238 | break 239 | else: 240 | support_name, query_name, class_ = self.random_choose() 241 | # query_img, query_mask, support_img, support_mask, class_ = self.load_frame(support_name, query_name, 242 | # class_) # class_ is cooc laebl 243 | 244 | if self.transform is not None: 245 | query_img, query_mask = self.transform(query_img, query_mask) 246 | support_img, support_mask = self.transform(support_img, support_mask) 247 | 248 | self.count = self.count + 1 249 | 250 | return query_img, query_mask, support_img, support_mask, class_, size 251 | 252 | def get_k_shot(self, idx): 253 | support_name_list, query_name, class_ = self.random_choose_k() 254 | 255 | img = self.coco.loadImgs(query_name)[0] 256 | size = [img['height'], img['width']] 257 | 258 | while True: 259 | query_img, query_mask, support_img_list, support_mask_list = self.load_frame_k_shot(support_name_list, query_name, class_) 260 | sum1 = query_mask.sum() 261 | sum2_0 = support_mask_list[0].sum() 262 | sum2_1 = support_mask_list[1].sum() 263 | sum2_2 = support_mask_list[2].sum() 264 | sum2_3 = support_mask_list[3].sum() 265 | sum2_4 = support_mask_list[4].sum() 266 | k =0 267 | if sum1 >= 2 * 32 * 32 : 268 | k=k+1 269 | if sum2_0 >= 2 * 32 * 32: 270 | k = k+1 271 | if sum2_1 >= 2 * 32 * 32: 272 | k = k+1 273 | if sum2_2 >= 2 * 32 * 32: 274 | k = k+1 275 | if sum2_3 >= 2 * 32 * 32: 276 | k = k+1 277 | if sum2_4 >= 2 * 32 * 32: 278 | k = k+1 279 | 280 | if k==6: 281 | break 282 | else: 283 | support_name_list, query_name, class_ = self.random_choose_k() 284 | 285 | # query_img, query_mask, support_img_list, support_mask_list = self.load_frame_k_shot(support_name_list, query_name, class_) # class_ is cooc laebl 286 | 287 | if self.transform is not None: 288 | query_img, query_mask = self.transform(query_img, query_mask) 289 | for i in range(len(support_mask_list)): 290 | support_temp_img = support_img_list[i] 291 | support_temp_mask = support_mask_list[i] 292 | support_temp_img, support_temp_mask = self.transform(support_temp_img, support_temp_mask) 293 | support_temp_img = support_temp_img.unsqueeze(dim=0) 294 | support_temp_mask = support_temp_mask.unsqueeze(dim=0) 295 | 296 | #print(support_temp_img.shape) 297 | if i ==0: 298 | support_img = support_temp_img 299 | support_mask = support_temp_mask 300 | else: 301 | support_img = torch.cat([support_img, support_temp_img], dim=0) 302 | support_mask = torch.cat([support_mask, support_temp_mask], dim=0) 303 | 304 | self.count = self.count + 1 305 | 306 | return query_img, query_mask, support_img, support_mask, class_, size 307 | def __len__(self): 308 | # return len(self.image_list) 309 | return self.len 310 | 311 | 312 | def __getitem__(self, idx): 313 | # support_name, query_name, class_ = self.random_choose() 314 | # 315 | # 316 | # query_img, query_mask, support_img, support_mask, class_ = self.load_frame(support_name, query_name, class_) # class_ is cooc laebl 317 | # 318 | # if self.transform is not None: 319 | # query_img, query_mask = self.transform(query_img, query_mask) 320 | # support_img, support_mask = self.transform(support_img, support_mask) 321 | # 322 | # self.count = self.count + 1 323 | 324 | 325 | if self.k_shot==1: 326 | query_img, query_mask, support_img, support_mask, class_, size = self.get_1_shot(idx)# , size 327 | else: 328 | 329 | query_img, query_mask, support_img, support_mask, class_, size = self.get_k_shot(idx) # , size 330 | 331 | return query_img, query_mask, support_img, support_mask, class_ 332 | 333 | 334 | -------------------------------------------------------------------------------- /data/coco_val.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | 4 | # random.seed(1234) 5 | # from .transforms import functional 6 | import os 7 | import cv2 8 | import random 9 | import torch 10 | import PIL.Image as Image 11 | import numpy as np 12 | from config import settings 13 | from pycocotools.coco import COCO 14 | #random.seed(1385) 15 | 16 | ''' 17 | 1. change seed 18 | 2. filt small object 19 | ''' 20 | 21 | class coco_val(): 22 | 23 | """Face Landmarks dataset.""" 24 | 25 | def __init__(self, args, transform=None, k_shot=1): 26 | """ 27 | Args: 28 | csv_file (string): Path to the csv file with annotations. 29 | root_dir (string): Directory with all the images. 30 | transform (callable, optional): Optional transform to be applied 31 | on a sample. 32 | """ 33 | self.num_classes = 80 34 | self.group = args.group 35 | self.num_folds = args.num_folds 36 | 37 | self.dataDir='/disk2/caoqinglong/coco' 38 | self.dataType='val2017' # train2017 val2017 39 | self.annFile='{}/annotations/instances_{}.json'.format(self.dataDir, self.dataType) 40 | self.coco=COCO(self.annFile) 41 | 42 | #self.nms = self.get_nms() 43 | self.val_id_list = self.get_val_id_list() 44 | self.coco_all_id = self.coco.getCatIds() 45 | self.val_coco_id_list = self.get_val_coco_id_list() 46 | self.list_splite = self.get_total_list() 47 | self.list_splite_len = len(self.list_splite) 48 | self.list_class = self.get_class_list() 49 | 50 | self.transform = transform 51 | self.count = 0 52 | self.random_generator = random.Random() 53 | self.k_shot = k_shot 54 | #self.len = args.max_steps *args.batch_size *2 55 | #self.random_generator.shuffle(self.list_splite) 56 | 57 | 58 | self.random_generator.seed(1385) 59 | #self.split = args.split 60 | 61 | def get_nms(self): 62 | cats = self.coco.loadCats(self.coco.getCatIds()) 63 | nms = [cat['name'] for cat in cats] 64 | return nms 65 | 66 | def get_val_coco_id_list(self): 67 | val_coco_id_list = [] 68 | for i in self.val_id_list: 69 | cls = self.coco_all_id[i] 70 | val_coco_id_list.append(cls) 71 | 72 | return val_coco_id_list 73 | 74 | def get_val_id_list(self): 75 | num = int(self.num_classes/ self.num_folds) 76 | #val_set = [self.group * num + v for v in range(num)] 77 | val_set = [self.group + self.num_folds * v for v in range(num)] 78 | 79 | #train_set = [x for x in range(self.num_classes) if x not in val_set] 80 | 81 | return val_set 82 | 83 | def get_category(self, annotations): 84 | category_id_list = [] 85 | for ann in annotations: 86 | category_id_list.append(ann['category_id']) 87 | category = np.array(category_id_list) 88 | category = np.unique(category) 89 | return category 90 | 91 | def get_total_list(self): 92 | new_exist_class_list = [] 93 | for coco_id in self.val_coco_id_list: 94 | imgIds = self.coco.getImgIds(catIds=coco_id); 95 | for i in range(len(imgIds)): 96 | img = self.coco.loadImgs(imgIds[i])[0] 97 | annIds = self.coco.getAnnIds(imgIds=img['id'], iscrowd=None) # catIds=catIds, 98 | anns = self.coco.loadAnns(annIds) 99 | label = self.get_category(anns) 100 | ##filt the img not in train set 101 | #if set(label.tolist()).issubset(self.train_coco_id_list): 102 | new_exist_class_list.append(img['id']) 103 | 104 | new_exist_class_list_unique = list(set(new_exist_class_list)) 105 | print("Total images are : ", len(new_exist_class_list_unique)) 106 | return new_exist_class_list_unique 107 | 108 | 109 | def get_class_list(self): 110 | list_class = {} 111 | for i in range(self.num_classes): 112 | list_class[i] = [] 113 | for name in self.list_splite: 114 | annIds = self.coco.getAnnIds(imgIds=name, iscrowd=None) # catIds=catIds, 115 | anns = self.coco.loadAnns(annIds) 116 | labels = self.get_category(anns) 117 | for class_ in labels: 118 | # decode coco label to our label 119 | class_us = self.coco_all_id.index(class_) 120 | list_class[class_us].append(name) 121 | 122 | return list_class 123 | 124 | def read_img(self, name): 125 | img = self.coco.loadImgs(name)[0] 126 | path = self.dataDir + '/val2017/' + img['file_name'] 127 | img = Image.open(path) 128 | 129 | return img 130 | 131 | def read_mask(self, name, category): 132 | 133 | img = self.coco.loadImgs(name)[0] 134 | 135 | annIds = self.coco.getAnnIds(imgIds=name, catIds=category, iscrowd=None) # catIds=catIds, 136 | anns = self.coco.loadAnns(annIds) 137 | 138 | mask = self.get_mask(img, anns, category) 139 | 140 | return mask.astype(np.float32) 141 | 142 | def polygons_to_mask2(self, img_shape, polygons): 143 | 144 | mask = np.zeros(img_shape, dtype=np.uint8) 145 | polygons = np.asarray([polygons], np.int32) # 这里必须是int32,其他类型使用fillPoly会报错 146 | # cv2.fillPoly(mask, polygons, 1) # 非int32 会报错 147 | cv2.fillConvexPoly(mask, polygons, 1) # 非int32 会报错 148 | return mask 149 | 150 | def get_mask(self, img, annotations, category_id): 151 | len_ann = len(annotations) 152 | 153 | half_mask = [] 154 | final_mask = [] 155 | 156 | for ann in annotations: 157 | if ann['category_id'] == category_id: 158 | if ann['iscrowd'] == 1: 159 | continue 160 | seg1 = ann['segmentation'] 161 | seg = seg1[0] 162 | for j in range(0, len(seg), 2): 163 | x = seg[j] 164 | y = seg[j + 1] 165 | mas = [x, y] 166 | half_mask.append(mas) 167 | final_mask.append(half_mask) 168 | half_mask = [] 169 | 170 | mask0 = self.polygons_to_mask2([img['height'],img['width']], final_mask[0]) 171 | for i in range(1, len(final_mask)): 172 | maskany = self.polygons_to_mask2([img['height'],img['width']], final_mask[i]) 173 | mask0 += maskany 174 | 175 | mask0[mask0 > 1] = 1 176 | 177 | return mask0 178 | 179 | def load_frame(self, support_name, query_name, class_): 180 | support_img = self.read_img(support_name) 181 | query_img = self.read_img(query_name) 182 | class_coco = self.coco_all_id[class_] 183 | support_mask = self.read_mask(support_name, class_coco) 184 | query_mask = self.read_mask(query_name, class_coco) 185 | 186 | #support_mask = self.read_binary_mask(support_name, class_) 187 | #query_mask = self.read_binary_mask(query_name, class_) 188 | 189 | 190 | 191 | return query_img.convert('RGB'), query_mask, support_img.convert('RGB'), support_mask, class_ 192 | 193 | def load_frame_k_shot(self, support_name_list, query_name, class_): 194 | class_coco = self.coco_all_id[class_] 195 | 196 | query_img = self.read_img(query_name) 197 | query_mask = self.read_mask(query_name, class_coco) 198 | 199 | support_img_list = [] 200 | support_mask_list = [] 201 | 202 | for support_name in support_name_list: 203 | support_img = self.read_img(support_name) 204 | support_mask = self.read_mask(support_name, class_coco) 205 | support_img_list.append(support_img.convert('RGB')) 206 | support_mask_list.append(support_mask) 207 | 208 | return query_img.convert('RGB'), query_mask, support_img_list, support_mask_list 209 | 210 | def random_choose(self): 211 | class_ = np.random.choice(self.val_id_list, 1, replace=False)[0] 212 | cat_list = self.list_class[class_] 213 | sample_img_ids_1 = np.random.choice(len(cat_list), 2, replace=False) 214 | 215 | query_name = cat_list[sample_img_ids_1[0]] 216 | support_name = cat_list[sample_img_ids_1[1]] 217 | 218 | return support_name, query_name, class_ 219 | 220 | def random_choose_k(self): 221 | class_ = np.random.choice(self.val_id_list, 1, replace=False)[0] 222 | cat_list = self.list_class[class_] 223 | sample_img_ids_1 = np.random.choice(len(cat_list), self.k_shot+1, replace=False) 224 | 225 | query_name = cat_list[sample_img_ids_1[0]] 226 | 227 | support_name_list = [] 228 | for i in range(self.k_shot): 229 | support_name = cat_list[sample_img_ids_1[i+1]] 230 | support_name_list.append(support_name) 231 | 232 | return support_name_list, query_name, class_ 233 | 234 | def get_1_shot(self, idx): 235 | support_name, query_name, class_ = self.random_choose() 236 | 237 | img = self.coco.loadImgs(query_name)[0] 238 | size = [img['height'], img['width']] 239 | 240 | # query_img, query_mask, support_img, support_mask, class_ = self.load_frame(support_name, query_name, 241 | # class_) # class_ is cooc laebl 242 | while True: 243 | # support_name, query_name, class_ = self.random_choose() 244 | query_img, query_mask, support_img, support_mask, class_ = self.load_frame(support_name, query_name, class_) 245 | sum1 = query_mask.sum() 246 | sum2 = support_mask.sum() 247 | if sum1 >= 2 * 32 * 32 and sum2 >= 2 * 32 * 32: 248 | break 249 | else: 250 | support_name, query_name, class_ = self.random_choose() 251 | 252 | if self.transform is not None: 253 | query_img, query_mask = self.transform(query_img, query_mask) 254 | support_img, support_mask = self.transform(support_img, support_mask) 255 | 256 | self.count = self.count + 1 257 | 258 | return query_img, query_mask, support_img, support_mask, class_, size 259 | 260 | def get_k_shot(self, idx): 261 | support_name_list, query_name, class_ = self.random_choose_k() 262 | 263 | img = self.coco.loadImgs(query_name)[0] 264 | size = [img['height'], img['width']] 265 | 266 | # query_img, query_mask, support_img_list, support_mask_list = self.load_frame_k_shot(support_name_list, query_name, class_) # class_ is cooc laebl 267 | 268 | while True: 269 | query_img, query_mask, support_img_list, support_mask_list = self.load_frame_k_shot(support_name_list, query_name, class_) 270 | sum1 = query_mask.sum() 271 | sum2_0 = support_mask_list[0].sum() 272 | sum2_1 = support_mask_list[1].sum() 273 | sum2_2 = support_mask_list[2].sum() 274 | sum2_3 = support_mask_list[3].sum() 275 | sum2_4 = support_mask_list[4].sum() 276 | k =0 277 | if sum1 >= 2 * 32 * 32 : 278 | k=k+1 279 | if sum2_0 >= 2 * 32 * 32: 280 | k = k+1 281 | if sum2_1 >= 2 * 32 * 32: 282 | k = k+1 283 | if sum2_2 >= 2 * 32 * 32: 284 | k = k+1 285 | if sum2_3 >= 2 * 32 * 32: 286 | k = k+1 287 | if sum2_4 >= 2 * 32 * 32: 288 | k = k+1 289 | 290 | if k==6: 291 | break 292 | else: 293 | support_name_list, query_name, class_ = self.random_choose_k() 294 | 295 | if self.transform is not None: 296 | query_img, query_mask = self.transform(query_img, query_mask) 297 | for i in range(len(support_mask_list)): 298 | support_temp_img = support_img_list[i] 299 | support_temp_mask = support_mask_list[i] 300 | support_temp_img, support_temp_mask = self.transform(support_temp_img, support_temp_mask) 301 | support_temp_img = support_temp_img.unsqueeze(dim=0) 302 | support_temp_mask = support_temp_mask.unsqueeze(dim=0) 303 | 304 | #print(support_temp_img.shape) 305 | if i ==0: 306 | support_img = support_temp_img 307 | support_mask = support_temp_mask 308 | else: 309 | support_img = torch.cat([support_img, support_temp_img], dim=0) 310 | support_mask = torch.cat([support_mask, support_temp_mask], dim=0) 311 | 312 | self.count = self.count + 1 313 | 314 | return query_img, query_mask, support_img, support_mask, class_, size 315 | 316 | def __len__(self): 317 | # return len(self.image_list) 318 | return 1000 319 | 320 | 321 | def __getitem__(self, idx): 322 | 323 | 324 | if self.k_shot==1: 325 | query_img, query_mask, support_img, support_mask, class_, size = self.get_1_shot(idx)# , size 326 | else: 327 | query_img, query_mask, support_img, support_mask, class_, size = self.get_k_shot(idx) # , size 328 | 329 | return query_img, query_mask, support_img, support_mask, class_, size 330 | 331 | 332 | -------------------------------------------------------------------------------- /data/voc_train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | 4 | import os 5 | import cv2 6 | import random 7 | import PIL.Image as Image 8 | import numpy as np 9 | from config import settings 10 | import torch 11 | #random.seed(1385) 12 | 13 | class voc_train(): 14 | 15 | """Face Landmarks dataset.""" 16 | 17 | def __init__(self, args, transform=None,k_shot=1): 18 | """ 19 | Args: 20 | csv_file (string): Path to the csv file with annotations. 21 | root_dir (string): Directory with all the images. 22 | transform (callable, optional): Optional transform to be applied 23 | on a sample. 24 | """ 25 | self.num_classes = 20 26 | self.group = args.group 27 | self.num_folds = args.num_folds 28 | #self.binary_map_dir = os.path.join(settings.DATA_DIR, 'VOCdevkit2012/VOC2012/', 'Binary_map_aug/train') #val 29 | self.data_list_dir = os.path.join('data_list/train') 30 | self.img_dir = os.path.join(settings.DATA_DIR, 'VOCdevkit2012/VOC2012AUG/', 'JPEGImages/') 31 | self.mask_dir = os.path.join(settings.DATA_DIR, 'VOCdevkit2012/VOC2012AUG/', 'SegmentationClassAug/') 32 | #self.binary_mask_dir = os.path.join(settings.DATA_DIR, 'VOCdevkit2012/VOC2012/', 'Binary_map_aug/train/') 33 | 34 | self.train_id_list = self.get_train_id_list() 35 | self.list_splite = self.get_total_list() 36 | self.list_splite_len = len(self.list_splite) 37 | self.list_class = self.get_class_list() 38 | 39 | self.transform = transform 40 | self.count = 0 41 | self.random_generator = random.Random() 42 | self.len = args.max_steps *args.batch_size *2 43 | #self.random_generator.shuffle(self.list_splite) 44 | #self.random_generator.seed(1385) 45 | self.k_shot = k_shot 46 | def get_train_id_list(self): 47 | num = int(self.num_classes/ self.num_folds) 48 | val_set = [self.group * num + v for v in range(num)] 49 | train_set = [x for x in range(self.num_classes) if x not in val_set] 50 | 51 | return train_set 52 | 53 | def get_total_list(self): 54 | new_exist_class_list = [] 55 | 56 | fold_list = [0, 1, 2, 3] 57 | fold_list.remove(self.group) 58 | 59 | for fold in fold_list: 60 | f = open(os.path.join(self.data_list_dir, 'split%1d_train.txt' % (fold))) 61 | while True: 62 | item = f.readline() 63 | if item == '': 64 | break 65 | img_name = item[:11] 66 | cat = int(item[13:15]) -1 67 | new_exist_class_list.append([img_name, cat]) 68 | print("Total images are : ", len(new_exist_class_list)) 69 | # if need filter 70 | new_exist_class_list = self.filte_multi_class(new_exist_class_list) 71 | return new_exist_class_list 72 | 73 | def filte_multi_class(self, exist_class_list): 74 | 75 | new_exist_class_list = [] 76 | for name, class_ in exist_class_list: 77 | 78 | mask_path = self.mask_dir + name + '.png' 79 | mask = cv2.imread(mask_path) 80 | labels = np.unique(mask[:,:,0]) 81 | 82 | labels = [label - 1 for label in labels if label != 255 and label != 0] 83 | if set(labels).issubset(self.train_id_list): 84 | new_exist_class_list.append([name, class_]) 85 | print("Total images after filted are : ", len(new_exist_class_list)) 86 | return new_exist_class_list 87 | 88 | 89 | def get_class_list(self): 90 | list_class = {} 91 | for i in range(self.num_classes): 92 | list_class[i] = [] 93 | for name, class_ in self.list_splite: 94 | list_class[class_].append(name) 95 | 96 | return list_class 97 | 98 | def read_img(self, name): 99 | path = self.img_dir + name + '.jpg' 100 | img = Image.open(path) 101 | # image = cv2.imread(path, cv2.IMREAD_COLOR) 102 | # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 103 | # img = np.float32(image) 104 | return img 105 | 106 | def read_mask(self, name, category): 107 | path = self.mask_dir + name + '.png' 108 | mask = cv2.imread(path) 109 | 110 | # mask[mask==255] = 255 111 | mask[mask!=category+1] = 0 112 | mask[mask==category+1] = 1 113 | return mask[:,:,0].astype(np.float32) 114 | 115 | ''' 116 | def read_binary_mask(self, name, category): 117 | path = self.binary_mask_dir +str(category+1)+'/'+ name + '.png' 118 | mask = cv2.imread(path)/255 119 | 120 | return mask[:,:,0].astype(np.float32) 121 | ''' 122 | def load_frame(self, support_name, query_name, class_): 123 | support_img = self.read_img(support_name) 124 | query_img = self.read_img(query_name) 125 | support_mask = self.read_mask(support_name, class_) 126 | query_mask = self.read_mask(query_name, class_) 127 | 128 | #support_mask = self.read_binary_mask(support_name, class_) 129 | #query_mask = self.read_binary_mask(query_name, class_) 130 | 131 | return query_img, query_mask, support_img, support_mask, class_ 132 | 133 | def load_frame_k_shot(self, support_name_list, query_name, class_): 134 | query_img = self.read_img(query_name) 135 | query_mask = self.read_mask(query_name, class_) 136 | 137 | support_img_list = [] 138 | support_mask_list = [] 139 | 140 | for support_name in support_name_list: 141 | support_img = self.read_img(support_name) 142 | support_mask = self.read_mask(support_name, class_) 143 | support_img_list.append(support_img) 144 | support_mask_list.append(support_mask) 145 | 146 | return query_img, query_mask, support_img_list, support_mask_list 147 | 148 | def random_choose(self): 149 | class_ = np.random.choice(self.train_id_list, 1, replace=False)[0] 150 | cat_list = self.list_class[class_] 151 | sample_img_ids_1 = np.random.choice(len(cat_list), 2, replace=False) 152 | 153 | query_name = cat_list[sample_img_ids_1[0]] 154 | support_name = cat_list[sample_img_ids_1[1]] 155 | 156 | return support_name, query_name, class_ 157 | 158 | def get_1_shot(self, idx): 159 | if self.count >= self.list_splite_len: 160 | self.random_generator.shuffle(self.list_splite) 161 | self.count = 0 162 | query_name, class_ = self.list_splite[self.count] 163 | 164 | while True: 165 | query_name, class_ = self.list_splite[self.count] 166 | while True: # random sample a support data 167 | support_img_list = self.list_class[class_] 168 | support_name = support_img_list[self.random_generator.randint(0, len(support_img_list) - 1)] 169 | if support_name != query_name: 170 | break 171 | query_img, query_mask, support_img, support_mask, class_ = self.load_frame(support_name, query_name, class_) 172 | sum1 = query_mask.sum() 173 | sum2 = support_mask.sum() 174 | if sum1 >= 2 * 32 * 32 and sum2 >= 2 * 32 * 32: 175 | break 176 | else: 177 | self.count = self.count + 1 178 | if self.count >= self.list_splite_len: 179 | self.random_generator.shuffle(self.list_splite) 180 | self.count = 0 181 | 182 | size = query_mask.shape 183 | 184 | if self.transform is not None: 185 | query_img, query_mask = self.transform(query_img, query_mask) 186 | support_img, support_mask = self.transform(support_img, support_mask) 187 | 188 | self.count = self.count + 1 189 | 190 | return query_img, query_mask, support_img, support_mask, class_, size 191 | 192 | def get_k_shot(self, idx): 193 | 194 | if self.count >= self.list_splite_len: 195 | self.random_generator.shuffle(self.list_splite) 196 | self.count = 0 197 | 198 | while True: 199 | if self.count >= self.list_splite_len: 200 | self.random_generator.shuffle(self.list_splite) 201 | self.count = 0 202 | query_name, class_ = self.list_splite[self.count] 203 | # random sample 5 support data 204 | support_set_list = self.list_class[class_] 205 | support_choice_list = support_set_list.copy() 206 | support_choice_list.remove(query_name) 207 | support_name_list = self.random_generator.sample(support_choice_list, self.k_shot) 208 | query_img, query_mask, support_img_list, support_mask_list = self.load_frame_k_shot(support_name_list, query_name, class_) 209 | sum1 = query_mask.sum() 210 | sum2_0 = support_mask_list[0].sum() 211 | sum2_1 = support_mask_list[1].sum() 212 | sum2_2 = support_mask_list[2].sum() 213 | sum2_3 = support_mask_list[3].sum() 214 | sum2_4 = support_mask_list[4].sum() 215 | k =0 216 | if sum1 >= 2 * 32 * 32 : 217 | k=k+1 218 | if sum2_0 >= 2 * 32 * 32: 219 | k = k+1 220 | if sum2_1 >= 2 * 32 * 32: 221 | k = k+1 222 | if sum2_2 >= 2 * 32 * 32: 223 | k = k+1 224 | if sum2_3 >= 2 * 32 * 32: 225 | k = k+1 226 | if sum2_4 >= 2 * 32 * 32: 227 | k = k+1 228 | 229 | if k==6: 230 | break 231 | else: 232 | self.count = self.count + 1 233 | 234 | size = query_mask.shape 235 | 236 | if self.transform is not None: 237 | query_img, query_mask = self.transform(query_img, query_mask) 238 | for i in range(len(support_mask_list)): 239 | support_temp_img = support_img_list[i] 240 | support_temp_mask = support_mask_list[i] 241 | support_temp_img, support_temp_mask = self.transform(support_temp_img, support_temp_mask) 242 | support_temp_img = support_temp_img.unsqueeze(dim=0) 243 | support_temp_mask = support_temp_mask.unsqueeze(dim=0) 244 | if i ==0: 245 | support_img = support_temp_img 246 | support_mask = support_temp_mask 247 | else: 248 | support_img = torch.cat([support_img, support_temp_img], dim=0) 249 | support_mask = torch.cat([support_mask, support_temp_mask], dim=0) 250 | 251 | 252 | self.count = self.count + 1 253 | 254 | return query_img, query_mask, support_img, support_mask, class_, size 255 | 256 | def __len__(self): 257 | # return len(self.image_list) 258 | return self.len 259 | 260 | def __getitem__(self, idx): 261 | # support_name, query_name, class_ = self.random_choose() 262 | 263 | # query_img, query_mask, support_img, support_mask, class_ = self.load_frame(support_name, query_name, class_) 264 | 265 | # while True: 266 | # support_name, query_name, class_ = self.random_choose() 267 | # query_img, query_mask, support_img, support_mask, class_ = self.load_frame(support_name, query_name, class_) 268 | # sum1 = query_mask.sum() 269 | # sum2 = support_mask.sum() 270 | # # print(sum2) 271 | # # print(sum1) 272 | # if sum1 >0 and sum2 >0 : 273 | # # print('ok') 274 | # break 275 | # 276 | # if self.transform is not None: 277 | # query_img, query_mask = self.transform(query_img, query_mask) 278 | # support_img, support_mask = self.transform(support_img, support_mask) 279 | # 280 | # self.count = self.count + 1 281 | 282 | if self.k_shot==1: 283 | query_img, query_mask, support_img, support_mask, class_, size = self.get_1_shot(idx)# , size 284 | else: 285 | query_img, query_mask, support_img, support_mask, class_, size = self.get_k_shot(idx) # , size 286 | 287 | 288 | return query_img, query_mask, support_img, support_mask, class_ 289 | -------------------------------------------------------------------------------- /data/voc_val.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | 4 | import os 5 | import cv2 6 | import random 7 | import PIL.Image as Image 8 | import numpy as np 9 | from config import settings 10 | import torch 11 | 12 | 13 | class voc_val(): 14 | 15 | """voc dataset.""" 16 | 17 | def __init__(self, args, transform=None, k_shot=1): 18 | self.num_classes = 20 19 | self.group = args.group 20 | self.num_folds = args.num_folds 21 | #self.binary_map_dir = os.path.join(settings.DATA_DIR, 'VOCdevkit2012/VOC2012/', 'list/val') #val 22 | self.data_list_dir = os.path.join('data_list/val') 23 | self.img_dir = os.path.join(settings.DATA_DIR, 'VOCdevkit2012/VOC2012AUG/', 'JPEGImages/') 24 | self.mask_dir = os.path.join(settings.DATA_DIR, 'VOCdevkit2012/VOC2012AUG/', 'SegmentationClassAug/') 25 | #self.binary_mask_dir = os.path.join(settings.DATA_DIR, 'VOCdevkit2012/VOC2012/', 'Binary_map_aug/val/') 26 | self.list_splite = self.get_total_list() 27 | self.list_splite_len = len(self.list_splite) 28 | self.list_class = self.get_class_list() 29 | self.transform = transform 30 | self.count = 0 31 | self.random_generator = random.Random() 32 | self.random_generator.seed(1385) #1385 33 | self.k_shot = k_shot 34 | 35 | def get_total_list(self): 36 | new_exist_class_list = [] 37 | f = open(os.path.join(self.data_list_dir, 'split%1d_val.txt' % (self.group))) 38 | while True: 39 | item = f.readline() 40 | if item == '': 41 | break 42 | img_name = item[:11] 43 | cat = int(item[13:15]) -1 44 | new_exist_class_list.append([img_name, cat]) 45 | print("Total images are : ", len(new_exist_class_list)) 46 | return new_exist_class_list 47 | 48 | def get_class_list(self): 49 | list_class = {} 50 | for i in range(self.num_classes): 51 | list_class[i] = [] 52 | for name, class_ in self.list_splite: 53 | if class_ < 0: 54 | print(name) 55 | list_class[class_].append(name) 56 | 57 | return list_class 58 | 59 | def read_img(self, name): 60 | path = self.img_dir + name + '.jpg' 61 | img = Image.open(path) 62 | # image = cv2.imread(path, cv2.IMREAD_COLOR) 63 | # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 64 | # img = np.float32(image) 65 | return img 66 | 67 | def read_mask(self, name, category): 68 | path = self.mask_dir + name + '.png' 69 | mask = cv2.imread(path) 70 | # mask[mask==255] = 255 71 | mask[mask!=category+1] = 0 72 | mask[mask==category+1] = 1 73 | return mask[:,:,0].astype(np.float32) 74 | 75 | ''' 76 | def read_binary_mask(self, name, category): 77 | path = self.binary_mask_dir +str(category+1)+'/'+ name + '.png' 78 | mask = cv2.imread(path)/255 79 | 80 | #mask[mask!=category+1] = 0 81 | #mask[mask==category+1] = 1 82 | 83 | return mask[:,:,0].astype(np.float32) 84 | ''' 85 | def load_frame(self, support_name, query_name, class_): 86 | support_img = self.read_img(support_name) 87 | query_img = self.read_img(query_name) 88 | support_mask = self.read_mask(support_name, class_) 89 | query_mask = self.read_mask(query_name, class_) 90 | 91 | 92 | 93 | return query_img, query_mask, support_img, support_mask, class_ 94 | 95 | def load_frame_k_shot(self, support_name_list, query_name, class_): 96 | # query_name = '2007_005951' 97 | # class_ = 13 98 | query_img = self.read_img(query_name) 99 | query_mask = self.read_mask(query_name, class_) 100 | 101 | support_img_list = [] 102 | support_mask_list = [] 103 | 104 | for support_name in support_name_list: 105 | support_img = self.read_img(support_name) 106 | support_mask = self.read_mask(support_name, class_) 107 | support_img_list.append(support_img) 108 | support_mask_list.append(support_mask) 109 | 110 | return query_img, query_mask, support_img_list, support_mask_list 111 | 112 | 113 | def get_1_shot(self, idx): 114 | if self.count >= self.list_splite_len: 115 | self.random_generator.shuffle(self.list_splite) 116 | self.count = 0 117 | query_name, class_ = self.list_splite[self.count] 118 | 119 | while True: 120 | if self.count >= self.list_splite_len: 121 | self.random_generator.shuffle(self.list_splite) 122 | self.count = 0 123 | query_name, class_ = self.list_splite[self.count] 124 | # query_name = '2008_006641' 125 | # class_ = 4 126 | while True: # random sample a support data 127 | support_img_list = self.list_class[class_] 128 | support_name = support_img_list[self.random_generator.randint(0, len(support_img_list) - 1)] 129 | if support_name != query_name: 130 | break 131 | query_img, query_mask, support_img, support_mask, class_ = self.load_frame(support_name, query_name, class_) 132 | sum1 = query_mask.sum() 133 | sum2 = support_mask.sum() 134 | if sum1 >= 2*32*32 and sum2 >= 2*32*32: 135 | break 136 | else: 137 | self.count = self.count + 1 138 | 139 | size = query_mask.shape 140 | 141 | if self.transform is not None: 142 | query_img, query_mask = self.transform(query_img, query_mask) 143 | support_img, support_mask = self.transform(support_img, support_mask) 144 | 145 | self.count = self.count + 1 146 | 147 | return query_img, query_mask, support_img, support_mask, class_, size 148 | 149 | # return query_img, query_mask, support_img, support_mask, support_name, query_name, class_, size 150 | 151 | def get_k_shot(self, idx): 152 | 153 | if self.count >= self.list_splite_len: 154 | self.random_generator.shuffle(self.list_splite) 155 | self.count = 0 156 | 157 | while True: 158 | if self.count >= self.list_splite_len: 159 | self.random_generator.shuffle(self.list_splite) 160 | self.count = 0 161 | query_name, class_ = self.list_splite[self.count] 162 | # query_name = '2007_001288' 163 | # class_ = 0 164 | # random sample 5 support data 165 | support_set_list = self.list_class[class_] 166 | support_choice_list = support_set_list.copy() 167 | support_choice_list.remove(query_name) 168 | support_name_list = self.random_generator.sample(support_choice_list, self.k_shot) 169 | query_img, query_mask, support_img_list, support_mask_list = self.load_frame_k_shot(support_name_list, query_name, class_) 170 | sum1 = query_mask.sum() 171 | sum2_0 = support_mask_list[0].sum() 172 | sum2_1 = support_mask_list[1].sum() 173 | sum2_2 = support_mask_list[2].sum() 174 | sum2_3 = support_mask_list[3].sum() 175 | sum2_4 = support_mask_list[4].sum() 176 | 177 | k =0 178 | if sum1 >= 2 * 32 * 32 : 179 | k=k+1 180 | if sum2_0 >= 2 * 32 * 32: 181 | k = k+1 182 | if sum2_1 >= 2 * 32 * 32: 183 | k = k+1 184 | if sum2_2 >= 2 * 32 * 32: 185 | k = k+1 186 | if sum2_3 >= 2 * 32 * 32: 187 | k = k+1 188 | if sum2_4 >= 2 * 32 * 32: 189 | k = k+1 190 | 191 | if k==6: 192 | break 193 | else: 194 | self.count = self.count + 1 195 | 196 | size = query_mask.shape 197 | 198 | if self.transform is not None: 199 | query_img, query_mask = self.transform(query_img, query_mask) 200 | for i in range(len(support_mask_list)): 201 | support_temp_img = support_img_list[i] 202 | support_temp_mask = support_mask_list[i] 203 | support_temp_img, support_temp_mask = self.transform(support_temp_img, support_temp_mask) 204 | support_temp_img = support_temp_img.unsqueeze(dim=0) 205 | support_temp_mask = support_temp_mask.unsqueeze(dim=0) 206 | if i ==0: 207 | support_img = support_temp_img 208 | support_mask = support_temp_mask 209 | else: 210 | support_img = torch.cat([support_img, support_temp_img], dim=0) 211 | support_mask = torch.cat([support_mask, support_temp_mask], dim=0) 212 | 213 | 214 | self.count = self.count + 1 215 | 216 | return query_img, query_mask, support_img, support_mask, class_, size 217 | 218 | def __len__(self): 219 | # return len(self.image_list) 220 | return 1000 221 | 222 | 223 | def __getitem__(self, idx): 224 | if self.k_shot==1: 225 | query_img, query_mask, support_img, support_mask, class_, size = self.get_1_shot(idx)# , size 226 | else: 227 | query_img, query_mask, support_img, support_mask, class_, size = self.get_k_shot(idx) # , size 228 | 229 | return query_img, query_mask, support_img, support_mask, class_, size 230 | 231 | 232 | -------------------------------------------------------------------------------- /data_list/val/split0_val.txt: -------------------------------------------------------------------------------- 1 | 2007_000033__01 2 | 2007_000061__04 3 | 2007_000129__02 4 | 2007_000346__05 5 | 2007_000529__04 6 | 2007_000559__05 7 | 2007_000572__02 8 | 2007_000762__05 9 | 2007_001288__01 10 | 2007_001289__03 11 | 2007_001311__02 12 | 2007_001408__05 13 | 2007_001568__01 14 | 2007_001630__02 15 | 2007_001761__01 16 | 2007_001884__01 17 | 2007_002094__03 18 | 2007_002266__01 19 | 2007_002376__01 20 | 2007_002400__03 21 | 2007_002619__01 22 | 2007_002719__04 23 | 2007_003088__05 24 | 2007_003131__04 25 | 2007_003188__02 26 | 2007_003349__03 27 | 2007_003571__04 28 | 2007_003621__02 29 | 2007_003682__03 30 | 2007_003861__04 31 | 2007_004052__01 32 | 2007_004143__03 33 | 2007_004241__04 34 | 2007_004468__05 35 | 2007_005074__04 36 | 2007_005107__02 37 | 2007_005294__05 38 | 2007_005304__05 39 | 2007_005428__05 40 | 2007_005509__01 41 | 2007_005600__01 42 | 2007_005705__04 43 | 2007_005828__01 44 | 2007_006076__03 45 | 2007_006086__05 46 | 2007_006449__02 47 | 2007_006946__01 48 | 2007_007084__03 49 | 2007_007235__02 50 | 2007_007341__01 51 | 2007_007470__01 52 | 2007_007477__04 53 | 2007_007836__02 54 | 2007_008051__03 55 | 2007_008084__03 56 | 2007_008204__05 57 | 2007_008670__03 58 | 2007_009088__03 59 | 2007_009258__02 60 | 2007_009323__03 61 | 2007_009458__05 62 | 2007_009687__05 63 | 2007_009817__03 64 | 2007_009911__01 65 | 2008_000120__04 66 | 2008_000123__03 67 | 2008_000533__03 68 | 2008_000725__02 69 | 2008_000911__05 70 | 2008_001013__04 71 | 2008_001040__04 72 | 2008_001135__04 73 | 2008_001260__04 74 | 2008_001404__02 75 | 2008_001514__03 76 | 2008_001531__02 77 | 2008_001546__01 78 | 2008_001580__04 79 | 2008_001966__03 80 | 2008_001971__01 81 | 2008_002043__03 82 | 2008_002269__02 83 | 2008_002358__01 84 | 2008_002429__03 85 | 2008_002467__05 86 | 2008_002504__04 87 | 2008_002775__05 88 | 2008_002864__05 89 | 2008_003034__04 90 | 2008_003076__05 91 | 2008_003108__02 92 | 2008_003110__03 93 | 2008_003155__01 94 | 2008_003270__02 95 | 2008_003369__01 96 | 2008_003858__04 97 | 2008_003876__01 98 | 2008_003886__04 99 | 2008_003926__01 100 | 2008_003976__01 101 | 2008_004363__02 102 | 2008_004654__02 103 | 2008_004659__05 104 | 2008_004704__01 105 | 2008_004758__02 106 | 2008_004995__02 107 | 2008_005262__05 108 | 2008_005338__01 109 | 2008_005628__04 110 | 2008_005727__02 111 | 2008_005812__05 112 | 2008_005904__05 113 | 2008_006216__01 114 | 2008_006229__04 115 | 2008_006254__02 116 | 2008_006703__01 117 | 2008_007120__03 118 | 2008_007143__04 119 | 2008_007219__05 120 | 2008_007350__01 121 | 2008_007498__03 122 | 2008_007811__05 123 | 2008_007994__03 124 | 2008_008268__03 125 | 2008_008629__02 126 | 2008_008711__02 127 | 2008_008746__03 128 | 2009_000032__01 129 | 2009_000037__03 130 | 2009_000121__05 131 | 2009_000149__02 132 | 2009_000201__05 133 | 2009_000205__01 134 | 2009_000318__03 135 | 2009_000354__02 136 | 2009_000387__01 137 | 2009_000421__04 138 | 2009_000440__01 139 | 2009_000446__04 140 | 2009_000457__02 141 | 2009_000469__04 142 | 2009_000573__02 143 | 2009_000619__03 144 | 2009_000664__03 145 | 2009_000723__04 146 | 2009_000828__04 147 | 2009_000840__05 148 | 2009_000879__03 149 | 2009_000991__03 150 | 2009_000998__03 151 | 2009_001108__03 152 | 2009_001160__03 153 | 2009_001255__02 154 | 2009_001278__05 155 | 2009_001314__03 156 | 2009_001332__01 157 | 2009_001565__03 158 | 2009_001607__03 159 | 2009_001683__03 160 | 2009_001718__02 161 | 2009_001765__03 162 | 2009_001818__05 163 | 2009_001850__01 164 | 2009_001851__01 165 | 2009_001941__04 166 | 2009_002185__05 167 | 2009_002295__02 168 | 2009_002320__01 169 | 2009_002372__05 170 | 2009_002521__05 171 | 2009_002594__05 172 | 2009_002604__03 173 | 2009_002649__05 174 | 2009_002727__04 175 | 2009_002732__05 176 | 2009_002749__05 177 | 2009_002808__01 178 | 2009_002856__05 179 | 2009_002888__01 180 | 2009_002928__02 181 | 2009_003003__05 182 | 2009_003005__01 183 | 2009_003043__04 184 | 2009_003080__04 185 | 2009_003193__02 186 | 2009_003224__02 187 | 2009_003269__05 188 | 2009_003273__03 189 | 2009_003343__02 190 | 2009_003378__03 191 | 2009_003450__03 192 | 2009_003498__03 193 | 2009_003504__04 194 | 2009_003517__05 195 | 2009_003640__03 196 | 2009_003696__01 197 | 2009_003707__04 198 | 2009_003806__01 199 | 2009_003858__03 200 | 2009_003971__02 201 | 2009_004021__03 202 | 2009_004084__03 203 | 2009_004125__04 204 | 2009_004247__05 205 | 2009_004324__05 206 | 2009_004509__03 207 | 2009_004540__03 208 | 2009_004568__03 209 | 2009_004579__05 210 | 2009_004635__04 211 | 2009_004653__01 212 | 2009_004848__02 213 | 2009_004882__02 214 | 2009_004886__03 215 | 2009_004895__03 216 | 2009_004969__01 217 | 2009_005038__05 218 | 2009_005137__03 219 | 2009_005156__02 220 | 2009_005189__01 221 | 2009_005190__05 222 | 2009_005260__03 223 | 2009_005262__03 224 | 2009_005302__05 225 | 2010_000065__02 226 | 2010_000083__02 227 | 2010_000084__04 228 | 2010_000238__01 229 | 2010_000241__03 230 | 2010_000272__04 231 | 2010_000342__02 232 | 2010_000426__05 233 | 2010_000572__01 234 | 2010_000622__01 235 | 2010_000814__03 236 | 2010_000906__04 237 | 2010_000961__03 238 | 2010_001016__03 239 | 2010_001017__01 240 | 2010_001024__01 241 | 2010_001036__04 242 | 2010_001061__03 243 | 2010_001069__03 244 | 2010_001174__01 245 | 2010_001367__02 246 | 2010_001367__05 247 | 2010_001448__01 248 | 2010_001830__05 249 | 2010_001995__03 250 | 2010_002017__05 251 | 2010_002030__02 252 | 2010_002142__03 253 | 2010_002147__01 254 | 2010_002150__04 255 | 2010_002200__01 256 | 2010_002310__01 257 | 2010_002536__02 258 | 2010_002546__04 259 | 2010_002693__02 260 | 2010_002939__01 261 | 2010_003127__01 262 | 2010_003132__01 263 | 2010_003168__03 264 | 2010_003362__03 265 | 2010_003365__01 266 | 2010_003418__03 267 | 2010_003468__05 268 | 2010_003473__03 269 | 2010_003495__01 270 | 2010_003547__04 271 | 2010_003716__01 272 | 2010_003771__03 273 | 2010_003781__05 274 | 2010_003820__03 275 | 2010_003912__02 276 | 2010_003915__01 277 | 2010_004041__04 278 | 2010_004056__05 279 | 2010_004208__04 280 | 2010_004314__01 281 | 2010_004419__01 282 | 2010_004520__05 283 | 2010_004529__05 284 | 2010_004551__05 285 | 2010_004556__03 286 | 2010_004559__03 287 | 2010_004662__04 288 | 2010_004772__04 289 | 2010_004828__05 290 | 2010_004994__03 291 | 2010_005252__04 292 | 2010_005401__04 293 | 2010_005428__03 294 | 2010_005496__05 295 | 2010_005531__03 296 | 2010_005534__01 297 | 2010_005582__05 298 | 2010_005664__02 299 | 2010_005705__04 300 | 2010_005718__01 301 | 2010_005762__05 302 | 2010_005877__01 303 | 2010_005888__01 304 | 2010_006034__01 305 | 2010_006070__02 306 | 2011_000066__05 307 | 2011_000112__03 308 | 2011_000185__03 309 | 2011_000234__04 310 | 2011_000238__04 311 | 2011_000412__02 312 | 2011_000435__04 313 | 2011_000456__03 314 | 2011_000482__03 315 | 2011_000585__02 316 | 2011_000669__03 317 | 2011_000747__05 318 | 2011_000874__01 319 | 2011_001114__01 320 | 2011_001161__04 321 | 2011_001263__01 322 | 2011_001287__03 323 | 2011_001407__01 324 | 2011_001421__03 325 | 2011_001434__01 326 | 2011_001589__04 327 | 2011_001624__01 328 | 2011_001793__04 329 | 2011_001880__01 330 | 2011_001988__02 331 | 2011_002064__02 332 | 2011_002098__05 333 | 2011_002223__02 334 | 2011_002295__03 335 | 2011_002327__01 336 | 2011_002515__01 337 | 2011_002675__01 338 | 2011_002713__02 339 | 2011_002754__04 340 | 2011_002863__05 341 | 2011_002929__01 342 | 2011_002975__04 343 | 2011_003003__02 344 | 2011_003030__03 345 | 2011_003145__03 346 | 2011_003271__05 347 | -------------------------------------------------------------------------------- /data_list/val/split1_val.txt: -------------------------------------------------------------------------------- 1 | 2007_000452__09 2 | 2007_000464__10 3 | 2007_000491__10 4 | 2007_000663__06 5 | 2007_000663__07 6 | 2007_000727__06 7 | 2007_000727__07 8 | 2007_000804__09 9 | 2007_000830__09 10 | 2007_001299__10 11 | 2007_001321__07 12 | 2007_001457__09 13 | 2007_001677__09 14 | 2007_001717__09 15 | 2007_001763__08 16 | 2007_001774__08 17 | 2007_001884__06 18 | 2007_002268__08 19 | 2007_002387__10 20 | 2007_002445__08 21 | 2007_002470__08 22 | 2007_002539__06 23 | 2007_002597__08 24 | 2007_002643__07 25 | 2007_002903__10 26 | 2007_003011__09 27 | 2007_003051__07 28 | 2007_003101__06 29 | 2007_003106__08 30 | 2007_003137__06 31 | 2007_003143__07 32 | 2007_003169__08 33 | 2007_003195__06 34 | 2007_003201__10 35 | 2007_003503__06 36 | 2007_003503__07 37 | 2007_003621__06 38 | 2007_003711__06 39 | 2007_003786__06 40 | 2007_003841__10 41 | 2007_003917__07 42 | 2007_003991__08 43 | 2007_004193__09 44 | 2007_004392__09 45 | 2007_004405__09 46 | 2007_004510__09 47 | 2007_004712__09 48 | 2007_004856__08 49 | 2007_004866__08 50 | 2007_005074__07 51 | 2007_005114__10 52 | 2007_005296__07 53 | 2007_005331__07 54 | 2007_005460__08 55 | 2007_005547__07 56 | 2007_005547__10 57 | 2007_005844__09 58 | 2007_005845__08 59 | 2007_005911__06 60 | 2007_005978__06 61 | 2007_006035__07 62 | 2007_006086__09 63 | 2007_006241__09 64 | 2007_006260__08 65 | 2007_006277__07 66 | 2007_006348__09 67 | 2007_006553__09 68 | 2007_006761__10 69 | 2007_006841__10 70 | 2007_007414__07 71 | 2007_007417__08 72 | 2007_007524__08 73 | 2007_007815__07 74 | 2007_007818__07 75 | 2007_007996__09 76 | 2007_008106__09 77 | 2007_008110__09 78 | 2007_008543__09 79 | 2007_008722__10 80 | 2007_008747__06 81 | 2007_008815__08 82 | 2007_008897__09 83 | 2007_008973__10 84 | 2007_009015__06 85 | 2007_009015__07 86 | 2007_009068__09 87 | 2007_009084__09 88 | 2007_009096__07 89 | 2007_009221__08 90 | 2007_009245__10 91 | 2007_009346__08 92 | 2007_009392__06 93 | 2007_009392__07 94 | 2007_009413__09 95 | 2007_009521__09 96 | 2007_009764__06 97 | 2007_009794__08 98 | 2007_009897__10 99 | 2007_009923__08 100 | 2007_009938__07 101 | 2008_000009__10 102 | 2008_000073__10 103 | 2008_000075__06 104 | 2008_000107__09 105 | 2008_000149__09 106 | 2008_000182__08 107 | 2008_000345__08 108 | 2008_000401__08 109 | 2008_000464__08 110 | 2008_000501__07 111 | 2008_000673__09 112 | 2008_000853__08 113 | 2008_000919__10 114 | 2008_001078__08 115 | 2008_001433__08 116 | 2008_001439__09 117 | 2008_001513__08 118 | 2008_001640__08 119 | 2008_001715__09 120 | 2008_001885__08 121 | 2008_002152__08 122 | 2008_002205__06 123 | 2008_002212__07 124 | 2008_002379__09 125 | 2008_002521__09 126 | 2008_002623__08 127 | 2008_002681__08 128 | 2008_002778__10 129 | 2008_002958__07 130 | 2008_003141__06 131 | 2008_003141__07 132 | 2008_003333__07 133 | 2008_003477__09 134 | 2008_003499__08 135 | 2008_003577__07 136 | 2008_003777__06 137 | 2008_003821__09 138 | 2008_003846__07 139 | 2008_004069__07 140 | 2008_004339__07 141 | 2008_004552__07 142 | 2008_004612__09 143 | 2008_004701__10 144 | 2008_005097__10 145 | 2008_005105__10 146 | 2008_005245__07 147 | 2008_005676__06 148 | 2008_006008__09 149 | 2008_006063__10 150 | 2008_006254__07 151 | 2008_006325__08 152 | 2008_006341__08 153 | 2008_006480__08 154 | 2008_006528__10 155 | 2008_006554__06 156 | 2008_006986__07 157 | 2008_007025__10 158 | 2008_007031__10 159 | 2008_007048__09 160 | 2008_007123__10 161 | 2008_007194__09 162 | 2008_007273__10 163 | 2008_007378__09 164 | 2008_007402__09 165 | 2008_007527__09 166 | 2008_007548__08 167 | 2008_007596__10 168 | 2008_007737__09 169 | 2008_007797__06 170 | 2008_007804__07 171 | 2008_007828__09 172 | 2008_008252__06 173 | 2008_008301__06 174 | 2008_008469__06 175 | 2008_008682__06 176 | 2009_000013__08 177 | 2009_000080__08 178 | 2009_000219__10 179 | 2009_000309__10 180 | 2009_000335__06 181 | 2009_000335__07 182 | 2009_000426__06 183 | 2009_000455__06 184 | 2009_000457__07 185 | 2009_000523__07 186 | 2009_000641__10 187 | 2009_000716__08 188 | 2009_000731__10 189 | 2009_000771__10 190 | 2009_000825__07 191 | 2009_000964__08 192 | 2009_001008__08 193 | 2009_001082__06 194 | 2009_001240__07 195 | 2009_001255__07 196 | 2009_001299__09 197 | 2009_001391__08 198 | 2009_001411__08 199 | 2009_001536__07 200 | 2009_001775__09 201 | 2009_001804__06 202 | 2009_001816__06 203 | 2009_001854__06 204 | 2009_002035__10 205 | 2009_002122__10 206 | 2009_002150__10 207 | 2009_002164__07 208 | 2009_002171__10 209 | 2009_002221__10 210 | 2009_002238__06 211 | 2009_002238__07 212 | 2009_002239__07 213 | 2009_002268__08 214 | 2009_002346__09 215 | 2009_002415__09 216 | 2009_002487__09 217 | 2009_002527__08 218 | 2009_002535__06 219 | 2009_002549__10 220 | 2009_002571__09 221 | 2009_002618__07 222 | 2009_002635__10 223 | 2009_002753__08 224 | 2009_002936__08 225 | 2009_002990__07 226 | 2009_003003__07 227 | 2009_003059__10 228 | 2009_003071__09 229 | 2009_003269__07 230 | 2009_003304__06 231 | 2009_003387__07 232 | 2009_003406__07 233 | 2009_003494__09 234 | 2009_003507__09 235 | 2009_003542__10 236 | 2009_003549__07 237 | 2009_003569__10 238 | 2009_003589__07 239 | 2009_003703__06 240 | 2009_003771__08 241 | 2009_003773__10 242 | 2009_003849__09 243 | 2009_003895__09 244 | 2009_003904__08 245 | 2009_004072__06 246 | 2009_004140__09 247 | 2009_004217__09 248 | 2009_004248__08 249 | 2009_004455__07 250 | 2009_004504__08 251 | 2009_004590__06 252 | 2009_004594__07 253 | 2009_004687__09 254 | 2009_004721__08 255 | 2009_004732__06 256 | 2009_004748__07 257 | 2009_004789__06 258 | 2009_004859__09 259 | 2009_004867__06 260 | 2009_005158__08 261 | 2009_005219__08 262 | 2009_005231__06 263 | 2010_000003__09 264 | 2010_000160__07 265 | 2010_000163__08 266 | 2010_000372__07 267 | 2010_000427__10 268 | 2010_000530__07 269 | 2010_000552__08 270 | 2010_000573__06 271 | 2010_000628__07 272 | 2010_000639__09 273 | 2010_000682__06 274 | 2010_000683__08 275 | 2010_000724__08 276 | 2010_000907__10 277 | 2010_000941__08 278 | 2010_000952__07 279 | 2010_001000__10 280 | 2010_001010__10 281 | 2010_001070__08 282 | 2010_001206__06 283 | 2010_001292__08 284 | 2010_001331__08 285 | 2010_001351__08 286 | 2010_001403__06 287 | 2010_001403__07 288 | 2010_001534__08 289 | 2010_001553__07 290 | 2010_001579__09 291 | 2010_001646__06 292 | 2010_001656__08 293 | 2010_001692__10 294 | 2010_001699__09 295 | 2010_001767__07 296 | 2010_001851__09 297 | 2010_001913__08 298 | 2010_002017__07 299 | 2010_002017__09 300 | 2010_002025__08 301 | 2010_002137__08 302 | 2010_002146__08 303 | 2010_002305__08 304 | 2010_002336__09 305 | 2010_002348__08 306 | 2010_002361__07 307 | 2010_002390__10 308 | 2010_002422__08 309 | 2010_002512__08 310 | 2010_002531__08 311 | 2010_002546__06 312 | 2010_002623__09 313 | 2010_002693__08 314 | 2010_002693__09 315 | 2010_002763__08 316 | 2010_002763__10 317 | 2010_002868__06 318 | 2010_002900__08 319 | 2010_002902__07 320 | 2010_002921__09 321 | 2010_002929__07 322 | 2010_002988__07 323 | 2010_003123__07 324 | 2010_003183__10 325 | 2010_003231__07 326 | 2010_003239__10 327 | 2010_003275__08 328 | 2010_003276__07 329 | 2010_003293__06 330 | 2010_003302__09 331 | 2010_003325__09 332 | 2010_003381__07 333 | 2010_003402__08 334 | 2010_003409__09 335 | 2010_003446__07 336 | 2010_003453__07 337 | 2010_003468__08 338 | 2010_003531__09 339 | 2010_003675__08 340 | 2010_003746__07 341 | 2010_003758__08 342 | 2010_003764__08 343 | 2010_003768__07 344 | 2010_003772__06 345 | 2010_003781__08 346 | 2010_003813__07 347 | 2010_003854__07 348 | 2010_003971__08 349 | 2010_003971__09 350 | 2010_004104__08 351 | 2010_004120__08 352 | 2010_004320__08 353 | 2010_004322__10 354 | 2010_004348__06 355 | 2010_004369__08 356 | 2010_004472__07 357 | 2010_004479__08 358 | 2010_004635__10 359 | 2010_004763__09 360 | 2010_004783__09 361 | 2010_004789__10 362 | 2010_004815__08 363 | 2010_004825__09 364 | 2010_004861__08 365 | 2010_004946__07 366 | 2010_005013__07 367 | 2010_005021__08 368 | 2010_005021__09 369 | 2010_005063__06 370 | 2010_005108__08 371 | 2010_005118__06 372 | 2010_005160__06 373 | 2010_005166__10 374 | 2010_005284__06 375 | 2010_005344__08 376 | 2010_005421__08 377 | 2010_005432__07 378 | 2010_005501__07 379 | 2010_005508__08 380 | 2010_005606__08 381 | 2010_005709__08 382 | 2010_005718__07 383 | 2010_005860__07 384 | 2010_005899__08 385 | 2010_006070__07 386 | 2011_000178__06 387 | 2011_000226__09 388 | 2011_000239__06 389 | 2011_000248__06 390 | 2011_000312__06 391 | 2011_000338__09 392 | 2011_000419__08 393 | 2011_000503__07 394 | 2011_000548__10 395 | 2011_000566__10 396 | 2011_000607__09 397 | 2011_000661__08 398 | 2011_000661__09 399 | 2011_000780__08 400 | 2011_000789__08 401 | 2011_000809__09 402 | 2011_000813__08 403 | 2011_000813__09 404 | 2011_000830__06 405 | 2011_000843__09 406 | 2011_000888__06 407 | 2011_000900__07 408 | 2011_000969__06 409 | 2011_001047__10 410 | 2011_001064__06 411 | 2011_001071__09 412 | 2011_001110__07 413 | 2011_001159__10 414 | 2011_001232__10 415 | 2011_001292__08 416 | 2011_001341__06 417 | 2011_001346__09 418 | 2011_001447__09 419 | 2011_001530__10 420 | 2011_001534__08 421 | 2011_001546__10 422 | 2011_001567__09 423 | 2011_001597__08 424 | 2011_001601__08 425 | 2011_001607__08 426 | 2011_001665__09 427 | 2011_001708__10 428 | 2011_001775__08 429 | 2011_001782__10 430 | 2011_001812__09 431 | 2011_002041__09 432 | 2011_002064__07 433 | 2011_002124__09 434 | 2011_002200__09 435 | 2011_002298__09 436 | 2011_002322__07 437 | 2011_002343__09 438 | 2011_002358__09 439 | 2011_002391__09 440 | 2011_002509__09 441 | 2011_002592__07 442 | 2011_002644__09 443 | 2011_002685__08 444 | 2011_002812__07 445 | 2011_002885__10 446 | 2011_003011__09 447 | 2011_003019__07 448 | 2011_003019__10 449 | 2011_003055__07 450 | 2011_003103__09 451 | 2011_003114__06 452 | -------------------------------------------------------------------------------- /data_list/val/split2_val.txt: -------------------------------------------------------------------------------- 1 | 2007_000129__15 2 | 2007_000323__15 3 | 2007_000332__13 4 | 2007_000346__15 5 | 2007_000762__11 6 | 2007_000762__15 7 | 2007_000783__13 8 | 2007_000783__15 9 | 2007_000799__13 10 | 2007_000799__15 11 | 2007_000830__11 12 | 2007_000847__11 13 | 2007_000847__15 14 | 2007_000999__15 15 | 2007_001175__15 16 | 2007_001239__12 17 | 2007_001284__15 18 | 2007_001311__15 19 | 2007_001408__15 20 | 2007_001423__15 21 | 2007_001430__11 22 | 2007_001430__15 23 | 2007_001526__15 24 | 2007_001585__15 25 | 2007_001586__13 26 | 2007_001586__15 27 | 2007_001594__15 28 | 2007_001630__15 29 | 2007_001677__11 30 | 2007_001678__15 31 | 2007_001717__15 32 | 2007_001763__12 33 | 2007_001955__13 34 | 2007_002046__13 35 | 2007_002119__15 36 | 2007_002260__14 37 | 2007_002268__12 38 | 2007_002378__15 39 | 2007_002426__15 40 | 2007_002539__15 41 | 2007_002565__15 42 | 2007_002597__12 43 | 2007_002624__11 44 | 2007_002624__15 45 | 2007_002643__15 46 | 2007_002728__15 47 | 2007_002823__14 48 | 2007_002823__15 49 | 2007_002824__15 50 | 2007_002852__12 51 | 2007_003011__11 52 | 2007_003020__15 53 | 2007_003022__13 54 | 2007_003022__15 55 | 2007_003088__15 56 | 2007_003106__15 57 | 2007_003110__12 58 | 2007_003134__15 59 | 2007_003188__15 60 | 2007_003194__12 61 | 2007_003367__14 62 | 2007_003367__15 63 | 2007_003373__12 64 | 2007_003373__15 65 | 2007_003530__15 66 | 2007_003621__15 67 | 2007_003742__11 68 | 2007_003742__15 69 | 2007_003872__12 70 | 2007_004033__14 71 | 2007_004033__15 72 | 2007_004112__12 73 | 2007_004112__15 74 | 2007_004121__15 75 | 2007_004189__12 76 | 2007_004275__14 77 | 2007_004275__15 78 | 2007_004281__15 79 | 2007_004380__14 80 | 2007_004380__15 81 | 2007_004392__15 82 | 2007_004405__11 83 | 2007_004538__13 84 | 2007_004538__15 85 | 2007_004644__12 86 | 2007_004712__11 87 | 2007_004712__15 88 | 2007_004722__13 89 | 2007_004722__15 90 | 2007_004902__13 91 | 2007_004902__15 92 | 2007_005114__13 93 | 2007_005114__15 94 | 2007_005149__12 95 | 2007_005173__14 96 | 2007_005173__15 97 | 2007_005281__15 98 | 2007_005304__15 99 | 2007_005331__13 100 | 2007_005331__15 101 | 2007_005354__14 102 | 2007_005354__15 103 | 2007_005509__15 104 | 2007_005547__15 105 | 2007_005608__14 106 | 2007_005608__15 107 | 2007_005696__12 108 | 2007_005759__14 109 | 2007_005803__11 110 | 2007_005844__11 111 | 2007_005845__15 112 | 2007_006028__15 113 | 2007_006076__15 114 | 2007_006086__11 115 | 2007_006117__15 116 | 2007_006171__12 117 | 2007_006171__15 118 | 2007_006241__11 119 | 2007_006364__13 120 | 2007_006364__15 121 | 2007_006373__15 122 | 2007_006444__12 123 | 2007_006444__15 124 | 2007_006560__15 125 | 2007_006647__14 126 | 2007_006647__15 127 | 2007_006698__15 128 | 2007_006802__15 129 | 2007_006841__15 130 | 2007_006864__15 131 | 2007_006866__13 132 | 2007_006866__15 133 | 2007_007007__11 134 | 2007_007007__15 135 | 2007_007109__13 136 | 2007_007109__15 137 | 2007_007195__15 138 | 2007_007203__15 139 | 2007_007211__14 140 | 2007_007235__15 141 | 2007_007417__12 142 | 2007_007493__15 143 | 2007_007498__11 144 | 2007_007498__15 145 | 2007_007651__11 146 | 2007_007651__15 147 | 2007_007688__14 148 | 2007_007748__13 149 | 2007_007748__15 150 | 2007_007795__15 151 | 2007_007810__11 152 | 2007_007810__15 153 | 2007_007815__15 154 | 2007_007836__15 155 | 2007_007849__15 156 | 2007_007996__15 157 | 2007_008110__15 158 | 2007_008204__15 159 | 2007_008222__12 160 | 2007_008256__13 161 | 2007_008256__15 162 | 2007_008260__12 163 | 2007_008374__15 164 | 2007_008415__12 165 | 2007_008430__15 166 | 2007_008596__13 167 | 2007_008596__15 168 | 2007_008708__15 169 | 2007_008802__13 170 | 2007_008897__15 171 | 2007_008944__15 172 | 2007_008964__12 173 | 2007_008964__15 174 | 2007_008980__12 175 | 2007_009068__15 176 | 2007_009084__12 177 | 2007_009084__14 178 | 2007_009251__13 179 | 2007_009251__15 180 | 2007_009258__15 181 | 2007_009320__15 182 | 2007_009331__12 183 | 2007_009331__13 184 | 2007_009331__15 185 | 2007_009413__11 186 | 2007_009413__15 187 | 2007_009521__11 188 | 2007_009562__12 189 | 2007_009592__12 190 | 2007_009654__15 191 | 2007_009655__15 192 | 2007_009684__15 193 | 2007_009687__15 194 | 2007_009691__14 195 | 2007_009691__15 196 | 2007_009706__11 197 | 2007_009750__15 198 | 2007_009756__14 199 | 2007_009756__15 200 | 2007_009841__13 201 | 2007_009938__14 202 | 2008_000080__12 203 | 2008_000213__15 204 | 2008_000215__15 205 | 2008_000223__15 206 | 2008_000233__15 207 | 2008_000234__15 208 | 2008_000239__12 209 | 2008_000270__12 210 | 2008_000270__15 211 | 2008_000271__15 212 | 2008_000359__15 213 | 2008_000474__15 214 | 2008_000510__15 215 | 2008_000573__11 216 | 2008_000573__15 217 | 2008_000602__13 218 | 2008_000630__15 219 | 2008_000661__12 220 | 2008_000661__15 221 | 2008_000662__15 222 | 2008_000666__15 223 | 2008_000673__15 224 | 2008_000700__15 225 | 2008_000725__15 226 | 2008_000731__15 227 | 2008_000763__11 228 | 2008_000763__15 229 | 2008_000765__13 230 | 2008_000782__14 231 | 2008_000795__15 232 | 2008_000811__14 233 | 2008_000811__15 234 | 2008_000863__12 235 | 2008_000943__12 236 | 2008_000992__15 237 | 2008_001013__15 238 | 2008_001028__15 239 | 2008_001070__12 240 | 2008_001074__15 241 | 2008_001076__15 242 | 2008_001150__14 243 | 2008_001170__15 244 | 2008_001231__15 245 | 2008_001249__15 246 | 2008_001283__15 247 | 2008_001308__15 248 | 2008_001379__12 249 | 2008_001404__15 250 | 2008_001478__12 251 | 2008_001491__15 252 | 2008_001504__15 253 | 2008_001531__15 254 | 2008_001547__15 255 | 2008_001629__15 256 | 2008_001682__13 257 | 2008_001821__15 258 | 2008_001874__15 259 | 2008_001895__12 260 | 2008_001895__15 261 | 2008_001992__13 262 | 2008_001992__15 263 | 2008_002212__15 264 | 2008_002239__12 265 | 2008_002240__14 266 | 2008_002241__15 267 | 2008_002379__11 268 | 2008_002383__14 269 | 2008_002495__15 270 | 2008_002536__12 271 | 2008_002588__15 272 | 2008_002775__11 273 | 2008_002775__15 274 | 2008_002835__13 275 | 2008_002835__15 276 | 2008_002859__12 277 | 2008_002864__11 278 | 2008_002864__15 279 | 2008_002904__12 280 | 2008_002929__15 281 | 2008_002936__12 282 | 2008_002942__15 283 | 2008_002958__12 284 | 2008_003034__15 285 | 2008_003076__15 286 | 2008_003108__15 287 | 2008_003141__15 288 | 2008_003210__15 289 | 2008_003238__12 290 | 2008_003238__15 291 | 2008_003330__15 292 | 2008_003333__14 293 | 2008_003333__15 294 | 2008_003379__13 295 | 2008_003451__14 296 | 2008_003451__15 297 | 2008_003461__13 298 | 2008_003461__15 299 | 2008_003477__11 300 | 2008_003492__15 301 | 2008_003511__12 302 | 2008_003511__15 303 | 2008_003546__15 304 | 2008_003576__12 305 | 2008_003676__15 306 | 2008_003733__15 307 | 2008_003782__13 308 | 2008_003856__15 309 | 2008_003874__15 310 | 2008_004101__15 311 | 2008_004140__11 312 | 2008_004140__15 313 | 2008_004175__13 314 | 2008_004345__14 315 | 2008_004396__13 316 | 2008_004399__14 317 | 2008_004399__15 318 | 2008_004575__11 319 | 2008_004575__15 320 | 2008_004624__13 321 | 2008_004654__15 322 | 2008_004687__13 323 | 2008_004705__13 324 | 2008_005049__14 325 | 2008_005089__15 326 | 2008_005145__11 327 | 2008_005197__12 328 | 2008_005197__15 329 | 2008_005245__14 330 | 2008_005245__15 331 | 2008_005399__15 332 | 2008_005422__14 333 | 2008_005445__15 334 | 2008_005525__13 335 | 2008_005637__14 336 | 2008_005642__13 337 | 2008_005691__13 338 | 2008_005738__15 339 | 2008_005812__15 340 | 2008_005915__14 341 | 2008_006008__11 342 | 2008_006036__13 343 | 2008_006108__11 344 | 2008_006108__15 345 | 2008_006130__12 346 | 2008_006216__15 347 | 2008_006219__13 348 | 2008_006254__15 349 | 2008_006275__15 350 | 2008_006341__15 351 | 2008_006408__11 352 | 2008_006408__15 353 | 2008_006526__14 354 | 2008_006526__15 355 | 2008_006554__15 356 | 2008_006722__12 357 | 2008_006722__15 358 | 2008_006874__14 359 | 2008_006874__15 360 | 2008_006981__12 361 | 2008_007048__11 362 | 2008_007219__15 363 | 2008_007378__11 364 | 2008_007378__12 365 | 2008_007392__13 366 | 2008_007392__15 367 | 2008_007402__11 368 | 2008_007402__15 369 | 2008_007513__12 370 | 2008_007737__15 371 | 2008_007828__15 372 | 2008_007945__13 373 | 2008_007994__15 374 | 2008_008051__11 375 | 2008_008127__14 376 | 2008_008127__15 377 | 2008_008221__15 378 | 2008_008335__11 379 | 2008_008335__15 380 | 2008_008362__11 381 | 2008_008362__15 382 | 2008_008392__13 383 | 2008_008393__13 384 | 2008_008421__13 385 | 2008_008469__15 386 | 2009_000012__13 387 | 2009_000074__14 388 | 2009_000074__15 389 | 2009_000156__12 390 | 2009_000219__15 391 | 2009_000309__15 392 | 2009_000412__13 393 | 2009_000418__15 394 | 2009_000421__15 395 | 2009_000457__15 396 | 2009_000704__15 397 | 2009_000705__13 398 | 2009_000727__13 399 | 2009_000730__14 400 | 2009_000730__15 401 | 2009_000825__14 402 | 2009_000825__15 403 | 2009_000839__12 404 | 2009_000892__12 405 | 2009_000931__13 406 | 2009_000935__12 407 | 2009_001215__11 408 | 2009_001215__15 409 | 2009_001299__15 410 | 2009_001433__13 411 | 2009_001433__15 412 | 2009_001535__12 413 | 2009_001663__15 414 | 2009_001687__12 415 | 2009_001687__15 416 | 2009_001718__15 417 | 2009_001768__15 418 | 2009_001854__15 419 | 2009_002012__12 420 | 2009_002042__15 421 | 2009_002097__13 422 | 2009_002155__12 423 | 2009_002165__13 424 | 2009_002185__15 425 | 2009_002239__14 426 | 2009_002239__15 427 | 2009_002317__14 428 | 2009_002317__15 429 | 2009_002346__12 430 | 2009_002346__15 431 | 2009_002372__15 432 | 2009_002382__14 433 | 2009_002382__15 434 | 2009_002415__11 435 | 2009_002445__12 436 | 2009_002487__11 437 | 2009_002539__12 438 | 2009_002571__11 439 | 2009_002584__15 440 | 2009_002649__15 441 | 2009_002651__14 442 | 2009_002651__15 443 | 2009_002732__15 444 | 2009_002975__13 445 | 2009_003003__11 446 | 2009_003003__15 447 | 2009_003063__12 448 | 2009_003065__15 449 | 2009_003071__11 450 | 2009_003071__15 451 | 2009_003123__11 452 | 2009_003196__14 453 | 2009_003217__12 454 | 2009_003241__12 455 | 2009_003269__15 456 | 2009_003323__13 457 | 2009_003323__15 458 | 2009_003466__12 459 | 2009_003481__13 460 | 2009_003494__15 461 | 2009_003507__11 462 | 2009_003576__14 463 | 2009_003576__15 464 | 2009_003756__12 465 | 2009_003804__13 466 | 2009_003810__12 467 | 2009_003849__11 468 | 2009_003849__15 469 | 2009_003903__13 470 | 2009_003928__12 471 | 2009_003991__11 472 | 2009_003991__15 473 | 2009_004033__12 474 | 2009_004043__14 475 | 2009_004043__15 476 | 2009_004140__11 477 | 2009_004221__15 478 | 2009_004455__14 479 | 2009_004497__13 480 | 2009_004507__12 481 | 2009_004507__15 482 | 2009_004581__12 483 | 2009_004592__12 484 | 2009_004738__14 485 | 2009_004738__15 486 | 2009_004848__15 487 | 2009_004859__11 488 | 2009_004859__15 489 | 2009_004942__13 490 | 2009_004987__14 491 | 2009_004987__15 492 | 2009_004994__12 493 | 2009_004994__15 494 | 2009_005038__11 495 | 2009_005038__15 496 | 2009_005078__14 497 | 2009_005087__15 498 | 2009_005217__13 499 | 2009_005217__15 500 | 2010_000003__12 501 | 2010_000038__13 502 | 2010_000038__15 503 | 2010_000087__14 504 | 2010_000087__15 505 | 2010_000110__12 506 | 2010_000110__15 507 | 2010_000159__12 508 | 2010_000174__11 509 | 2010_000174__15 510 | 2010_000216__12 511 | 2010_000238__15 512 | 2010_000256__15 513 | 2010_000422__12 514 | 2010_000530__15 515 | 2010_000559__15 516 | 2010_000639__12 517 | 2010_000666__13 518 | 2010_000666__15 519 | 2010_000738__15 520 | 2010_000788__12 521 | 2010_000874__13 522 | 2010_000904__12 523 | 2010_001024__15 524 | 2010_001124__12 525 | 2010_001251__14 526 | 2010_001264__12 527 | 2010_001313__14 528 | 2010_001313__15 529 | 2010_001367__15 530 | 2010_001376__12 531 | 2010_001451__13 532 | 2010_001553__14 533 | 2010_001563__12 534 | 2010_001563__15 535 | 2010_001579__11 536 | 2010_001579__15 537 | 2010_001692__15 538 | 2010_001699__15 539 | 2010_001734__15 540 | 2010_001767__15 541 | 2010_001851__11 542 | 2010_001908__12 543 | 2010_001956__12 544 | 2010_002017__15 545 | 2010_002137__15 546 | 2010_002161__13 547 | 2010_002161__15 548 | 2010_002228__12 549 | 2010_002251__14 550 | 2010_002251__15 551 | 2010_002271__14 552 | 2010_002336__11 553 | 2010_002396__14 554 | 2010_002396__15 555 | 2010_002480__12 556 | 2010_002623__15 557 | 2010_002691__13 558 | 2010_002763__15 559 | 2010_002792__15 560 | 2010_002902__15 561 | 2010_002929__15 562 | 2010_003014__15 563 | 2010_003060__12 564 | 2010_003187__12 565 | 2010_003207__14 566 | 2010_003239__15 567 | 2010_003325__11 568 | 2010_003325__15 569 | 2010_003381__15 570 | 2010_003409__15 571 | 2010_003446__15 572 | 2010_003506__12 573 | 2010_003531__11 574 | 2010_003532__13 575 | 2010_003597__11 576 | 2010_003597__15 577 | 2010_003746__12 578 | 2010_003746__15 579 | 2010_003947__14 580 | 2010_003971__11 581 | 2010_004042__14 582 | 2010_004165__12 583 | 2010_004165__15 584 | 2010_004219__14 585 | 2010_004219__15 586 | 2010_004337__15 587 | 2010_004355__14 588 | 2010_004432__15 589 | 2010_004472__15 590 | 2010_004479__15 591 | 2010_004519__13 592 | 2010_004550__12 593 | 2010_004559__15 594 | 2010_004628__12 595 | 2010_004697__14 596 | 2010_004697__15 597 | 2010_004795__12 598 | 2010_004815__15 599 | 2010_004825__11 600 | 2010_004828__15 601 | 2010_004856__13 602 | 2010_004941__14 603 | 2010_004951__15 604 | 2010_005046__11 605 | 2010_005046__15 606 | 2010_005118__15 607 | 2010_005159__12 608 | 2010_005160__14 609 | 2010_005166__15 610 | 2010_005174__13 611 | 2010_005206__12 612 | 2010_005245__12 613 | 2010_005245__15 614 | 2010_005252__14 615 | 2010_005252__15 616 | 2010_005284__15 617 | 2010_005366__14 618 | 2010_005433__14 619 | 2010_005501__14 620 | 2010_005575__12 621 | 2010_005582__15 622 | 2010_005606__15 623 | 2010_005626__11 624 | 2010_005626__15 625 | 2010_005644__12 626 | 2010_005709__15 627 | 2010_005871__15 628 | 2010_005991__12 629 | 2010_005991__15 630 | 2010_005992__12 631 | 2011_000045__12 632 | 2011_000051__15 633 | 2011_000054__15 634 | 2011_000178__15 635 | 2011_000226__11 636 | 2011_000248__15 637 | 2011_000338__11 638 | 2011_000396__13 639 | 2011_000435__15 640 | 2011_000438__15 641 | 2011_000455__14 642 | 2011_000455__15 643 | 2011_000479__15 644 | 2011_000512__14 645 | 2011_000526__13 646 | 2011_000536__12 647 | 2011_000566__15 648 | 2011_000585__15 649 | 2011_000598__11 650 | 2011_000618__14 651 | 2011_000618__15 652 | 2011_000638__15 653 | 2011_000780__15 654 | 2011_000809__11 655 | 2011_000809__15 656 | 2011_000843__15 657 | 2011_000953__11 658 | 2011_000953__15 659 | 2011_001014__12 660 | 2011_001060__15 661 | 2011_001069__15 662 | 2011_001071__15 663 | 2011_001159__15 664 | 2011_001276__11 665 | 2011_001276__12 666 | 2011_001276__15 667 | 2011_001346__15 668 | 2011_001416__15 669 | 2011_001447__15 670 | 2011_001530__15 671 | 2011_001567__15 672 | 2011_001619__15 673 | 2011_001642__12 674 | 2011_001665__11 675 | 2011_001674__15 676 | 2011_001714__12 677 | 2011_001714__15 678 | 2011_001722__13 679 | 2011_001745__12 680 | 2011_001794__15 681 | 2011_001862__11 682 | 2011_001862__12 683 | 2011_001868__12 684 | 2011_001984__12 685 | 2011_001988__15 686 | 2011_002002__15 687 | 2011_002040__12 688 | 2011_002075__11 689 | 2011_002075__15 690 | 2011_002098__12 691 | 2011_002110__12 692 | 2011_002110__15 693 | 2011_002121__12 694 | 2011_002124__15 695 | 2011_002156__12 696 | 2011_002200__11 697 | 2011_002200__15 698 | 2011_002247__15 699 | 2011_002279__12 700 | 2011_002298__12 701 | 2011_002308__15 702 | 2011_002317__15 703 | 2011_002322__14 704 | 2011_002322__15 705 | 2011_002343__15 706 | 2011_002358__11 707 | 2011_002358__15 708 | 2011_002371__12 709 | 2011_002498__15 710 | 2011_002509__15 711 | 2011_002532__15 712 | 2011_002575__15 713 | 2011_002578__15 714 | 2011_002589__12 715 | 2011_002623__15 716 | 2011_002641__15 717 | 2011_002675__15 718 | 2011_002951__13 719 | 2011_002997__15 720 | 2011_003019__14 721 | 2011_003019__15 722 | 2011_003085__13 723 | 2011_003114__15 724 | 2011_003240__15 725 | 2011_003256__12 726 | -------------------------------------------------------------------------------- /data_list/val/split3_val.txt: -------------------------------------------------------------------------------- 1 | 2007_000042__19 2 | 2007_000123__19 3 | 2007_000175__17 4 | 2007_000187__20 5 | 2007_000452__18 6 | 2007_000559__20 7 | 2007_000629__19 8 | 2007_000636__19 9 | 2007_000661__18 10 | 2007_000676__17 11 | 2007_000804__18 12 | 2007_000925__17 13 | 2007_001154__18 14 | 2007_001175__20 15 | 2007_001408__16 16 | 2007_001430__16 17 | 2007_001430__20 18 | 2007_001457__18 19 | 2007_001458__18 20 | 2007_001585__18 21 | 2007_001594__17 22 | 2007_001678__20 23 | 2007_001717__20 24 | 2007_001733__17 25 | 2007_001763__18 26 | 2007_001763__20 27 | 2007_002119__20 28 | 2007_002132__20 29 | 2007_002268__18 30 | 2007_002284__16 31 | 2007_002378__16 32 | 2007_002426__18 33 | 2007_002427__18 34 | 2007_002565__19 35 | 2007_002618__17 36 | 2007_002648__17 37 | 2007_002728__19 38 | 2007_003011__18 39 | 2007_003011__20 40 | 2007_003169__18 41 | 2007_003367__16 42 | 2007_003499__19 43 | 2007_003506__16 44 | 2007_003530__18 45 | 2007_003587__19 46 | 2007_003714__17 47 | 2007_003848__19 48 | 2007_003957__19 49 | 2007_004190__20 50 | 2007_004193__20 51 | 2007_004275__16 52 | 2007_004281__19 53 | 2007_004483__19 54 | 2007_004510__20 55 | 2007_004558__16 56 | 2007_004649__19 57 | 2007_004712__16 58 | 2007_004969__17 59 | 2007_005469__17 60 | 2007_005626__19 61 | 2007_005689__19 62 | 2007_005813__16 63 | 2007_005857__16 64 | 2007_005915__17 65 | 2007_006171__18 66 | 2007_006348__20 67 | 2007_006373__18 68 | 2007_006678__17 69 | 2007_006680__19 70 | 2007_006802__19 71 | 2007_007130__20 72 | 2007_007165__17 73 | 2007_007168__19 74 | 2007_007195__19 75 | 2007_007196__20 76 | 2007_007203__20 77 | 2007_007417__18 78 | 2007_007534__17 79 | 2007_007624__16 80 | 2007_007795__16 81 | 2007_007881__19 82 | 2007_007996__18 83 | 2007_008204__20 84 | 2007_008260__18 85 | 2007_008339__19 86 | 2007_008374__20 87 | 2007_008543__18 88 | 2007_008547__16 89 | 2007_009068__18 90 | 2007_009252__18 91 | 2007_009320__17 92 | 2007_009419__16 93 | 2007_009446__20 94 | 2007_009521__18 95 | 2007_009521__20 96 | 2007_009592__18 97 | 2007_009655__18 98 | 2007_009684__18 99 | 2007_009750__16 100 | 2008_000016__20 101 | 2008_000149__18 102 | 2008_000270__18 103 | 2008_000391__16 104 | 2008_000589__18 105 | 2008_000657__19 106 | 2008_001078__16 107 | 2008_001283__16 108 | 2008_001688__16 109 | 2008_001688__20 110 | 2008_001966__16 111 | 2008_002273__16 112 | 2008_002379__16 113 | 2008_002464__20 114 | 2008_002536__17 115 | 2008_002680__20 116 | 2008_002900__19 117 | 2008_002929__18 118 | 2008_003003__20 119 | 2008_003026__20 120 | 2008_003105__19 121 | 2008_003135__16 122 | 2008_003676__16 123 | 2008_003709__18 124 | 2008_003733__18 125 | 2008_003885__20 126 | 2008_004172__18 127 | 2008_004212__19 128 | 2008_004279__20 129 | 2008_004367__19 130 | 2008_004453__17 131 | 2008_004477__16 132 | 2008_004562__18 133 | 2008_004610__19 134 | 2008_004621__17 135 | 2008_004754__20 136 | 2008_004854__17 137 | 2008_004910__20 138 | 2008_005089__20 139 | 2008_005217__16 140 | 2008_005242__16 141 | 2008_005254__20 142 | 2008_005439__20 143 | 2008_005445__20 144 | 2008_005544__19 145 | 2008_005633__17 146 | 2008_005680__16 147 | 2008_006055__19 148 | 2008_006159__20 149 | 2008_006327__17 150 | 2008_006523__19 151 | 2008_006553__19 152 | 2008_006752__19 153 | 2008_006784__18 154 | 2008_006835__17 155 | 2008_007497__17 156 | 2008_007527__20 157 | 2008_007677__17 158 | 2008_007814__17 159 | 2008_007828__20 160 | 2008_008103__18 161 | 2008_008221__19 162 | 2008_008434__16 163 | 2009_000022__19 164 | 2009_000039__17 165 | 2009_000087__18 166 | 2009_000096__18 167 | 2009_000136__20 168 | 2009_000242__18 169 | 2009_000391__20 170 | 2009_000418__16 171 | 2009_000418__18 172 | 2009_000487__18 173 | 2009_000488__16 174 | 2009_000488__20 175 | 2009_000628__19 176 | 2009_000675__17 177 | 2009_000704__20 178 | 2009_000712__19 179 | 2009_000732__18 180 | 2009_000845__19 181 | 2009_000924__17 182 | 2009_001300__19 183 | 2009_001333__19 184 | 2009_001363__20 185 | 2009_001505__17 186 | 2009_001644__16 187 | 2009_001644__18 188 | 2009_001644__20 189 | 2009_001684__16 190 | 2009_001731__18 191 | 2009_001768__17 192 | 2009_001775__16 193 | 2009_001775__18 194 | 2009_001991__17 195 | 2009_002082__17 196 | 2009_002094__20 197 | 2009_002202__19 198 | 2009_002265__19 199 | 2009_002291__19 200 | 2009_002346__18 201 | 2009_002366__20 202 | 2009_002390__18 203 | 2009_002487__16 204 | 2009_002562__20 205 | 2009_002568__19 206 | 2009_002571__16 207 | 2009_002571__18 208 | 2009_002573__20 209 | 2009_002584__16 210 | 2009_002638__19 211 | 2009_002732__18 212 | 2009_002887__19 213 | 2009_002982__19 214 | 2009_003105__19 215 | 2009_003123__18 216 | 2009_003299__19 217 | 2009_003311__19 218 | 2009_003433__19 219 | 2009_003523__20 220 | 2009_003551__20 221 | 2009_003564__16 222 | 2009_003564__18 223 | 2009_003607__18 224 | 2009_003666__17 225 | 2009_003857__20 226 | 2009_003895__18 227 | 2009_003895__20 228 | 2009_003938__19 229 | 2009_004099__18 230 | 2009_004140__18 231 | 2009_004255__19 232 | 2009_004298__18 233 | 2009_004687__18 234 | 2009_004730__19 235 | 2009_004799__19 236 | 2009_004993__18 237 | 2009_004993__20 238 | 2009_005148__19 239 | 2009_005220__19 240 | 2010_000256__18 241 | 2010_000284__18 242 | 2010_000309__17 243 | 2010_000318__20 244 | 2010_000330__16 245 | 2010_000639__16 246 | 2010_000738__20 247 | 2010_000764__19 248 | 2010_001011__17 249 | 2010_001079__17 250 | 2010_001104__19 251 | 2010_001149__18 252 | 2010_001151__19 253 | 2010_001246__16 254 | 2010_001256__17 255 | 2010_001327__18 256 | 2010_001367__20 257 | 2010_001522__17 258 | 2010_001557__17 259 | 2010_001577__17 260 | 2010_001699__16 261 | 2010_001734__19 262 | 2010_001752__20 263 | 2010_001767__18 264 | 2010_001773__16 265 | 2010_001851__16 266 | 2010_001951__19 267 | 2010_001962__18 268 | 2010_002106__17 269 | 2010_002137__16 270 | 2010_002137__18 271 | 2010_002232__17 272 | 2010_002531__18 273 | 2010_002682__19 274 | 2010_002921__20 275 | 2010_003014__18 276 | 2010_003123__16 277 | 2010_003302__16 278 | 2010_003514__19 279 | 2010_003541__17 280 | 2010_003597__18 281 | 2010_003781__16 282 | 2010_003956__19 283 | 2010_004149__19 284 | 2010_004226__17 285 | 2010_004382__16 286 | 2010_004479__20 287 | 2010_004757__16 288 | 2010_004757__18 289 | 2010_004783__18 290 | 2010_004825__16 291 | 2010_004857__20 292 | 2010_004951__19 293 | 2010_004980__19 294 | 2010_005180__18 295 | 2010_005187__16 296 | 2010_005305__20 297 | 2010_005606__18 298 | 2010_005706__19 299 | 2010_005719__17 300 | 2010_005727__19 301 | 2010_005788__17 302 | 2010_005860__16 303 | 2010_005871__19 304 | 2010_005991__18 305 | 2010_006054__19 306 | 2011_000070__18 307 | 2011_000173__18 308 | 2011_000283__19 309 | 2011_000291__19 310 | 2011_000310__18 311 | 2011_000436__17 312 | 2011_000521__19 313 | 2011_000747__16 314 | 2011_001005__18 315 | 2011_001060__19 316 | 2011_001281__19 317 | 2011_001350__17 318 | 2011_001567__18 319 | 2011_001601__18 320 | 2011_001614__19 321 | 2011_001674__18 322 | 2011_001713__16 323 | 2011_001713__18 324 | 2011_001726__20 325 | 2011_001794__18 326 | 2011_001862__18 327 | 2011_001863__16 328 | 2011_001910__20 329 | 2011_002124__18 330 | 2011_002156__20 331 | 2011_002178__17 332 | 2011_002247__19 333 | 2011_002379__19 334 | 2011_002391__18 335 | 2011_002532__20 336 | 2011_002535__19 337 | 2011_002644__18 338 | 2011_002644__20 339 | 2011_002879__18 340 | 2011_002879__20 341 | 2011_003103__16 342 | 2011_003103__18 343 | 2011_003146__19 344 | 2011_003182__18 345 | 2011_003197__19 346 | 2011_003256__18 347 | -------------------------------------------------------------------------------- /data_list/val_list/split0_val.txt: -------------------------------------------------------------------------------- 1 | 2007_000033__01 2 | 2007_000061__04 3 | 2007_000129__02 4 | 2007_000346__05 5 | 2007_000529__04 6 | 2007_000559__05 7 | 2007_000572__02 8 | 2007_000762__05 9 | 2007_000837__04 10 | 2007_000999__05 11 | 2007_001288__01 12 | 2007_001289__03 13 | 2007_001311__02 14 | 2007_001408__05 15 | 2007_001568__01 16 | 2007_001587__02 17 | 2007_001630__02 18 | 2007_001761__01 19 | 2007_001884__01 20 | 2007_002094__03 21 | 2007_002266__01 22 | 2007_002376__01 23 | 2007_002400__03 24 | 2007_002412__03 25 | 2007_002619__01 26 | 2007_002719__04 27 | 2007_003088__05 28 | 2007_003131__04 29 | 2007_003188__02 30 | 2007_003349__03 31 | 2007_003571__04 32 | 2007_003621__02 33 | 2007_003682__03 34 | 2007_003861__04 35 | 2007_004052__01 36 | 2007_004143__03 37 | 2007_004241__04 38 | 2007_004468__05 39 | 2007_004712__05 40 | 2007_005074__04 41 | 2007_005107__02 42 | 2007_005281__02 43 | 2007_005294__05 44 | 2007_005304__05 45 | 2007_005428__05 46 | 2007_005509__01 47 | 2007_005600__01 48 | 2007_005705__04 49 | 2007_005828__01 50 | 2007_006028__02 51 | 2007_006046__04 52 | 2007_006076__03 53 | 2007_006086__05 54 | 2007_006241__05 55 | 2007_006449__02 56 | 2007_006549__04 57 | 2007_006946__01 58 | 2007_007084__03 59 | 2007_007235__02 60 | 2007_007341__01 61 | 2007_007470__01 62 | 2007_007477__04 63 | 2007_007493__05 64 | 2007_007836__02 65 | 2007_008051__03 66 | 2007_008084__03 67 | 2007_008204__05 68 | 2007_008645__04 69 | 2007_008670__03 70 | 2007_009088__03 71 | 2007_009258__02 72 | 2007_009323__03 73 | 2007_009458__05 74 | 2007_009654__05 75 | 2007_009687__05 76 | 2007_009817__03 77 | 2007_009911__01 78 | 2008_000120__04 79 | 2008_000123__03 80 | 2008_000533__03 81 | 2008_000725__02 82 | 2008_000795__05 83 | 2008_000848__03 84 | 2008_000911__05 85 | 2008_001013__04 86 | 2008_001040__04 87 | 2008_001135__04 88 | 2008_001231__02 89 | 2008_001260__04 90 | 2008_001404__02 91 | 2008_001514__03 92 | 2008_001531__02 93 | 2008_001546__01 94 | 2008_001580__04 95 | 2008_001966__03 96 | 2008_001971__01 97 | 2008_002043__03 98 | 2008_002269__02 99 | 2008_002358__01 100 | 2008_002429__03 101 | 2008_002467__05 102 | 2008_002492__04 103 | 2008_002504__04 104 | 2008_002775__05 105 | 2008_002864__05 106 | 2008_003076__05 107 | 2008_003108__02 108 | 2008_003110__03 109 | 2008_003155__01 110 | 2008_003270__02 111 | 2008_003369__01 112 | 2008_003858__04 113 | 2008_003876__01 114 | 2008_003886__04 115 | 2008_003926__01 116 | 2008_003976__01 117 | 2008_004101__02 118 | 2008_004363__02 119 | 2008_004654__02 120 | 2008_004659__05 121 | 2008_004704__01 122 | 2008_004758__02 123 | 2008_004995__02 124 | 2008_005089__05 125 | 2008_005262__05 126 | 2008_005338__01 127 | 2008_005398__04 128 | 2008_005399__05 129 | 2008_005628__04 130 | 2008_005727__02 131 | 2008_005738__02 132 | 2008_005812__05 133 | 2008_005904__05 134 | 2008_006143__03 135 | 2008_006229__04 136 | 2008_006254__02 137 | 2008_006703__01 138 | 2008_007120__03 139 | 2008_007143__04 140 | 2008_007219__05 141 | 2008_007350__01 142 | 2008_007498__03 143 | 2008_007507__01 144 | 2008_007737__05 145 | 2008_007811__05 146 | 2008_007836__01 147 | 2008_007994__03 148 | 2008_008268__03 149 | 2008_008362__05 150 | 2008_008629__02 151 | 2008_008711__02 152 | 2008_008746__03 153 | 2009_000032__01 154 | 2009_000037__03 155 | 2009_000121__05 156 | 2009_000149__02 157 | 2009_000201__05 158 | 2009_000205__01 159 | 2009_000318__03 160 | 2009_000351__02 161 | 2009_000354__02 162 | 2009_000387__01 163 | 2009_000421__04 164 | 2009_000440__01 165 | 2009_000446__04 166 | 2009_000457__02 167 | 2009_000469__04 168 | 2009_000573__02 169 | 2009_000573__05 170 | 2009_000619__03 171 | 2009_000664__03 172 | 2009_000723__04 173 | 2009_000828__04 174 | 2009_000840__05 175 | 2009_000879__03 176 | 2009_000919__04 177 | 2009_000991__03 178 | 2009_000998__03 179 | 2009_001108__03 180 | 2009_001160__03 181 | 2009_001255__02 182 | 2009_001278__05 183 | 2009_001314__03 184 | 2009_001332__01 185 | 2009_001565__03 186 | 2009_001607__03 187 | 2009_001663__02 188 | 2009_001683__03 189 | 2009_001718__02 190 | 2009_001765__03 191 | 2009_001818__05 192 | 2009_001850__01 193 | 2009_001851__01 194 | 2009_001941__04 195 | 2009_002185__05 196 | 2009_002295__02 197 | 2009_002320__01 198 | 2009_002372__05 199 | 2009_002521__05 200 | 2009_002591__04 201 | 2009_002594__05 202 | 2009_002604__03 203 | 2009_002649__05 204 | 2009_002727__04 205 | 2009_002732__05 206 | 2009_002749__05 207 | 2009_002808__01 208 | 2009_002856__05 209 | 2009_002888__01 210 | 2009_002928__02 211 | 2009_003003__05 212 | 2009_003005__01 213 | 2009_003043__04 214 | 2009_003065__02 215 | 2009_003080__04 216 | 2009_003193__02 217 | 2009_003224__02 218 | 2009_003269__05 219 | 2009_003273__03 220 | 2009_003343__02 221 | 2009_003378__03 222 | 2009_003450__03 223 | 2009_003498__03 224 | 2009_003504__04 225 | 2009_003517__05 226 | 2009_003640__03 227 | 2009_003696__01 228 | 2009_003707__04 229 | 2009_003806__01 230 | 2009_003858__03 231 | 2009_003971__02 232 | 2009_004021__03 233 | 2009_004084__03 234 | 2009_004125__04 235 | 2009_004247__05 236 | 2009_004324__05 237 | 2009_004509__03 238 | 2009_004540__03 239 | 2009_004568__03 240 | 2009_004579__05 241 | 2009_004635__04 242 | 2009_004653__01 243 | 2009_004848__02 244 | 2009_004882__02 245 | 2009_004886__03 246 | 2009_004895__03 247 | 2009_004969__01 248 | 2009_005038__05 249 | 2009_005087__02 250 | 2009_005137__03 251 | 2009_005156__02 252 | 2009_005189__01 253 | 2009_005190__05 254 | 2009_005260__03 255 | 2009_005262__03 256 | 2009_005302__05 257 | 2010_000065__02 258 | 2010_000083__02 259 | 2010_000084__04 260 | 2010_000238__01 261 | 2010_000241__03 262 | 2010_000256__05 263 | 2010_000272__04 264 | 2010_000335__03 265 | 2010_000342__02 266 | 2010_000426__05 267 | 2010_000559__04 268 | 2010_000572__01 269 | 2010_000573__05 270 | 2010_000622__01 271 | 2010_000679__04 272 | 2010_000814__03 273 | 2010_000836__03 274 | 2010_000906__04 275 | 2010_000918__03 276 | 2010_000961__03 277 | 2010_001016__03 278 | 2010_001017__01 279 | 2010_001024__01 280 | 2010_001036__04 281 | 2010_001061__03 282 | 2010_001069__03 283 | 2010_001174__01 284 | 2010_001246__02 285 | 2010_001367__02 286 | 2010_001367__05 287 | 2010_001448__01 288 | 2010_001820__03 289 | 2010_001830__05 290 | 2010_001995__03 291 | 2010_002017__05 292 | 2010_002030__02 293 | 2010_002142__03 294 | 2010_002147__01 295 | 2010_002150__04 296 | 2010_002200__01 297 | 2010_002310__01 298 | 2010_002450__03 299 | 2010_002536__02 300 | 2010_002546__04 301 | 2010_002693__02 302 | 2010_002792__02 303 | 2010_002939__01 304 | 2010_003127__01 305 | 2010_003132__01 306 | 2010_003168__03 307 | 2010_003293__02 308 | 2010_003362__03 309 | 2010_003365__01 310 | 2010_003418__03 311 | 2010_003446__02 312 | 2010_003468__05 313 | 2010_003473__03 314 | 2010_003495__01 315 | 2010_003547__04 316 | 2010_003597__05 317 | 2010_003716__01 318 | 2010_003771__03 319 | 2010_003781__05 320 | 2010_003820__03 321 | 2010_003912__02 322 | 2010_003912__05 323 | 2010_003915__01 324 | 2010_004041__04 325 | 2010_004056__05 326 | 2010_004063__01 327 | 2010_004208__04 328 | 2010_004314__01 329 | 2010_004419__01 330 | 2010_004520__05 331 | 2010_004529__05 332 | 2010_004551__05 333 | 2010_004556__03 334 | 2010_004559__03 335 | 2010_004662__04 336 | 2010_004772__04 337 | 2010_004828__05 338 | 2010_004994__03 339 | 2010_005252__04 340 | 2010_005353__03 341 | 2010_005401__04 342 | 2010_005428__03 343 | 2010_005496__05 344 | 2010_005531__03 345 | 2010_005534__01 346 | 2010_005582__05 347 | 2010_005664__02 348 | 2010_005705__04 349 | 2010_005718__01 350 | 2010_005762__05 351 | 2010_005860__02 352 | 2010_005877__01 353 | 2010_005888__01 354 | 2010_006034__01 355 | 2010_006070__02 356 | 2011_000066__05 357 | 2011_000112__03 358 | 2011_000185__03 359 | 2011_000234__04 360 | 2011_000238__04 361 | 2011_000412__02 362 | 2011_000435__04 363 | 2011_000438__05 364 | 2011_000456__03 365 | 2011_000481__01 366 | 2011_000482__03 367 | 2011_000548__03 368 | 2011_000585__02 369 | 2011_000669__03 370 | 2011_000747__05 371 | 2011_000807__01 372 | 2011_000843__05 373 | 2011_000874__01 374 | 2011_000912__03 375 | 2011_000953__05 376 | 2011_001114__01 377 | 2011_001161__04 378 | 2011_001263__01 379 | 2011_001287__03 380 | 2011_001407__01 381 | 2011_001421__03 382 | 2011_001434__01 383 | 2011_001529__02 384 | 2011_001589__04 385 | 2011_001613__04 386 | 2011_001624__01 387 | 2011_001793__04 388 | 2011_001880__01 389 | 2011_001988__02 390 | 2011_002002__02 391 | 2011_002064__02 392 | 2011_002098__05 393 | 2011_002150__02 394 | 2011_002223__02 395 | 2011_002295__03 396 | 2011_002327__01 397 | 2011_002358__05 398 | 2011_002498__05 399 | 2011_002515__01 400 | 2011_002548__03 401 | 2011_002578__03 402 | 2011_002578__04 403 | 2011_002675__01 404 | 2011_002713__02 405 | 2011_002754__04 406 | 2011_002863__05 407 | 2011_002929__01 408 | 2011_002975__04 409 | 2011_003003__02 410 | 2011_003030__03 411 | 2011_003145__03 412 | 2011_003240__05 413 | 2011_003271__05 414 | -------------------------------------------------------------------------------- /data_list/val_list/split1_val.txt: -------------------------------------------------------------------------------- 1 | 2007_000452__09 2 | 2007_000464__10 3 | 2007_000491__10 4 | 2007_000663__06 5 | 2007_000663__07 6 | 2007_000727__06 7 | 2007_000727__07 8 | 2007_000804__09 9 | 2007_000830__09 10 | 2007_001299__10 11 | 2007_001321__07 12 | 2007_001457__09 13 | 2007_001677__09 14 | 2007_001717__09 15 | 2007_001763__08 16 | 2007_001774__08 17 | 2007_001884__06 18 | 2007_002268__08 19 | 2007_002387__10 20 | 2007_002445__08 21 | 2007_002470__08 22 | 2007_002539__06 23 | 2007_002597__08 24 | 2007_002643__07 25 | 2007_002903__10 26 | 2007_003011__09 27 | 2007_003051__07 28 | 2007_003101__06 29 | 2007_003106__08 30 | 2007_003137__06 31 | 2007_003143__07 32 | 2007_003169__08 33 | 2007_003195__06 34 | 2007_003195__07 35 | 2007_003201__10 36 | 2007_003503__06 37 | 2007_003503__07 38 | 2007_003621__06 39 | 2007_003711__06 40 | 2007_003711__07 41 | 2007_003786__06 42 | 2007_003841__10 43 | 2007_003917__07 44 | 2007_003991__08 45 | 2007_004193__09 46 | 2007_004392__09 47 | 2007_004405__09 48 | 2007_004510__09 49 | 2007_004712__09 50 | 2007_004856__08 51 | 2007_004866__08 52 | 2007_005074__07 53 | 2007_005114__10 54 | 2007_005296__07 55 | 2007_005331__07 56 | 2007_005460__08 57 | 2007_005509__07 58 | 2007_005547__07 59 | 2007_005547__10 60 | 2007_005844__09 61 | 2007_005845__08 62 | 2007_005911__06 63 | 2007_005978__06 64 | 2007_006035__07 65 | 2007_006086__09 66 | 2007_006241__09 67 | 2007_006260__08 68 | 2007_006277__07 69 | 2007_006348__09 70 | 2007_006553__09 71 | 2007_006761__10 72 | 2007_006841__10 73 | 2007_007203__09 74 | 2007_007414__07 75 | 2007_007417__08 76 | 2007_007493__09 77 | 2007_007524__08 78 | 2007_007795__09 79 | 2007_007815__07 80 | 2007_007818__06 81 | 2007_007818__07 82 | 2007_007996__09 83 | 2007_008106__09 84 | 2007_008110__08 85 | 2007_008110__09 86 | 2007_008543__09 87 | 2007_008722__10 88 | 2007_008747__06 89 | 2007_008747__07 90 | 2007_008815__08 91 | 2007_008897__09 92 | 2007_008973__10 93 | 2007_009015__06 94 | 2007_009015__07 95 | 2007_009068__09 96 | 2007_009084__09 97 | 2007_009096__07 98 | 2007_009221__08 99 | 2007_009245__10 100 | 2007_009346__08 101 | 2007_009392__06 102 | 2007_009392__07 103 | 2007_009413__09 104 | 2007_009521__09 105 | 2007_009764__06 106 | 2007_009794__08 107 | 2007_009897__10 108 | 2007_009923__08 109 | 2007_009938__07 110 | 2008_000009__10 111 | 2008_000073__10 112 | 2008_000075__06 113 | 2008_000107__09 114 | 2008_000149__09 115 | 2008_000182__08 116 | 2008_000345__08 117 | 2008_000401__08 118 | 2008_000464__08 119 | 2008_000501__07 120 | 2008_000673__09 121 | 2008_000853__08 122 | 2008_000919__10 123 | 2008_001078__08 124 | 2008_001433__08 125 | 2008_001439__09 126 | 2008_001513__08 127 | 2008_001640__08 128 | 2008_001715__09 129 | 2008_001885__08 130 | 2008_002152__08 131 | 2008_002205__06 132 | 2008_002212__07 133 | 2008_002379__09 134 | 2008_002521__09 135 | 2008_002623__08 136 | 2008_002681__08 137 | 2008_002778__10 138 | 2008_002958__07 139 | 2008_003141__06 140 | 2008_003141__07 141 | 2008_003333__07 142 | 2008_003499__08 143 | 2008_003577__07 144 | 2008_003777__06 145 | 2008_003821__09 146 | 2008_003846__07 147 | 2008_004069__07 148 | 2008_004339__07 149 | 2008_004552__07 150 | 2008_004612__09 151 | 2008_004701__10 152 | 2008_005097__10 153 | 2008_005105__10 154 | 2008_005245__07 155 | 2008_005676__06 156 | 2008_006008__09 157 | 2008_006063__10 158 | 2008_006254__07 159 | 2008_006275__07 160 | 2008_006325__08 161 | 2008_006341__08 162 | 2008_006480__08 163 | 2008_006526__07 164 | 2008_006528__10 165 | 2008_006554__06 166 | 2008_006986__07 167 | 2008_007025__10 168 | 2008_007031__10 169 | 2008_007048__09 170 | 2008_007123__10 171 | 2008_007194__09 172 | 2008_007273__10 173 | 2008_007378__09 174 | 2008_007402__09 175 | 2008_007527__09 176 | 2008_007548__08 177 | 2008_007596__10 178 | 2008_007737__09 179 | 2008_007797__06 180 | 2008_007804__06 181 | 2008_007804__07 182 | 2008_007828__09 183 | 2008_008127__07 184 | 2008_008252__06 185 | 2008_008301__06 186 | 2008_008362__09 187 | 2008_008469__06 188 | 2008_008682__06 189 | 2009_000013__08 190 | 2009_000080__08 191 | 2009_000096__09 192 | 2009_000121__07 193 | 2009_000219__10 194 | 2009_000309__10 195 | 2009_000335__06 196 | 2009_000335__07 197 | 2009_000391__09 198 | 2009_000426__06 199 | 2009_000455__06 200 | 2009_000457__07 201 | 2009_000488__08 202 | 2009_000523__07 203 | 2009_000641__10 204 | 2009_000716__08 205 | 2009_000731__10 206 | 2009_000771__10 207 | 2009_000825__07 208 | 2009_000964__08 209 | 2009_001008__08 210 | 2009_001082__06 211 | 2009_001240__07 212 | 2009_001255__07 213 | 2009_001299__09 214 | 2009_001391__08 215 | 2009_001411__08 216 | 2009_001536__07 217 | 2009_001775__09 218 | 2009_001804__06 219 | 2009_001816__06 220 | 2009_001854__06 221 | 2009_002035__10 222 | 2009_002122__10 223 | 2009_002150__10 224 | 2009_002164__07 225 | 2009_002171__10 226 | 2009_002221__10 227 | 2009_002238__06 228 | 2009_002238__07 229 | 2009_002239__07 230 | 2009_002268__08 231 | 2009_002346__09 232 | 2009_002415__09 233 | 2009_002487__09 234 | 2009_002527__08 235 | 2009_002535__06 236 | 2009_002549__10 237 | 2009_002571__09 238 | 2009_002618__07 239 | 2009_002635__10 240 | 2009_002753__08 241 | 2009_002936__08 242 | 2009_002990__07 243 | 2009_003003__07 244 | 2009_003059__10 245 | 2009_003071__09 246 | 2009_003123__09 247 | 2009_003269__07 248 | 2009_003304__06 249 | 2009_003387__07 250 | 2009_003406__07 251 | 2009_003494__09 252 | 2009_003507__09 253 | 2009_003542__10 254 | 2009_003549__07 255 | 2009_003569__10 256 | 2009_003589__07 257 | 2009_003703__06 258 | 2009_003771__08 259 | 2009_003773__10 260 | 2009_003849__09 261 | 2009_003895__09 262 | 2009_003904__08 263 | 2009_003991__09 264 | 2009_004072__06 265 | 2009_004099__08 266 | 2009_004140__09 267 | 2009_004217__09 268 | 2009_004248__08 269 | 2009_004455__07 270 | 2009_004494__07 271 | 2009_004504__08 272 | 2009_004590__06 273 | 2009_004594__07 274 | 2009_004687__09 275 | 2009_004721__08 276 | 2009_004732__06 277 | 2009_004748__07 278 | 2009_004789__06 279 | 2009_004859__09 280 | 2009_004867__06 281 | 2009_005158__08 282 | 2009_005219__08 283 | 2009_005231__06 284 | 2010_000003__09 285 | 2010_000160__07 286 | 2010_000163__08 287 | 2010_000372__07 288 | 2010_000427__10 289 | 2010_000530__07 290 | 2010_000552__08 291 | 2010_000573__06 292 | 2010_000628__07 293 | 2010_000639__09 294 | 2010_000682__06 295 | 2010_000683__08 296 | 2010_000724__08 297 | 2010_000907__10 298 | 2010_000941__08 299 | 2010_000952__07 300 | 2010_001000__10 301 | 2010_001010__10 302 | 2010_001070__08 303 | 2010_001149__08 304 | 2010_001206__06 305 | 2010_001292__08 306 | 2010_001331__08 307 | 2010_001351__08 308 | 2010_001403__06 309 | 2010_001403__07 310 | 2010_001534__08 311 | 2010_001553__07 312 | 2010_001579__09 313 | 2010_001646__06 314 | 2010_001656__08 315 | 2010_001692__10 316 | 2010_001699__09 317 | 2010_001767__07 318 | 2010_001851__09 319 | 2010_001913__08 320 | 2010_001966__07 321 | 2010_002017__07 322 | 2010_002017__09 323 | 2010_002025__08 324 | 2010_002137__08 325 | 2010_002146__08 326 | 2010_002305__08 327 | 2010_002336__09 328 | 2010_002348__08 329 | 2010_002361__07 330 | 2010_002390__10 331 | 2010_002422__08 332 | 2010_002512__08 333 | 2010_002531__08 334 | 2010_002538__09 335 | 2010_002546__06 336 | 2010_002546__07 337 | 2010_002623__09 338 | 2010_002693__08 339 | 2010_002693__09 340 | 2010_002701__10 341 | 2010_002763__08 342 | 2010_002763__10 343 | 2010_002868__06 344 | 2010_002900__08 345 | 2010_002902__07 346 | 2010_002921__09 347 | 2010_002929__07 348 | 2010_002988__07 349 | 2010_003123__07 350 | 2010_003183__10 351 | 2010_003231__07 352 | 2010_003239__10 353 | 2010_003275__08 354 | 2010_003276__07 355 | 2010_003293__06 356 | 2010_003302__09 357 | 2010_003325__09 358 | 2010_003381__07 359 | 2010_003402__08 360 | 2010_003409__09 361 | 2010_003446__07 362 | 2010_003453__07 363 | 2010_003468__08 364 | 2010_003531__09 365 | 2010_003675__08 366 | 2010_003746__07 367 | 2010_003758__08 368 | 2010_003764__08 369 | 2010_003768__07 370 | 2010_003772__06 371 | 2010_003781__08 372 | 2010_003813__07 373 | 2010_003854__07 374 | 2010_003971__08 375 | 2010_003971__09 376 | 2010_004104__08 377 | 2010_004120__08 378 | 2010_004320__08 379 | 2010_004322__10 380 | 2010_004337__07 381 | 2010_004348__06 382 | 2010_004369__08 383 | 2010_004479__08 384 | 2010_004543__09 385 | 2010_004635__10 386 | 2010_004763__09 387 | 2010_004783__09 388 | 2010_004789__10 389 | 2010_004815__08 390 | 2010_004825__09 391 | 2010_004861__08 392 | 2010_004946__07 393 | 2010_005013__07 394 | 2010_005021__08 395 | 2010_005021__09 396 | 2010_005063__06 397 | 2010_005108__08 398 | 2010_005118__06 399 | 2010_005160__06 400 | 2010_005166__10 401 | 2010_005284__06 402 | 2010_005284__07 403 | 2010_005344__08 404 | 2010_005421__08 405 | 2010_005432__07 406 | 2010_005501__07 407 | 2010_005508__08 408 | 2010_005606__08 409 | 2010_005626__09 410 | 2010_005709__08 411 | 2010_005718__07 412 | 2010_005860__07 413 | 2010_005899__08 414 | 2010_005922__10 415 | 2010_006070__07 416 | 2011_000178__06 417 | 2011_000226__09 418 | 2011_000239__06 419 | 2011_000248__06 420 | 2011_000312__06 421 | 2011_000338__09 422 | 2011_000419__08 423 | 2011_000503__07 424 | 2011_000548__10 425 | 2011_000566__10 426 | 2011_000607__09 427 | 2011_000658__07 428 | 2011_000661__08 429 | 2011_000661__09 430 | 2011_000780__08 431 | 2011_000789__08 432 | 2011_000809__09 433 | 2011_000813__08 434 | 2011_000813__09 435 | 2011_000830__06 436 | 2011_000888__06 437 | 2011_000900__07 438 | 2011_000969__06 439 | 2011_001047__10 440 | 2011_001064__06 441 | 2011_001071__09 442 | 2011_001110__06 443 | 2011_001110__07 444 | 2011_001159__10 445 | 2011_001232__10 446 | 2011_001292__08 447 | 2011_001313__07 448 | 2011_001341__06 449 | 2011_001346__09 450 | 2011_001447__09 451 | 2011_001530__10 452 | 2011_001534__08 453 | 2011_001546__10 454 | 2011_001567__09 455 | 2011_001597__08 456 | 2011_001601__08 457 | 2011_001607__08 458 | 2011_001665__09 459 | 2011_001708__10 460 | 2011_001748__06 461 | 2011_001775__08 462 | 2011_001782__10 463 | 2011_001812__09 464 | 2011_001862__09 465 | 2011_002041__09 466 | 2011_002064__07 467 | 2011_002124__09 468 | 2011_002200__09 469 | 2011_002298__09 470 | 2011_002322__07 471 | 2011_002358__09 472 | 2011_002391__09 473 | 2011_002509__09 474 | 2011_002575__09 475 | 2011_002592__07 476 | 2011_002644__09 477 | 2011_002685__08 478 | 2011_002812__07 479 | 2011_002885__10 480 | 2011_003011__09 481 | 2011_003019__07 482 | 2011_003019__10 483 | 2011_003055__07 484 | 2011_003103__09 485 | 2011_003114__06 486 | -------------------------------------------------------------------------------- /data_list/val_list/split3_val.txt: -------------------------------------------------------------------------------- 1 | 2007_000042__19 2 | 2007_000123__19 3 | 2007_000175__17 4 | 2007_000187__20 5 | 2007_000452__18 6 | 2007_000559__20 7 | 2007_000629__19 8 | 2007_000636__19 9 | 2007_000661__16 10 | 2007_000661__18 11 | 2007_000676__17 12 | 2007_000804__18 13 | 2007_000925__17 14 | 2007_001154__18 15 | 2007_001175__20 16 | 2007_001408__16 17 | 2007_001430__16 18 | 2007_001430__20 19 | 2007_001457__18 20 | 2007_001458__18 21 | 2007_001585__18 22 | 2007_001594__17 23 | 2007_001678__20 24 | 2007_001717__20 25 | 2007_001733__17 26 | 2007_001763__18 27 | 2007_001763__20 28 | 2007_002119__20 29 | 2007_002132__20 30 | 2007_002268__18 31 | 2007_002284__16 32 | 2007_002378__16 33 | 2007_002426__18 34 | 2007_002427__16 35 | 2007_002427__18 36 | 2007_002565__19 37 | 2007_002618__17 38 | 2007_002648__17 39 | 2007_002728__19 40 | 2007_003011__18 41 | 2007_003011__20 42 | 2007_003020__16 43 | 2007_003169__18 44 | 2007_003367__16 45 | 2007_003499__19 46 | 2007_003506__16 47 | 2007_003530__18 48 | 2007_003587__19 49 | 2007_003714__17 50 | 2007_003848__19 51 | 2007_003957__19 52 | 2007_004121__16 53 | 2007_004190__20 54 | 2007_004193__20 55 | 2007_004275__16 56 | 2007_004281__19 57 | 2007_004392__16 58 | 2007_004483__19 59 | 2007_004510__20 60 | 2007_004558__16 61 | 2007_004649__19 62 | 2007_004712__16 63 | 2007_004969__17 64 | 2007_005058__17 65 | 2007_005469__17 66 | 2007_005626__19 67 | 2007_005689__19 68 | 2007_005813__16 69 | 2007_005857__16 70 | 2007_005915__17 71 | 2007_006171__18 72 | 2007_006348__20 73 | 2007_006373__18 74 | 2007_006678__17 75 | 2007_006680__19 76 | 2007_006802__19 77 | 2007_006837__17 78 | 2007_007130__16 79 | 2007_007130__20 80 | 2007_007165__17 81 | 2007_007168__19 82 | 2007_007195__19 83 | 2007_007196__20 84 | 2007_007203__20 85 | 2007_007417__18 86 | 2007_007534__17 87 | 2007_007624__16 88 | 2007_007795__16 89 | 2007_007849__20 90 | 2007_007881__19 91 | 2007_007996__18 92 | 2007_008204__20 93 | 2007_008260__18 94 | 2007_008339__19 95 | 2007_008374__20 96 | 2007_008543__18 97 | 2007_008547__16 98 | 2007_008897__20 99 | 2007_009068__16 100 | 2007_009068__18 101 | 2007_009252__16 102 | 2007_009252__18 103 | 2007_009320__17 104 | 2007_009419__16 105 | 2007_009446__20 106 | 2007_009521__18 107 | 2007_009521__20 108 | 2007_009592__18 109 | 2007_009655__18 110 | 2007_009684__18 111 | 2007_009750__16 112 | 2008_000016__20 113 | 2008_000149__18 114 | 2008_000270__18 115 | 2008_000391__16 116 | 2008_000589__18 117 | 2008_000657__19 118 | 2008_001078__16 119 | 2008_001283__16 120 | 2008_001688__16 121 | 2008_001688__20 122 | 2008_001966__16 123 | 2008_002273__16 124 | 2008_002379__16 125 | 2008_002464__20 126 | 2008_002536__17 127 | 2008_002680__20 128 | 2008_002900__19 129 | 2008_002929__18 130 | 2008_003003__16 131 | 2008_003003__20 132 | 2008_003026__20 133 | 2008_003105__19 134 | 2008_003135__16 135 | 2008_003676__16 136 | 2008_003709__18 137 | 2008_003885__20 138 | 2008_004172__18 139 | 2008_004212__19 140 | 2008_004279__20 141 | 2008_004367__19 142 | 2008_004453__17 143 | 2008_004477__16 144 | 2008_004562__18 145 | 2008_004610__19 146 | 2008_004621__17 147 | 2008_004754__20 148 | 2008_004854__17 149 | 2008_004910__20 150 | 2008_005089__20 151 | 2008_005217__16 152 | 2008_005242__16 153 | 2008_005254__20 154 | 2008_005439__20 155 | 2008_005445__20 156 | 2008_005544__19 157 | 2008_005633__17 158 | 2008_005680__16 159 | 2008_006055__19 160 | 2008_006159__20 161 | 2008_006327__17 162 | 2008_006523__19 163 | 2008_006553__19 164 | 2008_006752__19 165 | 2008_006784__18 166 | 2008_006835__17 167 | 2008_007497__17 168 | 2008_007527__20 169 | 2008_007677__17 170 | 2008_007814__17 171 | 2008_007828__20 172 | 2008_008103__18 173 | 2008_008221__19 174 | 2008_008434__16 175 | 2009_000022__19 176 | 2009_000039__17 177 | 2009_000087__18 178 | 2009_000096__18 179 | 2009_000136__20 180 | 2009_000242__18 181 | 2009_000391__20 182 | 2009_000418__16 183 | 2009_000418__18 184 | 2009_000487__18 185 | 2009_000488__16 186 | 2009_000488__20 187 | 2009_000628__19 188 | 2009_000675__17 189 | 2009_000704__20 190 | 2009_000712__19 191 | 2009_000732__18 192 | 2009_000845__19 193 | 2009_000924__17 194 | 2009_001300__19 195 | 2009_001333__19 196 | 2009_001363__20 197 | 2009_001505__17 198 | 2009_001644__16 199 | 2009_001644__18 200 | 2009_001644__20 201 | 2009_001684__16 202 | 2009_001731__18 203 | 2009_001768__17 204 | 2009_001775__16 205 | 2009_001775__18 206 | 2009_001991__17 207 | 2009_002082__17 208 | 2009_002094__20 209 | 2009_002202__19 210 | 2009_002265__19 211 | 2009_002291__19 212 | 2009_002346__18 213 | 2009_002366__20 214 | 2009_002390__18 215 | 2009_002487__16 216 | 2009_002562__20 217 | 2009_002568__19 218 | 2009_002571__16 219 | 2009_002571__18 220 | 2009_002573__20 221 | 2009_002584__16 222 | 2009_002638__19 223 | 2009_002732__18 224 | 2009_002771__16 225 | 2009_002887__19 226 | 2009_002928__16 227 | 2009_002982__19 228 | 2009_003105__19 229 | 2009_003123__18 230 | 2009_003299__19 231 | 2009_003311__19 232 | 2009_003387__20 233 | 2009_003433__19 234 | 2009_003523__20 235 | 2009_003551__20 236 | 2009_003564__16 237 | 2009_003564__18 238 | 2009_003607__18 239 | 2009_003666__17 240 | 2009_003857__20 241 | 2009_003895__18 242 | 2009_003895__20 243 | 2009_003938__19 244 | 2009_004070__17 245 | 2009_004099__18 246 | 2009_004140__18 247 | 2009_004255__19 248 | 2009_004298__18 249 | 2009_004687__18 250 | 2009_004730__19 251 | 2009_004799__19 252 | 2009_004993__18 253 | 2009_004993__20 254 | 2009_005038__16 255 | 2009_005089__17 256 | 2009_005148__19 257 | 2009_005220__19 258 | 2010_000256__18 259 | 2010_000284__18 260 | 2010_000309__17 261 | 2010_000318__20 262 | 2010_000330__16 263 | 2010_000639__16 264 | 2010_000738__20 265 | 2010_000764__19 266 | 2010_001011__17 267 | 2010_001079__17 268 | 2010_001104__19 269 | 2010_001149__18 270 | 2010_001151__19 271 | 2010_001246__16 272 | 2010_001256__17 273 | 2010_001327__18 274 | 2010_001367__20 275 | 2010_001522__17 276 | 2010_001557__17 277 | 2010_001577__17 278 | 2010_001699__16 279 | 2010_001734__19 280 | 2010_001752__20 281 | 2010_001767__18 282 | 2010_001773__16 283 | 2010_001851__16 284 | 2010_001951__19 285 | 2010_001962__18 286 | 2010_002106__17 287 | 2010_002137__16 288 | 2010_002232__17 289 | 2010_002336__16 290 | 2010_002531__18 291 | 2010_002682__19 292 | 2010_002921__20 293 | 2010_003123__16 294 | 2010_003302__16 295 | 2010_003514__19 296 | 2010_003541__17 297 | 2010_003597__16 298 | 2010_003597__18 299 | 2010_003708__19 300 | 2010_003781__16 301 | 2010_003956__19 302 | 2010_004149__19 303 | 2010_004226__17 304 | 2010_004382__16 305 | 2010_004479__20 306 | 2010_004757__16 307 | 2010_004757__18 308 | 2010_004783__18 309 | 2010_004825__16 310 | 2010_004857__20 311 | 2010_004980__19 312 | 2010_005160__17 313 | 2010_005180__18 314 | 2010_005180__20 315 | 2010_005187__16 316 | 2010_005305__16 317 | 2010_005305__20 318 | 2010_005606__18 319 | 2010_005706__19 320 | 2010_005719__17 321 | 2010_005727__19 322 | 2010_005788__17 323 | 2010_005860__16 324 | 2010_005871__19 325 | 2010_006026__17 326 | 2010_006054__19 327 | 2011_000070__18 328 | 2011_000173__18 329 | 2011_000283__19 330 | 2011_000291__19 331 | 2011_000310__18 332 | 2011_000436__17 333 | 2011_000521__19 334 | 2011_000607__20 335 | 2011_000747__16 336 | 2011_001005__18 337 | 2011_001020__17 338 | 2011_001060__19 339 | 2011_001082__19 340 | 2011_001281__19 341 | 2011_001350__17 342 | 2011_001567__18 343 | 2011_001614__19 344 | 2011_001713__16 345 | 2011_001713__18 346 | 2011_001726__20 347 | 2011_001794__18 348 | 2011_001862__16 349 | 2011_001863__16 350 | 2011_001910__20 351 | 2011_002178__17 352 | 2011_002247__19 353 | 2011_002379__19 354 | 2011_002391__16 355 | 2011_002391__18 356 | 2011_002509__20 357 | 2011_002532__20 358 | 2011_002535__19 359 | 2011_002644__18 360 | 2011_002644__20 361 | 2011_002662__17 362 | 2011_002812__16 363 | 2011_002879__20 364 | 2011_003103__16 365 | 2011_003146__19 366 | 2011_003182__18 367 | 2011_003197__19 368 | 2011_003256__18 369 | -------------------------------------------------------------------------------- /data_parallel.py: -------------------------------------------------------------------------------- 1 | from torch.nn.parallel import DataParallel 2 | import torch 3 | from torch.nn.parallel._functions import Scatter 4 | from torch.nn.parallel.parallel_apply import parallel_apply 5 | 6 | def scatter(inputs, target_gpus, chunk_sizes, dim=0): 7 | r""" 8 | Slices tensors into approximately equal chunks and 9 | distributes them across given GPUs. Duplicates 10 | references to objects that are not tensors. 11 | """ 12 | def scatter_map(obj): 13 | if isinstance(obj, torch.Tensor): 14 | try: 15 | return Scatter.apply(target_gpus, chunk_sizes, dim, obj) 16 | except: 17 | print('obj', obj.size()) 18 | print('dim', dim) 19 | print('chunk_sizes', chunk_sizes) 20 | quit() 21 | if isinstance(obj, tuple) and len(obj) > 0: 22 | return list(zip(*map(scatter_map, obj))) 23 | if isinstance(obj, list) and len(obj) > 0: 24 | return list(map(list, zip(*map(scatter_map, obj)))) 25 | if isinstance(obj, dict) and len(obj) > 0: 26 | return list(map(type(obj), zip(*map(scatter_map, obj.items())))) 27 | return [obj for targets in target_gpus] 28 | 29 | # After scatter_map is called, a scatter_map cell will exist. This cell 30 | # has a reference to the actual function scatter_map, which has references 31 | # to a closure that has a reference to the scatter_map cell (because the 32 | # fn is recursive). To avoid this reference cycle, we set the function to 33 | # None, clearing the cell 34 | try: 35 | return scatter_map(inputs) 36 | finally: 37 | scatter_map = None 38 | 39 | def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0): 40 | r"""Scatter with support for kwargs dictionary""" 41 | inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else [] 42 | kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else [] 43 | if len(inputs) < len(kwargs): 44 | inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) 45 | elif len(kwargs) < len(inputs): 46 | kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) 47 | inputs = tuple(inputs) 48 | kwargs = tuple(kwargs) 49 | return inputs, kwargs 50 | 51 | class BalancedDataParallel(DataParallel): 52 | def __init__(self, gpu0_bsz, *args, **kwargs): 53 | self.gpu0_bsz = gpu0_bsz 54 | super().__init__(*args, **kwargs) 55 | 56 | def forward(self, *inputs, **kwargs): 57 | if not self.device_ids: 58 | return self.module(*inputs, **kwargs) 59 | if self.gpu0_bsz == 0: 60 | device_ids = self.device_ids[1:] 61 | else: 62 | device_ids = self.device_ids 63 | inputs, kwargs = self.scatter(inputs, kwargs, device_ids) 64 | if len(self.device_ids) == 1: 65 | return self.module(*inputs[0], **kwargs[0]) 66 | replicas = self.replicate(self.module, self.device_ids) 67 | if self.gpu0_bsz == 0: 68 | replicas = replicas[1:] 69 | outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs) 70 | return self.gather(outputs, self.output_device) 71 | 72 | def parallel_apply(self, replicas, device_ids, inputs, kwargs): 73 | return parallel_apply(replicas, inputs, kwargs, device_ids) 74 | 75 | def scatter(self, inputs, kwargs, device_ids): 76 | bsz = inputs[0].size(self.dim) 77 | num_dev = len(self.device_ids) 78 | gpu0_bsz = self.gpu0_bsz 79 | bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1) 80 | if gpu0_bsz < bsz_unit: 81 | chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1) 82 | delta = bsz - sum(chunk_sizes) 83 | for i in range(delta): 84 | chunk_sizes[i + 1] += 1 85 | if gpu0_bsz == 0: 86 | chunk_sizes = chunk_sizes[1:] 87 | else: 88 | return super().scatter(inputs, kwargs, device_ids) 89 | return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim) 90 | -------------------------------------------------------------------------------- /img/chain10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoql98/progressively-dual-prior-guided-few-shot-semantic-segmentation/aaf2e7d2dca4a52972c8bff46e3a4bc6d775a946/img/chain10.png -------------------------------------------------------------------------------- /img/graph4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoql98/progressively-dual-prior-guided-few-shot-semantic-segmentation/aaf2e7d2dca4a52972c8bff46e3a4bc6d775a946/img/graph4.png -------------------------------------------------------------------------------- /img/prior1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoql98/progressively-dual-prior-guided-few-shot-semantic-segmentation/aaf2e7d2dca4a52972c8bff46e3a4bc6d775a946/img/prior1.png -------------------------------------------------------------------------------- /img/result7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoql98/progressively-dual-prior-guided-few-shot-semantic-segmentation/aaf2e7d2dca4a52972c8bff46e3a4bc6d775a946/img/result7.png -------------------------------------------------------------------------------- /models/PMMs.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class PMMs(nn.Module): 8 | '''Prototype Mixture Models 9 | Arguments: 10 | c (int): The input and output channel number. 11 | k (int): The number of the bases. 12 | stage_num (int): The iteration number for EM. 13 | ''' 14 | 15 | def __init__(self, c, k=3, stage_num=10): 16 | super(PMMs, self).__init__() 17 | self.stage_num = stage_num 18 | self.num_pro = k 19 | mu = torch.Tensor(1, c, k).cuda() 20 | mu.normal_(0, math.sqrt(2. / k)) # Init mu 21 | self.mu = self._l2norm(mu, dim=1) 22 | self.kappa = 20 23 | #self.register_buffer('mu', mu) 24 | 25 | 26 | def forward(self, support_feature, support_mask, query_feature): 27 | prototypes, mu_f, mu_b = self.generate_prototype(support_feature, support_mask) 28 | Prob_map, P = self.discriminative_model(query_feature, mu_f, mu_b) 29 | 30 | return prototypes, Prob_map 31 | 32 | def _l2norm(self, inp, dim): 33 | '''Normlize the inp tensor with l2-norm. 34 | Returns a tensor where each sub-tensor of input along the given dim is 35 | normalized such that the 2-norm of the sub-tensor is equal to 1. 36 | Arguments: 37 | inp (tensor): The input tensor. 38 | dim (int): The dimension to slice over to get the ssub-tensors. 39 | Returns: 40 | (tensor) The normalized tensor. 41 | ''' 42 | return inp / (1e-6 + inp.norm(dim=dim, keepdim=True)) 43 | 44 | def EM(self,x): 45 | ''' 46 | EM method 47 | :param x: feauture b * c * n 48 | :return: mu 49 | ''' 50 | b = x.shape[0] 51 | 52 | # k= 3 53 | # c = x.shape[1] 54 | # mu = torch.Tensor(1, c, k).cuda() 55 | # mu.normal_(0, math.sqrt(2. / k)) # Init mu 56 | # mu = self._l2norm(mu, dim=1) 57 | 58 | mu = self.mu.repeat(b, 1, 1) # b * c * k 59 | with torch.no_grad(): 60 | for i in range(self.stage_num): 61 | # E STEP: 62 | z = self.Kernel(x, mu) 63 | z = F.softmax(z, dim=2) # b * n * k 64 | # M STEP: 65 | z_ = z / (1e-6 + z.sum(dim=1, keepdim=True)) 66 | mu = torch.bmm(x, z_) # b * c * k 67 | 68 | mu = self._l2norm(mu, dim=1) 69 | 70 | mu = mu.permute(0, 2, 1) # b * k * c 71 | 72 | return mu 73 | 74 | def Kernel(self, x, mu): 75 | x_t = x.permute(0, 2, 1) # b * n * c 76 | z = self.kappa * torch.bmm(x_t, mu) # b * n * k 77 | 78 | return z 79 | 80 | def get_prototype(self,x): 81 | b, c, h, w = x.size() 82 | x = x.view(b, c, h * w) # b * c * n 83 | mu = self.EM(x) # b * k * c 84 | 85 | return mu 86 | 87 | def generate_prototype(self, feature, mask): 88 | # b,h,w = mask.shape 89 | # mask = mask.view(b,1,h,w).float() 90 | mask = F.interpolate(mask, feature.shape[-2:], mode='bilinear', align_corners=True) 91 | 92 | # foreground 93 | z = mask * feature 94 | mu_f = self.get_prototype(z) 95 | mu_ = [] 96 | for i in range(self.num_pro): 97 | mu_.append(mu_f[:, i, :].unsqueeze(dim=2).unsqueeze(dim=3)) 98 | 99 | # background 100 | # ignore_pix = torch.where(mask == 255) 101 | # mask[ignore_pix[0],ignore_pix[1]] = 0 102 | # mask[mask>1] = 0 103 | mask_bg = 1-mask 104 | 105 | z_bg = mask_bg * feature 106 | mu_b = self.get_prototype(z_bg) 107 | 108 | return mu_, mu_f, mu_b 109 | 110 | def discriminative_model(self, query_feature, mu_f, mu_b): 111 | 112 | mu = torch.cat([mu_f, mu_b], dim=1) 113 | mu = mu.permute(0, 2, 1) 114 | 115 | b, c, h, w = query_feature.size() 116 | x = query_feature.view(b, c, h * w) # b * c * n 117 | with torch.no_grad(): 118 | 119 | x_t = x.permute(0, 2, 1) # b * n * c 120 | z = torch.bmm(x_t, mu) # b * n * k 121 | 122 | z = F.softmax(z, dim=2) # b * n * k 123 | 124 | P = z.permute(0, 2, 1) 125 | 126 | P = P.view(b, self.num_pro * 2, h, w) # b * k * w * h probability map 127 | P_f = torch.sum(P[:, 0:self.num_pro], dim=1).unsqueeze(dim=1) # foreground 128 | P_b = torch.sum(P[:, self.num_pro:], dim=1).unsqueeze(dim=1) # background 129 | 130 | Prob_map = torch.cat([P_b, P_f], dim=1) 131 | 132 | return Prob_map, P 133 | -------------------------------------------------------------------------------- /models/PMMs_single.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class PMMs(nn.Module): 8 | '''Prototype Mixture Models 9 | Arguments: 10 | c (int): The input and output channel number. 11 | k (int): The number of the bases. 12 | stage_num (int): The iteration number for EM. 13 | ''' 14 | 15 | def __init__(self, c, k=3, stage_num=10): 16 | super(PMMs, self).__init__() 17 | self.stage_num = stage_num 18 | self.num_pro = k 19 | mu = torch.Tensor(1, c, k).cuda() 20 | mu.normal_(0, math.sqrt(2. / k)) # Init mu 21 | self.mu = self._l2norm(mu, dim=1) 22 | self.kappa = 20 23 | #self.register_buffer('mu', mu) 24 | 25 | 26 | def forward(self, support_feature, support_mask, query_feature): 27 | prototypes, mu_f = self.generate_prototype(support_feature, support_mask) 28 | # Prob_map, P = self.discriminative_model(query_feature, mu_f) 29 | 30 | return prototypes 31 | 32 | def _l2norm(self, inp, dim): 33 | '''Normlize the inp tensor with l2-norm. 34 | Returns a tensor where each sub-tensor of input along the given dim is 35 | normalized such that the 2-norm of the sub-tensor is equal to 1. 36 | Arguments: 37 | inp (tensor): The input tensor. 38 | dim (int): The dimension to slice over to get the ssub-tensors. 39 | Returns: 40 | (tensor) The normalized tensor. 41 | ''' 42 | return inp / (1e-6 + inp.norm(dim=dim, keepdim=True)) 43 | 44 | def EM(self,x): 45 | ''' 46 | EM method 47 | :param x: feauture b * c * n 48 | :return: mu 49 | ''' 50 | b = x.shape[0] 51 | mu = self.mu.repeat(b, 1, 1) # b * c * k 52 | with torch.no_grad(): 53 | for i in range(self.stage_num): 54 | # E STEP: 55 | z = self.Kernel(x, mu) 56 | z = F.softmax(z, dim=2) # b * n * k 57 | # M STEP: 58 | z_ = z / (1e-6 + z.sum(dim=1, keepdim=True)) 59 | mu = torch.bmm(x, z_) # b * c * k 60 | 61 | mu = self._l2norm(mu, dim=1) 62 | 63 | mu = mu.permute(0, 2, 1) # b * k * c 64 | 65 | return mu 66 | 67 | def Kernel(self, x, mu): 68 | x_t = x.permute(0, 2, 1) # b * n * c 69 | z = self.kappa * torch.bmm(x_t, mu) # b * n * k 70 | 71 | return z 72 | 73 | def get_prototype(self,x): 74 | b, c, h, w = x.size() 75 | x = x.view(b, c, h * w) # b * c * n 76 | mu = self.EM(x) # b * k * c 77 | 78 | return mu 79 | 80 | def generate_prototype(self, feature, mask): 81 | mask = F.interpolate(mask, feature.shape[-2:], mode='bilinear', align_corners=True) 82 | # foreground 83 | z = mask * feature 84 | mu_f = self.get_prototype(z) 85 | mu_ = [] 86 | for i in range(self.num_pro): 87 | mu_.append(mu_f[:, i, :].unsqueeze(dim=2).unsqueeze(dim=3)) 88 | 89 | return mu_, mu_f 90 | 91 | -------------------------------------------------------------------------------- /models/backbone/AlexNet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from collections import OrderedDict 6 | from pdb import set_trace as breakpoint 7 | 8 | class Flatten(nn.Module): 9 | def __init__(self): 10 | super(Flatten, self).__init__() 11 | 12 | def forward(self, feat): 13 | return feat.view(feat.size(0), -1) 14 | 15 | class AlexNet(nn.Module): 16 | def __init__(self, opt): 17 | super(AlexNet, self).__init__() 18 | num_classes = opt['num_classes'] 19 | 20 | conv1 = nn.Sequential( 21 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), 22 | nn.BatchNorm2d(64), 23 | nn.ReLU(inplace=True), 24 | ) 25 | pool1 = nn.MaxPool2d(kernel_size=3, stride=2) 26 | conv2 = nn.Sequential( 27 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 28 | nn.BatchNorm2d(192), 29 | nn.ReLU(inplace=True), 30 | ) 31 | pool2 = nn.MaxPool2d(kernel_size=3, stride=2) 32 | conv3 = nn.Sequential( 33 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 34 | nn.BatchNorm2d(384), 35 | nn.ReLU(inplace=True), 36 | ) 37 | conv4 = nn.Sequential( 38 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 39 | nn.BatchNorm2d(256), 40 | nn.ReLU(inplace=True), 41 | ) 42 | conv5 = nn.Sequential( 43 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 44 | nn.BatchNorm2d(256), 45 | nn.ReLU(inplace=True), 46 | ) 47 | pool5 = nn.MaxPool2d(kernel_size=3, stride=2) 48 | 49 | num_pool5_feats = 6 * 6 * 256 50 | fc_block = nn.Sequential( 51 | Flatten(), 52 | nn.Linear(num_pool5_feats, 4096, bias=False), 53 | nn.BatchNorm1d(4096), 54 | nn.ReLU(inplace=True), 55 | nn.Linear(4096, 4096, bias=False), 56 | nn.BatchNorm1d(4096), 57 | nn.ReLU(inplace=True), 58 | ) 59 | classifier = nn.Sequential( 60 | nn.Linear(4096, num_classes), 61 | ) 62 | 63 | self._feature_blocks = nn.ModuleList([ 64 | conv1, 65 | pool1, 66 | conv2, 67 | pool2, 68 | conv3, 69 | conv4, 70 | conv5, 71 | pool5, 72 | fc_block, 73 | classifier, 74 | ]) 75 | self.all_feat_names = [ 76 | 'conv1', 77 | 'pool1', 78 | 'conv2', 79 | 'pool2', 80 | 'conv3', 81 | 'conv4', 82 | 'conv5', 83 | 'pool5', 84 | 'fc_block', 85 | 'classifier', 86 | ] 87 | assert(len(self.all_feat_names) == len(self._feature_blocks)) 88 | 89 | def _parse_out_keys_arg(self, out_feat_keys): 90 | 91 | # By default return the features of the last layer / module. 92 | out_feat_keys = [self.all_feat_names[-1],] if out_feat_keys is None else out_feat_keys 93 | 94 | if len(out_feat_keys) == 0: 95 | raise ValueError('Empty list of output feature keys.') 96 | for f, key in enumerate(out_feat_keys): 97 | if key not in self.all_feat_names: 98 | raise ValueError('Feature with name {0} does not exist. Existing features: {1}.'.format(key, self.all_feat_names)) 99 | elif key in out_feat_keys[:f]: 100 | raise ValueError('Duplicate output feature key: {0}.'.format(key)) 101 | 102 | # Find the highest output feature in `out_feat_keys 103 | max_out_feat = max([self.all_feat_names.index(key) for key in out_feat_keys]) 104 | 105 | return out_feat_keys, max_out_feat 106 | 107 | def forward(self, x, out_feat_keys=None): 108 | """Forward an image `x` through the network and return the asked output features. 109 | 110 | Args: 111 | x: input image. 112 | out_feat_keys: a list/tuple with the feature names of the features 113 | that the function should return. By default the last feature of 114 | the network is returned. 115 | 116 | Return: 117 | out_feats: If multiple output features were asked then `out_feats` 118 | is a list with the asked output features placed in the same 119 | order as in `out_feat_keys`. If a single output feature was 120 | asked then `out_feats` is that output feature (and not a list). 121 | """ 122 | out_feat_keys, max_out_feat = self._parse_out_keys_arg(out_feat_keys) 123 | out_feats = [None] * len(out_feat_keys) 124 | 125 | feat = x 126 | for f in range(max_out_feat+1): 127 | feat = self._feature_blocks[f](feat) 128 | key = self.all_feat_names[f] 129 | if key in out_feat_keys: 130 | out_feats[out_feat_keys.index(key)] = feat 131 | 132 | out_feats = out_feats[0] if len(out_feats)==1 else out_feats 133 | return out_feats 134 | 135 | def get_L1filters(self): 136 | convlayer = self._feature_blocks[0][0] 137 | batchnorm = self._feature_blocks[0][1] 138 | filters = convlayer.weight.data 139 | scalars = (batchnorm.weight.data / torch.sqrt(batchnorm.running_var + 1e-05)) 140 | filters = (filters * scalars.view(-1, 1, 1, 1).expand_as(filters)).cpu().clone() 141 | 142 | return filters 143 | 144 | def create_model(opt): 145 | return AlexNet(opt) 146 | 147 | def Alet(size=321, pretrained=True): 148 | opt = {'num_classes':4} 149 | net = create_model(opt) 150 | if pretrained: 151 | snapshot = '/disk2/caoqinglong/model_net_epoch50.pth' 152 | checkpoint = torch.load(snapshot) 153 | a,b = checkpoint.items() 154 | net.load_state_dict(b[1]) 155 | net.train() 156 | return net 157 | 158 | if __name__ == '__main__': 159 | size = 224 160 | opt = {'num_classes':4} 161 | 162 | net = create_model(opt) 163 | 164 | snapshot = '/disk2/caoqinglong/model_net_epoch50.pth' 165 | checkpoint = torch.load(snapshot) 166 | a,b = checkpoint.items() 167 | # model = nn.DataParallel(model) 168 | new_state_dict = OrderedDict() 169 | for k , v in b[1].items(): 170 | new_state_dict[k] = v 171 | # load params 172 | net.load_state_dict(new_state_dict) 173 | 174 | x = torch.autograd.Variable(torch.FloatTensor(2,3,size,size).uniform_(-1,1)) 175 | 176 | out = net(x, out_feat_keys=net.all_feat_names) 177 | for f in range(len(out)): 178 | print('Output feature {0} - size {1}'.format( 179 | net.all_feat_names[f], out[f].size())) 180 | 181 | filters = net.get_L1filters() 182 | 183 | print('First layer filter shape: {0}'.format(filters.size())) 184 | -------------------------------------------------------------------------------- /models/backbone/NetworkInNetwork.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from collections import OrderedDict 6 | 7 | class BasicBlock(nn.Module): 8 | def __init__(self, in_planes, out_planes, kernel_size): 9 | super(BasicBlock, self).__init__() 10 | padding = (kernel_size-1)//2 11 | self.layers = nn.Sequential() 12 | self.layers.add_module('Conv', nn.Conv2d(in_planes, out_planes, \ 13 | kernel_size=kernel_size, stride=1, padding=padding, bias=False)) 14 | self.layers.add_module('BatchNorm', nn.BatchNorm2d(out_planes)) 15 | self.layers.add_module('ReLU', nn.ReLU(inplace=True)) 16 | 17 | def forward(self, x): 18 | return self.layers(x) 19 | 20 | # feat = F.avg_pool2d(feat, feat.size(3)).view(-1, self.nChannels) 21 | 22 | class GlobalAveragePooling(nn.Module): 23 | def __init__(self): 24 | super(GlobalAveragePooling, self).__init__() 25 | 26 | def forward(self, feat): 27 | num_channels = feat.size(1) 28 | return F.avg_pool2d(feat, (feat.size(2), feat.size(3))).view(-1, num_channels) 29 | 30 | class NetworkInNetwork(nn.Module): 31 | def __init__(self, opt): 32 | super(NetworkInNetwork, self).__init__() 33 | 34 | num_classes = opt['num_classes'] 35 | num_inchannels = opt['num_inchannels'] if ('num_inchannels' in opt) else 3 36 | num_stages = opt['num_stages'] if ('num_stages' in opt) else 3 37 | use_avg_on_conv3 = opt['use_avg_on_conv3'] if ('use_avg_on_conv3' in opt) else True 38 | 39 | 40 | assert(num_stages >= 3) 41 | nChannels = 192 42 | nChannels2 = 160 43 | nChannels3 = 96 44 | 45 | blocks = [nn.Sequential() for i in range(num_stages)] 46 | # 1st block 47 | blocks[0].add_module('Block1_ConvB1', BasicBlock(num_inchannels, nChannels, 5)) 48 | blocks[0].add_module('Block1_ConvB2', BasicBlock(nChannels, nChannels2, 1)) 49 | blocks[0].add_module('Block1_ConvB3', BasicBlock(nChannels2, nChannels3, 1)) 50 | blocks[0].add_module('Block1_MaxPool', nn.MaxPool2d(kernel_size=3,stride=2,padding=1)) 51 | 52 | # 2nd block 53 | blocks[1].add_module('Block2_ConvB1', BasicBlock(nChannels3, nChannels, 5)) 54 | blocks[1].add_module('Block2_ConvB2', BasicBlock(nChannels, nChannels, 1)) 55 | blocks[1].add_module('Block2_ConvB3', BasicBlock(nChannels, nChannels, 1)) 56 | blocks[1].add_module('Block2_AvgPool', nn.AvgPool2d(kernel_size=3,stride=2,padding=1)) 57 | 58 | # 3rd block 59 | blocks[2].add_module('Block3_ConvB1', BasicBlock(nChannels, nChannels, 3)) 60 | blocks[2].add_module('Block3_ConvB2', BasicBlock(nChannels, nChannels, 1)) 61 | blocks[2].add_module('Block3_ConvB3', BasicBlock(nChannels, nChannels, 1)) 62 | 63 | if num_stages > 3 and use_avg_on_conv3: 64 | blocks[2].add_module('Block3_AvgPool', nn.AvgPool2d(kernel_size=3,stride=2,padding=1)) 65 | for s in range(3, num_stages): 66 | blocks[s].add_module('Block'+str(s+1)+'_ConvB1', BasicBlock(nChannels, nChannels, 3)) 67 | blocks[s].add_module('Block'+str(s+1)+'_ConvB2', BasicBlock(nChannels, nChannels, 1)) 68 | blocks[s].add_module('Block'+str(s+1)+'_ConvB3', BasicBlock(nChannels, nChannels, 1)) 69 | 70 | # global average pooling and classifier 71 | blocks.append(nn.Sequential()) 72 | blocks[-1].add_module('GlobalAveragePooling', GlobalAveragePooling()) 73 | blocks[-1].add_module('Classifier', nn.Linear(nChannels, num_classes)) 74 | 75 | self._feature_blocks = nn.ModuleList(blocks) 76 | self.all_feat_names = ['conv'+str(s+1) for s in range(num_stages)] + ['classifier',] 77 | assert(len(self.all_feat_names) == len(self._feature_blocks)) 78 | 79 | def _parse_out_keys_arg(self, out_feat_keys): 80 | 81 | # By default return the features of the last layer / module. 82 | out_feat_keys = [self.all_feat_names[-1],] if out_feat_keys is None else out_feat_keys 83 | 84 | if len(out_feat_keys) == 0: 85 | raise ValueError('Empty list of output feature keys.') 86 | for f, key in enumerate(out_feat_keys): 87 | if key not in self.all_feat_names: 88 | raise ValueError('Feature with name {0} does not exist. Existing features: {1}.'.format(key, self.all_feat_names)) 89 | elif key in out_feat_keys[:f]: 90 | raise ValueError('Duplicate output feature key: {0}.'.format(key)) 91 | 92 | # Find the highest output feature in `out_feat_keys 93 | max_out_feat = max([self.all_feat_names.index(key) for key in out_feat_keys]) 94 | 95 | return out_feat_keys, max_out_feat 96 | 97 | def forward(self, x, out_feat_keys=None): 98 | """Forward an image `x` through the network and return the asked output features. 99 | 100 | Args: 101 | x: input image. 102 | out_feat_keys: a list/tuple with the feature names of the features 103 | that the function should return. By default the last feature of 104 | the network is returned. 105 | 106 | Return: 107 | out_feats: If multiple output features were asked then `out_feats` 108 | is a list with the asked output features placed in the same 109 | order as in `out_feat_keys`. If a single output feature was 110 | asked then `out_feats` is that output feature (and not a list). 111 | """ 112 | out_feat_keys, max_out_feat = self._parse_out_keys_arg(out_feat_keys) 113 | out_feats = [None] * len(out_feat_keys) 114 | 115 | feat = x 116 | for f in range(max_out_feat+1): 117 | feat = self._feature_blocks[f](feat) 118 | key = self.all_feat_names[f] 119 | if key in out_feat_keys: 120 | out_feats[out_feat_keys.index(key)] = feat 121 | 122 | out_feats = out_feats[0] if len(out_feats)==1 else out_feats 123 | return out_feats 124 | 125 | 126 | def weight_initialization(self): 127 | for m in self.modules(): 128 | if isinstance(m, nn.Conv2d): 129 | if m.weight.requires_grad: 130 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 131 | m.weight.data.normal_(0, math.sqrt(2. / n)) 132 | elif isinstance(m, nn.BatchNorm2d): 133 | if m.weight.requires_grad: 134 | m.weight.data.fill_(1) 135 | if m.bias.requires_grad: 136 | m.bias.data.zero_() 137 | elif isinstance(m, nn.Linear): 138 | if m.bias.requires_grad: 139 | m.bias.data.zero_() 140 | 141 | def create_model(opt): 142 | return NetworkInNetwork(opt) 143 | 144 | def NNet(size=321, pretrained=True): 145 | opt = {'num_classes':4, 'num_stages': 4} 146 | net = create_model(opt) 147 | if pretrained: 148 | snapshot = '/disk2/caoqinglong/CIFAR10_RotNet_NIN4blocks/CIFAR10_RotNet_NIN4blocks/model_net_epoch200.pth' 149 | checkpoint = torch.load(snapshot) 150 | a,b = checkpoint.items() 151 | net.load_state_dict(b[1]) 152 | net.train() 153 | return net 154 | if __name__ == '__main__': 155 | size = 321 156 | opt = {'num_classes':4, 'num_stages': 4} 157 | 158 | net = create_model(opt) 159 | 160 | snapshot = '/disk2/caoqinglong/CIFAR10_RotNet_NIN4blocks/CIFAR10_RotNet_NIN4blocks/model_net_epoch200.pth' 161 | checkpoint = torch.load(snapshot) 162 | a,b = checkpoint.items() 163 | # model = nn.DataParallel(model) 164 | new_state_dict = OrderedDict() 165 | for k , v in b[1].items(): 166 | new_state_dict[k] = v 167 | # load params 168 | net.load_state_dict(new_state_dict) 169 | 170 | x = torch.autograd.Variable(torch.FloatTensor(1,3,size,size).uniform_(-1,1)) 171 | 172 | out = net(x, out_feat_keys=net.all_feat_names) 173 | for f in range(len(out)): 174 | print('Output feature {0} - size {1}'.format( 175 | net.all_feat_names[f], out[f].size())) 176 | 177 | 178 | out = net(x) 179 | print('Final output: {0}'.format(out.size())) 180 | -------------------------------------------------------------------------------- /models/backbone/__pycache__/NetworkInNetwork.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoql98/progressively-dual-prior-guided-few-shot-semantic-segmentation/aaf2e7d2dca4a52972c8bff46e3a4bc6d775a946/models/backbone/__pycache__/NetworkInNetwork.cpython-36.pyc -------------------------------------------------------------------------------- /models/backbone/__pycache__/NetworkInNetwork.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoql98/progressively-dual-prior-guided-few-shot-semantic-segmentation/aaf2e7d2dca4a52972c8bff46e3a4bc6d775a946/models/backbone/__pycache__/NetworkInNetwork.cpython-37.pyc -------------------------------------------------------------------------------- /models/backbone/__pycache__/resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoql98/progressively-dual-prior-guided-few-shot-semantic-segmentation/aaf2e7d2dca4a52972c8bff46e3a4bc6d775a946/models/backbone/__pycache__/resnet.cpython-36.pyc -------------------------------------------------------------------------------- /models/backbone/__pycache__/resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoql98/progressively-dual-prior-guided-few-shot-semantic-segmentation/aaf2e7d2dca4a52972c8bff46e3a4bc6d775a946/models/backbone/__pycache__/resnet.cpython-37.pyc -------------------------------------------------------------------------------- /models/backbone/__pycache__/resnet_dialated.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoql98/progressively-dual-prior-guided-few-shot-semantic-segmentation/aaf2e7d2dca4a52972c8bff46e3a4bc6d775a946/models/backbone/__pycache__/resnet_dialated.cpython-36.pyc -------------------------------------------------------------------------------- /models/backbone/__pycache__/resnet_dialated.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoql98/progressively-dual-prior-guided-few-shot-semantic-segmentation/aaf2e7d2dca4a52972c8bff46e3a4bc6d775a946/models/backbone/__pycache__/resnet_dialated.cpython-37.pyc -------------------------------------------------------------------------------- /models/backbone/__pycache__/resnet_dialated4.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoql98/progressively-dual-prior-guided-few-shot-semantic-segmentation/aaf2e7d2dca4a52972c8bff46e3a4bc6d775a946/models/backbone/__pycache__/resnet_dialated4.cpython-36.pyc -------------------------------------------------------------------------------- /models/backbone/__pycache__/resnet_dialated4.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoql98/progressively-dual-prior-guided-few-shot-semantic-segmentation/aaf2e7d2dca4a52972c8bff46e3a4bc6d775a946/models/backbone/__pycache__/resnet_dialated4.cpython-37.pyc -------------------------------------------------------------------------------- /models/backbone/__pycache__/resnet_dialated_fuse.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoql98/progressively-dual-prior-guided-few-shot-semantic-segmentation/aaf2e7d2dca4a52972c8bff46e3a4bc6d775a946/models/backbone/__pycache__/resnet_dialated_fuse.cpython-37.pyc -------------------------------------------------------------------------------- /models/backbone/__pycache__/vgg.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoql98/progressively-dual-prior-guided-few-shot-semantic-segmentation/aaf2e7d2dca4a52972c8bff46e3a4bc6d775a946/models/backbone/__pycache__/vgg.cpython-37.pyc -------------------------------------------------------------------------------- /models/backbone/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import torch.utils.model_zoo as model_zoo 5 | import torchvision 6 | BatchNorm = nn.BatchNorm2d 7 | 8 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 9 | 'resnet152'] 10 | 11 | 12 | model_urls = { 13 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 14 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 15 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 16 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 17 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 18 | } 19 | 20 | 21 | def conv3x3(in_planes, out_planes, stride=1): 22 | """3x3 convolution with padding""" 23 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 24 | padding=1, bias=False) 25 | 26 | 27 | class BasicBlock(nn.Module): 28 | expansion = 1 29 | 30 | def __init__(self, inplanes, planes, stride=1, downsample=None): 31 | super(BasicBlock, self).__init__() 32 | self.conv1 = conv3x3(inplanes, planes, stride) 33 | self.bn1 = BatchNorm(planes) 34 | self.relu = nn.ReLU(inplace=True) 35 | self.conv2 = conv3x3(planes, planes) 36 | self.bn2 = BatchNorm(planes) 37 | self.downsample = downsample 38 | self.stride = stride 39 | 40 | def forward(self, x): 41 | residual = x 42 | 43 | out = self.conv1(x) 44 | out = self.bn1(out) 45 | out = self.relu(out) 46 | 47 | out = self.conv2(out) 48 | out = self.bn2(out) 49 | 50 | if self.downsample is not None: 51 | residual = self.downsample(x) 52 | 53 | out += residual 54 | out = self.relu(out) 55 | 56 | return out 57 | 58 | 59 | class Bottleneck(nn.Module): 60 | expansion = 4 61 | 62 | def __init__(self, inplanes, planes, stride=1, downsample=None): 63 | super(Bottleneck, self).__init__() 64 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 65 | self.bn1 = BatchNorm(planes) 66 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 67 | padding=1, bias=False) 68 | self.bn2 = BatchNorm(planes) 69 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 70 | self.bn3 = BatchNorm(planes * self.expansion) 71 | self.relu = nn.ReLU(inplace=True) 72 | self.downsample = downsample 73 | self.stride = stride 74 | 75 | def forward(self, x): 76 | residual = x 77 | 78 | out = self.conv1(x) 79 | out = self.bn1(out) 80 | out = self.relu(out) 81 | 82 | out = self.conv2(out) 83 | out = self.bn2(out) 84 | out = self.relu(out) 85 | 86 | out = self.conv3(out) 87 | out = self.bn3(out) 88 | 89 | if self.downsample is not None: 90 | residual = self.downsample(x) 91 | 92 | out += residual 93 | out = self.relu(out) 94 | 95 | return out 96 | 97 | 98 | class ResNet(nn.Module): 99 | 100 | def __init__(self, block, layers, num_classes=1000, deep_base=True): 101 | super(ResNet, self).__init__() 102 | self.deep_base = deep_base 103 | if not self.deep_base: 104 | self.inplanes = 64 105 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 106 | self.bn1 = BatchNorm(64) 107 | self.relu = nn.ReLU(inplace=True) 108 | else: 109 | self.inplanes = 128 110 | self.conv1 = conv3x3(3, 64, stride=2) 111 | self.bn1 = BatchNorm(64) 112 | self.relu1 = nn.ReLU(inplace=True) 113 | self.conv2 = conv3x3(64, 64) 114 | self.bn2 = BatchNorm(64) 115 | self.relu2 = nn.ReLU(inplace=True) 116 | self.conv3 = conv3x3(64, 128) 117 | self.bn3 = BatchNorm(128) 118 | self.relu3 = nn.ReLU(inplace=True) 119 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 120 | self.layer1 = self._make_layer(block, 64, layers[0]) 121 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 122 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 123 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 124 | self.avgpool = nn.AvgPool2d(7, stride=1) 125 | self.fc = nn.Linear(512 * block.expansion, num_classes) 126 | 127 | for m in self.modules(): 128 | if isinstance(m, nn.Conv2d): 129 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 130 | elif isinstance(m, BatchNorm): 131 | nn.init.constant_(m.weight, 1) 132 | nn.init.constant_(m.bias, 0) 133 | 134 | def _make_layer(self, block, planes, blocks, stride=1): 135 | downsample = None 136 | if stride != 1 or self.inplanes != planes * block.expansion: 137 | downsample = nn.Sequential( 138 | nn.Conv2d(self.inplanes, planes * block.expansion, 139 | kernel_size=1, stride=stride, bias=False), 140 | BatchNorm(planes * block.expansion), 141 | ) 142 | 143 | layers = [] 144 | layers.append(block(self.inplanes, planes, stride, downsample)) 145 | self.inplanes = planes * block.expansion 146 | for i in range(1, blocks): 147 | layers.append(block(self.inplanes, planes)) 148 | 149 | return nn.Sequential(*layers) 150 | 151 | def forward(self, x): 152 | x = self.relu1(self.bn1(self.conv1(x))) 153 | if self.deep_base: 154 | x = self.relu2(self.bn2(self.conv2(x))) 155 | x = self.relu3(self.bn3(self.conv3(x))) 156 | x = self.maxpool(x) 157 | 158 | x = self.layer1(x) 159 | x = self.layer2(x) 160 | x = self.layer3(x) 161 | x = self.layer4(x) 162 | 163 | x = self.avgpool(x) 164 | x = x.view(x.size(0), -1) 165 | x = self.fc(x) 166 | 167 | return x 168 | 169 | 170 | def resnet18(pretrained=False, **kwargs): 171 | """Constructs a ResNet-18 model. 172 | 173 | Args: 174 | pretrained (bool): If True, returns a model pre-trained on ImageNet 175 | """ 176 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 177 | if pretrained: 178 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 179 | return model 180 | 181 | 182 | def resnet34(pretrained=False, **kwargs): 183 | """Constructs a ResNet-34 model. 184 | 185 | Args: 186 | pretrained (bool): If True, returns a model pre-trained on ImageNet 187 | """ 188 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 189 | if pretrained: 190 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 191 | return model 192 | 193 | 194 | def resnet50(pretrained=True, **kwargs): 195 | """Constructs a ResNet-50 model. 196 | 197 | Args: 198 | pretrained (bool): If True, returns a model pre-trained on ImageNet 199 | """ 200 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 201 | if pretrained: 202 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 203 | model_path = '/disk2/caoqinglong/initial_models/resnet50_v2.pth' 204 | # resnet50 = torchvision.models.resnet50(pretrained=True) 205 | # model.load_state_dict(resnet50) 206 | model.load_state_dict(torch.load(model_path), strict=False) 207 | return model 208 | 209 | 210 | def resnet101(pretrained=False, **kwargs): 211 | """Constructs a ResNet-101 model. 212 | 213 | Args: 214 | pretrained (bool): If True, returns a model pre-trained on ImageNet 215 | """ 216 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 217 | if pretrained: 218 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 219 | model_path = './initmodel/resnet101_v2.pth' 220 | model.load_state_dict(torch.load(model_path), strict=False) 221 | return model 222 | 223 | 224 | def resnet152(pretrained=False, **kwargs): 225 | """Constructs a ResNet-152 model. 226 | 227 | Args: 228 | pretrained (bool): If True, returns a model pre-trained on ImageNet 229 | """ 230 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 231 | if pretrained: 232 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 233 | model_path = './initmodel/resnet152_v2.pth' 234 | model.load_state_dict(torch.load(model_path), strict=False) 235 | return model 236 | -------------------------------------------------------------------------------- /models/backbone/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | BatchNorm = nn.BatchNorm2d 6 | 7 | __all__ = [ 8 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 9 | 'vgg19_bn', 'vgg19', 10 | ] 11 | 12 | 13 | model_urls = { 14 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 15 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 16 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 17 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 18 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 19 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', 20 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 21 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', 22 | } 23 | 24 | 25 | class VGG(nn.Module): 26 | 27 | def __init__(self, features, num_classes=1000, init_weights=True): 28 | super(VGG, self).__init__() 29 | self.features = features 30 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 31 | self.classifier = nn.Sequential( 32 | nn.Linear(512 * 7 * 7, 4096), 33 | nn.ReLU(True), 34 | nn.Dropout(), 35 | nn.Linear(4096, 4096), 36 | nn.ReLU(True), 37 | nn.Dropout(), 38 | nn.Linear(4096, num_classes), 39 | ) 40 | if init_weights: 41 | self._initialize_weights() 42 | 43 | def forward(self, x): 44 | x = self.features(x) 45 | x = self.avgpool(x) 46 | x = x.view(x.size(0), -1) 47 | x = self.classifier(x) 48 | return x 49 | 50 | def _initialize_weights(self): 51 | for m in self.modules(): 52 | if isinstance(m, nn.Conv2d): 53 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 54 | if m.bias is not None: 55 | nn.init.constant_(m.bias, 0) 56 | elif isinstance(m, BatchNorm): 57 | nn.init.constant_(m.weight, 1) 58 | nn.init.constant_(m.bias, 0) 59 | elif isinstance(m, nn.Linear): 60 | nn.init.normal_(m.weight, 0, 0.01) 61 | nn.init.constant_(m.bias, 0) 62 | 63 | 64 | def make_layers(cfg, batch_norm=False): 65 | layers = [] 66 | in_channels = 3 67 | for v in cfg: 68 | if v == 'M': 69 | layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)] 70 | else: 71 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 72 | if batch_norm: 73 | layers += [conv2d, BatchNorm(v), nn.ReLU(inplace=True)] 74 | else: 75 | layers += [conv2d, nn.ReLU(inplace=True)] 76 | in_channels = v 77 | return nn.Sequential(*layers) 78 | 79 | 80 | cfg = { 81 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 82 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 83 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 84 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 85 | } 86 | 87 | 88 | def vgg11(pretrained=False, **kwargs): 89 | """VGG 11-layer model (configuration "A") 90 | Args: 91 | pretrained (bool): If True, returns a model pre-trained on ImageNet 92 | """ 93 | if pretrained: 94 | kwargs['init_weights'] = False 95 | model = VGG(make_layers(cfg['A']), **kwargs) 96 | if pretrained: 97 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11'])) 98 | return model 99 | 100 | 101 | def vgg11_bn(pretrained=False, **kwargs): 102 | """VGG 11-layer model (configuration "A") with batch normalization 103 | Args: 104 | pretrained (bool): If True, returns a model pre-trained on ImageNet 105 | """ 106 | if pretrained: 107 | kwargs['init_weights'] = False 108 | model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs) 109 | if pretrained: 110 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn'])) 111 | return model 112 | 113 | 114 | def vgg13(pretrained=False, **kwargs): 115 | """VGG 13-layer model (configuration "B") 116 | Args: 117 | pretrained (bool): If True, returns a model pre-trained on ImageNet 118 | """ 119 | if pretrained: 120 | kwargs['init_weights'] = False 121 | model = VGG(make_layers(cfg['B']), **kwargs) 122 | if pretrained: 123 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13'])) 124 | return model 125 | 126 | 127 | def vgg13_bn(pretrained=False, **kwargs): 128 | """VGG 13-layer model (configuration "B") with batch normalization 129 | Args: 130 | pretrained (bool): If True, returns a model pre-trained on ImageNet 131 | """ 132 | if pretrained: 133 | kwargs['init_weights'] = False 134 | model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs) 135 | if pretrained: 136 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn'])) 137 | return model 138 | 139 | 140 | def vgg16(pretrained=False, **kwargs): 141 | """VGG 16-layer model (configuration "D") 142 | Args: 143 | pretrained (bool): If True, returns a model pre-trained on ImageNet 144 | """ 145 | if pretrained: 146 | kwargs['init_weights'] = False 147 | model = VGG(make_layers(cfg['D']), **kwargs) 148 | if pretrained: 149 | #model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn'])) 150 | model_path = './initmodel/vgg16.pth' 151 | model.load_state_dict(torch.load(model_path), strict=False) 152 | return model 153 | 154 | 155 | def vgg16_bn(pretrained=False, **kwargs): 156 | """VGG 16-layer model (configuration "D") with batch normalization 157 | Args: 158 | pretrained (bool): If True, returns a model pre-trained on ImageNet 159 | """ 160 | if pretrained: 161 | kwargs['init_weights'] = False 162 | model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs) 163 | if pretrained: 164 | #model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn'])) 165 | model_path = '/disk2/caoqinglong/initial_models/vgg16_bn.pth' 166 | model.load_state_dict(torch.load(model_path), strict=False) 167 | return model 168 | 169 | 170 | def vgg19(pretrained=False, **kwargs): 171 | """VGG 19-layer model (configuration "E") 172 | Args: 173 | pretrained (bool): If True, returns a model pre-trained on ImageNet 174 | """ 175 | if pretrained: 176 | kwargs['init_weights'] = False 177 | model = VGG(make_layers(cfg['E']), **kwargs) 178 | if pretrained: 179 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19'])) 180 | return model 181 | 182 | 183 | def vgg19_bn(pretrained=False, **kwargs): 184 | """VGG 19-layer model (configuration 'E') with batch normalization 185 | Args: 186 | pretrained (bool): If True, returns a model pre-trained on ImageNet 187 | """ 188 | if pretrained: 189 | kwargs['init_weights'] = False 190 | model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs) 191 | if pretrained: 192 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19_bn'])) 193 | return model 194 | 195 | if __name__ =='__main__': 196 | import os 197 | os.environ["CUDA_VISIBLE_DEVICES"] = '7' 198 | input = torch.rand(4, 3, 473, 473).cuda() 199 | target = torch.rand(4, 473, 473).cuda()*1.0 200 | model = vgg16_bn(pretrained=False).cuda() 201 | model.train() 202 | layer0_idx = range(0,6) 203 | layer1_idx = range(6,13) 204 | layer2_idx = range(13,23) 205 | layer3_idx = range(23,33) 206 | layer4_idx = range(34,43) 207 | #layer4_idx = range(34,43) 208 | print(model.features) 209 | layers_0 = [] 210 | layers_1 = [] 211 | layers_2 = [] 212 | layers_3 = [] 213 | layers_4 = [] 214 | for idx in layer0_idx: 215 | layers_0 += [model.features[idx]] 216 | for idx in layer1_idx: 217 | layers_1 += [model.features[idx]] 218 | for idx in layer2_idx: 219 | layers_2 += [model.features[idx]] 220 | for idx in layer3_idx: 221 | layers_3 += [model.features[idx]] 222 | for idx in layer4_idx: 223 | layers_4 += [model.features[idx]] 224 | 225 | layer0 = nn.Sequential(*layers_0) 226 | layer1 = nn.Sequential(*layers_1) 227 | layer2 = nn.Sequential(*layers_2) 228 | layer3 = nn.Sequential(*layers_3) 229 | layer4 = nn.Sequential(*layers_4) 230 | 231 | output = layer0(input) 232 | print(layer0) 233 | print('layer 0: {}'.format(output.size())) 234 | output = layer1(output) 235 | print(layer1) 236 | print('layer 1: {}'.format(output.size())) 237 | output = layer2(output) 238 | print(layer2) 239 | print('layer 2: {}'.format(output.size())) 240 | output = layer3(output) 241 | print(layer3) 242 | print('layer 3: {}'.format(output.size())) 243 | output = layer4(output) 244 | print(layer4) 245 | print('layer 4: {}'.format(output.size())) 246 | -------------------------------------------------------------------------------- /networks/FPMMs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from models.backbone import resnet_dialated as resnet 5 | from models import ASPP 6 | from models.PMMs import PMMs 7 | from models.mRN import mRN 8 | # The Code of baseline network is referenced from https://github.com/icoz69/CaNet 9 | # The code of training & testing is referenced from https://github.com/xiaomengyc/SG-One 10 | 11 | class OneModel(nn.Module): 12 | def __init__(self, args): 13 | 14 | self.inplanes = 64 15 | self.num_pro = 3 16 | super(OneModel, self).__init__() 17 | 18 | self.model_res = resnet.Res50_Deeplab(pretrained=True) 19 | self.layer5 = nn.Sequential( 20 | nn.Conv2d(in_channels=1536, out_channels=256, kernel_size=3, stride=1, padding=2, dilation=2, bias=True), 21 | nn.BatchNorm2d(256), 22 | nn.ReLU()) 23 | 24 | self.layer55 = nn.Sequential( 25 | nn.Conv2d(in_channels=256 * 2, out_channels=256, kernel_size=3, stride=1, padding=2, dilation=2, 26 | bias=True), 27 | nn.BatchNorm2d(256), 28 | nn.ReLU() 29 | ) 30 | 31 | self.layer56 = nn.Sequential( 32 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, dilation=1, 33 | bias=True), 34 | nn.BatchNorm2d(256), 35 | nn.ReLU() 36 | ) 37 | 38 | self.layer6 = ASPP.PSPnet() 39 | 40 | self.layer7 = nn.Sequential( 41 | nn.Conv2d(1280, 256, kernel_size=1, stride=1, padding=0, bias=True), 42 | nn.BatchNorm2d(256), 43 | nn.ReLU() 44 | 45 | ) 46 | 47 | self.layer9 = nn.Conv2d(256, 2, kernel_size=1, stride=1, bias=True) # numclass = 2 48 | 49 | self.residule1 = nn.Sequential( 50 | nn.ReLU(), 51 | nn.Conv2d(256+2, 256, kernel_size=3, stride=1, padding=1, bias=True), 52 | nn.ReLU(), 53 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True) 54 | ) 55 | 56 | self.residule2 = nn.Sequential( 57 | nn.ReLU(), 58 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True), 59 | nn.ReLU(), 60 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True) 61 | ) 62 | 63 | self.residule3 = nn.Sequential( 64 | nn.ReLU(), 65 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True), 66 | nn.ReLU(), 67 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True) 68 | ) 69 | self.PMMs = PMMs(256, self.num_pro).cuda() 70 | 71 | self.batch_size = args.batch_size 72 | 73 | def forward(self, query_rgb, support_rgb, support_mask): 74 | # extract support_ feature 75 | support_feature = self.extract_feature_res(support_rgb) 76 | 77 | # extract query feature 78 | query_feature = self.extract_feature_res(query_rgb) 79 | b,c,h,w = query_feature.shape 80 | mRN_model = mRN(c).cuda() 81 | exit_feat_in, Prob_map = mRN_model(support_feature, support_mask, query_feature) 82 | vec_pos = Prob_map 83 | # # PMMs 84 | # vec_pos, Prob_map = self.PMMs(support_feature, support_mask, query_feature) 85 | 86 | # # feature concate 87 | # feature_size = query_feature.shape[-2:] 88 | 89 | 90 | # for i in range(self.num_pro): 91 | 92 | # vec = vec_pos[i] 93 | # exit_feat_in_ = self.f_v_concate(query_feature, vec, feature_size) 94 | # exit_feat_in_ = self.layer55(exit_feat_in_) 95 | # if i == 0: 96 | # exit_feat_in = exit_feat_in_ 97 | # else: 98 | # exit_feat_in = exit_feat_in + exit_feat_in_ 99 | # exit_feat_in = self.layer56(exit_feat_in) 100 | 101 | 102 | # segmentation 103 | out, _ = self.Segmentation(exit_feat_in, Prob_map) 104 | 105 | return support_feature, query_feature, vec_pos, out 106 | 107 | def forward_5shot(self, query_rgb, support_rgb_batch, support_mask_batch): 108 | # extract query feature 109 | query_feature = self.extract_feature_res(query_rgb) 110 | # feature concate 111 | feature_size = query_feature.shape[-2:] 112 | 113 | for i in range(support_rgb_batch.shape[1]): 114 | support_rgb = support_rgb_batch[:, i] 115 | support_mask = support_mask_batch[:, i] 116 | # extract support feature 117 | support_feature = self.extract_feature_res(support_rgb) 118 | support_mask_temp = F.interpolate(support_mask, support_feature.shape[-2:], mode='bilinear', 119 | align_corners=True) 120 | if i == 0: 121 | support_feature_all = support_feature 122 | support_mask_all = support_mask_temp 123 | else: 124 | support_feature_all = torch.cat([support_feature_all, support_feature], dim=2) 125 | support_mask_all = torch.cat([support_mask_all, support_mask_temp], dim=2) 126 | 127 | vec_pos, Prob_map = self.PMMs(support_feature_all, support_mask_all, query_feature) 128 | 129 | for i in range(self.num_pro): 130 | vec = vec_pos[i] 131 | exit_feat_in_ = self.f_v_concate(query_feature, vec, feature_size) 132 | exit_feat_in_ = self.layer55(exit_feat_in_) 133 | if i == 0: 134 | exit_feat_in = exit_feat_in_ 135 | else: 136 | exit_feat_in = exit_feat_in + exit_feat_in_ 137 | 138 | exit_feat_in = self.layer56(exit_feat_in) 139 | 140 | out, _ = self.Segmentation(exit_feat_in, Prob_map) 141 | 142 | return out, out, out, out 143 | 144 | def extract_feature_res(self, rgb): 145 | out_resnet = self.model_res(rgb) 146 | stage2_out = out_resnet[1] 147 | stage3_out = out_resnet[2] 148 | out_23 = torch.cat([stage2_out, stage3_out], dim=1) 149 | feature = self.layer5(out_23) 150 | 151 | return feature 152 | 153 | def f_v_concate(self, feature, vec_pos, feature_size): 154 | fea_pos = vec_pos.expand(-1, -1, feature_size[0], feature_size[1]) # tile for cat 155 | exit_feat_in = torch.cat([feature, fea_pos], dim=1) 156 | 157 | return exit_feat_in 158 | 159 | def Segmentation(self, feature, history_mask): 160 | feature_size = feature.shape[-2:] 161 | 162 | history_mask = F.interpolate(history_mask, feature_size, mode='bilinear', align_corners=True) 163 | out = feature 164 | out_plus_history = torch.cat([out, history_mask], dim=1) 165 | out = out + self.residule1(out_plus_history) 166 | out = out + self.residule2(out) 167 | # out = out + self.residule3(out) 168 | 169 | # out = self.layer6(out) 170 | # out = self.layer7(out) 171 | out = self.layer9(out) 172 | 173 | out_softmax = F.softmax(out, dim=1) 174 | 175 | return out, out_softmax 176 | 177 | def get_loss(self, logits, query_label, idx): 178 | bce_logits_func = nn.CrossEntropyLoss() 179 | outB, outA_pos, vec, outB_side = logits 180 | 181 | b, c, w, h = query_label.size() 182 | outB_side = F.upsample(outB_side, size=(w, h), mode='bilinear') 183 | 184 | # add 185 | query_label = query_label.view(b, -1) 186 | bb, cc, _, _ = outB_side.size() 187 | outB_side = outB_side.view(b, cc, w * h) 188 | # 189 | 190 | loss_bce_seg = bce_logits_func(outB_side, query_label.long()) 191 | 192 | loss = loss_bce_seg 193 | 194 | return loss, 0, 0 195 | 196 | def get_pred(self, logits, query_image): 197 | outB, outA_pos, outB_side1, outB_side = logits 198 | w, h = query_image.size()[-2:] 199 | # print(w) 200 | # print(outB_side.shape) 201 | outB_side = F.upsample(outB_side, size=(w, h), mode='bilinear') 202 | out_softmax = F.softmax(outB_side, dim=1) 203 | values, pred = torch.max(out_softmax, dim=1) 204 | # print(pred.sum(dim=[1,2])) 205 | return out_softmax, pred 206 | -------------------------------------------------------------------------------- /networks/FRPMMs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from models.backbone import resnet_dialated as resnet 5 | from models import ASPP 6 | from models.PMMs import PMMs 7 | 8 | # The Code of baseline network is referenced from https://github.com/icoz69/CaNet 9 | # The code of training & testing is referenced from https://github.com/xiaomengyc/SG-One 10 | 11 | class OneModel(nn.Module): 12 | def __init__(self, args): 13 | 14 | self.inplanes = 64 15 | self.num_pro_list = [1,3,6] 16 | self.num_pro = self.num_pro_list[0] 17 | super(OneModel, self).__init__() 18 | 19 | self.model_res = resnet.Res50_Deeplab(pretrained=True) 20 | self.layer5 = nn.Sequential( 21 | nn.Conv2d(in_channels=1536, out_channels=256, kernel_size=3, stride=1, padding=2, dilation=2, bias=True), 22 | nn.BatchNorm2d(256), 23 | nn.ReLU()) 24 | 25 | self.layer55 = nn.Sequential( 26 | nn.Conv2d(in_channels=256 * 2, out_channels=256, kernel_size=3, stride=1, padding=2, dilation=2, 27 | bias=True), 28 | nn.ReLU(), 29 | nn.Dropout2d(p=0.5), 30 | ) 31 | 32 | self.layer56 = nn.Sequential( 33 | nn.Conv2d(in_channels=256+2, out_channels=256, kernel_size=3, stride=1, padding=1, dilation=1, 34 | bias=True), 35 | nn.ReLU(), 36 | nn.Dropout2d(p=0.5), 37 | ) 38 | 39 | self.layer6 = ASPP.PSPnet() 40 | 41 | self.layer7 = nn.Sequential( 42 | nn.Conv2d(1280, 256, kernel_size=1, stride=1, padding=0, bias=True), 43 | nn.ReLU(), 44 | nn.Dropout2d(p=0.5), 45 | 46 | ) 47 | 48 | self.layer9 = nn.Conv2d(256, 2, kernel_size=1, stride=1, bias=True) # numclass = 2 49 | 50 | self.residule1 = nn.Sequential( 51 | nn.ReLU(), 52 | nn.Conv2d(256 + 2, 256, kernel_size=3, stride=1, padding=1, bias=True), 53 | nn.ReLU(), 54 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True) 55 | ) 56 | 57 | self.residule2 = nn.Sequential( 58 | nn.ReLU(), 59 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True), 60 | nn.ReLU(), 61 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True) 62 | ) 63 | 64 | self.residule3 = nn.Sequential( 65 | nn.ReLU(), 66 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True), 67 | nn.ReLU(), 68 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True) 69 | ) 70 | 71 | self.batch_size = args.batch_size 72 | 73 | def forward(self, query_rgb, support_rgb, support_mask): 74 | # extract support feature 75 | support_feature = self.extract_feature_res(support_rgb) 76 | 77 | # extract query feature 78 | query_feature = self.extract_feature_res(query_rgb) 79 | 80 | feature_size = query_feature.shape[-2:] 81 | 82 | # feature concate 83 | Pseudo_mask = (torch.zeros(self.batch_size, 2, 50, 50)).cuda() 84 | out_list = [] 85 | for num in self.num_pro_list: 86 | self.num_pro = num 87 | self.PMMs = PMMs(256, num).cuda() 88 | vec_pos, Prob_map = self.PMMs(support_feature, support_mask, query_feature) 89 | 90 | for i in range(num): 91 | vec = vec_pos[i] 92 | exit_feat_in_ = self.f_v_concate(query_feature, vec, feature_size) 93 | exit_feat_in_ = self.layer55(exit_feat_in_) 94 | if i == 0: 95 | exit_feat_in = exit_feat_in_ 96 | else: 97 | exit_feat_in = exit_feat_in + exit_feat_in_ 98 | 99 | 100 | exit_feat_in = torch.cat([exit_feat_in, Prob_map], dim=1) 101 | exit_feat_in = self.layer56(exit_feat_in) 102 | 103 | # segmentation 104 | out, out_softmax = self.Segmentation(exit_feat_in, Pseudo_mask) 105 | Pseudo_mask = out_softmax 106 | out_list.append(out) 107 | 108 | return support_feature, out_list[0], out_list[1], out 109 | 110 | def forward_5shot(self, query_rgb, support_rgb_batch, support_mask_batch): 111 | # extract query feature 112 | query_feature = self.extract_feature_res(query_rgb) 113 | 114 | feature_size = query_feature.shape[-2:] 115 | 116 | out5 = 0 117 | 118 | for i in range(support_rgb_batch.shape[1]): 119 | support_rgb = support_rgb_batch[:, i] 120 | support_mask = support_mask_batch[:, i] 121 | # extract support feature 122 | support_feature = self.extract_feature_res(support_rgb) 123 | support_mask_temp = F.interpolate(support_mask, support_feature.shape[-2:], mode='bilinear', 124 | align_corners=True) 125 | if i == 0: 126 | support_feature_all = support_feature 127 | support_mask_all = support_mask_temp 128 | else: 129 | support_feature_all = torch.cat([support_feature_all, support_feature], dim=2) 130 | support_mask_all = torch.cat([support_mask_all, support_mask_temp], dim=2) 131 | 132 | Pseudo_mask = (torch.zeros(self.batch_size, 2, 50, 50)).cuda() 133 | for num in self.num_pro_list: 134 | self.num_pro = num 135 | self.PMMs = PMMs(256, num).cuda() 136 | vec_pos, Prob_map = self.PMMs(support_feature_all, support_mask_all, query_feature) 137 | # vector conduct feature 138 | for i in range(num): 139 | vec = vec_pos[i] 140 | exit_feat_in_ = self.f_v_concate(query_feature, vec, feature_size) 141 | exit_feat_in_ = self.layer55(exit_feat_in_) 142 | if i == 0: 143 | exit_feat_in = exit_feat_in_ 144 | else: 145 | exit_feat_in = exit_feat_in + exit_feat_in_ 146 | 147 | exit_feat_in = torch.cat([exit_feat_in, Prob_map], dim=1) 148 | exit_feat_in = self.layer56(exit_feat_in) 149 | 150 | # segmentation 151 | out, out_softmax = self.Segmentation(exit_feat_in, Pseudo_mask) 152 | Pseudo_mask = out_softmax 153 | 154 | out5 = out5 + out_softmax 155 | out5 = out5 / 5 156 | return out5, out5, out5, out5 157 | 158 | return logits 159 | 160 | 161 | def extract_feature_res(self, rgb): 162 | out_resnet = self.model_res(rgb) 163 | stage2_out = out_resnet[1] 164 | stage3_out = out_resnet[2] 165 | out_23 = torch.cat([stage2_out, stage3_out], dim=1) 166 | feature = self.layer5(out_23) 167 | 168 | return feature 169 | 170 | def f_v_concate(self, feature, vec_pos, feature_size): 171 | fea_pos = vec_pos.expand(-1, -1, feature_size[0], feature_size[1]) # tile for cat 172 | exit_feat_in = torch.cat([feature, fea_pos], dim=1) 173 | 174 | return exit_feat_in 175 | 176 | def Segmentation(self, feature, history_mask): 177 | feature_size = feature.shape[-2:] 178 | 179 | history_mask = F.interpolate(history_mask, feature_size, mode='bilinear', align_corners=True) 180 | out = feature 181 | out_plus_history = torch.cat([feature, history_mask], dim=1) 182 | out = out + self.residule1(out_plus_history) 183 | out = out + self.residule2(out) 184 | out = out + self.residule3(out) 185 | 186 | out = self.layer6(out) 187 | out = self.layer7(out) 188 | out = self.layer9(out) 189 | 190 | out_softmax = F.softmax(out, dim=1) 191 | 192 | return out , out_softmax 193 | 194 | def get_loss(self, logits, query_label, idx): 195 | bce_logits_func = nn.CrossEntropyLoss() 196 | support_feature, out0, out1, outB_side = logits 197 | 198 | b, c, w, h = query_label.size() 199 | out0 = F.upsample(out0, size=(w, h), mode='bilinear') 200 | out1 = F.upsample(out1, size=(w, h), mode='bilinear') 201 | outB_side = F.upsample(outB_side, size=(w, h), mode='bilinear') 202 | 203 | bb, cc, _, _ = outB_side.size() 204 | 205 | out0 = out0.view(b, cc, w * h) 206 | out1 = out1.view(b, cc, w * h) 207 | outB_side = outB_side.view(b, cc, w * h) 208 | query_label = query_label.view(b, -1) 209 | 210 | loss_bce_seg0 = bce_logits_func(out0, query_label.long()) 211 | loss_bce_seg1 = bce_logits_func(out1, query_label.long()) 212 | loss_bce_seg2 = bce_logits_func(outB_side, query_label.long()) 213 | 214 | loss = loss_bce_seg0+loss_bce_seg1+loss_bce_seg2 215 | 216 | return loss, loss_bce_seg2, loss_bce_seg1 217 | 218 | def get_pred(self, logits, query_image): 219 | outB, outA_pos, outB_side1, outB_side = logits 220 | w, h = query_image.size()[-2:] 221 | outB_side = F.upsample(outB_side, size=(w, h), mode='bilinear') 222 | out_softmax = F.softmax(outB_side, dim=1) 223 | values, pred = torch.max(out_softmax, dim=1) 224 | return out_softmax, pred 225 | 226 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | 3 | __all__=[ 4 | 'FPMMs','FRPMMs','VGG16based','resnet50based','resnet50based4','resnet50basedf', 5 | 'resnet50_34','resnet50_34_all','resnet50_34_s' 6 | ] -------------------------------------------------------------------------------- /networks/__pycache__/FPMMs.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoql98/progressively-dual-prior-guided-few-shot-semantic-segmentation/aaf2e7d2dca4a52972c8bff46e3a4bc6d775a946/networks/__pycache__/FPMMs.cpython-36.pyc -------------------------------------------------------------------------------- /networks/__pycache__/FPMMs.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoql98/progressively-dual-prior-guided-few-shot-semantic-segmentation/aaf2e7d2dca4a52972c8bff46e3a4bc6d775a946/networks/__pycache__/FPMMs.cpython-37.pyc -------------------------------------------------------------------------------- /networks/__pycache__/FRPMMs.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoql98/progressively-dual-prior-guided-few-shot-semantic-segmentation/aaf2e7d2dca4a52972c8bff46e3a4bc6d775a946/networks/__pycache__/FRPMMs.cpython-36.pyc -------------------------------------------------------------------------------- /networks/__pycache__/FRPMMs.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoql98/progressively-dual-prior-guided-few-shot-semantic-segmentation/aaf2e7d2dca4a52972c8bff46e3a4bc6d775a946/networks/__pycache__/FRPMMs.cpython-37.pyc -------------------------------------------------------------------------------- /networks/__pycache__/VGG16based.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoql98/progressively-dual-prior-guided-few-shot-semantic-segmentation/aaf2e7d2dca4a52972c8bff46e3a4bc6d775a946/networks/__pycache__/VGG16based.cpython-36.pyc -------------------------------------------------------------------------------- /networks/__pycache__/VGG16based.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoql98/progressively-dual-prior-guided-few-shot-semantic-segmentation/aaf2e7d2dca4a52972c8bff46e3a4bc6d775a946/networks/__pycache__/VGG16based.cpython-37.pyc -------------------------------------------------------------------------------- /networks/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoql98/progressively-dual-prior-guided-few-shot-semantic-segmentation/aaf2e7d2dca4a52972c8bff46e3a4bc6d775a946/networks/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /networks/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoql98/progressively-dual-prior-guided-few-shot-semantic-segmentation/aaf2e7d2dca4a52972c8bff46e3a4bc6d775a946/networks/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /networks/__pycache__/resnet50_34.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoql98/progressively-dual-prior-guided-few-shot-semantic-segmentation/aaf2e7d2dca4a52972c8bff46e3a4bc6d775a946/networks/__pycache__/resnet50_34.cpython-36.pyc -------------------------------------------------------------------------------- /networks/__pycache__/resnet50_34.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoql98/progressively-dual-prior-guided-few-shot-semantic-segmentation/aaf2e7d2dca4a52972c8bff46e3a4bc6d775a946/networks/__pycache__/resnet50_34.cpython-37.pyc -------------------------------------------------------------------------------- /networks/__pycache__/resnet50_34_all.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoql98/progressively-dual-prior-guided-few-shot-semantic-segmentation/aaf2e7d2dca4a52972c8bff46e3a4bc6d775a946/networks/__pycache__/resnet50_34_all.cpython-36.pyc -------------------------------------------------------------------------------- /networks/__pycache__/resnet50_34_all.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoql98/progressively-dual-prior-guided-few-shot-semantic-segmentation/aaf2e7d2dca4a52972c8bff46e3a4bc6d775a946/networks/__pycache__/resnet50_34_all.cpython-37.pyc -------------------------------------------------------------------------------- /networks/__pycache__/resnet50_34_s.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoql98/progressively-dual-prior-guided-few-shot-semantic-segmentation/aaf2e7d2dca4a52972c8bff46e3a4bc6d775a946/networks/__pycache__/resnet50_34_s.cpython-36.pyc -------------------------------------------------------------------------------- /networks/__pycache__/resnet50_34_s.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoql98/progressively-dual-prior-guided-few-shot-semantic-segmentation/aaf2e7d2dca4a52972c8bff46e3a4bc6d775a946/networks/__pycache__/resnet50_34_s.cpython-37.pyc -------------------------------------------------------------------------------- /networks/__pycache__/resnet50based.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoql98/progressively-dual-prior-guided-few-shot-semantic-segmentation/aaf2e7d2dca4a52972c8bff46e3a4bc6d775a946/networks/__pycache__/resnet50based.cpython-36.pyc -------------------------------------------------------------------------------- /networks/__pycache__/resnet50based.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoql98/progressively-dual-prior-guided-few-shot-semantic-segmentation/aaf2e7d2dca4a52972c8bff46e3a4bc6d775a946/networks/__pycache__/resnet50based.cpython-37.pyc -------------------------------------------------------------------------------- /networks/__pycache__/resnet50based4.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoql98/progressively-dual-prior-guided-few-shot-semantic-segmentation/aaf2e7d2dca4a52972c8bff46e3a4bc6d775a946/networks/__pycache__/resnet50based4.cpython-36.pyc -------------------------------------------------------------------------------- /networks/__pycache__/resnet50based4.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoql98/progressively-dual-prior-guided-few-shot-semantic-segmentation/aaf2e7d2dca4a52972c8bff46e3a4bc6d775a946/networks/__pycache__/resnet50based4.cpython-37.pyc -------------------------------------------------------------------------------- /networks/__pycache__/resnet50basedf.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoql98/progressively-dual-prior-guided-few-shot-semantic-segmentation/aaf2e7d2dca4a52972c8bff46e3a4bc6d775a946/networks/__pycache__/resnet50basedf.cpython-36.pyc -------------------------------------------------------------------------------- /networks/__pycache__/resnet50basedf.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoql98/progressively-dual-prior-guided-few-shot-semantic-segmentation/aaf2e7d2dca4a52972c8bff46e3a4bc6d775a946/networks/__pycache__/resnet50basedf.cpython-37.pyc -------------------------------------------------------------------------------- /networks/__pycache__/resnet50basedfused.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoql98/progressively-dual-prior-guided-few-shot-semantic-segmentation/aaf2e7d2dca4a52972c8bff46e3a4bc6d775a946/networks/__pycache__/resnet50basedfused.cpython-37.pyc -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import cv2 4 | import torch 5 | import json 6 | import numpy as np 7 | import argparse 8 | import time 9 | import torch.nn.functional as F 10 | from data.LoadDataSeg import val_loader 11 | from utils import NoteEvaluation 12 | from networks import * 13 | from utils.Restore import restore 14 | 15 | from config import settings 16 | 17 | 18 | K_SHOT = 1 19 | DATASET = 'voc' 20 | SNAPSHOT_DIR =settings.SNAPSHOT_DIR 21 | if DATASET =='coco': 22 | SNAPSHOT_DIR = SNAPSHOT_DIR+'/coco' 23 | 24 | 25 | GPU_ID = '0' 26 | os.environ["CUDA_VISIBLE_DEVICES"] = GPU_ID 27 | 28 | def get_arguments(): 29 | parser = argparse.ArgumentParser(description='OneShot') 30 | parser.add_argument("--arch", type=str,default='FRPMMs') 31 | parser.add_argument("--disp_interval", type=int, default=100) 32 | parser.add_argument("--snapshot_dir", type=str, default=SNAPSHOT_DIR) 33 | 34 | parser.add_argument("--group", type=int, default=0) 35 | parser.add_argument('--num_folds', type=int, default=4) 36 | parser.add_argument('--restore_step', type=int, default=100000) 37 | parser.add_argument('--batch_size', type=int, default=1) 38 | parser.add_argument('--mode', type=str, default='val') 39 | parser.add_argument('--dataset', type=str, default=DATASET) 40 | 41 | return parser.parse_args() 42 | 43 | def get_model(args): 44 | 45 | model = eval(args.arch).OneModel(args) 46 | 47 | model = model.cuda() 48 | 49 | return model 50 | 51 | def val(args): 52 | model = get_model(args) 53 | model.eval() 54 | 55 | evaluations = NoteEvaluation.Evaluation(args) 56 | 57 | for group in range(0,1): 58 | 59 | print("-------------GROUP %d-------------" % (group)) 60 | 61 | args.group = group 62 | evaluations.group =args.group 63 | val_dataloader = val_loader(args,k_shot = K_SHOT) 64 | restore(args, model) 65 | it = 0 66 | 67 | for data in val_dataloader: 68 | begin_time = time.time() 69 | it = it+1 70 | query_img, query_mask, support_img, support_mask, idx, size = data 71 | 72 | query_img, query_mask, support_img, support_mask, idx \ 73 | = query_img.cuda(), query_mask.cuda(), support_img.cuda(), support_mask.cuda(), idx.cuda() 74 | 75 | with torch.no_grad(): 76 | 77 | logits = model(query_img,query_mask,support_img, support_mask) 78 | query_img = F.upsample(query_img, size=(size[0], size[1]), mode='bilinear') 79 | query_mask = F.upsample(query_mask, size=(size[0], size[1]), mode='nearest') 80 | values, pred = model.get_pred(logits, query_img) 81 | 82 | evaluations.update_evl(idx, query_mask, pred, 0) 83 | end_time = time.time() 84 | ImgPerSec = 1/(end_time-begin_time) 85 | 86 | print("It has tested %d, %.2f images/s" %(it,ImgPerSec), end="\r") 87 | print("Group %d: %.4f " %(args.group, evaluations.group_mean_iou[args.group])) 88 | 89 | iou = evaluations.iou_list 90 | # print('IOU:', iou) 91 | mIoU = np.mean(iou) 92 | # print('mIoU: ', mIoU) 93 | print("group0_mIou", evaluations.group_mean_iou[0]) 94 | print("group1_mIou", evaluations.group_mean_iou[1]) 95 | print("group2_mIou", evaluations.group_mean_iou[2]) 96 | print("group3_mIou", evaluations.group_mean_iou[3]) 97 | print(evaluations.group_mean_iou) 98 | #print(evaluations.iou_list) 99 | 100 | return mIoU, iou, evaluations 101 | 102 | 103 | 104 | if __name__ == '__main__': 105 | args = get_arguments() 106 | print('Running parameters:\n') 107 | print(json.dumps(vars(args), indent=4, separators=(',', ':'))) 108 | if not os.path.exists(args.snapshot_dir): 109 | os.mkdir(args.snapshot_dir) 110 | val(args) 111 | -------------------------------------------------------------------------------- /test_5shot.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import cv2 4 | import torch 5 | import json 6 | import numpy as np 7 | import argparse 8 | import time 9 | import torch.nn.functional as F 10 | from data.LoadDataSeg import val_loader 11 | from utils import NoteEvaluation 12 | from networks import * 13 | from utils.Restore import restore 14 | 15 | from config import settings 16 | 17 | DATASET = 'voc' 18 | SNAPSHOT_DIR =settings.SNAPSHOT_DIR 19 | if DATASET =='coco': 20 | SNAPSHOT_DIR = SNAPSHOT_DIR+'/coco' 21 | 22 | GPU_ID = '0' 23 | os.environ["CUDA_VISIBLE_DEVICES"] = GPU_ID 24 | K_SHOT=5 25 | 26 | Restore_Step_list = ['best', 'best', 'best', 'best'] 27 | 28 | def get_arguments(): 29 | parser = argparse.ArgumentParser(description='OneShot') 30 | parser.add_argument("--arch", type=str,default='FRPMMs') 31 | parser.add_argument("--disp_interval", type=int, default=100) 32 | parser.add_argument("--snapshot_dir", type=str, default=SNAPSHOT_DIR) 33 | 34 | parser.add_argument("--group", type=int, default=0) 35 | parser.add_argument('--num_folds', type=int, default=4) 36 | parser.add_argument('--restore_step', type=int, default=100000) 37 | parser.add_argument('--batch_size', type=int, default=1) 38 | parser.add_argument('--mode', type=str, default='val') 39 | parser.add_argument('--dataset', type=str, default=DATASET) 40 | 41 | return parser.parse_args() 42 | 43 | def get_model(args): 44 | 45 | model = eval(args.arch).OneModel(args) 46 | 47 | model = model.cuda() 48 | 49 | return model 50 | 51 | def val(args): 52 | model = get_model(args) 53 | model.eval() 54 | 55 | evaluations = NoteEvaluation.Evaluation(args) 56 | 57 | for group in range(4): 58 | args.restore_step = Restore_Step_list[group] 59 | print("-------------GROUP %d-------------" % (group)) 60 | 61 | args.group = group 62 | evaluations.group =args.group 63 | val_dataloader = val_loader(args, k_shot=K_SHOT) 64 | restore(args, model) 65 | it = 0 66 | 67 | for data in val_dataloader: 68 | begin_time = time.time() 69 | it = it+1 70 | query_img, query_mask, support_img, support_mask, idx, size = data 71 | 72 | query_img, query_mask, support_img, support_mask, idx \ 73 | = query_img.cuda(), query_mask.cuda(), support_img.cuda(), support_mask.cuda(), idx.cuda() 74 | 75 | with torch.no_grad(): 76 | logits = model.forward_5shot(query_img, support_img, support_mask) 77 | query_img = F.upsample(query_img, size=(size[0], size[1]), mode='bilinear') 78 | query_mask = F.upsample(query_mask, size=(size[0], size[1]), mode='nearest') 79 | 80 | values, pred = model.get_pred(logits, query_img) 81 | evaluations.update_evl(idx, query_mask, pred, 0) 82 | end_time = time.time() 83 | ImgPerSec = 1/(end_time-begin_time) 84 | print("It has tested %d, %.2f images/s" %(it,ImgPerSec), end="\r") 85 | print("Group %d: %.4f " %(args.group, evaluations.group_mean_iou[args.group])) 86 | iou = evaluations.iou_list 87 | print('IOU:', iou) 88 | mIoU = np.mean(iou) 89 | print('mIoU: ', mIoU) 90 | print("group0_iou", evaluations.group_mean_iou[0]) 91 | print("group1_iou", evaluations.group_mean_iou[1]) 92 | print("group2_iou", evaluations.group_mean_iou[2]) 93 | print("group3_iou", evaluations.group_mean_iou[3]) 94 | print(evaluations.group_mean_iou) 95 | 96 | return mIoU, iou 97 | 98 | 99 | 100 | if __name__ == '__main__': 101 | args = get_arguments() 102 | print('Running parameters:\n') 103 | print(json.dumps(vars(args), indent=4, separators=(',', ':'))) 104 | if not os.path.exists(args.snapshot_dir): 105 | os.mkdir(args.snapshot_dir) 106 | val(args) 107 | -------------------------------------------------------------------------------- /test_all_frame.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import cv2 4 | import torch 5 | import json 6 | import numpy as np 7 | import argparse 8 | from torch.autograd import Variable 9 | from data.LoadDataSeg import val_loader 10 | from utils import NoteEvaluation 11 | from networks import * 12 | from utils.Restore import restore 13 | 14 | from config import settings 15 | from test import val as VAL 16 | #from test_5shot import val as VAL 17 | 18 | from utils.Restore import Save_Evaluations 19 | from utils.Visualize import print_best 20 | 21 | DATASET = 'voc' 22 | # DATASET = 'coco' 23 | SNAPSHOT_DIR =settings.SNAPSHOT_DIR 24 | #SNAPSHOT_DIR = SNAPSHOT_DIR+'/coco' 25 | if DATASET =='coco': 26 | SNAPSHOT_DIR = SNAPSHOT_DIR+'/coco' 27 | 28 | START = 16000 29 | END = 205000 30 | # END = 45000 31 | GPU_ID = '0' 32 | os.environ["CUDA_VISIBLE_DEVICES"] = GPU_ID 33 | 34 | def get_arguments(): 35 | parser = argparse.ArgumentParser(description='OneShot') 36 | parser.add_argument("--arch", type=str,default='resnet50_34') 37 | parser.add_argument("--disp_interval", type=int, default=100) 38 | parser.add_argument("--snapshot_dir", type=str, default=SNAPSHOT_DIR) 39 | 40 | parser.add_argument("--group", type=int, default=0) 41 | parser.add_argument('--num_folds', type=int, default=4) 42 | parser.add_argument('--interval', type=int, default=250) 43 | parser.add_argument('--start', type=int, default=START) 44 | parser.add_argument('--end', type=int, default=END) 45 | parser.add_argument('--restore_step', type=int, default=100000) 46 | parser.add_argument('--batch_size', type=int, default=1) 47 | parser.add_argument('--mode', type=str, default='val') 48 | parser.add_argument('--dataset', type=str, default=DATASET) 49 | 50 | return parser.parse_args() 51 | 52 | def get_model(args): 53 | 54 | model = eval(args.arch).OneModel(args) 55 | 56 | model = model.cuda() 57 | 58 | return model 59 | 60 | 61 | if __name__ == '__main__': 62 | args = get_arguments() 63 | print('Running parameters:\n') 64 | print(json.dumps(vars(args), indent=4, separators=(',', ':'))) 65 | Best_Note = NoteEvaluation.note_best() 66 | File_Evaluations = Save_Evaluations(args) 67 | if not os.path.exists(args.snapshot_dir): 68 | os.mkdir(args.snapshot_dir) 69 | for i in range(args.start, args.end, args.interval): 70 | print("---------------------------------EVALUATE STEP %d---------------------------------" % (i)) 71 | args.restore_step = i 72 | 73 | mIoU, iou, evaluations = VAL(args) 74 | 75 | Best_Note.update(mIoU, args.restore_step, iou, evaluations) 76 | File_Evaluations.update_date(args.restore_step, mIoU, evaluations) 77 | print("-------------") 78 | print("best_BMVC_IOU ", Best_Note.best_mean) 79 | print("best_group0_iou", Best_Note.best0) 80 | print("best_group1_iou", Best_Note.best1) 81 | print("best_group2_iou", Best_Note.best2) 82 | print("best_group3_iou", Best_Note.best3) 83 | 84 | print_best(Best_Note) 85 | File_Evaluations.update_best(Best_Note) 86 | File_Evaluations.update_best_eachgroup(Best_Note) 87 | File_Evaluations.save_file() 88 | 89 | 90 | -------------------------------------------------------------------------------- /test_frame.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import torch 4 | import json 5 | import numpy as np 6 | import argparse 7 | import time 8 | import torch.nn.functional as F 9 | 10 | from data.LoadDataSeg import val_loader 11 | from networks import * 12 | from utils import NoteEvaluation 13 | from utils.Restore import restore 14 | from config import settings 15 | 16 | DATASET = 'voc' 17 | SNAPSHOT_DIR =settings.SNAPSHOT_DIR 18 | if DATASET =='coco': 19 | SNAPSHOT_DIR = SNAPSHOT_DIR+'/coco' 20 | 21 | 22 | GPU_ID = '0' 23 | os.environ["CUDA_VISIBLE_DEVICES"] = GPU_ID 24 | 25 | 26 | def get_arguments(): 27 | parser = argparse.ArgumentParser(description='OneShot') 28 | parser.add_argument("--arch", type=str,default='FRPMMs') 29 | parser.add_argument("--disp_interval", type=int, default=100) 30 | parser.add_argument("--snapshot_dir", type=str, default=SNAPSHOT_DIR) 31 | 32 | parser.add_argument("--group", type=int, default=0) 33 | parser.add_argument('--num_folds', type=int, default=4) 34 | parser.add_argument('--restore_step', type=str, default='best') 35 | parser.add_argument('--batch_size', type=int, default=1) 36 | parser.add_argument('--mode', type=str, default='val') 37 | parser.add_argument('--dataset', type=str, default=DATASET) 38 | 39 | return parser.parse_args() 40 | 41 | def get_model(args): 42 | 43 | model = eval(args.arch).OneModel(args) 44 | 45 | model = model.cuda() 46 | 47 | return model 48 | 49 | def val(args): 50 | model = get_model(args) 51 | model.eval() 52 | 53 | evaluations = NoteEvaluation.Evaluation(args) 54 | 55 | for group in range(4): 56 | 57 | print("-------------GROUP %d-------------" % (group)) 58 | 59 | args.group = group 60 | evaluations.group =args.group 61 | val_dataloader = val_loader(args,k_shot=1) 62 | restore(args, model) 63 | it = 0 64 | 65 | for data in val_dataloader: 66 | begin_time = time.time() 67 | it = it+1 68 | query_img, query_mask, support_img, support_mask, idx, size = data 69 | 70 | query_img, query_mask, support_img, support_mask, idx \ 71 | = query_img.cuda(), query_mask.cuda(), support_img.cuda(), support_mask.cuda(), idx.cuda() 72 | 73 | with torch.no_grad(): 74 | 75 | logits = model(query_img, support_img, support_mask) 76 | 77 | query_img = F.upsample(query_img, size=(size[0], size[1]), mode='bilinear') 78 | query_mask = F.upsample(query_mask, size=(size[0], size[1]), mode='nearest') 79 | 80 | values, pred = model.get_pred(logits, query_img) 81 | evaluations.update_evl(idx, query_mask, pred, 0) 82 | end_time = time.time() 83 | ImgPerSec = 1/(end_time-begin_time) 84 | print("It has tested %d, %.2f images/s" %(it*args.batch_size,ImgPerSec*args.batch_size), end="\r") 85 | print("Group %d: %.4f " %(args.group, evaluations.group_mean_iou[args.group])) 86 | iou = evaluations.iou_list 87 | print('IOU:', iou) 88 | mIoU = np.mean(iou) 89 | print('mIoU: ', mIoU) 90 | print("group0_iou", evaluations.group_mean_iou[0]) 91 | print("group1_iou", evaluations.group_mean_iou[1]) 92 | print("group2_iou", evaluations.group_mean_iou[2]) 93 | print("group3_iou", evaluations.group_mean_iou[3]) 94 | print(evaluations.group_mean_iou) 95 | 96 | return mIoU, iou 97 | 98 | 99 | 100 | if __name__ == '__main__': 101 | args = get_arguments() 102 | print('Running parameters:\n') 103 | print(json.dumps(vars(args), indent=4, separators=(',', ':'))) 104 | if not os.path.exists(args.snapshot_dir): 105 | os.mkdir(args.snapshot_dir) 106 | val(args) 107 | -------------------------------------------------------------------------------- /utils/NoteEvaluation.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | def measure(y_in, pred_in): 5 | thresh = .5 6 | y = y_in>thresh 7 | pred = pred_in>thresh 8 | tp = np.logical_and(y,pred).sum() 9 | tn = np.logical_and(np.logical_not(y), np.logical_not(pred)).sum() 10 | fp = np.logical_and(np.logical_not(y), pred).sum() 11 | fn = np.logical_and(y, np.logical_not(pred)).sum() 12 | return tp, tn, fp, fn 13 | 14 | class Evaluation(): 15 | def __init__(self, args): 16 | if args.dataset == 'coco': 17 | self.num_classes = 80 18 | if args.dataset == 'voc': 19 | self.num_classes = 20 20 | self.num_folds=4 21 | self.group_class_num = self.num_classes/4 22 | self.batch_size = args.batch_size 23 | self.disp_interval = args.disp_interval 24 | self.clear_num = 200 25 | self.group = args.group 26 | self.group_mean_iou = [0]*4 27 | self.setup() 28 | 29 | def get_val_id_list(self): 30 | num = int(self.num_classes / self.num_folds) 31 | val_set = [self.group + self.num_folds * v for v in range(num)] 32 | 33 | return val_set 34 | 35 | def setup(self): 36 | self.tp_list = [0] * self.num_classes 37 | self.total_list = [0] * self.num_classes 38 | self.iou_list = [0] * self.num_classes 39 | 40 | def update_class_index(self): 41 | if self.num_classes == 80: 42 | self.class_indexes = self.get_val_id_list() 43 | if self.num_classes == 20: 44 | self.class_indexes = range(self.group * 5, (self.group + 1) * 5) 45 | 46 | def update_evl(self, idx, query_mask, pred, count): 47 | self.update_class_index() 48 | if count==self.clear_num: 49 | self.setup() 50 | 51 | for i in range(self.batch_size): 52 | id = idx[i].item() 53 | tp, total = self.test_in_train(query_mask[i],pred[i]) 54 | 55 | self.tp_list[id] += tp 56 | self.total_list[id] += total 57 | self.iou_list = [self.tp_list[ic] / 58 | float(max(self.total_list[ic], 1)) 59 | for ic in range(self.num_classes)] 60 | # print(self.iou_list) 61 | # self.class_indexes= range(0,6) 62 | l1 = list(np.take(self.iou_list, self.class_indexes)) 63 | if 0 in l1: 64 | l1.remove(0) 65 | self.group_mean_iou[self.group] = np.mean(np.asarray(l1)) 66 | 67 | # self.group_mean_iou[self.group] = np.mean(np.take(self.iou_list, self.class_indexes)) 68 | 69 | 70 | def test_in_train(self,query_label, pred): 71 | # pred = pred.data.cpu().numpy().astype(np.int32) 72 | pred = pred.data.cpu().numpy() 73 | query_label = query_label.cpu().numpy().astype(np.int32) 74 | 75 | tp, tn, fp, fn = measure(query_label, pred) 76 | total = tp + fp + fn 77 | 78 | return tp, total 79 | 80 | class note_best(object): 81 | def __init__(self): 82 | self.init_independent() 83 | 84 | def init_independent(self): 85 | self.best0 = 0 86 | self.best1 = 0 87 | self.best2 = 0 88 | self.best3 = 0 89 | self.best0_step = 0 90 | self.best1_step = 0 91 | self.best2_step = 0 92 | self.best3_step = 0 93 | self.best_mean = 0 94 | 95 | def update(self, mIou, restore_step, iou_list, evaluations): 96 | self.update_independent_fold(restore_step, iou_list, evaluations) 97 | 98 | def update_independent_fold(self, restore_step, iou_list, evaluations): 99 | g0 = evaluations.group_mean_iou[0] 100 | g1 = evaluations.group_mean_iou[1] 101 | g2 = evaluations.group_mean_iou[2] 102 | g3 = evaluations.group_mean_iou[3] 103 | 104 | if g0 > self.best0: 105 | self.best0 = g0 106 | self.best0_step = restore_step 107 | if g1 > self.best1: 108 | self.best1 = g1 109 | self.best1_step = restore_step 110 | if g2 > self.best2: 111 | self.best2 = g2 112 | self.best2_step = restore_step 113 | if g3 > self.best3: 114 | self.best3 = g3 115 | self.best3_step = restore_step 116 | self.best_mean = (self.best0+self.best1+self.best2+self.best3)/4 117 | -------------------------------------------------------------------------------- /utils/NoteLoss.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | 5 | class AverageMeter(object): 6 | """Computes and stores the average and current value""" 7 | def __init__(self): 8 | self.reset() 9 | 10 | def reset(self): 11 | self.val = 0 12 | self.avg = 0 13 | self.sum = 0 14 | self.count = 0 15 | self.data = [] 16 | 17 | def update(self, val, n=1): 18 | self.val = val 19 | self.sum += val * n 20 | self.count += n 21 | self.avg = self.sum / self.count 22 | self.data.append(val) 23 | 24 | class Loss_total(): 25 | 26 | def __init__(self, args): 27 | self.total = AverageMeter() 28 | self.part1 = AverageMeter() 29 | self.part2 = AverageMeter() 30 | self.disp_interval = args.disp_interval 31 | 32 | def updateloss(self, loss_val, loss_part1=0, loss_part2=0): 33 | self.total.update(loss_val.data.item(), 1) 34 | self.part1.update(loss_part1.data.item(),1) if isinstance(loss_part1, torch.Tensor) else self.part1.update(0,1) 35 | self.part2.update(loss_part2.data.item(),1) if isinstance(loss_part2, torch.Tensor) else self.part2.update(0,1) 36 | 37 | def logloss(self, log_file): 38 | count = self.total.count 39 | loss_val_float = self.total.val 40 | out_str = '%d, %.4f\n' % (count, loss_val_float) 41 | log_file.write(out_str) 42 | 43 | -------------------------------------------------------------------------------- /utils/Restore.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import shutil 4 | import torch 5 | import pandas as pd 6 | import numpy as np 7 | import torch.nn as nn 8 | from collections import OrderedDict 9 | def restore(args, model): 10 | 11 | group = args.group 12 | savedir = os.path.join(args.snapshot_dir, args.arch, 'group_%d_of_%d'%(group, args.num_folds)) 13 | if args.restore_step=='best': 14 | filename='%s.pth.tar'%(args.restore_step) 15 | else: 16 | filename='step_%d.pth.tar'%(args.restore_step) 17 | snapshot = os.path.join(savedir, filename) 18 | assert os.path.exists(snapshot), "Snapshot file %s does not exist."%(snapshot) 19 | # print(snapshot) 20 | checkpoint = torch.load(snapshot) 21 | # for i in checkpoint: 22 | # print(i) 23 | # model = nn.DataParallel(model) #using more than one gpu 24 | 25 | new_state_dict = OrderedDict() 26 | for k , v in checkpoint.items(): 27 | k1 = v.items() 28 | # print(list(v.items()).name) 29 | # print(dict(v)) 30 | for k1, v1 in v.items(): 31 | namekey = k1[7:] # remove `module.` 32 | new_state_dict[namekey] = v1 33 | model.load_state_dict(new_state_dict) 34 | 35 | # model.load_state_dict(checkpoint['state_dict']) 36 | print('Loaded weights from %s'%(snapshot)) 37 | 38 | def get_model_para_number(model): 39 | total_number = 0 40 | for para in model.parameters(): 41 | total_number += torch.numel(para) 42 | 43 | return total_number 44 | 45 | def get_save_dir(args): 46 | snapshot_dir = os.path.join(args.snapshot_dir, args.arch, 'group_%d_of_%d'%(args.group, args.num_folds)) 47 | return snapshot_dir 48 | 49 | def save_checkpoint(args, state, is_best, filename='checkpoint.pth.tar'): 50 | savedir = os.path.join(args.snapshot_dir, args.arch, 'group_%d_of_%d'%(args.group, args.num_folds)) 51 | if not os.path.exists(savedir): 52 | os.makedirs(savedir) 53 | 54 | savepath = os.path.join(savedir, filename) 55 | torch.save(state, savepath) 56 | if is_best: 57 | shutil.copyfile(savepath, os.path.join(args.snapshot_dir, 'model_best.pth.tar')) 58 | 59 | def save_model(args, count, model, optimizer): 60 | if count % args.save_interval == 0 and count > 0: 61 | save_checkpoint(args, 62 | { 63 | 'state_dict': model.state_dict() 64 | }, is_best=False, 65 | filename='step_%d.pth.tar' 66 | % (count)) 67 | 68 | class Save_Evaluations(): 69 | def __init__(self, args): 70 | self.savedir = 'evaluation/' + args.arch + '_' + args.dataset+'_'+'val.csv' 71 | self.Note_Iou = [] 72 | self.col = [] 73 | self.ind = ['Group0','Group1','Group2','Group3','Mean'] 74 | self.dataset = args.dataset 75 | self.update_class_index() 76 | for i in range(5): 77 | self.Note_Iou.append([]) 78 | def get_val_id_list(self, group): 79 | num_classes = 80 80 | num_folds = 4 81 | num = int(num_classes / num_folds) 82 | # val_set = [self.group * num + v for v in range(num)] 83 | val_set = [group + num_folds * v for v in range(num)] 84 | return val_set 85 | def update_class_index(self): 86 | self.class_indexes = [] 87 | for group in range(4): 88 | if self.dataset == 'coco': 89 | self.class_indexes.append(self.get_val_id_list(group)) 90 | if self.dataset == 'voc': 91 | self.class_indexes = range(group * 5, (group + 1) * 5) 92 | def save_file(self): 93 | test = pd.DataFrame(data=self.Note_Iou, index=self.ind, columns=self.col) 94 | test.to_csv(self.savedir) 95 | def update_date(self, restore_step, mIoU, evaluations): 96 | self.col.append(restore_step) 97 | 98 | self.Note_Iou[0].append(evaluations.group_mean_iou[0]) 99 | self.Note_Iou[1].append(evaluations.group_mean_iou[1]) 100 | self.Note_Iou[2].append(evaluations.group_mean_iou[2]) 101 | self.Note_Iou[3].append(evaluations.group_mean_iou[3]) 102 | self.Note_Iou[4].append(mIoU) 103 | 104 | def update_best(self, Best_Note): 105 | self.col.append(Best_Note.restore_step) 106 | self.Note_Iou[0].append(Best_Note.group0_iou) 107 | self.Note_Iou[1].append(Best_Note.group1_iou) 108 | self.Note_Iou[2].append(Best_Note.group2_iou) 109 | self.Note_Iou[3].append(Best_Note.group3_iou) 110 | self.Note_Iou[4].append(Best_Note.BMVC_IOU) 111 | def update_best_eachgroup(self, Best_Note): 112 | self.col.append('Best') 113 | self.Note_Iou[0].append(Best_Note.best0) 114 | self.Note_Iou[1].append(Best_Note.best1) 115 | self.Note_Iou[2].append(Best_Note.best2) 116 | self.Note_Iou[3].append(Best_Note.best3) 117 | self.Note_Iou[4].append(Best_Note.best_mean) 118 | 119 | self.col.append('Best_Step') 120 | self.Note_Iou[0].append(Best_Note.best0_step) 121 | self.Note_Iou[1].append(Best_Note.best1_step) 122 | self.Note_Iou[2].append(Best_Note.best2_step) 123 | self.Note_Iou[3].append(Best_Note.best3_step) 124 | self.Note_Iou[4].append(Best_Note.best_mean) 125 | -------------------------------------------------------------------------------- /utils/Visualize.py: -------------------------------------------------------------------------------- 1 | import math 2 | import cv2 3 | import time 4 | import numpy as np 5 | import torch 6 | import matplotlib.pyplot as plt 7 | import skimage.morphology as skm 8 | from config import settings 9 | 10 | class visualize_loss_evl_train(): 11 | def __init__(self, args): 12 | self.disp_interval = args.disp_interval 13 | 14 | def visualize(self, args, count, Loss, Evaluation, begin_time): 15 | 16 | if count % self.disp_interval == 0: 17 | loss_mean = np.mean(Loss.total.data[-100:]) 18 | loss1 = np.mean(Loss.part1.data[-100:]) 19 | loss2 = np.mean(Loss.part2.data[-100:]) 20 | mean_IOU = np.mean(np.take(Evaluation.iou_list, np.where(np.array(Evaluation.iou_list) > 0))) 21 | 22 | step_time = time.time() - begin_time 23 | remaining_time = step_time*(args.max_steps-count)/self.disp_interval/3600 24 | print(args.arch, 'Group:%d \t Step:%d \t Loss:%.3f \t ' 25 | 'Part1: %.3f \t Part2: %.3f \t mean_IOU: %.4f \t' 26 | 'Step time: %.4f s \t Remaining time: %.4f h' % (args.group, count, loss_mean, 27 | loss1.cpu().data.numpy() if isinstance(loss1, 28 | torch.Tensor) else loss1, 29 | loss2.cpu().data.numpy() if isinstance(loss2, 30 | torch.Tensor) else loss2, 31 | mean_IOU, step_time, remaining_time)) 32 | def print_best(Best_Note): 33 | print("---------------------------------FINAL BEST RESULT ---------------------------------") 34 | print("best_ BMVC_IOU ", Best_Note.best_mean) 35 | print("group0_iou", Best_Note.best0) 36 | print("group1_iou", Best_Note.best1) 37 | print("group2_iou", Best_Note.best2) 38 | print("group3_iou", Best_Note.best3) 39 | print("group0_step ", Best_Note.best0_step) 40 | print("group1_step ", Best_Note.best1_step) 41 | print("group2_step ", Best_Note.best2_step) 42 | print("group3_step ", Best_Note.best3_step) 43 | -------------------------------------------------------------------------------- /utils/__pycache__/NoteEvaluation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoql98/progressively-dual-prior-guided-few-shot-semantic-segmentation/aaf2e7d2dca4a52972c8bff46e3a4bc6d775a946/utils/__pycache__/NoteEvaluation.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/NoteEvaluation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoql98/progressively-dual-prior-guided-few-shot-semantic-segmentation/aaf2e7d2dca4a52972c8bff46e3a4bc6d775a946/utils/__pycache__/NoteEvaluation.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/NoteLoss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoql98/progressively-dual-prior-guided-few-shot-semantic-segmentation/aaf2e7d2dca4a52972c8bff46e3a4bc6d775a946/utils/__pycache__/NoteLoss.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/Restore.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoql98/progressively-dual-prior-guided-few-shot-semantic-segmentation/aaf2e7d2dca4a52972c8bff46e3a4bc6d775a946/utils/__pycache__/Restore.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/Restore.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoql98/progressively-dual-prior-guided-few-shot-semantic-segmentation/aaf2e7d2dca4a52972c8bff46e3a4bc6d775a946/utils/__pycache__/Restore.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/Visualize.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoql98/progressively-dual-prior-guided-few-shot-semantic-segmentation/aaf2e7d2dca4a52972c8bff46e3a4bc6d775a946/utils/__pycache__/Visualize.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/Visualize.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoql98/progressively-dual-prior-guided-few-shot-semantic-segmentation/aaf2e7d2dca4a52972c8bff46e3a4bc6d775a946/utils/__pycache__/Visualize.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/my_optim.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoql98/progressively-dual-prior-guided-few-shot-semantic-segmentation/aaf2e7d2dca4a52972c8bff46e3a4bc6d775a946/utils/__pycache__/my_optim.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/my_optim.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoql98/progressively-dual-prior-guided-few-shot-semantic-segmentation/aaf2e7d2dca4a52972c8bff46e3a4bc6d775a946/utils/__pycache__/my_optim.cpython-37.pyc -------------------------------------------------------------------------------- /utils/my_optim.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.optim as optim 3 | import numpy as np 4 | 5 | def get_finetune_optimizer(args, model): 6 | lr = args.lr 7 | weight_list = [] 8 | bias_list = [] 9 | pretrain_weight_list = [] 10 | pretrain_bias_list =[] 11 | for name,value in model.named_parameters(): 12 | if 'model_res' in name or 'model_backbone' in name: 13 | if 'weight' in name: 14 | pretrain_weight_list.append(value) 15 | elif 'bias' in name: 16 | pretrain_bias_list.append(value) 17 | else: 18 | if 'weight' in name: 19 | weight_list.append(value) 20 | elif 'bias' in name: 21 | bias_list.append(value) 22 | 23 | opt = optim.SGD([{'params': pretrain_weight_list, 'lr':lr}, 24 | {'params': pretrain_bias_list, 'lr':lr*2}, 25 | {'params': weight_list, 'lr':lr*10}, 26 | {'params': bias_list, 'lr':lr*20}], momentum=0.90, weight_decay=0.0005) # momentum = 0.99 27 | 28 | return opt 29 | 30 | def adjust_learning_rate_poly(args, optimizer, iter, power=0.9): 31 | base_lr = args.lr 32 | max_iter = args.max_steps 33 | reduce = ((1-float(iter)/max_iter)**(power)) 34 | lr = base_lr * reduce 35 | optimizer.param_groups[0]['lr'] = lr * 1 36 | optimizer.param_groups[1]['lr'] = lr * 2 37 | optimizer.param_groups[2]['lr'] = lr * 10 38 | optimizer.param_groups[3]['lr'] = lr * 20 39 | -------------------------------------------------------------------------------- /utils/test.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import cv2 4 | import torch 5 | import json 6 | import numpy as np 7 | import argparse 8 | import time 9 | import torch.nn.functional as F 10 | from data.LoadDataSeg import val_loader 11 | from utils import NoteEvaluation 12 | from networks import * 13 | from utils.Restore import restore 14 | 15 | from config import settings 16 | 17 | 18 | K_SHOT =1 19 | DATASET = 'voc' 20 | SNAPSHOT_DIR =settings.SNAPSHOT_DIR 21 | if DATASET =='coco': 22 | SNAPSHOT_DIR = SNAPSHOT_DIR+'/coco' 23 | 24 | 25 | GPU_ID = '0' 26 | os.environ["CUDA_VISIBLE_DEVICES"] = GPU_ID 27 | 28 | def get_arguments(): 29 | parser = argparse.ArgumentParser(description='OneShot') 30 | parser.add_argument("--arch", type=str,default='FRPMMs') 31 | parser.add_argument("--disp_interval", type=int, default=100) 32 | parser.add_argument("--snapshot_dir", type=str, default=SNAPSHOT_DIR) 33 | 34 | parser.add_argument("--group", type=int, default=0) 35 | parser.add_argument('--num_folds', type=int, default=4) 36 | parser.add_argument('--restore_step', type=int, default=100000) 37 | parser.add_argument('--batch_size', type=int, default=1) 38 | parser.add_argument('--mode', type=str, default='val') 39 | parser.add_argument('--dataset', type=str, default=DATASET) 40 | 41 | return parser.parse_args() 42 | 43 | def get_model(args): 44 | 45 | model = eval(args.arch).OneModel(args) 46 | 47 | model = model.cuda() 48 | 49 | return model 50 | 51 | def val(args): 52 | model = get_model(args) 53 | model.eval() 54 | 55 | evaluations = NoteEvaluation.Evaluation(args) 56 | 57 | for group in range(0,1): 58 | 59 | print("-------------GROUP %d-------------" % (group)) 60 | 61 | args.group = group 62 | evaluations.group =args.group 63 | val_dataloader = val_loader(args,k_shot = K_SHOT) 64 | restore(args, model) 65 | it = 0 66 | 67 | for data in val_dataloader: 68 | begin_time = time.time() 69 | it = it+1 70 | query_img, query_mask, support_img, support_mask, idx, size = data 71 | 72 | query_img, query_mask, support_img, support_mask, idx \ 73 | = query_img.cuda(), query_mask.cuda(), support_img.cuda(), support_mask.cuda(), idx.cuda() 74 | 75 | with torch.no_grad(): 76 | 77 | logits = model(query_img,query_mask,support_img, support_mask) 78 | # loss_val,loss_part1, loss_part2, pred = model(query_img, support_img, support_mask,query_mask) 79 | query_img = F.upsample(query_img, size=(size[0], size[1]), mode='bilinear') 80 | # b,h,w = query_mask.shape 81 | # query_mask = query_mask.view(b,1,h,w).float() 82 | # query_mask[query_mask>1]=0 83 | query_mask = F.upsample(query_mask, size=(size[0], size[1]), mode='nearest') 84 | # pred = pred.unsqueeze(0).float() 85 | # print(pred.shape) 86 | # pred = F.upsample(pred, size=(size[0], size[1]), mode='bilinear') 87 | values, pred = model.get_pred(logits, query_img) 88 | 89 | evaluations.update_evl(idx, query_mask, pred, 0) 90 | end_time = time.time() 91 | ImgPerSec = 1/(end_time-begin_time) 92 | 93 | print("It has tested %d, %.2f images/s" %(it,ImgPerSec), end="\r") 94 | print("Group %d: %.4f " %(args.group, evaluations.group_mean_iou[args.group])) 95 | 96 | iou = evaluations.iou_list 97 | print('IOU:', iou) 98 | mIoU = np.mean(iou) 99 | print('mIoU: ', mIoU) 100 | print("group0_iou", evaluations.group_mean_iou[0]) 101 | print("group1_iou", evaluations.group_mean_iou[1]) 102 | print("group2_iou", evaluations.group_mean_iou[2]) 103 | print("group3_iou", evaluations.group_mean_iou[3]) 104 | print(evaluations.group_mean_iou) 105 | #print(evaluations.iou_list) 106 | 107 | return mIoU, iou, evaluations 108 | 109 | 110 | 111 | if __name__ == '__main__': 112 | args = get_arguments() 113 | print('Running parameters:\n') 114 | print(json.dumps(vars(args), indent=4, separators=(',', ':'))) 115 | if not os.path.exists(args.snapshot_dir): 116 | os.mkdir(args.snapshot_dir) 117 | val(args) 118 | --------------------------------------------------------------------------------