├── scripts ├── __init__.py ├── SEAM │ ├── voc12 │ │ ├── cls_labels.npy │ │ └── make_cls_labels.py │ ├── requirements.txt │ ├── LICENSE │ ├── network │ │ ├── resnet38_SEAM.py │ │ └── resnet38_aff.py │ ├── tool │ │ ├── visualization.py │ │ ├── pyutils.py │ │ └── torchutils.py │ ├── README.md │ ├── evaluation.py │ ├── infer_SEAM.py │ ├── train_aff.py │ ├── infer_aff.py │ └── infer_SEAM_good.py ├── vis_habitats.py ├── save_latex.py ├── game_results.py ├── qualitative.py ├── val_pseudo.py ├── save_preds.py ├── across_habitat.py ├── across_fish.py └── test_affinity.py ├── exp_configs ├── ablation_exps.py ├── lironne_exps.py └── __init__.py ├── src ├── models │ ├── networks │ │ ├── detr │ │ │ ├── __init__.py │ │ │ ├── util │ │ │ │ ├── __init__.py │ │ │ │ └── box_ops.py │ │ │ ├── position_encoding.py │ │ │ ├── matcher.py │ │ │ └── backbone.py │ │ ├── deeplab.py │ │ ├── lanenet.py │ │ ├── resnet50_cam.py │ │ ├── __init__.py │ │ ├── unet_resnet.py │ │ ├── resnet_seam.py │ │ ├── unet2d.py │ │ ├── fcn8_resnet.py │ │ ├── resnet50.py │ │ └── fcn8_vgg16_multiscale.py │ ├── optimizers │ │ ├── __init__.py │ │ └── sps.py │ ├── __init__.py │ └── metrics │ │ └── __init__.py ├── datasets │ ├── voc12 │ │ ├── cls_labels.npy │ │ └── make_cls_labels.py │ └── transformers │ │ └── __init__.py ├── modules │ ├── imantics │ │ ├── __init__.py │ │ ├── point.py │ │ ├── utils.py │ │ ├── styles.py │ │ ├── category.py │ │ ├── basic.py │ │ └── color.py │ ├── sstransforms.py │ ├── eprop │ │ └── eprop.py │ └── lcfcn │ │ └── lcfcn_loss.py ├── misc │ ├── torchutils.py │ ├── pyutils.py │ └── indexing.py └── utils.py ├── new.png ├── old.png ├── tmp ├── tmp0.png ├── tmp1.png ├── tmp2.png ├── tmp3.png └── tmp4.png ├── results ├── images └── plots ├── .gitignore ├── requirements.txt ├── job_configs.py └── README.md /scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /exp_configs/ablation_exps.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/networks/detr/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /new.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IssamLaradji/affinity_lcfcn/HEAD/new.png -------------------------------------------------------------------------------- /old.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IssamLaradji/affinity_lcfcn/HEAD/old.png -------------------------------------------------------------------------------- /tmp/tmp0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IssamLaradji/affinity_lcfcn/HEAD/tmp/tmp0.png -------------------------------------------------------------------------------- /tmp/tmp1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IssamLaradji/affinity_lcfcn/HEAD/tmp/tmp1.png -------------------------------------------------------------------------------- /tmp/tmp2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IssamLaradji/affinity_lcfcn/HEAD/tmp/tmp2.png -------------------------------------------------------------------------------- /tmp/tmp3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IssamLaradji/affinity_lcfcn/HEAD/tmp/tmp3.png -------------------------------------------------------------------------------- /tmp/tmp4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IssamLaradji/affinity_lcfcn/HEAD/tmp/tmp4.png -------------------------------------------------------------------------------- /results/images: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IssamLaradji/affinity_lcfcn/HEAD/results/images -------------------------------------------------------------------------------- /results/plots: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IssamLaradji/affinity_lcfcn/HEAD/results/plots -------------------------------------------------------------------------------- /exp_configs/lironne_exps.py: -------------------------------------------------------------------------------- 1 | from haven import haven_utils as hu 2 | import itertools, copy 3 | EXP_GROUPS = {} -------------------------------------------------------------------------------- /src/models/networks/detr/util/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | -------------------------------------------------------------------------------- /scripts/SEAM/voc12/cls_labels.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IssamLaradji/affinity_lcfcn/HEAD/scripts/SEAM/voc12/cls_labels.npy -------------------------------------------------------------------------------- /src/datasets/voc12/cls_labels.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IssamLaradji/affinity_lcfcn/HEAD/src/datasets/voc12/cls_labels.npy -------------------------------------------------------------------------------- /src/modules/imantics/__init__.py: -------------------------------------------------------------------------------- 1 | from .annotation import * 2 | from .category import * 3 | from .dataset import * 4 | from .styles import * 5 | from .image import * 6 | from .color import * 7 | 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.pyc 3 | *.pt 4 | *.out 5 | 6 | __pycache__/ 7 | .vscode/ 8 | *.ckpt 9 | .tmp/ 10 | results/*.ipynb 11 | results/*.ipynb* 12 | src/datasets/synbols 13 | usr_configs.py 14 | tmp.png 15 | .ipynb_checkpoints/ 16 | /Alz_temp/ 17 | /.idea/ 18 | -------------------------------------------------------------------------------- /exp_configs/__init__.py: -------------------------------------------------------------------------------- 1 | from . import alzayat_exps, issam_exps, lironne_exps, shortlist_exps 2 | EXP_GROUPS = {} 3 | 4 | EXP_GROUPS.update(lironne_exps.EXP_GROUPS) 5 | EXP_GROUPS.update(issam_exps.EXP_GROUPS) 6 | EXP_GROUPS.update(alzayat_exps.EXP_GROUPS) 7 | EXP_GROUPS.update(shortlist_exps.EXP_GROUPS) -------------------------------------------------------------------------------- /scripts/SEAM/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.18.2 2 | cython>=0.28.5 3 | tensorboard>=1.8.0 4 | tensorboardX>=1.6 5 | imageio>=2.6.1 6 | scikit-image>=0.14.0 7 | pydensecrf>=1.0rc2 8 | torch>=0.4.1 9 | torchvision 10 | scipy==1.1.0 11 | opencv-python>=3.4.2.17 12 | pandas>=0.23.4 13 | Pillow>=5.2.0 14 | mxnet 15 | -------------------------------------------------------------------------------- /src/modules/imantics/point.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Point(np.ndarray): 5 | 6 | def __abs__(self): 7 | return np.linalg.norm(self) 8 | 9 | def dist(self,other): 10 | return np.linalg.norm(self-other) 11 | 12 | def dot(self, other): 13 | return np.dot(self, other) 14 | 15 | -------------------------------------------------------------------------------- /src/modules/imantics/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def json_default(o): 5 | print(o) 6 | if isinstance(o, np.int64): 7 | return int(o) 8 | if isinstance(o, np.ndarray): 9 | return o.tolist() 10 | 11 | type_name = o.__class__.__name__ 12 | raise TypeError("Object of type {} is not JSON serializable".format(type_name)) 13 | 14 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.4.0 2 | torchvision==0.5.0 3 | pydicom>=1.4.2 4 | pylidc>=0.2.1 5 | SimpleITK>=1.2.4 6 | torchnet>=0.0.4 7 | h5py>=2.10.0 8 | tensorboard>=1.14.0 9 | ninja>=1.9.0.post1 10 | medpy>=0.4.0 11 | mdai>=0.4.1 12 | timm>=0.1.20 13 | pretrainedmodels>=0.7.4 14 | efficientnet_pytorch>=0.6.3 15 | matplotlib>=3.1.2 16 | seaborn>=0.9.0 17 | batchgenerators>=0.20.1 18 | scikit-image>=0.14.2 19 | kornia==0.2.0 20 | haven-ai 21 | -------------------------------------------------------------------------------- /src/modules/imantics/styles.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | #: COCO Style value (http://cocodataset.org/#format-data) 4 | COCO = "coco" 5 | 6 | #: PaperJs Style value (http://paperjs.org/reference/compoundpath/) 7 | PAPERJS = "paperjs" 8 | 9 | #: VGG Style value () 10 | VGG = "vgg" 11 | 12 | #: VOC Style value 13 | VOC = "voc" 14 | 15 | #: YOLO Style value (https://github.com/AlexeyAB/darknet#how-to-train-to-detect-your-custom-objects) 16 | YOLO = "yolo" 17 | -------------------------------------------------------------------------------- /src/models/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from . import sps 3 | 4 | 5 | def get_optimizer(name, model, exp_dict): 6 | if name == "adam": 7 | opt = torch.optim.Adam( 8 | model.parameters(), lr=exp_dict["lr"], betas=(0.99, 0.999)) 9 | 10 | elif name == "sgd": 11 | opt = torch.optim.SGD( 12 | model.parameters(), lr=exp_dict["lr"]) 13 | 14 | elif name == "sps": 15 | opt = sps.Sps( 16 | model.parameters(), c=1, momentum=0.6) 17 | return opt -------------------------------------------------------------------------------- /job_configs.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | ACCOUNT_ID = os.environ['EAI_ACCOUNT_ID'] 5 | 6 | JOB_CONFIG = { 7 | 'image': 'registry.console.elementai.com/%s/ssh' % os.environ['EAI_ACCOUNT_ID'] , 8 | 'data': [ 9 | 'eai.colab.public:/mnt/public', 10 | ], 11 | 'restartable':True, 12 | 'resources': { 13 | 'cpu': 4, 14 | 'mem': 8, 15 | 'gpu': 1 16 | }, 17 | 'interactive': False, 18 | 'bid': 5000, 19 | } -------------------------------------------------------------------------------- /scripts/SEAM/voc12/make_cls_labels.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import voc12.data 3 | import numpy as np 4 | 5 | if __name__ == '__main__': 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--train_list", default='train_aug.txt', type=str) 9 | parser.add_argument("--val_list", default='val.txt', type=str) 10 | parser.add_argument("--out", default="cls_labels.npy", type=str) 11 | parser.add_argument("--voc12_root", required=True, type=str) 12 | args = parser.parse_args() 13 | 14 | img_name_list = voc12.data.load_img_name_list(args.train_list) 15 | img_name_list.extend(voc12.data.load_img_name_list(args.val_list)) 16 | label_list = voc12.data.load_image_label_list_from_xml(img_name_list, args.voc12_root) 17 | 18 | d = dict() 19 | for img_name, label in zip(img_name_list, label_list): 20 | d[img_name] = label 21 | 22 | np.save(args.out, d) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Weakly Supervised Fish Segmentation 3 | 4 | ### JCU fish 5 | 6 | - https://www.dropbox.com/sh/b2jlua76ogyr5rk/AABsJVljG7v2BOunE1k4f_XTa?dl=0 7 | 8 | ## semseg Weakly supervised for JCU fish 9 | 10 | ``` 11 | python trainval.py -e weakly_JCUfish -sb -d -r 1 12 | ``` 13 | ## affinity Weakly supervised for JCU fish 14 | 15 | ``` 16 | python trainval.py -e weakly_JCUfish_aff -sb -d -r 1 17 | ``` 18 | 19 | # Citation 20 | 21 | ``` 22 | @misc{laradji2020affinity, 23 | title={Affinity LCFCN: Learning to Segment Fish with Weak Supervision}, 24 | author={Issam Laradji and Alzayat Saleh and Pau Rodriguez and 25 | Derek Nowrouzezahrai and Mostafa Rahimi Azghadi and David Vazquez}, 26 | year={2020}, 27 | eprint={2011.03149}, 28 | archivePrefix={arXiv}, 29 | primaryClass={cs.CV} 30 | } 31 | ``` 32 | -------------------------------------------------------------------------------- /src/datasets/voc12/make_cls_labels.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import voc12.dataloader 3 | import numpy as np 4 | 5 | if __name__ == '__main__': 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--train_list", default='train_aug.txt', type=str) 9 | parser.add_argument("--val_list", default='val.txt', type=str) 10 | parser.add_argument("--out", default="cls_labels.npy", type=str) 11 | parser.add_argument("--voc12_root", default="../../../Dataset/VOC2012", type=str) 12 | args = parser.parse_args() 13 | 14 | train_name_list = voc12.dataloader.load_img_name_list(args.train_list) 15 | val_name_list = voc12.dataloader.load_img_name_list(args.val_list) 16 | 17 | train_val_name_list = np.concatenate([train_name_list, val_name_list], axis=0) 18 | label_list = voc12.dataloader.load_image_label_list_from_xml(train_val_name_list, args.voc12_root) 19 | 20 | total_label = np.zeros(20) 21 | 22 | d = dict() 23 | for img_name, label in zip(train_val_name_list, label_list): 24 | d[img_name] = label 25 | total_label += label 26 | 27 | print(total_label) 28 | np.save(args.out, d) -------------------------------------------------------------------------------- /scripts/SEAM/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Hibercraft 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/models/networks/deeplab.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torchvision 3 | import torch 4 | from skimage import morphology as morph 5 | import numpy as np 6 | from torch import optim 7 | import torch.nn.functional as F 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | import numpy as np 12 | from skimage.morphology import watershed 13 | from skimage.segmentation import find_boundaries 14 | from scipy import ndimage 15 | from torchvision import models 16 | 17 | class DeepLab(nn.Module): 18 | def __init__(self, **kwargs): 19 | super().__init__() 20 | self.model = models.segmentation.deeplabv3_resnet50(pretrained_backbone=False) 21 | # base.forward = lambda x: base.forward(x)['out'] 22 | # # FREEZE BATCH NORMS 23 | for m in self.model.modules(): 24 | if isinstance(m, nn.BatchNorm2d): 25 | m.weight.requires_grad = False 26 | m.bias.requires_grad = False 27 | # m.reset_parameters() 28 | m.eval() 29 | # with torch.no_grad(): 30 | # m.weight.fill_(1.0) 31 | # m.bias.zero_() 32 | 33 | 34 | 35 | 36 | def forward(self, x): 37 | return self.model(x)['out'] -------------------------------------------------------------------------------- /src/modules/imantics/category.py: -------------------------------------------------------------------------------- 1 | from .basic import Semantic 2 | from .color import Color 3 | 4 | 5 | class Category(Semantic): 6 | 7 | 8 | @classmethod 9 | def from_coco(cls, coco): 10 | data = { 11 | 'name': coco.get('name'), 12 | 'metadata': coco.get('metadata', {}), 13 | 'id': coco.get('id', 0), 14 | 'parent': coco.get('supercategory'), 15 | 'color': coco.get('color') 16 | } 17 | return cls(**data) 18 | 19 | 20 | def __init__(self, name, parent=None, metadata={}, id=0, color=None): 21 | self.id = id 22 | self.name = name 23 | self.parent = None 24 | self.color = Color.create(color) 25 | 26 | super(Category, self).__init__(id, metadata) 27 | 28 | def coco(self, include=True): 29 | 30 | category = { 31 | 'id': self.id, 32 | 'name': self.name, 33 | 'supercategory': self.parent.name if self.parent else None, 34 | 'metadata': self.metadata, 35 | 'color': self.color.hex 36 | } 37 | 38 | if include: 39 | return { 40 | 'categories': [category] 41 | } 42 | 43 | return category 44 | 45 | 46 | __all__ = ["Category"] 47 | -------------------------------------------------------------------------------- /src/modules/imantics/basic.py: -------------------------------------------------------------------------------- 1 | from .styles import * 2 | 3 | 4 | class Semantic(object): 5 | 6 | def __init__(self, id, metadata={}): 7 | self.id = id 8 | self.metadata = metadata 9 | 10 | def coco(self): 11 | """ 12 | Export object in COCO format 13 | 14 | :returns: object in format 15 | :rtype: dict 16 | """ 17 | return {} 18 | 19 | def vgg(self): 20 | """ 21 | Export object in VGG format 22 | """ 23 | return [] 24 | 25 | def voc(self): 26 | """ 27 | Export object in VOC format 28 | 29 | :returns: object in format 30 | :rtype: lxml.element 31 | """ 32 | return None 33 | 34 | def yolo(self): 35 | """ 36 | Export object in YOLO format 37 | 38 | :returns: object in format 39 | :rtype: list, tuple 40 | """ 41 | return [] 42 | 43 | def paperjs(self): 44 | """ 45 | Export object in PaperJS format 46 | 47 | :returns: object in format 48 | :rtype: dict 49 | """ 50 | return {} 51 | 52 | def export(self, style=COCO): 53 | """ 54 | Exports object into specified style 55 | """ 56 | return { 57 | COCO: self.coco(), 58 | VGG: self.vgg(), 59 | YOLO: self.yolo(), 60 | VOC: self.voc(), 61 | PAPERJS: self.paperjs() 62 | }.get(style) 63 | 64 | def save(self, file): 65 | pass 66 | -------------------------------------------------------------------------------- /src/models/networks/lanenet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import math 4 | 5 | class VGG(nn.Module): 6 | def __init__(self): 7 | super(VGG, self).__init__() 8 | self.x64 = nn.Conv2d(1,64,kernel_size=1, padding = 0, bias = False) 9 | self.x128 = nn.Conv2d(1,128,kernel_size=1, padding = 0, bias = False) 10 | self.x256 = nn.Conv2d(1,256,kernel_size=1, padding = 0, bias = False) 11 | self.x64.weight = torch.nn.Parameter(torch.ones((64,1,1,1))) 12 | self.x128.weight = torch.nn.Parameter(torch.ones((128,1,1,1))) 13 | self.x256.weight = torch.nn.Parameter(torch.ones((256,1,1,1))) 14 | self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding = 1, bias=True) 15 | self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding = 1, bias=True) 16 | self.conv3_4 = nn.Conv2d(256, 256, kernel_size=3, padding = 1, bias=True) 17 | params = torch.load('/mnt/projects/vision_prototypes/pau/covid/vgg19-dcbb9e9d.pth') 18 | self.conv1_2.weight = torch.nn.Parameter(params['features.2.weight']) 19 | self.conv1_2.bias = torch.nn.Parameter(params['features.2.bias']) 20 | self.conv2_2.weight = torch.nn.Parameter(params['features.7.weight']) 21 | self.conv2_2.bias = torch.nn.Parameter(params['features.7.bias']) 22 | self.conv3_4.weight = torch.nn.Parameter(params['features.16.weight']) 23 | self.conv3_4.bias = torch.nn.Parameter(params['features.16.bias']) 24 | #{k: v for k, v in pretrained_dict.items() if k in model_dict} 25 | def forward(self, x): 26 | x64 = self.x64(x) 27 | x64 = self.conv1_2(x64) 28 | x128 = self.x128(x) 29 | x128 = self.conv2_2(x128) 30 | x256 = self.x256(x) 31 | x256 = self.conv3_4(x256) 32 | x_vgg = torch.cat([x64, x128, x256], dim = 1) 33 | return x_vgg -------------------------------------------------------------------------------- /src/models/networks/resnet50_cam.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from src.misc import torchutils 4 | from . import resnet50 5 | 6 | 7 | class Net(nn.Module): 8 | 9 | def __init__(self, n_classes=20): 10 | super(Net, self).__init__() 11 | self.n_classes = n_classes 12 | self.resnet50 = resnet50.resnet50(pretrained=True, strides=(2, 2, 2, 1)) 13 | 14 | self.stage1 = nn.Sequential(self.resnet50.conv1, self.resnet50.bn1, self.resnet50.relu, self.resnet50.maxpool, 15 | self.resnet50.layer1) 16 | self.stage2 = nn.Sequential(self.resnet50.layer2) 17 | self.stage3 = nn.Sequential(self.resnet50.layer3) 18 | self.stage4 = nn.Sequential(self.resnet50.layer4) 19 | 20 | self.classifier = nn.Conv2d(2048, n_classes, 1, bias=False) 21 | 22 | self.backbone = nn.ModuleList([self.stage1, self.stage2, self.stage3, self.stage4]) 23 | self.newly_added = nn.ModuleList([self.classifier]) 24 | 25 | def forward(self, x): 26 | 27 | x = self.stage1(x) 28 | x = self.stage2(x).detach() 29 | 30 | x = self.stage3(x) 31 | x = self.stage4(x) 32 | 33 | x = torchutils.gap2d(x, keepdims=True) 34 | x = self.classifier(x) 35 | x = x.view(-1, self.n_classes) 36 | 37 | return x 38 | 39 | def train(self, mode=True): 40 | for p in self.resnet50.conv1.parameters(): 41 | p.requires_grad = False 42 | for p in self.resnet50.bn1.parameters(): 43 | p.requires_grad = False 44 | 45 | def trainable_parameters(self): 46 | 47 | return (list(self.backbone.parameters()), list(self.newly_added.parameters())) 48 | 49 | 50 | class CAM(Net): 51 | 52 | def __init__(self): 53 | super(CAM, self).__init__() 54 | 55 | def forward(self, x): 56 | 57 | x = self.stage1(x) 58 | 59 | x = self.stage2(x) 60 | 61 | x = self.stage3(x) 62 | 63 | x = self.stage4(x) 64 | 65 | x = F.conv2d(x, self.classifier.weight) 66 | x = F.relu(x) 67 | 68 | x = x[0] + x[1].flip(-1) 69 | 70 | return x 71 | -------------------------------------------------------------------------------- /src/models/networks/__init__.py: -------------------------------------------------------------------------------- 1 | from . import fcn8_vgg16, fcn8_vgg16_multiscale, unet2d, unet_resnet, attu_net, fcn8_resnet, deeplab 2 | from . import resnet_seam, infnet 3 | from . import resnet50_cam, resnet50_irn, resnet50, fcn8_resnet 4 | from torchvision import models 5 | import torch, os 6 | import torch.nn as nn 7 | 8 | 9 | def get_network(network_name, n_classes, exp_dict): 10 | if network_name == 'infnet': 11 | model_base = infnet.InfNet(n_classes=1, loss=exp_dict['model']['loss']) 12 | 13 | if network_name == 'fcn8_vgg16_att': 14 | model_base = fcn8_vgg16.FCN8VGG16(n_classes=n_classes, with_attention=True) 15 | 16 | if network_name == 'fcn8_vgg16': 17 | model_base = fcn8_vgg16.FCN8VGG16(n_classes=n_classes, 18 | with_attention=exp_dict['model'].get('with_attention'), 19 | with_affinity=exp_dict['model'].get('with_affinity'), 20 | with_affinity_average=exp_dict['model'].get('with_affinity_average'), 21 | shared=exp_dict['model'].get('shared'), 22 | exp_dict=exp_dict 23 | ) 24 | 25 | if network_name == "fcn8_vgg16_multiscale": 26 | model_base = fcn8_vgg16_multiscale.FCN8VGG16(n_classes=n_classes) 27 | 28 | if network_name == "unet_resnet": 29 | model_base = unet_resnet.ResNetUNet(n_class=n_classes) 30 | 31 | if network_name == "resnet_seam": 32 | model_base = resnet_seam.ResNetSeam() 33 | # path_base = '/mnt/datasets/public/issam/seam' 34 | # model_base.load_state_dict(torch.load(os.path.join(path_base, 'resnet38_SEAM.pth'))) 35 | weights_dict = model_base.resnet38d.convert_mxnet_to_torch(args.weights) 36 | 37 | model.load_state_dict(weights_dict, strict=False) 38 | 39 | if network_name == "attu_net": 40 | model_base = attu_net.AttU_Net() 41 | 42 | if network_name == "resnet50_cam": 43 | return resnet50_cam.Net(n_classes=n_classes) 44 | elif network_name == "resnet50_irn": 45 | return resnet50_irn 46 | elif network_name == "fcn8_resnet": 47 | return fcn8_resnet.FCN8(n_classes) 48 | 49 | return model_base 50 | 51 | -------------------------------------------------------------------------------- /scripts/vis_habitats.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | 3 | path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 4 | sys.path.insert(0, path) 5 | 6 | from haven import haven_chk as hc 7 | from haven import haven_results as hr 8 | from haven import haven_utils as hu 9 | import torch 10 | import torchvision 11 | import tqdm 12 | import pandas as pd 13 | import pprint 14 | import itertools 15 | import os 16 | import pylab as plt 17 | import time 18 | import numpy as np 19 | 20 | from src import models 21 | from src import datasets 22 | from src import utils as ut 23 | 24 | import argparse 25 | 26 | from torch.utils.data import sampler 27 | from torch.utils.data.sampler import RandomSampler 28 | from torch.backends import cudnn 29 | from torch.nn import functional as F 30 | from torch.utils.data import DataLoader 31 | 32 | cudnn.benchmark = True 33 | 34 | if __name__ == "__main__": 35 | savedir_base = '/mnt/public/results/toolkit/weak_supervision' 36 | hash_list = ['a55d2c5dda331b1a0e191b104406dd1c'] 37 | # LCFCN 38 | # hash_id = 'bcba046296675e9e3af5cd9f353d217b' 39 | for hash_id in hash_list: 40 | exp_dict = hu.load_json(os.path.join(savedir_base, hash_id, 'exp_dict.json')) 41 | datadir = '/mnt/public/datasets/DeepFish/' 42 | split = 'train' 43 | train_set = datasets.get_dataset(dataset_dict=exp_dict["dataset"], 44 | split=split, 45 | datadir=datadir, 46 | exp_dict=exp_dict, 47 | dataset_size=exp_dict['dataset_size']) 48 | train_loader = DataLoader(train_set, 49 | # sampler=val_sampler, 50 | batch_size=1, 51 | collate_fn=ut.collate_fn, 52 | num_workers=0) 53 | for i, batch in enumerate(train_loader): 54 | points = (batch['points'].squeeze() == 1).numpy() 55 | if points.sum() == 0: 56 | continue 57 | savedir_image = os.path.join('.tmp/habitats/%s/%d.png' % (batch['meta'][0]['habitat'], i)) 58 | img = hu.denormalize(batch['images'], mode='rgb') 59 | # img_pred = model.predict_on_batch(batch) 60 | hu.save_image(savedir_image, img, points=points, radius=1) -------------------------------------------------------------------------------- /src/misc/torchutils.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | from torch.utils.data import Subset 5 | import numpy as np 6 | import math 7 | 8 | 9 | class PolyOptimizer(torch.optim.SGD): 10 | 11 | def __init__(self, params, lr, weight_decay, max_step, momentum=0.9): 12 | super().__init__(params, lr, weight_decay) 13 | 14 | self.global_step = 0 15 | self.max_step = max_step 16 | self.momentum = momentum 17 | 18 | self.__initial_lr = [group['lr'] for group in self.param_groups] 19 | 20 | 21 | def step(self, closure=None): 22 | 23 | if self.global_step < self.max_step: 24 | lr_mult = (1 - self.global_step / self.max_step) ** self.momentum 25 | 26 | for i in range(len(self.param_groups)): 27 | self.param_groups[i]['lr'] = self.__initial_lr[i] * lr_mult 28 | 29 | super().step(closure) 30 | 31 | self.global_step += 1 32 | 33 | class SGDROptimizer(torch.optim.SGD): 34 | 35 | def __init__(self, params, steps_per_epoch, lr=0, weight_decay=0, epoch_start=1, restart_mult=2): 36 | super().__init__(params, lr, weight_decay) 37 | 38 | self.global_step = 0 39 | self.local_step = 0 40 | self.total_restart = 0 41 | 42 | self.max_step = steps_per_epoch * epoch_start 43 | self.restart_mult = restart_mult 44 | 45 | self.__initial_lr = [group['lr'] for group in self.param_groups] 46 | 47 | 48 | def step(self, closure=None): 49 | 50 | if self.local_step >= self.max_step: 51 | self.local_step = 0 52 | self.max_step *= self.restart_mult 53 | self.total_restart += 1 54 | 55 | lr_mult = (1 + math.cos(math.pi * self.local_step / self.max_step))/2 / (self.total_restart + 1) 56 | 57 | for i in range(len(self.param_groups)): 58 | self.param_groups[i]['lr'] = self.__initial_lr[i] * lr_mult 59 | 60 | super().step(closure) 61 | 62 | self.local_step += 1 63 | self.global_step += 1 64 | 65 | 66 | def split_dataset(dataset, n_splits): 67 | 68 | return [Subset(dataset, np.arange(i, len(dataset), n_splits)) for i in range(n_splits)] 69 | 70 | 71 | def gap2d(x, keepdims=False): 72 | out = torch.mean(x.view(x.size(0), x.size(1), -1), -1) 73 | if keepdims: 74 | out = out.view(out.size(0), out.size(1), 1, 1) 75 | 76 | return out 77 | -------------------------------------------------------------------------------- /src/modules/imantics/color.py: -------------------------------------------------------------------------------- 1 | import random as rand 2 | import numpy as np 3 | import colorsys 4 | 5 | 6 | class Color: 7 | 8 | @classmethod 9 | def create(cls, color): 10 | """ 11 | Creates color class 12 | 13 | string - generates color from hex 14 | tuple and values between [0, 1] - generates from hls 15 | tuple and values between [0, 255] - generates from rgb 16 | 17 | :param color: tuple, list, str 18 | :returns: color class 19 | """ 20 | if isinstance(color, str): 21 | return cls(hex=color) 22 | 23 | if isinstance(color, (list, tuple)): 24 | if np.any(np.array(color) > 1): 25 | return cls(rgb=color) 26 | return cls(hls=color) 27 | 28 | if isinstance(color, Color): 29 | return color 30 | 31 | return cls().random() 32 | 33 | @classmethod 34 | def random(cls, h=(0, 1), l=(0.35,0.70), s=(0.6, 1)): 35 | """ 36 | Generates a random color 37 | 38 | :param l: range for lightness 39 | :type l: tuple 40 | :param h: range for hue 41 | :type h: tuple 42 | :param s: range for saturation 43 | :type s: tuple 44 | :returns: randomly generated color 45 | :rtype: :class:`Color` 46 | """ 47 | h = rand.uniform(h[0], h[1]) 48 | l = rand.uniform(l[0], l[1]) 49 | s = rand.uniform(s[0], s[1]) 50 | return cls(hls=(h, l, s)) 51 | 52 | def __init__(self, hls=None, rgb=None, hex=None): 53 | self._hls = hls 54 | self._rgb = rgb 55 | self._hex = hex 56 | 57 | @property 58 | def hex(self): 59 | """ 60 | Hex representation of color 61 | """ 62 | if not self._hex: 63 | r, g, b = self.rgb 64 | self._hex = '#%02x%02x%02x' % (r, g, b) 65 | 66 | return self._hex 67 | 68 | @property 69 | def hls(self): 70 | """ 71 | HLS representation of color 72 | """ 73 | if not self._hls: 74 | self._hls = colorsys.rgb_to_hls(*[i/255 for i in self.rgb]) 75 | 76 | return self._hls 77 | 78 | @property 79 | def rgb(self): 80 | """ 81 | RGB representation of color 82 | """ 83 | if not self._rgb: 84 | if self._hex: 85 | h = self.hex.lstrip('#') 86 | self._rgb = tuple(int(h[i:i + 2], 16) for i in (0, 2, 4)) 87 | else: 88 | self._rgb = [int(i*255) for i in colorsys.hls_to_rgb(*self.hls)] 89 | 90 | return self._rgb 91 | 92 | 93 | __all__ = ['Color'] 94 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch 3 | import numpy as np 4 | import torch 5 | from torch.autograd import Variable 6 | import os 7 | import argparse 8 | from datetime import datetime 9 | import torch.nn.functional as F 10 | 11 | 12 | 13 | def joint_loss(pred, mask): 14 | weit = 1 + 5*torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask) 15 | wbce = F.binary_cross_entropy_with_logits(pred, mask, reduction='none') 16 | wbce = (weit*wbce).sum(dim=(2, 3)) / weit.sum(dim=(2, 3)) 17 | 18 | pred = torch.sigmoid(pred) 19 | inter = ((pred * mask)*weit).sum(dim=(2, 3)) 20 | union = ((pred + mask)*weit).sum(dim=(2, 3)) 21 | wiou = 1 - (inter + 1)/(union - inter+1) 22 | return (wbce + wiou).mean() 23 | 24 | def joint_loss_flat(pred, mask, roi_mask=None): 25 | W = 1 + 5*torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask).squeeze() 26 | L = F.binary_cross_entropy_with_logits(pred, mask, reduction='none').squeeze() 27 | P = torch.sigmoid(pred).squeeze() 28 | M = mask.squeeze() 29 | 30 | if roi_mask is not None: 31 | W = W[roi_mask] 32 | L = L[roi_mask] 33 | P = P[roi_mask] 34 | M = M[roi_mask] 35 | 36 | # Sum them up 37 | WL = (W*L).sum() / W.sum() 38 | I = ((P * M)*W).sum() 39 | U = ((P + M)*W).sum() 40 | 41 | # Compute Weighted IoU 42 | WIoU = 1 - (I + 1)/(U - I+1) 43 | 44 | return (WL + WIoU).mean() 45 | 46 | 47 | 48 | def clip_gradient(optimizer, grad_clip): 49 | """ 50 | For calibrating mis-alignment gradient via cliping gradient technique 51 | :param optimizer: 52 | :param grad_clip: 53 | :return: 54 | """ 55 | for group in optimizer.param_groups: 56 | for param in group['params']: 57 | if param.grad is not None: 58 | param.grad.data.clamp_(-grad_clip, grad_clip) 59 | 60 | 61 | def adjust_lr(optimizer, epoch, decay_rate=0.1, decay_epoch=30): 62 | decay = decay_rate ** (epoch // decay_epoch) 63 | for param_group in optimizer.param_groups: 64 | param_group['lr'] *= decay 65 | 66 | 67 | def collate_fn(batch): 68 | batch_dict = {} 69 | for k in batch[0]: 70 | batch_dict[k] = [] 71 | for i in range(len(batch)): 72 | 73 | batch_dict[k] += [batch[i][k]] 74 | # tuple(zip(*batch)) 75 | batch_dict['images'] = torch.stack(batch_dict['images']) 76 | if 'masks' in batch_dict: 77 | batch_dict['masks'] = torch.stack(batch_dict['masks']) 78 | if 'points' in batch_dict: 79 | batch_dict['points'] = torch.stack(batch_dict['points']) 80 | if 'edges' in batch_dict: 81 | batch_dict['edges'] = torch.stack(batch_dict['edges']) 82 | 83 | return batch_dict 84 | -------------------------------------------------------------------------------- /src/models/networks/detr/util/box_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Utilities for bounding box manipulation and GIoU. 4 | """ 5 | import torch 6 | from torchvision.ops.boxes import box_area 7 | 8 | 9 | def box_cxcywh_to_xyxy(x): 10 | x_c, y_c, w, h = x.unbind(-1) 11 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 12 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 13 | return torch.stack(b, dim=-1) 14 | 15 | 16 | def box_xyxy_to_cxcywh(x): 17 | x0, y0, x1, y1 = x.unbind(-1) 18 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 19 | (x1 - x0), (y1 - y0)] 20 | return torch.stack(b, dim=-1) 21 | 22 | 23 | # modified from torchvision to also return the union 24 | def box_iou(boxes1, boxes2): 25 | area1 = box_area(boxes1) 26 | area2 = box_area(boxes2) 27 | 28 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 29 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 30 | 31 | wh = (rb - lt).clamp(min=0) # [N,M,2] 32 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 33 | 34 | union = area1[:, None] + area2 - inter 35 | 36 | iou = inter / union 37 | return iou, union 38 | 39 | 40 | def generalized_box_iou(boxes1, boxes2): 41 | """ 42 | Generalized IoU from https://giou.stanford.edu/ 43 | 44 | The boxes should be in [x0, y0, x1, y1] format 45 | 46 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 47 | and M = len(boxes2) 48 | """ 49 | # degenerate boxes gives inf / nan results 50 | # so do an early check 51 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 52 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 53 | iou, union = box_iou(boxes1, boxes2) 54 | 55 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 56 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 57 | 58 | wh = (rb - lt).clamp(min=0) # [N,M,2] 59 | area = wh[:, :, 0] * wh[:, :, 1] 60 | 61 | return iou - (area - union) / area 62 | 63 | 64 | def masks_to_boxes(masks): 65 | """Compute the bounding boxes around the provided masks 66 | 67 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. 68 | 69 | Returns a [N, 4] tensors, with the boxes in xyxy format 70 | """ 71 | if masks.numel() == 0: 72 | return torch.zeros((0, 4), device=masks.device) 73 | 74 | h, w = masks.shape[-2:] 75 | 76 | y = torch.arange(0, h, dtype=torch.float) 77 | x = torch.arange(0, w, dtype=torch.float) 78 | y, x = torch.meshgrid(y, x) 79 | 80 | x_mask = (masks * x.unsqueeze(0)) 81 | x_max = x_mask.flatten(1).max(-1)[0] 82 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 83 | 84 | y_mask = (masks * y.unsqueeze(0)) 85 | y_max = y_mask.flatten(1).max(-1)[0] 86 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 87 | 88 | return torch.stack([x_min, y_min, x_max, y_max], 1) 89 | -------------------------------------------------------------------------------- /scripts/save_latex.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | 3 | path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 4 | sys.path.insert(0, path) 5 | 6 | 7 | 8 | hash_dct = {'b04090f27c7c52bcec65f6ba455ed2d8': 'Fully_Supervised', 9 | '6d4af38d64b23586e71a198de2608333': 'LCFCN', 10 | '84ced18cf5c1fb3ad5820cc1b55a38fa': 'LCFCN+Affinity_(ours)', 11 | '63f29eec3dbe1e03364f198ed7d4b414': 'Point-level_Loss ', 12 | '017e7441c2f581b6fee9e3ac6f574edc': 'Cross_entropy_Loss+pseudo-mask'} 13 | 14 | def make_latex_count(csv_path, column_name): 15 | import os 16 | import pandas as pd 17 | from tqdm import tqdm 18 | all_df = [] 19 | paths = (os.listdir(csv_path)) 20 | for file in tqdm(paths): 21 | if file.endswith('_latex.csv'): 22 | continue 23 | if file.endswith('.csv'): 24 | file_path = '{}/{}'.format(csv_path, file) 25 | print(file_path) 26 | org_DF = pd.read_csv(file_path).round(decimals=3) 27 | org_DF = org_DF[column_name] 28 | org_DF['Loss Function'] = hash_dct[file.split("_")[0]] 29 | all_df.append(org_DF) 30 | concat_df = pd.concat(all_df, axis=1) 31 | concat_df = concat_df.transpose() 32 | concat_df = concat_df[['Loss Function', 0, 1, 2]] 33 | concat_df.to_csv(os.path.join(csv_path , "%s_latex.csv"%column_name), index=False) 34 | concat_df.to_latex(os.path.join(csv_path , "%s_latex.tex"%column_name), 35 | index=False, caption=column_name, label=column_name) 36 | print(concat_df) 37 | 38 | 39 | def make_latex_habitat(csv_path, column_name): 40 | import os 41 | import pandas as pd 42 | from tqdm import tqdm 43 | all_df = [] 44 | habitats = [] 45 | paths = (os.listdir(csv_path)) 46 | for file in tqdm(paths): 47 | if file.endswith('_latex.csv'): 48 | continue 49 | if file.endswith('.csv'): 50 | file_path = '{}/{}'.format(csv_path, file) 51 | print(file_path) 52 | org_DF = pd.read_csv(file_path).round(decimals=3) 53 | habitats = org_DF["Habitat"] 54 | org_DF = org_DF[column_name] 55 | org_DF['Loss Function'] = hash_dct[file.split("_")[0]] 56 | all_df.append(org_DF) 57 | concat_df = pd.concat(all_df, axis=1, ignore_index= True ) 58 | concat_df.insert(0, "habitat", habitats, True) 59 | concat_df = concat_df.transpose() 60 | cols = list(concat_df.columns) 61 | cols = [cols[-1]] + cols[:-1] 62 | concat_df = concat_df[cols] 63 | concat_df.to_csv(os.path.join(csv_path , "%s_latex.csv"%column_name), index=False) 64 | concat_df.to_latex(os.path.join(csv_path , "%s_latex.tex"%column_name), 65 | index=False, caption=column_name, label=column_name) 66 | print(concat_df) 67 | 68 | 69 | if __name__ == '__main__': 70 | fish = '/mnt/public/predictions/fish/' 71 | habitat = '/mnt/public/predictions/habitat/' 72 | make_latex_count(fish, "IoU class 1") 73 | make_latex_habitat(habitat, "IoU class 1") 74 | -------------------------------------------------------------------------------- /src/modules/sstransforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | def norm_grid(grid): 4 | B, H, W, C = grid.shape 5 | grid -= grid.view(-1, C).min(0).view(1, 1, 1, C) 6 | grid /= grid.view(-1, C).max(0).view(1, 1, 1, C) 7 | grid = (grid - 0.5) * 2 8 | return grid 9 | 10 | def get_grid(shape, normalized=False): 11 | B, C, H, W = shape 12 | grid_x, grid_y = torch.meshgrid(torch.arange(H), torch.arange(W)) 13 | grid_x = grid_x.float().cuda() 14 | grid_y = grid_y.float().cuda() 15 | indices = torch.stack([grid_y, grid_x], -1).view(1, H, W, 2).expand(B, H, W, 2).contiguous() 16 | if normalized: 17 | indices = norm_grid(indices) 18 | return indices 19 | 20 | def get_elastic(grid, sigma, alpha): 21 | B, H, W, C = grid.shape 22 | sigma=self.exp_dict["model"]["sigma"] 23 | alpha=self.exp_dict["model"]["alpha"] 24 | dx = gaussian_filter((np.random.rand(B, H, W) * 2 - 1), sigma, mode="constant", cval=0) * alpha 25 | dy = gaussian_filter((np.random.rand(B, H, W) * 2 - 1), sigma, mode="constant", cval=0) * alpha 26 | dx = torch.from_numpy(dx).cuda().float() 27 | dy = torch.from_numpy(dy).cuda().float() 28 | dgrid_x = grid_x + dx 29 | dgrid_y = grid_y + dy 30 | dindices = torch.stack([dgrid_y, dgrid_x], -1).view(1, H, W, 2).expand(B, H, W, 2).contiguous() 31 | # grid = get_grid(images.shape) 32 | grid += dindices 33 | grid = norm_grid(dindices) 34 | return dindices 35 | 36 | def get_flip(grid, axis=1, random=True): 37 | if random: 38 | flips = torch.randint(low=0, high=2, size=(grid.size(0), 1, 1, 1), device=grid.device) 39 | flips *= 2 40 | flips -= 1 41 | else: 42 | flips = -1 43 | grid[..., axis] *= flips 44 | return grid 45 | 46 | 47 | def batch_rotation(grid, rots): 48 | ret = [] 49 | for i, rot in enumerate(rots): 50 | ret.append(grid[i, ...].rot90(-int(rot // 90), [1,2])) 51 | return torch.stack(ret, 0) 52 | 53 | def get_rotation(images, rot): 54 | if rot == 0: # 0 degrees rotation 55 | return grid 56 | elif rot == 90: # 90 degrees rotation 57 | return get_flip(grid.permute(0, 1, 3, 2), axis=0, random=False).contiguous() 58 | elif rot == 180: # 90 degrees rotation 59 | return get_flip(get_flip(grid, 0, False), 1, False) 60 | elif rot == 270: # 270 degrees rotation / or -90 61 | return get_flip(grid, 0, False).permute(0, 1, 3, 2).contiguous() 62 | else: 63 | raise ValueError('rotation should be 0, 90, 180, or 270 degrees') 64 | # def get_rotation(grid, rot): 65 | # if rot == 0: # 0 degrees rotation 66 | # return grid 67 | # elif rot == 90: # 90 degrees rotation 68 | # return get_flip(grid.permute(0, 1, 3, 2), axis=0, random=False).contiguous() 69 | # elif rot == 180: # 90 degrees rotation 70 | # return get_flip(get_flip(grid, 0, False), 1, False) 71 | # elif rot == 270: # 270 degrees rotation / or -90 72 | # return get_flip(grid, 0, False).permute(0, 1, 3, 2).contiguous() 73 | # else: 74 | # raise ValueError('rotation should be 0, 90, 180, or 270 degrees') -------------------------------------------------------------------------------- /src/misc/pyutils.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import time 4 | import sys 5 | 6 | class Logger(object): 7 | def __init__(self, outfile): 8 | self.terminal = sys.stdout 9 | self.log = open(outfile, "w") 10 | sys.stdout = self 11 | 12 | def write(self, message): 13 | self.terminal.write(message) 14 | self.log.write(message) 15 | 16 | def flush(self): 17 | self.terminal.flush() 18 | 19 | 20 | class AverageMeter: 21 | def __init__(self, *keys): 22 | self.__data = dict() 23 | for k in keys: 24 | self.__data[k] = [0.0, 0] 25 | 26 | def add(self, dict): 27 | for k, v in dict.items(): 28 | if k not in self.__data: 29 | self.__data[k] = [0.0, 0] 30 | self.__data[k][0] += v 31 | self.__data[k][1] += 1 32 | 33 | def get(self, *keys): 34 | if len(keys) == 1: 35 | return self.__data[keys[0]][0] / self.__data[keys[0]][1] 36 | else: 37 | v_list = [self.__data[k][0] / self.__data[k][1] for k in keys] 38 | return tuple(v_list) 39 | 40 | def pop(self, key=None): 41 | if key is None: 42 | for k in self.__data.keys(): 43 | self.__data[k] = [0.0, 0] 44 | else: 45 | v = self.get(key) 46 | self.__data[key] = [0.0, 0] 47 | return v 48 | 49 | 50 | class Timer: 51 | def __init__(self, starting_msg = None): 52 | self.start = time.time() 53 | self.stage_start = self.start 54 | 55 | if starting_msg is not None: 56 | print(starting_msg, time.ctime(time.time())) 57 | 58 | def __enter__(self): 59 | return self 60 | 61 | def __exit__(self, exc_type, exc_val, exc_tb): 62 | return 63 | 64 | def update_progress(self, progress): 65 | self.elapsed = time.time() - self.start 66 | self.est_total = self.elapsed / progress 67 | self.est_remaining = self.est_total - self.elapsed 68 | self.est_finish = int(self.start + self.est_total) 69 | 70 | 71 | def str_estimated_complete(self): 72 | return str(time.ctime(self.est_finish)) 73 | 74 | def get_stage_elapsed(self): 75 | return time.time() - self.stage_start 76 | 77 | def reset_stage(self): 78 | self.stage_start = time.time() 79 | 80 | def lapse(self): 81 | out = time.time() - self.stage_start 82 | self.stage_start = time.time() 83 | return out 84 | 85 | 86 | def to_one_hot(sparse_integers, maximum_val=None, dtype=np.bool): 87 | 88 | if maximum_val is None: 89 | maximum_val = np.max(sparse_integers) + 1 90 | 91 | src_shape = sparse_integers.shape 92 | 93 | flat_src = np.reshape(sparse_integers, [-1]) 94 | src_size = flat_src.shape[0] 95 | 96 | one_hot = np.zeros((maximum_val, src_size), dtype) 97 | one_hot[flat_src, np.arange(src_size)] = 1 98 | 99 | one_hot = np.reshape(one_hot, [maximum_val] + list(src_shape)) 100 | 101 | return one_hot 102 | -------------------------------------------------------------------------------- /src/models/networks/unet_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import models 4 | 5 | def convrelu(in_channels, out_channels, kernel, padding): 6 | return nn.Sequential( 7 | nn.Conv2d(in_channels, out_channels, kernel, padding=padding), 8 | nn.ReLU(inplace=True), 9 | ) 10 | 11 | 12 | class ResNetUNet(nn.Module): 13 | def __init__(self, n_class): 14 | super().__init__() 15 | 16 | self.base_model = models.resnet18(pretrained=True) 17 | self.base_layers = list(self.base_model.children()) 18 | 19 | self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2) 20 | self.layer0_1x1 = convrelu(64, 64, 1, 0) 21 | self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4) 22 | self.layer1_1x1 = convrelu(64, 64, 1, 0) 23 | self.layer2 = self.base_layers[5] # size=(N, 128, x.H/8, x.W/8) 24 | self.layer2_1x1 = convrelu(128, 128, 1, 0) 25 | self.layer3 = self.base_layers[6] # size=(N, 256, x.H/16, x.W/16) 26 | self.layer3_1x1 = convrelu(256, 256, 1, 0) 27 | self.layer4 = self.base_layers[7] # size=(N, 512, x.H/32, x.W/32) 28 | self.layer4_1x1 = convrelu(512, 512, 1, 0) 29 | 30 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 31 | 32 | self.conv_up3 = convrelu(256 + 512, 512, 3, 1) 33 | self.conv_up2 = convrelu(128 + 512, 256, 3, 1) 34 | self.conv_up1 = convrelu(64 + 256, 256, 3, 1) 35 | self.conv_up0 = convrelu(64 + 256, 128, 3, 1) 36 | 37 | self.conv_original_size0 = convrelu(3, 64, 3, 1) 38 | self.conv_original_size1 = convrelu(64, 64, 3, 1) 39 | self.conv_original_size2 = convrelu(64 + 128, 64, 3, 1) 40 | 41 | self.conv_last = nn.Conv2d(64, n_class, 1) 42 | 43 | # # FREEZE BATCH NORMS 44 | for m in self.modules(): 45 | if isinstance(m, nn.BatchNorm2d): 46 | m.weight.requires_grad = False 47 | m.bias.requires_grad = False 48 | 49 | def forward(self, input): 50 | x_original = self.conv_original_size0(input) 51 | x_original = self.conv_original_size1(x_original) 52 | 53 | layer0 = self.layer0(input) 54 | layer1 = self.layer1(layer0) 55 | layer2 = self.layer2(layer1) 56 | layer3 = self.layer3(layer2) 57 | layer4 = self.layer4(layer3) 58 | 59 | layer4 = self.layer4_1x1(layer4) 60 | x = self.upsample(layer4) 61 | layer3 = self.layer3_1x1(layer3) 62 | x = torch.cat([x, layer3], dim=1) 63 | x = self.conv_up3(x) 64 | 65 | x = self.upsample(x) 66 | layer2 = self.layer2_1x1(layer2) 67 | x = torch.cat([x, layer2], dim=1) 68 | x = self.conv_up2(x) 69 | 70 | x = self.upsample(x) 71 | layer1 = self.layer1_1x1(layer1) 72 | x = torch.cat([x, layer1], dim=1) 73 | x = self.conv_up1(x) 74 | 75 | x = self.upsample(x) 76 | layer0 = self.layer0_1x1(layer0) 77 | x = torch.cat([x, layer0], dim=1) 78 | x = self.conv_up0(x) 79 | 80 | x = self.upsample(x) 81 | x = torch.cat([x, x_original], dim=1) 82 | x = self.conv_original_size2(x) 83 | 84 | out = self.conv_last(x) 85 | 86 | return out -------------------------------------------------------------------------------- /scripts/game_results.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | 3 | path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 4 | sys.path.insert(0, path) 5 | 6 | from haven import haven_chk as hc 7 | from haven import haven_results as hr 8 | from haven import haven_utils as hu 9 | import torch 10 | import torchvision 11 | import tqdm 12 | import pandas as pd 13 | import pprint 14 | import itertools 15 | import os 16 | import pylab as plt 17 | import time 18 | import numpy as np 19 | 20 | from src import models 21 | from src import datasets 22 | from src import utils as ut 23 | 24 | import argparse 25 | 26 | from torch.utils.data import sampler 27 | from torch.utils.data.sampler import RandomSampler 28 | from torch.backends import cudnn 29 | from torch.nn import functional as F 30 | from torch.utils.data import DataLoader 31 | import pandas as pd 32 | cudnn.benchmark = True 33 | 34 | if __name__ == "__main__": 35 | savedir_base = '/mnt/public/results/toolkit/weak_supervision' 36 | 37 | hash_list = ['b04090f27c7c52bcec65f6ba455ed2d8', 38 | '6d4af38d64b23586e71a198de2608333', 39 | '84ced18cf5c1fb3ad5820cc1b55a38fa', 40 | '63f29eec3dbe1e03364f198ed7d4b414', 41 | '017e7441c2f581b6fee9e3ac6f574edc'] 42 | datadir = '/mnt/public/datasets/DeepFish/' 43 | 44 | score_list = [] 45 | for hash_id in hash_list: 46 | fname = os.path.join('/mnt/public/predictions/game/%s.pkl' % hash_id) 47 | exp_dict = hu.load_json(os.path.join(savedir_base, hash_id, 'exp_dict.json')) 48 | if os.path.exists(fname): 49 | print('FOUND:', fname) 50 | val_dict = hu.load_pkl(fname) 51 | else: 52 | 53 | train_set = datasets.get_dataset(dataset_dict=exp_dict["dataset"], 54 | split='train', 55 | datadir=datadir, 56 | exp_dict=exp_dict, 57 | dataset_size=exp_dict['dataset_size']) 58 | 59 | test_set = datasets.get_dataset(dataset_dict=exp_dict["dataset"], 60 | split='test', 61 | datadir=datadir, 62 | exp_dict=exp_dict, 63 | dataset_size=exp_dict['dataset_size']) 64 | 65 | test_loader = DataLoader(test_set, 66 | batch_size=1, 67 | collate_fn=ut.collate_fn, 68 | num_workers=0) 69 | pprint.pprint(exp_dict) 70 | # Model 71 | # ================== 72 | model = models.get_model(model_dict=exp_dict['model'], 73 | exp_dict=exp_dict, 74 | train_set=train_set).cuda() 75 | 76 | model_path = os.path.join(savedir_base, hash_id, 'model_best.pth') 77 | 78 | # load best model 79 | model.load_state_dict(hu.torch_load(model_path)) 80 | val_dict = model.val_on_loader(test_loader) 81 | 82 | val_dict['hash_id'] = hash_id 83 | pprint.pprint(val_dict) 84 | 85 | hu.save_pkl(fname, val_dict) 86 | 87 | val_dict['model'] = exp_dict['model'] 88 | score_list += [val_dict] 89 | 90 | print(pd.DataFrame(score_list)) 91 | 92 | 93 | -------------------------------------------------------------------------------- /scripts/SEAM/network/resnet38_SEAM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.sparse as sparse 4 | import torch.nn.functional as F 5 | import numpy as np 6 | np.set_printoptions(threshold=np.inf) 7 | 8 | from . import resnet38d 9 | 10 | class Net(resnet38d.Net): 11 | def __init__(self): 12 | super(Net, self).__init__() 13 | self.dropout7 = torch.nn.Dropout2d(0.5) 14 | 15 | self.fc8 = nn.Conv2d(4096, 21, 1, bias=False) 16 | 17 | self.f8_3 = torch.nn.Conv2d(512, 64, 1, bias=False) 18 | self.f8_4 = torch.nn.Conv2d(1024, 128, 1, bias=False) 19 | self.f9 = torch.nn.Conv2d(192+3, 192, 1, bias=False) 20 | 21 | torch.nn.init.xavier_uniform_(self.fc8.weight) 22 | torch.nn.init.kaiming_normal_(self.f8_3.weight) 23 | torch.nn.init.kaiming_normal_(self.f8_4.weight) 24 | torch.nn.init.xavier_uniform_(self.f9.weight, gain=4) 25 | self.from_scratch_layers = [self.f8_3, self.f8_4, self.f9, self.fc8] 26 | self.not_training = [self.conv1a, self.b2, self.b2_1, self.b2_2] 27 | 28 | def forward(self, x): 29 | N, C, H, W = x.size() 30 | d = super().forward_as_dict(x) 31 | cam = self.fc8(self.dropout7(d['conv6'])) 32 | n,c,h,w = cam.size() 33 | with torch.no_grad(): 34 | cam_d = F.relu(cam.detach()) 35 | cam_d_max = torch.max(cam_d.view(n,c,-1), dim=-1)[0].view(n,c,1,1)+1e-5 36 | cam_d_norm = F.relu(cam_d-1e-5)/cam_d_max 37 | cam_d_norm[:,0,:,:] = 1-torch.max(cam_d_norm[:,1:,:,:], dim=1)[0] 38 | cam_max = torch.max(cam_d_norm[:,1:,:,:], dim=1, keepdim=True)[0] 39 | cam_d_norm[:,1:,:,:][cam_d_norm[:,1:,:,:] < cam_max] = 0 40 | 41 | f8_3 = F.relu(self.f8_3(d['conv4'].detach()), inplace=True) 42 | f8_4 = F.relu(self.f8_4(d['conv5'].detach()), inplace=True) 43 | x_s = F.interpolate(x,(h,w),mode='bilinear',align_corners=True) 44 | f = torch.cat([x_s, f8_3, f8_4], dim=1) 45 | n,c,h,w = f.size() 46 | 47 | cam_rv = F.interpolate(self.PCM(cam_d_norm, f), (H,W), mode='bilinear', align_corners=True) 48 | cam = F.interpolate(cam, (H,W), mode='bilinear', align_corners=True) 49 | return cam, cam_rv 50 | 51 | def PCM(self, cam, f): 52 | n,c,h,w = f.size() 53 | cam = F.interpolate(cam, (h,w), mode='bilinear', align_corners=True).view(n,-1,h*w) 54 | f = self.f9(f) 55 | f = f.view(n,-1,h*w) 56 | f = f/(torch.norm(f,dim=1,keepdim=True)+1e-5) 57 | 58 | aff = F.relu(torch.matmul(f.transpose(1,2), f),inplace=True) 59 | aff = aff/(torch.sum(aff,dim=1,keepdim=True)+1e-5) 60 | cam_rv = torch.matmul(cam, aff).view(n,-1,h,w) 61 | 62 | return cam_rv 63 | 64 | def get_parameter_groups(self): 65 | groups = ([], [], [], []) 66 | print('======================================================') 67 | for m in self.modules(): 68 | 69 | if (isinstance(m, nn.Conv2d) or isinstance(m, nn.modules.normalization.GroupNorm)): 70 | 71 | if m.weight.requires_grad: 72 | if m in self.from_scratch_layers: 73 | groups[2].append(m.weight) 74 | else: 75 | groups[0].append(m.weight) 76 | 77 | if m.bias is not None and m.bias.requires_grad: 78 | if m in self.from_scratch_layers: 79 | groups[3].append(m.bias) 80 | else: 81 | groups[1].append(m.bias) 82 | 83 | return groups 84 | 85 | -------------------------------------------------------------------------------- /src/models/networks/resnet_seam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.sparse as sparse 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | from . import resnet38d 8 | 9 | class ResNetSeam(resnet38d.Net): 10 | def __init__(self, n_classes=2): 11 | super().__init__() 12 | self.dropout7 = torch.nn.Dropout2d(0.5) 13 | 14 | self.fc8 = nn.Conv2d(4096, n_classes, 1, bias=False) 15 | 16 | self.f8_3 = torch.nn.Conv2d(512, 64, 1, bias=False) 17 | self.f8_4 = torch.nn.Conv2d(1024, 128, 1, bias=False) 18 | self.f9 = torch.nn.Conv2d(192+3, 192, 1, bias=False) 19 | 20 | torch.nn.init.xavier_uniform_(self.fc8.weight) 21 | torch.nn.init.kaiming_normal_(self.f8_3.weight) 22 | torch.nn.init.kaiming_normal_(self.f8_4.weight) 23 | torch.nn.init.xavier_uniform_(self.f9.weight, gain=4) 24 | self.from_scratch_layers = [self.f8_3, self.f8_4, self.f9, self.fc8] 25 | self.not_training = [self.conv1a, self.b2, self.b2_1, self.b2_2] 26 | 27 | def forward(self, x, train=False): 28 | N, C, H, W = x.size() 29 | d = super().forward_as_dict(x) 30 | cam = self.fc8(self.dropout7(d['conv6'])) 31 | n,c,h,w = cam.size() 32 | with torch.no_grad(): 33 | cam_d = F.relu(cam.detach()) 34 | cam_d_max = torch.max(cam_d.view(n,c,-1), dim=-1)[0].view(n,c,1,1)+1e-5 35 | cam_d_norm = F.relu(cam_d-1e-5)/cam_d_max 36 | cam_d_norm[:,0,:,:] = 1-torch.max(cam_d_norm[:,1:,:,:], dim=1)[0] 37 | cam_max = torch.max(cam_d_norm[:,1:,:,:], dim=1, keepdim=True)[0] 38 | cam_d_norm[:,1:,:,:][cam_d_norm[:,1:,:,:] < cam_max] = 0 39 | 40 | f8_3 = F.relu(self.f8_3(d['conv4'].detach()), inplace=True) 41 | f8_4 = F.relu(self.f8_4(d['conv5'].detach()), inplace=True) 42 | x_s = F.interpolate(x,(h,w),mode='bilinear',align_corners=True) 43 | f = torch.cat([x_s, f8_3, f8_4], dim=1) 44 | n,c,h,w = f.size() 45 | 46 | # cam_rv = F.interpolate(self.PCM(cam_d_norm, f), (H,W), mode='bilinear', align_corners=True) 47 | # if train == False: 48 | # return cam_rv 49 | 50 | cam = F.interpolate(cam, (H,W), mode='bilinear', align_corners=True) 51 | return cam 52 | 53 | def PCM(self, cam, f): 54 | n,c,h,w = f.size() 55 | cam = F.interpolate(cam, (h,w), mode='bilinear', align_corners=True).view(n,-1,h*w) 56 | f = self.f9(f) 57 | f = f.view(n,-1,h*w) 58 | f = f/(torch.norm(f,dim=1,keepdim=True)+1e-5) 59 | 60 | aff = F.relu(torch.matmul(f.transpose(1,2), f),inplace=True) 61 | aff = aff/(torch.sum(aff,dim=1,keepdim=True)+1e-5) 62 | cam_rv = torch.matmul(cam, aff).view(n,-1,h,w) 63 | 64 | return cam_rv 65 | 66 | def get_parameter_groups(self): 67 | groups = ([], [], [], []) 68 | print('======================================================') 69 | for m in self.modules(): 70 | 71 | if (isinstance(m, nn.Conv2d) or isinstance(m, nn.modules.normalization.GroupNorm)): 72 | 73 | if m.weight.requires_grad: 74 | if m in self.from_scratch_layers: 75 | groups[2].append(m.weight) 76 | else: 77 | groups[0].append(m.weight) 78 | 79 | if m.bias is not None and m.bias.requires_grad: 80 | if m in self.from_scratch_layers: 81 | groups[3].append(m.bias) 82 | else: 83 | groups[1].append(m.bias) 84 | 85 | return groups 86 | -------------------------------------------------------------------------------- /src/models/networks/detr/position_encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Various positional encodings for the transformer. 4 | """ 5 | import math 6 | import torch 7 | from torch import nn 8 | 9 | from .util.misc import NestedTensor 10 | 11 | 12 | class PositionEmbeddingSine(nn.Module): 13 | """ 14 | This is a more standard version of the position embedding, very similar to the one 15 | used by the Attention is all you need paper, generalized to work on images. 16 | """ 17 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 18 | super().__init__() 19 | self.num_pos_feats = num_pos_feats 20 | self.temperature = temperature 21 | self.normalize = normalize 22 | if scale is not None and normalize is False: 23 | raise ValueError("normalize should be True if scale is passed") 24 | if scale is None: 25 | scale = 2 * math.pi 26 | self.scale = scale 27 | 28 | def forward(self, tensor_list: NestedTensor): 29 | x = tensor_list.tensors 30 | mask = tensor_list.mask 31 | assert mask is not None 32 | not_mask = ~mask 33 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 34 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 35 | if self.normalize: 36 | eps = 1e-6 37 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 38 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 39 | 40 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 41 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 42 | 43 | pos_x = x_embed[:, :, :, None] / dim_t 44 | pos_y = y_embed[:, :, :, None] / dim_t 45 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 46 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 47 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 48 | return pos 49 | 50 | 51 | class PositionEmbeddingLearned(nn.Module): 52 | """ 53 | Absolute pos embedding, learned. 54 | """ 55 | def __init__(self, num_pos_feats=256): 56 | super().__init__() 57 | self.row_embed = nn.Embedding(50, num_pos_feats) 58 | self.col_embed = nn.Embedding(50, num_pos_feats) 59 | self.reset_parameters() 60 | 61 | def reset_parameters(self): 62 | nn.init.uniform_(self.row_embed.weight) 63 | nn.init.uniform_(self.col_embed.weight) 64 | 65 | def forward(self, tensor_list: NestedTensor): 66 | x = tensor_list.tensors 67 | h, w = x.shape[-2:] 68 | i = torch.arange(w, device=x.device) 69 | j = torch.arange(h, device=x.device) 70 | x_emb = self.col_embed(i) 71 | y_emb = self.row_embed(j) 72 | pos = torch.cat([ 73 | x_emb.unsqueeze(0).repeat(h, 1, 1), 74 | y_emb.unsqueeze(1).repeat(1, w, 1), 75 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) 76 | return pos 77 | 78 | 79 | def build_position_encoding(args): 80 | N_steps = args.hidden_dim // 2 81 | if args.position_embedding in ('v2', 'sine'): 82 | # TODO find a better way of exposing other arguments 83 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True) 84 | elif args.position_embedding in ('v3', 'learned'): 85 | position_embedding = PositionEmbeddingLearned(N_steps) 86 | else: 87 | raise ValueError(f"not supported {args.position_embedding}") 88 | 89 | return position_embedding 90 | -------------------------------------------------------------------------------- /scripts/qualitative.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | 3 | path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 4 | sys.path.insert(0, path) 5 | 6 | from haven import haven_chk as hc 7 | from haven import haven_results as hr 8 | from haven import haven_utils as hu 9 | import torch 10 | import torchvision 11 | import tqdm 12 | import pandas as pd 13 | import pprint 14 | import itertools 15 | import os 16 | import pylab as plt 17 | import time 18 | import numpy as np 19 | 20 | from src import models 21 | from src import datasets 22 | from src import utils as ut 23 | 24 | import argparse 25 | 26 | from torch.utils.data import sampler 27 | from torch.utils.data.sampler import RandomSampler 28 | from torch.backends import cudnn 29 | from torch.nn import functional as F 30 | from torch.utils.data import DataLoader 31 | 32 | cudnn.benchmark = True 33 | 34 | if __name__ == "__main__": 35 | savedir_base = '/mnt/public/results/toolkit/weak_supervision' 36 | # hash_list = [] 37 | datadir = '/mnt/public/datasets/DeepFish/' 38 | # on localiization 39 | hash_list = [#Point loss 40 | '63f29eec3dbe1e03364f198ed7d4b414', 41 | # LCFCN 42 | 'a55d2c5dda331b1a0e191b104406dd1c', 43 | #A-LCFCN 44 | '13b0f4e395b6dc5368f7965c20e75612', 45 | # A-LCFCN+PM 46 | 'fcc1acac9ff5c2fa776d65ac76c3892b'] 47 | 48 | main_hash = 'fcc1acac9ff5c2fa776d65ac76c3892b' 49 | exp_dict = hu.load_json(os.path.join(savedir_base, main_hash, 'exp_dict.json')) 50 | exp_dict['count_mode'] = 0 51 | test_set = datasets.get_dataset(dataset_dict=exp_dict["dataset"], 52 | split='test', 53 | datadir=datadir, 54 | exp_dict=exp_dict, 55 | dataset_size=exp_dict['dataset_size']) 56 | test_loader = DataLoader(test_set, 57 | # sampler=val_sampler, 58 | batch_size=1, 59 | collate_fn=ut.collate_fn, 60 | num_workers=0) 61 | 62 | for i, batch in enumerate(test_loader): 63 | points = (batch['points'].squeeze() == 1).numpy() 64 | if points.sum() == 0: 65 | continue 66 | savedir_image = os.path.join('.tmp/qualitative/%d.png' % (i)) 67 | img = hu.denormalize(batch['images'], mode='rgb') 68 | img_org = np.array(hu.save_image(savedir_image, img, mask=batch['masks'].numpy(), return_image=True)) 69 | 70 | img_list = [img_org] 71 | with torch.no_grad(): 72 | for hash_id in hash_list: 73 | score_path = os.path.join(savedir_base, hash_id, 'score_list_best.pkl') 74 | score_list = hu.load_pkl(score_path) 75 | 76 | exp_dict = hu.load_json(os.path.join(savedir_base, hash_id, 'exp_dict.json')) 77 | print(i, exp_dict['model']['loss'], exp_dict['model'].get('with_affinity'), 'score:', score_list[-1]['test_class1']) 78 | 79 | model = models.get_model(model_dict=exp_dict['model'], 80 | exp_dict=exp_dict, 81 | train_set=test_set).cuda() 82 | 83 | model_path = os.path.join(savedir_base, hash_id, 'model_best.pth') 84 | model.load_state_dict(hu.torch_load(model_path), with_opt=False) 85 | mask_pred = model.predict_on_batch(batch) 86 | img_pred = np.array(hu.save_image(savedir_image, img, mask=mask_pred, return_image=True)) 87 | img_list += [img_pred] 88 | 89 | img_cat = np.concatenate(img_list, axis=1) 90 | hu.save_image(savedir_image, img_cat) 91 | 92 | -------------------------------------------------------------------------------- /src/models/networks/unet2d.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class DoubleConv(nn.Module): 8 | """(convolution => [BN] => ReLU) * 2""" 9 | 10 | def __init__(self, in_channels, out_channels): 11 | super().__init__() 12 | self.double_conv = nn.Sequential( 13 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 14 | nn.BatchNorm2d(out_channels), 15 | nn.ReLU(inplace=True), 16 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), 17 | nn.BatchNorm2d(out_channels), 18 | nn.ReLU(inplace=True) 19 | ) 20 | 21 | def forward(self, x): 22 | return self.double_conv(x) 23 | 24 | 25 | class Down(nn.Module): 26 | """Downscaling with maxpool then double conv""" 27 | 28 | def __init__(self, in_channels, out_channels): 29 | super().__init__() 30 | self.maxpool_conv = nn.Sequential( 31 | nn.MaxPool2d(2), 32 | DoubleConv(in_channels, out_channels) 33 | ) 34 | 35 | def forward(self, x): 36 | return self.maxpool_conv(x) 37 | 38 | 39 | class Up(nn.Module): 40 | """Upscaling then double conv""" 41 | 42 | def __init__(self, in_channels, out_channels, bilinear=True): 43 | super().__init__() 44 | 45 | # if bilinear, use the normal convolutions to reduce the number of channels 46 | if bilinear: 47 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 48 | else: 49 | self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2) 50 | 51 | self.conv = DoubleConv(in_channels, out_channels) 52 | 53 | def forward(self, x1, x2): 54 | x1 = self.up(x1) 55 | # input is CHW 56 | diffY = torch.tensor([x2.size()[2] - x1.size()[2]]) 57 | diffX = torch.tensor([x2.size()[3] - x1.size()[3]]) 58 | 59 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 60 | diffY // 2, diffY - diffY // 2]) 61 | # if you have padding issues, see 62 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 63 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 64 | x = torch.cat([x2, x1], dim=1) 65 | return self.conv(x) 66 | 67 | 68 | class OutConv(nn.Module): 69 | def __init__(self, in_channels, out_channels): 70 | super(OutConv, self).__init__() 71 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 72 | 73 | def forward(self, x): 74 | return self.conv(x) 75 | 76 | """ Full assembly of the parts to form the complete network """ 77 | 78 | class UNet(nn.Module): 79 | def __init__(self, n_channels, n_classes, bilinear=True): 80 | super(UNet, self).__init__() 81 | self.n_channels = n_channels 82 | self.n_classes = n_classes 83 | self.bilinear = bilinear 84 | 85 | self.inc = DoubleConv(n_channels, 64) 86 | self.down1 = Down(64, 128) 87 | self.down2 = Down(128, 256) 88 | self.down3 = Down(256, 512) 89 | self.down4 = Down(512, 512) 90 | self.up1 = Up(1024, 256, bilinear) 91 | self.up2 = Up(512, 128, bilinear) 92 | self.up3 = Up(256, 64, bilinear) 93 | self.up4 = Up(128, 64, bilinear) 94 | self.outc = OutConv(64, n_classes) 95 | 96 | def forward(self, x): 97 | x1 = self.inc(x) 98 | x2 = self.down1(x1) 99 | x3 = self.down2(x2) 100 | x4 = self.down3(x3) 101 | x5 = self.down4(x4) 102 | x = self.up1(x5, x4) 103 | x = self.up2(x, x3) 104 | x = self.up3(x, x2) 105 | x = self.up4(x, x1) 106 | logits = self.outc(x) 107 | return logits -------------------------------------------------------------------------------- /src/modules/eprop/eprop.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | def get_similarity_matrix(x, rbf_scale): 8 | b, c = x.size() 9 | sq_dist = ((x.view(b, 1, c) - x.view(1, b, c))**2).sum(-1) / np.sqrt(c) 10 | mask = sq_dist != 0 11 | sq_dist = sq_dist / sq_dist[mask].std() 12 | weights = torch.exp(-sq_dist * rbf_scale) 13 | mask = torch.eye(weights.size(1), dtype=torch.bool, device=weights.device) 14 | weights = weights * (~mask).float() 15 | return weights 16 | 17 | def embedding_propagation(x, alpha, rbf_scale, norm_prop, propagator=None): 18 | if propagator is None: 19 | weights = get_similarity_matrix(x, rbf_scale) 20 | propagator = global_consistency(weights, alpha=alpha, norm_prop=norm_prop) 21 | return torch.mm(propagator, x) 22 | 23 | def label_propagation(x, labels, nclasses, alpha, rbf_scale, norm_prop, apply_log, propagator=None, epsilon=1e-6): 24 | labels = F.one_hot(labels, nclasses + 1) 25 | labels = labels[:, :nclasses].float() # the max label is unlabeled 26 | if propagator is None: 27 | weights = get_similarity_matrix(x, rbf_scale) 28 | propagator = global_consistency(weights, alpha=alpha, norm_prop=norm_prop) 29 | y_pred = torch.mm(propagator, labels) 30 | if apply_log: 31 | y_pred = torch.log(y_pred + epsilon) 32 | 33 | return y_pred 34 | 35 | 36 | class EmbeddingPropagation(torch.nn.Module): 37 | def __init__(self, alpha=0.5, rbf_scale=1, norm_prop=False): 38 | super().__init__() 39 | self.alpha = alpha 40 | self.rbf_scale = rbf_scale 41 | self.norm_prop = norm_prop 42 | 43 | def forward(self, x, propagator=None): 44 | b, c, h, w = x.size() 45 | x = x.view(b, c * h * w) 46 | return embedding_propagation(x, self.alpha, self.rbf_scale, self.norm_prop, propagator=propagator).view(b, c, h, w) 47 | 48 | class LabelPropagation(torch.nn.Module): 49 | def __init__(self, alpha=0.2, rbf_scale=1, norm_prop=True, apply_log=True): 50 | super().__init__() 51 | self.alpha = alpha 52 | self.rbf_scale = rbf_scale 53 | self.norm_prop = norm_prop 54 | self.apply_log = apply_log 55 | 56 | def forward(self, x, labels, nclasses, propagator=None): 57 | """Applies label propagation given a set of embeddings and labels 58 | 59 | Arguments: 60 | x {Tensor} -- Input embeddings 61 | labels {Tensor} -- Input labels from 0 to nclasses + 1. The highest value corresponds to unlabeled samples. 62 | nclasses {int} -- Total number of classes 63 | 64 | Keyword Arguments: 65 | propagator {Tensor} -- A pre-computed propagator (default: {None}) 66 | 67 | Returns: 68 | tuple(Tensor, Tensor) -- Logits and Propagator 69 | """ 70 | return label_propagation(x, labels, nclasses, self.alpha, self.rbf_scale, self.norm_prop, self.apply_log, propagator=propagator) 71 | 72 | def global_consistency(weights, alpha=1, norm_prop=False): 73 | """Implements D. Zhou et al. "Learning with local and global consistency". (Same as in TPN paper but without bug) 74 | 75 | Args: 76 | weights: Tensor of shape (batch, n, n). Expected to be exp( -d^2/s^2 ), where d is the euclidean distance and 77 | s the scale parameter. 78 | labels: Tensor of shape (batch, n, n_classes) 79 | alpha: Scaler, acts as a smoothing factor 80 | Returns: 81 | Tensor of shape (batch, n, n_classes) representing the logits of each classes 82 | """ 83 | n = weights.shape[1] 84 | identity = torch.eye(n, dtype=weights.dtype, device=weights.device) 85 | isqrt_diag = 1. / torch.sqrt(1e-4 + torch.sum(weights, dim=-1)) 86 | # checknan(laplacian=isqrt_diag) 87 | S = weights * isqrt_diag[None, :] * isqrt_diag[:, None] 88 | # checknan(normalizedlaplacian=S) 89 | propagator = identity - alpha * S 90 | propagator = torch.inverse(propagator[None, ...])[0] 91 | # checknan(propagator=propagator) 92 | if norm_prop: 93 | propagator = F.normalize(propagator, p=1, dim=-1) 94 | return propagator -------------------------------------------------------------------------------- /src/models/networks/fcn8_resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torchvision 3 | import torch 4 | from skimage import morphology as morph 5 | import numpy as np 6 | from torch import optim 7 | import torch.nn.functional as F 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | import numpy as np 12 | from skimage.morphology import watershed 13 | from skimage.segmentation import find_boundaries 14 | from scipy import ndimage 15 | 16 | 17 | class FCN8(nn.Module): 18 | def __init__(self, n_classes): 19 | super().__init__() 20 | self.n_classes = n_classes 21 | 22 | # Load the pretrained weights, remove avg pool 23 | # layer and get the output stride of 8 24 | resnet50_32s = torchvision.models.resnet50(pretrained=True) 25 | resnet_block_expansion_rate = resnet50_32s.layer1[0].expansion 26 | 27 | # Create a linear layer -- we don't need logits in this case 28 | resnet50_32s.fc = nn.Sequential() 29 | 30 | self.resnet50_32s = resnet50_32s 31 | 32 | self.score_32s = nn.Conv2d(512 * resnet_block_expansion_rate, 33 | self.n_classes, 34 | kernel_size=1) 35 | 36 | self.score_16s = nn.Conv2d(256 * resnet_block_expansion_rate, 37 | self.n_classes, 38 | kernel_size=1) 39 | 40 | self.score_8s = nn.Conv2d(128 * resnet_block_expansion_rate, 41 | self.n_classes, 42 | kernel_size=1) 43 | 44 | 45 | # # FREEZE BATCH NORMS 46 | for m in self.modules(): 47 | if isinstance(m, nn.BatchNorm2d): 48 | m.weight.requires_grad = False 49 | m.bias.requires_grad = False 50 | 51 | 52 | def extract_features(self, x_input): 53 | self.resnet50_32s.eval() 54 | x = self.resnet50_32s.conv1(x_input) 55 | x = self.resnet50_32s.bn1(x) 56 | x = self.resnet50_32s.relu(x) 57 | x = self.resnet50_32s.maxpool(x) 58 | 59 | x = self.resnet50_32s.layer1(x) 60 | 61 | x_8s = self.resnet50_32s.layer2(x) 62 | x_16s = self.resnet50_32s.layer3(x_8s) 63 | x_32s = self.resnet50_32s.layer4(x_16s) 64 | 65 | return x_8s, x_16s, x_32s 66 | 67 | 68 | 69 | def forward(self, x): 70 | self.resnet50_32s.eval() 71 | input_spatial_dim = x.size()[2:] 72 | 73 | x = self.resnet50_32s.conv1(x) 74 | x = self.resnet50_32s.bn1(x) 75 | x = self.resnet50_32s.relu(x) 76 | x = self.resnet50_32s.maxpool(x) 77 | 78 | x = self.resnet50_32s.layer1(x) 79 | 80 | x = self.resnet50_32s.layer2(x) 81 | logits_8s = self.score_8s(x) 82 | 83 | x = self.resnet50_32s.layer3(x) 84 | logits_16s = self.score_16s(x) 85 | 86 | x = self.resnet50_32s.layer4(x) 87 | logits_32s = self.score_32s(x) 88 | 89 | logits_16s_spatial_dim = logits_16s.size()[2:] 90 | logits_8s_spatial_dim = logits_8s.size()[2:] 91 | 92 | logits_16s += nn.functional.interpolate(logits_32s, 93 | size=logits_16s_spatial_dim, 94 | mode="bilinear", 95 | align_corners=True) 96 | 97 | logits_8s += nn.functional.interpolate(logits_16s, 98 | size=logits_8s_spatial_dim, 99 | mode="bilinear", 100 | align_corners=True) 101 | 102 | logits_upsampled = nn.functional.interpolate(logits_8s, 103 | size=input_spatial_dim, 104 | mode="bilinear", 105 | align_corners=True) 106 | 107 | return logits_upsampled 108 | 109 | -------------------------------------------------------------------------------- /scripts/SEAM/tool/visualization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import cv2 5 | import pydensecrf.densecrf as dcrf 6 | from pydensecrf.utils import unary_from_softmax 7 | 8 | def color_pro(pro, img=None, mode='hwc'): 9 | H, W = pro.shape 10 | pro_255 = (pro*255).astype(np.uint8) 11 | pro_255 = np.expand_dims(pro_255,axis=2) 12 | color = cv2.applyColorMap(pro_255,cv2.COLORMAP_JET) 13 | color = cv2.cvtColor(color, cv2.COLOR_BGR2RGB) 14 | if img is not None: 15 | rate = 0.5 16 | if mode == 'hwc': 17 | assert img.shape[0] == H and img.shape[1] == W 18 | color = cv2.addWeighted(img,rate,color,1-rate,0) 19 | elif mode == 'chw': 20 | assert img.shape[1] == H and img.shape[2] == W 21 | img = np.transpose(img,(1,2,0)) 22 | color = cv2.addWeighted(img,rate,color,1-rate,0) 23 | color = np.transpose(color,(2,0,1)) 24 | else: 25 | if mode == 'chw': 26 | color = np.transpose(color,(2,0,1)) 27 | return color 28 | 29 | def generate_vis(p, gt, img, func_label2color, threshold=0.1, norm=True): 30 | # All the input should be numpy.array 31 | # img should be 0-255 uint8 32 | C, H, W = p.shape 33 | 34 | if norm: 35 | prob = max_norm(p, 'numpy') 36 | else: 37 | prob = p 38 | if gt is not None: 39 | prob = prob * gt 40 | prob[prob<=0] = 1e-7 41 | if threshold is not None: 42 | prob[0,:,:] = np.power(1-np.max(prob[1:,:,:],axis=0,keepdims=True), 4) 43 | 44 | CLS = ColorCLS(prob, func_label2color) 45 | CAM = ColorCAM(prob, img) 46 | 47 | prob_crf = dense_crf(prob, img, n_classes=C, n_iters=1) 48 | 49 | CLS_crf = ColorCLS(prob_crf, func_label2color) 50 | CAM_crf = ColorCAM(prob_crf, img) 51 | 52 | return CLS, CAM, CLS_crf, CAM_crf 53 | 54 | def max_norm(p, version='torch', e=1e-5): 55 | if version is 'torch': 56 | if p.dim() == 3: 57 | C, H, W = p.size() 58 | p = F.relu(p) 59 | max_v = torch.max(p.view(C,-1),dim=-1)[0].view(C,1,1) 60 | min_v = torch.min(p.view(C,-1),dim=-1)[0].view(C,1,1) 61 | p = F.relu(p-min_v-e)/(max_v-min_v+e) 62 | elif p.dim() == 4: 63 | N, C, H, W = p.size() 64 | p = F.relu(p) 65 | max_v = torch.max(p.view(N,C,-1),dim=-1)[0].view(N,C,1,1) 66 | min_v = torch.min(p.view(N,C,-1),dim=-1)[0].view(N,C,1,1) 67 | p = F.relu(p-min_v-e)/(max_v-min_v+e) 68 | elif version is 'numpy' or version is 'np': 69 | if p.ndim == 3: 70 | C, H, W = p.shape 71 | p[p<0] = 0 72 | max_v = np.max(p,(1,2),keepdims=True) 73 | min_v = np.min(p,(1,2),keepdims=True) 74 | p[p 0 115 | return_interm_layers = args.masks 116 | backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation) 117 | model = Joiner(backbone, position_embedding) 118 | model.num_channels = backbone.num_channels 119 | return model 120 | -------------------------------------------------------------------------------- /scripts/SEAM/README.md: -------------------------------------------------------------------------------- 1 | # SEAM 2 | The implementation of [**Self-supervised Equivariant Attention Mechanism for Weakly Supervised Semantic Segmentaion**](https://arxiv.org/abs/2004.04581). 3 | ![SEAM network](https://github.com/YudeWang/SEAM/blob/master/network.png) 4 | 5 | ## Abstract 6 | Image-level weakly supervised semantic segmentation is a challenging problem that has been deeply studied in recentyears. Most of advanced solutions exploit class activation map (CAM). However, CAMs can hardly serve as the object mask due to the gap between full and weak supervisions. In this paper, we propose a self-supervised equivariant attention mechanism (SEAM) to discover additional supervision and narrow the gap. Our method is based on the observation that equivariance is an implicit constraint in fully supervised semantic segmentation, whose pixel-level labels take the same spatial transformation as the input images during data augmentation. However, this constraint is lost on the CAMs trained by image-level supervision. Therefore, we propose consistency regularization on predicted CAMs from various transformed images to provide self-supervision for network learning. Moreover, we propose a pixel correlation module (PCM), which exploits context appearance information and refines the prediction of current pixel by its similar neighbors, leading to further improvement on CAMs consistency. Extensive experiments on PASCAL VOC 2012 dataset demonstrate our method outperforms state-of-the-art methods using the same level of supervision. 7 | 8 | Thanks to the work of [jiwoon-ahn](https://github.com/jiwoon-ahn), the code of this repository borrow heavly from his [AffinityNet](https://github.com/jiwoon-ahn/psa) repository, and we follw the same pipeline to verify the effectiveness of our SEAM. 9 | 10 | ## Requirements 11 | - Python 3.6 12 | - pytorch 0.4.1, torchvision 0.2.1 13 | - CUDA 9.0 14 | - 4 x GPUs (12GB) 15 | 16 | ## Usage 17 | ### Installation 18 | - Download the repository. 19 | ``` 20 | git clone https://github.com/YudeWang/SEAM.git 21 | ``` 22 | - Install python dependencies. 23 | ``` 24 | pip install -r requirements.txt 25 | ``` 26 | - **Download model weights from [here](https://drive.google.com/open?id=1jWsV5Yev-PwKgvvtUM3GnY0ogb50-qKa)**, including ImageNet pretrained models and our training results. 27 | 28 | - Download PASCAL VOC 2012 devkit (follow instructions in http://host.robots.ox.ac.uk/pascal/VOC/voc2012/#devkit). It is suggested to make a soft link toward downloaded dataset. 29 | ``` 30 | ln -s $your_dataset_path/VOCdevkit/VOC2012 VOC2012 31 | ``` 32 | 33 | - (Optional) The image-level labels have already been given in `voc12/cls_label.npy`. If you want to regenerate it (which is unnecessary), please download the annotation of VOC 2012 SegmentationClassAug training set (containing 10582 images), which can be download [here](https://www.dropbox.com/s/oeu149j8qtbs1x0/SegmentationClassAug.zip?dl=0) and place them all as `VOC2012/SegmentationClassAug/xxxxxx.png`. Then run the code 34 | ``` 35 | cd voc12 36 | python make_cls_labels.py --voc12_root VOC2012 37 | ``` 38 | ### SEAM step 39 | 40 | 1. SEAM training 41 | ``` 42 | python train_SEAM.py --voc12_root VOC2012 --weights $pretrained_model --session_name $your_session_name 43 | ``` 44 | 45 | 2. SEAM inference. 46 | ``` 47 | python infer_SEAM.py --weights $SEAM_weights --infer_list [voc12/val.txt | voc12/train.txt | voc12/train_aug.txt] --out_cam $your_cam_dir --out_crf $your_crf_dir 48 | ``` 49 | 50 | 3. SEAM step evaluation. We provide python mIoU evaluation script `evaluation.py`, or you can use official development kit. Here we suggest to show the curve of mIoU with different background score. 51 | ``` 52 | python evaluation.py --list VOC2012/ImageSets/Segmentation/[val.txt | train.txt] --predict_dir $your_cam_dir --gt_dir VOC2012/SegmentationClass --comment $your_comments --type npy --curve True 53 | ``` 54 | 55 | ### Random walk step 56 | The random walk step keep the same with AffinityNet repository. 57 | 1. Train AffinityNet. 58 | ``` 59 | python train_aff.py --weights $pretrained_model --voc12_root VOC2012 --la_crf_dir $your_crf_dir_4.0 --ha_crf_dir $your_crf_dir_24.0 --session_name $your_session_name 60 | ``` 61 | 2. Random walk propagation 62 | ``` 63 | python infer_aff.py --weights $aff_weights --infer_list [voc12/val.txt | voc12/train.txt] --cam_dir $your_cam_dir --voc12_root VOC2012 --out_rw $your_rw_dir 64 | ``` 65 | 3. Random walk step evaluation 66 | ``` 67 | python evaluation.py --list VOC2012/ImageSets/Segmentation/[val.txt | train.txt] --predict_dir $your_rw_dir --gt_dir VOC2012/SegmentationClass --comment $your_comments --type png 68 | ``` 69 | 70 | ## Citation 71 | Please cite our paper if the code is helpful to your research. 72 | ``` 73 | @InProceedings{Wang_2020_CVPR_SEAM, 74 | author = {Yude Wang and Jie Zhang and Meina Kan and Shiguang Shan and Xilin Chen}, 75 | title = {Self-supervised Equivariant Attention Mechanism for Weakly Supervised Semantic Segmentation}, 76 | booktitle = {Proc. IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 77 | year = {2020} 78 | } 79 | ``` 80 | ## Reference 81 | [1] J. Ahn and S. Kwak. Learning pixel-level semantic affinity with image-level supervision for weakly supervised semantic segmentation. In Proc. IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2018. 82 | -------------------------------------------------------------------------------- /src/models/optimizers/sps.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import time 4 | import copy 5 | class Sps(torch.optim.Optimizer): 6 | def __init__(self, 7 | params, 8 | n_batches_per_epoch=500, 9 | init_step_size=1, 10 | c=0.5, 11 | gamma=2.0, 12 | eta_max=None, 13 | adapt_flag='smooth_iter', 14 | eps=1e-8, 15 | centralize_grad_norm=False, 16 | centralize_grad=False, 17 | momentum=0.): 18 | params = list(params) 19 | super().__init__(params, {}) 20 | self.eps = eps 21 | self.params = params 22 | self.c = c 23 | self.centralize_grad_norm = centralize_grad_norm 24 | self.centralize_grad = centralize_grad 25 | if centralize_grad: 26 | assert self.centralize_grad_norm is False 27 | self.eta_max = eta_max 28 | self.gamma = gamma 29 | self.init_step_size = init_step_size 30 | self.adapt_flag = adapt_flag 31 | self.state['step'] = 0 32 | self.state['step_size'] = init_step_size 33 | self.step_size_max = 0. 34 | self.n_batches_per_epoch = n_batches_per_epoch 35 | self.state['n_forwards'] = 0 36 | self.state['n_backwards'] = 0 37 | self.momentum = momentum 38 | self.params_prev = None 39 | if self.momentum != 0: 40 | self.params_prev = copy.deepcopy(params) 41 | 42 | def step(self, closure=None, loss=None, batch=None): 43 | if loss is None and closure is None: 44 | raise ValueError('please specify either closure or loss') 45 | if loss is not None: 46 | if not isinstance(loss, torch.Tensor): 47 | loss = torch.tensor(loss) 48 | # increment step 49 | self.state['step'] += 1 50 | # get fstar 51 | fstar = 0. 52 | # get loss and compute gradients 53 | if loss is None: 54 | loss = closure() 55 | else: 56 | assert closure is None, 'if loss is provided then closure should beNone' 57 | # save the current parameters: 58 | grad_current = get_grad_list(self.params, centralize_grad=self.centralize_grad) 59 | grad_norm = compute_grad_norm(grad_current, centralize_grad_norm=self.centralize_grad_norm) 60 | if grad_norm < 1e-8: 61 | step_size = 0. 62 | else: 63 | # adapt the step size 64 | if self.adapt_flag in ['constant']: 65 | # adjust the step size based on an upper bound and fstar 66 | step_size = (loss - fstar) / \ 67 | (self.c * (grad_norm)**2 + self.eps) 68 | if loss < fstar: 69 | step_size = 0. 70 | else: 71 | if self.eta_max is None: 72 | step_size = step_size.item() 73 | else: 74 | step_size = min(self.eta_max, step_size.item()) 75 | elif self.adapt_flag in ['smooth_iter']: 76 | # smoothly adjust the step size 77 | step_size = loss / (self.c * (grad_norm)**2 + self.eps) 78 | coeff = self.gamma**(1./self.n_batches_per_epoch) 79 | step_size = min(coeff * self.state['step_size'], 80 | step_size.item()) 81 | else: 82 | raise ValueError('adapt_flag: %s not supported' % 83 | self.adapt_flag) 84 | # update with step size 85 | if self.momentum > 0: 86 | params_tmp = copy.deepcopy(self.params) 87 | for p, g, p_prev in zip(self.params, grad_current, self.params_prev): 88 | p.data = p - step_size * g + self.momentum * (p - p_prev) 89 | self.params_prev = params_tmp 90 | else: 91 | for p, g in zip(self.params, grad_current): 92 | p.data.add_(- float(step_size), g) 93 | # update state with metrics 94 | self.state['n_forwards'] += 1 95 | self.state['n_backwards'] += 1 96 | self.state['step_size'] = step_size 97 | self.state['grad_norm'] = grad_norm.item() 98 | if torch.isnan(self.params[0]).sum() > 0: 99 | raise ValueError('Got NaNs') 100 | return float(loss) 101 | # utils 102 | # ------------------------------ 103 | def compute_grad_norm(grad_list, centralize_grad_norm=False): 104 | grad_norm = 0. 105 | for g in grad_list: 106 | if g is None: 107 | continue 108 | if g.dim() > 1 and centralize_grad_norm: 109 | # centralize grads 110 | g.add_(-g.mean(dim = tuple(range(1,g.dim())), keepdim = True)) 111 | grad_norm += torch.sum(torch.mul(g, g)) 112 | grad_norm = torch.sqrt(grad_norm) 113 | return grad_norm 114 | 115 | def get_grad_list(params, centralize_grad=False): 116 | grad_list = [] 117 | for p in params: 118 | g = p.grad.data 119 | if len(list(g.size()))>1 and centralize_grad: 120 | # centralize grads 121 | g.add_(-g.mean(dim = tuple(range(1,len(list(g.size())))), 122 | keepdim = True)) 123 | grad_list += [g] 124 | return grad_list 125 | -------------------------------------------------------------------------------- /src/datasets/transformers/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | 5 | from scipy.ndimage import zoom 6 | from torchvision import transforms 7 | from . import trans_utils as tu 8 | from haven import haven_utils as hu 9 | # from batchgenerators.augmentations import crop_and_pad_augmentations 10 | from . import micnn_augmentor 11 | 12 | 13 | def apply_transform(split, image, label=None, transform_name='basic', 14 | exp_dict=None): 15 | 16 | if transform_name == 'basic': 17 | image /= 4095 18 | 19 | if exp_dict['dataset']['transform_mode'] is not None: 20 | image, label = image[200:550], label[200:550] 21 | 22 | if split == 'train': 23 | if exp_dict['dataset']['transform_mode'] == 2: 24 | da = micnn_augmentor.Data_Augmentation() 25 | image, label = da.run(image[None,None], label[None,None]) 26 | 27 | if exp_dict['dataset']['transform_mode'] == 3: 28 | if np.random.rand() < 0.5: 29 | da = micnn_augmentor.Data_Augmentation() 30 | image, label = da.run(image[None,None], label[None,None]) 31 | 32 | image = image.squeeze() 33 | label = label.squeeze() 34 | # hu.save_image('tmp_after.png', image) 35 | 36 | image = torch.FloatTensor(image)[None] 37 | # assert image.min() >= 0 38 | # assert image.max() <= 1 39 | 40 | normalize = transforms.Normalize((0.5,), (0.5,)) 41 | image = normalize(image) 42 | return image, label 43 | 44 | if transform_name == 'basic_hu': 45 | image+= 1024 46 | image /= 5024 47 | assert image.min()>=0 and image.max() <= 1 48 | # if exp_dict['dataset']['transform_mode'] is not None: 49 | # image, label = image[200:550], label[200:550] 50 | class_map = tu.get_class_map(exp_dict['n_classes']) 51 | lbl_trans = transforms.Compose([ 52 | tu.GroupLabels(class_map), 53 | tu.PreparePilLabel(), 54 | transforms.ToPILImage(), 55 | tu.UndoPreparePilLabel(), 56 | transforms.ToTensor(), 57 | tu.Squeeze(), 58 | ]) 59 | label = lbl_trans(label) 60 | 61 | if split == 'train': 62 | if exp_dict['dataset']['transform_mode'] == 2: 63 | da = micnn_augmentor.Data_Augmentation() 64 | image, label = da.run(image[None,None], label[None,None]) 65 | 66 | if exp_dict['dataset']['transform_mode'] == 3: 67 | if np.random.rand() < 0.5: 68 | da = micnn_augmentor.Data_Augmentation() 69 | image, label = da.run(image[None,None], label[None,None].numpy()) 70 | label = torch.LongTensor(label).squeeze() 71 | 72 | image = image.squeeze() 73 | label = label.squeeze() 74 | # hu.save_image('tmp_after.png', image) 75 | 76 | image = torch.FloatTensor(image)[None] 77 | # assert image.min() >= 0 78 | # assert image.max() <= 1 79 | 80 | normalize = transforms.Normalize((0.5,), (0.5,)) 81 | image = normalize(image) 82 | return image, label 83 | 84 | elif transform_name == 'mdai_basic': 85 | img_trans = transforms.Compose([ 86 | tu.Threshold(min=-1000, max=50), 87 | transforms.ToPILImage(), 88 | transforms.ToTensor(), 89 | transforms.Normalize( 90 | mean=torch.tensor([-653.2204]), 91 | std=torch.tensor([628.5188]) 92 | ) 93 | ]) 94 | class_map = tu.get_class_map(exp_dict['n_classes']) 95 | lbl_trans = transforms.Compose([ 96 | tu.GroupLabels(class_map), 97 | tu.PreparePilLabel(), 98 | transforms.ToPILImage(), 99 | tu.UndoPreparePilLabel(), 100 | transforms.ToTensor(), 101 | tu.Squeeze(), 102 | ]) 103 | image, label = img_trans(image), lbl_trans(label) 104 | if split == 'train': 105 | if exp_dict['dataset'].get('transform_mode') == 3: 106 | if np.random.rand() < 0.5: 107 | da = micnn_augmentor.Data_Augmentation() 108 | image, label = da.run(image.numpy()[None], label.numpy()[None,None]) 109 | image, label = torch.FloatTensor(image).squeeze()[None], torch.LongTensor(label).squeeze() 110 | if label is None: 111 | return image 112 | return image, label 113 | 114 | elif transform_name == 'pspnet_transformer': 115 | windows = ['lung'] 116 | mean, std = tu.get_normalization_stats(windows) 117 | thresh_min, thresh_max = tu.get_thresholds_stats(windows) 118 | 119 | img_trans = transforms.Compose([ 120 | tu.Threshold(min=thresh_min, max=thresh_max), 121 | transforms.ToTensor(), 122 | transforms.Normalize( 123 | mean=mean, 124 | std=std 125 | ) 126 | ]) 127 | 128 | class_map = tu.get_class_map(exp_dict['n_classes']) 129 | lbl_trans = transforms.Compose([ 130 | tu.GroupLabels(class_map), 131 | tu.PreparePilLabel(), 132 | transforms.ToPILImage(), 133 | tu.UndoPreparePilLabel(), 134 | transforms.ToTensor(), 135 | tu.Squeeze(), 136 | ]) 137 | if label is None: 138 | return img_trans(image) 139 | return img_trans(image).float(), lbl_trans(label) 140 | 141 | 142 | -------------------------------------------------------------------------------- /scripts/SEAM/evaluation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | from PIL import Image 5 | import multiprocessing 6 | import argparse 7 | 8 | categories = ['background','aeroplane','bicycle','bird','boat','bottle','bus','car','cat','chair','cow', 9 | 'diningtable','dog','horse','motorbike','person','pottedplant','sheep','sofa','train','tvmonitor'] 10 | def do_python_eval(predict_folder, gt_folder, name_list, num_cls=21, input_type='png', threshold=1.0, printlog=False): 11 | TP = [] 12 | P = [] 13 | T = [] 14 | for i in range(num_cls): 15 | TP.append(multiprocessing.Value('i', 0, lock=True)) 16 | P.append(multiprocessing.Value('i', 0, lock=True)) 17 | T.append(multiprocessing.Value('i', 0, lock=True)) 18 | 19 | def compare(start,step,TP,P,T,input_type,threshold): 20 | for idx in range(start,len(name_list),step): 21 | name = name_list[idx] 22 | if input_type == 'png': 23 | predict_file = os.path.join(predict_folder,'%s.png'%name) 24 | predict = np.array(Image.open(predict_file)) #cv2.imread(predict_file) 25 | elif input_type == 'npy': 26 | predict_file = os.path.join(predict_folder,'%s.npy'%name) 27 | predict_dict = np.load(predict_file, allow_pickle=True).item() 28 | h, w = list(predict_dict.values())[0].shape 29 | tensor = np.zeros((21,h,w),np.float32) 30 | for key in predict_dict.keys(): 31 | tensor[key+1] = predict_dict[key] 32 | tensor[0,:,:] = threshold 33 | predict = np.argmax(tensor, axis=0).astype(np.uint8) 34 | 35 | gt_file = os.path.join(gt_folder,'%s.png'%name) 36 | gt = np.array(Image.open(gt_file)) 37 | cal = gt<255 38 | mask = (predict==gt) * cal 39 | 40 | for i in range(num_cls): 41 | P[i].acquire() 42 | P[i].value += np.sum((predict==i)*cal) 43 | P[i].release() 44 | T[i].acquire() 45 | T[i].value += np.sum((gt==i)*cal) 46 | T[i].release() 47 | TP[i].acquire() 48 | TP[i].value += np.sum((gt==i)*mask) 49 | TP[i].release() 50 | p_list = [] 51 | for i in range(8): 52 | p = multiprocessing.Process(target=compare, args=(i,8,TP,P,T,input_type,threshold)) 53 | p.start() 54 | p_list.append(p) 55 | for p in p_list: 56 | p.join() 57 | IoU = [] 58 | T_TP = [] 59 | P_TP = [] 60 | FP_ALL = [] 61 | FN_ALL = [] 62 | for i in range(num_cls): 63 | IoU.append(TP[i].value/(T[i].value+P[i].value-TP[i].value+1e-10)) 64 | T_TP.append(T[i].value/(TP[i].value+1e-10)) 65 | P_TP.append(P[i].value/(TP[i].value+1e-10)) 66 | FP_ALL.append((P[i].value-TP[i].value)/(T[i].value + P[i].value - TP[i].value + 1e-10)) 67 | FN_ALL.append((T[i].value-TP[i].value)/(T[i].value + P[i].value - TP[i].value + 1e-10)) 68 | loglist = {} 69 | for i in range(num_cls): 70 | loglist[categories[i]] = IoU[i] * 100 71 | 72 | miou = np.mean(np.array(IoU)) 73 | loglist['mIoU'] = miou * 100 74 | if printlog: 75 | for i in range(num_cls): 76 | if i%2 != 1: 77 | print('%11s:%7.3f%%'%(categories[i],IoU[i]*100),end='\t') 78 | else: 79 | print('%11s:%7.3f%%'%(categories[i],IoU[i]*100)) 80 | print('\n======================================================') 81 | print('%11s:%7.3f%%'%('mIoU',miou*100)) 82 | return loglist 83 | 84 | def writedict(file, dictionary): 85 | s = '' 86 | for key in dictionary.keys(): 87 | sub = '%s:%s '%(key, dictionary[key]) 88 | s += sub 89 | s += '\n' 90 | file.write(s) 91 | 92 | def writelog(filepath, metric, comment): 93 | filepath = filepath 94 | logfile = open(filepath,'a') 95 | import time 96 | logfile.write(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) 97 | logfile.write('\t%s\n'%comment) 98 | writedict(logfile, metric) 99 | logfile.write('=====================================\n') 100 | logfile.close() 101 | 102 | 103 | if __name__ == '__main__': 104 | 105 | parser = argparse.ArgumentParser() 106 | parser.add_argument("--list", default='./VOC2012/ImageSets/Segmentation/train.txt', type=str) 107 | parser.add_argument("--predict_dir", default='./out_rw', type=str) 108 | parser.add_argument("--gt_dir", default='./VOC2012/SegmentationClass', type=str) 109 | parser.add_argument('--logfile', default='./evallog.txt',type=str) 110 | parser.add_argument('--comment', required=True, type=str) 111 | parser.add_argument('--type', default='png', choices=['npy', 'png'], type=str) 112 | parser.add_argument('--t', default=None, type=float) 113 | parser.add_argument('--curve', default=False, type=bool) 114 | args = parser.parse_args() 115 | 116 | if args.type == 'npy': 117 | assert args.t is not None or args.curve 118 | df = pd.read_csv(args.list, names=['filename']) 119 | name_list = df['filename'].values 120 | if not args.curve: 121 | loglist = do_python_eval(args.predict_dir, args.gt_dir, name_list, 21, args.type, args.t, printlog=True) 122 | writelog(args.logfile, loglist, args.comment) 123 | else: 124 | l = [] 125 | for i in range(60): 126 | t = i/100.0 127 | loglist = do_python_eval(args.predict_dir, args.gt_dir, name_list, 21, args.type, t) 128 | l.append(loglist['mIoU']) 129 | print('%d/60 background score: %.3f\tmIoU: %.3f%%'%(i, t, loglist['mIoU'])) 130 | writelog(args.logfile, {'mIoU':l}, args.comment) 131 | -------------------------------------------------------------------------------- /scripts/across_habitat.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | 3 | path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 4 | sys.path.insert(0, path) 5 | 6 | from haven import haven_chk as hc 7 | from haven import haven_results as hr 8 | from haven import haven_utils as hu 9 | import torch 10 | import torchvision 11 | import tqdm 12 | import pandas as pd 13 | import pprint 14 | import itertools 15 | import os 16 | import pylab as plt 17 | import time 18 | import numpy as np 19 | 20 | from src import models 21 | from src import datasets 22 | from src import utils as ut 23 | from src.models import metrics 24 | 25 | import argparse 26 | 27 | from torch.utils.data import sampler 28 | from torch.utils.data.sampler import RandomSampler 29 | from torch.backends import cudnn 30 | from torch.nn import functional as F 31 | from torch.utils.data import DataLoader 32 | import pandas as pd 33 | 34 | cudnn.benchmark = True 35 | 36 | if __name__ == "__main__": 37 | savedir_base = '/mnt/public/results/toolkit/weak_supervision' 38 | 39 | hash_list = ['b04090f27c7c52bcec65f6ba455ed2d8', 40 | '6d4af38d64b23586e71a198de2608333', 41 | '84ced18cf5c1fb3ad5820cc1b55a38fa', 42 | '63f29eec3dbe1e03364f198ed7d4b414', 43 | '017e7441c2f581b6fee9e3ac6f574edc'] 44 | hash_dct = {'b04090f27c7c52bcec65f6ba455ed2d8': 'Fully_Supervised', 45 | '6d4af38d64b23586e71a198de2608333': 'LCFCN', 46 | '84ced18cf5c1fb3ad5820cc1b55a38fa': 'LCFCN+Affinity_(ours)', 47 | '63f29eec3dbe1e03364f198ed7d4b414': 'Point-level_Loss ', 48 | '017e7441c2f581b6fee9e3ac6f574edc': 'Cross_entropy_Loss+pseudo-mask'} 49 | datadir = '/mnt/public/datasets/DeepFish/' 50 | 51 | score_list = [] 52 | for hash_id in hash_list: 53 | fname = os.path.join('/mnt/public/predictions/habitat/%s.pkl' % hash_id) 54 | exp_dict = hu.load_json(os.path.join(savedir_base, hash_id, 'exp_dict.json')) 55 | if os.path.exists(fname): 56 | print('FOUND:', fname) 57 | val_dict = hu.load_pkl(fname) 58 | else: 59 | 60 | train_set = datasets.get_dataset(dataset_dict=exp_dict["dataset"], 61 | split='train', 62 | datadir=datadir, 63 | exp_dict=exp_dict, 64 | dataset_size=exp_dict['dataset_size']) 65 | 66 | test_set = datasets.get_dataset(dataset_dict=exp_dict["dataset"], 67 | split='test', 68 | datadir=datadir, 69 | exp_dict=exp_dict, 70 | dataset_size=exp_dict['dataset_size']) 71 | 72 | test_loader = DataLoader(test_set, 73 | batch_size=1, 74 | collate_fn=ut.collate_fn, 75 | num_workers=0) 76 | pprint.pprint(exp_dict) 77 | # Model 78 | # ================== 79 | model = models.get_model(model_dict=exp_dict['model'], 80 | exp_dict=exp_dict, 81 | train_set=train_set).cuda() 82 | 83 | model_path = os.path.join(savedir_base, hash_id, 'model_best.pth') 84 | 85 | # load best model 86 | model.load_state_dict(hu.torch_load(model_path)) 87 | # loop over the val_loader and saves image 88 | # get counts 89 | habitats = [] 90 | for i, batch in enumerate(test_loader): 91 | habitat = batch['meta'][0]['habitat'] 92 | habitats += [habitat] 93 | habitats = np.array(habitats) 94 | 95 | val_dict = {} 96 | val_dict_lst = [] 97 | for h in np.unique(habitats): 98 | val_meter = metrics.SegMeter(split=test_loader.dataset.split) 99 | 100 | for i, batch in enumerate(tqdm.tqdm(test_loader)): 101 | habitat = batch['meta'][0]['habitat'] 102 | if habitat != h: 103 | continue 104 | 105 | val_meter.val_on_batch(model, batch) 106 | score_dict = val_meter.get_avg_score() 107 | pprint.pprint(score_dict) 108 | 109 | val_dict[h] = val_meter.get_avg_score() 110 | val_dict_dfc = pd.DataFrame([val_meter.get_avg_score()]) 111 | val_dict_dfc.insert(0, "Habitat", h, True) 112 | val_dict_dfc.rename( 113 | columns={'test_score': 'mIoU', 'test_class0': 'IoU class 0', 'test_class1': 'IoU class 1', 114 | 'test_mae': 'MAE', 'test_game': 'GAME'}, inplace=True) 115 | val_dict_lst.append(val_dict_dfc) 116 | val_dict_df = pd.concat(val_dict_lst, axis=0) 117 | val_dict_df.to_csv(os.path.join('/mnt/public/predictions/habitat/', "%s_habitat_score_df.csv" % hash_id), 118 | index=False) 119 | val_dict_df.to_latex(os.path.join('/mnt/public/predictions/habitat/', "%s_habitat_score_df.tex" % hash_id), 120 | index=False, caption=hash_dct[hash_id], label=hash_dct[hash_id]) 121 | 122 | hu.save_pkl(fname, val_dict) 123 | 124 | val_dict['model'] = exp_dict['model'] 125 | score_list += [val_dict] 126 | 127 | print(pd.DataFrame(score_list)) 128 | # score_df = pd.DataFrame(score_list) 129 | # score_df.to_csv(os.path.join('/mnt/public/predictions/habitat/', "habitat_score_df.csv")) 130 | # score_df.to_latex(os.path.join('/mnt/public/predictions/habitat/', "habitat_score_df.tex")) 131 | 132 | -------------------------------------------------------------------------------- /scripts/SEAM/tool/pyutils.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import time 4 | import sys 5 | 6 | class Logger(object): 7 | def __init__(self, outfile): 8 | self.terminal = sys.stdout 9 | self.log = open(outfile, "w") 10 | sys.stdout = self 11 | 12 | def write(self, message): 13 | self.terminal.write(message) 14 | self.log.write(message) 15 | 16 | def flush(self): 17 | self.terminal.flush() 18 | 19 | 20 | class AverageMeter: 21 | def __init__(self, *keys): 22 | self.__data = dict() 23 | for k in keys: 24 | self.__data[k] = [0.0, 0] 25 | 26 | def add(self, dict): 27 | for k, v in dict.items(): 28 | self.__data[k][0] += v 29 | self.__data[k][1] += 1 30 | 31 | def get(self, *keys): 32 | if len(keys) == 1: 33 | return self.__data[keys[0]][0] / self.__data[keys[0]][1] 34 | else: 35 | v_list = [self.__data[k][0] / self.__data[k][1] for k in keys] 36 | return tuple(v_list) 37 | 38 | def pop(self, key=None): 39 | if key is None: 40 | for k in self.__data.keys(): 41 | self.__data[k] = [0.0, 0] 42 | else: 43 | v = self.get(key) 44 | self.__data[key] = [0.0, 0] 45 | return v 46 | 47 | 48 | class Timer: 49 | def __init__(self, starting_msg = None): 50 | self.start = time.time() 51 | self.stage_start = self.start 52 | 53 | if starting_msg is not None: 54 | print(starting_msg, time.ctime(time.time())) 55 | 56 | 57 | def update_progress(self, progress): 58 | self.elapsed = time.time() - self.start 59 | self.est_total = self.elapsed / progress 60 | self.est_remaining = self.est_total - self.elapsed 61 | self.est_finish = int(self.start + self.est_total) 62 | 63 | 64 | def str_est_finish(self): 65 | return str(time.ctime(self.est_finish)) 66 | 67 | def get_stage_elapsed(self): 68 | return time.time() - self.stage_start 69 | 70 | def reset_stage(self): 71 | self.stage_start = time.time() 72 | 73 | 74 | from multiprocessing.pool import ThreadPool 75 | 76 | class BatchThreader: 77 | 78 | def __init__(self, func, args_list, batch_size, prefetch_size=4, processes=12): 79 | self.batch_size = batch_size 80 | self.prefetch_size = prefetch_size 81 | 82 | self.pool = ThreadPool(processes=processes) 83 | self.async_result = [] 84 | 85 | self.func = func 86 | self.left_args_list = args_list 87 | self.n_tasks = len(args_list) 88 | 89 | # initial work 90 | self.__start_works(self.__get_n_pending_works()) 91 | 92 | 93 | def __start_works(self, times): 94 | for _ in range(times): 95 | args = self.left_args_list.pop(0) 96 | self.async_result.append( 97 | self.pool.apply_async(self.func, args)) 98 | 99 | 100 | def __get_n_pending_works(self): 101 | return min((self.prefetch_size + 1) * self.batch_size - len(self.async_result) 102 | , len(self.left_args_list)) 103 | 104 | 105 | 106 | def pop_results(self): 107 | 108 | n_inwork = len(self.async_result) 109 | 110 | n_fetch = min(n_inwork, self.batch_size) 111 | rtn = [self.async_result.pop(0).get() 112 | for _ in range(n_fetch)] 113 | 114 | to_fill = self.__get_n_pending_works() 115 | if to_fill == 0: 116 | self.pool.close() 117 | else: 118 | self.__start_works(to_fill) 119 | 120 | return rtn 121 | 122 | 123 | 124 | 125 | def get_indices_of_pairs(radius, size): 126 | 127 | search_dist = [] 128 | 129 | for x in range(1, radius): 130 | search_dist.append((0, x)) 131 | 132 | for y in range(1, radius): 133 | for x in range(-radius + 1, radius): 134 | if x * x + y * y < radius * radius: 135 | search_dist.append((y, x)) 136 | 137 | radius_floor = radius - 1 138 | 139 | full_indices = np.reshape(np.arange(0, size[0]*size[1], dtype=np.int64), 140 | (size[0], size[1])) 141 | 142 | cropped_height = size[0] - radius_floor 143 | cropped_width = size[1] - 2 * radius_floor 144 | 145 | indices_from = np.reshape(full_indices[:-radius_floor, radius_floor:-radius_floor], 146 | [-1]) 147 | 148 | indices_to_list = [] 149 | 150 | for dy, dx in search_dist: 151 | indices_to = full_indices[dy:dy + cropped_height, 152 | radius_floor + dx:radius_floor + dx + cropped_width] 153 | indices_to = np.reshape(indices_to, [-1]) 154 | 155 | indices_to_list.append(indices_to) 156 | 157 | concat_indices_to = np.concatenate(indices_to_list, axis=0) 158 | 159 | return indices_from, concat_indices_to 160 | 161 | def get_indices_of_pairs_circle(radius, size): 162 | 163 | search_dist = [] 164 | 165 | for y in range(-radius + 1, radius): 166 | for x in range(-radius + 1, radius): 167 | if x * x + y * y < radius * radius and x*x+y*y!=0: 168 | search_dist.append((y, x)) 169 | 170 | radius_floor = radius - 1 171 | 172 | full_indices = np.reshape(np.arange(0, size[0]*size[1], dtype=np.int64), 173 | (size[0], size[1])) 174 | 175 | cropped_height = size[0] - 2 * radius_floor 176 | cropped_width = size[1] - 2 * radius_floor 177 | 178 | indices_from = np.reshape(full_indices[radius_floor:-radius_floor, radius_floor:-radius_floor], 179 | [-1]) 180 | 181 | indices_to_list = [] 182 | 183 | for dy, dx in search_dist: 184 | indices_to = full_indices[radius_floor + dy : radius_floor + dy + cropped_height, 185 | radius_floor + dx : radius_floor + dx + cropped_width] 186 | indices_to = np.reshape(indices_to, [-1]) 187 | 188 | indices_to_list.append(indices_to) 189 | 190 | concat_indices_to = np.concatenate(indices_to_list, axis=0) 191 | 192 | return indices_from, concat_indices_to 193 | -------------------------------------------------------------------------------- /scripts/SEAM/infer_SEAM.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | import cv2 5 | import os 6 | from .voc12 import data 7 | import scipy.misc 8 | import importlib 9 | from torch.utils.data import DataLoader 10 | import torchvision 11 | from .tool import imutils, pyutils#, visualization 12 | import argparse 13 | from PIL import Image 14 | import torch.nn.functional as F 15 | import pandas as pd 16 | from .network import resnet38_SEAM 17 | import PIL.Image 18 | from torchvision.transforms import ToPILImage, ToTensor 19 | from torchvision.transforms import functional as Ff 20 | 21 | def HCW_to_CHW(tensor, sal=False): 22 | if sal: 23 | tensor = np.expand_dims(tensor, axis=0) 24 | else: 25 | tensor = np.transpose(tensor, (1, 2, 0)) 26 | return tensor 27 | 28 | def msf_img_lists(name, img, label, SEAM_model): 29 | name = name, 30 | label = label 31 | # label = ToTensor()(label) 32 | # img = ToPILImage()(img) 33 | model = SEAM_model 34 | unit = 1 35 | scales = [0.5, 1.0, 1.5, 2.0] 36 | inter_transform = torchvision.transforms.Compose( 37 | [np.asarray, 38 | model.normalize, 39 | # ToTensor(), 40 | imutils.HWC_to_CHW 41 | ]) 42 | intera_transform = torchvision.transforms.Compose( 43 | [ToTensor(), 44 | HCW_to_CHW 45 | ]) 46 | 47 | 48 | rounded_size = (int(round(img.size[0] / unit) * unit), int(round(img.size[1] / unit) * unit)) 49 | 50 | ms_img_list = [] 51 | for s in scales: 52 | target_size = (round(rounded_size[0] * s), 53 | round(rounded_size[1] * s)) 54 | s_img = img.resize(target_size, resample=PIL.Image.CUBIC) 55 | ms_img_list.append(s_img) 56 | 57 | if inter_transform: 58 | for i in range(len(ms_img_list)): 59 | ms_img_list[i] = inter_transform(ms_img_list[i]) 60 | 61 | msf_img_list = [] 62 | for i in range(len(ms_img_list)): 63 | msf_img_list.append(ms_img_list[i]) 64 | msf_img_list.append(np.flip(ms_img_list[i], -1).copy()) 65 | 66 | for i in range(len(msf_img_list)): 67 | msf_img_list[i] = intera_transform(msf_img_list[i]) 68 | msf_img_list[i] = msf_img_list[i][None] 69 | 70 | 71 | 72 | return name, msf_img_list, label 73 | def infer_SEAM(name, img, label, weights_dir = "", model=None): 74 | 75 | weights =weights_dir 76 | # network ="SEAM.network.resnet38_SEAM" 77 | num_workers =1 78 | out_cam_pred_alpha =0.26 79 | 80 | # args = parser.parse_args() 81 | crf_alpha = [4,24] 82 | # model = getattr(importlib.import_module(network), 'Net')() 83 | if model is None: 84 | model = resnet38_SEAM.Net() 85 | model.load_state_dict(torch.load(weights)) 86 | 87 | model.eval() 88 | model.cuda() 89 | 90 | 91 | n_gpus = torch.cuda.device_count() 92 | model_replicas = torch.nn.parallel.replicate(model, list(range(n_gpus))) 93 | img_name, img_list, label = msf_img_lists(name, img, label, model) 94 | img_name = img_name[0] 95 | 96 | # for iter, (img_name, img_list, label) in enumerate(infer_data_loader): 97 | # img_name = img_name[0]; label = label[0] 98 | 99 | # img_path = voc12.data.get_img_path(img_name, voc12_root) 100 | # orig_img = np.asarray(Image.open(img_path)) 101 | orig_img = np.asarray(img) 102 | orig_img_size = orig_img.shape[:2] 103 | 104 | def _work(i, img): 105 | with torch.no_grad(): 106 | with torch.cuda.device(i%n_gpus): 107 | # img = ToTensor()(img)[None] 108 | _, cam = model_replicas[i%n_gpus](img.cuda()) 109 | cam = F.upsample(cam[:,1:,:,:], orig_img_size, mode='bilinear', align_corners=False)[0] 110 | cam = cam.cpu().numpy() * label.cpu().clone().view(20, 1, 1).numpy() 111 | if i % 2 == 1: 112 | cam = np.flip(cam, axis=-1) 113 | return cam 114 | 115 | thread_pool = pyutils.BatchThreader(_work, list(enumerate(img_list)), 116 | batch_size=12, prefetch_size=0, processes=num_workers) 117 | 118 | cam_list = thread_pool.pop_results() 119 | 120 | sum_cam = np.sum(cam_list, axis=0) 121 | sum_cam[sum_cam < 0] = 0 122 | cam_max = np.max(sum_cam, (1,2), keepdims=True) 123 | cam_min = np.min(sum_cam, (1,2), keepdims=True) 124 | sum_cam[sum_cam < cam_min+1e-5] = 0 125 | norm_cam = (sum_cam-cam_min-1e-5) / (cam_max - cam_min + 1e-5) 126 | 127 | cam_dict = {} 128 | for i in range(20): 129 | if label[i] > 1e-5: 130 | cam_dict[i] = norm_cam[i] 131 | 132 | # if out_cam is not None: 133 | # np.save(os.path.join(out_cam, img_name + '.npy'), cam_dict) 134 | # print("saved : %s"%os.path.join(out_cam, img_name + '.npy')) 135 | 136 | # if out_cam_pred is not None: 137 | bg_score = [np.ones_like(norm_cam[0])*out_cam_pred_alpha] 138 | pred = np.argmax(np.concatenate((bg_score, norm_cam)), 0) 139 | # scipy.misc.imsave(os.path.join(out_cam_pred, img_name + '.png'), pred.astype(np.uint8)) 140 | # print("saved : %s" % os.path.join(out_cam_pred, img_name + '.png')) 141 | 142 | def _crf_with_alpha(cam_dict, alpha): 143 | v = np.array(list(cam_dict.values())) 144 | bg_score = np.power(1 - np.max(v, axis=0, keepdims=True), alpha) 145 | bgcam_score = np.concatenate((bg_score, v), axis=0) 146 | crf_score = imutils.crf_inference(orig_img, bgcam_score, labels=bgcam_score.shape[0]) 147 | 148 | n_crf_al = dict() 149 | 150 | n_crf_al[0] = crf_score[0] 151 | for i, key in enumerate(cam_dict.keys()): 152 | n_crf_al[key+1] = crf_score[i+1] 153 | 154 | return n_crf_al 155 | 156 | # if out_crf is not None: 157 | for t in crf_alpha: 158 | crf = _crf_with_alpha(cam_dict, t) 159 | # folder = out_crf + ('_%.1f'%t) 160 | # if not os.path.exists(folder): 161 | # os.makedirs(folder) 162 | # np.save(os.path.join(folder, img_name + '.npy'), crf) 163 | # print("saved : %s" % os.path.join(folder, img_name + '.npy')) 164 | 165 | # print("DONE infer_SEAM") 166 | return cam_dict, pred, crf 167 | 168 | -------------------------------------------------------------------------------- /scripts/SEAM/train_aff.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | import random 5 | from torch.utils.data import DataLoader 6 | from torchvision import transforms 7 | import voc12.data 8 | from tool import pyutils, imutils, torchutils 9 | import argparse 10 | import importlib 11 | 12 | 13 | 14 | if __name__ == '__main__': 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("--batch_size", default=8, type=int) 18 | parser.add_argument("--max_epoches", default=8, type=int) 19 | parser.add_argument("--network", default="network.resnet38_aff", type=str) 20 | parser.add_argument("--lr", default=0.01, type=float) 21 | parser.add_argument("--num_workers", default=8, type=int) 22 | parser.add_argument("--wt_dec", default=5e-4, type=float) 23 | parser.add_argument("--train_list", default="voc12/train_aug.txt", type=str) 24 | parser.add_argument("--val_list", default="voc12/val.txt", type=str) 25 | parser.add_argument("--session_name", default="resnet38_aff", type=str) 26 | parser.add_argument("--crop_size", default=448, type=int) 27 | parser.add_argument("--weights", required=True, type=str) 28 | parser.add_argument("--voc12_root", default='VOC2012', type=str) 29 | parser.add_argument("--la_crf_dir", required=True, type=str) 30 | parser.add_argument("--ha_crf_dir", required=True, type=str) 31 | args = parser.parse_args() 32 | 33 | pyutils.Logger(args.session_name + '.log') 34 | 35 | print(vars(args)) 36 | 37 | model = getattr(importlib.import_module(args.network), 'Net')() 38 | 39 | print(model) 40 | 41 | 42 | train_dataset = voc12.data.VOC12AffDataset(args.train_list, label_la_dir=args.la_crf_dir, label_ha_dir=args.ha_crf_dir, 43 | voc12_root=args.voc12_root, cropsize=args.crop_size, radius=5, 44 | joint_transform_list=[ 45 | None, 46 | None, 47 | imutils.RandomCrop(args.crop_size), 48 | imutils.RandomHorizontalFlip() 49 | ], 50 | img_transform_list=[ 51 | transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1), 52 | np.asarray, 53 | model.normalize, 54 | imutils.HWC_to_CHW 55 | ], 56 | label_transform_list=[ 57 | None, 58 | None, 59 | None, 60 | imutils.AvgPool2d(8) 61 | ]) 62 | def worker_init_fn(worker_id): 63 | np.random.seed(1 + worker_id) 64 | train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, 65 | pin_memory=True, drop_last=True, worker_init_fn=worker_init_fn) 66 | max_step = len(train_dataset) // args.batch_size * args.max_epoches 67 | 68 | param_groups = model.get_parameter_groups() 69 | optimizer = torchutils.PolyOptimizer([ 70 | {'params': param_groups[0], 'lr': args.lr, 'weight_decay': args.wt_dec}, 71 | {'params': param_groups[1], 'lr': 2*args.lr, 'weight_decay': 0}, 72 | {'params': param_groups[2], 'lr': 10*args.lr, 'weight_decay': args.wt_dec}, 73 | {'params': param_groups[3], 'lr': 20*args.lr, 'weight_decay': 0} 74 | ], lr=args.lr, weight_decay=args.wt_dec, max_step=max_step) 75 | 76 | if args.weights[-7:] == '.params': 77 | import network.resnet38d 78 | assert args.network == "network.resnet38_aff" 79 | weights_dict = network.resnet38d.convert_mxnet_to_torch(args.weights) 80 | else: 81 | weights_dict = torch.load(args.weights) 82 | 83 | model.load_state_dict(weights_dict, strict=False) 84 | model = torch.nn.DataParallel(model).cuda() 85 | model.train() 86 | 87 | avg_meter = pyutils.AverageMeter('loss', 'bg_loss', 'fg_loss', 'neg_loss', 'bg_cnt', 'fg_cnt', 'neg_cnt') 88 | 89 | timer = pyutils.Timer("Session started: ") 90 | 91 | for ep in range(args.max_epoches): 92 | 93 | for iter, pack in enumerate(train_data_loader): 94 | 95 | aff = model.forward(pack[0]) 96 | 97 | bg_label = pack[1][0].cuda(non_blocking=True) 98 | fg_label = pack[1][1].cuda(non_blocking=True) 99 | neg_label = pack[1][2].cuda(non_blocking=True) 100 | 101 | bg_count = torch.sum(bg_label) + 1e-5 102 | fg_count = torch.sum(fg_label) + 1e-5 103 | neg_count = torch.sum(neg_label) + 1e-5 104 | 105 | bg_loss = torch.sum(- bg_label * torch.log(aff + 1e-5)) / bg_count 106 | fg_loss = torch.sum(- fg_label * torch.log(aff + 1e-5)) / fg_count 107 | neg_loss = torch.sum(- neg_label * torch.log(1. + 1e-5 - aff)) / neg_count 108 | 109 | loss = bg_loss/4 + fg_loss/4 + neg_loss/2 110 | 111 | optimizer.zero_grad() 112 | loss.backward() 113 | optimizer.step() 114 | 115 | avg_meter.add({ 116 | 'loss': loss.item(), 117 | 'bg_loss': bg_loss.item(), 'fg_loss': fg_loss.item(), 'neg_loss': neg_loss.item(), 118 | 'bg_cnt': bg_count.item(), 'fg_cnt': fg_count.item(), 'neg_cnt': neg_count.item() 119 | }) 120 | 121 | if (optimizer.global_step - 1) % 50 == 0: 122 | 123 | timer.update_progress(optimizer.global_step / max_step) 124 | 125 | print('Iter:%5d/%5d' % (optimizer.global_step-1, max_step), 126 | 'loss:%.4f %.4f %.4f %.4f' % avg_meter.get('loss', 'bg_loss', 'fg_loss', 'neg_loss'), 127 | 'cnt:%.0f %.0f %.0f' % avg_meter.get('bg_cnt', 'fg_cnt', 'neg_cnt'), 128 | 'imps:%.1f' % ((iter+1) * args.batch_size / timer.get_stage_elapsed()), 129 | 'Fin:%s' % (timer.str_est_finish()), 130 | 'lr: %.4f' % (optimizer.param_groups[0]['lr']), flush=True) 131 | 132 | avg_meter.pop() 133 | 134 | 135 | else: 136 | print('') 137 | timer.reset_stage() 138 | 139 | torch.save(model.module.state_dict(), args.session_name + '.pth') 140 | -------------------------------------------------------------------------------- /src/misc/indexing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | 6 | class PathIndex: 7 | 8 | def __init__(self, radius, default_size): 9 | self.radius = radius 10 | self.radius_floor = int(np.ceil(radius) - 1) 11 | 12 | self.search_paths, self.search_dst = self.get_search_paths_dst(self.radius) 13 | 14 | self.path_indices, self.src_indices, self.dst_indices = self.get_path_indices(default_size) 15 | 16 | return 17 | 18 | def get_search_paths_dst(self, max_radius=5): 19 | 20 | coord_indices_by_length = [[] for _ in range(max_radius * 4)] 21 | 22 | search_dirs = [] 23 | 24 | for x in range(1, max_radius): 25 | search_dirs.append((0, x)) 26 | 27 | for y in range(1, max_radius): 28 | for x in range(-max_radius + 1, max_radius): 29 | if x * x + y * y < max_radius ** 2: 30 | search_dirs.append((y, x)) 31 | 32 | for dir in search_dirs: 33 | 34 | length_sq = dir[0] ** 2 + dir[1] ** 2 35 | path_coords = [] 36 | 37 | min_y, max_y = sorted((0, dir[0])) 38 | min_x, max_x = sorted((0, dir[1])) 39 | 40 | for y in range(min_y, max_y + 1): 41 | for x in range(min_x, max_x + 1): 42 | 43 | dist_sq = (dir[0] * x - dir[1] * y) ** 2 / length_sq 44 | 45 | if dist_sq < 1: 46 | path_coords.append([y, x]) 47 | 48 | path_coords.sort(key=lambda x: -abs(x[0]) - abs(x[1])) 49 | path_length = len(path_coords) 50 | 51 | coord_indices_by_length[path_length].append(path_coords) 52 | 53 | path_list_by_length = [np.asarray(v) for v in coord_indices_by_length if v] 54 | path_destinations = np.concatenate([p[:, 0] for p in path_list_by_length], axis=0) 55 | 56 | return path_list_by_length, path_destinations 57 | 58 | def get_path_indices(self, size): 59 | 60 | full_indices = np.reshape(np.arange(0, size[0] * size[1], dtype=np.int64), (size[0], size[1])) 61 | 62 | cropped_height = size[0] - self.radius_floor 63 | cropped_width = size[1] - 2 * self.radius_floor 64 | 65 | path_indices = [] 66 | 67 | for paths in self.search_paths: 68 | 69 | path_indices_list = [] 70 | for p in paths: 71 | 72 | coord_indices_list = [] 73 | 74 | for dy, dx in p: 75 | coord_indices = full_indices[dy:dy + cropped_height, 76 | self.radius_floor + dx:self.radius_floor + dx + cropped_width] 77 | coord_indices = np.reshape(coord_indices, [-1]) 78 | 79 | coord_indices_list.append(coord_indices) 80 | 81 | path_indices_list.append(coord_indices_list) 82 | 83 | path_indices.append(np.array(path_indices_list)) 84 | 85 | src_indices = np.reshape(full_indices[:cropped_height, self.radius_floor:self.radius_floor + cropped_width], -1) 86 | dst_indices = np.concatenate([p[:,0] for p in path_indices], axis=0) 87 | 88 | return path_indices, src_indices, dst_indices 89 | 90 | 91 | def edge_to_affinity(edge, paths_indices): 92 | 93 | aff_list = [] 94 | edge = edge.view(edge.size(0), -1) 95 | 96 | for i in range(len(paths_indices)): 97 | if isinstance(paths_indices[i], np.ndarray): 98 | paths_indices[i] = torch.from_numpy(paths_indices[i]) 99 | paths_indices[i] = paths_indices[i].cuda(non_blocking=True) 100 | 101 | for ind in paths_indices: 102 | ind_flat = ind.view(-1) 103 | dist = torch.index_select(edge, dim=-1, index=ind_flat) 104 | dist = dist.view(dist.size(0), ind.size(0), ind.size(1), ind.size(2)) 105 | aff = torch.squeeze(1 - F.max_pool2d(dist, (dist.size(2), 1)), dim=2) 106 | aff_list.append(aff) 107 | aff_cat = torch.cat(aff_list, dim=1) 108 | 109 | return aff_cat 110 | 111 | 112 | def affinity_sparse2dense(affinity_sparse, ind_from, ind_to, n_vertices): 113 | 114 | ind_from = torch.from_numpy(ind_from) 115 | ind_to = torch.from_numpy(ind_to) 116 | 117 | affinity_sparse = affinity_sparse.view(-1).cpu() 118 | ind_from = ind_from.repeat(ind_to.size(0)).view(-1) 119 | ind_to = ind_to.view(-1) 120 | 121 | indices = torch.stack([ind_from, ind_to]) 122 | indices_tp = torch.stack([ind_to, ind_from]) 123 | 124 | indices_id = torch.stack([torch.arange(0, n_vertices).long(), torch.arange(0, n_vertices).long()]) 125 | 126 | affinity_dense = torch.sparse.FloatTensor(torch.cat([indices, indices_id, indices_tp], dim=1), 127 | torch.cat([affinity_sparse, torch.ones([n_vertices]), affinity_sparse])).to_dense().cuda() 128 | 129 | return affinity_dense 130 | 131 | 132 | def to_transition_matrix(affinity_dense, beta, times): 133 | scaled_affinity = torch.pow(affinity_dense, beta) 134 | 135 | trans_mat = scaled_affinity / torch.sum(scaled_affinity, dim=0, keepdim=True) 136 | for _ in range(times): 137 | trans_mat = torch.matmul(trans_mat, trans_mat) 138 | 139 | return trans_mat 140 | 141 | def propagate_to_edge(x, edge, radius=5, beta=10, exp_times=8): 142 | 143 | height, width = x.shape[-2:] 144 | 145 | hor_padded = width+radius*2 146 | ver_padded = height+radius 147 | 148 | path_index = PathIndex(radius=radius, default_size=(ver_padded, hor_padded)) 149 | 150 | edge_padded = F.pad(edge, (radius, radius, 0, radius), mode='constant', value=1.0) 151 | sparse_aff = edge_to_affinity(torch.unsqueeze(edge_padded, 0), 152 | path_index.path_indices) 153 | 154 | dense_aff = affinity_sparse2dense(sparse_aff, path_index.src_indices, 155 | path_index.dst_indices, ver_padded * hor_padded) 156 | dense_aff = dense_aff.view(ver_padded, hor_padded, ver_padded, hor_padded) 157 | dense_aff = dense_aff[:-radius, radius:-radius, :-radius, radius:-radius] 158 | dense_aff = dense_aff.reshape(height * width, height * width) 159 | 160 | trans_mat = to_transition_matrix(dense_aff, beta=beta, times=exp_times) 161 | 162 | x = x.view(-1, height, width) * (1 - edge) 163 | 164 | rw = torch.matmul(x.view(-1, height * width), trans_mat) 165 | rw = rw.view(rw.size(0), 1, height, width) 166 | 167 | return rw -------------------------------------------------------------------------------- /scripts/SEAM/infer_aff.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from .tool import imutils 4 | 5 | import argparse 6 | import importlib 7 | import numpy as np 8 | 9 | from .voc12 import data 10 | from torch.utils.data import DataLoader 11 | import scipy.misc 12 | import torch.nn.functional as F 13 | import os.path 14 | from .network import resnet38_aff 15 | from torchvision.transforms import ToPILImage, ToTensor 16 | 17 | 18 | def get_indices_in_radius(height, width, radius): 19 | 20 | search_dist = [] 21 | for x in range(1, radius): 22 | search_dist.append((0, x)) 23 | 24 | for y in range(1, radius): 25 | for x in range(-radius+1, radius): 26 | if x*x + y*y < radius*radius: 27 | search_dist.append((y, x)) 28 | 29 | full_indices = np.reshape(np.arange(0, height * width, dtype=np.int64), 30 | (height, width)) 31 | radius_floor = radius-1 32 | cropped_height = height - radius_floor 33 | cropped_width = width - 2 * radius_floor 34 | 35 | indices_from = np.reshape(full_indices[:-radius_floor, radius_floor:-radius_floor], [-1]) 36 | 37 | indices_from_to_list = [] 38 | 39 | for dy, dx in search_dist: 40 | 41 | indices_to = full_indices[dy:dy + cropped_height, radius_floor + dx:radius_floor + dx + cropped_width] 42 | indices_to = np.reshape(indices_to, [-1]) 43 | 44 | indices_from_to = np.stack((indices_from, indices_to), axis=1) 45 | 46 | indices_from_to_list.append(indices_from_to) 47 | 48 | concat_indices_from_to = np.concatenate(indices_from_to_list, axis=0) 49 | 50 | return concat_indices_from_to 51 | 52 | def HCW_to_CHW(tensor, sal=False): 53 | if sal: 54 | tensor = np.expand_dims(tensor, axis=0) 55 | else: 56 | tensor = np.transpose(tensor, (1, 2, 0)) 57 | return tensor 58 | 59 | 60 | def name_img(name, img, SEAM_model): 61 | name = name 62 | 63 | 64 | # label = ToTensor()(label) 65 | # img = ToPILImage()(img) 66 | model = SEAM_model 67 | unit = 1 68 | scales = [0.5, 1.0, 1.5, 2.0] 69 | inter_transform = torchvision.transforms.Compose( 70 | [np.asarray, 71 | model.normalize, 72 | # ToTensor(), 73 | imutils.HWC_to_CHW 74 | ]) 75 | intera_transform = torchvision.transforms.Compose( 76 | [ToTensor(), 77 | HCW_to_CHW 78 | ]) 79 | 80 | img = inter_transform(img) 81 | img = intera_transform(img) 82 | img = img[None] 83 | 84 | return name, img 85 | 86 | def infer_aff(name, img, cam_dict, weights_dir = "", model=None): 87 | 88 | weights =weights_dir 89 | # network ="network.resnet38_aff" 90 | alpha = 6 91 | beta = 8 92 | logt = 6 93 | crf = False 94 | 95 | if model is None: 96 | model = resnet38_aff.Net() 97 | model.load_state_dict(torch.load(weights), strict=False) 98 | 99 | model.eval() 100 | model.cuda() 101 | 102 | # infer_dataset = voc12.data.VOC12ImageDataset(infer_list, voc12_root=voc12_root, 103 | # transform=torchvision.transforms.Compose( 104 | # [np.asarray, 105 | # model.normalize, 106 | # imutils.HWC_to_CHW])) 107 | # infer_data_loader = DataLoader(infer_dataset, shuffle=False, num_workers=num_workers, pin_memory=True) 108 | name, img = name_img(name, img, model) 109 | 110 | # for iter, (name, img) in enumerate(infer_data_loader): 111 | 112 | # name = name[0] 113 | # print(iter) 114 | 115 | orig_shape = img.shape 116 | padded_size = (int(np.ceil(img.shape[2]/8)*8), int(np.ceil(img.shape[3]/8)*8)) 117 | 118 | p2d = (0, padded_size[1] - img.shape[3], 0, padded_size[0] - img.shape[2]) 119 | img = F.pad(img, p2d) 120 | 121 | dheight = int(np.ceil(img.shape[2]/8)) 122 | dwidth = int(np.ceil(img.shape[3]/8)) 123 | 124 | # cam = np.load(os.path.join(cam_dir, name + '.npy'), allow_pickle=True).item() 125 | cam = cam_dict 126 | 127 | cam_full_arr = np.zeros((21, orig_shape[2], orig_shape[3]), np.float32) 128 | for k, v in cam.items(): 129 | cam_full_arr[k+1] = v 130 | cam_full_arr[0] = (1 - np.max(cam_full_arr[1:], (0), keepdims=False))**alpha 131 | #cam_full_arr[0] = 0.2 132 | cam_full_arr = np.pad(cam_full_arr, ((0, 0), (0, p2d[3]), (0, p2d[1])), mode='constant') 133 | 134 | with torch.no_grad(): 135 | aff_mat = torch.pow(model.forward(img.cuda(), True), beta) 136 | 137 | trans_mat = aff_mat / torch.sum(aff_mat, dim=0, keepdim=True) 138 | for _ in range(logt): 139 | trans_mat = torch.matmul(trans_mat, trans_mat) 140 | 141 | cam_full_arr = torch.from_numpy(cam_full_arr) 142 | cam_full_arr = F.avg_pool2d(cam_full_arr, 8, 8) 143 | 144 | cam_vec = cam_full_arr.view(21, -1) 145 | 146 | cam_rw = torch.matmul(cam_vec.cuda(), trans_mat) 147 | cam_rw = cam_rw.view(1, 21, dheight, dwidth) 148 | 149 | cam_rw = torch.nn.Upsample((img.shape[2], img.shape[3]), mode='bilinear')(cam_rw) 150 | 151 | if crf: 152 | img_8 = img[0].numpy().transpose((1,2,0))#F.interpolate(img, (dheight,dwidth), mode='bilinear')[0].numpy().transpose((1,2,0)) 153 | img_8 = np.ascontiguousarray(img_8) 154 | mean = (0.485, 0.456, 0.406) 155 | std = (0.229, 0.224, 0.225) 156 | img_8[:,:,0] = (img_8[:,:,0]*std[0] + mean[0])*255 157 | img_8[:,:,1] = (img_8[:,:,1]*std[1] + mean[1])*255 158 | img_8[:,:,2] = (img_8[:,:,2]*std[2] + mean[2])*255 159 | img_8[img_8 > 255] = 255 160 | img_8[img_8 < 0] = 0 161 | img_8 = img_8.astype(np.uint8) 162 | cam_rw = cam_rw[0].cpu().numpy() 163 | cam_rw = imutils.crf_inference(img_8, cam_rw, t=1) 164 | cam_rw = torch.from_numpy(cam_rw).view(1, 21, img.shape[2], img.shape[3]).cuda() 165 | 166 | 167 | _, cam_rw_pred = torch.max(cam_rw, 1) 168 | 169 | preds = np.uint8(cam_rw_pred.cpu().data[0])[:orig_shape[2], :orig_shape[3]] 170 | probs = cam_rw.cpu().data[0][:, :orig_shape[2], :orig_shape[3]] 171 | # scipy.misc.imsave(os.path.join(out_rw, name + '.png'), res) 172 | # print("saved : %s" %os.path.join(out_rw, name + '.png')) 173 | assert probs.shape[1] == preds.shape[0] 174 | assert probs.shape[2] == preds.shape[1] 175 | print("Done infer_aff") 176 | return preds, probs 177 | -------------------------------------------------------------------------------- /scripts/across_fish.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | 3 | path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 4 | sys.path.insert(0, path) 5 | 6 | from haven import haven_chk as hc 7 | from haven import haven_results as hr 8 | from haven import haven_utils as hu 9 | import torch 10 | import torchvision 11 | import tqdm 12 | import pandas as pd 13 | import pprint 14 | import itertools 15 | import os 16 | import pylab as plt 17 | import time 18 | import numpy as np 19 | 20 | from src import models 21 | from src import datasets 22 | from src import utils as ut 23 | from src.models import metrics 24 | 25 | import argparse 26 | 27 | from torch.utils.data import sampler 28 | from torch.utils.data.sampler import RandomSampler 29 | from torch.backends import cudnn 30 | from torch.nn import functional as F 31 | from torch.utils.data import DataLoader 32 | import pandas as pd 33 | 34 | cudnn.benchmark = True 35 | 36 | if __name__ == "__main__": 37 | savedir_base = '/mnt/public/results/toolkit/weak_supervision' 38 | 39 | # hash_list = ['b04090f27c7c52bcec65f6ba455ed2d8', 40 | # '6d4af38d64b23586e71a198de2608333', 41 | # '84ced18cf5c1fb3ad5820cc1b55a38fa', 42 | # '63f29eec3dbe1e03364f198ed7d4b414', 43 | # '017e7441c2f581b6fee9e3ac6f574edc'] 44 | 45 | # hash_dct = {'b04090f27c7c52bcec65f6ba455ed2d8': 'Fully_Supervised', 46 | # '6d4af38d64b23586e71a198de2608333': 'LCFCN', 47 | # '84ced18cf5c1fb3ad5820cc1b55a38fa': 'LCFCN+Affinity_(ours)', 48 | # '63f29eec3dbe1e03364f198ed7d4b414': 'Point-level_Loss ', 49 | # '017e7441c2f581b6fee9e3ac6f574edc': 'Cross_entropy_Loss+pseudo-mask'} 50 | hash_dct = {'a55d2c5dda331b1a0e191b104406dd1c': 'LCFCN', 51 | '13b0f4e395b6dc5368f7965c20e75612': 'A-LCFCN', 52 | 'fcc1acac9ff5c2fa776d65ac76c3892b': 'A-LCFCN + PM'} 53 | hash_list = ['a55d2c5dda331b1a0e191b104406dd1c', 54 | '13b0f4e395b6dc5368f7965c20e75612', 55 | 'fcc1acac9ff5c2fa776d65ac76c3892b'] 56 | datadir = '/mnt/public/datasets/DeepFish/' 57 | 58 | score_list = [] 59 | for hash_id in hash_list: 60 | fname = os.path.join('/mnt/public/predictions/fish/%s.pkl' % hash_id) 61 | exp_dict = hu.load_json(os.path.join(savedir_base, hash_id, 'exp_dict.json')) 62 | if os.path.exists(fname) and 0: 63 | print('FOUND:', fname) 64 | val_dict = hu.load_pkl(fname) 65 | else: 66 | 67 | train_set = datasets.get_dataset(dataset_dict=exp_dict["dataset"], 68 | split='train', 69 | datadir=datadir, 70 | exp_dict=exp_dict, 71 | dataset_size=exp_dict['dataset_size']) 72 | 73 | test_set = datasets.get_dataset(dataset_dict=exp_dict["dataset"], 74 | split='test', 75 | datadir=datadir, 76 | exp_dict=exp_dict, 77 | dataset_size=exp_dict['dataset_size']) 78 | 79 | test_loader = DataLoader(test_set, 80 | batch_size=1, 81 | collate_fn=ut.collate_fn, 82 | num_workers=0) 83 | pprint.pprint(exp_dict) 84 | # Model 85 | # ================== 86 | model = models.get_model(model_dict=exp_dict['model'], 87 | exp_dict=exp_dict, 88 | train_set=train_set).cuda() 89 | 90 | model_path = os.path.join(savedir_base, hash_id, 'model_best.pth') 91 | 92 | # load best model 93 | model.load_state_dict(hu.torch_load(model_path)) 94 | # loop over the val_loader and saves image 95 | # get counts 96 | counts = [] 97 | for i, batch in enumerate(test_loader): 98 | count = float((batch['points'] == 1).sum()) 99 | counts += [count] 100 | hu.save_image('.tmp/counts/%d_%d.png' % (i, len(batch['point_list'][0])//2), hu.denormalize(batch['images'], mode='rgb'), mask=batch['masks'].numpy()) 101 | counts = np.array(counts) 102 | 103 | val_dict = {} 104 | val_dict_lst = [] 105 | for c in np.unique(counts): 106 | val_meter = metrics.SegMeter(split=test_loader.dataset.split) 107 | 108 | for i, batch in enumerate(tqdm.tqdm(test_loader)): 109 | count = float((batch['points'] == 1).sum()) 110 | if count != c: 111 | continue 112 | 113 | val_meter.val_on_batch(model, batch) 114 | score_dict = val_meter.get_avg_score() 115 | # pprint.pprint(score_dict) 116 | 117 | val_dict[c] = val_meter.get_avg_score() 118 | val_dict_dfc = pd.DataFrame([val_meter.get_avg_score()]) 119 | val_dict_dfc.insert(0, "Count", int(c), True) 120 | val_dict_dfc.rename( 121 | columns={'test_score': 'mIoU', 'test_class0': 'IoU class 0', 'test_class1': 'IoU class 1', 122 | 'test_mae': 'MAE', 'test_game': 'GAME'}, inplace=True) 123 | val_dict_lst.append(val_dict_dfc) 124 | val_dict_df = pd.concat(val_dict_lst, axis=0) 125 | val_dict_df.to_csv(os.path.join('/mnt/public/predictions/fish/', "%s_count_score_df.csv" % hash_id), 126 | index=False) 127 | val_dict_df.to_latex(os.path.join('/mnt/public/predictions/fish/', "%s_count_score_df.tex" % hash_id), 128 | index=False, caption=hash_dct[hash_id], label=hash_dct[hash_id]) 129 | 130 | hu.save_pkl(fname, val_dict) 131 | 132 | val_dict['model'] = exp_dict['model'] 133 | val_dict['hash_id'] = hash_id 134 | score_list += [val_dict] 135 | 136 | print(pd.DataFrame(score_list)) 137 | # score_df = pd.DataFrame(score_list) 138 | # score_df.to_csv(os.path.join('/mnt/public/predictions/fish/', "score_df.csv")) 139 | # score_df.to_latex(os.path.join('/mnt/public/predictions/fish/', "count_score_df.tex")) 140 | -------------------------------------------------------------------------------- /scripts/SEAM/tool/torchutils.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch.utils.data import Dataset 4 | from PIL import Image 5 | import os.path 6 | import random 7 | import numpy as np 8 | from tool import imutils 9 | 10 | class PolyOptimizer(torch.optim.SGD): 11 | 12 | def __init__(self, params, lr, weight_decay, max_step, momentum=0.9): 13 | super().__init__(params, lr, weight_decay) 14 | 15 | self.global_step = 0 16 | self.max_step = max_step 17 | self.momentum = momentum 18 | 19 | self.__initial_lr = [group['lr'] for group in self.param_groups] 20 | 21 | 22 | def step(self, closure=None): 23 | 24 | if self.global_step < self.max_step: 25 | lr_mult = (1 - self.global_step / self.max_step) ** self.momentum 26 | 27 | for i in range(len(self.param_groups)): 28 | self.param_groups[i]['lr'] = self.__initial_lr[i] * lr_mult 29 | 30 | super().step(closure) 31 | 32 | self.global_step += 1 33 | 34 | 35 | class BatchNorm2dFixed(torch.nn.Module): 36 | 37 | def __init__(self, num_features, eps=1e-5): 38 | super(BatchNorm2dFixed, self).__init__() 39 | self.num_features = num_features 40 | self.eps = eps 41 | self.weight = torch.nn.Parameter(torch.Tensor(num_features)) 42 | self.bias = torch.nn.Parameter(torch.Tensor(num_features)) 43 | self.register_buffer('running_mean', torch.zeros(num_features)) 44 | self.register_buffer('running_var', torch.ones(num_features)) 45 | 46 | 47 | def forward(self, input): 48 | 49 | return F.batch_norm( 50 | input, self.running_mean, self.running_var, self.weight, self.bias, 51 | False, eps=self.eps) 52 | 53 | def __call__(self, x): 54 | return self.forward(x) 55 | 56 | 57 | class SegmentationDataset(Dataset): 58 | def __init__(self, img_name_list_path, img_dir, label_dir, rescale=None, flip=False, cropsize=None, 59 | img_transform=None, mask_transform=None): 60 | self.img_name_list_path = img_name_list_path 61 | self.img_dir = img_dir 62 | self.label_dir = label_dir 63 | 64 | self.img_transform = img_transform 65 | self.mask_transform = mask_transform 66 | 67 | self.img_name_list = open(self.img_name_list_path).read().splitlines() 68 | 69 | self.rescale = rescale 70 | self.flip = flip 71 | self.cropsize = cropsize 72 | 73 | def __len__(self): 74 | return len(self.img_name_list) 75 | 76 | def __getitem__(self, idx): 77 | 78 | name = self.img_name_list[idx] 79 | 80 | img = Image.open(os.path.join(self.img_dir, name + '.jpg')).convert("RGB") 81 | mask = Image.open(os.path.join(self.label_dir, name + '.png')) 82 | 83 | if self.rescale is not None: 84 | s = self.rescale[0] + random.random() * (self.rescale[1] - self.rescale[0]) 85 | adj_size = (round(img.size[0]*s/8)*8, round(img.size[1]*s/8)*8) 86 | img = img.resize(adj_size, resample=Image.CUBIC) 87 | mask = img.resize(adj_size, resample=Image.NEAREST) 88 | 89 | if self.img_transform is not None: 90 | img = self.img_transform(img) 91 | if self.mask_transform is not None: 92 | mask = self.mask_transform(mask) 93 | 94 | if self.cropsize is not None: 95 | img, mask = imutils.random_crop([img, mask], self.cropsize, (0, 255)) 96 | 97 | mask = imutils.RescaleNearest(0.125)(mask) 98 | 99 | if self.flip is True and bool(random.getrandbits(1)): 100 | img = np.flip(img, 1).copy() 101 | mask = np.flip(mask, 1).copy() 102 | 103 | img = np.transpose(img, (2, 0, 1)) 104 | 105 | return name, img, mask 106 | 107 | 108 | class ExtractAffinityLabelInRadius(): 109 | 110 | def __init__(self, cropsize, radius=5): 111 | self.radius = radius 112 | 113 | self.search_dist = [] 114 | 115 | for x in range(1, radius): 116 | self.search_dist.append((0, x)) 117 | 118 | for y in range(1, radius): 119 | for x in range(-radius+1, radius): 120 | if x*x + y*y < radius*radius: 121 | self.search_dist.append((y, x)) 122 | 123 | self.radius_floor = radius-1 124 | 125 | self.crop_height = cropsize - self.radius_floor 126 | self.crop_width = cropsize - 2 * self.radius_floor 127 | return 128 | 129 | def __call__(self, label): 130 | 131 | labels_from = label[:-self.radius_floor, self.radius_floor:-self.radius_floor] 132 | labels_from = np.reshape(labels_from, [-1]) 133 | 134 | labels_to_list = [] 135 | valid_pair_list = [] 136 | 137 | for dy, dx in self.search_dist: 138 | labels_to = label[dy:dy+self.crop_height, self.radius_floor+dx:self.radius_floor+dx+self.crop_width] 139 | labels_to = np.reshape(labels_to, [-1]) 140 | 141 | valid_pair = np.logical_and(np.less(labels_to, 255), np.less(labels_from, 255)) 142 | 143 | labels_to_list.append(labels_to) 144 | valid_pair_list.append(valid_pair) 145 | 146 | bc_labels_from = np.expand_dims(labels_from, 0) 147 | concat_labels_to = np.stack(labels_to_list) 148 | concat_valid_pair = np.stack(valid_pair_list) 149 | 150 | pos_affinity_label = np.equal(bc_labels_from, concat_labels_to) 151 | 152 | bg_pos_affinity_label = np.logical_and(pos_affinity_label, np.equal(bc_labels_from, 0)).astype(np.float32) 153 | 154 | fg_pos_affinity_label = np.logical_and(np.logical_and(pos_affinity_label, np.not_equal(bc_labels_from, 0)), concat_valid_pair).astype(np.float32) 155 | 156 | neg_affinity_label = np.logical_and(np.logical_not(pos_affinity_label), concat_valid_pair).astype(np.float32) 157 | 158 | return bg_pos_affinity_label, fg_pos_affinity_label, neg_affinity_label 159 | 160 | class AffinityFromMaskDataset(SegmentationDataset): 161 | def __init__(self, img_name_list_path, img_dir, label_dir, rescale=None, flip=False, cropsize=None, 162 | img_transform=None, mask_transform=None, radius=5): 163 | super().__init__(img_name_list_path, img_dir, label_dir, rescale, flip, cropsize, img_transform, mask_transform) 164 | 165 | self.radius = radius 166 | 167 | self.extract_aff_lab_func = ExtractAffinityLabelInRadius(cropsize=cropsize//8, radius=radius) 168 | 169 | def __getitem__(self, idx): 170 | name, img, mask = super().__getitem__(idx) 171 | 172 | aff_label = self.extract_aff_lab_func(mask) 173 | 174 | return name, img, aff_label 175 | -------------------------------------------------------------------------------- /scripts/test_affinity.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 3 | sys.path.insert(0, path) 4 | 5 | import torch 6 | import torchvision 7 | import tqdm 8 | import pandas as pd 9 | import pprint 10 | import itertools 11 | import os 12 | import pylab as plt 13 | import exp_configs 14 | import time 15 | import numpy as np 16 | 17 | from src import models 18 | from src import datasets 19 | from src import utils as ut 20 | 21 | from src import misc 22 | import argparse 23 | 24 | from torch.utils.data import sampler 25 | from torch.utils.data.sampler import RandomSampler 26 | from torch.backends import cudnn 27 | from torch.nn import functional as F 28 | from torch.utils.data import DataLoader 29 | from src import datasets 30 | # from src import optimizers 31 | import torchvision 32 | from src import models 33 | cudnn.benchmark = True 34 | 35 | from haven import haven_utils as hu 36 | from haven import haven_img as hi 37 | from haven import haven_results as hr 38 | from haven import haven_chk as hc 39 | # from src import looc_utils as lu 40 | from PIL import Image 41 | import scipy.io 42 | from src import models 43 | from src import utils as ut 44 | import exp_configs 45 | 46 | import argparse 47 | import numpy as np 48 | import time 49 | import cv2, pprint 50 | from PIL import Image 51 | import torch 52 | from SEAM.infer_SEAM import infer_SEAM 53 | from SEAM.infer_aff import infer_aff 54 | from SEAM.network import resnet38_SEAM, resnet38_aff 55 | 56 | 57 | def get_indices_in_radius(height, width, radius=5): 58 | search_dist = [] 59 | for x in range(1, radius): 60 | search_dist.append((0, x)) 61 | 62 | for y in range(1, radius): 63 | for x in range(-radius+1, radius): 64 | if x*x + y*y < radius*radius: 65 | search_dist.append((y, x)) 66 | 67 | full_indices = np.reshape(np.arange(0, height * width, dtype=np.int64), 68 | (height, width)) 69 | radius_floor = radius-1 70 | cropped_height = height - radius_floor 71 | cropped_width = width - 2 * radius_floor 72 | 73 | indices_from = np.reshape(full_indices[:-radius_floor, radius_floor:-radius_floor], [-1]) 74 | 75 | indices_from_to_list = [] 76 | 77 | for dy, dx in search_dist: 78 | 79 | indices_to = full_indices[dy:dy + cropped_height, radius_floor + dx:radius_floor + dx + cropped_width] 80 | indices_to = np.reshape(indices_to, [-1]) 81 | 82 | indices_from_to = np.stack((indices_from, indices_to), axis=1) 83 | 84 | indices_from_to_list.append(indices_from_to) 85 | 86 | concat_indices_from_to = np.concatenate(indices_from_to_list, axis=0) 87 | 88 | return concat_indices_from_to 89 | 90 | def get_affinity_labels(segm_map, indices_from, indices_to, n_classes=2): 91 | # _, n_classes, _, _ = segm_map.shape 92 | segm_map_flat = np.reshape(segm_map, -1) 93 | 94 | segm_label_from = np.expand_dims(segm_map_flat[indices_from], axis=0) 95 | segm_label_to = segm_map_flat[indices_to] 96 | 97 | valid_label = np.logical_and(np.less(segm_label_from, n_classes), np.less(segm_label_to, n_classes)) 98 | 99 | equal_label = np.equal(segm_label_from, segm_label_to) 100 | 101 | pos_affinity_label = np.logical_and(equal_label, valid_label) 102 | 103 | bg_pos_affinity_label = np.logical_and(pos_affinity_label, np.equal(segm_label_from, 0)).astype(np.float32) 104 | fg_pos_affinity_label = np.logical_and(pos_affinity_label, np.greater(segm_label_from, 0)).astype(np.float32) 105 | 106 | neg_affinity_label = np.logical_and(np.logical_not(equal_label), valid_label).astype(np.float32) 107 | 108 | return torch.from_numpy(bg_pos_affinity_label), torch.from_numpy(fg_pos_affinity_label), \ 109 | torch.from_numpy(neg_affinity_label) 110 | 111 | 112 | 113 | 114 | if __name__ == "__main__": 115 | exp_dict = {'batch_size': 1, 116 | 'dataset': {'n_classes': 2, 'name': 'JcuFish'}, 117 | 'dataset_size': {'train': 'all', 'val': 'all'}, 118 | 'lr': 1e-06, 119 | 'max_epoch': 100, 120 | 'model': {'base': 'fcn8_vgg16', 121 | 'loss': 'point_level', 122 | 'n_channels': 3, 123 | 'n_classes': 2, 124 | 'name': 'semseg'}, 125 | 'num_channels': 1, 126 | 'optimizer': 'adam'} 127 | pprint.pprint(exp_dict) 128 | train_set = datasets.get_dataset(dataset_dict=exp_dict["dataset"], 129 | split="train", 130 | datadir='/mnt/public/datasets/DeepFish', 131 | exp_dict=exp_dict, 132 | dataset_size=exp_dict['dataset_size']) 133 | 134 | model_seam = resnet38_SEAM.Net().cuda() 135 | model_seam.load_state_dict(torch.load(os.path.join('/mnt/public/weights', 'resnet38_SEAM.pth'))) 136 | 137 | model_aff = resnet38_aff.Net().cuda() 138 | model_aff.load_state_dict(torch.load(os.path.join('/mnt/public/weights', 'resnet38_aff_SEAM.pth')), strict=False) 139 | 140 | # ut.generate_seam_segmentation(train_set, 141 | # path_base='/mnt/datasets/public/issam/seam', 142 | # # path_base='D:/Issam/SEAM_model/' 143 | # ) 144 | # stop 145 | model = models.get_model(model_dict=exp_dict['model'], exp_dict=exp_dict, train_set=train_set).cuda() 146 | exp_id = hu.hash_dict(exp_dict) 147 | fname = os.path.join('/mnt/public/results/toolkit/weak_supervision', exp_id, 'model.pth') 148 | model.model_base.load_state_dict(torch.load(fname)['model'], strict=False) 149 | 150 | for k in range(5): 151 | batch_id = np.where(train_set.labels)[0][k] 152 | batch = ut.collate_fn([train_set[batch_id]]) 153 | logits = F.softmax(model.model_base.forward(batch['images'].cuda()), dim=1) 154 | 155 | img = batch['images'].cuda() 156 | logits_new = model_aff.apply_affinity( batch['images'], logits, crf=0) 157 | 158 | i1 = hu.save_image('old.png', 159 | img=hu.denormalize(img, mode='rgb'), 160 | mask=logits.argmax(dim=1).cpu().numpy(), return_image=True) 161 | 162 | i2 = hu.save_image('new.png', 163 | img=hu.denormalize(img, mode='rgb'), 164 | mask=logits_new.argmax(dim=1).cpu().numpy(), return_image=True) 165 | hu.save_image('tmp/tmp%d.png' % k, np.concatenate([np.array(i1), np.array(i2)], axis=1)) 166 | print('saved %d' %k) 167 | -------------------------------------------------------------------------------- /scripts/SEAM/infer_SEAM_good.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | import cv2 5 | import os 6 | import voc12.data 7 | import scipy.misc 8 | import importlib 9 | from torch.utils.data import DataLoader 10 | import torchvision 11 | from tool import imutils, pyutils#, visualization 12 | import argparse 13 | from PIL import Image 14 | import torch.nn.functional as F 15 | import pandas as pd 16 | import PIL.Image 17 | 18 | 19 | def msf_img_lists(name, img, label, SEAM_model): 20 | name, img, label = name, img, label 21 | model = SEAM_model 22 | unit = 1 23 | scales = [0.5, 1.0, 1.5, 2.0] 24 | inter_transform = torchvision.transforms.Compose( 25 | [np.asarray, 26 | model.normalize, 27 | imutils.HWC_to_CHW]) 28 | 29 | rounded_size = (int(round(img.size[0] / unit) * unit), int(round(img.size[1] / unit) * unit)) 30 | 31 | ms_img_list = [] 32 | for s in scales: 33 | target_size = (round(rounded_size[0] * s), 34 | round(rounded_size[1] * s)) 35 | s_img = img.resize(target_size, resample=PIL.Image.CUBIC) 36 | ms_img_list.append(s_img) 37 | 38 | if inter_transform: 39 | for i in range(len(ms_img_list)): 40 | ms_img_list[i] = inter_transform(ms_img_list[i]) 41 | 42 | msf_img_list = [] 43 | for i in range(len(ms_img_list)): 44 | msf_img_list.append(ms_img_list[i]) 45 | msf_img_list.append(np.flip(ms_img_list[i], -1).copy()) 46 | 47 | return name, msf_img_list, label 48 | 49 | def infer_SEAM(): 50 | # parser = argparse.ArgumentParser() 51 | # parser.add_argument("--weights", required=True, type=str) 52 | # parser.add_argument("--network", default="network.resnet38_SEAM", type=str) 53 | # parser.add_argument("--infer_list", default="voc12/train.txt", type=str) 54 | # parser.add_argument("--num_workers", default=8, type=int) 55 | # parser.add_argument("--voc12_root", default='VOC2012', type=str) 56 | # parser.add_argument("--out_cam", default=None, type=str) 57 | # parser.add_argument("--out_crf", default=None, type=str) 58 | # parser.add_argument("--out_cam_pred", default=None, type=str) 59 | # parser.add_argument("--out_cam_pred_alpha", default=0.26, type=float) 60 | 61 | """ 62 | --weights "C:/Users/Alzay/NewPC_OneDrive/OneDrive - James Cook University/dev/ISPS/Alz_temp/SEAM/SEAM_model/resnet38_SEAM.pth" 63 | --infer_list "D:/prototypes/_cam_dir/val.txt" 64 | --out_cam "D:/prototypes/_cam_dir" 65 | --voc12_root "D:/Datasets/Pascal_Voc_Dataset/pascal_2012/VOCdevkit/VOC2012/" 66 | --out_crf "D:\prototypes\_crf_dir" 67 | """ 68 | weights ="C:/Users/Alzay/NewPC_OneDrive/OneDrive - James Cook University/dev/ISPS/Alz_temp/SEAM/SEAM_model/resnet38_SEAM.pth" 69 | network ="network.resnet38_SEAM" 70 | num_workers =1 71 | voc12_root ="D:/Datasets/Pascal_Voc_Dataset/pascal_2012/VOCdevkit/VOC2012/" 72 | infer_list =os.path.join(voc12_root,"ImageSets", "Segmentation", "trainval.txt") 73 | # infer_list =[os.path.splitext(filename)[0] for filename in os.listdir(os.path.join(voc12_root,"JPEGImages"))] 74 | out_cam ="D:/prototypes/_cam_dir" 75 | out_crf ="D:\prototypes\_crf_dir" 76 | out_cam_pred ="D:/prototypes/out_cam_pred/" 77 | out_cam_pred_alpha =0.26 78 | 79 | # args = parser.parse_args() 80 | crf_alpha = [4,24] 81 | model = getattr(importlib.import_module(network), 'Net')() 82 | model.load_state_dict(torch.load(weights)) 83 | 84 | model.eval() 85 | model.cuda() 86 | 87 | infer_dataset = voc12.data.VOC12ClsDatasetMSF(infer_list, voc12_root=voc12_root, 88 | scales=[0.5, 1.0, 1.5, 2.0], 89 | inter_transform=torchvision.transforms.Compose( 90 | [np.asarray, 91 | model.normalize, 92 | imutils.HWC_to_CHW])) 93 | 94 | infer_data_loader = DataLoader(infer_dataset, shuffle=False, num_workers=num_workers, pin_memory=True) 95 | 96 | n_gpus = torch.cuda.device_count() 97 | model_replicas = torch.nn.parallel.replicate(model, list(range(n_gpus))) 98 | 99 | for iter, (img_name, img_list, label) in enumerate(infer_data_loader): 100 | img_name = img_name[0]; label = label[0] 101 | 102 | img_path = voc12.data.get_img_path(img_name, voc12_root) 103 | orig_img = np.asarray(Image.open(img_path)) 104 | orig_img_size = orig_img.shape[:2] 105 | 106 | def _work(i, img): 107 | with torch.no_grad(): 108 | with torch.cuda.device(i%n_gpus): 109 | _, cam = model_replicas[i%n_gpus](img.cuda()) 110 | cam = F.upsample(cam[:,1:,:,:], orig_img_size, mode='bilinear', align_corners=False)[0] 111 | cam = cam.cpu().numpy() * label.clone().view(20, 1, 1).numpy() 112 | if i % 2 == 1: 113 | cam = np.flip(cam, axis=-1) 114 | return cam 115 | 116 | thread_pool = pyutils.BatchThreader(_work, list(enumerate(img_list)), 117 | batch_size=12, prefetch_size=0, processes=num_workers) 118 | 119 | cam_list = thread_pool.pop_results() 120 | 121 | sum_cam = np.sum(cam_list, axis=0) 122 | sum_cam[sum_cam < 0] = 0 123 | cam_max = np.max(sum_cam, (1,2), keepdims=True) 124 | cam_min = np.min(sum_cam, (1,2), keepdims=True) 125 | sum_cam[sum_cam < cam_min+1e-5] = 0 126 | norm_cam = (sum_cam-cam_min-1e-5) / (cam_max - cam_min + 1e-5) 127 | 128 | cam_dict = {} 129 | for i in range(20): 130 | if label[i] > 1e-5: 131 | cam_dict[i] = norm_cam[i] 132 | 133 | if out_cam is not None: 134 | np.save(os.path.join(out_cam, img_name + '.npy'), cam_dict) 135 | 136 | if out_cam_pred is not None: 137 | bg_score = [np.ones_like(norm_cam[0])*out_cam_pred_alpha] 138 | pred = np.argmax(np.concatenate((bg_score, norm_cam)), 0) 139 | scipy.misc.imsave(os.path.join(out_cam_pred, img_name + '.png'), pred.astype(np.uint8)) 140 | 141 | def _crf_with_alpha(cam_dict, alpha): 142 | v = np.array(list(cam_dict.values())) 143 | bg_score = np.power(1 - np.max(v, axis=0, keepdims=True), alpha) 144 | bgcam_score = np.concatenate((bg_score, v), axis=0) 145 | crf_score = imutils.crf_inference(orig_img, bgcam_score, labels=bgcam_score.shape[0]) 146 | 147 | n_crf_al = dict() 148 | 149 | n_crf_al[0] = crf_score[0] 150 | for i, key in enumerate(cam_dict.keys()): 151 | n_crf_al[key+1] = crf_score[i+1] 152 | 153 | return n_crf_al 154 | 155 | if out_crf is not None: 156 | for t in crf_alpha: 157 | crf = _crf_with_alpha(cam_dict, t) 158 | folder = out_crf + ('_%.1f'%t) 159 | if not os.path.exists(folder): 160 | os.makedirs(folder) 161 | np.save(os.path.join(folder, img_name + '.npy'), crf) 162 | 163 | print(iter) 164 | 165 | 166 | if __name__ == '__main__': 167 | 168 | infer_SEAM() -------------------------------------------------------------------------------- /src/models/networks/fcn8_vgg16_multiscale.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torchvision 3 | import torch 4 | from skimage import morphology as morph 5 | import numpy as np 6 | from src.modules.eprop import eprop 7 | import torch.utils.model_zoo as model_zoo 8 | 9 | #----------- LC-FCN8 10 | class FCN8VGG16(nn.Module): 11 | def __init__(self, n_classes): 12 | super().__init__() 13 | self.n_classes = n_classes 14 | # PREDEFINE LAYERS 15 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True) 16 | self.relu = nn.ReLU(inplace=True) 17 | 18 | # VGG16 PART 19 | self.conv1_1 = conv3x3(3, 64, stride=1, padding=100) 20 | self.conv1_2 = conv3x3(64, 64) 21 | 22 | self.conv2_1 = conv3x3(64, 128) 23 | self.conv2_2 = conv3x3(128, 128) 24 | 25 | self.conv3_1 = conv3x3(128, 256) 26 | self.conv3_2 = conv3x3(256, 256) 27 | self.conv3_3 = conv3x3(256, 256) 28 | 29 | self.conv4_1 = conv3x3(256, 512) 30 | self.conv4_2 = conv3x3(512, 512) 31 | self.conv4_3 = conv3x3(512, 512) 32 | 33 | self.conv5_1 = conv3x3(512, 512) 34 | self.conv5_2 = conv3x3(512, 512) 35 | self.conv5_3 = conv3x3(512, 512) 36 | 37 | self.fc6 = nn.Conv2d(512, 4096, kernel_size=7, stride=1, padding=0) 38 | self.dropout_f6 = nn.Dropout() 39 | self.fc7 = nn.Conv2d(4096, 4096, kernel_size=1, stride=1, padding=0) 40 | self.dropout_f7 = nn.Dropout() 41 | # SEMANTIC SEGMENTAION PART 42 | self.scoring_layer = nn.Conv2d(4096, self.n_classes, kernel_size=1, 43 | stride=1, padding=0) 44 | 45 | self.upscore2 = nn.ConvTranspose2d(self.n_classes, self.n_classes, 46 | kernel_size=4, stride=2, bias=False) 47 | self.upscore_pool4 = nn.ConvTranspose2d(self.n_classes, self.n_classes, 48 | kernel_size=4, stride=2, bias=False) 49 | self.upscore8 = nn.ConvTranspose2d(self.n_classes, self.n_classes, 50 | kernel_size=16, stride=8, bias=False) 51 | 52 | # Initilize Weights 53 | self.scoring_layer.weight.data.zero_() 54 | self.scoring_layer.bias.data.zero_() 55 | 56 | self.score_pool3 = nn.Conv2d(256, self.n_classes, kernel_size=1) 57 | self.score_pool4 = nn.Conv2d(512, self.n_classes, kernel_size=1) 58 | self.score_pool3.weight.data.zero_() 59 | self.score_pool3.bias.data.zero_() 60 | self.score_pool4.weight.data.zero_() 61 | self.score_pool4.bias.data.zero_() 62 | 63 | self.upscore2.weight.data.copy_(get_upsampling_weight(self.n_classes, self.n_classes, 4)) 64 | self.upscore_pool4.weight.data.copy_(get_upsampling_weight(self.n_classes, self.n_classes, 4)) 65 | self.upscore8.weight.data.copy_(get_upsampling_weight(self.n_classes, self.n_classes, 16)) 66 | self.eprop = eprop.EmbeddingPropagation() 67 | # Pretrained layers 68 | pth_url = 'https://download.pytorch.org/models/vgg16-397923af.pth' # download from model zoo 69 | state_dict = model_zoo.load_url(pth_url) 70 | 71 | layer_names = [layer_name for layer_name in state_dict] 72 | 73 | 74 | counter = 0 75 | for p in self.parameters(): 76 | if counter < 26: # conv1_1 to pool5 77 | p.data = state_dict[ layer_names[counter] ] 78 | elif counter == 26: # fc6 weight 79 | p.data = state_dict[ layer_names[counter] ].view(4096, 512, 7, 7) 80 | elif counter == 27: # fc6 bias 81 | p.data = state_dict[ layer_names[counter] ] 82 | elif counter == 28: # fc7 weight 83 | p.data = state_dict[ layer_names[counter] ].view(4096, 4096, 1, 1) 84 | elif counter == 29: # fc7 bias 85 | p.data = state_dict[ layer_names[counter] ] 86 | 87 | 88 | counter += 1 89 | 90 | def forward(self, x, return_features=False): 91 | n,c,h,w = x.size() 92 | # VGG16 PART 93 | conv1_1 = self.relu( self.conv1_1(x) ) 94 | conv1_2 = self.relu( self.conv1_2(conv1_1) ) 95 | pool1 = self.pool(conv1_2) 96 | 97 | conv2_1 = self.relu( self.conv2_1(pool1) ) 98 | conv2_2 = self.relu( self.conv2_2(conv2_1) ) 99 | pool2 = self.pool(conv2_2) 100 | # pool2 = self.eprop(pool2) 101 | conv3_1 = self.relu( self.conv3_1(pool2) ) 102 | conv3_2 = self.relu( self.conv3_2(conv3_1) ) 103 | conv3_3 = self.relu( self.conv3_3(conv3_2) ) 104 | pool3 = self.pool(conv3_3) 105 | 106 | conv4_1 = self.relu( self.conv4_1(pool3) ) 107 | conv4_2 = self.relu( self.conv4_2(conv4_1) ) 108 | conv4_3 = self.relu( self.conv4_3(conv4_2) ) 109 | pool4 = self.pool(conv4_3) 110 | 111 | conv5_1 = self.relu( self.conv5_1(pool4) ) 112 | conv5_2 = self.relu( self.conv5_2(conv5_1) ) 113 | conv5_3 = self.relu( self.conv5_3(conv5_2) ) 114 | pool5 = self.pool(conv5_3) 115 | 116 | fc6 = self.dropout_f6( self.relu( self.fc6(pool5) ) ) 117 | fc7 = self.dropout_f7( self.relu( self.fc7(fc6) ) ) 118 | 119 | # SEMANTIC SEGMENTATION PART 120 | # first 121 | scores = self.scoring_layer( fc7 ) 122 | upscore2 = self.upscore2(scores) 123 | 124 | # second 125 | score_pool4 = self.score_pool4(pool4) 126 | score_pool4c = score_pool4[:, :, 5:5+upscore2.size(2), 127 | 5:5+upscore2.size(3)] 128 | upscore_pool4 = self.upscore_pool4(score_pool4c + upscore2) 129 | 130 | # third 131 | score_pool3 = self.score_pool3(pool3) 132 | score_pool3c = score_pool3[:, :, 9:9+upscore_pool4.size(2), 133 | 9:9+upscore_pool4.size(3)] 134 | 135 | output = self.upscore8(score_pool3c + upscore_pool4) 136 | if return_features: 137 | return output[:, :, 31: (31 + h), 31: (31 + w)].contiguous(), [score_pool4c, score_pool3c] 138 | else: 139 | return output[:, :, 31: (31 + h), 31: (31 + w)].contiguous() 140 | 141 | # =========================================================== 142 | # helpers 143 | def get_upsampling_weight(in_channels, out_channels, kernel_size): 144 | """Make a 2D bilinear kernel suitable for upsampling""" 145 | factor = (kernel_size + 1) // 2 146 | if kernel_size % 2 == 1: 147 | center = factor - 1 148 | else: 149 | center = factor - 0.5 150 | og = np.ogrid[:kernel_size, :kernel_size] 151 | filt = (1 - abs(og[0] - center) / factor) * \ 152 | (1 - abs(og[1] - center) / factor) 153 | weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), 154 | dtype=np.float64) 155 | weight[range(in_channels), range(out_channels), :, :] = filt 156 | return torch.from_numpy(weight).float() 157 | 158 | 159 | def conv3x3(in_planes, out_planes, stride=1, padding=1): 160 | "3x3 convolution with padding" 161 | return nn.Conv2d(in_planes, out_planes, kernel_size=(3,3), stride=(stride,stride), 162 | padding=(padding,padding)) 163 | 164 | def conv1x1(in_planes, out_planes, stride=1): 165 | "1x1 convolution with padding" 166 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 167 | padding=0) -------------------------------------------------------------------------------- /src/modules/lcfcn/lcfcn_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import skimage 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from skimage.segmentation import find_boundaries, watershed 6 | from scipy import ndimage 7 | from skimage import morphology as morph 8 | 9 | 10 | 11 | def compute_loss(points, probs, roi_mask=None): 12 | """ 13 | images: n x c x h x w 14 | probs: h x w (0 or 1) 15 | """ 16 | points = points.squeeze() 17 | probs = probs.squeeze() 18 | 19 | assert(points.max() <= 1) 20 | 21 | tgt_list = get_tgt_list(points, probs, roi_mask=roi_mask) 22 | 23 | # image level 24 | # pt_flat = points.view(-1) 25 | pr_flat = probs.view(-1) 26 | 27 | # compute loss 28 | loss = 0. 29 | for tgt_dict in tgt_list: 30 | pr_subset = pr_flat[tgt_dict['ind_list']] 31 | # pr_subset = pr_subset.cpu() 32 | loss += tgt_dict['scale'] * F.binary_cross_entropy(pr_subset, 33 | torch.ones(pr_subset.shape, device=pr_subset.device) * tgt_dict['label'], 34 | reduction='mean') 35 | 36 | return loss 37 | 38 | @torch.no_grad() 39 | def get_tgt_list(points, probs, roi_mask=None): 40 | tgt_list = [] 41 | 42 | # image level 43 | pt_flat = points.view(-1) 44 | pr_flat = probs.view(-1) 45 | 46 | u_list = points.unique() 47 | if 0 in u_list: 48 | ind_bg = pr_flat.argmin() 49 | tgt_list += [{'scale': 1, 'ind_list':[ind_bg], 'label':0}] 50 | 51 | if 1 in u_list: 52 | ind_fg = pr_flat.argmax() 53 | tgt_list += [{'scale': 1, 'ind_list':[ind_fg], 'label':1}] 54 | 55 | # point level 56 | if 1 in u_list: 57 | ind_fg = torch.where(pt_flat==1)[0] 58 | tgt_list += [{'scale': len(ind_fg), 'ind_list':ind_fg, 'label':1}] 59 | 60 | # get blobs 61 | probs_numpy = probs.detach().cpu().numpy() 62 | blobs = get_blobs(probs_numpy, roi_mask=None) 63 | 64 | # get foreground and background blobs 65 | points = points.cpu().numpy() 66 | fg_uniques = np.unique(blobs * points) 67 | bg_uniques = [x for x in np.unique(blobs) if x not in fg_uniques] 68 | 69 | # split level 70 | # ----------- 71 | n_total = points.sum() 72 | 73 | if n_total > 1: 74 | # global split 75 | boundaries = watersplit(probs_numpy, points) 76 | ind_bg = np.where(boundaries.ravel())[0] 77 | 78 | tgt_list += [{'scale': (n_total-1), 'ind_list':ind_bg, 'label':0}] 79 | 80 | # local split 81 | for u in fg_uniques: 82 | if u == 0: 83 | continue 84 | 85 | ind = blobs==u 86 | 87 | b_points = points * ind 88 | n_points = b_points.sum() 89 | 90 | if n_points < 2: 91 | continue 92 | 93 | # local split 94 | boundaries = watersplit(probs_numpy, b_points)*ind 95 | ind_bg = np.where(boundaries.ravel())[0] 96 | 97 | tgt_list += [{'scale': (n_points - 1), 'ind_list':ind_bg, 'label':0}] 98 | 99 | # fp level 100 | for u in bg_uniques: 101 | if u == 0: 102 | continue 103 | 104 | b_mask = blobs==u 105 | if roi_mask is not None: 106 | b_mask = (roi_mask * b_mask) 107 | if b_mask.sum() == 0: 108 | pass 109 | # from haven import haven_utils as hu 110 | # hu.save_image('tmp.png', np.hstack([blobs==u, roi_mask])) 111 | # print() 112 | else: 113 | ind_bg = np.where(b_mask.ravel())[0] 114 | tgt_list += [{'scale': 1, 'ind_list':ind_bg, 'label':0}] 115 | 116 | return tgt_list 117 | 118 | 119 | def watersplit(_probs, _points): 120 | points = _points.copy() 121 | 122 | points[points != 0] = np.arange(1, points.sum()+1) 123 | points = points.astype(float) 124 | 125 | probs = ndimage.black_tophat(_probs.copy(), 7) 126 | seg = watershed(probs, points) 127 | 128 | return find_boundaries(seg) 129 | 130 | 131 | def get_blobs(probs, roi_mask=None): 132 | h, w = probs.shape 133 | 134 | pred_mask = (probs>0.5).astype('uint8') 135 | blobs = np.zeros((h, w), int) 136 | 137 | blobs = morph.label(pred_mask == 1) 138 | 139 | if roi_mask is not None: 140 | blobs = (blobs * roi_mask[None]).astype(int) 141 | 142 | return blobs 143 | 144 | 145 | def blobs2points(blobs): 146 | blobs = blobs.squeeze() 147 | points = np.zeros(blobs.shape).astype("uint8") 148 | rps = skimage.measure.regionprops(blobs) 149 | 150 | assert points.ndim == 2 151 | 152 | for r in rps: 153 | y, x = r.centroid 154 | 155 | points[int(y), int(x)] = 1 156 | 157 | return points 158 | 159 | def compute_game(pred_points, gt_points, L=1): 160 | n_rows = 2**L 161 | n_cols = 2**L 162 | 163 | pred_points = pred_points.astype(float).squeeze() 164 | gt_points = np.array(gt_points).astype(float).squeeze() 165 | h, w = pred_points.shape 166 | se = 0. 167 | 168 | hs, ws = h//n_rows, w//n_cols 169 | for i in range(n_rows): 170 | for j in range(n_cols): 171 | 172 | sr, er = hs*i, hs*(i+1) 173 | sc, ec = ws*j, ws*(j+1) 174 | 175 | pred_count = pred_points[sr:er, sc:ec] 176 | gt_count = gt_points[sr:er, sc:ec] 177 | 178 | se += float(abs(gt_count.sum() - pred_count.sum())) 179 | return se / (L+1) 180 | 181 | def save_tmp(fname, images, logits, radius, points): 182 | from haven import haven_utils as hu 183 | probs = F.softmax(logits, 1); 184 | mask = probs.argmax(dim=1).cpu().numpy().astype('uint8').squeeze() 185 | img_mask=hu.save_image('tmp2.png', 186 | hu.denormalize(images, mode='rgb'), 187 | mask=mask, return_image=True) 188 | hu.save_image(fname,np.array(img_mask)/255. , radius=radius, 189 | points=points) 190 | 191 | def get_random_points(mask, n_points=1, seed=1): 192 | from haven import haven_utils as hu 193 | y_list, x_list = np.where(mask) 194 | points = np.zeros(mask.squeeze().shape) 195 | with hu.random_seed(seed): 196 | for i in range(n_points): 197 | yi = np.random.choice(y_list) 198 | x_tmp = x_list[y_list == yi] 199 | xi = np.random.choice(x_tmp) 200 | points[yi, xi] = 1 201 | 202 | return points 203 | 204 | def get_points_from_mask(mask, bg_points=0): 205 | n_points = 0 206 | points = np.zeros(mask.shape) 207 | # print(np.unique(mask)) 208 | assert(len(np.setdiff1d(np.unique(mask),[0,1,2] ))==0) 209 | 210 | for c in np.unique(mask): 211 | if c == 0: 212 | continue 213 | blobs = morph.label((mask==c).squeeze()) 214 | points_class = blobs2points(blobs) 215 | 216 | ind = points_class!=0 217 | n_points += int(points_class[ind].sum()) 218 | points[ind] = c 219 | assert morph.label((mask).squeeze()).max() == n_points 220 | points[points==0] = 255 221 | if bg_points == -1: 222 | bg_points = n_points 223 | 224 | if bg_points: 225 | from haven import haven_utils as hu 226 | y_list, x_list = np.where(mask==0) 227 | with hu.random_seed(1): 228 | for i in range(bg_points): 229 | yi = np.random.choice(y_list) 230 | x_tmp = x_list[y_list == yi] 231 | xi = np.random.choice(x_tmp) 232 | points[yi, xi] = 0 233 | 234 | return points 235 | -------------------------------------------------------------------------------- /src/models/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | from scipy import spatial 4 | import numpy as np 5 | import torch 6 | 7 | from collections import defaultdict 8 | from src.modules.lcfcn import lcfcn_loss 9 | from scipy import spatial 10 | import numpy as np 11 | import torch 12 | 13 | from collections import defaultdict 14 | 15 | from scipy import spatial 16 | import numpy as np 17 | import torch 18 | 19 | 20 | class SegMeter: 21 | def __init__(self, split): 22 | self.cf = None 23 | self.n_samples = 0 24 | self.split = split 25 | self.ae = 0 26 | self.game = 0 27 | 28 | def val_on_batch(self, model, batch): 29 | masks = batch["masks"].squeeze() 30 | self.n_samples += batch['images'].shape[0] 31 | pred_mask = model.predict_on_batch(batch).squeeze() 32 | 33 | # counts 34 | blobs = lcfcn_loss.get_blobs(pred_mask) 35 | points = lcfcn_loss.blobs2points(blobs) 36 | pred_counts = float(points.sum()) 37 | self.ae += np.abs(float((batch['points']==1).sum()) - pred_counts) 38 | 39 | gt_points = batch['points'].squeeze().clone() 40 | gt_points[gt_points!=1] = 0 41 | self.game += lcfcn_loss.compute_game(pred_points=points.squeeze(), 42 | gt_points=gt_points, L=3) 43 | # 44 | # print(masks.sum()) 45 | ind = masks != 255 46 | masks = masks[ind] 47 | pred_mask = pred_mask[ind] 48 | 49 | 50 | 51 | labels = np.arange(model.n_classes) 52 | cf = confusion_multi_class(torch.as_tensor(pred_mask).float().cuda(), masks.cuda().float(), 53 | labels=labels) 54 | 55 | 56 | if self.cf is None: 57 | self.cf = cf 58 | else: 59 | self.cf += cf 60 | 61 | def get_avg_score(self): 62 | # return -1 63 | Inter = np.diag(self.cf) 64 | G = self.cf.sum(axis=1) 65 | P = self.cf.sum(axis=0) 66 | union = G + P - Inter 67 | 68 | nz = union != 0 69 | iou = Inter / np.maximum(union, 1) 70 | mIoU = np.mean(iou[nz]) 71 | iou[~nz] = np.nan 72 | val_dict = {'%s_score' % self.split: mIoU} 73 | for c in range(self.cf.shape[1]): 74 | val_dict['%s_class%d' % (self.split, c)] = iou[c] 75 | val_dict['%s_mae' % (self.split)] = self.ae / self.n_samples 76 | val_dict['%s_game' % (self.split)] = self.game / self.n_samples 77 | return val_dict 78 | 79 | 80 | 81 | def confusion_multi_class(prediction, truth, labels): 82 | """ 83 | cf = confusion_matrix(y_true=prediction.cpu().numpy().ravel(), 84 | y_pred=truth.cpu().numpy().ravel(), 85 | labels=labels) 86 | """ 87 | nclasses = labels.max() + 1 88 | cf2 = torch.zeros(nclasses, nclasses, dtype=torch.float, device=prediction.device) 89 | prediction = prediction.view(-1).long() 90 | truth = truth.view(-1) 91 | to_one_hot = torch.eye(int(nclasses), dtype=cf2.dtype, device=prediction.device) 92 | for c in range(nclasses): 93 | true_mask = (truth == c) 94 | pred_one_hot = to_one_hot[prediction[true_mask]].sum(0) 95 | cf2[:, c] = pred_one_hot 96 | 97 | return cf2.cpu().numpy() 98 | 99 | 100 | def confusion_binary_class(prediction, truth): 101 | confusion_vector = prediction / truth 102 | 103 | tp = torch.sum(confusion_vector == 1).item() 104 | fp = torch.sum(confusion_vector == float('inf')).item() 105 | tn = torch.sum(torch.isnan(confusion_vector)).item() 106 | fn = torch.sum(confusion_vector == 0).item() 107 | cm = np.array([[tn,fp], 108 | [fn,tp]]) 109 | return cm 110 | 111 | 112 | 113 | class SegMeterBinary: 114 | def __init__(self, split): 115 | self.cf = None 116 | self.struct_list = [] 117 | self.split = split 118 | 119 | def val_on_batch(self, model, batch): 120 | masks_org = batch["masks"] 121 | 122 | pred_mask_org = model.predict_on_batch(batch) 123 | ind = masks_org != 255 124 | masks = masks_org[ind] 125 | pred_mask = pred_mask_org[ind] 126 | self.n_classes = model.n_classes 127 | if model.n_classes == 1: 128 | cf = confusion_binary_class(torch.as_tensor(pred_mask).float().cuda(), masks.cuda().float()) 129 | else: 130 | labels = np.arange(model.n_classes) 131 | cf = confusion_multi_class(torch.as_tensor(pred_mask).float().cuda(), masks.cuda().float(), 132 | labels=labels) 133 | 134 | if self.cf is None: 135 | self.cf = cf 136 | else: 137 | self.cf += cf 138 | 139 | # structure 140 | struct_score = float(struct_metric.compute_struct_metric(pred_mask_org, masks_org)) 141 | self.struct_list += [struct_score] 142 | 143 | def get_avg_score(self): 144 | TP = np.diag(self.cf) 145 | TP_FP = self.cf.sum(axis=1) 146 | TP_FN = self.cf.sum(axis=0) 147 | TN = TP[::-1] 148 | 149 | 150 | FP = TP_FP - TP 151 | FN = TP_FN - TP 152 | 153 | iou = TP / (TP + FP + FN) 154 | dice = 2*TP / (FP + FN + 2*TP) 155 | 156 | iou[np.isnan(iou)] = -1 157 | dice[np.isnan(dice)] = -1 158 | 159 | mDice = np.mean(dice) 160 | mIoU = np.mean(iou) 161 | 162 | prec = TP / (TP + FP) 163 | recall = TP / (TP + FN) 164 | spec = TN/(TN+FP) 165 | fscore = (( 2.0 * prec * recall ) / (prec + recall)) 166 | 167 | val_dict = {} 168 | if self.n_classes == 1: 169 | val_dict['%s_dice' % self.split] = dice[0] 170 | val_dict['%s_iou' % self.split] = iou[0] 171 | 172 | val_dict['%s_prec' % self.split] = prec[0] 173 | val_dict['%s_recall' % self.split] = recall[0] 174 | val_dict['%s_spec' % self.split] = spec[0] 175 | val_dict['%s_fscore' % self.split] = fscore[0] 176 | 177 | val_dict['%s_score' % self.split] = dice[0] 178 | val_dict['%s_struct' % self.split] = np.mean(self.struct_list) 179 | return val_dict 180 | 181 | # def confusion_multi_class(prediction, truth, labels): 182 | # """ 183 | # cf = confusion_matrix(y_true=prediction.cpu().numpy().ravel(), 184 | # y_pred=truth.cpu().numpy().ravel(), 185 | # labels=labels) 186 | # """ 187 | # nclasses = labels.max() + 1 188 | # cf2 = torch.zeros(nclasses, nclasses, dtype=torch.float, 189 | # device=prediction.device) 190 | # prediction = prediction.view(-1).long() 191 | # truth = truth.view(-1) 192 | # to_one_hot = torch.eye(int(nclasses), dtype=cf2.dtype, 193 | # device=prediction.device) 194 | # for c in range(nclasses): 195 | # true_mask = (truth == c) 196 | # pred_one_hot = to_one_hot[prediction[true_mask]].sum(0) 197 | # cf2[:, c] = pred_one_hot 198 | 199 | # return cf2.cpu().numpy() 200 | 201 | 202 | 203 | def confusion_binary_class(pred_mask, gt_mask): 204 | intersect = pred_mask.bool() & gt_mask.bool() 205 | 206 | fp_tp = (pred_mask ==1).sum().item() 207 | fn_tp = gt_mask.sum().item() 208 | tn_fn = (pred_mask ==0).sum().item() 209 | 210 | tp = (intersect == 1).sum().item() 211 | fp = fp_tp - tp 212 | fn = fn_tp - tp 213 | tn = tn_fn - fn 214 | 215 | cm = np.array([[tp, fp], 216 | [fn, tn]]) 217 | return cm -------------------------------------------------------------------------------- /scripts/SEAM/network/resnet38_aff.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.sparse as sparse 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from . import resnet38d 7 | from ..tool import pyutils 8 | 9 | 10 | def crf_inference(img, probs, t=10, scale_factor=1, labels=2): 11 | import pydensecrf.densecrf as dcrf 12 | from pydensecrf.utils import unary_from_softmax 13 | 14 | h, w = img.shape[:2] 15 | n_labels = labels 16 | 17 | d = dcrf.DenseCRF2D(w, h, n_labels) 18 | 19 | unary = unary_from_softmax(probs) 20 | unary = np.ascontiguousarray(unary) 21 | 22 | d.setUnaryEnergy(unary) 23 | d.addPairwiseGaussian(sxy=3/scale_factor, compat=3) 24 | d.addPairwiseBilateral(sxy=80/scale_factor, srgb=13, rgbim=np.copy(img), compat=10) 25 | Q = d.inference(t) 26 | 27 | return np.array(Q).reshape((n_labels, h, w)) 28 | 29 | class Net(resnet38d.Net): 30 | def __init__(self, n_classes, exp_dict): 31 | super(Net, self).__init__() 32 | 33 | self.f8_3 = torch.nn.Conv2d(512, 64, 1, bias=False) 34 | self.f8_4 = torch.nn.Conv2d(1024, 128, 1, bias=False) 35 | self.f8_5 = torch.nn.Conv2d(4096, 256, 1, bias=False) 36 | 37 | self.f9 = torch.nn.Conv2d(448, 448, 1, bias=False) 38 | 39 | torch.nn.init.kaiming_normal_(self.f8_3.weight) 40 | torch.nn.init.kaiming_normal_(self.f8_4.weight) 41 | torch.nn.init.kaiming_normal_(self.f8_5.weight) 42 | torch.nn.init.xavier_uniform_(self.f9.weight, gain=4) 43 | 44 | self.not_training = [self.conv1a, self.b2, self.b2_1, self.b2_2] 45 | 46 | self.from_scratch_layers = [self.f8_3, self.f8_4, self.f8_5, self.f9] 47 | 48 | self.predefined_featuresize = int(448//8) 49 | self.radius = 5 50 | self.ind_from, self.ind_to = pyutils.get_indices_of_pairs(radius=self.radius, size=(self.predefined_featuresize, self.predefined_featuresize)) 51 | self.ind_from = torch.from_numpy(self.ind_from); self.ind_to = torch.from_numpy(self.ind_to) 52 | 53 | self.dropout7 = torch.nn.Dropout2d(0.5) 54 | 55 | self.fc8 = nn.Conv2d(4096, n_classes, 1, bias=False) 56 | self.beta = exp_dict['model'].get('beta', 8) 57 | self.logt = exp_dict['model'].get('logt', 4) 58 | return 59 | 60 | def apply_affinity(self, img, logits, crf=False): 61 | h_org, w_org = img.shape[2:] 62 | padded_size = (int(np.ceil(img.shape[2]/8)*8), int(np.ceil(img.shape[3]/8)*8)) 63 | p2d = (0, padded_size[1] - img.shape[3], 0, padded_size[0] - img.shape[2]) 64 | img = F.pad(img, p2d) 65 | beta = self.beta 66 | logt = self.logt 67 | aff_mat = torch.pow(self.forward(img.cuda(), True), beta) 68 | trans_mat = aff_mat / torch.sum(aff_mat, dim=0, keepdim=True) 69 | 70 | for _ in range(logt): 71 | trans_mat = torch.matmul(trans_mat, trans_mat) 72 | 73 | n_classes = logits.shape[1] 74 | logits = F.pad(logits, p2d) 75 | _,_, h, w = logits.shape 76 | # indices_from_to = get_indices_in_radius(h, w, radius=5) 77 | # labels = get_affinity_labels(logits.argmax(dim=1).cpu().numpy(), indices_from_to[:,0], indices_from_to[:,1]) 78 | cam = F.avg_pool2d(logits, 8, 8) 79 | cam_vec = cam.view(n_classes, -1) 80 | cam_vec = torch.matmul(cam_vec.cuda(), trans_mat) 81 | logits_new = cam_vec.view(1, n_classes, h//8, w//8) 82 | 83 | logits_new = torch.nn.functional.interpolate(logits_new, (h, w), mode='bilinear') 84 | if crf: 85 | img_8 = img[0].cpu().numpy().transpose((1,2,0)) 86 | img_8 = np.ascontiguousarray(img_8) 87 | mean = (0.485, 0.456, 0.406) 88 | std = (0.229, 0.224, 0.225) 89 | img_8[:,:,0] = (img_8[:,:,0]*std[0] + mean[0])*255 90 | img_8[:,:,1] = (img_8[:,:,1]*std[1] + mean[1])*255 91 | img_8[:,:,2] = (img_8[:,:,2]*std[2] + mean[2])*255 92 | img_8[img_8 > 255] = 255 93 | img_8[img_8 < 0] = 0 94 | img_8 = img_8.astype(np.uint8) 95 | cam_rw = logits_new[0].detach().cpu().numpy() 96 | cam_rw = crf_inference(img_8, cam_rw, t=1) 97 | cam_rw = torch.from_numpy(cam_rw).view(1, 2, img.shape[2], img.shape[3]).cuda() 98 | logits_new = cam_rw 99 | 100 | logits_new = torch.nn.functional.interpolate(logits_new, (h_org, w_org), mode='bilinear') 101 | return logits_new 102 | 103 | def output_logits(self, x, to_dense=False): 104 | 105 | d = super().forward_as_dict(x) 106 | 107 | cam = self.fc8(self.dropout7(d['conv6'])) 108 | h, w = x.shape[-2:] 109 | return torch.nn.functional.interpolate(cam, (h, w), mode='bilinear', align_corners=True) 110 | 111 | 112 | def forward(self, x, to_dense=False): 113 | 114 | d = super().forward_as_dict(x) 115 | 116 | f8_3 = F.elu(self.f8_3(d['conv4'])) 117 | f8_4 = F.elu(self.f8_4(d['conv5'])) 118 | f8_5 = F.elu(self.f8_5(d['conv6'])) 119 | x = F.elu(self.f9(torch.cat([f8_3, f8_4, f8_5], dim=1))) 120 | 121 | if x.size(2) == self.predefined_featuresize and x.size(3) == self.predefined_featuresize: 122 | ind_from = self.ind_from 123 | ind_to = self.ind_to 124 | else: 125 | min_edge = min(x.size(2), x.size(3)) 126 | radius = (min_edge-1)//2 if min_edge < self.radius*2+1 else self.radius 127 | ind_from, ind_to = pyutils.get_indices_of_pairs(radius, (x.size(2), x.size(3))) 128 | ind_from = torch.from_numpy(ind_from); ind_to = torch.from_numpy(ind_to) 129 | 130 | x = x.view(x.size(0), x.size(1), -1).contiguous() 131 | ind_from = ind_from.contiguous() 132 | ind_to = ind_to.contiguous() 133 | 134 | ff = torch.index_select(x, dim=2, index=ind_from.cuda(non_blocking=True)) 135 | ft = torch.index_select(x, dim=2, index=ind_to.cuda(non_blocking=True)) 136 | 137 | ff = torch.unsqueeze(ff, dim=2) 138 | ft = ft.view(ft.size(0), ft.size(1), -1, ff.size(3)) 139 | 140 | aff = torch.exp(-torch.mean(torch.abs(ft-ff), dim=1)) 141 | 142 | if to_dense: 143 | aff = aff.view(-1).cpu() 144 | 145 | ind_from_exp = torch.unsqueeze(ind_from, dim=0).expand(ft.size(2), -1).contiguous().view(-1) 146 | indices = torch.stack([ind_from_exp, ind_to]) 147 | indices_tp = torch.stack([ind_to, ind_from_exp]) 148 | 149 | area = x.size(2) 150 | indices_id = torch.stack([torch.arange(0, area).long(), torch.arange(0, area).long()]) 151 | 152 | aff_mat = sparse.FloatTensor(torch.cat([indices, indices_id, indices_tp], dim=1), 153 | torch.cat([aff, torch.ones([area]), aff])).to_dense().cuda() 154 | 155 | return aff_mat 156 | 157 | else: 158 | return aff 159 | 160 | 161 | def get_parameter_groups(self): 162 | groups = ([], [], [], []) 163 | 164 | for m in self.modules(): 165 | 166 | if (isinstance(m, nn.Conv2d) or isinstance(m, nn.modules.normalization.GroupNorm)): 167 | 168 | if m.weight.requires_grad: 169 | if m in self.from_scratch_layers: 170 | groups[2].append(m.weight) 171 | else: 172 | groups[0].append(m.weight) 173 | 174 | if m.bias is not None and m.bias.requires_grad: 175 | 176 | if m in self.from_scratch_layers: 177 | groups[3].append(m.bias) 178 | else: 179 | groups[1].append(m.bias) 180 | 181 | return groups 182 | 183 | 184 | 185 | --------------------------------------------------------------------------------