├── Img ├── motivation.png ├── network.png └── networkWSSS.png ├── README.md ├── WSOL ├── Data │ ├── CUB_200_2011 │ │ ├── images_list.txt │ │ ├── test_bounding_box.txt │ │ ├── test_gt.txt │ │ └── test_list.txt │ └── ILSVRC │ │ ├── val_gt.txt │ │ └── val_list.txt ├── DataLoader │ ├── CUB_200_2011.py │ ├── ILSVRC.py │ ├── OpenImage.py │ ├── Standford_Car.py │ ├── Stanford_Dogs.py │ ├── __init__.py │ └── fgvc_aircraft_2013b.py ├── Model │ ├── __init__.py │ ├── inception.py │ ├── mobilenet.py │ ├── resnet.py │ └── vgg.py ├── count_pxap.py ├── evaluator.py ├── logs │ ├── CUB_200_2011 │ │ ├── inception │ │ │ └── save_weight.txt │ │ ├── mobilenet │ │ │ └── save_weight.txt │ │ ├── resnet │ │ │ └── save_weight.txt │ │ └── vgg │ │ │ └── save_weight.txt │ ├── ILSVRC │ │ ├── inception │ │ │ └── save_weight.txt │ │ ├── mobilenet │ │ │ └── save_weight.txt │ │ ├── resnet │ │ │ └── save_weight.txt │ │ └── vgg │ │ │ └── save_weight.txt │ └── OpenImage │ │ └── resnet │ │ └── save_weight.txt ├── train.py ├── train_ILSVRC.py └── utils │ ├── IoU.py │ ├── accuracy.py │ ├── augment.py │ ├── func.py │ ├── hyperparameters.py │ ├── lr.py │ ├── optimizer.py │ └── util.py └── WSSS ├── data ├── coco14 │ ├── gen_lst.py │ ├── split_list.py │ ├── train.txt │ ├── train_cls.txt │ ├── val.txt │ └── val_cls.txt └── voc12 │ ├── train.txt │ ├── train_aug.txt │ ├── train_cls.txt │ ├── val_aug.txt │ └── val_cls.txt ├── misc ├── imutils.py ├── indexing.py ├── pyutils.py └── torchutils.py ├── net ├── __init__.py ├── resnet50.py ├── resnet50_bas.py ├── resnet50_cam.py └── resnet50_irn.py ├── run_sample.py ├── step ├── cam_to_ir_label.py ├── eval_cam.py ├── eval_sem_seg.py ├── make_cam.py ├── make_sem_seg_labels.py ├── train_bas.py └── train_irn.py ├── tool ├── accuracy.py ├── func.py ├── imutils.py ├── optimizer.py ├── pyutils.py ├── torchutils.py ├── util.py └── visualization.py ├── train_bas.py └── voc12 ├── cls_labels.npy ├── dataloader.py ├── make_cls_labels.py ├── test.txt ├── train.txt ├── train_aug.txt └── val.txt /Img/motivation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wpy1999/BAS-Extension/bd1bfa89efd8688dd185a8f34051b2786ac31eba/Img/motivation.png -------------------------------------------------------------------------------- /Img/network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wpy1999/BAS-Extension/bd1bfa89efd8688dd185a8f34051b2786ac31eba/Img/network.png -------------------------------------------------------------------------------- /Img/networkWSSS.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wpy1999/BAS-Extension/bd1bfa89efd8688dd185a8f34051b2786ac31eba/Img/networkWSSS.png -------------------------------------------------------------------------------- /WSOL/DataLoader/CUB_200_2011.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import torch.nn as nn 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | from PIL import Image 7 | import numpy as np 8 | import torch.utils.data as data 9 | from torchvision.datasets import ImageFolder 10 | 11 | def load_test_bbox(root, test_txt_path, test_gt_path,resize_size, crop_size): 12 | test_gt = [] 13 | test_txt = [] 14 | shift_size = (resize_size - crop_size) // 2. 15 | with open(test_txt_path, 'r') as f: 16 | for line in f: 17 | img_path = line.strip('\n').split(';')[0] 18 | test_txt.append(img_path) 19 | with open(test_gt_path, 'r') as f: 20 | for line in f: 21 | x0, y0, x1, y1, h, w = line.strip('\n').split(' ') 22 | x0, y0, x1, y1, h, w = float(x0), float(y0), float(x1), float(y1), float(h), float(w) 23 | x0 = int(max(x0 / w * resize_size - shift_size, 0)) 24 | y0 = int(max(y0 / h * resize_size - shift_size, 0)) 25 | x1 = int(min(x1 / w * resize_size - shift_size, crop_size - 1)) 26 | y1 = int(min(y1 / h * resize_size - shift_size, crop_size - 1)) 27 | test_gt.append(np.array([x0, y0, x1, y1]).reshape(-1)) 28 | final_dict = {} 29 | for k, v in zip(test_txt, test_gt): 30 | k = os.path.join(root, 'test', k) 31 | k = k.replace('/', '\\') 32 | final_dict[k] = v 33 | return final_dict 34 | 35 | class ImageDataset(data.Dataset): 36 | def __init__(self, args ,phase=None): 37 | args.num_classes = 200 38 | self.args =args 39 | self.root = 'Data/' + args.root 40 | self.test_txt_path = self.root + '/' + 'test_list.txt' 41 | self.test_gt_path = self.root + '/' + 'test_bounding_box.txt' 42 | self.crop_size = args.crop_size 43 | self.resize_size = args.resize_size 44 | self.phase = args.phase if phase == None else phase 45 | self.num_classes = args.num_classes 46 | self.tencrop = args.tencrop 47 | if self.phase == 'train': 48 | self.img_dataset = ImageFolder(os.path.join(self.root, 'train')) 49 | self.transform = transforms.Compose([ 50 | transforms.Resize((self.resize_size, self.resize_size)), 51 | transforms.RandomCrop(self.crop_size), 52 | transforms.RandomHorizontalFlip(), 53 | transforms.ToTensor(), 54 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) 55 | ]) 56 | elif self.phase == 'test': 57 | self.img_dataset = ImageFolder(os.path.join(self.root, 'test')) 58 | self.transform = transforms.Compose([ 59 | transforms.Resize((self.resize_size, self.resize_size)), 60 | transforms.CenterCrop(self.crop_size), 61 | transforms.ToTensor(), 62 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) 63 | ]) 64 | self.transform_tencrop = transforms.Compose([ 65 | transforms.Resize((self.resize_size, self.resize_size)), 66 | transforms.TenCrop(args.crop_size), 67 | transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])), 68 | transforms.Lambda(lambda crops: torch.stack([transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))(crop) for crop in crops])), 69 | ]) 70 | 71 | self.label_classes = [] 72 | for k, v in self.img_dataset.class_to_idx.items(): 73 | self.label_classes.append(k) 74 | self.img_dataset = self.img_dataset.imgs 75 | self.test_bbox = load_test_bbox(self.root, self.test_txt_path, self.test_gt_path,self.resize_size, self.crop_size) 76 | 77 | def __getitem__(self, index): 78 | path, img_class = self.img_dataset[index] 79 | label = torch.zeros(self.num_classes) 80 | label[img_class] = 1 81 | img = Image.open(path).convert('RGB') 82 | img_trans = self.transform(img) 83 | if self.phase == 'train': 84 | return path, img_trans, img_class 85 | else: 86 | path = path.replace('/', '\\') 87 | bbox = self.test_bbox[path] 88 | if self.tencrop: 89 | img_tencrop = self.transform_tencrop(img) 90 | return img_trans, img_tencrop, img_class, bbox, path 91 | return img_trans, img_trans, img_class, bbox, path 92 | 93 | def __len__(self): 94 | return len(self.img_dataset) 95 | -------------------------------------------------------------------------------- /WSOL/DataLoader/ILSVRC.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import torch.nn as nn 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | from PIL import Image 7 | import numpy as np 8 | import torch.utils.data as data 9 | from torchvision.datasets import ImageFolder 10 | 11 | def load_test_bbox(root, test_gt_path,crop_size,resize_size): 12 | test_gt = [] 13 | test_txt = [] 14 | shift_size = (resize_size - crop_size) // 2 15 | with open(test_gt_path, 'r') as f: 16 | 17 | for line in f: 18 | temp_gt = [] 19 | part_1, part_2 = line.strip('\n').split(';') 20 | img_path, w, h, _ = part_1.split(' ') 21 | part_2 = part_2[1:] 22 | bbox = part_2.split(' ') 23 | bbox = np.array(bbox, dtype=np.float32) 24 | box_num = len(bbox) // 4 25 | w, h = np.float32(w),np.float32(h) 26 | for i in range(box_num): 27 | bbox[4*i] = int(max(bbox[4*i] / w * resize_size - shift_size, 0)) 28 | bbox[4*i+1] = int(max(bbox[4*i+1] / h * resize_size - shift_size, 0)) 29 | bbox[4*i+2] = int(min(bbox[4*i+2] / w * resize_size - shift_size, crop_size - 1)) 30 | bbox[4*i+3] = int(min(bbox[4*i+3] / h * resize_size - shift_size, crop_size - 1)) 31 | temp_gt.append([bbox[4*i], bbox[4*i+1], bbox[4*i+2], bbox[4*i+3]]) 32 | test_gt.append(temp_gt) 33 | img_path = img_path.replace("\\\\","\\") 34 | test_txt.append(img_path) 35 | final_dict = {} 36 | for k, v in zip(test_txt, test_gt): 37 | k = os.path.join(root, 'val', k) 38 | k = k.replace('/', '\\') 39 | final_dict[k] = v 40 | return final_dict 41 | 42 | class ImageDataset(data.Dataset): 43 | def __init__(self, args, phase=None): 44 | args.num_classes = 1000 45 | self.args =args 46 | self.root = '/media/data/imagenet-1k' 47 | self.test_txt_path = 'Data/ILSVRC/val_list.txt' 48 | self.test_gt_path = 'Data/ILSVRC/val_gt.txt' 49 | self.crop_size = args.crop_size 50 | self.resize_size = args.resize_size 51 | self.phase = args.phase if phase == None else phase 52 | self.num_classes = args.num_classes 53 | self.tencrop = args.tencrop 54 | if self.phase == 'train': 55 | self.img_dataset = ImageFolder(os.path.join(self.root, 'train')) 56 | self.transform = transforms.Compose([ 57 | transforms.Resize((self.resize_size, self.resize_size)), 58 | transforms.RandomCrop(self.crop_size), 59 | transforms.RandomHorizontalFlip(), 60 | transforms.ToTensor(), 61 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) 62 | ]) 63 | elif self.phase == 'test': 64 | self.img_dataset = ImageFolder(os.path.join(self.root, 'val')) 65 | self.transform = transforms.Compose([ 66 | transforms.Resize((self.resize_size, self.resize_size)), 67 | transforms.CenterCrop(self.crop_size), 68 | transforms.ToTensor(), 69 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) 70 | ]) 71 | self.transform_tencrop = transforms.Compose([ 72 | transforms.Resize((self.resize_size, self.resize_size)), 73 | transforms.TenCrop(args.crop_size), 74 | transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])), 75 | transforms.Lambda(lambda crops: torch.stack([transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))(crop) for crop in crops])), 76 | ]) 77 | self.label_classes = [] 78 | for k, v in self.img_dataset.class_to_idx.items(): 79 | self.label_classes.append(k) 80 | self.img_dataset = self.img_dataset.imgs 81 | self.test_bbox = load_test_bbox(self.root, self.test_gt_path,self.crop_size,self.resize_size) 82 | 83 | def __getitem__(self, index): 84 | path, img_class = self.img_dataset[index] 85 | 86 | label = torch.zeros(self.num_classes) 87 | label[img_class] = 1 88 | img = Image.open(path).convert('RGB') 89 | img_trans = self.transform(img) 90 | if self.phase == 'train': 91 | return path, img_trans, img_class 92 | else: 93 | path = path.replace('/', '\\') 94 | bbox = self.test_bbox[path] 95 | bbox = np.array(bbox).reshape(-1) 96 | bbox = " ".join(list(map(str, bbox))) 97 | if self.tencrop: 98 | img_tencrop = self.transform_tencrop(img) 99 | return img_trans, img_tencrop, img_class, bbox, path 100 | return img_trans, img_trans, img_class, bbox, path 101 | 102 | def __len__(self): 103 | return len(self.img_dataset) 104 | -------------------------------------------------------------------------------- /WSOL/DataLoader/OpenImage.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import torch.nn as nn 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | from PIL import Image 7 | import numpy as np 8 | import torch.utils.data as data 9 | from torchvision.datasets import ImageFolder 10 | 11 | def load_img_list(image_list_path): 12 | img_dict = [] 13 | with open(image_list_path, 'r') as f: 14 | for line in f: 15 | img_name, classes = line[:-1].split(',') 16 | img_dict.append(img_name + ';' + str(classes)) 17 | return img_dict 18 | 19 | class ImageDataset(data.Dataset): 20 | def __init__(self, args, phase=None): 21 | args.num_classes = 100 22 | self.args =args 23 | self.root = '/home2/Datasets/' + args.root 24 | self.classes_list = self.root + '/' + 'classes.txt' 25 | self.phase = args.phase if phase == None else phase 26 | if self.phase == 'train': 27 | self.image_list_path = self.root + '/' + 'train_images.txt' 28 | elif self.phase == 'test': 29 | self.image_list_path = self.root + '/' + 'test_images.txt' 30 | self.crop_size = args.crop_size 31 | self.resize_size = args.resize_size 32 | self.num_classes = args.num_classes 33 | if self.phase == 'train': 34 | self.transform = transforms.Compose([ 35 | transforms.Resize((self.resize_size, self.resize_size)), 36 | transforms.RandomCrop(self.crop_size), 37 | transforms.RandomHorizontalFlip(), 38 | transforms.ToTensor(), 39 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) 40 | ]) 41 | elif self.phase == 'test': 42 | self.transform = transforms.Compose([ 43 | transforms.Resize((self.crop_size, self.crop_size)), 44 | transforms.ToTensor(), 45 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) 46 | ]) 47 | self.label_classes = {} 48 | with open(self.classes_list, 'r') as f: 49 | for i, line in enumerate(f): 50 | self.label_classes[line[:-1]] = i 51 | self.img_list = load_img_list(self.image_list_path) 52 | 53 | def __getitem__(self, index): 54 | if self.phase == 'train': 55 | img_name, img_class = self.img_list[index].split(';') 56 | path = self.root + '/' + img_name 57 | img = Image.open(path).convert('RGB') 58 | img = self.transform(img) 59 | img_class = int(img_class) 60 | return path, img, img_class 61 | else: 62 | img_name, img_class = self.img_list[index].split(';') 63 | path = self.root + '/' + img_name 64 | img = Image.open(path).convert('RGB') 65 | img = self.transform(img) 66 | img_class = int(img_class) 67 | return img, img, img_class, img, path ## without bbox 68 | 69 | def __len__(self): 70 | return len(self.img_list) -------------------------------------------------------------------------------- /WSOL/DataLoader/Standford_Car.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import torch.nn as nn 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | from PIL import Image 7 | import numpy as np 8 | import torch.utils.data as data 9 | from torchvision.datasets import ImageFolder 10 | 11 | def bbox_crop(bbox, resize_size, crop_size): 12 | #resize_size = crop_size 13 | shift_size = (resize_size - crop_size) // 2. 14 | x0, y0, x1, y1, h, w = bbox.split(' ') 15 | x0, y0, x1, y1, h, w = float(x0), float(y0), float(x1), float(y1), float(h), float(w) 16 | x0 = int(max(x0 / w * resize_size - shift_size, 0)) 17 | y0 = int(max(y0 / h * resize_size - shift_size, 0)) 18 | x1 = int(min(x1 / w * resize_size - shift_size, crop_size - 1)) 19 | y1 = int(min(y1 / h * resize_size - shift_size, crop_size - 1)) 20 | return np.array([x0, y0, x1, y1]).reshape(-1) 21 | 22 | def load_test_bbox(image_list, label_classes, phase): 23 | img_dict = [] 24 | with open(image_list, 'r') as f: 25 | for line in f: 26 | img_name, classes, img_type, bbox = line[:-1].split(';') 27 | if img_type == '0' and phase == 'train': 28 | img_dict.append(img_name + ';' + str(label_classes[classes])) 29 | elif img_type == '1' and phase == 'test': ## test 30 | img_dict.append(img_name + ';' + str(label_classes[classes]) + ';' + bbox) 31 | return img_dict 32 | class ImageDataset(data.Dataset): 33 | def __init__(self, args, phase): 34 | args.num_classes = 196 35 | self.args =args 36 | self.root = 'Data/' + args.root 37 | self.image_list = self.root + '/' + 'images_box.txt' 38 | self.classes_list = self.root + '/' + 'classes.txt' 39 | self.crop_size = args.crop_size 40 | self.resize_size = args.resize_size 41 | self.phase = phase 42 | self.num_classes = args.num_classes 43 | if self.phase == 'train': 44 | self.transform = transforms.Compose([ 45 | transforms.Resize((self.resize_size, self.resize_size)), 46 | transforms.RandomCrop(self.crop_size), 47 | transforms.RandomHorizontalFlip(), 48 | transforms.ToTensor(), 49 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) 50 | ]) 51 | elif self.phase == 'test': 52 | self.transform = transforms.Compose([ 53 | transforms.Resize((self.resize_size, self.resize_size)), 54 | transforms.CenterCrop(self.crop_size), 55 | transforms.ToTensor(), 56 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) 57 | ]) 58 | self.label_classes = {} 59 | with open(self.classes_list, 'r') as f: 60 | for i, line in enumerate(f): 61 | self.label_classes[line[:-1]] = i 62 | self.index_list = load_test_bbox(self.image_list, self.label_classes, self.phase) 63 | 64 | def __getitem__(self, index): 65 | if self.phase == 'train': 66 | img_name, img_class = self.index_list[index].split(';') 67 | path = self.root + '/images/' + img_name + '.jpg' 68 | img = Image.open(path).convert('RGB') 69 | img = self.transform(img) 70 | img_class = int(img_class) 71 | return path, img, img_class 72 | else: 73 | img_name, img_class, bbox = self.index_list[index].split(';') 74 | path = self.root + '/images/' + img_name + '.jpg' 75 | img = Image.open(path).convert('RGB') 76 | img = self.transform(img) 77 | bbox = bbox_crop(bbox, self.resize_size, self.crop_size) 78 | img_class = int(img_class) 79 | return img, img_class, bbox, path 80 | 81 | def __len__(self): 82 | return len(self.index_list) -------------------------------------------------------------------------------- /WSOL/DataLoader/Stanford_Dogs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import torch.nn as nn 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | from PIL import Image 7 | import numpy as np 8 | import torch.utils.data as data 9 | from torchvision.datasets import ImageFolder 10 | 11 | def load_test_bbox(root, test_txt_path, test_gt_path,resize_size, crop_size): 12 | test_gt = [] 13 | test_txt = [] 14 | #resize_size = crop_size 15 | shift_size = (resize_size - crop_size) // 2. 16 | with open(test_txt_path, 'r') as f: 17 | for line in f: 18 | img_path = line.strip('\n').split(';')[0] 19 | test_txt.append(img_path) 20 | with open(test_gt_path, 'r') as f: 21 | for line in f: 22 | cur_box = [] 23 | box_num = (len(line.strip('\n').split(' ')) - 2) // 4 24 | h, w = float(line.strip('\n').split(' ')[-2]), float(line.strip('\n').split(' ')[-1]) 25 | for i in range(box_num): 26 | x0, y0, x1, y1 = line.strip('\n').split(' ')[0+4*i:4*(1+i)] 27 | x0, y0, x1, y1 = float(x0), float(y0), float(x1), float(y1) 28 | x0 = int(max(x0 / w * resize_size - shift_size, 0)) 29 | y0 = int(max(y0 / h * resize_size - shift_size, 0)) 30 | x1 = int(min(x1 / w * resize_size - shift_size, crop_size - 1)) 31 | y1 = int(min(y1 / h * resize_size - shift_size, crop_size - 1)) 32 | cur_box.append([x0, y0, x1, y1]) 33 | test_gt.append(np.array(cur_box)) 34 | final_dict = {} 35 | for k, v in zip(test_txt, test_gt): 36 | k = os.path.join(root, 'test', k) 37 | k = k.replace('/', '\\') 38 | final_dict[k] = v 39 | return final_dict 40 | 41 | class ImageDataset(data.Dataset): 42 | def __init__(self, args, phase): 43 | args.num_classes = 120 44 | self.args =args 45 | self.root = 'Data/' + args.root 46 | self.test_txt_path = self.root + '/' + 'test_list.txt' 47 | self.test_gt_path = self.root + '/' + 'test_gt.txt' 48 | self.crop_size = args.crop_size 49 | self.resize_size = args.resize_size 50 | self.phase = phase 51 | self.num_classes = args.num_classes 52 | if self.phase == 'train': 53 | self.img_dataset = ImageFolder(os.path.join(self.root, 'train')) 54 | self.transform = transforms.Compose([ 55 | transforms.Resize((self.resize_size, self.resize_size)), 56 | transforms.RandomCrop(self.crop_size), 57 | transforms.RandomHorizontalFlip(), 58 | transforms.ToTensor(), 59 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) 60 | ]) 61 | elif self.phase == 'test': 62 | self.img_dataset = ImageFolder(os.path.join(self.root, 'test')) 63 | self.transform = transforms.Compose([ 64 | transforms.Resize((self.resize_size, self.resize_size)), 65 | transforms.CenterCrop(self.crop_size), 66 | transforms.ToTensor(), 67 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) 68 | ]) 69 | 70 | self.label_classes = [] 71 | for k, v in self.img_dataset.class_to_idx.items(): 72 | self.label_classes.append(k) 73 | self.img_dataset = self.img_dataset.imgs 74 | self.test_bbox = load_test_bbox(self.root, self.test_txt_path, self.test_gt_path,self.resize_size, self.crop_size) 75 | 76 | def __getitem__(self, index): 77 | path, img_class = self.img_dataset[index] 78 | label = torch.zeros(self.num_classes) 79 | label[img_class] = 1 80 | img = Image.open(path).convert('RGB') 81 | img = self.transform(img) 82 | if self.phase == 'train': 83 | return path, img, img_class 84 | else: 85 | path = path.replace('/', '\\') 86 | bbox = self.test_bbox[path] 87 | bbox = np.array(bbox).reshape(-1) 88 | bbox = " ".join(list(map(str, bbox))) 89 | return img, img_class, bbox, path 90 | 91 | def __len__(self): 92 | return len(self.img_dataset) -------------------------------------------------------------------------------- /WSOL/DataLoader/__init__.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | 3 | __all__ = ['CUB_200_2011', 4 | 'OpenImage', 5 | 'fgvc_aircraft_2013b', 6 | 'Stanford_Dogs', 7 | 'Standford_Car', 8 | 'ILSVRC'] -------------------------------------------------------------------------------- /WSOL/DataLoader/fgvc_aircraft_2013b.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import torch.nn as nn 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | from PIL import Image 7 | import numpy as np 8 | import torch.utils.data as data 9 | from torchvision.datasets import ImageFolder 10 | 11 | def bbox_crop(bbox, resize_size, crop_size): 12 | shift_size = (resize_size - crop_size) // 2. 13 | x0, y0, x1, y1, h, w = bbox.split(' ') 14 | x0, y0, x1, y1, h, w = float(x0), float(y0), float(x1), float(y1), float(h), float(w) 15 | x0 = int(max(x0 / w * resize_size - shift_size, 0)) 16 | y0 = int(max(y0 / h * resize_size - shift_size, 0)) 17 | x1 = int(min(x1 / w * resize_size - shift_size, crop_size - 1)) 18 | y1 = int(min(y1 / h * resize_size - shift_size, crop_size - 1)) 19 | return np.array([x0, y0, x1, y1]).reshape(-1) 20 | 21 | def load_test_bbox(image_list, label_classes, phase): 22 | img_dict = [] 23 | with open(image_list, 'r') as f: 24 | for line in f: 25 | img_name, classes, img_type, bbox = line[:-1].split(';') 26 | if (img_type == '0' or img_type == '2') and phase == 'train': 27 | img_dict.append(img_name + ';' + str(label_classes[classes])) 28 | elif img_type == '1' and phase == 'test': ## test 29 | img_dict.append(img_name + ';' + str(label_classes[classes]) + ';' + bbox) 30 | return img_dict 31 | 32 | class ImageDataset(data.Dataset): 33 | def __init__(self, args): 34 | args.num_classes = 100 35 | self.args =args 36 | self.root = 'Data/' + args.root 37 | self.image_list = self.root + '/' + 'images_box.txt' 38 | self.classes_list = self.root + '/' + 'classes.txt' 39 | self.crop_size = args.crop_size 40 | self.resize_size = args.resize_size 41 | self.phase = args.phase 42 | self.num_classes = args.num_classes 43 | if self.phase == 'train': 44 | self.transform = transforms.Compose([ 45 | transforms.Resize((self.resize_size, self.resize_size)), 46 | transforms.RandomCrop(self.crop_size), 47 | transforms.RandomHorizontalFlip(), 48 | transforms.ToTensor(), 49 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) 50 | ]) 51 | elif self.phase == 'test': 52 | self.transform = transforms.Compose([ 53 | transforms.Resize((self.resize_size, self.resize_size)), 54 | transforms.CenterCrop(self.crop_size), 55 | transforms.ToTensor(), 56 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) 57 | ]) 58 | self.label_classes = {} 59 | with open(self.classes_list, 'r') as f: 60 | for i, line in enumerate(f): 61 | self.label_classes[line[:-1]] = i 62 | self.index_list = load_test_bbox(self.image_list, self.label_classes, self.phase) 63 | 64 | def __getitem__(self, index): 65 | if self.phase == 'train': 66 | img_name, img_class = self.index_list[index].split(';') 67 | path = self.root + '/images/' + img_name + '.jpg' 68 | img = Image.open(path).convert('RGB') 69 | img = self.transform(img) 70 | img_class = int(img_class) 71 | return path, img, img_class 72 | else: 73 | img_name, img_class, bbox = self.index_list[index].split(';') 74 | path = self.root + '/images/' + img_name + '.jpg' 75 | img = Image.open(path).convert('RGB') 76 | img = self.transform(img) 77 | bbox = bbox_crop(bbox, self.resize_size, self.crop_size) 78 | img_class = int(img_class) 79 | return img, img_class, bbox, path 80 | 81 | def __len__(self): 82 | return len(self.index_list) -------------------------------------------------------------------------------- /WSOL/Model/__init__.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | 3 | __all__ = ['vgg', 4 | 'resnet', 5 | 'inception', 6 | 'mobilenet',] -------------------------------------------------------------------------------- /WSOL/Model/mobilenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import cv2 7 | from skimage import measure 8 | from utils.func import * 9 | import torchvision.models as models 10 | from torch.nn import init 11 | 12 | class Model(nn.Module): 13 | def __init__(self, args): 14 | super(Model, self).__init__() 15 | def conv_bn(inp, oup, stride): 16 | return nn.Sequential( 17 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 18 | nn.BatchNorm2d(oup), 19 | nn.ReLU(inplace=True) 20 | ) 21 | def conv_dw(inp, oup, stride): 22 | return nn.Sequential( 23 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), 24 | nn.BatchNorm2d(inp), 25 | nn.ReLU(inplace=True), 26 | 27 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 28 | nn.BatchNorm2d(oup), 29 | nn.ReLU(inplace=True), 30 | ) 31 | self.model = nn.Sequential( 32 | conv_bn( 3, 32, 2), ## ->112 33 | conv_dw( 32, 64, 1), 34 | conv_dw( 64, 128, 2), ## ->56 35 | conv_dw(128, 128, 1), 36 | conv_dw(128, 256, 2), ## ->28 37 | conv_dw(256, 256, 1), 38 | conv_dw(256, 512, 1), 39 | conv_dw(512, 512, 1), 40 | conv_dw(512, 512, 1), 41 | conv_dw(512, 512, 1), 42 | conv_dw(512, 512, 1), 43 | conv_dw(512, 512, 1), 44 | conv_dw(512, 1024, 2), ## ->14 45 | conv_dw(1024, 1024, 1), 46 | ) 47 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 48 | self.classifier_cls = nn.Sequential( 49 | nn.Conv2d(1024, 1024, kernel_size=3, padding=1), 50 | nn.ReLU(inplace=True), 51 | nn.Conv2d(1024, 1024, kernel_size=3, padding=1), 52 | nn.ReLU(inplace=True), 53 | nn.Conv2d(1024, args.num_classes, kernel_size=1, padding=0), 54 | ) 55 | self.classifier_loc = nn.Sequential( 56 | nn.Conv2d(512, args.num_classes, kernel_size=3, padding=1), 57 | nn.Sigmoid(), 58 | ) 59 | self._initialize_weights() 60 | 61 | def forward(self, x, label=None, topk=1, evaluate=False): 62 | batch = x.size(0) 63 | 64 | x = self.model[:-2](x) 65 | x_4 = x.clone() 66 | x = self.model[-2:](x) 67 | x = self.classifier_cls(x) 68 | self.feature_map = x 69 | 70 | ## score 71 | x = self.avg_pool(x).view(x.size(0), -1) 72 | self.score_1 = x 73 | 74 | ## p_label 75 | if topk == 1: 76 | p_label = label.unsqueeze(-1) 77 | else: 78 | _, p_label = self.score_1.topk(topk, 1, True, True) 79 | 80 | ## S 81 | self.S = torch.zeros(batch).cuda() 82 | for i in range(batch): 83 | self.S[i] = self.score_1[i][label[i]] 84 | 85 | ## M_fg 86 | M = self.classifier_loc(x_4) 87 | M_fg = torch.zeros(batch, 1, 28, 28).cuda() 88 | for i in range(batch): 89 | M_fg[i][0] = M[i][p_label[i]].mean(0) 90 | self.M_fg = M_fg 91 | 92 | if evaluate: 93 | return self.score_1, None, None, self.M_fg 94 | 95 | ## weight_copy 96 | erase_branch = copy.deepcopy(self.model[-2:]) 97 | classifier_cls_copy = copy.deepcopy(self.classifier_cls) 98 | 99 | ## erase 100 | x_erase = x_4.clone().detach() * (1-M_fg) 101 | x_erase = erase_branch(x_erase) 102 | x_erase = classifier_cls_copy(x_erase) 103 | x_erase = self.avg_pool(x_erase).view(x_erase.size(0), -1) 104 | 105 | ## S_bg 106 | self.S_bg = torch.zeros(batch).cuda() 107 | for i in range(batch): 108 | self.S_bg[i] = x_erase[i][label[i]] 109 | 110 | ## score_2 111 | x = self.feature_map * nn.AvgPool2d(2)(self.M_fg) 112 | self.score_2 = self.avg_pool(x).squeeze(-1).squeeze(-1) 113 | 114 | ## bas 115 | S = nn.ReLU()(self.S.clone().detach()) 116 | S_bg = nn.ReLU()(self.S_bg).clone() 117 | bas = S_bg / (S+1e-8) 118 | bas[S_bg>S] = 1 119 | 120 | ## area 121 | area = self.M_fg.clone().view(batch, -1).mean(1) 122 | 123 | return self.score_1, self.score_2, bas, area, self.M_fg 124 | 125 | def _initialize_weights(self): 126 | for m in self.modules(): 127 | if isinstance(m, nn.Conv2d): 128 | nn.init.xavier_uniform_(m.weight.data) 129 | if m.bias is not None: 130 | m.bias.data.zero_() 131 | elif isinstance(m, nn.BatchNorm2d): 132 | m.weight.data.fill_(1) 133 | m.bias.data.zero_() 134 | elif isinstance(m, nn.Linear): 135 | m.weight.data.normal_(0, 0.01) 136 | m.bias.data.zero_() 137 | 138 | 139 | def model(args, pretrained=True): 140 | 141 | model = Model(args) 142 | if pretrained: 143 | pretrained_dict = torch.load('pretrained_weight/mobilenet_v1_with_relu_69_5.pth') 144 | model_dict = model.state_dict() 145 | model_conv_name = [] 146 | 147 | for i, (k, v) in enumerate(model_dict.items()): 148 | if 'tracked' in k[-7:]: 149 | continue 150 | model_conv_name.append(k) 151 | for i, (k, v) in enumerate(pretrained_dict.items()): 152 | model_dict[model_conv_name[i]] = v 153 | model.load_state_dict(model_dict) 154 | print("pretrained weight load complete..") 155 | return model -------------------------------------------------------------------------------- /WSOL/Model/resnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | from torch.utils.model_zoo import load_url 5 | from torch.autograd import Variable 6 | from utils.util import * 7 | import torch.nn.functional as F 8 | import numpy as np 9 | import cv2 10 | import copy 11 | import matplotlib.pyplot as plt 12 | 13 | class Bottleneck(nn.Module): 14 | expansion = 4 15 | 16 | def __init__(self, inplanes, planes, stride=1, downsample=None, 17 | base_width=64): 18 | super(Bottleneck, self).__init__() 19 | width = int(planes * (base_width / 64.)) 20 | self.conv1 = nn.Conv2d(inplanes, width, 1, bias=False) 21 | self.bn1 = nn.BatchNorm2d(width) 22 | self.conv2 = nn.Conv2d(width, width, 3, 23 | stride=stride, padding=1, bias=False) 24 | self.bn2 = nn.BatchNorm2d(width) 25 | self.conv3 = nn.Conv2d(width, planes * self.expansion, 1, bias=False) 26 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 27 | self.relu = nn.ReLU(inplace=True) 28 | self.downsample = downsample 29 | self.stride = stride 30 | 31 | def forward(self, x): 32 | identity = x 33 | 34 | out = self.conv1(x) 35 | out = self.bn1(out) 36 | out = self.relu(out) 37 | 38 | out = self.conv2(out) 39 | out = self.bn2(out) 40 | out = self.relu(out) 41 | 42 | out = self.conv3(out) 43 | out = self.bn3(out) 44 | 45 | if self.downsample is not None: 46 | identity = self.downsample(x) 47 | 48 | out += identity 49 | out = self.relu(out) 50 | 51 | return out 52 | 53 | 54 | class ResNetCam(nn.Module): 55 | def __init__(self, block, layers, args, num_classes=1000, 56 | large_feature_map=True): 57 | super(ResNetCam, self).__init__() 58 | self.args = args 59 | stride_l3 = 1 if large_feature_map else 2 60 | self.inplanes = 64 61 | 62 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, 63 | padding=3, bias=False) 64 | self.bn1 = nn.BatchNorm2d(self.inplanes) 65 | self.relu = nn.ReLU(inplace=True) 66 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 67 | 68 | self.layer1 = self._make_layer(block, 64, layers[0], stride=1) 69 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 70 | self.layer3 = self._make_layer(block, 256, layers[2], stride=stride_l3) 71 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1) 72 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 73 | 74 | self.classifier_cls = nn.Sequential( 75 | nn.Conv2d(2048, 1024, kernel_size=3, padding=1), 76 | nn.ReLU(inplace=True), 77 | nn.Conv2d(1024, 1024, kernel_size=3, padding=1), 78 | nn.ReLU(inplace=True), 79 | nn.Conv2d(1024, args.num_classes, kernel_size=1, padding=0), 80 | ) 81 | 82 | self.classifier_loc = nn.Sequential( 83 | nn.Conv2d(1024, args.num_classes, kernel_size=3, padding=1), 84 | nn.Sigmoid(), 85 | ) 86 | initialize_weights(self.modules(), init_mode='xavier') 87 | 88 | def forward(self, x, label=None, topk=1, evaluate=False ): 89 | batch = x.size(0) 90 | x = self.conv1(x) 91 | x = self.bn1(x) 92 | x = self.relu(x) 93 | x = self.maxpool(x) 94 | 95 | x = self.layer1(x) 96 | x = self.layer2(x) 97 | 98 | x = self.layer3(x) 99 | x_3 = x.clone() 100 | 101 | x = F.max_pool2d(x,kernel_size=2) 102 | x = self.layer4(x) 103 | 104 | x = self.classifier_cls(x) 105 | self.feature_map = x 106 | 107 | self.score_1 = self.avg_pool(x).squeeze(-1).squeeze(-1) 108 | 109 | # p_label 110 | if topk == 1: 111 | p_label = label.unsqueeze(-1) 112 | else: 113 | _, p_label = self.score_1.topk(topk, 1, True, True) 114 | # x_sum 115 | self.S = torch.zeros(batch).cuda() 116 | for i in range(batch): 117 | self.S[i] = self.score_1[i][label[i]] 118 | 119 | ## M_fg 120 | M = self.classifier_loc(x_3) 121 | M_fg = torch.zeros(batch, 1, 28, 28).cuda() 122 | for i in range(batch): 123 | M_fg[i][0] = M[i][p_label[i]].mean(0) 124 | self.M_fg = M_fg 125 | 126 | if evaluate: 127 | return self.score_1, None, None, self.M_fg 128 | 129 | ## weight_copy 130 | classifier_cls_copy = copy.deepcopy(self.classifier_cls) 131 | layer4_copy = copy.deepcopy(self.layer4) 132 | 133 | ## erase 134 | x_erase = x_3.clone().detach() * (1-M_fg) 135 | x_erase = F.max_pool2d(x_erase, kernel_size=2 ) 136 | x_erase = layer4_copy(x_erase) 137 | x_erase = classifier_cls_copy(x_erase) 138 | x_erase = self.avg_pool(x_erase).view(x_erase.size(0), -1) 139 | 140 | ## S_bg 141 | self.S_bg = torch.zeros(batch).cuda() 142 | for i in range(batch): 143 | self.S_bg[i] = x_erase[i][label[i]] 144 | 145 | ## score_2 146 | x = self.feature_map * nn.AvgPool2d(2)(self.M_fg) 147 | self.score_2 = self.avg_pool(x).squeeze(-1).squeeze(-1) 148 | 149 | ## bas 150 | batch = self.S.size(0) 151 | S = (nn.ReLU()(self.S.clone().detach())).view(batch,-1).mean(1) 152 | S_bg = (nn.ReLU()(self.S_bg)).view(batch,-1).mean(1) 153 | bas = S_bg / (S+1e-8) 154 | bas[S_bg>S] = 1 155 | 156 | ## area 157 | area = self.M_fg.clone().view(batch, -1).mean(1) 158 | 159 | return self.score_1, self.score_2, bas, area, self.M_fg 160 | 161 | 162 | def _make_layer(self, block, planes, blocks, stride): 163 | layers = self._layer(block, planes, blocks, stride) 164 | return nn.Sequential(*layers) 165 | 166 | def _layer(self, block, planes, blocks, stride): 167 | downsample = get_downsampling_layer(self.inplanes, block, planes, 168 | stride) 169 | 170 | layers = [block(self.inplanes, planes, stride, downsample)] 171 | self.inplanes = planes * block.expansion 172 | for _ in range(1, blocks): 173 | layers.append(block(self.inplanes, planes)) 174 | 175 | return layers 176 | 177 | 178 | def normalize_atten_maps(self, atten_maps): 179 | atten_shape = atten_maps.size() 180 | 181 | #-------------------------- 182 | batch_mins, _ = torch.min(atten_maps.view(atten_shape[0:-2] + (-1,)), dim=-1, keepdim=True) 183 | batch_maxs, _ = torch.max(atten_maps.view(atten_shape[0:-2] + (-1,)), dim=-1, keepdim=True) 184 | atten_normed = torch.div(atten_maps.view(atten_shape[0:-2] + (-1,))-batch_mins, 185 | batch_maxs - batch_mins + 1e-10) 186 | atten_normed = atten_normed.view(atten_shape) 187 | 188 | return atten_normed 189 | 190 | def get_downsampling_layer(inplanes, block, planes, stride): 191 | outplanes = planes * block.expansion 192 | if stride == 1 and inplanes == outplanes: 193 | return 194 | else: 195 | return nn.Sequential( 196 | nn.Conv2d(inplanes, outplanes, 1, stride, bias=False), 197 | nn.BatchNorm2d(outplanes), 198 | ) 199 | 200 | def load_pretrained_model(model): 201 | strict_rule = True 202 | 203 | state_dict = torch.load('pretrained_weight/resnet50-19c8e357.pth') 204 | 205 | state_dict = remove_layer(state_dict, 'fc') 206 | strict_rule = False 207 | 208 | model.load_state_dict(state_dict, strict=strict_rule) 209 | return model 210 | 211 | 212 | def model(args, pretrained=True): 213 | model = ResNetCam(Bottleneck, [3, 4, 6, 3], args) 214 | if pretrained: 215 | model = load_pretrained_model(model) 216 | return model 217 | -------------------------------------------------------------------------------- /WSOL/Model/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import random 7 | import cv2 8 | from skimage import measure 9 | from torch.nn.parameter import Parameter 10 | from utils.func import * 11 | import torch.distributed as dist 12 | class Model(nn.Module): 13 | def __init__(self, args): 14 | super(Model, self).__init__() 15 | self.args = args 16 | self.num_classes = args.num_classes 17 | self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, padding=1) ## -> 64x224x224 18 | self.relu1_1 = nn.ReLU(inplace=True) 19 | self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1) ## -> 64x224x224 20 | self.relu1_2 = nn.ReLU(inplace=True) 21 | self.pool1 = nn.MaxPool2d(2) ## -> 64x112x112 22 | 23 | self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1) ## -> 128x112x112 24 | self.relu2_1 = nn.ReLU(inplace=True) 25 | self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1) ## -> 128x112x112 26 | self.relu2_2 = nn.ReLU(inplace=True) 27 | self.pool2 = nn.MaxPool2d(2) ## -> 128x56x56 28 | 29 | self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1) ## -> 256x56x56 30 | self.relu3_1 = nn.ReLU(inplace=True) 31 | self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1) ## -> 256x56x56 32 | self.relu3_2 = nn.ReLU(inplace=True) 33 | self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1) ## -> 256x56x56 34 | self.relu3_3 = nn.ReLU(inplace=True) 35 | self.pool3 = nn.MaxPool2d(2) ## -> 256x28x28 36 | 37 | self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1) ## -> 512x28x28 38 | self.relu4_1 = nn.ReLU(inplace=True) 39 | self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1) ## -> 512x28x28 40 | self.relu4_2 = nn.ReLU(inplace=True) 41 | self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1) ## -> 512x28x28 42 | self.relu4_3 = nn.ReLU(inplace=True) 43 | self.pool4 = nn.MaxPool2d(2) ## -> 512x14x14 44 | 45 | self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1) ## -> 512x14x14 46 | self.relu5_1 = nn.ReLU(inplace=True) 47 | self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1) ## -> 512x14x14 48 | self.relu5_2 = nn.ReLU(inplace=True) 49 | self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1) ## -> 512x14x14 50 | self.relu5_3 = nn.ReLU(inplace=True) 51 | self.pool5 = nn.MaxPool2d(2) 52 | 53 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 54 | 55 | self.classifier_cls = nn.Sequential( 56 | nn.Conv2d(512, 1024, kernel_size=3, padding=1), 57 | nn.ReLU(inplace=True), 58 | nn.Conv2d(1024, 1024, kernel_size=3, padding=1), 59 | nn.ReLU(inplace=True), 60 | nn.Conv2d(1024, args.num_classes, kernel_size=1, padding=0), 61 | ) 62 | 63 | self.classifier_loc = nn.Sequential( 64 | nn.Conv2d(512, args.num_classes, kernel_size=3, padding=1), 65 | nn.Sigmoid(), 66 | ) 67 | 68 | def forward(self, x, label=None, topk=1, evaluate=False): 69 | batch = x.size(0) 70 | 71 | x = self.conv1_1(x) 72 | x = self.relu1_1(x) 73 | x = self.conv1_2(x) 74 | x = self.relu1_2(x) 75 | x = self.pool1(x) 76 | 77 | x = self.conv2_1(x) 78 | x = self.relu2_1(x) 79 | x = self.conv2_2(x) 80 | x = self.relu2_2(x) 81 | x = self.pool2(x) 82 | 83 | x = self.conv3_1(x) 84 | x = self.relu3_1(x) 85 | x = self.conv3_2(x) 86 | x = self.relu3_2(x) 87 | x = self.conv3_3(x) 88 | x = self.relu3_3(x) 89 | 90 | x = self.pool3(x) 91 | x = self.conv4_1(x) 92 | x = self.relu4_1(x) 93 | x = self.conv4_2(x) 94 | x = self.relu4_2(x) 95 | x = self.conv4_3(x) 96 | x = self.relu4_3(x) 97 | x_4 = x.clone() 98 | 99 | x = self.pool4(x) 100 | 101 | x = self.conv5_1(x) 102 | x = self.relu5_1(x) 103 | x = self.conv5_2(x) 104 | x = self.relu5_2(x) 105 | x = self.conv5_3(x) 106 | x = self.relu5_3(x) 107 | 108 | x = self.classifier_cls(x) 109 | self.feature_map = x 110 | 111 | ## score 112 | x = self.avg_pool(x).view(x.size(0), -1) 113 | self.score_1 = x 114 | 115 | ## p_label 116 | if topk == 1: 117 | p_label = label.unsqueeze(-1) 118 | else: 119 | _, p_label = self.score_1.topk(topk, 1, True, True) 120 | 121 | ## S 122 | self.S = torch.zeros(batch).cuda() 123 | for i in range(batch): 124 | self.S[i] = self.score_1[i][label[i]] 125 | 126 | ## M_fg 127 | M = self.classifier_loc(x_4) 128 | M_fg = torch.zeros(batch, 1, 28, 28).cuda() 129 | for i in range(batch): 130 | M_fg[i][0] = M[i][p_label[i]].mean(0) 131 | self.M_fg = M_fg 132 | 133 | if evaluate: 134 | return self.score_1, None, None, self.M_fg 135 | 136 | ## weight_copy 137 | conv_copy_5_1 = copy.deepcopy(self.conv5_1) 138 | conv_copy_5_2 = copy.deepcopy(self.conv5_2) 139 | conv_copy_5_3 = copy.deepcopy(self.conv5_3) 140 | classifier_cls_copy = copy.deepcopy(self.classifier_cls) 141 | 142 | ## erase 143 | x_erase = x_4.clone().detach() * (1 - M_fg) 144 | x_erase = self.pool4(x_erase) 145 | x_erase = conv_copy_5_1(x_erase) 146 | x_erase = self.relu5_1(x_erase) 147 | x_erase = conv_copy_5_2(x_erase) 148 | x_erase = self.relu5_2(x_erase) 149 | x_erase = conv_copy_5_3(x_erase) 150 | x_erase = self.relu5_3(x_erase) 151 | x_erase = classifier_cls_copy(x_erase) 152 | x_erase = self.avg_pool(x_erase).view(x_erase.size(0), -1) 153 | 154 | ## S_bg 155 | self.S_bg = torch.zeros(batch).cuda() 156 | for i in range(batch): 157 | self.S_bg[i] = x_erase[i][label[i]] 158 | 159 | ## score_2 160 | x = self.feature_map * nn.AvgPool2d(2)(self.M_fg) 161 | self.score_2 = self.avg_pool(x).squeeze(-1).squeeze(-1) 162 | 163 | ## bas 164 | S = nn.ReLU()(self.S.clone().detach()) 165 | S_bg = nn.ReLU()(self.S_bg).clone() 166 | bas = S_bg / (S + 1e-8) 167 | bas[S_bg>S] = 0 168 | 169 | ## area 170 | area = self.M_fg.clone().view(batch, -1).mean(1) 171 | 172 | return self.score_1, self.score_2, bas, area, self.M_fg 173 | 174 | 175 | def normalize_atten_maps(self, atten_maps): 176 | atten_shape = atten_maps.size() 177 | 178 | #-------------------------- 179 | batch_mins, _ = torch.min(atten_maps.view(atten_shape[0:-2] + (-1,)), dim=-1, keepdim=True) 180 | batch_maxs, _ = torch.max(atten_maps.view(atten_shape[0:-2] + (-1,)), dim=-1, keepdim=True) 181 | atten_normed = torch.div(atten_maps.view(atten_shape[0:-2] + (-1,))-batch_mins, 182 | batch_maxs - batch_mins + 1e-10) 183 | atten_normed = atten_normed.view(atten_shape) 184 | 185 | return atten_normed 186 | 187 | def weight_init(m): 188 | classname = m.__class__.__name__ ## __class__将实例指向类,然后调用类的__name__属性 189 | if classname.find('Conv') != -1: 190 | m.weight.data.normal_(0.0, 0.02) 191 | elif classname.find('BatchNorm') != -1: 192 | m.weight.data.normal_(1.0, 0.02) 193 | m.bias.data.fill_(0) 194 | elif classname.find('Linear') != -1: 195 | m.weight.data.normal_(0, 0.01) 196 | m.weight.data.fill_(0) 197 | 198 | def model(args, pretrained=True): 199 | 200 | model = Model(args) 201 | if pretrained: 202 | model.apply(weight_init) 203 | pretrained_dict = torch.load('pretrained_weight/vgg16.pth') 204 | model_dict = model.state_dict() 205 | model_conv_name = [] 206 | 207 | for i, (k, v) in enumerate(model_dict.items()): 208 | model_conv_name.append(k) 209 | 210 | for i, (k, v) in enumerate(pretrained_dict.items()): 211 | if k.split('.')[0] != 'features': 212 | break 213 | if np.shape(model_dict[model_conv_name[i]]) == np.shape(v): 214 | model_dict[model_conv_name[i]] = v 215 | model.load_state_dict(model_dict) 216 | return model -------------------------------------------------------------------------------- /WSOL/count_pxap.py: -------------------------------------------------------------------------------- 1 | from cProfile import label 2 | import numpy as np 3 | import cv2 4 | import os 5 | from DataLoader import * 6 | import argparse 7 | import torch 8 | from torch.autograd import Variable 9 | from Model import * 10 | from utils.accuracy import * 11 | import matplotlib.pyplot as plt 12 | _RESIZE_LENGTH = 224 13 | 14 | def load_mask_image(file_path, resize_size): 15 | mask = np.float32(cv2.imread(file_path, cv2.IMREAD_GRAYSCALE)) 16 | mask = cv2.resize(mask, resize_size, interpolation=cv2.INTER_NEAREST) 17 | return mask 18 | 19 | 20 | def get_mask(mask_root, mask_paths, ignore_path): 21 | mask_all_instances = [] 22 | for mask_path in mask_paths: 23 | mask_file = os.path.join(mask_root, mask_path) 24 | mask = load_mask_image(mask_file, (_RESIZE_LENGTH, _RESIZE_LENGTH)) 25 | mask_all_instances.append(mask > 0.5) 26 | mask_all_instances = np.stack(mask_all_instances, axis=0).any(axis=0) 27 | 28 | ignore_file = os.path.join(mask_root, ignore_path) 29 | ignore_box_mask = load_mask_image(ignore_file, 30 | (_RESIZE_LENGTH, _RESIZE_LENGTH)) 31 | ignore_box_mask = ignore_box_mask > 0.5 32 | 33 | ignore_mask = np.logical_and(ignore_box_mask, 34 | np.logical_not(mask_all_instances)) 35 | 36 | if np.logical_and(ignore_mask, mask_all_instances).any(): 37 | raise RuntimeError("Ignore and foreground masks intersect.") 38 | 39 | return (mask_all_instances.astype(np.uint8) + 40 | 255 * ignore_mask.astype(np.uint8)) 41 | 42 | def get_mask_paths(metadata): 43 | mask_paths = {} 44 | ignore_paths = {} 45 | with open(metadata) as f: 46 | for line in f.readlines(): 47 | image_id, mask_path, ignore_path = line.strip('\n').split(',') 48 | if image_id in mask_paths: 49 | mask_paths[image_id].append(mask_path) 50 | assert (len(ignore_path) == 0) 51 | else: 52 | mask_paths[image_id] = [mask_path] 53 | ignore_paths[image_id] = ignore_path 54 | return mask_paths, ignore_paths 55 | 56 | def get_mask_paths_cubmask(metadata): 57 | mask_paths = {} 58 | ignore_paths = {} 59 | with open(metadata) as f: 60 | for line in f.readlines(): 61 | image_id, mask_path, ignore_path = line.strip('\n').split(',') 62 | if image_id in mask_paths: 63 | mask_paths[image_id].append(mask_path) 64 | assert (len(ignore_path) == 0) 65 | else: 66 | mask_paths['/'.join(image_id.split('/')[1:])] = [mask_path] 67 | ignore_paths['/'.join(image_id.split('/')[1:])] = ignore_path 68 | return mask_paths, ignore_paths 69 | 70 | def normalize_map(atten_map,shape): 71 | min_val = np.min(atten_map) 72 | max_val = np.max(atten_map) 73 | atten_norm = (atten_map - min_val)/(max_val - min_val + 1e-5) 74 | atten_norm = cv2.resize(atten_norm, dsize=(shape[1],shape[0])) 75 | return atten_norm 76 | 77 | 78 | def val_pxap_one_epoch(args, Val_Loader, model, epoch=0): 79 | if args.evaluate == False: 80 | return None, None, None 81 | cam_threshold_list = list(np.arange(0, 1, 0.001)) 82 | threshold_list_right_edge = np.append(cam_threshold_list, [1.0, 2.0, 3.0]) 83 | num_bins = len(cam_threshold_list) + 2 84 | gt_true_score_hist = np.zeros(num_bins, dtype=np.float) 85 | gt_false_score_hist = np.zeros(num_bins, dtype=np.float) 86 | if args.root == 'OpenImage': 87 | mask_paths, ignore_paths = get_mask_paths('/home2/Datasets/OpenImage/test_mask.txt') 88 | elif args.root == 'CUB_200_2011': 89 | mask_paths, ignore_paths = get_mask_paths_cubmask('/home2/Datasets/CUB_200_2011/CUBMask/test/localization.txt') 90 | for step, (img_batch, img_tencrop, img_class, box, path_batch) in enumerate(Val_Loader): ## batch_size=1 91 | with torch.no_grad(): 92 | batch = img_batch.size(0) 93 | img_batch = Variable(img_batch).cuda() 94 | output1, _, _, fmap_batch = model(img_batch, img_class, args.topk, evaluate=True) 95 | for i in range(batch): 96 | shape = (224, 224) 97 | path = path_batch[i].replace('\\','/') 98 | img_name = path.split('/')[-1] 99 | img_path = path.split('/')[-2] + '/' + path.split('/')[-1] 100 | if args.root == 'OpenImage': 101 | gt_mask = get_mask('/home2/Datasets/OpenImage/mask/', 102 | mask_paths['test/' + img_path], 103 | ignore_paths['test/' + img_path]) 104 | elif args.root == 'CUB_200_2011': 105 | gt_mask = get_mask('/home2/Datasets/CUB_200_2011/CUBMask/', 106 | mask_paths[img_path], 107 | ignore_paths[img_path]) 108 | fmap = np.array(fmap_batch.data.cpu())[i][0] 109 | fmap = normalize_map(fmap, shape) 110 | 111 | gt_true_scores = fmap[gt_mask == 1] 112 | gt_false_scores = fmap[gt_mask == 0] 113 | # histograms in ascending order 114 | gt_true_hist, _ = np.histogram(gt_true_scores, 115 | bins=threshold_list_right_edge) 116 | gt_true_score_hist += gt_true_hist.astype(np.float) 117 | 118 | gt_false_hist, _ = np.histogram(gt_false_scores, 119 | bins=threshold_list_right_edge) 120 | gt_false_score_hist += gt_false_hist.astype(np.float) 121 | 122 | if args.save_img_flag: 123 | ori_img = cv2.imread(path) 124 | ori_img = cv2.resize(ori_img, (224,224)) 125 | heatmap = np.uint8(255 * fmap) 126 | heatmap = heatmap.astype(np.uint8) 127 | heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) 128 | img_add = cv2.addWeighted(ori_img.astype(np.uint8), 0.5, heatmap.astype(np.uint8), 0.5, 0) 129 | if os.path.exists('logs/' + args.root + '/' + args.arch + '/save_pxap_img') == 0: 130 | os.mkdir('logs/' + args.root + '/' + args.arch + '/save_pxap_img') 131 | cv2.imwrite('logs/' + args.root + '/' + args.arch + '/save_pxap_img/' + img_name, img_add) 132 | 133 | 134 | num_gt_true = gt_true_score_hist.sum() 135 | tp = gt_true_score_hist[::-1].cumsum() 136 | fn = num_gt_true - tp 137 | 138 | num_gt_false = gt_false_score_hist.sum() 139 | fp = gt_false_score_hist[::-1].cumsum() 140 | tn = num_gt_false - fp 141 | 142 | if ((tp + fn) <= 0).all(): 143 | raise RuntimeError("No positive ground truth in the eval set.") 144 | if ((tp + fp) <= 0).all(): 145 | raise RuntimeError("No positive prediction in the eval set.") 146 | 147 | non_zero_indices = (tp + fp) != 0 148 | 149 | precision = tp / (tp + fp) 150 | recall = tp / (tp + fn) 151 | 152 | IoU = tp / (fp + fn + tp) 153 | IoU*= 100 154 | 155 | PxAP = (precision[1:] * np.diff(recall))[non_zero_indices[1:]].sum() 156 | PxAP *= 100 157 | 158 | pIoU = max(IoU) 159 | print('Val Epoch : [{}] \tPxAP: {:.2f} \t pIoU : {:.2f} \n'.format(epoch, PxAP, pIoU)) 160 | 161 | return PxAP, pIoU, 0 162 | 163 | if __name__ == '__main__': 164 | parser = argparse.ArgumentParser() 165 | parser.add_argument('--crop_size', type=int, default=224) 166 | parser.add_argument('--resize_size', type=int, default=224) 167 | parser.add_argument('--root', type=str, help="[CUB_200_2011, OpenImage]", 168 | default='OpenImage') 169 | parser.add_argument('--num_classes', type=int, default=200) 170 | parser.add_argument('--batch_size', type=int, default=64) 171 | parser.add_argument('--num_workers', type=int, default=2) 172 | parser.add_argument('--topk', type=int, default=1) 173 | parser.add_argument('--arch', type=str, default='resnet') ## choose [ vgg, resnet, inception, mobilenet ] 174 | parser.add_argument('--evaluate', type=bool, default=True) 175 | parser.add_argument('--save_img_flag', type=bool, default=False) 176 | parser.add_argument('--tencrop', type=bool, default=False) 177 | parser.add_argument('--gpu', type=str, default='3') 178 | args = parser.parse_args() 179 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 180 | 181 | MyData = eval(args.root).ImageDataset(args, phase='test') 182 | Val_Loader = torch.utils.data.DataLoader(dataset=MyData, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) 183 | ## model 184 | model = eval(args.arch).model(args, pretrained=False) 185 | model.cuda(device=0) 186 | checkpoint = torch.load('logs/' + args.root + '/' + args.arch + '/' + 'best_loc.pth.tar') ## best_top1 best_gt best_cls epoch_99 187 | checkpoint['model'] = {k[7:]: v for k, v in checkpoint['model'].items()} 188 | model.load_state_dict(checkpoint['model']) 189 | model.eval() 190 | 191 | val_pxap_one_epoch(args, Val_Loader, model) 192 | 193 | -------------------------------------------------------------------------------- /WSOL/logs/CUB_200_2011/inception/save_weight.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wpy1999/BAS-Extension/bd1bfa89efd8688dd185a8f34051b2786ac31eba/WSOL/logs/CUB_200_2011/inception/save_weight.txt -------------------------------------------------------------------------------- /WSOL/logs/CUB_200_2011/mobilenet/save_weight.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wpy1999/BAS-Extension/bd1bfa89efd8688dd185a8f34051b2786ac31eba/WSOL/logs/CUB_200_2011/mobilenet/save_weight.txt -------------------------------------------------------------------------------- /WSOL/logs/CUB_200_2011/resnet/save_weight.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wpy1999/BAS-Extension/bd1bfa89efd8688dd185a8f34051b2786ac31eba/WSOL/logs/CUB_200_2011/resnet/save_weight.txt -------------------------------------------------------------------------------- /WSOL/logs/CUB_200_2011/vgg/save_weight.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wpy1999/BAS-Extension/bd1bfa89efd8688dd185a8f34051b2786ac31eba/WSOL/logs/CUB_200_2011/vgg/save_weight.txt -------------------------------------------------------------------------------- /WSOL/logs/ILSVRC/inception/save_weight.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wpy1999/BAS-Extension/bd1bfa89efd8688dd185a8f34051b2786ac31eba/WSOL/logs/ILSVRC/inception/save_weight.txt -------------------------------------------------------------------------------- /WSOL/logs/ILSVRC/mobilenet/save_weight.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wpy1999/BAS-Extension/bd1bfa89efd8688dd185a8f34051b2786ac31eba/WSOL/logs/ILSVRC/mobilenet/save_weight.txt -------------------------------------------------------------------------------- /WSOL/logs/ILSVRC/resnet/save_weight.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wpy1999/BAS-Extension/bd1bfa89efd8688dd185a8f34051b2786ac31eba/WSOL/logs/ILSVRC/resnet/save_weight.txt -------------------------------------------------------------------------------- /WSOL/logs/ILSVRC/vgg/save_weight.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wpy1999/BAS-Extension/bd1bfa89efd8688dd185a8f34051b2786ac31eba/WSOL/logs/ILSVRC/vgg/save_weight.txt -------------------------------------------------------------------------------- /WSOL/logs/OpenImage/resnet/save_weight.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wpy1999/BAS-Extension/bd1bfa89efd8688dd185a8f34051b2786ac31eba/WSOL/logs/OpenImage/resnet/save_weight.txt -------------------------------------------------------------------------------- /WSOL/train_ILSVRC.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | from Model import * 6 | from DataLoader import * 7 | from torch.autograd import Variable 8 | #from evaluator import val_loc_one_epoch 9 | from skimage import measure 10 | import torch.distributed as dist 11 | from torch.nn.parallel import DistributedDataParallel as DDP 12 | from utils.hyperparameters import get_hyperparameters 13 | from utils.accuracy import * 14 | from utils.lr import * 15 | from utils.optimizer import * 16 | from utils.func import * 17 | from utils.util import * 18 | import pprint 19 | import os 20 | import random 21 | import shutil 22 | def set_seed(seed): 23 | np.random.seed(seed) 24 | torch.manual_seed(seed) 25 | torch.cuda.manual_seed(seed) 26 | random.seed(seed) 27 | torch.cuda.manual_seed_all(seed) 28 | 29 | torch.backends.cudnn.deterministic = True 30 | torch.backends.cudnn.benchmark = False 31 | 32 | def makedirs(dirs): 33 | if not os.path.exists(dirs): 34 | os.makedirs(dirs) 35 | 36 | class opts(object): 37 | def __init__(self): 38 | self.parser = argparse.ArgumentParser() 39 | ## path 40 | self.parser.add_argument('--root', type=str, help="[CUB_200_2011, ILSVRC, OpenImage, Fgvc_aircraft_2013b, Standford_Car, Stanford_Dogs]", 41 | default='ILSVRC') 42 | self.parser.add_argument('--num_classes', type=int, default=1000) 43 | self.parser.add_argument('--seed', type=int, default=0) 44 | ## save 45 | self.parser.add_argument('--save_path', type=str, default='logs') 46 | self.parser.add_argument('--log_file', type=str, default='log.txt') 47 | self.parser.add_argument('--log_code_dir', type=str, default='save_code') 48 | self.parser.add_argument('--threshold', type=float, default=[0.5]) 49 | self.parser.add_argument('--evaluate', type=bool, default=False) 50 | self.parser.add_argument('--save_img_flag', type=bool, default=False) 51 | self.parser.add_argument('--tencrop', type=bool, default=False) 52 | ## dataloader 53 | self.parser.add_argument('--crop_size', type=int, default=224) 54 | self.parser.add_argument('--resize_size', type=int, default=256) 55 | self.parser.add_argument('--num_workers', type=int, default=2) 56 | self.parser.add_argument("--local_rank", type=int,default=-1) 57 | ## train 58 | self.parser.add_argument('--batch_size', type=int, default=64) 59 | self.parser.add_argument('--epochs', type=int, default=9) 60 | self.parser.add_argument('--phase', type=str, default='train') ## train / test 61 | self.parser.add_argument('--lr', type=float, default=0.001) 62 | self.parser.add_argument('--weight_decay', type=float, default=5e-4) 63 | self.parser.add_argument('--power', type=float, default=0.9) 64 | self.parser.add_argument('--momentum', type=float, default=0.9) 65 | ## model 66 | self.parser.add_argument('--arch', type=str, help="[vgg, resnet, inception, mobilenet]", 67 | default='resnet') 68 | ## GPU' 69 | 70 | def parse(self): 71 | opt = self.parser.parse_args() 72 | opt.arch = opt.arch 73 | return opt 74 | 75 | args = opts().parse() 76 | 77 | ## default setting (topk, threshold) 78 | args = get_hyperparameters(args) 79 | args.batch_size = args.batch_size // int(torch.cuda.device_count()) 80 | print(args.batch_size) 81 | 82 | ## save_log_txt 83 | makedirs(args.save_path + '/' + args.root + '/' + args.arch + '/' + args.log_code_dir) 84 | sys.stdout = Logger(args.save_path + '/' + args.root + '/' + args.arch + '/' + args.log_code_dir + '/' + args.log_file) 85 | sys.stdout.log.flush() 86 | 87 | ## save_code 88 | save_file = ['train.py', 'train_ILSVRC.py', 'evaluator.py', 'count_pxap.py'] 89 | for file_name in save_file: 90 | shutil.copyfile(file_name, args.save_path + '/' + args.root + '/' + args.arch + '/' + args.log_code_dir + '/' + file_name) 91 | save_dir = ['Model', 'utils', 'DataLoader'] 92 | for dir_name in save_dir: 93 | copy_dir(dir_name, args.save_path + '/' + args.root + '/' + args.arch + '/' + args.log_code_dir + '/' + dir_name) 94 | 95 | if __name__ == '__main__': 96 | local_rank = args.local_rank 97 | torch.cuda.set_device(local_rank) 98 | dist.init_process_group(backend='nccl') 99 | device = torch.device("cuda", local_rank) 100 | set_seed(args.seed + dist.get_rank()) 101 | TrainData = eval(args.root).ImageDataset(args) 102 | train_sampler = torch.utils.data.distributed.DistributedSampler(TrainData) 103 | Train_Loader = torch.utils.data.DataLoader(dataset=TrainData, batch_size=args.batch_size , sampler=train_sampler, num_workers=args.num_workers, pin_memory=True) 104 | 105 | ## model 106 | model = eval(args.arch).model(args, pretrained=True) 107 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model).to(local_rank) 108 | model = DDP(model, find_unused_parameters=False , device_ids=[local_rank], output_device=local_rank) 109 | 110 | ## optimizer 111 | optimizer = get_optimizer(model, args) 112 | loss_func = nn.CrossEntropyLoss().to(local_rank) 113 | best_gt = 0 114 | best_top1 = 0 115 | best_cls = 0 116 | if dist.get_rank() == 0: 117 | print('Train begining!') 118 | for epoch in range(0, args.epochs): 119 | Train_Loader.sampler.set_epoch(epoch) 120 | ## accuracy 121 | cls_acc_1 = AverageMeter() 122 | cls_acc_2 = AverageMeter() 123 | loss_epoch_1 = AverageMeter() 124 | loss_epoch_2 = AverageMeter() 125 | loss_epoch_3 = AverageMeter() 126 | loss_epoch_4 = AverageMeter() 127 | poly_lr_scheduler(optimizer, epoch, decay_epoch=3) 128 | model.train() 129 | for step, (path, imgs, label) in enumerate(Train_Loader): 130 | imgs, label = Variable(imgs).to(local_rank), label.to(local_rank) 131 | ## backward 132 | optimizer.zero_grad() 133 | output1, output2, bas, area, map = model(imgs, label, 1) 134 | label = label.long() 135 | pred = torch.max(output1, 1)[1] 136 | loss_1 = loss_func(output1, label).cuda() 137 | loss_2 = loss_func(output2, label).cuda() 138 | bas, area = bas.mean(0), area.mean(0) 139 | loss = loss_1 + loss_2 * args.alpha + bas + area * args.beta 140 | loss.backward() 141 | optimizer.step() 142 | ## count_accuracy 143 | cur_batch = label.size(0) 144 | cur_cls_acc_1 = 100. * compute_cls_acc(output1, label) 145 | cls_acc_1.updata(cur_cls_acc_1, cur_batch) 146 | cur_cls_acc_2 = 100. * compute_cls_acc(output2, label) 147 | cls_acc_2.updata(cur_cls_acc_2, cur_batch) 148 | loss_epoch_1.updata(loss_1.data, 1) 149 | loss_epoch_2.updata(loss_2.data, 1) 150 | loss_epoch_3.updata(bas.data, 1) 151 | loss_epoch_4.updata(area.data, 1) 152 | if (step+1) % (len(Train_Loader) // 5) == 0 and args.root=='ILSVRC' and dist.get_rank() == 0: 153 | iter = int((step+1) // (len(Train_Loader) // 5)) 154 | print('Epoch:[{}/{}]\tstep:[{}/{}]\tCLS:{:.3f}\tFRG:{:.3f}\tBAS:{:.2f}\tArea:{:.2f}\tepoch_acc:{:.2f}%\tepoch_acc2:{:.2f}%'.format( 155 | epoch+1, args.epochs, step+1, len(Train_Loader), loss_epoch_1.avg,loss_epoch_2.avg,loss_epoch_3.avg,loss_epoch_4.avg,cls_acc_1.avg,cls_acc_2.avg 156 | )) 157 | sys.stdout.log.flush() 158 | if dist.get_rank() == 0: 159 | print('Epoch:[{}/{}]\tstep:[{}/{}]\tCLS:{:.3f}\tFRG:{:.3f}\tBAS:{:.2f}\tArea:{:.2f}\tepoch_acc:{:.2f}%\tepoch_acc2:{:.2f}%'.format( 160 | epoch+1, args.epochs, step+1, len(Train_Loader), loss_epoch_1.avg,loss_epoch_2.avg,loss_epoch_3.avg,loss_epoch_4.avg,cls_acc_1.avg,cls_acc_2.avg 161 | )) 162 | sys.stdout.log.flush() 163 | torch.save({'model':model.state_dict(), 164 | 'best_thr':0, 165 | 'epoch':epoch+1, 166 | }, os.path.join(args.save_path, args.root + '/' +args.arch + '/' + 'epoch_'+ str(epoch+1) +'.pth.tar'), _use_new_zipfile_serialization=False) 167 | 168 | 169 | -------------------------------------------------------------------------------- /WSOL/utils/IoU.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def IoU(boxA, boxB): 4 | xA = max(boxA[0], boxB[0]) 5 | yA = max(boxA[1], boxB[1]) 6 | xB = min(boxA[2], boxB[2]) 7 | yB = min(boxA[3], boxB[3]) 8 | 9 | interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1) 10 | boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1) 11 | boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1) 12 | 13 | iou = interArea / float(boxAArea + boxBArea - interArea) 14 | return iou -------------------------------------------------------------------------------- /WSOL/utils/accuracy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from utils.IoU import IoU 4 | 5 | def compute_cls_acc(preds, label): 6 | pred = torch.max(preds, 1)[1] 7 | #label = torch.max(labels, 1)[1] 8 | num_correct = (pred == label).sum() 9 | return float(num_correct)/ float(preds.size(0)) 10 | 11 | def compute_loc_acc(loc_preds, loc_gt, theta=0.5): 12 | correct = 0 13 | for i in range(loc_preds.size(0)): 14 | iou = IoU(loc_preds[i], loc_gt[i]) 15 | correct += (iou >= theta).sum() 16 | return float(correct) / float(loc_preds.size(0)) 17 | 18 | class AverageMeter(object): 19 | def __init__(self): 20 | self.reset() 21 | 22 | def reset(self): 23 | self.val = 0.0 24 | self.avg = 0.0 25 | self.sum = 0.0 26 | self.cnt = 0.0 27 | 28 | def updata(self, val, n=1.0): 29 | self.val = val 30 | self.sum += val * n 31 | self.cnt += n 32 | if self.cnt == 0: 33 | self.avg = 1 34 | else: 35 | self.avg = self.sum / self.cnt 36 | 37 | def accuracy(output, target, topk=(1,)): 38 | maxk = max(topk) 39 | batch_size = target.size(0) 40 | 41 | _, pred = output.topk(maxk, 1, True, True) 42 | pred = pred.t() 43 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 44 | 45 | res = [] 46 | for k in topk: 47 | correct_k = correct[:k].reshape(-1).float().sum(0) 48 | res.append(correct_k.mul_(100.0 / batch_size)) 49 | return res -------------------------------------------------------------------------------- /WSOL/utils/augment.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import copy 3 | import numbers 4 | from torchvision.transforms import functional as F 5 | from .func import * 6 | import random 7 | import warnings 8 | import math 9 | 10 | class RandomHorizontalFlipBBox(object): 11 | def __init__(self, p=0.5): 12 | self.p = p 13 | 14 | def __call__(self, img, bbox): 15 | if random.random() < self.p: 16 | flipbox = copy.deepcopy(bbox) 17 | flipbox[0] = 1-bbox[2] 18 | flipbox[2] = 1-bbox[0] 19 | return F.hflip(img), flipbox 20 | 21 | return img, bbox 22 | class RandomResizedBBoxCrop(object): 23 | """Crop the given PIL Image to random size and aspect ratio. 24 | 25 | A crop of random size (default: of 0.08 to 1.0) of the original size and a random 26 | aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop 27 | is finally resized to given size. 28 | This is popularly used to train the Inception networks. 29 | 30 | Args: 31 | size: expected output size of each edge 32 | scale: range of size of the origin size cropped 33 | ratio: range of aspect ratio of the origin aspect ratio cropped 34 | interpolation: Default: PIL.Image.BILINEAR 35 | """ 36 | 37 | def __init__(self, size, scale=(0.2, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR): 38 | if isinstance(size, tuple): 39 | self.size = size 40 | else: 41 | self.size = (size, size) 42 | if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): 43 | warnings.warn("range should be of kind (min, max)") 44 | 45 | self.interpolation = interpolation 46 | self.scale = scale 47 | self.ratio = ratio 48 | 49 | @staticmethod 50 | def get_params(img, bbox, scale, ratio): 51 | """Get parameters for ``crop`` for a random sized crop. 52 | 53 | Args: 54 | img (PIL Image): Image to be cropped. 55 | scale (tuple): range of size of the origin size cropped 56 | ratio (tuple): range of aspect ratio of the origin aspect ratio cropped 57 | 58 | Returns: 59 | tuple: params (i, j, h, w) to be passed to ``crop`` for a random 60 | sized crop. 61 | """ 62 | 63 | area = img.size[0] * img.size[1] 64 | 65 | for attempt in range(30): 66 | target_area = random.uniform(*scale) * area 67 | log_ratio = (math.log(ratio[0]), math.log(ratio[1])) 68 | aspect_ratio = math.exp(random.uniform(*log_ratio)) 69 | 70 | w = int(round(math.sqrt(target_area * aspect_ratio))) 71 | h = int(round(math.sqrt(target_area / aspect_ratio))) 72 | 73 | if w <= img.size[0] and h <= img.size[1]: 74 | 75 | i = random.randint(0, img.size[1] - h) # i is y actually 76 | j = random.randint(0, img.size[0] - w) # j is x 77 | 78 | #compute intersection between crop image and bbox 79 | intersec = compute_intersec(i, j, h, w, bbox) 80 | 81 | if intersec[2]-intersec[0]>0 and intersec[3]-intersec[1]>0: 82 | intersec = normalize_intersec(i, j, h, w, intersec) 83 | return i, j, h, w, intersec 84 | 85 | # Fallback to central crop 86 | in_ratio = img.size[0] / img.size[1] 87 | if (in_ratio < min(ratio)): 88 | w = img.size[0] 89 | h = int(round(w / min(ratio))) 90 | elif (in_ratio > max(ratio)): 91 | h = img.size[1] 92 | w = int(round(h * max(ratio))) 93 | else: # whole image 94 | w = img.size[0] 95 | h = img.size[1] 96 | 97 | i = (img.size[1] - h) // 2 98 | j = (img.size[0] - w) // 2 99 | 100 | intersec = compute_intersec(i, j, h, w, bbox) 101 | intersec = normalize_intersec(i, j, h, w, intersec) 102 | return i, j, h, w, intersec 103 | 104 | def __call__(self, img, bbox): 105 | """ 106 | Args: 107 | img (PIL Image): Image to be cropped and resized. 108 | 109 | Returns: 110 | PIL Image: Randomly cropped and resized image. 111 | """ 112 | i, j, h, w, crop_bbox = self.get_params(img, bbox, self.scale, self.ratio) 113 | return F.resized_crop(img, i, j, h, w, self.size, self.interpolation), crop_bbox 114 | class RandomBBoxCrop(object): 115 | """Crop the given PIL Image at a random location. 116 | 117 | Args: 118 | size (sequence or int): Desired output size of the crop. If size is an 119 | int instead of sequence like (h, w), a square crop (size, size) is 120 | made. 121 | padding (int or sequence, optional): Optional padding on each border 122 | of the image. Default is None, i.e no padding. If a sequence of length 123 | 4 is provided, it is used to pad left, top, right, bottom borders 124 | respectively. If a sequence of length 2 is provided, it is used to 125 | pad left/right, top/bottom borders, respectively. 126 | pad_if_needed (boolean): It will pad the image if smaller than the 127 | desired size to avoid raising an exception. Since cropping is done 128 | after padding, the padding seems to be done at a random offset. 129 | fill: Pixel fill value for constant fill. Default is 0. If a tuple of 130 | length 3, it is used to fill R, G, B channels respectively. 131 | This value is only used when the padding_mode is constant 132 | padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant. 133 | 134 | - constant: pads with a constant value, this value is specified with fill 135 | 136 | - edge: pads with the last value on the edge of the image 137 | 138 | - reflect: pads with reflection of image (without repeating the last value on the edge) 139 | 140 | padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode 141 | will result in [3, 2, 1, 2, 3, 4, 3, 2] 142 | 143 | - symmetric: pads with reflection of image (repeating the last value on the edge) 144 | 145 | padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode 146 | will result in [2, 1, 1, 2, 3, 4, 4, 3] 147 | 148 | """ 149 | 150 | def __init__(self, size): 151 | if isinstance(size, numbers.Number): 152 | self.size = (int(size), int(size)) 153 | else: 154 | self.size = size 155 | 156 | @staticmethod 157 | def get_params(img, bbox, output_size): 158 | """Get parameters for ``crop`` for a random crop. 159 | 160 | Args: 161 | img (PIL Image): Image to be cropped. 162 | output_size (tuple): Expected output size of the crop. 163 | 164 | Returns: 165 | tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. 166 | """ 167 | w, h = img.size 168 | th, tw = output_size 169 | if w == tw and h == th: 170 | return 0, 0, h, w 171 | 172 | i = random.randint(0, h - th) 173 | j = random.randint(0, w - tw) 174 | intersec = compute_intersec(i, j, h, w, bbox) 175 | intersec = normalize_intersec(i, j, h, w, intersec) 176 | return i, j, th, tw, intersec 177 | 178 | def __call__(self, img, bbox): 179 | """ 180 | Args: 181 | img (PIL Image): Image to be cropped. 182 | 183 | Returns: 184 | PIL Image: Cropped image. 185 | """ 186 | 187 | i, j, h, w,crop_bbox = self.get_params(img, bbox, self.size) 188 | 189 | return F.crop(img, i, j, h, w),crop_bbox 190 | 191 | def __repr__(self): 192 | return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding) 193 | 194 | class ResizedBBoxCrop(object): 195 | 196 | def __init__(self, size, interpolation=Image.BILINEAR): 197 | self.size = size 198 | 199 | self.interpolation = interpolation 200 | 201 | @staticmethod 202 | def get_params(img, bbox, size): 203 | #resize to 256 204 | if isinstance(size, int): 205 | w, h = img.size 206 | if (w <= h and w == size) or (h <= w and h == size): 207 | img = copy.deepcopy(img) 208 | ow, oh = w, h 209 | if w < h: 210 | ow = size 211 | oh = int(size*h/w) 212 | else: 213 | oh = size 214 | ow = int(size*w/h) 215 | else: 216 | ow, oh = size[::-1] 217 | w, h = img.size 218 | 219 | 220 | intersec = copy.deepcopy(bbox) 221 | ratew = ow / w 222 | rateh = oh / h 223 | intersec[0] = bbox[0]*ratew 224 | intersec[2] = bbox[2]*ratew 225 | intersec[1] = bbox[1]*rateh 226 | intersec[3] = bbox[3]*rateh 227 | 228 | #intersec = normalize_intersec(i, j, h, w, intersec) 229 | return (oh, ow), intersec 230 | 231 | def __call__(self, img, bbox): 232 | """ 233 | Args: 234 | img (PIL Image): Image to be cropped and resized. 235 | 236 | Returns: 237 | PIL Image: Randomly cropped and resized image. 238 | """ 239 | size, crop_bbox = self.get_params(img, bbox, self.size) 240 | return F.resize(img, self.size, self.interpolation), crop_bbox 241 | 242 | 243 | class CenterBBoxCrop(object): 244 | 245 | def __init__(self, size, interpolation=Image.BILINEAR): 246 | self.size = size 247 | 248 | self.interpolation = interpolation 249 | 250 | @staticmethod 251 | def get_params(img, bbox, size): 252 | #center crop 253 | if isinstance(size, numbers.Number): 254 | output_size = (int(size), int(size)) 255 | 256 | w, h = img.size 257 | th, tw = output_size 258 | 259 | i = int(round((h - th) / 2.)) 260 | j = int(round((w - tw) / 2.)) 261 | 262 | intersec = compute_intersec(i, j, th, tw, bbox) 263 | intersec = normalize_intersec(i, j, th, tw, intersec) 264 | 265 | #intersec = normalize_intersec(i, j, h, w, intersec) 266 | return i, j, th, tw, intersec 267 | 268 | def __call__(self, img, bbox): 269 | """ 270 | Args: 271 | img (PIL Image): Image to be cropped and resized. 272 | 273 | Returns: 274 | PIL Image: Randomly cropped and resized image. 275 | """ 276 | i, j, th, tw, crop_bbox = self.get_params(img, bbox, self.size) 277 | return F.center_crop(img, self.size), crop_bbox 278 | 279 | 280 | -------------------------------------------------------------------------------- /WSOL/utils/func.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import torch 4 | import copy 5 | import cv2 6 | import sys 7 | def count_max(x): 8 | count_dict = {} 9 | for xlist in x: 10 | for item in xlist: 11 | if item==0: 12 | continue 13 | if item not in count_dict.keys(): 14 | count_dict[item] = 0 15 | count_dict[item] += 1 16 | if count_dict == {}: 17 | return -1 18 | count_dict = sorted(count_dict.items(), key=lambda d:d[1], reverse=True) 19 | return count_dict[0][0] 20 | 21 | def compute_intersec(i, j, h, w, bbox): 22 | ''' 23 | intersection box between croped box and GT BBox 24 | ''' 25 | intersec = copy.deepcopy(bbox) 26 | 27 | intersec[0] = max(j, bbox[0]) 28 | intersec[1] = max(i, bbox[1]) 29 | intersec[2] = min(j + w, bbox[2]) 30 | intersec[3] = min(i + h, bbox[3]) 31 | return intersec 32 | 33 | 34 | def normalize_intersec(i, j, h, w, intersec): 35 | ''' 36 | return: normalize into [0, 1] 37 | ''' 38 | 39 | intersec[0] = (intersec[0] - j) / w 40 | intersec[2] = (intersec[2] - j) / w 41 | intersec[1] = (intersec[1] - i) / h 42 | intersec[3] = (intersec[3] - i) / h 43 | return intersec 44 | 45 | def normalize_map(atten_map, crop_size): 46 | min_val = np.min(atten_map) 47 | max_val = np.max(atten_map) 48 | atten_norm = (atten_map - min_val)/(max_val - min_val + 1e-5) 49 | atten_norm = cv2.resize(atten_norm, dsize=(crop_size, crop_size)) 50 | return atten_norm 51 | 52 | class Logger(object): 53 | def __init__(self,filename="Default.log"): 54 | self.terminal = sys.stdout 55 | self.log = open(filename,'w+') 56 | 57 | def write(self,message): 58 | self.terminal.write(message) 59 | self.log.write(message) 60 | 61 | def flush(self): 62 | pass 63 | -------------------------------------------------------------------------------- /WSOL/utils/hyperparameters.py: -------------------------------------------------------------------------------- 1 | def get_hyperparameters(args, batch_flag=True): 2 | try: 3 | args.alpha = {'CUB_200_2011':{'vgg':0, 'resnet':0.5, 'inception':0.5, 'mobilenet':0.5}, 4 | 'ILSVRC':{'vgg':0.05, 'resnet':1, 'inception':1, 'mobilenet':0.5}, 5 | 'OpenImage':{'resnet':0.5}}[args.root][args.arch] 6 | except: print('no alpha') 7 | try: 8 | args.beta = {'CUB_200_2011':{'vgg':0.9, 'resnet':1.5, 'inception':1, 'mobilenet':1.9}, 9 | 'ILSVRC':{'vgg':1, 'resnet':2, 'inception':2.5, 'mobilenet':1.5}, 10 | 'OpenImage':{'resnet':1.5}}[args.root][args.arch] 11 | except: print('no beta') 12 | try: 13 | args.topk = {'CUB_200_2011':{'vgg':80, 'resnet':200, 'inception':200, 'mobilenet':80}, 14 | 'ILSVRC':{'vgg':1, 'resnet':1, 'inception':1, 'mobilenet':1}, 15 | 'OpenImage':{'resnet':1}}[args.root][args.arch] 16 | except: print('no topk') 17 | try: 18 | args.threshold = {'CUB_200_2011':{'vgg':[0.1], 'resnet':[0.1], 'inception':[0.1], 'mobilenet':[0.1]}, 19 | 'ILSVRC':{'vgg':[0.45], 'resnet':[0.25], 'inception':[0.25], 'mobilenet':[0.5]}}[args.root][args.arch] 20 | except: print('no threshold') 21 | try: 22 | args.epoch = {'CUB_200_2011': 120, 23 | 'ILSVRC': 9, 24 | 'OpenImage': 50}[args.root] 25 | except: print('no epoch') 26 | try: 27 | args.decay_epoch = {'CUB_200_2011': 80, 28 | 'ILSVRC': 3, 29 | 'OpenImage': 20}[args.root] 30 | except: print('no decay_epoch') 31 | try: 32 | if batch_flag: 33 | args.batch_size = {'CUB_200_2011':{'vgg':32, 'resnet':32, 'inception':32, 'mobilenet':32}, 34 | 'ILSVRC':{'vgg':256, 'resnet':128, 'inception':128, 'mobilenet':128}, 35 | 'OpenImage':{'resnet':64}}[args.root][args.arch] 36 | except: print('no batch_size') 37 | 38 | return args 39 | -------------------------------------------------------------------------------- /WSOL/utils/lr.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | def poly_lr_scheduler(optimizer, epoch, lr_decay_rate=0.1, decay_epoch=80): 4 | if epoch!=0 and epoch % decay_epoch == 0: 5 | for param_group in optimizer.param_groups: 6 | param_group['lr'] *= lr_decay_rate 7 | 8 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, 9 | start_warmup_value=0, warmup_steps=-1): 10 | warmup_schedule = np.array([]) 11 | warmup_iters = warmup_epochs * niter_per_ep 12 | if warmup_steps > 0: 13 | warmup_iters = warmup_steps 14 | print("Set warmup steps = %d" % warmup_iters) 15 | if warmup_epochs > 0: 16 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 17 | 18 | iters = np.arange(epochs * niter_per_ep - warmup_iters) 19 | schedule = np.array( 20 | [final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters]) 21 | 22 | schedule = np.concatenate((warmup_schedule, schedule)) 23 | 24 | assert len(schedule) == epochs * niter_per_ep 25 | return schedule -------------------------------------------------------------------------------- /WSOL/utils/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | def get_optimizer(model, args): 3 | lr =args.lr 4 | weight_list = [] 5 | bias_list = [] 6 | last_weight_list = [] 7 | last_bias_list =[] 8 | for name, value in model.named_parameters(): 9 | if 'classifier' in name : 10 | if 'weight' in name: 11 | last_weight_list.append(value) 12 | elif 'bias' in name: 13 | last_bias_list.append(value) 14 | else: 15 | if 'weight' in name: 16 | weight_list.append(value) 17 | elif 'bias' in name: 18 | bias_list.append(value) 19 | optmizer = torch.optim.SGD([{'params': weight_list, 20 | 'lr': lr}, 21 | {'params': bias_list, 22 | 'lr': lr*2}, 23 | {'params': last_weight_list, 24 | 'lr': lr*10}, 25 | {'params': last_bias_list, 26 | 'lr': lr*20}], momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) 27 | return optmizer -------------------------------------------------------------------------------- /WSOL/utils/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.nn as nn 3 | import shutil 4 | 5 | 6 | def remove_layer(state_dict, keyword): 7 | keys = [key for key in state_dict.keys()] 8 | for key in keys: 9 | if keyword in key: 10 | state_dict.pop(key) 11 | return state_dict 12 | 13 | 14 | def replace_layer(state_dict, keyword1, keyword2): 15 | keys = [key for key in state_dict.keys()] 16 | for key in keys: 17 | if keyword1 in key: 18 | new_key = key.replace(keyword1, keyword2) 19 | state_dict[new_key] = state_dict.pop(key) 20 | return state_dict 21 | 22 | 23 | def initialize_weights(modules, init_mode): 24 | for m in modules: 25 | if isinstance(m, nn.Conv2d): 26 | if init_mode == 'he': 27 | nn.init.kaiming_normal_(m.weight, mode='fan_out', 28 | nonlinearity='relu') 29 | elif init_mode == 'xavier': 30 | nn.init.xavier_uniform_(m.weight.data) 31 | else: 32 | raise ValueError('Invalid init_mode {}'.format(init_mode)) 33 | if m.bias is not None: 34 | nn.init.constant_(m.bias, 0) 35 | elif isinstance(m, nn.BatchNorm2d): 36 | nn.init.constant_(m.weight, 1) 37 | nn.init.constant_(m.bias, 0) 38 | elif isinstance(m, nn.Linear): 39 | nn.init.normal_(m.weight, 0, 0.01) 40 | nn.init.constant_(m.bias, 0) 41 | 42 | def copy_dir(dir1, dir2): 43 | dlist = os.listdir(dir1) 44 | if not os.path.exists(dir2): 45 | os.mkdir(dir2) 46 | for f in dlist: 47 | file1 = os.path.join(dir1, f) # 源文件 48 | file2 = os.path.join(dir2, f) # 目标文件 49 | if os.path.isfile(file1): 50 | shutil.copyfile(file1, file2) 51 | 52 | def seek_class(args, path): 53 | if args.root == 'CUB_200_2011': 54 | val_root = 'Data/'+args.root + '/test' 55 | path_split = path.split("_")[:-2] 56 | part_path = '' 57 | for i in path_split: 58 | part_path += i + '_' 59 | part_path = part_path[:-1] 60 | part_path = part_path.lower() 61 | classes_name = os.listdir(val_root) 62 | for cur_class in classes_name: 63 | cur_class_low = cur_class.lower() 64 | if cur_class_low[4:] == part_path: 65 | path = 'Data/' + args.root + '/test/' + cur_class + '/' + path 66 | return path, int(cur_class[:3]) - 1, cur_class 67 | 68 | elif args.root == 'fgvc_aircraft_2013b': 69 | label_classes = {} 70 | with open('Data/fgvc_aircraft_2013b/classes.txt', 'r') as f: 71 | for i, line in enumerate(f): 72 | label_classes[line[:-1]] = i 73 | with open('Data/fgvc_aircraft_2013b/images_box.txt', 'r') as f: 74 | for line in f: 75 | img_name = line.split(';')[0] 76 | if img_name == path[:-4]: 77 | img_label = line.split(';')[1] 78 | path = 'Data/fgvc_aircraft_2013b/images' + '/' + path 79 | return path, label_classes[img_label], img_label 80 | 81 | -------------------------------------------------------------------------------- /WSSS/data/coco14/gen_lst.py: -------------------------------------------------------------------------------- 1 | with open('val_cls.txt') as f: 2 | lines = f.readlines() 3 | 4 | img_names = [line[:-1].split()[0] for line in lines] 5 | 6 | f1 = open('val.txt', 'w') 7 | for i in range(len(img_names)): 8 | f1.write('/JPEGImages/' + img_names[i] + '.jpg ') 9 | f1.write('/SegmentationClass/' + img_names[i] + '.png' + '\n') 10 | -------------------------------------------------------------------------------- /WSSS/data/coco14/split_list.py: -------------------------------------------------------------------------------- 1 | with open('train_cls.txt') as f: 2 | lines = f.readlines() 3 | 4 | L = len(lines) // 20000 5 | print(L) 6 | for i in range(L): 7 | f1 = open('train_cls_part{}.txt'.format(i+1), 'w') 8 | for j in range(20000): 9 | line = lines[20000*i + j] 10 | f1.write(line) 11 | f1.close() 12 | 13 | f1 = open('train_cls_part{}.txt'.format(L+1), 'w') 14 | for j in range(len(lines)-L*20000): 15 | line = lines[20000*L + j] 16 | f1.write(line) 17 | f1.close() 18 | -------------------------------------------------------------------------------- /WSSS/misc/imutils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | 4 | import pydensecrf.densecrf as dcrf 5 | from pydensecrf.utils import unary_from_labels 6 | from PIL import Image 7 | 8 | def pil_resize(img, size, order): 9 | if size[0] == img.shape[0] and size[1] == img.shape[1]: 10 | return img 11 | 12 | if order == 3: 13 | resample = Image.BICUBIC 14 | elif order == 0: 15 | resample = Image.NEAREST 16 | 17 | return np.asarray(Image.fromarray(img).resize(size[::-1], resample)) 18 | 19 | def pil_rescale(img, scale, order): 20 | height, width = img.shape[:2] 21 | target_size = (int(np.round(height*scale)), int(np.round(width*scale))) 22 | return pil_resize(img, target_size, order) 23 | 24 | 25 | def random_resize_long(img, min_long, max_long): 26 | target_long = random.randint(min_long, max_long) 27 | h, w = img.shape[:2] 28 | 29 | if w < h: 30 | scale = target_long / h 31 | else: 32 | scale = target_long / w 33 | 34 | return pil_rescale(img, scale, 3) 35 | 36 | def resize_scale(img, resize_scale): 37 | 38 | return pil_resize(img, (resize_scale,resize_scale), 3) 39 | 40 | def random_scale(img, scale_range, order): 41 | 42 | target_scale = scale_range[0] + random.random() * (scale_range[1] - scale_range[0]) 43 | 44 | if isinstance(img, tuple): 45 | return (pil_rescale(img[0], target_scale, order[0]), pil_rescale(img[1], target_scale, order[1])) 46 | else: 47 | return pil_rescale(img[0], target_scale, order) 48 | 49 | def random_lr_flip(img): 50 | 51 | if bool(random.getrandbits(1)): 52 | if isinstance(img, tuple): 53 | return [np.fliplr(m) for m in img] 54 | else: 55 | return np.fliplr(img) 56 | else: 57 | return img 58 | 59 | def get_random_crop_box(imgsize, cropsize): 60 | h, w = imgsize 61 | 62 | ch = min(cropsize, h) 63 | cw = min(cropsize, w) 64 | 65 | w_space = w - cropsize 66 | h_space = h - cropsize 67 | 68 | if w_space > 0: 69 | cont_left = 0 70 | img_left = random.randrange(w_space + 1) 71 | else: 72 | cont_left = random.randrange(-w_space + 1) 73 | img_left = 0 74 | 75 | if h_space > 0: 76 | cont_top = 0 77 | img_top = random.randrange(h_space + 1) 78 | else: 79 | cont_top = random.randrange(-h_space + 1) 80 | img_top = 0 81 | 82 | return cont_top, cont_top+ch, cont_left, cont_left+cw, img_top, img_top+ch, img_left, img_left+cw 83 | 84 | def random_crop(images, cropsize, default_values): 85 | 86 | if isinstance(images, np.ndarray): images = (images,) 87 | if isinstance(default_values, int): default_values = (default_values,) 88 | 89 | imgsize = images[0].shape[:2] 90 | box = get_random_crop_box(imgsize, cropsize) 91 | 92 | new_images = [] 93 | for img, f in zip(images, default_values): 94 | 95 | if len(img.shape) == 3: 96 | cont = np.ones((cropsize, cropsize, img.shape[2]), img.dtype)*f 97 | else: 98 | cont = np.ones((cropsize, cropsize), img.dtype)*f 99 | cont[box[0]:box[1], box[2]:box[3]] = img[box[4]:box[5], box[6]:box[7]] 100 | new_images.append(cont) 101 | 102 | if len(new_images) == 1: 103 | new_images = new_images[0] 104 | 105 | return new_images 106 | 107 | def top_left_crop(img, cropsize, default_value): 108 | 109 | h, w = img.shape[:2] 110 | 111 | ch = min(cropsize, h) 112 | cw = min(cropsize, w) 113 | 114 | if len(img.shape) == 2: 115 | container = np.ones((cropsize, cropsize), img.dtype)*default_value 116 | else: 117 | container = np.ones((cropsize, cropsize, img.shape[2]), img.dtype)*default_value 118 | 119 | container[:ch, :cw] = img[:ch, :cw] 120 | 121 | return container 122 | 123 | def center_crop(img, cropsize, default_value=0): 124 | 125 | h, w = img.shape[:2] 126 | 127 | ch = min(cropsize, h) 128 | cw = min(cropsize, w) 129 | 130 | sh = h - cropsize 131 | sw = w - cropsize 132 | 133 | if sw > 0: 134 | cont_left = 0 135 | img_left = int(round(sw / 2)) 136 | else: 137 | cont_left = int(round(-sw / 2)) 138 | img_left = 0 139 | 140 | if sh > 0: 141 | cont_top = 0 142 | img_top = int(round(sh / 2)) 143 | else: 144 | cont_top = int(round(-sh / 2)) 145 | img_top = 0 146 | 147 | if len(img.shape) == 2: 148 | container = np.ones((cropsize, cropsize), img.dtype)*default_value 149 | else: 150 | container = np.ones((cropsize, cropsize, img.shape[2]), img.dtype)*default_value 151 | 152 | container[cont_top:cont_top+ch, cont_left:cont_left+cw] = \ 153 | img[img_top:img_top+ch, img_left:img_left+cw] 154 | 155 | return container 156 | 157 | def HWC_to_CHW(img): 158 | return np.transpose(img, (2, 0, 1)) 159 | 160 | def crf_inference_label(img, labels, t=10, n_labels=21, gt_prob=0.7): 161 | 162 | h, w = img.shape[:2] 163 | 164 | d = dcrf.DenseCRF2D(w, h, n_labels) 165 | 166 | unary = unary_from_labels(labels, n_labels, gt_prob=gt_prob, zero_unsure=False) 167 | 168 | d.setUnaryEnergy(unary) 169 | d.addPairwiseGaussian(sxy=3, compat=3) 170 | d.addPairwiseBilateral(sxy=50, srgb=5, rgbim=np.ascontiguousarray(np.copy(img)), compat=10) 171 | 172 | q = d.inference(t) 173 | 174 | return np.argmax(np.array(q).reshape((n_labels, h, w)), axis=0) 175 | 176 | 177 | def get_strided_size(orig_size, stride): 178 | return ((orig_size[0]-1)//stride+1, (orig_size[1]-1)//stride+1) 179 | 180 | 181 | def get_strided_up_size(orig_size, stride): 182 | strided_size = get_strided_size(orig_size, stride) 183 | return strided_size[0]*stride, strided_size[1]*stride 184 | 185 | 186 | def compress_range(arr): 187 | uniques = np.unique(arr) 188 | maximum = np.max(uniques) 189 | 190 | d = np.zeros(maximum+1, np.int32) 191 | d[uniques] = np.arange(uniques.shape[0]) 192 | 193 | out = d[arr] 194 | return out - np.min(out) 195 | 196 | 197 | def colorize_score(score_map, exclude_zero=False, normalize=True, by_hue=False): 198 | import matplotlib.colors 199 | if by_hue: 200 | aranged = np.arange(score_map.shape[0]) / (score_map.shape[0]) 201 | hsv_color = np.stack((aranged, np.ones_like(aranged), np.ones_like(aranged)), axis=-1) 202 | rgb_color = matplotlib.colors.hsv_to_rgb(hsv_color) 203 | 204 | test = rgb_color[np.argmax(score_map, axis=0)] 205 | test = np.expand_dims(np.max(score_map, axis=0), axis=-1) * test 206 | 207 | if normalize: 208 | return test / (np.max(test) + 1e-5) 209 | else: 210 | return test 211 | 212 | else: 213 | VOC_color = np.array([(0, 0, 0), (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128), 214 | (0, 128, 128), (128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0), (192, 128, 0), 215 | (64, 0, 128), (192, 0, 128), (64, 128, 128), (192, 128, 128), (0, 64, 0), (128, 64, 0), 216 | (0, 192, 0), (128, 192, 0), (0, 64, 128), (255, 255, 255)], np.float32) 217 | 218 | if exclude_zero: 219 | VOC_color = VOC_color[1:] 220 | 221 | test = VOC_color[np.argmax(score_map, axis=0)%22] 222 | test = np.expand_dims(np.max(score_map, axis=0), axis=-1) * test 223 | if normalize: 224 | test /= np.max(test) + 1e-5 225 | 226 | return test 227 | 228 | 229 | def colorize_displacement(disp): 230 | 231 | import matplotlib.colors 232 | import math 233 | 234 | a = (np.arctan2(-disp[0], -disp[1]) / math.pi + 1) / 2 235 | 236 | r = np.sqrt(disp[0] ** 2 + disp[1] ** 2) 237 | s = r / np.max(r) 238 | hsv_color = np.stack((a, s, np.ones_like(a)), axis=-1) 239 | rgb_color = matplotlib.colors.hsv_to_rgb(hsv_color) 240 | 241 | return rgb_color 242 | 243 | 244 | def colorize_label(label_map, normalize=True, by_hue=True, exclude_zero=False, outline=False): 245 | 246 | label_map = label_map.astype(np.uint8) 247 | 248 | if by_hue: 249 | import matplotlib.colors 250 | sz = np.max(label_map) 251 | aranged = np.arange(sz) / sz 252 | hsv_color = np.stack((aranged, np.ones_like(aranged), np.ones_like(aranged)), axis=-1) 253 | rgb_color = matplotlib.colors.hsv_to_rgb(hsv_color) 254 | rgb_color = np.concatenate([np.zeros((1, 3)), rgb_color], axis=0) 255 | 256 | test = rgb_color[label_map] 257 | else: 258 | VOC_color = np.array([(0, 0, 0), (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128), 259 | (0, 128, 128), (128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0), (192, 128, 0), 260 | (64, 0, 128), (192, 0, 128), (64, 128, 128), (192, 128, 128), (0, 64, 0), (128, 64, 0), 261 | (0, 192, 0), (128, 192, 0), (0, 64, 128), (255, 255, 255)], np.float32) 262 | 263 | if exclude_zero: 264 | VOC_color = VOC_color[1:] 265 | test = VOC_color[label_map] 266 | if normalize: 267 | test /= np.max(test) 268 | 269 | if outline: 270 | edge = np.greater(np.sum(np.abs(test[:-1, :-1] - test[1:, :-1]), axis=-1) + np.sum(np.abs(test[:-1, :-1] - test[:-1, 1:]), axis=-1), 0) 271 | edge1 = np.pad(edge, ((0, 1), (0, 1)), mode='constant', constant_values=0) 272 | edge2 = np.pad(edge, ((1, 0), (1, 0)), mode='constant', constant_values=0) 273 | edge = np.repeat(np.expand_dims(np.maximum(edge1, edge2), -1), 3, axis=-1) 274 | 275 | test = np.maximum(test, edge) 276 | return test 277 | -------------------------------------------------------------------------------- /WSSS/misc/indexing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | 6 | class PathIndex: 7 | 8 | def __init__(self, radius, default_size): 9 | self.radius = radius 10 | self.radius_floor = int(np.ceil(radius) - 1) 11 | 12 | self.search_paths, self.search_dst = self.get_search_paths_dst(self.radius) 13 | 14 | self.path_indices, self.src_indices, self.dst_indices = self.get_path_indices(default_size) 15 | 16 | return 17 | 18 | def get_search_paths_dst(self, max_radius=5): 19 | 20 | coord_indices_by_length = [[] for _ in range(max_radius * 4)] 21 | 22 | search_dirs = [] 23 | 24 | for x in range(1, max_radius): 25 | search_dirs.append((0, x)) 26 | 27 | for y in range(1, max_radius): 28 | for x in range(-max_radius + 1, max_radius): 29 | if x * x + y * y < max_radius ** 2: 30 | search_dirs.append((y, x)) 31 | 32 | for dir in search_dirs: 33 | 34 | length_sq = dir[0] ** 2 + dir[1] ** 2 35 | path_coords = [] 36 | 37 | min_y, max_y = sorted((0, dir[0])) 38 | min_x, max_x = sorted((0, dir[1])) 39 | 40 | for y in range(min_y, max_y + 1): 41 | for x in range(min_x, max_x + 1): 42 | 43 | dist_sq = (dir[0] * x - dir[1] * y) ** 2 / length_sq 44 | 45 | if dist_sq < 1: 46 | path_coords.append([y, x]) 47 | 48 | path_coords.sort(key=lambda x: -abs(x[0]) - abs(x[1])) 49 | path_length = len(path_coords) 50 | 51 | coord_indices_by_length[path_length].append(path_coords) 52 | 53 | path_list_by_length = [np.asarray(v) for v in coord_indices_by_length if v] 54 | path_destinations = np.concatenate([p[:, 0] for p in path_list_by_length], axis=0) 55 | 56 | return path_list_by_length, path_destinations 57 | 58 | def get_path_indices(self, size): 59 | 60 | full_indices = np.reshape(np.arange(0, size[0] * size[1], dtype=np.int64), (size[0], size[1])) 61 | 62 | cropped_height = size[0] - self.radius_floor 63 | cropped_width = size[1] - 2 * self.radius_floor 64 | 65 | path_indices = [] 66 | 67 | for paths in self.search_paths: 68 | 69 | path_indices_list = [] 70 | for p in paths: 71 | 72 | coord_indices_list = [] 73 | 74 | for dy, dx in p: 75 | coord_indices = full_indices[dy:dy + cropped_height, 76 | self.radius_floor + dx:self.radius_floor + dx + cropped_width] 77 | coord_indices = np.reshape(coord_indices, [-1]) 78 | 79 | coord_indices_list.append(coord_indices) 80 | 81 | path_indices_list.append(coord_indices_list) 82 | 83 | path_indices.append(np.array(path_indices_list)) 84 | 85 | src_indices = np.reshape(full_indices[:cropped_height, self.radius_floor:self.radius_floor + cropped_width], -1) 86 | dst_indices = np.concatenate([p[:,0] for p in path_indices], axis=0) 87 | 88 | return path_indices, src_indices, dst_indices 89 | 90 | 91 | def edge_to_affinity(edge, paths_indices): 92 | 93 | aff_list = [] 94 | edge = edge.view(edge.size(0), -1) 95 | 96 | for i in range(len(paths_indices)): 97 | if isinstance(paths_indices[i], np.ndarray): 98 | paths_indices[i] = torch.from_numpy(paths_indices[i]) 99 | paths_indices[i] = paths_indices[i].cuda(non_blocking=True) 100 | 101 | for ind in paths_indices: 102 | ind_flat = ind.view(-1) 103 | dist = torch.index_select(edge, dim=-1, index=ind_flat) 104 | dist = dist.view(dist.size(0), ind.size(0), ind.size(1), ind.size(2)) 105 | aff = torch.squeeze(1 - F.max_pool2d(dist, (dist.size(2), 1)), dim=2) 106 | aff_list.append(aff) 107 | aff_cat = torch.cat(aff_list, dim=1) 108 | 109 | return aff_cat 110 | 111 | 112 | def affinity_sparse2dense(affinity_sparse, ind_from, ind_to, n_vertices): 113 | 114 | ind_from = torch.from_numpy(ind_from) 115 | ind_to = torch.from_numpy(ind_to) 116 | 117 | affinity_sparse = affinity_sparse.view(-1).cpu() 118 | ind_from = ind_from.repeat(ind_to.size(0)).view(-1) 119 | ind_to = ind_to.view(-1) 120 | 121 | indices = torch.stack([ind_from, ind_to]) 122 | indices_tp = torch.stack([ind_to, ind_from]) 123 | 124 | indices_id = torch.stack([torch.arange(0, n_vertices).long(), torch.arange(0, n_vertices).long()]) 125 | 126 | affinity_dense = torch.sparse.FloatTensor(torch.cat([indices, indices_id, indices_tp], dim=1), 127 | torch.cat([affinity_sparse, torch.ones([n_vertices]), affinity_sparse])).to_dense().cuda() 128 | 129 | return affinity_dense 130 | 131 | 132 | def to_transition_matrix(affinity_dense, beta, times): 133 | scaled_affinity = torch.pow(affinity_dense, beta) 134 | 135 | trans_mat = scaled_affinity / torch.sum(scaled_affinity, dim=0, keepdim=True) 136 | for _ in range(times): 137 | trans_mat = torch.matmul(trans_mat, trans_mat) 138 | 139 | return trans_mat 140 | 141 | def propagate_to_edge(x, edge, radius=5, beta=10, exp_times=8): 142 | 143 | height, width = x.shape[-2:] 144 | 145 | hor_padded = width+radius*2 146 | ver_padded = height+radius 147 | 148 | path_index = PathIndex(radius=radius, default_size=(ver_padded, hor_padded)) 149 | 150 | edge_padded = F.pad(edge, (radius, radius, 0, radius), mode='constant', value=1.0) 151 | sparse_aff = edge_to_affinity(torch.unsqueeze(edge_padded, 0), 152 | path_index.path_indices) 153 | 154 | dense_aff = affinity_sparse2dense(sparse_aff, path_index.src_indices, 155 | path_index.dst_indices, ver_padded * hor_padded) 156 | dense_aff = dense_aff.view(ver_padded, hor_padded, ver_padded, hor_padded) 157 | dense_aff = dense_aff[:-radius, radius:-radius, :-radius, radius:-radius] 158 | dense_aff = dense_aff.reshape(height * width, height * width) 159 | 160 | trans_mat = to_transition_matrix(dense_aff, beta=beta, times=exp_times) 161 | 162 | x = x.view(-1, height, width) * (1 - edge) 163 | 164 | rw = torch.matmul(x.view(-1, height * width), trans_mat) 165 | rw = rw.view(rw.size(0), 1, height, width) 166 | 167 | return rw -------------------------------------------------------------------------------- /WSSS/misc/pyutils.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import time 4 | import sys 5 | 6 | class Logger(object): 7 | def __init__(self, outfile): 8 | self.terminal = sys.stdout 9 | self.log = open(outfile, "a") 10 | sys.stdout = self 11 | 12 | def write(self, message): 13 | self.terminal.write(message) 14 | self.log.write(message) 15 | 16 | def flush(self): 17 | self.terminal.flush() 18 | 19 | 20 | class AverageMeter: 21 | def __init__(self, *keys): 22 | self.__data = dict() 23 | for k in keys: 24 | self.__data[k] = [0.0, 0] 25 | 26 | def add(self, dict): 27 | for k, v in dict.items(): 28 | if k not in self.__data: 29 | self.__data[k] = [0.0, 0] 30 | self.__data[k][0] += v 31 | self.__data[k][1] += 1 32 | 33 | def get(self, *keys): 34 | if len(keys) == 1: 35 | return self.__data[keys[0]][0] / self.__data[keys[0]][1] 36 | else: 37 | v_list = [self.__data[k][0] / self.__data[k][1] for k in keys] 38 | return tuple(v_list) 39 | 40 | def pop(self, key=None): 41 | if key is None: 42 | for k in self.__data.keys(): 43 | self.__data[k] = [0.0, 0] 44 | else: 45 | v = self.get(key) 46 | self.__data[key] = [0.0, 0] 47 | return v 48 | 49 | 50 | class Timer: 51 | def __init__(self, starting_msg = None): 52 | self.start = time.time() 53 | self.stage_start = self.start 54 | 55 | if starting_msg is not None: 56 | print(starting_msg, time.ctime(time.time())) 57 | 58 | def __enter__(self): 59 | return self 60 | 61 | def __exit__(self, exc_type, exc_val, exc_tb): 62 | return 63 | 64 | def update_progress(self, progress): 65 | self.elapsed = time.time() - self.start 66 | self.est_total = self.elapsed / progress 67 | self.est_remaining = self.est_total - self.elapsed 68 | self.est_finish = int(self.start + self.est_total) 69 | 70 | 71 | def str_estimated_complete(self): 72 | return str(time.ctime(self.est_finish)) 73 | 74 | def get_stage_elapsed(self): 75 | return time.time() - self.stage_start 76 | 77 | def reset_stage(self): 78 | self.stage_start = time.time() 79 | 80 | def lapse(self): 81 | out = time.time() - self.stage_start 82 | self.stage_start = time.time() 83 | return out 84 | 85 | 86 | def to_one_hot(sparse_integers, maximum_val=None, dtype=np.bool): 87 | 88 | if maximum_val is None: 89 | maximum_val = np.max(sparse_integers) + 1 90 | 91 | src_shape = sparse_integers.shape 92 | 93 | flat_src = np.reshape(sparse_integers, [-1]) 94 | src_size = flat_src.shape[0] 95 | 96 | one_hot = np.zeros((maximum_val, src_size), dtype) 97 | one_hot[flat_src, np.arange(src_size)] = 1 98 | 99 | one_hot = np.reshape(one_hot, [maximum_val] + list(src_shape)) 100 | 101 | return one_hot 102 | -------------------------------------------------------------------------------- /WSSS/misc/torchutils.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | from torch.utils.data import Subset 5 | import numpy as np 6 | import math 7 | 8 | 9 | class PolyOptimizer(torch.optim.SGD): 10 | 11 | def __init__(self, params, lr, weight_decay, max_step, momentum=0.9): 12 | super().__init__(params, lr, weight_decay) 13 | 14 | self.global_step = 0 15 | self.max_step = max_step 16 | self.momentum = momentum 17 | 18 | self.__initial_lr = [group['lr'] for group in self.param_groups] 19 | 20 | 21 | def step(self, closure=None): 22 | 23 | if self.global_step < self.max_step: 24 | lr_mult = (1 - self.global_step / self.max_step) ** self.momentum 25 | 26 | for i in range(len(self.param_groups)): 27 | self.param_groups[i]['lr'] = self.__initial_lr[i] * lr_mult 28 | 29 | super().step(closure) 30 | 31 | self.global_step += 1 32 | 33 | class SGDROptimizer(torch.optim.SGD): 34 | 35 | def __init__(self, params, steps_per_epoch, lr=0, weight_decay=0, epoch_start=1, restart_mult=2): 36 | super().__init__(params, lr, weight_decay) 37 | 38 | self.global_step = 0 39 | self.local_step = 0 40 | self.total_restart = 0 41 | 42 | self.max_step = steps_per_epoch * epoch_start 43 | self.restart_mult = restart_mult 44 | 45 | self.__initial_lr = [group['lr'] for group in self.param_groups] 46 | 47 | 48 | def step(self, closure=None): 49 | 50 | if self.local_step >= self.max_step: 51 | self.local_step = 0 52 | self.max_step *= self.restart_mult 53 | self.total_restart += 1 54 | 55 | lr_mult = (1 + math.cos(math.pi * self.local_step / self.max_step))/2 / (self.total_restart + 1) 56 | 57 | for i in range(len(self.param_groups)): 58 | self.param_groups[i]['lr'] = self.__initial_lr[i] * lr_mult 59 | 60 | super().step(closure) 61 | 62 | self.local_step += 1 63 | self.global_step += 1 64 | 65 | 66 | def split_dataset(dataset, n_splits): 67 | 68 | return [Subset(dataset, np.arange(i, len(dataset), n_splits)) for i in range(n_splits)] 69 | 70 | 71 | def gap2d(x, keepdims=False): 72 | out = torch.mean(x.view(x.size(0), x.size(1), -1), -1) 73 | if keepdims: 74 | out = out.view(out.size(0), out.size(1), 1, 1) 75 | 76 | return out 77 | 78 | def gap2d_pos(x, keepdims=False): 79 | out = torch.sum(x.view(x.size(0), x.size(1), -1), -1) / (torch.sum(x>0)+1e-12) 80 | if keepdims: 81 | out = out.view(out.size(0), out.size(1), 1, 1) 82 | 83 | return out 84 | 85 | def gsp2d(x, keepdims=False): 86 | out = torch.sum(x.view(x.size(0), x.size(1), -1), -1) 87 | if keepdims: 88 | out = out.view(out.size(0), out.size(1), 1, 1) 89 | 90 | return out 91 | 92 | -------------------------------------------------------------------------------- /WSSS/net/__init__.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | 3 | __all__ = ['resnet50_bas'] -------------------------------------------------------------------------------- /WSSS/net/resnet50.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | model_urls = { 6 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth' 7 | } 8 | 9 | 10 | class FixedBatchNorm(nn.BatchNorm2d): 11 | def forward(self, input): 12 | return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, 13 | training=False, eps=self.eps) 14 | 15 | 16 | class Bottleneck(nn.Module): 17 | expansion = 4 18 | 19 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 20 | super(Bottleneck, self).__init__() 21 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 22 | self.bn1 = FixedBatchNorm(planes) 23 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 24 | padding=dilation, bias=False, dilation=dilation) 25 | self.bn2 = FixedBatchNorm(planes) 26 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 27 | self.bn3 = FixedBatchNorm(planes * 4) 28 | self.relu = nn.ReLU(inplace=True) 29 | self.downsample = downsample 30 | self.stride = stride 31 | self.dilation = dilation 32 | 33 | def forward(self, x): 34 | residual = x 35 | 36 | out = self.conv1(x) 37 | out = self.bn1(out) 38 | out = self.relu(out) 39 | 40 | out = self.conv2(out) 41 | out = self.bn2(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv3(out) 45 | out = self.bn3(out) 46 | 47 | if self.downsample is not None: 48 | residual = self.downsample(x) 49 | 50 | out += residual 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | 56 | class ResNet(nn.Module): 57 | 58 | def __init__(self, block, layers, strides=(2, 2, 2, 2), dilations=(1, 1, 1, 1)): 59 | self.inplanes = 64 60 | super(ResNet, self).__init__() 61 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=strides[0], padding=3, 62 | bias=False) 63 | self.bn1 = FixedBatchNorm(64) 64 | self.relu = nn.ReLU(inplace=True) 65 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 66 | 67 | self.layer1 = self._make_layer(block, 64, layers[0], stride=1, dilation=dilations[0]) 68 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1]) 69 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2]) 70 | self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3]) 71 | 72 | self.inplanes = 1024 73 | 74 | #self.avgpool = nn.AvgPool2d(7, stride=1) 75 | #self.fc = nn.Linear(512 * block.expansion, 1000) 76 | 77 | 78 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 79 | downsample = None 80 | if stride != 1 or self.inplanes != planes * block.expansion: 81 | downsample = nn.Sequential( 82 | nn.Conv2d(self.inplanes, planes * block.expansion, 83 | kernel_size=1, stride=stride, bias=False), 84 | FixedBatchNorm(planes * block.expansion), 85 | ) 86 | 87 | layers = [block(self.inplanes, planes, stride, downsample, dilation=1)] 88 | self.inplanes = planes * block.expansion 89 | for i in range(1, blocks): 90 | layers.append(block(self.inplanes, planes, dilation=dilation)) 91 | 92 | return nn.Sequential(*layers) 93 | 94 | def forward(self, x): 95 | x = self.conv1(x) 96 | x = self.bn1(x) 97 | x = self.relu(x) 98 | x = self.maxpool(x) 99 | 100 | x = self.layer1(x) 101 | x = self.layer2(x) 102 | x = self.layer3(x) 103 | x = self.layer4(x) 104 | 105 | x = self.avgpool(x) 106 | x = x.view(x.size(0), -1) 107 | x = self.fc(x) 108 | 109 | return x 110 | 111 | 112 | def resnet50(pretrained=True, **kwargs): 113 | 114 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 115 | if pretrained: 116 | state_dict = model_zoo.load_url(model_urls['resnet50']) 117 | state_dict.pop('fc.weight') 118 | state_dict.pop('fc.bias') 119 | model.load_state_dict(state_dict, strict=False) 120 | print("model pretrained initialized") 121 | 122 | return model -------------------------------------------------------------------------------- /WSSS/net/resnet50_bas.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | from torch.utils.model_zoo import load_url 5 | from torch.autograd import Variable 6 | import torch.nn.functional as F 7 | import numpy as np 8 | import cv2 9 | import copy 10 | import matplotlib.pyplot as plt 11 | import random 12 | 13 | def initialize_weights(modules, init_mode): 14 | for m in modules: 15 | if isinstance(m, nn.Conv2d): 16 | if init_mode == 'he': 17 | nn.init.kaiming_normal_(m.weight, mode='fan_out', 18 | nonlinearity='relu') 19 | elif init_mode == 'xavier': 20 | nn.init.xavier_uniform_(m.weight.data) 21 | else: 22 | raise ValueError('Invalid init_mode {}'.format(init_mode)) 23 | if m.bias is not None: 24 | nn.init.constant_(m.bias, 0) 25 | elif isinstance(m, nn.BatchNorm2d): 26 | nn.init.constant_(m.weight, 1) 27 | nn.init.constant_(m.bias, 0) 28 | elif isinstance(m, nn.Linear): 29 | nn.init.normal_(m.weight, 0, 0.01) 30 | nn.init.constant_(m.bias, 0) 31 | 32 | class Bottleneck(nn.Module): 33 | expansion = 4 34 | 35 | def __init__(self, inplanes, planes, stride=1, downsample=None, 36 | base_width=64): 37 | super(Bottleneck, self).__init__() 38 | width = int(planes * (base_width / 64.)) 39 | self.conv1 = nn.Conv2d(inplanes, width, 1, bias=False) 40 | self.bn1 = nn.BatchNorm2d(width) 41 | self.conv2 = nn.Conv2d(width, width, 3, 42 | stride=stride, padding=1, bias=False) 43 | self.bn2 = nn.BatchNorm2d(width) 44 | self.conv3 = nn.Conv2d(width, planes * self.expansion, 1, bias=False) 45 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 46 | self.relu = nn.ReLU(inplace=True) 47 | self.downsample = downsample 48 | self.stride = stride 49 | 50 | def forward(self, x): 51 | identity = x 52 | 53 | out = self.conv1(x) 54 | out = self.bn1(out) 55 | out = self.relu(out) 56 | 57 | out = self.conv2(out) 58 | out = self.bn2(out) 59 | out = self.relu(out) 60 | 61 | out = self.conv3(out) 62 | out = self.bn3(out) 63 | 64 | if self.downsample is not None: 65 | identity = self.downsample(x) 66 | 67 | out += identity 68 | out = self.relu(out) 69 | 70 | return out 71 | 72 | 73 | class ResNetCam(nn.Module): 74 | def __init__(self, block, layers, num_classes=1000, 75 | large_feature_map=True): 76 | super(ResNetCam, self).__init__() 77 | self.num_classes = 20 78 | stride_l3 = 1 if large_feature_map else 2 79 | self.inplanes = 64 80 | 81 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, 82 | padding=3, bias=False) 83 | self.bn1 = nn.BatchNorm2d(self.inplanes) 84 | self.relu = nn.ReLU(inplace=True) 85 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 86 | 87 | self.layer1 = self._make_layer(block, 64, layers[0], stride=1) 88 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 89 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 90 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1) 91 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 92 | 93 | self.classifier_cls = nn.Conv2d(2048, self.num_classes, kernel_size=1, bias=False) 94 | 95 | self.classifier_loc = nn.Sequential( 96 | nn.Conv2d(1024, self.num_classes, kernel_size=3, padding=1, bias=False), 97 | nn.Sigmoid(), 98 | ) 99 | initialize_weights(self.modules(), init_mode='xavier') 100 | 101 | def forward(self, x, label=None): 102 | batch = x.size(0) 103 | 104 | x = self.conv1(x) 105 | x = self.bn1(x) 106 | x = self.relu(x) 107 | x = self.maxpool(x) 108 | 109 | x = self.layer1(x) 110 | x = self.layer2(x) 111 | 112 | x = self.layer3(x) 113 | x_3 = x.clone() 114 | x = self.layer4(x) 115 | 116 | x = self.classifier_cls(x) 117 | self.feature_map = x 118 | 119 | ## score 120 | x = self.avg_pool(x).view(x.size(0), -1) 121 | self.score_1 = x 122 | 123 | ## M 124 | M = self.classifier_loc(x_3) 125 | if label == None: 126 | M = M[0] + M[1].flip(-1) 127 | return M 128 | 129 | # p_label (random select one ground-truth label) 130 | p_label = self.get_p_label(label) 131 | 132 | ## S 133 | self.S = torch.zeros(batch).cuda() 134 | for i in range(batch): 135 | self.S[i] = self.score_1[i][p_label[i]] 136 | ## M_fg 137 | M_fg = torch.zeros(batch, 1, M.size(2), M.size(3)).cuda() 138 | for i in range(batch): 139 | M_fg[i][0] = M[i][p_label[i]] 140 | self.M_fg = M_fg 141 | 142 | ## weight_copy 143 | classifier_cls_copy = copy.deepcopy(self.classifier_cls) 144 | layer4_copy = copy.deepcopy(self.layer4) 145 | 146 | ## erase 147 | x_erase = x_3.detach() * (1-M_fg) 148 | x_erase = layer4_copy(x_erase) 149 | x_erase = classifier_cls_copy(x_erase) 150 | x_erase = self.avg_pool(x_erase).view(x_erase.size(0), -1) 151 | 152 | ## S_bg 153 | self.S_bg = torch.zeros(batch).cuda() 154 | for i in range(batch): 155 | self.S_bg[i] = x_erase[i][p_label[i]] 156 | 157 | ## score_2 158 | x = self.feature_map.clone().detach() * self.M_fg 159 | self.score_2 = self.avg_pool(x).squeeze(-1).squeeze(-1) 160 | 161 | ## bas 162 | S = nn.ReLU()(self.S.clone().detach()).view(batch,-1).mean(1) 163 | S_bg = nn.ReLU()(self.S_bg).view(batch,-1).mean(1) 164 | 165 | bas = S_bg / (S+1e-8) 166 | bas[S_bg>S] = 1 167 | area = M_fg.clone().view(batch, -1).mean(1) 168 | 169 | bas, area = bas.mean(0), area.mean(0) 170 | 171 | ## score_2_loss 172 | cls_loss_2 = nn.CrossEntropyLoss()(self.score_2 , p_label).cuda() 173 | 174 | return self.score_1 , cls_loss_2, bas, area 175 | 176 | 177 | def _make_layer(self, block, planes, blocks, stride): 178 | layers = self._layer(block, planes, blocks, stride) 179 | return nn.Sequential(*layers) 180 | 181 | def _layer(self, block, planes, blocks, stride): 182 | downsample = get_downsampling_layer(self.inplanes, block, planes, 183 | stride) 184 | 185 | layers = [block(self.inplanes, planes, stride, downsample)] 186 | self.inplanes = planes * block.expansion 187 | for _ in range(1, blocks): 188 | layers.append(block(self.inplanes, planes)) 189 | 190 | return layers 191 | 192 | def get_p_label(self, label): 193 | batch = label.size(0) 194 | p_label = torch.zeros(batch).cuda() 195 | p_label_one_hot = torch.zeros(batch,self.num_classes).cuda() 196 | for i in range(batch): 197 | p_batch_label = torch.nonzero(label[i]).squeeze(-1) 198 | p_batch_label_random = random.randint(0,len(p_batch_label)-1) 199 | p_batch_label_random = p_batch_label[p_batch_label_random] 200 | p_label[i] = p_batch_label_random 201 | p_label_one_hot[i][p_batch_label_random] = 1 202 | p_label = p_label.long() 203 | p_label_one_hot = p_label_one_hot.long() 204 | return p_label 205 | 206 | def normalize_atten_maps(self, atten_maps): 207 | atten_shape = atten_maps.size() 208 | 209 | #-------------------------- 210 | batch_mins, _ = torch.min(atten_maps.view(atten_shape[0:-2] + (-1,)), dim=-1, keepdim=True) 211 | batch_maxs, _ = torch.max(atten_maps.view(atten_shape[0:-2] + (-1,)), dim=-1, keepdim=True) 212 | atten_normed = torch.div(atten_maps.view(atten_shape[0:-2] + (-1,))-batch_mins, 213 | batch_maxs - batch_mins + 1e-10) 214 | atten_normed = atten_normed.view(atten_shape) 215 | 216 | return atten_normed 217 | 218 | def get_downsampling_layer(inplanes, block, planes, stride): 219 | outplanes = planes * block.expansion 220 | if stride == 1 and inplanes == outplanes: 221 | return 222 | else: 223 | return nn.Sequential( 224 | nn.Conv2d(inplanes, outplanes, 1, stride, bias=False), 225 | nn.BatchNorm2d(outplanes), 226 | ) 227 | 228 | def remove_layer(state_dict, keyword): 229 | keys = [key for key in state_dict.keys()] 230 | for key in keys: 231 | if keyword in key: 232 | state_dict.pop(key) 233 | return state_dict 234 | 235 | def load_pretrained_model(model): 236 | strict_rule = True 237 | 238 | state_dict = torch.load('sess/resnet50-19c8e357.pth') 239 | 240 | state_dict = remove_layer(state_dict, 'fc') 241 | strict_rule = False 242 | 243 | model.load_state_dict(state_dict, strict=strict_rule) 244 | return model 245 | 246 | 247 | def model(pretrained=True): 248 | model = ResNetCam(Bottleneck, [3, 4, 6, 3]) 249 | if pretrained: 250 | model = load_pretrained_model(model) 251 | return model 252 | -------------------------------------------------------------------------------- /WSSS/net/resnet50_cam.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from misc import torchutils 4 | from net import resnet50 5 | 6 | 7 | class Net(nn.Module): 8 | 9 | def __init__(self, stride=16): 10 | super(Net, self).__init__() 11 | if stride == 16: 12 | self.resnet50 = resnet50.resnet50(pretrained=True, strides=(2, 2, 2, 1)) 13 | self.stage1 = nn.Sequential(self.resnet50.conv1, self.resnet50.bn1, self.resnet50.relu, self.resnet50.maxpool, 14 | self.resnet50.layer1) 15 | else: 16 | self.resnet50 = resnet50.resnet50(pretrained=True, strides=(2, 2, 1, 1), dilations=(1, 1, 2, 2)) 17 | self.stage1 = nn.Sequential(self.resnet50.conv1, self.resnet50.bn1, self.resnet50.relu, self.resnet50.maxpool, 18 | self.resnet50.layer1) 19 | self.stage2 = nn.Sequential(self.resnet50.layer2) 20 | self.stage3 = nn.Sequential(self.resnet50.layer3) 21 | self.stage4 = nn.Sequential(self.resnet50.layer4) 22 | 23 | self.classifier = nn.Conv2d(2048, 20, 1, bias=False) 24 | 25 | self.backbone = nn.ModuleList([self.stage1, self.stage2, self.stage3, self.stage4]) 26 | self.newly_added = nn.ModuleList([self.classifier]) 27 | 28 | def forward(self, x): 29 | 30 | x = self.stage1(x) 31 | x = self.stage2(x) 32 | 33 | x = self.stage3(x) 34 | x = self.stage4(x) 35 | 36 | 37 | x = torchutils.gap2d(x, keepdims=True) 38 | x = self.classifier(x) 39 | x = x.view(-1, 20) 40 | 41 | return x 42 | 43 | def train(self, mode=True): 44 | for p in self.resnet50.conv1.parameters(): 45 | p.requires_grad = False 46 | for p in self.resnet50.bn1.parameters(): 47 | p.requires_grad = False 48 | 49 | def trainable_parameters(self): 50 | 51 | return (list(self.backbone.parameters()), list(self.newly_added.parameters())) 52 | 53 | 54 | class CAM(Net): 55 | 56 | def __init__(self, stride=16): 57 | super(CAM, self).__init__(stride=stride) 58 | 59 | def forward(self, x, separate=False): 60 | x = self.stage1(x) 61 | 62 | x = self.stage2(x) 63 | 64 | x = self.stage3(x) 65 | 66 | x = self.stage4(x) 67 | 68 | x = F.conv2d(x, self.classifier.weight) 69 | if separate: 70 | return x 71 | x = F.relu(x) 72 | x = x[0] + x[1].flip(-1) 73 | 74 | return x 75 | -------------------------------------------------------------------------------- /WSSS/net/resnet50_irn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from net import resnet50 5 | 6 | 7 | class Net(nn.Module): 8 | 9 | def __init__(self): 10 | super(Net, self).__init__() 11 | 12 | # backbone 13 | self.resnet50 = resnet50.resnet50(pretrained=True, strides=[2, 2, 2, 1]) 14 | 15 | self.stage1 = nn.Sequential(self.resnet50.conv1, self.resnet50.bn1, self.resnet50.relu, self.resnet50.maxpool) 16 | self.stage2 = nn.Sequential(self.resnet50.layer1) 17 | self.stage3 = nn.Sequential(self.resnet50.layer2) 18 | self.stage4 = nn.Sequential(self.resnet50.layer3) 19 | self.stage5 = nn.Sequential(self.resnet50.layer4) 20 | self.mean_shift = Net.MeanShift(2) 21 | 22 | # branch: class boundary detection 23 | self.fc_edge1 = nn.Sequential( 24 | nn.Conv2d(64, 32, 1, bias=False), 25 | nn.GroupNorm(4, 32), 26 | nn.ReLU(inplace=True), 27 | ) 28 | self.fc_edge2 = nn.Sequential( 29 | nn.Conv2d(256, 32, 1, bias=False), 30 | nn.GroupNorm(4, 32), 31 | nn.ReLU(inplace=True), 32 | ) 33 | self.fc_edge3 = nn.Sequential( 34 | nn.Conv2d(512, 32, 1, bias=False), 35 | nn.GroupNorm(4, 32), 36 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), 37 | nn.ReLU(inplace=True), 38 | ) 39 | self.fc_edge4 = nn.Sequential( 40 | nn.Conv2d(1024, 32, 1, bias=False), 41 | nn.GroupNorm(4, 32), 42 | nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False), 43 | nn.ReLU(inplace=True), 44 | ) 45 | self.fc_edge5 = nn.Sequential( 46 | nn.Conv2d(2048, 32, 1, bias=False), 47 | nn.GroupNorm(4, 32), 48 | nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False), 49 | nn.ReLU(inplace=True), 50 | ) 51 | self.fc_edge6 = nn.Conv2d(160, 1, 1, bias=True) 52 | 53 | # branch: displacement field 54 | self.fc_dp1 = nn.Sequential( 55 | nn.Conv2d(64, 64, 1, bias=False), 56 | nn.GroupNorm(8, 64), 57 | nn.ReLU(inplace=True), 58 | ) 59 | self.fc_dp2 = nn.Sequential( 60 | nn.Conv2d(256, 128, 1, bias=False), 61 | nn.GroupNorm(16, 128), 62 | nn.ReLU(inplace=True), 63 | ) 64 | self.fc_dp3 = nn.Sequential( 65 | nn.Conv2d(512, 256, 1, bias=False), 66 | nn.GroupNorm(16, 256), 67 | nn.ReLU(inplace=True), 68 | ) 69 | self.fc_dp4 = nn.Sequential( 70 | nn.Conv2d(1024, 256, 1, bias=False), 71 | nn.GroupNorm(16, 256), 72 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), 73 | nn.ReLU(inplace=True), 74 | ) 75 | self.fc_dp5 = nn.Sequential( 76 | nn.Conv2d(2048, 256, 1, bias=False), 77 | nn.GroupNorm(16, 256), 78 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), 79 | nn.ReLU(inplace=True), 80 | ) 81 | self.fc_dp6 = nn.Sequential( 82 | nn.Conv2d(768, 256, 1, bias=False), 83 | nn.GroupNorm(16, 256), 84 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), 85 | nn.ReLU(inplace=True), 86 | ) 87 | self.fc_dp7 = nn.Sequential( 88 | nn.Conv2d(448, 256, 1, bias=False), 89 | nn.GroupNorm(16, 256), 90 | nn.ReLU(inplace=True), 91 | nn.Conv2d(256, 2, 1, bias=False), 92 | self.mean_shift 93 | ) 94 | 95 | self.backbone = nn.ModuleList([self.stage1, self.stage2, self.stage3, self.stage4, self.stage5]) 96 | self.edge_layers = nn.ModuleList([self.fc_edge1, self.fc_edge2, self.fc_edge3, self.fc_edge4, self.fc_edge5, self.fc_edge6]) 97 | self.dp_layers = nn.ModuleList([self.fc_dp1, self.fc_dp2, self.fc_dp3, self.fc_dp4, self.fc_dp5, self.fc_dp6, self.fc_dp7]) 98 | 99 | class MeanShift(nn.Module): 100 | 101 | def __init__(self, num_features): 102 | super(Net.MeanShift, self).__init__() 103 | self.register_buffer('running_mean', torch.zeros(num_features)) 104 | 105 | def forward(self, input): 106 | if self.training: 107 | return input 108 | return input - self.running_mean.view(1, 2, 1, 1) 109 | 110 | def forward(self, x): 111 | x1 = self.stage1(x).detach() 112 | x2 = self.stage2(x1).detach() 113 | x3 = self.stage3(x2).detach() 114 | x4 = self.stage4(x3).detach() 115 | x5 = self.stage5(x4).detach() 116 | 117 | edge1 = self.fc_edge1(x1) 118 | edge2 = self.fc_edge2(x2) 119 | edge3 = self.fc_edge3(x3)[..., :edge2.size(2), :edge2.size(3)] 120 | edge4 = self.fc_edge4(x4)[..., :edge2.size(2), :edge2.size(3)] 121 | edge5 = self.fc_edge5(x5)[..., :edge2.size(2), :edge2.size(3)] 122 | edge_out = self.fc_edge6(torch.cat([edge1, edge2, edge3, edge4, edge5], dim=1)) 123 | 124 | dp1 = self.fc_dp1(x1) 125 | dp2 = self.fc_dp2(x2) 126 | dp3 = self.fc_dp3(x3) 127 | dp4 = self.fc_dp4(x4)[..., :dp3.size(2), :dp3.size(3)] 128 | dp5 = self.fc_dp5(x5)[..., :dp3.size(2), :dp3.size(3)] 129 | 130 | dp_up3 = self.fc_dp6(torch.cat([dp3, dp4, dp5], dim=1))[..., :dp2.size(2), :dp2.size(3)] 131 | dp_out = self.fc_dp7(torch.cat([dp1, dp2, dp_up3], dim=1)) 132 | 133 | return edge_out, dp_out 134 | 135 | def trainable_parameters(self): 136 | return (tuple(self.edge_layers.parameters()), 137 | tuple(self.dp_layers.parameters())) 138 | 139 | def train(self, mode=True): 140 | super().train(mode) 141 | self.backbone.eval() 142 | 143 | 144 | class AffinityDisplacementLoss(Net): 145 | 146 | path_indices_prefix = "path_indices" 147 | 148 | def __init__(self, path_index): 149 | 150 | super(AffinityDisplacementLoss, self).__init__() 151 | 152 | self.path_index = path_index 153 | 154 | self.n_path_lengths = len(path_index.path_indices) 155 | for i, pi in enumerate(path_index.path_indices): 156 | self.register_buffer(AffinityDisplacementLoss.path_indices_prefix + str(i), torch.from_numpy(pi)) 157 | 158 | self.register_buffer( 159 | 'disp_target', 160 | torch.unsqueeze(torch.unsqueeze(torch.from_numpy(path_index.search_dst).transpose(1, 0), 0), -1).float()) 161 | 162 | def to_affinity(self, edge): 163 | aff_list = [] 164 | edge = edge.view(edge.size(0), -1) 165 | 166 | for i in range(self.n_path_lengths): 167 | ind = self._buffers[AffinityDisplacementLoss.path_indices_prefix + str(i)] 168 | ind_flat = ind.view(-1) 169 | dist = torch.index_select(edge, dim=-1, index=ind_flat) 170 | dist = dist.view(dist.size(0), ind.size(0), ind.size(1), ind.size(2)) 171 | aff = torch.squeeze(1 - F.max_pool2d(dist, (dist.size(2), 1)), dim=2) 172 | aff_list.append(aff) 173 | aff_cat = torch.cat(aff_list, dim=1) 174 | 175 | return aff_cat 176 | 177 | def to_pair_displacement(self, disp): 178 | height, width = disp.size(2), disp.size(3) 179 | radius_floor = self.path_index.radius_floor 180 | 181 | cropped_height = height - radius_floor 182 | cropped_width = width - 2 * radius_floor 183 | 184 | disp_src = disp[:, :, :cropped_height, radius_floor:radius_floor + cropped_width] 185 | 186 | disp_dst = [disp[:, :, dy:dy + cropped_height, radius_floor + dx:radius_floor + dx + cropped_width] 187 | for dy, dx in self.path_index.search_dst] 188 | disp_dst = torch.stack(disp_dst, 2) 189 | 190 | pair_disp = torch.unsqueeze(disp_src, 2) - disp_dst 191 | pair_disp = pair_disp.view(pair_disp.size(0), pair_disp.size(1), pair_disp.size(2), -1) 192 | 193 | return pair_disp 194 | 195 | def to_displacement_loss(self, pair_disp): 196 | return torch.abs(pair_disp - self.disp_target) 197 | 198 | def forward(self, *inputs): 199 | x, return_loss = inputs 200 | edge_out, dp_out = super().forward(x) 201 | 202 | if return_loss is False: 203 | return edge_out, dp_out 204 | 205 | aff = self.to_affinity(torch.sigmoid(edge_out)) 206 | pos_aff_loss = (-1) * torch.log(aff + 1e-5) 207 | neg_aff_loss = (-1) * torch.log(1. + 1e-5 - aff) 208 | 209 | pair_disp = self.to_pair_displacement(dp_out) 210 | dp_fg_loss = self.to_displacement_loss(pair_disp) 211 | dp_bg_loss = torch.abs(pair_disp) 212 | 213 | return pos_aff_loss, neg_aff_loss, dp_fg_loss, dp_bg_loss 214 | 215 | 216 | class EdgeDisplacement(Net): 217 | 218 | def __init__(self, crop_size=512, stride=4): 219 | super(EdgeDisplacement, self).__init__() 220 | self.crop_size = crop_size 221 | self.stride = stride 222 | 223 | def forward(self, x): 224 | feat_size = (x.size(2)-1)//self.stride+1, (x.size(3)-1)//self.stride+1 225 | 226 | x = F.pad(x, [0, self.crop_size-x.size(3), 0, self.crop_size-x.size(2)]) 227 | edge_out, dp_out = super().forward(x) 228 | edge_out = edge_out[..., :feat_size[0], :feat_size[1]] 229 | dp_out = dp_out[..., :feat_size[0], :feat_size[1]] 230 | 231 | edge_out = torch.sigmoid(edge_out[0]/2 + edge_out[1].flip(-1)/2) 232 | dp_out = dp_out[0] 233 | 234 | return edge_out, dp_out 235 | 236 | 237 | -------------------------------------------------------------------------------- /WSSS/run_sample.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | from misc import pyutils 5 | 6 | if __name__ == '__main__': 7 | os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3' 8 | parser = argparse.ArgumentParser() 9 | 10 | # Environment 11 | parser.add_argument("--num_workers", default=os.cpu_count()//2, type=int) 12 | parser.add_argument("--voc12_root", default='/home2/wpy/WSSS/VOC2012/', type=str, 13 | help="Path to VOC 2012 Devkit, must contain ./JPEGImages as subdirectory.") 14 | 15 | # Dataset 16 | parser.add_argument("--train_list", default="voc12/train_aug.txt", type=str) 17 | parser.add_argument("--val_list", default="voc12/val.txt", type=str) 18 | parser.add_argument("--infer_list", default="voc12/train.txt", type=str, 19 | help="voc12/train_aug.txt to train a fully supervised model, " 20 | "voc12/train.txt or voc12/val.txt to quickly check the quality of the labels.") 21 | parser.add_argument("--chainer_eval_set", default="train", type=str) 22 | 23 | # Class Activation Map 24 | parser.add_argument("--cam_network", default="net.resnet50_bas", type=str) 25 | parser.add_argument("--cam_crop_size", default=512, type=int) 26 | parser.add_argument("--cam_batch_size", default=16, type=int) 27 | parser.add_argument("--cam_num_epoches", default=5, type=int) 28 | parser.add_argument("--cam_learning_rate", default=0.1, type=float) 29 | parser.add_argument("--cam_weight_decay", default=1e-4, type=float) 30 | parser.add_argument("--cam_eval_thres", default=0.3, type=float) 31 | parser.add_argument("--cam_scales", default=(1.0, 0.5, 1.5, 2.0), 32 | help="Multi-scale inferences") 33 | 34 | # Mining Inter-pixel Relations 35 | parser.add_argument("--conf_fg_thres", default=0.30, type=float) 36 | parser.add_argument("--conf_bg_thres", default=0.17, type=float) 37 | 38 | # Inter-pixel Relation Network (IRNet) 39 | parser.add_argument("--irn_network", default="net.resnet50_irn", type=str) 40 | parser.add_argument("--irn_crop_size", default=512, type=int) 41 | parser.add_argument("--irn_batch_size", default=32, type=int) 42 | parser.add_argument("--irn_num_epoches", default=3, type=int) 43 | parser.add_argument("--irn_learning_rate", default=0.1, type=float) 44 | parser.add_argument("--irn_weight_decay", default=1e-4, type=float) 45 | 46 | # Random Walk Params 47 | parser.add_argument("--beta", default=10) 48 | parser.add_argument("--exp_times", default=8, 49 | help="Hyper-parameter that controls the number of random walk iterations," 50 | "The random walk is performed 2^{exp_times}.") 51 | parser.add_argument("--sem_seg_bg_thres", default=0.22) 52 | 53 | # Output Path 54 | parser.add_argument("--log_name", default="sample_train_eval", type=str) 55 | parser.add_argument("--cam_weights_name", default="sess/res50_cam.pth", type=str) 56 | parser.add_argument("--irn_weights_name", default="sess/res50_irn.pth", type=str) 57 | parser.add_argument("--cam_out_dir", default="result/cam", type=str) 58 | parser.add_argument("--ir_label_out_dir", default="result/ir_label", type=str) 59 | parser.add_argument("--sem_seg_out_dir", default="result/sem_seg", type=str) 60 | 61 | # Step 62 | parser.add_argument("--make_cam_pass", default=True) 63 | parser.add_argument("--eval_cam_pass", default=True) 64 | parser.add_argument("--cam_to_ir_label_pass", default=False) 65 | parser.add_argument("--train_irn_pass", default=False) 66 | parser.add_argument("--make_sem_seg_pass", default=True) 67 | parser.add_argument("--eval_sem_seg_pass", default=True) 68 | 69 | args = parser.parse_args() 70 | 71 | os.makedirs("sess", exist_ok=True) 72 | os.makedirs(args.cam_out_dir, exist_ok=True) 73 | os.makedirs(args.ir_label_out_dir, exist_ok=True) 74 | os.makedirs(args.sem_seg_out_dir, exist_ok=True) 75 | 76 | pyutils.Logger(args.log_name + '.log') 77 | 78 | print(vars(args)) 79 | 80 | if args.make_cam_pass is True: 81 | import step.make_cam 82 | 83 | timer = pyutils.Timer('step.make_cam:') 84 | step.make_cam.run(args) 85 | 86 | if args.eval_cam_pass is True: 87 | import step.eval_cam 88 | 89 | timer = pyutils.Timer('step.eval_cam:') 90 | final_miou = [] 91 | for i in range(20, 45): 92 | t = i/100.0 93 | args.cam_eval_thres = t 94 | miou = step.eval_cam.run(args) 95 | final_miou.append(miou) 96 | print(final_miou) 97 | print(np.max(np.array(final_miou))) 98 | 99 | if args.cam_to_ir_label_pass is True: 100 | import step.cam_to_ir_label 101 | 102 | timer = pyutils.Timer('step.cam_to_ir_label:') 103 | step.cam_to_ir_label.run(args) 104 | 105 | if args.train_irn_pass is True: 106 | import step.train_irn 107 | 108 | timer = pyutils.Timer('step.train_irn:') 109 | step.train_irn.run(args) 110 | 111 | if args.make_sem_seg_pass is True: 112 | import step.make_sem_seg_labels 113 | 114 | timer = pyutils.Timer('step.make_sem_seg_labels:') 115 | step.make_sem_seg_labels.run(args) 116 | 117 | if args.eval_sem_seg_pass is True: 118 | import step.eval_sem_seg 119 | 120 | timer = pyutils.Timer('step.eval_sem_seg:') 121 | step.eval_sem_seg.run(args) 122 | 123 | -------------------------------------------------------------------------------- /WSSS/step/cam_to_ir_label.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import numpy as np 4 | import imageio 5 | 6 | from torch import multiprocessing 7 | from torch.utils.data import DataLoader 8 | from PIL import Image 9 | import voc12.dataloader 10 | from misc import torchutils, imutils 11 | 12 | palette = [0,0,0, 128,0,0, 0,128,0, 128,128,0, 0,0,128, 128,0,128, 0,128,128, 128,128,128, 13 | 64,0,0, 192,0,0, 64,128,0, 192,128,0, 64,0,128, 192,0,128, 64,128,128, 192,128,128, 14 | 0,64,0, 128,64,0, 0,192,0, 128,192,0, 0,64,128, 128,64,128, 0,192,128, 128,192,128, 15 | 64,64,0, 192,64,0, 64,192,0, 192,192,0] 16 | 17 | 18 | def _work(process_id, infer_dataset, args): 19 | 20 | databin = infer_dataset[process_id] 21 | infer_data_loader = DataLoader(databin, shuffle=False, num_workers=0, pin_memory=False) 22 | 23 | for iter, pack in enumerate(infer_data_loader): 24 | img_name = voc12.dataloader.decode_int_filename(pack['name'][0]) 25 | img = pack['img'][0].numpy() 26 | cam_dict = np.load(os.path.join(args.cam_out_dir, img_name + '.npy'), allow_pickle=True).item() 27 | 28 | cams = cam_dict['high_res'] 29 | keys = np.pad(cam_dict['keys'] + 1, (1, 0), mode='constant') 30 | 31 | # 1. find confident fg & bg 32 | fg_conf_cam = np.pad(cams, ((1, 0), (0, 0), (0, 0)), mode='constant', constant_values=args.conf_fg_thres) 33 | fg_conf_cam = np.argmax(fg_conf_cam, axis=0) 34 | pred = imutils.crf_inference_label(img, fg_conf_cam, n_labels=keys.shape[0]) 35 | 36 | fg_conf = keys[pred] 37 | 38 | bg_conf_cam = np.pad(cams, ((1, 0), (0, 0), (0, 0)), mode='constant', constant_values=args.conf_bg_thres) 39 | bg_conf_cam = np.argmax(bg_conf_cam, axis=0) 40 | pred = imutils.crf_inference_label(img, bg_conf_cam, n_labels=keys.shape[0]) 41 | bg_conf = keys[pred] 42 | 43 | # 2. combine confident fg & bg 44 | conf = fg_conf.copy() 45 | conf[fg_conf == 0] = 255 46 | conf[bg_conf + fg_conf == 0] = 0 47 | 48 | imageio.imwrite(os.path.join(args.ir_label_out_dir, img_name + '.png'), 49 | conf.astype(np.uint8)) 50 | 51 | 52 | if process_id == args.num_workers - 1 and iter % (len(databin) // 20) == 0: 53 | print("%d " % ((5 * iter + 1) // (len(databin) // 20)), end='') 54 | 55 | def run(args): 56 | dataset = voc12.dataloader.VOC12ImageDataset(args.train_list, voc12_root=args.voc12_root, img_normal=None, to_torch=False) 57 | dataset = torchutils.split_dataset(dataset, args.num_workers) 58 | 59 | print('[ ', end='') 60 | multiprocessing.spawn(_work, nprocs=args.num_workers, args=(dataset, args), join=True) 61 | print(']') 62 | -------------------------------------------------------------------------------- /WSSS/step/eval_cam.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import os 4 | from chainercv.datasets import VOCSemanticSegmentationDataset 5 | from chainercv.evaluations import calc_semantic_segmentation_confusion 6 | 7 | def run(args): 8 | dataset = VOCSemanticSegmentationDataset(split=args.chainer_eval_set, data_dir=args.voc12_root) 9 | # labels = [dataset.get_example_by_keys(i, (1,))[0] for i in range(len(dataset))] 10 | 11 | preds = [] 12 | labels = [] 13 | n_images = 0 14 | for i, id in enumerate(dataset.ids): 15 | n_images += 1 16 | cam_dict = np.load(os.path.join(args.cam_out_dir, id + '.npy'), allow_pickle=True).item() 17 | cams = cam_dict['high_res'] 18 | cams = np.pad(cams, ((1, 0), (0, 0), (0, 0)), mode='constant', constant_values=args.cam_eval_thres) 19 | keys = np.pad(cam_dict['keys'] + 1, (1, 0), mode='constant') 20 | cls_labels = np.argmax(cams, axis=0) 21 | cls_labels = keys[cls_labels] 22 | preds.append(cls_labels.copy()) 23 | labels.append(dataset.get_example_by_keys(i, (1,))[0]) 24 | 25 | confusion = calc_semantic_segmentation_confusion(preds, labels) 26 | 27 | gtj = confusion.sum(axis=1) 28 | resj = confusion.sum(axis=0) 29 | gtjresj = np.diag(confusion) 30 | denominator = gtj + resj - gtjresj 31 | iou = gtjresj / denominator 32 | 33 | 34 | print("threshold:", args.cam_eval_thres, 'miou:', np.nanmean(iou), "i_imgs", n_images) 35 | print('among_predfg_bg', float((resj[1:].sum()-confusion[1:,1:].sum())/(resj[1:].sum()))) 36 | 37 | return np.nanmean(iou) -------------------------------------------------------------------------------- /WSSS/step/eval_sem_seg.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import os 4 | from chainercv.datasets import VOCSemanticSegmentationDataset 5 | from chainercv.evaluations import calc_semantic_segmentation_confusion 6 | import imageio 7 | 8 | def run(args): 9 | dataset = VOCSemanticSegmentationDataset(split=args.chainer_eval_set, data_dir=args.voc12_root) 10 | # labels = [dataset.get_example_by_keys(i, (1,))[0] for i in range(len(dataset))] 11 | 12 | preds = [] 13 | labels = [] 14 | n_img = 0 15 | for i, id in enumerate(dataset.ids): 16 | cls_labels = imageio.imread(os.path.join(args.sem_seg_out_dir, id + '.png')).astype(np.uint8) 17 | cls_labels[cls_labels == 255] = 0 18 | preds.append(cls_labels.copy()) 19 | labels.append(dataset.get_example_by_keys(i, (1,))[0]) 20 | n_img += 1 21 | 22 | confusion = calc_semantic_segmentation_confusion(preds, labels)[:21, :21] 23 | 24 | gtj = confusion.sum(axis=1) 25 | resj = confusion.sum(axis=0) 26 | gtjresj = np.diag(confusion) 27 | denominator = gtj + resj - gtjresj 28 | fp = 1. - gtj / denominator 29 | fn = 1. - resj / denominator 30 | iou = gtjresj / denominator 31 | print("total images", n_img) 32 | print(fp[0], fn[0]) 33 | print(np.mean(fp[1:]), np.mean(fn[1:])) 34 | 35 | print({'iou': iou, 'miou': np.nanmean(iou)}) 36 | -------------------------------------------------------------------------------- /WSSS/step/make_cam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import multiprocessing, cuda 3 | from torch.utils.data import DataLoader 4 | import torch.nn.functional as F 5 | from torch.backends import cudnn 6 | 7 | import numpy as np 8 | import importlib 9 | import os 10 | 11 | import voc12.dataloader 12 | from misc import torchutils, imutils 13 | 14 | cudnn.enabled = True 15 | 16 | def _work(process_id, model,model_ori, dataset, args): 17 | 18 | databin = dataset[process_id] 19 | n_gpus = torch.cuda.device_count() 20 | data_loader = DataLoader(databin, shuffle=False, num_workers=args.num_workers // n_gpus, pin_memory=False) 21 | 22 | with torch.no_grad(), cuda.device(process_id): 23 | 24 | model.cuda() 25 | model_ori.cuda() 26 | 27 | for iter, pack in enumerate(data_loader): 28 | 29 | img_name = pack['name'][0] 30 | label = pack['label'][0] 31 | size = pack['size'] 32 | 33 | strided_size = imutils.get_strided_size(size, 4) 34 | strided_up_size = imutils.get_strided_up_size(size, 16) 35 | 36 | BAS_outputs = [model(img[0].cuda(non_blocking=True)) 37 | for img in pack['img']] 38 | 39 | strided_BAS = torch.sum(torch.stack( 40 | [F.interpolate(torch.unsqueeze(o, 0), strided_size, mode='bilinear', align_corners=False)[0] for o 41 | in BAS_outputs]), 0) 42 | 43 | highres_BAS = [F.interpolate(torch.unsqueeze(o, 1), strided_up_size, 44 | mode='bilinear', align_corners=False) for o in BAS_outputs] 45 | highres_BAS = torch.sum(torch.stack(highres_BAS, 0), 0)[:, 0, :size[0], :size[1]] 46 | 47 | valid_cat = torch.nonzero(label)[:, 0] 48 | 49 | strided_BAS = strided_BAS[valid_cat] 50 | strided_BAS /= F.adaptive_max_pool2d(strided_BAS, (1, 1)) + 1e-5 51 | 52 | highres_BAS = highres_BAS[valid_cat] 53 | highres_BAS /= F.adaptive_max_pool2d(highres_BAS, (1, 1)) + 1e-5 54 | 55 | if True: ## combine with CAM 56 | cam = [model_ori(img[0].cuda(non_blocking=True)) 57 | for img in pack['img']] 58 | 59 | strided_cam = torch.sum(torch.stack( 60 | [F.interpolate(torch.unsqueeze(o, 0), strided_size, mode='bilinear', align_corners=False)[0] for o 61 | in cam]), 0) 62 | 63 | highre_cam = [F.interpolate(torch.unsqueeze(o, 1), strided_up_size, 64 | mode='bilinear', align_corners=False) for o in cam] 65 | highre_cam = torch.sum(torch.stack(highre_cam, 0), 0)[:, 0, :size[0], :size[1]] 66 | 67 | strided_cam = strided_cam[valid_cat] 68 | strided_cam /= F.adaptive_max_pool2d(strided_cam, (1, 1)) + 1e-5 69 | 70 | strided_BAS = strided_cam + strided_BAS 71 | strided_BAS /= F.adaptive_max_pool2d(strided_BAS, (1, 1)) + 1e-5 72 | 73 | highre_cam = highre_cam[valid_cat] 74 | highre_cam /= F.adaptive_max_pool2d(highre_cam, (1, 1)) + 1e-5 75 | 76 | highres_BAS = highre_cam + highres_BAS 77 | highres_BAS /= F.adaptive_max_pool2d(highres_BAS, (1, 1)) + 1e-5 78 | 79 | # save cams 80 | np.save(os.path.join(args.cam_out_dir, img_name + '.npy'), 81 | {"keys": valid_cat, "cam": strided_BAS.cpu(), "high_res": highres_BAS.cpu().numpy()}) 82 | 83 | if process_id == n_gpus - 1 and iter % (len(databin) // 20) == 0: 84 | print("%d " % ((5*iter+1)//(len(databin) // 20)), end='') 85 | 86 | def run(args): 87 | model_ori = getattr(importlib.import_module("net.resnet50_cam"), 'CAM')() 88 | model_ori.load_state_dict(torch.load(args.cam_weights_name + '.pth'), strict=True) 89 | model_ori.eval() 90 | 91 | model = getattr(importlib.import_module(args.cam_network), 'model')(args) 92 | checkpoint = torch.load('sess/epoch_10.pth.tar') 93 | pretrained_dict = {k[7:]: v for k, v in checkpoint.items()} 94 | model.load_state_dict(pretrained_dict) 95 | model.eval() 96 | 97 | n_gpus = torch.cuda.device_count() 98 | 99 | dataset = voc12.dataloader.VOC12ClassificationDatasetMSF(args.train_list, 100 | voc12_root=args.voc12_root, scales=args.cam_scales) 101 | dataset = torchutils.split_dataset(dataset, n_gpus) 102 | 103 | print('[ ', end='') 104 | multiprocessing.spawn(_work, nprocs=n_gpus, args=(model,model_ori, dataset, args), join=True) 105 | print(']') 106 | 107 | torch.cuda.empty_cache() -------------------------------------------------------------------------------- /WSSS/step/make_sem_seg_labels.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import multiprocessing, cuda 3 | from torch.utils.data import DataLoader 4 | import torch.nn.functional as F 5 | from torch.backends import cudnn 6 | import numpy as np 7 | import importlib 8 | import os 9 | import imageio 10 | 11 | import voc12.dataloader 12 | from misc import torchutils, indexing, imutils 13 | 14 | cudnn.enabled = True 15 | 16 | def _work(process_id, model, dataset, args): 17 | 18 | n_gpus = torch.cuda.device_count() 19 | databin = dataset[process_id] 20 | data_loader = DataLoader(databin, 21 | shuffle=False, num_workers=args.num_workers // n_gpus, pin_memory=False) 22 | 23 | with torch.no_grad(), cuda.device(process_id): 24 | 25 | model.cuda() 26 | 27 | for iter, pack in enumerate(data_loader): 28 | img_name = voc12.dataloader.decode_int_filename(pack['name'][0]) 29 | orig_img_size = np.asarray(pack['size']) 30 | ori_image = pack['img'][0].numpy() 31 | 32 | edge, dp = model(pack['img'][0].cuda(non_blocking=True)) 33 | 34 | cam_dict = np.load(args.cam_out_dir + '/' + img_name + '.npy', allow_pickle=True).item() 35 | 36 | cams = cam_dict['cam'] 37 | keys = np.pad(cam_dict['keys'] + 1, (1, 0), mode='constant') 38 | cams = np.power(cams, 1.5) 39 | cam_downsized_values = cams.cuda() 40 | 41 | rw = indexing.propagate_to_edge(cam_downsized_values, edge, beta=args.beta, exp_times=args.exp_times, radius=5) 42 | 43 | rw_up = F.interpolate(rw, scale_factor=4, mode='bilinear', align_corners=False)[..., 0, :orig_img_size[0], :orig_img_size[1]] 44 | 45 | rw_up = rw_up / torch.max(rw_up) 46 | 47 | rw_up_bg = F.pad(rw_up, (0, 0, 0, 0, 1, 0), value=args.sem_seg_bg_thres) 48 | rw_pred = torch.argmax(rw_up_bg, dim=0).cpu().numpy() 49 | 50 | 51 | rw_pred = keys[rw_pred] 52 | 53 | imageio.imsave(os.path.join(args.sem_seg_out_dir, img_name + '.png'), rw_pred.astype(np.uint8)) 54 | 55 | if process_id == n_gpus - 1 and iter % (len(databin) // 20) == 0: 56 | print("%d " % ((5*iter+1)//(len(databin) // 20)), end='') 57 | 58 | 59 | 60 | def run(args): 61 | model = getattr(importlib.import_module(args.irn_network), 'EdgeDisplacement')() 62 | model.load_state_dict(torch.load(args.irn_weights_name), strict=False) 63 | model.eval() 64 | 65 | n_gpus = torch.cuda.device_count() 66 | 67 | dataset = voc12.dataloader.VOC12ClassificationDatasetMSF(args.infer_list, 68 | voc12_root=args.voc12_root, 69 | scales=(1.0,)) 70 | dataset = torchutils.split_dataset(dataset, n_gpus) 71 | 72 | print("[", end='') 73 | multiprocessing.spawn(_work, nprocs=n_gpus, args=(model, dataset, args), join=True) 74 | print("]") 75 | 76 | torch.cuda.empty_cache() 77 | -------------------------------------------------------------------------------- /WSSS/step/train_bas.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | import torch 4 | from torch.backends import cudnn 5 | cudnn.enabled = True 6 | from torch.utils.data import DataLoader 7 | import torch.nn.functional as F 8 | 9 | import importlib 10 | 11 | import voc12.dataloader 12 | from misc import pyutils, torchutils 13 | from torch import autograd 14 | import os 15 | 16 | 17 | 18 | def run(args): 19 | 20 | model = getattr(importlib.import_module(args.cam_network), 'Net')() 21 | 22 | 23 | train_dataset = voc12.dataloader.VOC12ClassificationDataset(args.train_list, voc12_root=args.voc12_root, 24 | resize_long=(320, 640), hor_flip=True, 25 | crop_size=512, crop_method="random") 26 | train_data_loader = DataLoader(train_dataset, batch_size=args.cam_batch_size, 27 | shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True) 28 | max_step = (len(train_dataset) // args.cam_batch_size) * args.cam_num_epoches 29 | 30 | param_groups = model.trainable_parameters() 31 | optimizer = torchutils.PolyOptimizer([ 32 | {'params': param_groups[0], 'lr': args.cam_learning_rate, 'weight_decay': args.cam_weight_decay}, 33 | {'params': param_groups[1], 'lr': 10*args.cam_learning_rate, 'weight_decay': args.cam_weight_decay}, 34 | ], lr=args.cam_learning_rate, weight_decay=args.cam_weight_decay, max_step=max_step) 35 | 36 | model = torch.nn.DataParallel(model).cuda() 37 | model.train() 38 | 39 | avg_meter = pyutils.AverageMeter() 40 | 41 | timer = pyutils.Timer() 42 | 43 | for ep in range(args.cam_num_epoches): 44 | 45 | print('Epoch %d/%d' % (ep+1, args.cam_num_epoches)) 46 | 47 | for step, pack in enumerate(train_data_loader): 48 | 49 | img = pack['img'] 50 | img = img.cuda() 51 | label = pack['label'].cuda(non_blocking=True) 52 | model.zero_grad() 53 | x = model(img) 54 | 55 | optimizer.zero_grad() 56 | 57 | loss = F.multilabel_soft_margin_loss(x, label) 58 | 59 | loss.backward() 60 | avg_meter.add({'loss1': loss.item()}) 61 | 62 | 63 | optimizer.step() 64 | if (optimizer.global_step-1)%100 == 0: 65 | timer.update_progress(optimizer.global_step / max_step) 66 | 67 | print('step:%5d/%5d' % (optimizer.global_step - 1, max_step), 68 | 'loss:%.4f' % (avg_meter.pop('loss1')), 69 | 'imps:%.1f' % ((step + 1) * args.cam_batch_size / timer.get_stage_elapsed()), 70 | 'lr: %.4f' % (optimizer.param_groups[0]['lr']), 71 | 'etc:%s' % (timer.str_estimated_complete()), flush=True) 72 | 73 | 74 | torch.save(model.module.state_dict(), args.cam_weights_name + '.pth') 75 | torch.cuda.empty_cache() -------------------------------------------------------------------------------- /WSSS/step/train_irn.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch.backends import cudnn 4 | cudnn.enabled = True 5 | from torch.utils.data import DataLoader 6 | import voc12.dataloader 7 | from misc import pyutils, torchutils, indexing 8 | import importlib 9 | from PIL import ImageFile 10 | ImageFile.LOAD_TRUNCATED_IMAGES = True 11 | def run(args): 12 | 13 | path_index = indexing.PathIndex(radius=10, default_size=(args.irn_crop_size // 4, args.irn_crop_size // 4)) 14 | 15 | model = getattr(importlib.import_module(args.irn_network), 'AffinityDisplacementLoss')( 16 | path_index) 17 | 18 | train_dataset = voc12.dataloader.VOC12AffinityDataset(args.train_list, 19 | label_dir=args.ir_label_out_dir, 20 | voc12_root=args.voc12_root, 21 | indices_from=path_index.src_indices, 22 | indices_to=path_index.dst_indices, 23 | hor_flip=True, 24 | crop_size=args.irn_crop_size, 25 | crop_method="random", 26 | rescale=(0.5, 1.5) 27 | ) 28 | train_data_loader = DataLoader(train_dataset, batch_size=args.irn_batch_size, 29 | shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True) 30 | 31 | max_step = (len(train_dataset) // args.irn_batch_size) * args.irn_num_epoches 32 | 33 | param_groups = model.trainable_parameters() 34 | optimizer = torchutils.PolyOptimizer([ 35 | {'params': param_groups[0], 'lr': 1*args.irn_learning_rate, 'weight_decay': args.irn_weight_decay}, 36 | {'params': param_groups[1], 'lr': 10*args.irn_learning_rate, 'weight_decay': args.irn_weight_decay} 37 | ], lr=args.irn_learning_rate, weight_decay=args.irn_weight_decay, max_step=max_step) 38 | 39 | model = torch.nn.DataParallel(model).cuda() 40 | model.train() 41 | 42 | avg_meter = pyutils.AverageMeter() 43 | 44 | timer = pyutils.Timer() 45 | 46 | for ep in range(args.irn_num_epoches): 47 | 48 | print('Epoch %d/%d' % (ep+1, args.irn_num_epoches)) 49 | 50 | for iter, pack in enumerate(train_data_loader): 51 | 52 | img = pack['img'].cuda(non_blocking=True) 53 | bg_pos_label = pack['aff_bg_pos_label'].cuda(non_blocking=True) 54 | fg_pos_label = pack['aff_fg_pos_label'].cuda(non_blocking=True) 55 | neg_label = pack['aff_neg_label'].cuda(non_blocking=True) 56 | 57 | pos_aff_loss, neg_aff_loss, dp_fg_loss, dp_bg_loss = model(img, True) 58 | 59 | bg_pos_aff_loss = torch.sum(bg_pos_label * pos_aff_loss) / (torch.sum(bg_pos_label) + 1e-5) 60 | fg_pos_aff_loss = torch.sum(fg_pos_label * pos_aff_loss) / (torch.sum(fg_pos_label) + 1e-5) 61 | pos_aff_loss = bg_pos_aff_loss / 2 + fg_pos_aff_loss / 2 62 | neg_aff_loss = torch.sum(neg_label * neg_aff_loss) / (torch.sum(neg_label) + 1e-5) 63 | 64 | dp_fg_loss = torch.sum(dp_fg_loss * torch.unsqueeze(fg_pos_label, 1)) / (2 * torch.sum(fg_pos_label) + 1e-5) 65 | dp_bg_loss = torch.sum(dp_bg_loss * torch.unsqueeze(bg_pos_label, 1)) / (2 * torch.sum(bg_pos_label) + 1e-5) 66 | 67 | avg_meter.add({'loss1': pos_aff_loss.item(), 'loss2': neg_aff_loss.item(), 68 | 'loss3': dp_fg_loss.item(), 'loss4': dp_bg_loss.item()}) 69 | 70 | total_loss = (pos_aff_loss + neg_aff_loss) / 2 + (dp_fg_loss + dp_bg_loss) / 2 71 | 72 | optimizer.zero_grad() 73 | total_loss.backward() 74 | optimizer.step() 75 | 76 | if (optimizer.global_step - 1) % 50 == 0: 77 | timer.update_progress(optimizer.global_step / max_step) 78 | 79 | print('step:%5d/%5d' % (optimizer.global_step - 1, max_step), 80 | 'loss:%.4f %.4f %.4f %.4f' % ( 81 | avg_meter.pop('loss1'), avg_meter.pop('loss2'), avg_meter.pop('loss3'), avg_meter.pop('loss4')), 82 | 'imps:%.1f' % ((iter + 1) * args.irn_batch_size / timer.get_stage_elapsed()), 83 | 'lr: %.4f' % (optimizer.param_groups[0]['lr']), 84 | 'etc:%s' % (timer.str_estimated_complete()), flush=True) 85 | else: 86 | timer.reset_stage() 87 | 88 | infer_dataset = voc12.dataloader.VOC12ImageDataset(args.infer_list, 89 | voc12_root=args.voc12_root, 90 | crop_size=args.irn_crop_size, 91 | crop_method="top_left") 92 | infer_data_loader = DataLoader(infer_dataset, batch_size=args.irn_batch_size, 93 | shuffle=False, num_workers=args.num_workers, pin_memory=True, drop_last=True) 94 | 95 | model.eval() 96 | print('Analyzing displacements mean ... ', end='') 97 | 98 | dp_mean_list = [] 99 | 100 | with torch.no_grad(): 101 | for iter, pack in enumerate(infer_data_loader): 102 | img = pack['img'].cuda(non_blocking=True) 103 | 104 | aff, dp = model(img, False) 105 | 106 | dp_mean_list.append(torch.mean(dp, dim=(0, 2, 3)).cpu()) 107 | 108 | model.module.mean_shift.running_mean = torch.mean(torch.stack(dp_mean_list), dim=0) 109 | print('done.') 110 | 111 | torch.save(model.module.state_dict(), args.irn_weights_name) 112 | torch.cuda.empty_cache() 113 | -------------------------------------------------------------------------------- /WSSS/tool/accuracy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def IoU(boxA, boxB): 5 | xA = max(boxA[0], boxB[0]) 6 | yA = max(boxA[1], boxB[1]) 7 | xB = min(boxA[2], boxB[2]) 8 | yB = min(boxA[3], boxB[3]) 9 | 10 | interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1) 11 | boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1) 12 | boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1) 13 | 14 | iou = interArea / float(boxAArea + boxBArea - interArea) 15 | return iou 16 | 17 | def compute_cls_acc(preds, label): 18 | pred = torch.max(preds, 1)[1] 19 | #label = torch.max(labels, 1)[1] 20 | num_correct = (pred == label).sum() 21 | return float(num_correct)/ float(preds.size(0)) 22 | 23 | def compute_loc_acc(loc_preds, loc_gt, theta=0.5): 24 | correct = 0 25 | for i in range(loc_preds.size(0)): 26 | iou = IoU(loc_preds[i], loc_gt[i]) 27 | correct += (iou >= theta).sum() 28 | return float(correct) / float(loc_preds.size(0)) 29 | 30 | class AverageMeter(object): 31 | def __init__(self): 32 | self.reset() 33 | 34 | def reset(self): 35 | self.val = 0.0 36 | self.avg = 0.0 37 | self.sum = 0.0 38 | self.cnt = 0.0 39 | 40 | def updata(self, val, n=1.0): 41 | self.val = val 42 | self.sum += val * n 43 | self.cnt += n 44 | if self.cnt == 0: 45 | self.avg = 1 46 | else: 47 | self.avg = self.sum / self.cnt 48 | 49 | def accuracy(output, target, topk=(1,)): 50 | maxk = max(topk) 51 | batch_size = target.size(0) 52 | 53 | _, pred = output.topk(maxk, 1, True, True) 54 | pred = pred.t() 55 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 56 | 57 | res = [] 58 | for k in topk: 59 | correct_k = correct[:k].reshape(-1).float().sum(0) 60 | res.append(correct_k.mul_(100.0 / batch_size)) 61 | return res 62 | -------------------------------------------------------------------------------- /WSSS/tool/func.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import torch 4 | import copy 5 | import cv2 6 | import sys 7 | def count_max(x): 8 | count_dict = {} 9 | for xlist in x: 10 | for item in xlist: 11 | if item==0: 12 | continue 13 | if item not in count_dict.keys(): 14 | count_dict[item] = 0 15 | count_dict[item] += 1 16 | if count_dict == {}: 17 | return -1 18 | count_dict = sorted(count_dict.items(), key=lambda d:d[1], reverse=True) 19 | return count_dict[0][0] 20 | 21 | def compute_intersec(i, j, h, w, bbox): 22 | ''' 23 | intersection box between croped box and GT BBox 24 | ''' 25 | intersec = copy.deepcopy(bbox) 26 | 27 | intersec[0] = max(j, bbox[0]) 28 | intersec[1] = max(i, bbox[1]) 29 | intersec[2] = min(j + w, bbox[2]) 30 | intersec[3] = min(i + h, bbox[3]) 31 | return intersec 32 | 33 | 34 | def normalize_intersec(i, j, h, w, intersec): 35 | ''' 36 | return: normalize into [0, 1] 37 | ''' 38 | 39 | intersec[0] = (intersec[0] - j) / w 40 | intersec[2] = (intersec[2] - j) / w 41 | intersec[1] = (intersec[1] - i) / h 42 | intersec[3] = (intersec[3] - i) / h 43 | return intersec 44 | 45 | def normalize_map(atten_map, crop_size): 46 | min_val = np.min(atten_map) 47 | max_val = np.max(atten_map) 48 | atten_norm = (atten_map - min_val)/(max_val - min_val + 1e-5) 49 | atten_norm = cv2.resize(atten_norm, dsize=(crop_size, crop_size)) 50 | return atten_norm 51 | 52 | class Logger(object): 53 | def __init__(self,filename="Default.log"): 54 | self.terminal = sys.stdout 55 | self.log = open(filename,'w+') 56 | 57 | def write(self,message): 58 | self.terminal.write(message) 59 | self.log.write(message) 60 | 61 | def flush(self): 62 | pass 63 | -------------------------------------------------------------------------------- /WSSS/tool/imutils.py: -------------------------------------------------------------------------------- 1 | 2 | import PIL.Image 3 | import random 4 | import numpy as np 5 | 6 | class RandomResizeLong(): 7 | 8 | def __init__(self, min_long, max_long): 9 | self.min_long = min_long 10 | self.max_long = max_long 11 | 12 | def __call__(self, img, sal=None): 13 | 14 | target_long = random.randint(self.min_long, self.max_long) 15 | w, h = img.size 16 | 17 | if w < h: 18 | target_shape = (int(round(w * target_long / h)), target_long) 19 | else: 20 | target_shape = (target_long, int(round(h * target_long / w))) 21 | 22 | img = img.resize(target_shape, resample=PIL.Image.CUBIC) 23 | if sal: 24 | sal = sal.resize(target_shape, resample=PIL.Image.CUBIC) 25 | return img, sal 26 | return img 27 | 28 | 29 | class RandomCrop(): 30 | 31 | def __init__(self, cropsize): 32 | self.cropsize = cropsize 33 | 34 | def __call__(self, imgarr, sal=None): 35 | 36 | h, w, c = imgarr.shape 37 | 38 | ch = min(self.cropsize, h) 39 | cw = min(self.cropsize, w) 40 | 41 | w_space = w - self.cropsize 42 | h_space = h - self.cropsize 43 | 44 | if w_space > 0: 45 | cont_left = 0 46 | img_left = random.randrange(w_space+1) 47 | else: 48 | cont_left = random.randrange(-w_space+1) 49 | img_left = 0 50 | 51 | if h_space > 0: 52 | cont_top = 0 53 | img_top = random.randrange(h_space+1) 54 | else: 55 | cont_top = random.randrange(-h_space+1) 56 | img_top = 0 57 | 58 | container = np.zeros((self.cropsize, self.cropsize, imgarr.shape[-1]), np.float32) 59 | container[cont_top:cont_top+ch, cont_left:cont_left+cw] = \ 60 | imgarr[img_top:img_top+ch, img_left:img_left+cw] 61 | if sal is not None: 62 | container_sal = np.zeros((self.cropsize, self.cropsize,1), np.float32) 63 | container_sal[cont_top:cont_top+ch, cont_left:cont_left+cw,0] = \ 64 | sal[img_top:img_top+ch, img_left:img_left+cw] 65 | return container, container_sal 66 | 67 | return container 68 | 69 | def get_random_crop_box(imgsize, cropsize): 70 | h, w = imgsize 71 | 72 | ch = min(cropsize, h) 73 | cw = min(cropsize, w) 74 | 75 | w_space = w - cropsize 76 | h_space = h - cropsize 77 | 78 | if w_space > 0: 79 | cont_left = 0 80 | img_left = random.randrange(w_space + 1) 81 | else: 82 | cont_left = random.randrange(-w_space + 1) 83 | img_left = 0 84 | 85 | if h_space > 0: 86 | cont_top = 0 87 | img_top = random.randrange(h_space + 1) 88 | else: 89 | cont_top = random.randrange(-h_space + 1) 90 | img_top = 0 91 | 92 | return cont_top, cont_top+ch, cont_left, cont_left+cw, img_top, img_top+ch, img_left, img_left+cw 93 | 94 | def crop_with_box(img, box): 95 | if len(img.shape) == 3: 96 | img_cont = np.zeros((max(box[1]-box[0], box[4]-box[5]), max(box[3]-box[2], box[7]-box[6]), img.shape[-1]), dtype=img.dtype) 97 | else: 98 | img_cont = np.zeros((max(box[1] - box[0], box[4] - box[5]), max(box[3] - box[2], box[7] - box[6])), dtype=img.dtype) 99 | img_cont[box[0]:box[1], box[2]:box[3]] = img[box[4]:box[5], box[6]:box[7]] 100 | return img_cont 101 | 102 | 103 | def random_crop(images, cropsize, fills): 104 | if isinstance(images[0], PIL.Image.Image): 105 | imgsize = images[0].size[::-1] 106 | else: 107 | imgsize = images[0].shape[:2] 108 | box = get_random_crop_box(imgsize, cropsize) 109 | 110 | new_images = [] 111 | for img, f in zip(images, fills): 112 | 113 | if isinstance(img, PIL.Image.Image): 114 | img = img.crop((box[6], box[4], box[7], box[5])) 115 | cont = PIL.Image.new(img.mode, (cropsize, cropsize)) 116 | cont.paste(img, (box[2], box[0])) 117 | new_images.append(cont) 118 | 119 | else: 120 | if len(img.shape) == 3: 121 | cont = np.ones((cropsize, cropsize, img.shape[2]), img.dtype)*f 122 | else: 123 | cont = np.ones((cropsize, cropsize), img.dtype)*f 124 | cont[box[0]:box[1], box[2]:box[3]] = img[box[4]:box[5], box[6]:box[7]] 125 | new_images.append(cont) 126 | 127 | return new_images 128 | 129 | 130 | class AvgPool2d(): 131 | 132 | def __init__(self, ksize): 133 | self.ksize = ksize 134 | 135 | def __call__(self, img): 136 | import skimage.measure 137 | 138 | return skimage.measure.block_reduce(img, (self.ksize, self.ksize, 1), np.mean) 139 | 140 | 141 | class RandomHorizontalFlip(): 142 | def __init__(self): 143 | return 144 | 145 | def __call__(self, img, sal=None): 146 | if bool(random.getrandbits(1)): 147 | #img = img.transpose(PIL.Image.FLIP_LEFT_RIGHT) 148 | img = np.fliplr(img).copy() 149 | if sal: 150 | #sal = sal.transpose(PIL.Image.FLIP_LEFT_RIGHT) 151 | sal = np.fliplr(sal).copy() 152 | return img, sal 153 | return img 154 | else: 155 | if sal: 156 | return img, sal 157 | return img 158 | 159 | 160 | class CenterCrop(): 161 | 162 | def __init__(self, cropsize, default_value=0): 163 | self.cropsize = cropsize 164 | self.default_value = default_value 165 | 166 | def __call__(self, npimg): 167 | 168 | h, w = npimg.shape[:2] 169 | 170 | ch = min(self.cropsize, h) 171 | cw = min(self.cropsize, w) 172 | 173 | sh = h - self.cropsize 174 | sw = w - self.cropsize 175 | 176 | if sw > 0: 177 | cont_left = 0 178 | img_left = int(round(sw / 2)) 179 | else: 180 | cont_left = int(round(-sw / 2)) 181 | img_left = 0 182 | 183 | if sh > 0: 184 | cont_top = 0 185 | img_top = int(round(sh / 2)) 186 | else: 187 | cont_top = int(round(-sh / 2)) 188 | img_top = 0 189 | 190 | if len(npimg.shape) == 2: 191 | container = np.ones((self.cropsize, self.cropsize), npimg.dtype)*self.default_value 192 | else: 193 | container = np.ones((self.cropsize, self.cropsize, npimg.shape[2]), npimg.dtype)*self.default_value 194 | 195 | container[cont_top:cont_top+ch, cont_left:cont_left+cw] = \ 196 | npimg[img_top:img_top+ch, img_left:img_left+cw] 197 | 198 | return container 199 | 200 | 201 | def HWC_to_CHW(tensor, sal=False): 202 | if sal: 203 | tensor = np.expand_dims(tensor, axis=0) 204 | else: 205 | tensor = np.transpose(tensor, (2, 0, 1)) 206 | return tensor 207 | 208 | 209 | class RescaleNearest(): 210 | def __init__(self, scale): 211 | self.scale = scale 212 | 213 | def __call__(self, npimg): 214 | import cv2 215 | return cv2.resize(npimg, None, fx=self.scale, fy=self.scale, interpolation=cv2.INTER_NEAREST) 216 | 217 | 218 | 219 | 220 | def crf_inference(img, probs, t=10, scale_factor=1, labels=21): 221 | import pydensecrf.densecrf as dcrf 222 | from pydensecrf.utils import unary_from_softmax 223 | 224 | h, w = img.shape[:2] 225 | n_labels = labels 226 | 227 | d = dcrf.DenseCRF2D(w, h, n_labels) 228 | 229 | unary = unary_from_softmax(probs) 230 | unary = np.ascontiguousarray(unary) 231 | 232 | d.setUnaryEnergy(unary) 233 | d.addPairwiseGaussian(sxy=3/scale_factor, compat=3) 234 | d.addPairwiseBilateral(sxy=80/scale_factor, srgb=13, rgbim=np.copy(img), compat=10) 235 | Q = d.inference(t) 236 | 237 | return np.array(Q).reshape((n_labels, h, w)) 238 | -------------------------------------------------------------------------------- /WSSS/tool/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tool import pyutils, imutils, torchutils, visualization 3 | 4 | def get_optimizer(model, args, max_step): 5 | param_groups = ([], [], [], []) 6 | for name, value in model.named_parameters(): 7 | if 'classifier' in name : 8 | if 'weight' in name: 9 | param_groups[2].append(value) 10 | elif 'bias' in name: 11 | param_groups[3].append(value) 12 | else: 13 | if 'weight' in name: 14 | param_groups[0].append(value) 15 | elif 'bias' in name: 16 | param_groups[1].append(value) 17 | 18 | optimizer = torchutils.PolyOptimizer([ 19 | {'params': param_groups[0], 'lr': args.lr, 'weight_decay': args.wt_dec}, 20 | {'params': param_groups[2], 'lr': 10*args.lr, 'weight_decay': args.wt_dec}, 21 | ], lr=args.lr, weight_decay=args.wt_dec, max_step=max_step) 22 | return optimizer -------------------------------------------------------------------------------- /WSSS/tool/pyutils.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import time 4 | import sys 5 | 6 | class Logger(object): 7 | def __init__(self, outfile): 8 | self.terminal = sys.stdout 9 | self.log = open(outfile, "w") 10 | sys.stdout = self 11 | 12 | def write(self, message): 13 | self.terminal.write(message) 14 | self.log.write(message) 15 | 16 | def flush(self): 17 | self.terminal.flush() 18 | 19 | 20 | class AverageMeter: 21 | def __init__(self, *keys): 22 | self.__data = dict() 23 | for k in keys: 24 | self.__data[k] = [0.0, 0] 25 | 26 | def add(self, dict): 27 | for k, v in dict.items(): 28 | self.__data[k][0] += v 29 | self.__data[k][1] += 1 30 | 31 | def get(self, *keys): 32 | if len(keys) == 1: 33 | return self.__data[keys[0]][0] / self.__data[keys[0]][1] 34 | else: 35 | v_list = [self.__data[k][0] / self.__data[k][1] for k in keys] 36 | return tuple(v_list) 37 | 38 | def pop(self, key=None): 39 | if key is None: 40 | for k in self.__data.keys(): 41 | self.__data[k] = [0.0, 0] 42 | else: 43 | v = self.get(key) 44 | self.__data[key] = [0.0, 0] 45 | return v 46 | 47 | 48 | class Timer: 49 | def __init__(self, starting_msg = None): 50 | self.start = time.time() 51 | self.stage_start = self.start 52 | 53 | if starting_msg is not None: 54 | print(starting_msg, time.ctime(time.time())) 55 | 56 | 57 | def update_progress(self, progress): 58 | self.elapsed = time.time() - self.start 59 | self.est_total = self.elapsed / progress 60 | self.est_remaining = self.est_total - self.elapsed 61 | self.est_finish = int(self.start + self.est_total) 62 | 63 | 64 | def str_est_finish(self): 65 | return str(time.ctime(self.est_finish)) 66 | 67 | def get_stage_elapsed(self): 68 | return time.time() - self.stage_start 69 | 70 | def reset_stage(self): 71 | self.stage_start = time.time() 72 | 73 | 74 | from multiprocessing.pool import ThreadPool 75 | 76 | class BatchThreader: 77 | 78 | def __init__(self, func, args_list, batch_size, prefetch_size=4, processes=12): 79 | self.batch_size = batch_size 80 | self.prefetch_size = prefetch_size 81 | 82 | self.pool = ThreadPool(processes=processes) 83 | self.async_result = [] 84 | 85 | self.func = func 86 | self.left_args_list = args_list 87 | self.n_tasks = len(args_list) 88 | 89 | # initial work 90 | self.__start_works(self.__get_n_pending_works()) 91 | 92 | 93 | def __start_works(self, times): 94 | for _ in range(times): 95 | args = self.left_args_list.pop(0) 96 | self.async_result.append( 97 | self.pool.apply_async(self.func, args)) 98 | 99 | 100 | def __get_n_pending_works(self): 101 | return min((self.prefetch_size + 1) * self.batch_size - len(self.async_result) 102 | , len(self.left_args_list)) 103 | 104 | 105 | 106 | def pop_results(self): 107 | 108 | n_inwork = len(self.async_result) 109 | 110 | n_fetch = min(n_inwork, self.batch_size) 111 | rtn = [self.async_result.pop(0).get() 112 | for _ in range(n_fetch)] 113 | 114 | to_fill = self.__get_n_pending_works() 115 | if to_fill == 0: 116 | self.pool.close() 117 | else: 118 | self.__start_works(to_fill) 119 | 120 | return rtn 121 | 122 | 123 | 124 | 125 | def get_indices_of_pairs(radius, size): 126 | 127 | search_dist = [] 128 | 129 | for x in range(1, radius): 130 | search_dist.append((0, x)) 131 | 132 | for y in range(1, radius): 133 | for x in range(-radius + 1, radius): 134 | if x * x + y * y < radius * radius: 135 | search_dist.append((y, x)) 136 | 137 | radius_floor = radius - 1 138 | 139 | full_indices = np.reshape(np.arange(0, size[0]*size[1], dtype=np.int64), 140 | (size[0], size[1])) 141 | 142 | cropped_height = size[0] - radius_floor 143 | cropped_width = size[1] - 2 * radius_floor 144 | 145 | indices_from = np.reshape(full_indices[:-radius_floor, radius_floor:-radius_floor], 146 | [-1]) 147 | 148 | indices_to_list = [] 149 | 150 | for dy, dx in search_dist: 151 | indices_to = full_indices[dy:dy + cropped_height, 152 | radius_floor + dx:radius_floor + dx + cropped_width] 153 | indices_to = np.reshape(indices_to, [-1]) 154 | 155 | indices_to_list.append(indices_to) 156 | 157 | concat_indices_to = np.concatenate(indices_to_list, axis=0) 158 | 159 | return indices_from, concat_indices_to 160 | 161 | def get_indices_of_pairs_circle(radius, size): 162 | 163 | search_dist = [] 164 | 165 | for y in range(-radius + 1, radius): 166 | for x in range(-radius + 1, radius): 167 | if x * x + y * y < radius * radius and x*x+y*y!=0: 168 | search_dist.append((y, x)) 169 | 170 | radius_floor = radius - 1 171 | 172 | full_indices = np.reshape(np.arange(0, size[0]*size[1], dtype=np.int64), 173 | (size[0], size[1])) 174 | 175 | cropped_height = size[0] - 2 * radius_floor 176 | cropped_width = size[1] - 2 * radius_floor 177 | 178 | indices_from = np.reshape(full_indices[radius_floor:-radius_floor, radius_floor:-radius_floor], 179 | [-1]) 180 | 181 | indices_to_list = [] 182 | 183 | for dy, dx in search_dist: 184 | indices_to = full_indices[radius_floor + dy : radius_floor + dy + cropped_height, 185 | radius_floor + dx : radius_floor + dx + cropped_width] 186 | indices_to = np.reshape(indices_to, [-1]) 187 | 188 | indices_to_list.append(indices_to) 189 | 190 | concat_indices_to = np.concatenate(indices_to_list, axis=0) 191 | 192 | return indices_from, concat_indices_to 193 | -------------------------------------------------------------------------------- /WSSS/tool/torchutils.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch.utils.data import Dataset 4 | from PIL import Image 5 | import os.path 6 | import random 7 | import numpy as np 8 | from tool import imutils 9 | 10 | class PolyOptimizer(torch.optim.SGD): 11 | 12 | def __init__(self, params, lr, weight_decay, max_step, momentum=0.9): 13 | super().__init__(params, lr, weight_decay) 14 | 15 | self.global_step = 0 16 | self.max_step = max_step 17 | self.momentum = momentum 18 | 19 | self.__initial_lr = [group['lr'] for group in self.param_groups] 20 | 21 | 22 | def step(self, closure=None): 23 | 24 | if self.global_step < self.max_step: 25 | lr_mult = (1 - self.global_step / self.max_step) ** self.momentum 26 | 27 | for i in range(len(self.param_groups)): 28 | self.param_groups[i]['lr'] = self.__initial_lr[i] * lr_mult 29 | 30 | super().step(closure) 31 | 32 | self.global_step += 1 33 | 34 | 35 | class BatchNorm2dFixed(torch.nn.Module): 36 | 37 | def __init__(self, num_features, eps=1e-5): 38 | super(BatchNorm2dFixed, self).__init__() 39 | self.num_features = num_features 40 | self.eps = eps 41 | self.weight = torch.nn.Parameter(torch.Tensor(num_features)) 42 | self.bias = torch.nn.Parameter(torch.Tensor(num_features)) 43 | self.register_buffer('running_mean', torch.zeros(num_features)) 44 | self.register_buffer('running_var', torch.ones(num_features)) 45 | 46 | 47 | def forward(self, input): 48 | 49 | return F.batch_norm( 50 | input, self.running_mean, self.running_var, self.weight, self.bias, 51 | False, eps=self.eps) 52 | 53 | def __call__(self, x): 54 | return self.forward(x) 55 | 56 | 57 | class SegmentationDataset(Dataset): 58 | def __init__(self, img_name_list_path, img_dir, label_dir, rescale=None, flip=False, cropsize=None, 59 | img_transform=None, mask_transform=None): 60 | self.img_name_list_path = img_name_list_path 61 | self.img_dir = img_dir 62 | self.label_dir = label_dir 63 | 64 | self.img_transform = img_transform 65 | self.mask_transform = mask_transform 66 | 67 | self.img_name_list = open(self.img_name_list_path).read().splitlines() 68 | 69 | self.rescale = rescale 70 | self.flip = flip 71 | self.cropsize = cropsize 72 | 73 | def __len__(self): 74 | return len(self.img_name_list) 75 | 76 | def __getitem__(self, idx): 77 | 78 | name = self.img_name_list[idx] 79 | 80 | img = Image.open(os.path.join(self.img_dir, name + '.jpg')).convert("RGB") 81 | mask = Image.open(os.path.join(self.label_dir, name + '.png')) 82 | 83 | if self.rescale is not None: 84 | s = self.rescale[0] + random.random() * (self.rescale[1] - self.rescale[0]) 85 | adj_size = (round(img.size[0]*s/8)*8, round(img.size[1]*s/8)*8) 86 | img = img.resize(adj_size, resample=Image.CUBIC) 87 | mask = img.resize(adj_size, resample=Image.NEAREST) 88 | 89 | if self.img_transform is not None: 90 | img = self.img_transform(img) 91 | if self.mask_transform is not None: 92 | mask = self.mask_transform(mask) 93 | 94 | if self.cropsize is not None: 95 | img, mask = imutils.random_crop([img, mask], self.cropsize, (0, 255)) 96 | 97 | mask = imutils.RescaleNearest(0.125)(mask) 98 | 99 | if self.flip is True and bool(random.getrandbits(1)): 100 | img = np.flip(img, 1).copy() 101 | mask = np.flip(mask, 1).copy() 102 | 103 | img = np.transpose(img, (2, 0, 1)) 104 | 105 | return name, img, mask 106 | 107 | 108 | class ExtractAffinityLabelInRadius(): 109 | 110 | def __init__(self, cropsize, radius=5): 111 | self.radius = radius 112 | 113 | self.search_dist = [] 114 | 115 | for x in range(1, radius): 116 | self.search_dist.append((0, x)) 117 | 118 | for y in range(1, radius): 119 | for x in range(-radius+1, radius): 120 | if x*x + y*y < radius*radius: 121 | self.search_dist.append((y, x)) 122 | 123 | self.radius_floor = radius-1 124 | 125 | self.crop_height = cropsize - self.radius_floor 126 | self.crop_width = cropsize - 2 * self.radius_floor 127 | return 128 | 129 | def __call__(self, label): 130 | 131 | labels_from = label[:-self.radius_floor, self.radius_floor:-self.radius_floor] 132 | labels_from = np.reshape(labels_from, [-1]) 133 | 134 | labels_to_list = [] 135 | valid_pair_list = [] 136 | 137 | for dy, dx in self.search_dist: 138 | labels_to = label[dy:dy+self.crop_height, self.radius_floor+dx:self.radius_floor+dx+self.crop_width] 139 | labels_to = np.reshape(labels_to, [-1]) 140 | 141 | valid_pair = np.logical_and(np.less(labels_to, 255), np.less(labels_from, 255)) 142 | 143 | labels_to_list.append(labels_to) 144 | valid_pair_list.append(valid_pair) 145 | 146 | bc_labels_from = np.expand_dims(labels_from, 0) 147 | concat_labels_to = np.stack(labels_to_list) 148 | concat_valid_pair = np.stack(valid_pair_list) 149 | 150 | pos_affinity_label = np.equal(bc_labels_from, concat_labels_to) 151 | 152 | bg_pos_affinity_label = np.logical_and(pos_affinity_label, np.equal(bc_labels_from, 0)).astype(np.float32) 153 | 154 | fg_pos_affinity_label = np.logical_and(np.logical_and(pos_affinity_label, np.not_equal(bc_labels_from, 0)), concat_valid_pair).astype(np.float32) 155 | 156 | neg_affinity_label = np.logical_and(np.logical_not(pos_affinity_label), concat_valid_pair).astype(np.float32) 157 | 158 | return bg_pos_affinity_label, fg_pos_affinity_label, neg_affinity_label 159 | 160 | class AffinityFromMaskDataset(SegmentationDataset): 161 | def __init__(self, img_name_list_path, img_dir, label_dir, rescale=None, flip=False, cropsize=None, 162 | img_transform=None, mask_transform=None, radius=5): 163 | super().__init__(img_name_list_path, img_dir, label_dir, rescale, flip, cropsize, img_transform, mask_transform) 164 | 165 | self.radius = radius 166 | 167 | self.extract_aff_lab_func = ExtractAffinityLabelInRadius(cropsize=cropsize//8, radius=radius) 168 | 169 | def __getitem__(self, idx): 170 | name, img, mask = super().__getitem__(idx) 171 | 172 | aff_label = self.extract_aff_lab_func(mask) 173 | 174 | return name, img, aff_label 175 | -------------------------------------------------------------------------------- /WSSS/tool/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.nn as nn 3 | import shutil 4 | import numpy as np 5 | 6 | def remove_layer(state_dict, keyword): 7 | keys = [key for key in state_dict.keys()] 8 | for key in keys: 9 | if keyword in key: 10 | state_dict.pop(key) 11 | return state_dict 12 | 13 | 14 | def replace_layer(state_dict, keyword1, keyword2): 15 | keys = [key for key in state_dict.keys()] 16 | for key in keys: 17 | if keyword1 in key: 18 | new_key = key.replace(keyword1, keyword2) 19 | state_dict[new_key] = state_dict.pop(key) 20 | return state_dict 21 | 22 | 23 | def initialize_weights(modules, init_mode): 24 | for m in modules: 25 | if isinstance(m, nn.Conv2d): 26 | if init_mode == 'he': 27 | nn.init.kaiming_normal_(m.weight, mode='fan_out', 28 | nonlinearity='relu') 29 | elif init_mode == 'xavier': 30 | nn.init.xavier_uniform_(m.weight.data) 31 | else: 32 | raise ValueError('Invalid init_mode {}'.format(init_mode)) 33 | if m.bias is not None: 34 | nn.init.constant_(m.bias, 0) 35 | elif isinstance(m, nn.BatchNorm2d): 36 | nn.init.constant_(m.weight, 1) 37 | nn.init.constant_(m.bias, 0) 38 | elif isinstance(m, nn.Linear): 39 | nn.init.normal_(m.weight, 0, 0.01) 40 | nn.init.constant_(m.bias, 0) 41 | 42 | def copy_dir(dir1, dir2): 43 | dlist = os.listdir(dir1) 44 | if not os.path.exists(dir2): 45 | os.mkdir(dir2) 46 | for f in dlist: 47 | file1 = os.path.join(dir1, f) # 源文件 48 | file2 = os.path.join(dir2, f) # 目标文件 49 | if os.path.isfile(file1): 50 | shutil.copyfile(file1, file2) 51 | 52 | def seek_class(args, path): 53 | if args.root == 'CUB_200_2011': 54 | val_root = 'Data/'+args.root + '/test' 55 | path_split = path.split("_")[:-2] 56 | part_path = '' 57 | for i in path_split: 58 | part_path += i + '_' 59 | part_path = part_path[:-1] 60 | part_path = part_path.lower() 61 | classes_name = os.listdir(val_root) 62 | for cur_class in classes_name: 63 | cur_class_low = cur_class.lower() 64 | if cur_class_low[4:] == part_path: 65 | path = 'Data/' + args.root + '/test/' + cur_class + '/' + path 66 | return path, int(cur_class[:3]) - 1, cur_class 67 | 68 | elif args.root == 'fgvc_aircraft_2013b': 69 | label_classes = {} 70 | with open('Data/fgvc_aircraft_2013b/classes.txt', 'r') as f: 71 | for i, line in enumerate(f): 72 | label_classes[line[:-1]] = i 73 | with open('Data/fgvc_aircraft_2013b/images_box.txt', 'r') as f: 74 | for line in f: 75 | img_name = line.split(';')[0] 76 | if img_name == path[:-4]: 77 | img_label = line.split(';')[1] 78 | path = 'Data/fgvc_aircraft_2013b/images' + '/' + path 79 | return path, label_classes[img_label], img_label 80 | class Normalize(): 81 | def __init__(self, mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225)): 82 | 83 | self.mean = mean 84 | self.std = std 85 | 86 | def __call__(self, img): 87 | imgarr = np.asarray(img) 88 | proc_img = np.empty_like(imgarr, np.float32) 89 | 90 | proc_img[..., 0] = (imgarr[..., 0] / 255. - self.mean[0]) / self.std[0] 91 | proc_img[..., 1] = (imgarr[..., 1] / 255. - self.mean[1]) / self.std[1] 92 | proc_img[..., 2] = (imgarr[..., 2] / 255. - self.mean[2]) / self.std[2] 93 | 94 | return proc_img 95 | 96 | def decode_from_colormap(data, colors): 97 | ignore = (data == 255).astype(np.int32) 98 | 99 | mask = 1 - ignore 100 | data *= mask 101 | 102 | h, w = data.shape 103 | image = colors[data.reshape((h * w))].reshape((h, w, 3)) 104 | 105 | ignore = np.concatenate([ignore[..., np.newaxis], ignore[..., np.newaxis], ignore[..., np.newaxis]], axis=-1) 106 | image[ignore.astype(np.bool)] = 255 107 | return image 108 | 109 | def get_color_map_dic(): 110 | labels = ['background', 111 | 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 112 | 'bus', 'car', 'cat', 'chair', 'cow', 113 | 'diningtable', 'dog', 'horse', 'motorbike', 'person', 114 | 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor', 'void'] -------------------------------------------------------------------------------- /WSSS/tool/visualization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import cv2 5 | import pydensecrf.densecrf as dcrf 6 | from pydensecrf.utils import unary_from_softmax 7 | 8 | def color_pro(pro, img=None, mode='hwc'): 9 | H, W = pro.shape 10 | pro_255 = (pro*255).astype(np.uint8) 11 | pro_255 = np.expand_dims(pro_255,axis=2) 12 | color = cv2.applyColorMap(pro_255,cv2.COLORMAP_JET) 13 | color = cv2.cvtColor(color, cv2.COLOR_BGR2RGB) 14 | if img is not None: 15 | rate = 0.5 16 | if mode == 'hwc': 17 | assert img.shape[0] == H and img.shape[1] == W 18 | color = cv2.addWeighted(img,rate,color,1-rate,0) 19 | elif mode == 'chw': 20 | assert img.shape[1] == H and img.shape[2] == W 21 | img = np.transpose(img,(1,2,0)) 22 | color = cv2.addWeighted(img,rate,color,1-rate,0) 23 | color = np.transpose(color,(2,0,1)) 24 | else: 25 | if mode == 'chw': 26 | color = np.transpose(color,(2,0,1)) 27 | return color 28 | 29 | def generate_vis(p, gt, img, func_label2color, threshold=0.1, norm=True): 30 | # All the input should be numpy.array 31 | # img should be 0-255 uint8 32 | C, H, W = p.shape 33 | 34 | if norm: 35 | prob = max_norm(p, 'numpy') 36 | else: 37 | prob = p 38 | if gt is not None: 39 | prob = prob * gt 40 | prob[prob<=0] = 1e-7 41 | if threshold is not None: 42 | prob[0,:,:] = np.power(1-np.max(prob[1:,:,:],axis=0,keepdims=True), 4) 43 | 44 | CLS = ColorCLS(prob, func_label2color) 45 | CAM = ColorCAM(prob, img) 46 | 47 | prob_crf = dense_crf(prob, img, n_classes=C, n_iters=1) 48 | 49 | CLS_crf = ColorCLS(prob_crf, func_label2color) 50 | CAM_crf = ColorCAM(prob_crf, img) 51 | 52 | return CLS, CAM, CLS_crf, CAM_crf 53 | 54 | def max_norm(p, version='torch', e=1e-5): 55 | if version is 'torch': 56 | if p.dim() == 3: 57 | C, H, W = p.size() 58 | p = F.relu(p) 59 | max_v = torch.max(p.view(C,-1),dim=-1)[0].view(C,1,1) 60 | min_v = torch.min(p.view(C,-1),dim=-1)[0].view(C,1,1) 61 | p = F.relu(p-min_v-e)/(max_v-min_v+e) 62 | elif p.dim() == 4: 63 | N, C, H, W = p.size() 64 | p = F.relu(p) 65 | max_v = torch.max(p.view(N,C,-1),dim=-1)[0].view(N,C,1,1) 66 | min_v = torch.min(p.view(N,C,-1),dim=-1)[0].view(N,C,1,1) 67 | p = F.relu(p-min_v-e)/(max_v-min_v+e) 68 | elif version is 'numpy' or version is 'np': 69 | if p.ndim == 3: 70 | C, H, W = p.shape 71 | p[p<0] = 0 72 | max_v = np.max(p,(1,2),keepdims=True) 73 | min_v = np.min(p,(1,2),keepdims=True) 74 | p[p0.5] 55 | batch_pos_logits = batch_pos_logits.view(class_num, 1) 56 | batch_neg_logits = batch_neg_logits.view(1, 20 - class_num).expand((class_num, 20 - class_num)) 57 | batch_logits = torch.cat([batch_pos_logits, batch_neg_logits], 1) 58 | c += loss_func(batch_logits, p_batch_label) * class_num/ batch 59 | return c 60 | 61 | class opts(object): 62 | def __init__(self): 63 | self.parser = argparse.ArgumentParser() 64 | ## save 65 | self.parser.add_argument('--save_path', type=str, default='sess') 66 | self.parser.add_argument('--log_file', type=str, default='log.txt') 67 | self.parser.add_argument('--log_code_dir', type=str, default='save_code') 68 | 69 | self.parser.add_argument("--train_list", default="voc12/train_aug.txt", type=str) 70 | self.parser.add_argument("--val_list", default="voc12/val.txt", type=str) 71 | self.parser.add_argument("--voc12_root", default='/home2/wpy/WSSS/VOC2012/', type=str) 72 | ## dataloader 73 | self.parser.add_argument('--crop_size', default=448, type=int) 74 | self.parser.add_argument('--num_workers', type=int, default=2) 75 | ## train 76 | self.parser.add_argument('--num_classes', type=int, default=20) 77 | self.parser.add_argument('--batch_size', type=int, default=16) 78 | self.parser.add_argument('--epochs', type=int, default=10) 79 | self.parser.add_argument('--phase', type=str, default='train') ## train / test 80 | self.parser.add_argument('--lr', type=float, default=0.005) 81 | self.parser.add_argument('--wt_dec', type=float, default=1e-4) 82 | self.parser.add_argument('--power', type=float, default=0.9) 83 | self.parser.add_argument('--momentum', type=float, default=0.9) 84 | self.parser.add_argument("--local_rank", type=int,default=-1) 85 | self.parser.add_argument("--seed", type=int,default=0) 86 | ## model 87 | self.parser.add_argument('--arch', type=str, default='resnet50_bas') ## choose [ vgg, resnet, inception, mobilenet ] 88 | 89 | def parse(self): 90 | opt = self.parser.parse_args() 91 | opt.arch = opt.arch 92 | return opt 93 | 94 | args = opts().parse() 95 | lr = args.lr 96 | 97 | args.batch_size = args.batch_size // int(torch.cuda.device_count()) 98 | print(args.batch_size) 99 | 100 | ## save_log_txt 101 | makedirs(args.save_path + '/' + args.arch + '/' + args.log_code_dir) 102 | sys.stdout = Logger(args.save_path + '/' + args.arch + '/' + args.log_code_dir + '/' + args.log_file) 103 | sys.stdout.log.flush() 104 | 105 | ## save_code 106 | save_file = ['train_bas.py', 'run_sample.py'] 107 | for file_name in save_file: 108 | shutil.copyfile(file_name, args.save_path + '/' + args.arch + '/' + args.log_code_dir + '/' + file_name) 109 | save_dir = ['net', 'step', 'misc'] 110 | for file_name in save_dir: 111 | copy_dir(file_name, args.save_path + '/' + args.arch + '/' + args.log_code_dir + '/' + file_name) 112 | 113 | ## DDP 114 | local_rank = args.local_rank 115 | torch.cuda.set_device(local_rank) 116 | dist.init_process_group(backend='nccl') 117 | device = torch.device("cuda", local_rank) 118 | set_seed(args.seed + dist.get_rank()) 119 | 120 | train_dataset = voc12.dataloader.VOC12ClassificationDataset(args.train_list, voc12_root=args.voc12_root, 121 | resize_scale=512, hor_flip=True, 122 | crop_size=448, crop_method="random") 123 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 124 | train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_sampler, 125 | num_workers=args.num_workers, pin_memory=True, drop_last=True) 126 | max_step = len(train_dataset) // args.batch_size * args.epochs 127 | 128 | model = eval(args.arch).model() 129 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model).cuda().to(local_rank) 130 | model = DDP(model, find_unused_parameters=False , device_ids=[local_rank], output_device=local_rank) 131 | model.train() 132 | 133 | optimizer = optimizer.get_optimizer(model, args, max_step) 134 | 135 | for epoch in range(args.epochs): 136 | train_data_loader.sampler.set_epoch(epoch) 137 | loss_epoch_1 = AverageMeter() 138 | loss_epoch_2 = AverageMeter() 139 | loss_epoch_3 = AverageMeter() 140 | loss_epoch_4 = AverageMeter() 141 | for step, pack in enumerate(train_data_loader): 142 | imgs, label = pack['img'].to(local_rank), pack['label'].to(local_rank) 143 | optimizer.zero_grad() 144 | output1, loss2, loss3, loss4= model(imgs, label) 145 | 146 | loss1 = cross_entropy(output1, label).to(local_rank) #F.multilabel_soft_margin_loss(output1, label).to(local_rank) 147 | loss = loss1 + loss2 * 0.2 + loss3 + loss4 * 1.2 148 | loss.backward() 149 | optimizer.step() 150 | loss_epoch_1.updata(loss1.data, 1) 151 | loss_epoch_2.updata(loss2.data, 1) 152 | loss_epoch_3.updata(loss3.data, 1) 153 | loss_epoch_4.updata(loss4.data, 1) 154 | if dist.get_rank() == 0: 155 | print(' Epoch:[{}/{}]\t step:[{}/{}]\tcls_loss_1:{:.3f}\tcls_loss_2:{:.3f}\tbas_loss:{:.3f}\tarea_loss:{:.3f} '.format( 156 | epoch+1, args.epochs, step+1, len(train_data_loader), loss_epoch_1.avg, loss_epoch_2.avg, loss_epoch_3.avg, loss_epoch_4.avg 157 | )) 158 | sys.stdout.log.flush() 159 | torch.save(model.state_dict(), os.path.join(args.save_path + '/' + args.arch +'/' + 'epoch_'+ str(epoch+1) +'.pth.tar'),_use_new_zipfile_serialization=False) 160 | -------------------------------------------------------------------------------- /WSSS/voc12/cls_labels.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wpy1999/BAS-Extension/bd1bfa89efd8688dd185a8f34051b2786ac31eba/WSSS/voc12/cls_labels.npy -------------------------------------------------------------------------------- /WSSS/voc12/dataloader.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import Dataset 5 | import os.path 6 | import imageio 7 | from misc import imutils 8 | 9 | IMG_FOLDER_NAME = "JPEGImages" 10 | ANNOT_FOLDER_NAME = "Annotations" 11 | IGNORE = 255 12 | 13 | CAT_LIST = ['aeroplane', 'bicycle', 'bird', 'boat', 14 | 'bottle', 'bus', 'car', 'cat', 'chair', 15 | 'cow', 'diningtable', 'dog', 'horse', 16 | 'motorbike', 'person', 'pottedplant', 17 | 'sheep', 'sofa', 'train', 18 | 'tvmonitor'] 19 | 20 | N_CAT = len(CAT_LIST) 21 | 22 | CAT_NAME_TO_NUM = dict(zip(CAT_LIST,range(len(CAT_LIST)))) 23 | 24 | cls_labels_dict = np.load('voc12/cls_labels.npy', allow_pickle=True).item() 25 | 26 | def decode_int_filename(int_filename): 27 | s = str(int(int_filename)) 28 | return s[:4] + '_' + s[4:] 29 | 30 | def load_image_label_from_xml(img_name, voc12_root): 31 | from xml.dom import minidom 32 | 33 | elem_list = minidom.parse(os.path.join(voc12_root, ANNOT_FOLDER_NAME, decode_int_filename(img_name) + '.xml')).getElementsByTagName('name') 34 | 35 | multi_cls_lab = np.zeros((N_CAT), np.float32) 36 | 37 | for elem in elem_list: 38 | cat_name = elem.firstChild.data 39 | if cat_name in CAT_LIST: 40 | cat_num = CAT_NAME_TO_NUM[cat_name] 41 | multi_cls_lab[cat_num] = 1.0 42 | 43 | return multi_cls_lab 44 | 45 | def load_image_label_list_from_xml(img_name_list, voc12_root): 46 | 47 | return [load_image_label_from_xml(img_name, voc12_root) for img_name in img_name_list] 48 | 49 | def load_image_label_list_from_npy(img_name_list): 50 | 51 | return np.array([cls_labels_dict[img_name] for img_name in img_name_list]) 52 | 53 | def get_img_path(img_name, voc12_root): 54 | if not isinstance(img_name, str): 55 | img_name = decode_int_filename(img_name) 56 | return os.path.join(voc12_root, IMG_FOLDER_NAME, img_name + '.jpg') 57 | 58 | def load_img_name_list(dataset_path): 59 | 60 | img_name_list = np.loadtxt(dataset_path, dtype=np.int32) 61 | 62 | return img_name_list 63 | 64 | 65 | class TorchvisionNormalize(): 66 | def __init__(self, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): 67 | self.mean = mean 68 | self.std = std 69 | 70 | def __call__(self, img): 71 | imgarr = np.asarray(img) 72 | proc_img = np.empty_like(imgarr, np.float32) 73 | 74 | proc_img[..., 0] = (imgarr[..., 0] / 255. - self.mean[0]) / self.std[0] 75 | proc_img[..., 1] = (imgarr[..., 1] / 255. - self.mean[1]) / self.std[1] 76 | proc_img[..., 2] = (imgarr[..., 2] / 255. - self.mean[2]) / self.std[2] 77 | 78 | return proc_img 79 | 80 | class GetAffinityLabelFromIndices(): 81 | 82 | def __init__(self, indices_from, indices_to): 83 | 84 | self.indices_from = indices_from 85 | self.indices_to = indices_to 86 | 87 | def __call__(self, segm_map): 88 | 89 | segm_map_flat = np.reshape(segm_map, -1) 90 | 91 | segm_label_from = np.expand_dims(segm_map_flat[self.indices_from], axis=0) 92 | segm_label_to = segm_map_flat[self.indices_to] 93 | 94 | valid_label = np.logical_and(np.less(segm_label_from, 21), np.less(segm_label_to, 21)) 95 | 96 | equal_label = np.equal(segm_label_from, segm_label_to) 97 | 98 | pos_affinity_label = np.logical_and(equal_label, valid_label) 99 | 100 | bg_pos_affinity_label = np.logical_and(pos_affinity_label, np.equal(segm_label_from, 0)).astype(np.float32) 101 | fg_pos_affinity_label = np.logical_and(pos_affinity_label, np.greater(segm_label_from, 0)).astype(np.float32) 102 | 103 | neg_affinity_label = np.logical_and(np.logical_not(equal_label), valid_label).astype(np.float32) 104 | 105 | return torch.from_numpy(bg_pos_affinity_label), torch.from_numpy(fg_pos_affinity_label), \ 106 | torch.from_numpy(neg_affinity_label) 107 | 108 | 109 | class VOC12ImageDataset(Dataset): 110 | 111 | def __init__(self, img_name_list_path, voc12_root, 112 | resize_long=None, resize_scale=None, rescale=None, img_normal=TorchvisionNormalize(), hor_flip=False, 113 | crop_size=None, crop_method=None, to_torch=True): 114 | 115 | self.img_name_list = load_img_name_list(img_name_list_path) 116 | self.voc12_root = voc12_root 117 | 118 | self.resize_long = resize_long 119 | self.rescale = rescale 120 | self.resize_scale = resize_scale 121 | self.crop_size = crop_size 122 | self.img_normal = img_normal 123 | self.hor_flip = hor_flip 124 | self.crop_method = crop_method 125 | self.to_torch = to_torch 126 | 127 | def __len__(self): 128 | return len(self.img_name_list) 129 | 130 | def __getitem__(self, idx): 131 | name = self.img_name_list[idx] 132 | name_str = decode_int_filename(name) 133 | 134 | img = np.asarray(imageio.imread(get_img_path(name_str, self.voc12_root))) 135 | 136 | if self.resize_long: 137 | img = imutils.random_resize_long(img, self.resize_long[0], self.resize_long[1]) 138 | 139 | if self.resize_scale: 140 | img = imutils.resize_scale(img, resize_scale=self.resize_scale) 141 | 142 | if self.rescale: 143 | img = imutils.random_scale(img, scale_range=self.rescale, order=3) 144 | 145 | if self.img_normal: 146 | img = self.img_normal(img) 147 | 148 | if self.hor_flip: 149 | img = imutils.random_lr_flip(img) 150 | 151 | if self.crop_size: 152 | if self.crop_method == "random": 153 | img = imutils.random_crop(img, self.crop_size, 0) 154 | else: 155 | img = imutils.top_left_crop(img, self.crop_size, 0) 156 | 157 | if self.to_torch: 158 | img = imutils.HWC_to_CHW(img) 159 | 160 | return {'name': name_str, 'img': img} 161 | 162 | class VOC12ClassificationDataset(VOC12ImageDataset): 163 | 164 | def __init__(self, img_name_list_path, voc12_root, 165 | resize_long=None, resize_scale=None, rescale=None, img_normal=TorchvisionNormalize(), hor_flip=False, 166 | crop_size=None, crop_method=None): 167 | super().__init__(img_name_list_path, voc12_root, 168 | resize_long, resize_scale, rescale, img_normal, hor_flip, 169 | crop_size, crop_method) 170 | self.label_list = load_image_label_list_from_npy(self.img_name_list) 171 | 172 | def __getitem__(self, idx): 173 | out = super().__getitem__(idx) 174 | 175 | out['label'] = torch.from_numpy(self.label_list[idx]) 176 | 177 | return out 178 | 179 | class VOC12ClassificationDatasetMSF(VOC12ClassificationDataset): 180 | 181 | def __init__(self, img_name_list_path, voc12_root, 182 | img_normal=TorchvisionNormalize(), 183 | scales=(1.0,)): 184 | self.scales = scales 185 | 186 | super().__init__(img_name_list_path, voc12_root, img_normal=img_normal) 187 | self.scales = scales 188 | 189 | def __getitem__(self, idx): 190 | name = self.img_name_list[idx] 191 | name_str = decode_int_filename(name) 192 | 193 | img = imageio.imread(get_img_path(name_str, self.voc12_root)) 194 | 195 | ms_img_list = [] 196 | for s in self.scales: 197 | s_img = imutils.pil_rescale(img, s, order=3) 198 | s_img = self.img_normal(s_img) 199 | s_img = imutils.HWC_to_CHW(s_img) 200 | ms_img_list.append(np.stack([s_img, np.flip(s_img, -1)], axis=0)) 201 | if len(self.scales) == 1: 202 | ms_img_list = ms_img_list[0] 203 | 204 | out = {"name": name_str, "img": ms_img_list, "size": (img.shape[0], img.shape[1]), 205 | "label": torch.from_numpy(self.label_list[idx])} 206 | return out 207 | 208 | class VOC12SegmentationDataset(Dataset): 209 | 210 | def __init__(self, img_name_list_path, label_dir, crop_size, voc12_root, 211 | rescale=None, img_normal=TorchvisionNormalize(), hor_flip=False, 212 | crop_method = 'random'): 213 | 214 | self.img_name_list = load_img_name_list(img_name_list_path) 215 | self.voc12_root = voc12_root 216 | 217 | self.label_dir = label_dir 218 | 219 | self.rescale = rescale 220 | self.crop_size = crop_size 221 | self.img_normal = img_normal 222 | self.hor_flip = hor_flip 223 | self.crop_method = crop_method 224 | 225 | def __len__(self): 226 | return len(self.img_name_list) 227 | 228 | def __getitem__(self, idx): 229 | name = self.img_name_list[idx] 230 | name_str = decode_int_filename(name) 231 | 232 | img = imageio.imread(get_img_path(name_str, self.voc12_root)) 233 | # print(os.path.join(self.label_dir, name_str + '.png')) 234 | label = imageio.imread(os.path.join(self.label_dir, name_str + '.png')) 235 | 236 | img = np.asarray(img) 237 | 238 | if self.rescale: 239 | img, label = imutils.random_scale((img, label), scale_range=self.rescale, order=(3, 0)) 240 | 241 | if self.img_normal: 242 | img = self.img_normal(img) 243 | 244 | if self.hor_flip: 245 | img, label = imutils.random_lr_flip((img, label)) 246 | 247 | if self.crop_method == "random": 248 | img, label = imutils.random_crop((img, label), self.crop_size, (0, 255)) 249 | else: 250 | img = imutils.top_left_crop(img, self.crop_size, 0) 251 | label = imutils.top_left_crop(label, self.crop_size, 255) 252 | 253 | img = imutils.HWC_to_CHW(img) 254 | 255 | return {'name': name, 'img': img, 'label': label} 256 | 257 | class VOC12AffinityDataset(VOC12SegmentationDataset): 258 | def __init__(self, img_name_list_path, label_dir, crop_size, voc12_root, 259 | indices_from, indices_to, 260 | rescale=None, img_normal=TorchvisionNormalize(), hor_flip=False, crop_method=None): 261 | super().__init__(img_name_list_path, label_dir, crop_size, voc12_root, rescale, img_normal, hor_flip, crop_method=crop_method) 262 | 263 | self.extract_aff_lab_func = GetAffinityLabelFromIndices(indices_from, indices_to) 264 | 265 | def __len__(self): 266 | return len(self.img_name_list) 267 | 268 | def __getitem__(self, idx): 269 | out = super().__getitem__(idx) 270 | 271 | reduced_label = imutils.pil_rescale(out['label'], 0.25, 0) 272 | 273 | out['aff_bg_pos_label'], out['aff_fg_pos_label'], out['aff_neg_label'] = self.extract_aff_lab_func(reduced_label) 274 | 275 | return out 276 | 277 | -------------------------------------------------------------------------------- /WSSS/voc12/make_cls_labels.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import voc12.dataloader 3 | import numpy as np 4 | 5 | if __name__ == '__main__': 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--train_list", default='train_aug.txt', type=str) 9 | parser.add_argument("--val_list", default='val.txt', type=str) 10 | parser.add_argument("--out", default="cls_labels.npy", type=str) 11 | parser.add_argument("--voc12_root", default="../../../Dataset/VOC2012", type=str) 12 | args = parser.parse_args() 13 | 14 | train_name_list = voc12.dataloader.load_img_name_list(args.train_list) 15 | val_name_list = voc12.dataloader.load_img_name_list(args.val_list) 16 | 17 | train_val_name_list = np.concatenate([train_name_list, val_name_list], axis=0) 18 | label_list = voc12.dataloader.load_image_label_list_from_xml(train_val_name_list, args.voc12_root) 19 | 20 | total_label = np.zeros(20) 21 | 22 | d = dict() 23 | for img_name, label in zip(train_val_name_list, label_list): 24 | d[img_name] = label 25 | total_label += label 26 | 27 | print(total_label) 28 | np.save(args.out, d) --------------------------------------------------------------------------------