├── README.md ├── datasets.py ├── metadata ├── colormap_coarse.csv ├── train_list.txt └── val_list.txt ├── models.py ├── opts.py ├── segmentTool ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── models.cpython-36.pyc │ ├── resnet.cpython-36.pyc │ └── resnext.cpython-36.pyc ├── lib │ ├── nn │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ └── __init__.cpython-36.pyc │ │ ├── modules │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-36.pyc │ │ │ │ ├── batchnorm.cpython-36.pyc │ │ │ │ ├── comm.cpython-36.pyc │ │ │ │ └── replicate.cpython-36.pyc │ │ │ ├── batchnorm.py │ │ │ ├── comm.py │ │ │ ├── replicate.py │ │ │ ├── tests │ │ │ │ ├── test_numeric_batchnorm.py │ │ │ │ └── test_sync_batchnorm.py │ │ │ └── unittest.py │ │ └── parallel │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ └── data_parallel.cpython-36.pyc │ │ │ └── data_parallel.py │ └── utils │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ └── th.cpython-36.pyc │ │ ├── data │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── dataloader.cpython-36.pyc │ │ │ ├── dataset.cpython-36.pyc │ │ │ └── sampler.cpython-36.pyc │ │ ├── dataloader.py │ │ ├── dataset.py │ │ ├── distributed.py │ │ └── sampler.py │ │ └── th.py ├── models.py ├── resnet.py └── resnext.py ├── test.py ├── test_carla.py ├── test_seq.py ├── tools ├── get_trainning_data_from_house3d.py ├── process_data_for_VPN.py ├── process_transfer_driving_data.py └── process_transfer_indoor_data.py ├── train.py ├── train_carla.py ├── train_transfer.py ├── transform.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Cross-view Semantic Segmentation for Sensing Surroundings 2 | 3 | We release the code of the View Parsing Networks, the main model for Cross-view Semantic Segmentation task. 4 | 5 | 6 | **[Cross-View Semantic Segmentation for Sensing Surroundings 7 | ](https://arxiv.org/pdf/1906.03560.pdf)** 8 |
9 | [Bowen Pan](http://people.csail.mit.edu/bpan/), 10 | [Jiankai Sun](), 11 | [Ho Yin Tiga Leung](), 12 | [Alex Andonian](https://www.alexandonian.com/), and 13 | [Bolei Zhou](http://bzhou.ie.cuhk.edu.hk/) 14 |
15 | IEEE Robotics and Automation Letters 16 |
17 | In IEEE International Conference on Intelligent Robots and Systems (IROS) 2020 18 |
19 | [[Paper]](https://arxiv.org/pdf/1906.03560.pdf) 20 | [[Project Page]](https://decisionforce.github.io/VPN/) 21 | 22 | ``` 23 | @ARTICLE{pan2019crossview, 24 | author={B. {Pan} and J. {Sun} and H. Y. T. {Leung} and A. {Andonian} and B. {Zhou}}, 25 | journal={IEEE Robotics and Automation Letters}, 26 | title={Cross-View Semantic Segmentation for Sensing Surroundings}, 27 | year={2020}, 28 | volume={5}, 29 | number={3}, 30 | pages={4867-4873}, 31 | } 32 | ``` 33 | 34 | 35 | ### Requirement 36 | - Install the [House3D](https://github.com/facebookresearch/House3D) simulator, or [Gibson](http://gibsonenv.stanford.edu) simulator. 37 | - Software: Ubuntu 16.04.3 LTS, CUDA>=8.0, Python>=3.5, PyTorch>=0.4.0 38 | 39 | ## Train and test VPN 40 | 41 | ### Data processing (use House3D for example) 42 | - Use [get_training_data_from_house3d.py](https://github.com/pbw-Berwin/View-Parsing-Network/blob/master/tools/get_trainning_data_from_house3d.py) to extract data from House3D environment. 43 | - Use [process_data_for_VPN.py](https://github.com/pbw-Berwin/View-Parsing-Network/blob/master/tools/process_data_for_VPN.py) to split training set and validation set, and visualize data. 44 | 45 | ### Training Command 46 | ``` 47 | # Training in indoor-room scenarios, using RGB input modality, with 8 input views. 48 | python -u train.py --fc-dim 256 --use-depth false --use-mask false --transform-type fc --input-resolution 400 --label-res 25 --store-name [STORE_NAME] --n-views 8 --batch-size 48 -j 10 --data_root [PATH_TO_DATASET_ROOT] --train-list [PATH_TO_TRAIN_LIST] --eval-list [PATH_TO_EVAL_LIST] 49 | 50 | # Training in driving-traffic scenarios, using RGB input modality, with 6 input views. 51 | python -u train_carla.py --fc-dim 256 --use-depth false --use-mask false --transform-type fc --input-resolution 400 --label-res 25 --store-name [STORE_NAME] --n-views 6 --batch-size 48 -j 10 --data_root [PATH_TO_DATASET_ROOT] --train-list [PATH_TO_TRAIN_LIST] --eval-list [PATH_TO_EVAL_LIST] 52 | ``` 53 | 54 | ### Testing Command 55 | ``` 56 | # Training in indoor-room scenarios, using RGB input modality, with 8 input views. 57 | python -u test.py --fc-dim 256 --use-depth false --use-mask false --transform-type fc --input-resolution 400 --label-res 25 --store-name [STORE_NAME] --n-views 8 --batch-size 4 --test-views 8 --data_root [PATH_TO_DATASET_ROOT] --eval-list [PATH_TO_EVAL_LIST] --num-class [NUM_CLASS] -j 10 --weights [PATH_TO_PRETRAIN_MODEL] 58 | 59 | # Testing in driving-traffic scenarios, using RGB input modality, with 6 input views. 60 | python -u test_carla.py --fc-dim 256 --use-depth false --use-mask false --transform-type fc --input-resolution 400 --label-res 25 --store-name [STORE_NAME] --n-views 6 --batch-size 4 --test-views 6 --data_root [PATH_TO_DATASET_ROOT] --eval-list [PATH_TO_EVAL_LIST] --num-class [NUM_CLASS] -j 10 --weights [PATH_TO_PRETRAIN_MODEL] 61 | ``` 62 | 63 | ## Transfer learning for sim-to-real adaptation 64 | 65 | ### Data processing (use indoor-room scenarios for example) 66 | - Use [process_transfer_indoor_data.py](https://github.com/pbw-Berwin/View-Parsing-Network/blob/master/tools/process_transfer_indoor_data.py) to split source domain set and target domain set. 67 | 68 | ### Training Command 69 | ``` 70 | # Training in indoor-room scenarios, using RGB input modality, with 8 input views. 71 | python -u train_transfer.py --task-id [TASK_NAME] --num-class [NUM_CLASS] --learning-rate-D 3e-6 --iter-size-G 1 --iter-size-D 1 --snapshot-dir ./snapshot --batch-size 20 --tensorboard true --n-views 6 --train_source_list [PATH_TO_TRAIN_LIST] --train_target_list [PATH_TO_EVAL_LIST] --VPN-weights [PATH_TO_PRETRAINED_WEIGHT] --scenarios indoor 72 | ``` 73 | -------------------------------------------------------------------------------- /metadata/colormap_coarse.csv: -------------------------------------------------------------------------------- 1 | name,r,g,b 2 | ottoman,0,0,0 3 | storage_bench,255,255,0 4 | mirror,28,230,255 5 | shower,255,52,255 6 | kitchen_appliance,255,74,70 7 | sofa,0,137,65 8 | outdoor_lamp,0,111,166 9 | tripod,163,0,89 10 | toilet,122,73,0 11 | cart,0,0,166 12 | Ground,99,255,172 13 | decoration,183,151,98 14 | arch,0,77,67 15 | unknown,143,176,255 16 | pet,153,125,135 17 | computer,90,0,7 18 | stairs,128,150,147 19 | bed,254,255,230 20 | heater,27,68,0 21 | plant,79,198,1 22 | mailbox,59,93,255 23 | fence,74,59,83 24 | shelving,255,47,128 25 | safe,97,97,90 26 | person,186,9,0 27 | desk,107,121,0 28 | television,0,194,160 29 | stand,255,170,146 30 | table_and_chair,255,144,201 31 | whiteboard,185,3,170 32 | ceiling,209,97,0 33 | coffin,221,239,255 34 | window,0,0,53 35 | drinkbar,123,79,75 36 | picture_frame,161,194,153 37 | chair,48,0,24 38 | wood_board,10,166,216 39 | headstone,1,51,73 40 | OTHER,0,132,111 41 | wardrobe_cabinet,55,33,1 42 | floor,255,181,0 43 | fireplace,194,255,237 44 | cloth,160,121,191 45 | empty,204,7,68 46 | tv_stand,194,255,153 47 | garage_door,0,30,9 48 | workplace,0,72,156 49 | table,111,0,98 50 | hanging_kitchen_cabinet,12,189,102 51 | shoes,238,195,255 52 | trash_can,69,109,117 53 | kitchen_set,183,123,104 54 | ATM,122,135,161 55 | bench_chair,120,141,102 56 | hanger,136,85,120 57 | outdoor_seating,250,208,159 58 | bathtub,255,138,154 59 | kitchenware,209,87,160 60 | outdoor_cover,190,196,89 61 | dressing_table,69,102,72 62 | bathroom_stuff,0,134,237 63 | vehicle,136,111,76 64 | candle,52,54,45 65 | dresser,180,168,189 66 | Ceiling,0,166,170 67 | clock,69,44,44 68 | pool,99,99,117 69 | door,163,200,201 70 | fan,255,145,63 71 | magazines,147,138,129 72 | kitchen_cabinet,87,83,41 73 | partition,0,254,207 74 | recreation,176,91,111 75 | trinket,140,208,255 76 | Box,59,151,0 77 | wall,4,247,87 78 | roof,200,161,161 79 | books,30,110,0 80 | gym_equipment,121,0,215 81 | sink,167,117,0 82 | rug,99,103,169 83 | outdoor_spring,160,88,55 84 | indoor_lamp,107,0,44 85 | household_appliance,119,38,0 86 | pillow,215,144,255 87 | air_conditioner,155,151,0 88 | toy,84,158,121 89 | grill,255,246,159 90 | vase,32,22,37 91 | column,114,65,143 92 | switch,188,35,255 93 | shoes_cabinet,153,173,192 94 | curtain,58,36,101 95 | music,146,35,41 96 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : models.py 4 | # Author : Bowen Pan 5 | # Email : panbowen0607@gmail.com 6 | # Date : 09/18/2018 7 | # 8 | # Distributed under terms of the MIT license. 9 | 10 | import torch 11 | from torch import nn 12 | import utils 13 | from collections import OrderedDict 14 | import torch.nn.functional as F 15 | from segmentTool.models import * 16 | import numpy as np 17 | from itertools import combinations 18 | 19 | builder = ModelBuilder() 20 | 21 | class TransformModule(nn.Module): 22 | def __init__(self, dim=25, num_view=8): 23 | super(TransformModule, self).__init__() 24 | self.num_view = num_view 25 | self.dim = dim 26 | self.mat_list = nn.ModuleList() 27 | for i in range(self.num_view): 28 | fc_transform = nn.Sequential( 29 | nn.Linear(dim * dim, dim * dim), 30 | nn.ReLU(), 31 | nn.Linear(dim * dim, dim * dim), 32 | nn.ReLU() 33 | ) 34 | self.mat_list += [fc_transform] 35 | 36 | def forward(self, x): 37 | # shape x: B, V, C, H, W 38 | x = x.view(list(x.size()[:3]) + [self.dim * self.dim,]) 39 | view_comb = self.mat_list[0](x[:, 0]) 40 | for index in range(x.size(1))[1:]: 41 | view_comb += self.mat_list[index](x[:, index]) 42 | view_comb = view_comb.view(list(view_comb.size()[:2]) + [self.dim, self.dim]) 43 | return view_comb 44 | 45 | 46 | class SumModule(nn.Module): 47 | def __init__(self): 48 | super(SumModule, self).__init__() 49 | 50 | def forward(self, x): 51 | # shape x: B, V, C, H, W 52 | x = torch.sum(x, dim=1, keepdim=False) 53 | return x 54 | 55 | 56 | class VPNModel(nn.Module): 57 | def __init__(self, config): 58 | super(VPNModel, self).__init__() 59 | self.num_views = config.num_views 60 | self.output_size = config.output_size 61 | self.transform_type = config.transform_type 62 | print('Views number: ' + str(self.num_views)) 63 | print('Transform Type: ', self.transform_type) 64 | self.encoder = builder.build_encoder( 65 | arch=config.encoder, 66 | fc_dim=config.fc_dim, 67 | ) 68 | if self.transform_type == 'fc': 69 | self.transform_module = TransformModule(dim=self.output_size, num_view=self.num_views) 70 | elif self.transform_type == 'sum': 71 | self.transform_module = SumModule() 72 | self.decoder = builder.build_decoder( 73 | arch=config.decoder, 74 | fc_dim=config.fc_dim, 75 | num_class=config.num_class, 76 | use_softmax=False, 77 | ) 78 | 79 | def forward(self, x, return_feat=False): 80 | B, N, C, H, W = x.view([-1, self.num_views, int(x.size()[1] / self.num_views)] \ 81 | + list(x.size()[2:])).size() 82 | 83 | x = x.view(B*N, C, H, W) 84 | x = self.encoder(x)[0] 85 | x = x.view([B, N] + list(x.size()[1:])) 86 | x = self.transform_module(x) 87 | if return_feat: 88 | x, feat = self.decoder([x], return_feat=return_feat) 89 | else: 90 | x = self.decoder([x]) 91 | x = x.transpose(1,2).transpose(2,3).contiguous() 92 | if return_feat: 93 | feat = feat.transpose(1,2).transpose(2,3).contiguous() 94 | return x, feat 95 | return x 96 | 97 | 98 | 99 | class FCDiscriminator(nn.Module): 100 | def __init__(self, num_classes, ndf = 64): 101 | super(FCDiscriminator, self).__init__() 102 | 103 | self.conv1 = nn.Conv2d(num_classes, ndf, kernel_size=4, stride=2, padding=1) 104 | self.conv2 = nn.Conv2d(ndf, ndf*2, kernel_size=4, stride=2, padding=1) 105 | self.conv3 = nn.Conv2d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1) 106 | self.conv4 = nn.Conv2d(ndf*4, ndf*8, kernel_size=4, stride=2, padding=1) 107 | self.classifier = nn.Conv2d(ndf*8, 1, kernel_size=4, stride=2, padding=1) 108 | 109 | self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 110 | #self.up_sample = nn.Upsample(scale_factor=32, mode='bilinear') 111 | #self.sigmoid = nn.Sigmoid() 112 | 113 | 114 | def forward(self, x): 115 | x = self.conv1(x) 116 | x = self.leaky_relu(x) 117 | x = self.conv2(x) 118 | x = self.leaky_relu(x) 119 | x = self.conv3(x) 120 | x = self.leaky_relu(x) 121 | x = self.conv4(x) 122 | x = self.leaky_relu(x) 123 | x = self.classifier(x) 124 | #x = self.up_sample(x) 125 | #x = self.sigmoid(x) 126 | return x 127 | 128 | 129 | 130 | -------------------------------------------------------------------------------- /opts.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def str2bool(v): 4 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 5 | return True 6 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 7 | return False 8 | else: 9 | raise argparse.ArgumentTypeError('Unsupported value encountered.') 10 | 11 | parser = argparse.ArgumentParser(description="PyTorch implementation of CMP") 12 | parser.add_argument('--data_root', type=str, default='/data/vision/oliva/scenedataset/syntheticscene/TopViewMaskDataset') 13 | parser.add_argument('--test-dir', type=str, default='') 14 | parser.add_argument('--train-list', type=str, default='./metadata/train_list.txt') 15 | parser.add_argument('--eval-list', type=str, default='./metadata/val_list.txt') 16 | parser.add_argument('--start-lr', type=float, default=2e-4) 17 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, 18 | metavar='W', help='weight decay (default: 5e-4)') 19 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 20 | help='momentum') 21 | parser.add_argument('--transform-type', type=str, default='fc') 22 | parser.add_argument('--use-mask', type=str2bool, nargs='?', const=False) 23 | parser.add_argument('--use-depth', type=str2bool, nargs='?', const=False) 24 | parser.add_argument('--input-resolution', default=400, type=int, metavar='N') 25 | parser.add_argument('--label-resolution', default=25, type=int, metavar='N') 26 | parser.add_argument('--fc-dim', default=256, type=int, metavar='N') 27 | parser.add_argument('--segSize', default=256, type=int, metavar='N') 28 | parser.add_argument('--log-root', type=str, default='./log') 29 | parser.add_argument('--encoder', type=str, default='resnet18') 30 | parser.add_argument('--decoder', type=str, default='ppm_bilinear') 31 | parser.add_argument('--store-name', type=str, default='') 32 | parser.add_argument('--start_epoch', type=int, default=0) 33 | parser.add_argument('--epochs', type=int, default=40) 34 | parser.add_argument('--n-views', type=int, default=8) 35 | parser.add_argument('--lr_steps', default=[10], type=float, nargs="+", 36 | metavar='LRSteps', help='epochs to decay learning rate by 10') 37 | parser.add_argument('--print-freq', '-p', default=10, type=int, 38 | metavar='N', help='print frequency (default: 10)') 39 | parser.add_argument('--print-img-freq', default=100, type=int, 40 | metavar='N', help='print frequency (default: 100)') 41 | parser.add_argument('--ckpt-freq', default=1, type=int, 42 | metavar='N', help='save frequency (default: 2)') 43 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 44 | help='evaluate model on validation set') 45 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 46 | help='path to latest checkpoint (default: none)') 47 | parser.add_argument('--tflogdir', default='', type=str, metavar="PATH") 48 | parser.add_argument('--logname', type=str, default="") 49 | parser.add_argument('-b', '--batch-size', default=104, type=int, 50 | metavar='N', help='mini-batch size (default: 256)') 51 | parser.add_argument('--scale_size', default=224, type=int) 52 | parser.add_argument('--num-class', default=94, type=int) 53 | parser.add_argument('-j', '--num_workers', default=4, type=int, metavar='N', 54 | help='number of data loading workers (default: 4)') 55 | parser.add_argument('--root_model', type=str, default='model') 56 | parser.add_argument('--weights', type=str, 57 | default='') 58 | parser.add_argument('--visualize', type=str, default='./visualize') 59 | parser.add_argument('--centralSize', type=int, default=12) 60 | parser.add_argument('--mapSize', type=int, default=1000) 61 | parser.add_argument('--ppi', type=int, default=4) 62 | parser.add_argument('--scale', type=float, default=2) 63 | parser.add_argument('--trajectory-file', type=str, default='') 64 | parser.add_argument('--real-scale', type=float, default=1.2) 65 | parser.add_argument('--use-topdown', type=str2bool, default=False) 66 | parser.add_argument('--visual-input', type=str2bool, default=False) 67 | -------------------------------------------------------------------------------- /segmentTool/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import ModelBuilder, SegmentationModule 2 | -------------------------------------------------------------------------------- /segmentTool/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbw-Berwin/View-Parsing-Network/9ccfcce978aecbe6c469256013e6ca94a2436b67/segmentTool/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /segmentTool/__pycache__/models.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbw-Berwin/View-Parsing-Network/9ccfcce978aecbe6c469256013e6ca94a2436b67/segmentTool/__pycache__/models.cpython-36.pyc -------------------------------------------------------------------------------- /segmentTool/__pycache__/resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbw-Berwin/View-Parsing-Network/9ccfcce978aecbe6c469256013e6ca94a2436b67/segmentTool/__pycache__/resnet.cpython-36.pyc -------------------------------------------------------------------------------- /segmentTool/__pycache__/resnext.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbw-Berwin/View-Parsing-Network/9ccfcce978aecbe6c469256013e6ca94a2436b67/segmentTool/__pycache__/resnext.cpython-36.pyc -------------------------------------------------------------------------------- /segmentTool/lib/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .modules import * 2 | from .parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to 3 | -------------------------------------------------------------------------------- /segmentTool/lib/nn/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbw-Berwin/View-Parsing-Network/9ccfcce978aecbe6c469256013e6ca94a2436b67/segmentTool/lib/nn/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /segmentTool/lib/nn/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback 13 | -------------------------------------------------------------------------------- /segmentTool/lib/nn/modules/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbw-Berwin/View-Parsing-Network/9ccfcce978aecbe6c469256013e6ca94a2436b67/segmentTool/lib/nn/modules/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /segmentTool/lib/nn/modules/__pycache__/batchnorm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbw-Berwin/View-Parsing-Network/9ccfcce978aecbe6c469256013e6ca94a2436b67/segmentTool/lib/nn/modules/__pycache__/batchnorm.cpython-36.pyc -------------------------------------------------------------------------------- /segmentTool/lib/nn/modules/__pycache__/comm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbw-Berwin/View-Parsing-Network/9ccfcce978aecbe6c469256013e6ca94a2436b67/segmentTool/lib/nn/modules/__pycache__/comm.cpython-36.pyc -------------------------------------------------------------------------------- /segmentTool/lib/nn/modules/__pycache__/replicate.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbw-Berwin/View-Parsing-Network/9ccfcce978aecbe6c469256013e6ca94a2436b67/segmentTool/lib/nn/modules/__pycache__/replicate.cpython-36.pyc -------------------------------------------------------------------------------- /segmentTool/lib/nn/modules/batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | from torch.nn.modules.batchnorm import _BatchNorm 17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 18 | 19 | from .comm import SyncMaster 20 | 21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] 22 | 23 | 24 | def _sum_ft(tensor): 25 | """sum over the first and last dimention""" 26 | return tensor.sum(dim=0).sum(dim=-1) 27 | 28 | 29 | def _unsqueeze_ft(tensor): 30 | """add new dementions at the front and the tail""" 31 | return tensor.unsqueeze(0).unsqueeze(-1) 32 | 33 | 34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 36 | 37 | 38 | class _SynchronizedBatchNorm(_BatchNorm): 39 | def __init__(self, num_features, eps=1e-5, momentum=0.001, affine=True): 40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 41 | 42 | self._sync_master = SyncMaster(self._data_parallel_master) 43 | 44 | self._is_parallel = False 45 | self._parallel_id = None 46 | self._slave_pipe = None 47 | 48 | # customed batch norm statistics 49 | self._moving_average_fraction = 1. - momentum 50 | self.register_buffer('_tmp_running_mean', torch.zeros(self.num_features)) 51 | self.register_buffer('_tmp_running_var', torch.ones(self.num_features)) 52 | self.register_buffer('_running_iter', torch.ones(1)) 53 | self._tmp_running_mean = self.running_mean.clone() * self._running_iter 54 | self._tmp_running_var = self.running_var.clone() * self._running_iter 55 | 56 | def forward(self, input): 57 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 58 | if not (self._is_parallel and self.training): 59 | return F.batch_norm( 60 | input, self.running_mean, self.running_var, self.weight, self.bias, 61 | self.training, self.momentum, self.eps) 62 | 63 | # Resize the input to (B, C, -1). 64 | input_shape = input.size() 65 | input = input.view(input.size(0), self.num_features, -1) 66 | 67 | # Compute the sum and square-sum. 68 | sum_size = input.size(0) * input.size(2) 69 | input_sum = _sum_ft(input) 70 | input_ssum = _sum_ft(input ** 2) 71 | 72 | # Reduce-and-broadcast the statistics. 73 | if self._parallel_id == 0: 74 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 75 | else: 76 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 77 | 78 | # Compute the output. 79 | if self.affine: 80 | # MJY:: Fuse the multiplication for speed. 81 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 82 | else: 83 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 84 | 85 | # Reshape it. 86 | return output.view(input_shape) 87 | 88 | def __data_parallel_replicate__(self, ctx, copy_id): 89 | self._is_parallel = True 90 | self._parallel_id = copy_id 91 | 92 | # parallel_id == 0 means master device. 93 | if self._parallel_id == 0: 94 | ctx.sync_master = self._sync_master 95 | else: 96 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 97 | 98 | def _data_parallel_master(self, intermediates): 99 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 100 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 101 | 102 | to_reduce = [i[1][:2] for i in intermediates] 103 | to_reduce = [j for i in to_reduce for j in i] # flatten 104 | target_gpus = [i[1].sum.get_device() for i in intermediates] 105 | 106 | sum_size = sum([i[1].sum_size for i in intermediates]) 107 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 108 | 109 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 110 | 111 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 112 | 113 | outputs = [] 114 | for i, rec in enumerate(intermediates): 115 | outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) 116 | 117 | return outputs 118 | 119 | def _add_weighted(self, dest, delta, alpha=1, beta=1, bias=0): 120 | """return *dest* by `dest := dest*alpha + delta*beta + bias`""" 121 | return dest * alpha + delta * beta + bias 122 | 123 | def _compute_mean_std(self, sum_, ssum, size): 124 | """Compute the mean and standard-deviation with sum and square-sum. This method 125 | also maintains the moving average on the master device.""" 126 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 127 | mean = sum_ / size 128 | sumvar = ssum - sum_ * mean 129 | unbias_var = sumvar / (size - 1) 130 | bias_var = sumvar / size 131 | 132 | self._tmp_running_mean = self._add_weighted(self._tmp_running_mean, mean.data, alpha=self._moving_average_fraction) 133 | self._tmp_running_var = self._add_weighted(self._tmp_running_var, unbias_var.data, alpha=self._moving_average_fraction) 134 | self._running_iter = self._add_weighted(self._running_iter, 1, alpha=self._moving_average_fraction) 135 | 136 | self.running_mean = self._tmp_running_mean / self._running_iter 137 | self.running_var = self._tmp_running_var / self._running_iter 138 | 139 | return mean, bias_var.clamp(self.eps) ** -0.5 140 | 141 | 142 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 143 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 144 | mini-batch. 145 | 146 | .. math:: 147 | 148 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 149 | 150 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 151 | standard-deviation are reduced across all devices during training. 152 | 153 | For example, when one uses `nn.DataParallel` to wrap the network during 154 | training, PyTorch's implementation normalize the tensor on each device using 155 | the statistics only on that device, which accelerated the computation and 156 | is also easy to implement, but the statistics might be inaccurate. 157 | Instead, in this synchronized version, the statistics will be computed 158 | over all training samples distributed on multiple devices. 159 | 160 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 161 | as the built-in PyTorch implementation. 162 | 163 | The mean and standard-deviation are calculated per-dimension over 164 | the mini-batches and gamma and beta are learnable parameter vectors 165 | of size C (where C is the input size). 166 | 167 | During training, this layer keeps a running estimate of its computed mean 168 | and variance. The running sum is kept with a default momentum of 0.1. 169 | 170 | During evaluation, this running mean/variance is used for normalization. 171 | 172 | Because the BatchNorm is done over the `C` dimension, computing statistics 173 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 174 | 175 | Args: 176 | num_features: num_features from an expected input of size 177 | `batch_size x num_features [x width]` 178 | eps: a value added to the denominator for numerical stability. 179 | Default: 1e-5 180 | momentum: the value used for the running_mean and running_var 181 | computation. Default: 0.1 182 | affine: a boolean value that when set to ``True``, gives the layer learnable 183 | affine parameters. Default: ``True`` 184 | 185 | Shape: 186 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 187 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 188 | 189 | Examples: 190 | >>> # With Learnable Parameters 191 | >>> m = SynchronizedBatchNorm1d(100) 192 | >>> # Without Learnable Parameters 193 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 194 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 195 | >>> output = m(input) 196 | """ 197 | 198 | def _check_input_dim(self, input): 199 | if input.dim() != 2 and input.dim() != 3: 200 | raise ValueError('expected 2D or 3D input (got {}D input)' 201 | .format(input.dim())) 202 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input) 203 | 204 | 205 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 206 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 207 | of 3d inputs 208 | 209 | .. math:: 210 | 211 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 212 | 213 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 214 | standard-deviation are reduced across all devices during training. 215 | 216 | For example, when one uses `nn.DataParallel` to wrap the network during 217 | training, PyTorch's implementation normalize the tensor on each device using 218 | the statistics only on that device, which accelerated the computation and 219 | is also easy to implement, but the statistics might be inaccurate. 220 | Instead, in this synchronized version, the statistics will be computed 221 | over all training samples distributed on multiple devices. 222 | 223 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 224 | as the built-in PyTorch implementation. 225 | 226 | The mean and standard-deviation are calculated per-dimension over 227 | the mini-batches and gamma and beta are learnable parameter vectors 228 | of size C (where C is the input size). 229 | 230 | During training, this layer keeps a running estimate of its computed mean 231 | and variance. The running sum is kept with a default momentum of 0.1. 232 | 233 | During evaluation, this running mean/variance is used for normalization. 234 | 235 | Because the BatchNorm is done over the `C` dimension, computing statistics 236 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 237 | 238 | Args: 239 | num_features: num_features from an expected input of 240 | size batch_size x num_features x height x width 241 | eps: a value added to the denominator for numerical stability. 242 | Default: 1e-5 243 | momentum: the value used for the running_mean and running_var 244 | computation. Default: 0.1 245 | affine: a boolean value that when set to ``True``, gives the layer learnable 246 | affine parameters. Default: ``True`` 247 | 248 | Shape: 249 | - Input: :math:`(N, C, H, W)` 250 | - Output: :math:`(N, C, H, W)` (same shape as input) 251 | 252 | Examples: 253 | >>> # With Learnable Parameters 254 | >>> m = SynchronizedBatchNorm2d(100) 255 | >>> # Without Learnable Parameters 256 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 257 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 258 | >>> output = m(input) 259 | """ 260 | 261 | def _check_input_dim(self, input): 262 | if input.dim() != 4: 263 | raise ValueError('expected 4D input (got {}D input)' 264 | .format(input.dim())) 265 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input) 266 | 267 | 268 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 269 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 270 | of 4d inputs 271 | 272 | .. math:: 273 | 274 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 275 | 276 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 277 | standard-deviation are reduced across all devices during training. 278 | 279 | For example, when one uses `nn.DataParallel` to wrap the network during 280 | training, PyTorch's implementation normalize the tensor on each device using 281 | the statistics only on that device, which accelerated the computation and 282 | is also easy to implement, but the statistics might be inaccurate. 283 | Instead, in this synchronized version, the statistics will be computed 284 | over all training samples distributed on multiple devices. 285 | 286 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 287 | as the built-in PyTorch implementation. 288 | 289 | The mean and standard-deviation are calculated per-dimension over 290 | the mini-batches and gamma and beta are learnable parameter vectors 291 | of size C (where C is the input size). 292 | 293 | During training, this layer keeps a running estimate of its computed mean 294 | and variance. The running sum is kept with a default momentum of 0.1. 295 | 296 | During evaluation, this running mean/variance is used for normalization. 297 | 298 | Because the BatchNorm is done over the `C` dimension, computing statistics 299 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 300 | or Spatio-temporal BatchNorm 301 | 302 | Args: 303 | num_features: num_features from an expected input of 304 | size batch_size x num_features x depth x height x width 305 | eps: a value added to the denominator for numerical stability. 306 | Default: 1e-5 307 | momentum: the value used for the running_mean and running_var 308 | computation. Default: 0.1 309 | affine: a boolean value that when set to ``True``, gives the layer learnable 310 | affine parameters. Default: ``True`` 311 | 312 | Shape: 313 | - Input: :math:`(N, C, D, H, W)` 314 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 315 | 316 | Examples: 317 | >>> # With Learnable Parameters 318 | >>> m = SynchronizedBatchNorm3d(100) 319 | >>> # Without Learnable Parameters 320 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 321 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 322 | >>> output = m(input) 323 | """ 324 | 325 | def _check_input_dim(self, input): 326 | if input.dim() != 5: 327 | raise ValueError('expected 5D input (got {}D input)' 328 | .format(input.dim())) 329 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input) 330 | -------------------------------------------------------------------------------- /segmentTool/lib/nn/modules/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def register_slave(self, identifier): 79 | """ 80 | Register an slave device. 81 | 82 | Args: 83 | identifier: an identifier, usually is the device id. 84 | 85 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 86 | 87 | """ 88 | if self._activated: 89 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 90 | self._activated = False 91 | self._registry.clear() 92 | future = FutureResult() 93 | self._registry[identifier] = _MasterRegistry(future) 94 | return SlavePipe(identifier, self._queue, future) 95 | 96 | def run_master(self, master_msg): 97 | """ 98 | Main entry for the master device in each forward pass. 99 | The messages were first collected from each devices (including the master device), and then 100 | an callback will be invoked to compute the message to be sent back to each devices 101 | (including the master device). 102 | 103 | Args: 104 | master_msg: the message that the master want to send to itself. This will be placed as the first 105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 106 | 107 | Returns: the message to be sent back to the master device. 108 | 109 | """ 110 | self._activated = True 111 | 112 | intermediates = [(0, master_msg)] 113 | for i in range(self.nr_slaves): 114 | intermediates.append(self._queue.get()) 115 | 116 | results = self._master_callback(intermediates) 117 | assert results[0][0] == 0, 'The first result should belongs to the master.' 118 | 119 | for i, res in results: 120 | if i == 0: 121 | continue 122 | self._registry[i].result.put(res) 123 | 124 | for i in range(self.nr_slaves): 125 | assert self._queue.get() is True 126 | 127 | return results[0][1] 128 | 129 | @property 130 | def nr_slaves(self): 131 | return len(self._registry) 132 | -------------------------------------------------------------------------------- /segmentTool/lib/nn/modules/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /segmentTool/lib/nn/modules/tests/test_numeric_batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : test_numeric_batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | 9 | import unittest 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.autograd import Variable 14 | 15 | from sync_batchnorm.unittest import TorchTestCase 16 | 17 | 18 | def handy_var(a, unbias=True): 19 | n = a.size(0) 20 | asum = a.sum(dim=0) 21 | as_sum = (a ** 2).sum(dim=0) # a square sum 22 | sumvar = as_sum - asum * asum / n 23 | if unbias: 24 | return sumvar / (n - 1) 25 | else: 26 | return sumvar / n 27 | 28 | 29 | class NumericTestCase(TorchTestCase): 30 | def testNumericBatchNorm(self): 31 | a = torch.rand(16, 10) 32 | bn = nn.BatchNorm2d(10, momentum=1, eps=1e-5, affine=False) 33 | bn.train() 34 | 35 | a_var1 = Variable(a, requires_grad=True) 36 | b_var1 = bn(a_var1) 37 | loss1 = b_var1.sum() 38 | loss1.backward() 39 | 40 | a_var2 = Variable(a, requires_grad=True) 41 | a_mean2 = a_var2.mean(dim=0, keepdim=True) 42 | a_std2 = torch.sqrt(handy_var(a_var2, unbias=False).clamp(min=1e-5)) 43 | # a_std2 = torch.sqrt(a_var2.var(dim=0, keepdim=True, unbiased=False) + 1e-5) 44 | b_var2 = (a_var2 - a_mean2) / a_std2 45 | loss2 = b_var2.sum() 46 | loss2.backward() 47 | 48 | self.assertTensorClose(bn.running_mean, a.mean(dim=0)) 49 | self.assertTensorClose(bn.running_var, handy_var(a)) 50 | self.assertTensorClose(a_var1.data, a_var2.data) 51 | self.assertTensorClose(b_var1.data, b_var2.data) 52 | self.assertTensorClose(a_var1.grad, a_var2.grad) 53 | 54 | 55 | if __name__ == '__main__': 56 | unittest.main() 57 | -------------------------------------------------------------------------------- /segmentTool/lib/nn/modules/tests/test_sync_batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : test_sync_batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | 9 | import unittest 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.autograd import Variable 14 | 15 | from sync_batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, DataParallelWithCallback 16 | from sync_batchnorm.unittest import TorchTestCase 17 | 18 | 19 | def handy_var(a, unbias=True): 20 | n = a.size(0) 21 | asum = a.sum(dim=0) 22 | as_sum = (a ** 2).sum(dim=0) # a square sum 23 | sumvar = as_sum - asum * asum / n 24 | if unbias: 25 | return sumvar / (n - 1) 26 | else: 27 | return sumvar / n 28 | 29 | 30 | def _find_bn(module): 31 | for m in module.modules(): 32 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, SynchronizedBatchNorm1d, SynchronizedBatchNorm2d)): 33 | return m 34 | 35 | 36 | class SyncTestCase(TorchTestCase): 37 | def _syncParameters(self, bn1, bn2): 38 | bn1.reset_parameters() 39 | bn2.reset_parameters() 40 | if bn1.affine and bn2.affine: 41 | bn2.weight.data.copy_(bn1.weight.data) 42 | bn2.bias.data.copy_(bn1.bias.data) 43 | 44 | def _checkBatchNormResult(self, bn1, bn2, input, is_train, cuda=False): 45 | """Check the forward and backward for the customized batch normalization.""" 46 | bn1.train(mode=is_train) 47 | bn2.train(mode=is_train) 48 | 49 | if cuda: 50 | input = input.cuda() 51 | 52 | self._syncParameters(_find_bn(bn1), _find_bn(bn2)) 53 | 54 | input1 = Variable(input, requires_grad=True) 55 | output1 = bn1(input1) 56 | output1.sum().backward() 57 | input2 = Variable(input, requires_grad=True) 58 | output2 = bn2(input2) 59 | output2.sum().backward() 60 | 61 | self.assertTensorClose(input1.data, input2.data) 62 | self.assertTensorClose(output1.data, output2.data) 63 | self.assertTensorClose(input1.grad, input2.grad) 64 | self.assertTensorClose(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean) 65 | self.assertTensorClose(_find_bn(bn1).running_var, _find_bn(bn2).running_var) 66 | 67 | def testSyncBatchNormNormalTrain(self): 68 | bn = nn.BatchNorm1d(10) 69 | sync_bn = SynchronizedBatchNorm1d(10) 70 | 71 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True) 72 | 73 | def testSyncBatchNormNormalEval(self): 74 | bn = nn.BatchNorm1d(10) 75 | sync_bn = SynchronizedBatchNorm1d(10) 76 | 77 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False) 78 | 79 | def testSyncBatchNormSyncTrain(self): 80 | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) 81 | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 82 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 83 | 84 | bn.cuda() 85 | sync_bn.cuda() 86 | 87 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True, cuda=True) 88 | 89 | def testSyncBatchNormSyncEval(self): 90 | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) 91 | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 92 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 93 | 94 | bn.cuda() 95 | sync_bn.cuda() 96 | 97 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True) 98 | 99 | def testSyncBatchNorm2DSyncTrain(self): 100 | bn = nn.BatchNorm2d(10) 101 | sync_bn = SynchronizedBatchNorm2d(10) 102 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 103 | 104 | bn.cuda() 105 | sync_bn.cuda() 106 | 107 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True) 108 | 109 | 110 | if __name__ == '__main__': 111 | unittest.main() 112 | -------------------------------------------------------------------------------- /segmentTool/lib/nn/modules/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /segmentTool/lib/nn/parallel/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to 2 | -------------------------------------------------------------------------------- /segmentTool/lib/nn/parallel/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbw-Berwin/View-Parsing-Network/9ccfcce978aecbe6c469256013e6ca94a2436b67/segmentTool/lib/nn/parallel/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /segmentTool/lib/nn/parallel/__pycache__/data_parallel.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbw-Berwin/View-Parsing-Network/9ccfcce978aecbe6c469256013e6ca94a2436b67/segmentTool/lib/nn/parallel/__pycache__/data_parallel.cpython-36.pyc -------------------------------------------------------------------------------- /segmentTool/lib/nn/parallel/data_parallel.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf8 -*- 2 | 3 | import torch.cuda as cuda 4 | import torch.nn as nn 5 | import torch 6 | from torch.autograd import Variable 7 | import collections 8 | from torch.nn.parallel._functions import Gather 9 | 10 | __all__ = ['UserScatteredDataParallel', 'user_scattered_collate', 'async_copy_to'] 11 | 12 | def async_copy_to(obj, dev, main_stream=None): 13 | if torch.is_tensor(obj): 14 | obj = Variable(obj) 15 | if isinstance(obj, Variable): 16 | v = obj.cuda(dev, async=True) 17 | if main_stream is not None: 18 | v.data.record_stream(main_stream) 19 | return v 20 | elif isinstance(obj, collections.Mapping): 21 | return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()} 22 | elif isinstance(obj, collections.Sequence): 23 | return [async_copy_to(o, dev, main_stream) for o in obj] 24 | else: 25 | return obj 26 | 27 | 28 | def dict_gather(outputs, target_device, dim=0): 29 | """ 30 | Gathers variables from different GPUs on a specified device 31 | (-1 means the CPU), with dictionary support. 32 | """ 33 | def gather_map(outputs): 34 | out = outputs[0] 35 | if isinstance(out, Variable): 36 | # MJY(20180330) HACK:: force nr_dims > 0 37 | if out.dim() == 0: 38 | outputs = [o.unsqueeze(0) for o in outputs] 39 | return Gather.apply(target_device, dim, *outputs) 40 | elif out is None: 41 | return None 42 | elif isinstance(out, collections.Mapping): 43 | return {k: gather_map([o[k] for o in outputs]) for k in out} 44 | elif isinstance(out, collections.Sequence): 45 | return type(out)(map(gather_map, zip(*outputs))) 46 | return gather_map(outputs) 47 | 48 | 49 | class DictGatherDataParallel(nn.DataParallel): 50 | def gather(self, outputs, output_device): 51 | return dict_gather(outputs, output_device, dim=self.dim) 52 | 53 | 54 | class UserScatteredDataParallel(DictGatherDataParallel): 55 | def scatter(self, inputs, kwargs, device_ids): 56 | assert len(inputs) == 1 57 | inputs = inputs[0] 58 | inputs = _async_copy_stream(inputs, device_ids) 59 | inputs = [[i] for i in inputs] 60 | assert len(kwargs) == 0 61 | kwargs = [{} for _ in range(len(inputs))] 62 | 63 | return inputs, kwargs 64 | 65 | 66 | def user_scattered_collate(batch): 67 | return batch 68 | 69 | 70 | def _async_copy(inputs, device_ids): 71 | nr_devs = len(device_ids) 72 | assert type(inputs) in (tuple, list) 73 | assert len(inputs) == nr_devs 74 | 75 | outputs = [] 76 | for i, dev in zip(inputs, device_ids): 77 | with cuda.device(dev): 78 | outputs.append(async_copy_to(i, dev)) 79 | 80 | return tuple(outputs) 81 | 82 | 83 | def _async_copy_stream(inputs, device_ids): 84 | nr_devs = len(device_ids) 85 | assert type(inputs) in (tuple, list) 86 | assert len(inputs) == nr_devs 87 | 88 | outputs = [] 89 | streams = [_get_stream(d) for d in device_ids] 90 | for i, dev, stream in zip(inputs, device_ids, streams): 91 | with cuda.device(dev): 92 | main_stream = cuda.current_stream() 93 | with cuda.stream(stream): 94 | outputs.append(async_copy_to(i, dev, main_stream=main_stream)) 95 | main_stream.wait_stream(stream) 96 | 97 | return outputs 98 | 99 | 100 | """Adapted from: torch/nn/parallel/_functions.py""" 101 | # background streams used for copying 102 | _streams = None 103 | 104 | 105 | def _get_stream(device): 106 | """Gets a background stream for copying between CPU and GPU""" 107 | global _streams 108 | if device == -1: 109 | return None 110 | if _streams is None: 111 | _streams = [None] * cuda.device_count() 112 | if _streams[device] is None: _streams[device] = cuda.Stream(device) 113 | return _streams[device] 114 | -------------------------------------------------------------------------------- /segmentTool/lib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .th import * 2 | -------------------------------------------------------------------------------- /segmentTool/lib/utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbw-Berwin/View-Parsing-Network/9ccfcce978aecbe6c469256013e6ca94a2436b67/segmentTool/lib/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /segmentTool/lib/utils/__pycache__/th.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbw-Berwin/View-Parsing-Network/9ccfcce978aecbe6c469256013e6ca94a2436b67/segmentTool/lib/utils/__pycache__/th.cpython-36.pyc -------------------------------------------------------------------------------- /segmentTool/lib/utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .dataset import Dataset, TensorDataset, ConcatDataset 3 | from .dataloader import DataLoader 4 | -------------------------------------------------------------------------------- /segmentTool/lib/utils/data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbw-Berwin/View-Parsing-Network/9ccfcce978aecbe6c469256013e6ca94a2436b67/segmentTool/lib/utils/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /segmentTool/lib/utils/data/__pycache__/dataloader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbw-Berwin/View-Parsing-Network/9ccfcce978aecbe6c469256013e6ca94a2436b67/segmentTool/lib/utils/data/__pycache__/dataloader.cpython-36.pyc -------------------------------------------------------------------------------- /segmentTool/lib/utils/data/__pycache__/dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbw-Berwin/View-Parsing-Network/9ccfcce978aecbe6c469256013e6ca94a2436b67/segmentTool/lib/utils/data/__pycache__/dataset.cpython-36.pyc -------------------------------------------------------------------------------- /segmentTool/lib/utils/data/__pycache__/sampler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbw-Berwin/View-Parsing-Network/9ccfcce978aecbe6c469256013e6ca94a2436b67/segmentTool/lib/utils/data/__pycache__/sampler.cpython-36.pyc -------------------------------------------------------------------------------- /segmentTool/lib/utils/data/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.multiprocessing as multiprocessing 3 | from torch._C import _set_worker_signal_handlers, _update_worker_pids, \ 4 | _remove_worker_pids, _error_if_any_worker_fails 5 | from .sampler import SequentialSampler, RandomSampler, BatchSampler 6 | import signal 7 | import functools 8 | import collections 9 | import re 10 | import sys 11 | import threading 12 | import traceback 13 | from torch._six import string_classes, int_classes 14 | import numpy as np 15 | 16 | if sys.version_info[0] == 2: 17 | import Queue as queue 18 | else: 19 | import queue 20 | 21 | 22 | class ExceptionWrapper(object): 23 | r"Wraps an exception plus traceback to communicate across threads" 24 | 25 | def __init__(self, exc_info): 26 | self.exc_type = exc_info[0] 27 | self.exc_msg = "".join(traceback.format_exception(*exc_info)) 28 | 29 | 30 | _use_shared_memory = False 31 | """Whether to use shared memory in default_collate""" 32 | 33 | 34 | def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id): 35 | global _use_shared_memory 36 | _use_shared_memory = True 37 | 38 | # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal 39 | # module's handlers are executed after Python returns from C low-level 40 | # handlers, likely when the same fatal signal happened again already. 41 | # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1 42 | _set_worker_signal_handlers() 43 | 44 | torch.set_num_threads(1) 45 | torch.manual_seed(seed) 46 | np.random.seed(seed) 47 | 48 | if init_fn is not None: 49 | init_fn(worker_id) 50 | 51 | while True: 52 | r = index_queue.get() 53 | if r is None: 54 | break 55 | idx, batch_indices = r 56 | try: 57 | samples = collate_fn([dataset[i] for i in batch_indices]) 58 | except Exception: 59 | data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) 60 | else: 61 | data_queue.put((idx, samples)) 62 | 63 | 64 | def _worker_manager_loop(in_queue, out_queue, done_event, pin_memory, device_id): 65 | if pin_memory: 66 | torch.cuda.set_device(device_id) 67 | 68 | while True: 69 | try: 70 | r = in_queue.get() 71 | except Exception: 72 | if done_event.is_set(): 73 | return 74 | raise 75 | if r is None: 76 | break 77 | if isinstance(r[1], ExceptionWrapper): 78 | out_queue.put(r) 79 | continue 80 | idx, batch = r 81 | try: 82 | if pin_memory: 83 | batch = pin_memory_batch(batch) 84 | except Exception: 85 | out_queue.put((idx, ExceptionWrapper(sys.exc_info()))) 86 | else: 87 | out_queue.put((idx, batch)) 88 | 89 | numpy_type_map = { 90 | 'float64': torch.DoubleTensor, 91 | 'float32': torch.FloatTensor, 92 | 'float16': torch.HalfTensor, 93 | 'int64': torch.LongTensor, 94 | 'int32': torch.IntTensor, 95 | 'int16': torch.ShortTensor, 96 | 'int8': torch.CharTensor, 97 | 'uint8': torch.ByteTensor, 98 | } 99 | 100 | 101 | def default_collate(batch): 102 | "Puts each data field into a tensor with outer dimension batch size" 103 | 104 | error_msg = "batch must contain tensors, numbers, dicts or lists; found {}" 105 | elem_type = type(batch[0]) 106 | if torch.is_tensor(batch[0]): 107 | out = None 108 | if _use_shared_memory: 109 | # If we're in a background process, concatenate directly into a 110 | # shared memory tensor to avoid an extra copy 111 | numel = sum([x.numel() for x in batch]) 112 | storage = batch[0].storage()._new_shared(numel) 113 | out = batch[0].new(storage) 114 | return torch.stack(batch, 0, out=out) 115 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ 116 | and elem_type.__name__ != 'string_': 117 | elem = batch[0] 118 | if elem_type.__name__ == 'ndarray': 119 | # array of string classes and object 120 | if re.search('[SaUO]', elem.dtype.str) is not None: 121 | raise TypeError(error_msg.format(elem.dtype)) 122 | 123 | return torch.stack([torch.from_numpy(b) for b in batch], 0) 124 | if elem.shape == (): # scalars 125 | py_type = float if elem.dtype.name.startswith('float') else int 126 | return numpy_type_map[elem.dtype.name](list(map(py_type, batch))) 127 | elif isinstance(batch[0], int_classes): 128 | return torch.LongTensor(batch) 129 | elif isinstance(batch[0], float): 130 | return torch.DoubleTensor(batch) 131 | elif isinstance(batch[0], string_classes): 132 | return batch 133 | elif isinstance(batch[0], collections.Mapping): 134 | return {key: default_collate([d[key] for d in batch]) for key in batch[0]} 135 | elif isinstance(batch[0], collections.Sequence): 136 | transposed = zip(*batch) 137 | return [default_collate(samples) for samples in transposed] 138 | 139 | raise TypeError((error_msg.format(type(batch[0])))) 140 | 141 | 142 | def pin_memory_batch(batch): 143 | if torch.is_tensor(batch): 144 | return batch.pin_memory() 145 | elif isinstance(batch, string_classes): 146 | return batch 147 | elif isinstance(batch, collections.Mapping): 148 | return {k: pin_memory_batch(sample) for k, sample in batch.items()} 149 | elif isinstance(batch, collections.Sequence): 150 | return [pin_memory_batch(sample) for sample in batch] 151 | else: 152 | return batch 153 | 154 | 155 | _SIGCHLD_handler_set = False 156 | """Whether SIGCHLD handler is set for DataLoader worker failures. Only one 157 | handler needs to be set for all DataLoaders in a process.""" 158 | 159 | 160 | def _set_SIGCHLD_handler(): 161 | # Windows doesn't support SIGCHLD handler 162 | if sys.platform == 'win32': 163 | return 164 | # can't set signal in child threads 165 | if not isinstance(threading.current_thread(), threading._MainThread): 166 | return 167 | global _SIGCHLD_handler_set 168 | if _SIGCHLD_handler_set: 169 | return 170 | previous_handler = signal.getsignal(signal.SIGCHLD) 171 | if not callable(previous_handler): 172 | previous_handler = None 173 | 174 | def handler(signum, frame): 175 | # This following call uses `waitid` with WNOHANG from C side. Therefore, 176 | # Python can still get and update the process status successfully. 177 | _error_if_any_worker_fails() 178 | if previous_handler is not None: 179 | previous_handler(signum, frame) 180 | 181 | signal.signal(signal.SIGCHLD, handler) 182 | _SIGCHLD_handler_set = True 183 | 184 | 185 | class DataLoaderIter(object): 186 | "Iterates once over the DataLoader's dataset, as specified by the sampler" 187 | 188 | def __init__(self, loader): 189 | self.dataset = loader.dataset 190 | self.collate_fn = loader.collate_fn 191 | self.batch_sampler = loader.batch_sampler 192 | self.num_workers = loader.num_workers 193 | self.pin_memory = loader.pin_memory and torch.cuda.is_available() 194 | self.timeout = loader.timeout 195 | self.done_event = threading.Event() 196 | 197 | self.sample_iter = iter(self.batch_sampler) 198 | 199 | if self.num_workers > 0: 200 | self.worker_init_fn = loader.worker_init_fn 201 | self.index_queue = multiprocessing.SimpleQueue() 202 | self.worker_result_queue = multiprocessing.SimpleQueue() 203 | self.batches_outstanding = 0 204 | self.worker_pids_set = False 205 | self.shutdown = False 206 | self.send_idx = 0 207 | self.rcvd_idx = 0 208 | self.reorder_dict = {} 209 | 210 | base_seed = torch.LongTensor(1).random_(0, 2**31-1)[0] 211 | self.workers = [ 212 | multiprocessing.Process( 213 | target=_worker_loop, 214 | args=(self.dataset, self.index_queue, self.worker_result_queue, self.collate_fn, 215 | base_seed + i, self.worker_init_fn, i)) 216 | for i in range(self.num_workers)] 217 | 218 | if self.pin_memory or self.timeout > 0: 219 | self.data_queue = queue.Queue() 220 | if self.pin_memory: 221 | maybe_device_id = torch.cuda.current_device() 222 | else: 223 | # do not initialize cuda context if not necessary 224 | maybe_device_id = None 225 | self.worker_manager_thread = threading.Thread( 226 | target=_worker_manager_loop, 227 | args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory, 228 | maybe_device_id)) 229 | self.worker_manager_thread.daemon = True 230 | self.worker_manager_thread.start() 231 | else: 232 | self.data_queue = self.worker_result_queue 233 | 234 | for w in self.workers: 235 | w.daemon = True # ensure that the worker exits on process exit 236 | w.start() 237 | 238 | _update_worker_pids(id(self), tuple(w.pid for w in self.workers)) 239 | _set_SIGCHLD_handler() 240 | self.worker_pids_set = True 241 | 242 | # prime the prefetch loop 243 | for _ in range(2 * self.num_workers): 244 | self._put_indices() 245 | 246 | def __len__(self): 247 | return len(self.batch_sampler) 248 | 249 | def _get_batch(self): 250 | if self.timeout > 0: 251 | try: 252 | return self.data_queue.get(timeout=self.timeout) 253 | except queue.Empty: 254 | raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout)) 255 | else: 256 | return self.data_queue.get() 257 | 258 | def __next__(self): 259 | if self.num_workers == 0: # same-process loading 260 | indices = next(self.sample_iter) # may raise StopIteration 261 | batch = self.collate_fn([self.dataset[i] for i in indices]) 262 | if self.pin_memory: 263 | batch = pin_memory_batch(batch) 264 | return batch 265 | 266 | # check if the next sample has already been generated 267 | if self.rcvd_idx in self.reorder_dict: 268 | batch = self.reorder_dict.pop(self.rcvd_idx) 269 | return self._process_next_batch(batch) 270 | 271 | if self.batches_outstanding == 0: 272 | self._shutdown_workers() 273 | raise StopIteration 274 | 275 | while True: 276 | assert (not self.shutdown and self.batches_outstanding > 0) 277 | idx, batch = self._get_batch() 278 | self.batches_outstanding -= 1 279 | if idx != self.rcvd_idx: 280 | # store out-of-order samples 281 | self.reorder_dict[idx] = batch 282 | continue 283 | return self._process_next_batch(batch) 284 | 285 | next = __next__ # Python 2 compatibility 286 | 287 | def __iter__(self): 288 | return self 289 | 290 | def _put_indices(self): 291 | assert self.batches_outstanding < 2 * self.num_workers 292 | indices = next(self.sample_iter, None) 293 | if indices is None: 294 | return 295 | self.index_queue.put((self.send_idx, indices)) 296 | self.batches_outstanding += 1 297 | self.send_idx += 1 298 | 299 | def _process_next_batch(self, batch): 300 | self.rcvd_idx += 1 301 | self._put_indices() 302 | if isinstance(batch, ExceptionWrapper): 303 | raise batch.exc_type(batch.exc_msg) 304 | return batch 305 | 306 | def __getstate__(self): 307 | # TODO: add limited pickling support for sharing an iterator 308 | # across multiple threads for HOGWILD. 309 | # Probably the best way to do this is by moving the sample pushing 310 | # to a separate thread and then just sharing the data queue 311 | # but signalling the end is tricky without a non-blocking API 312 | raise NotImplementedError("DataLoaderIterator cannot be pickled") 313 | 314 | def _shutdown_workers(self): 315 | try: 316 | if not self.shutdown: 317 | self.shutdown = True 318 | self.done_event.set() 319 | # if worker_manager_thread is waiting to put 320 | while not self.data_queue.empty(): 321 | self.data_queue.get() 322 | for _ in self.workers: 323 | self.index_queue.put(None) 324 | # done_event should be sufficient to exit worker_manager_thread, 325 | # but be safe here and put another None 326 | self.worker_result_queue.put(None) 327 | finally: 328 | # removes pids no matter what 329 | if self.worker_pids_set: 330 | _remove_worker_pids(id(self)) 331 | self.worker_pids_set = False 332 | 333 | def __del__(self): 334 | if self.num_workers > 0: 335 | self._shutdown_workers() 336 | 337 | 338 | class DataLoader(object): 339 | """ 340 | Data loader. Combines a dataset and a sampler, and provides 341 | single- or multi-process iterators over the dataset. 342 | 343 | Arguments: 344 | dataset (Dataset): dataset from which to load the data. 345 | batch_size (int, optional): how many samples per batch to load 346 | (default: 1). 347 | shuffle (bool, optional): set to ``True`` to have the data reshuffled 348 | at every epoch (default: False). 349 | sampler (Sampler, optional): defines the strategy to draw samples from 350 | the dataset. If specified, ``shuffle`` must be False. 351 | batch_sampler (Sampler, optional): like sampler, but returns a batch of 352 | indices at a time. Mutually exclusive with batch_size, shuffle, 353 | sampler, and drop_last. 354 | num_workers (int, optional): how many subprocesses to use for data 355 | loading. 0 means that the data will be loaded in the main process. 356 | (default: 0) 357 | collate_fn (callable, optional): merges a list of samples to form a mini-batch. 358 | pin_memory (bool, optional): If ``True``, the data loader will copy tensors 359 | into CUDA pinned memory before returning them. 360 | drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, 361 | if the dataset size is not divisible by the batch size. If ``False`` and 362 | the size of dataset is not divisible by the batch size, then the last batch 363 | will be smaller. (default: False) 364 | timeout (numeric, optional): if positive, the timeout value for collecting a batch 365 | from workers. Should always be non-negative. (default: 0) 366 | worker_init_fn (callable, optional): If not None, this will be called on each 367 | worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as 368 | input, after seeding and before data loading. (default: None) 369 | 370 | .. note:: By default, each worker will have its PyTorch seed set to 371 | ``base_seed + worker_id``, where ``base_seed`` is a long generated 372 | by main process using its RNG. You may use ``torch.initial_seed()`` to access 373 | this value in :attr:`worker_init_fn`, which can be used to set other seeds 374 | (e.g. NumPy) before data loading. 375 | 376 | .. warning:: If ``spawn'' start method is used, :attr:`worker_init_fn` cannot be an 377 | unpicklable object, e.g., a lambda function. 378 | """ 379 | 380 | def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, 381 | num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False, 382 | timeout=0, worker_init_fn=None): 383 | self.dataset = dataset 384 | self.batch_size = batch_size 385 | self.num_workers = num_workers 386 | self.collate_fn = collate_fn 387 | self.pin_memory = pin_memory 388 | self.drop_last = drop_last 389 | self.timeout = timeout 390 | self.worker_init_fn = worker_init_fn 391 | 392 | if timeout < 0: 393 | raise ValueError('timeout option should be non-negative') 394 | 395 | if batch_sampler is not None: 396 | if batch_size > 1 or shuffle or sampler is not None or drop_last: 397 | raise ValueError('batch_sampler is mutually exclusive with ' 398 | 'batch_size, shuffle, sampler, and drop_last') 399 | 400 | if sampler is not None and shuffle: 401 | raise ValueError('sampler is mutually exclusive with shuffle') 402 | 403 | if self.num_workers < 0: 404 | raise ValueError('num_workers cannot be negative; ' 405 | 'use num_workers=0 to disable multiprocessing.') 406 | 407 | if batch_sampler is None: 408 | if sampler is None: 409 | if shuffle: 410 | sampler = RandomSampler(dataset) 411 | else: 412 | sampler = SequentialSampler(dataset) 413 | batch_sampler = BatchSampler(sampler, batch_size, drop_last) 414 | 415 | self.sampler = sampler 416 | self.batch_sampler = batch_sampler 417 | 418 | def __iter__(self): 419 | return DataLoaderIter(self) 420 | 421 | def __len__(self): 422 | return len(self.batch_sampler) 423 | -------------------------------------------------------------------------------- /segmentTool/lib/utils/data/dataset.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import warnings 3 | 4 | from torch._utils import _accumulate 5 | from torch import randperm 6 | 7 | 8 | class Dataset(object): 9 | """An abstract class representing a Dataset. 10 | 11 | All other datasets should subclass it. All subclasses should override 12 | ``__len__``, that provides the size of the dataset, and ``__getitem__``, 13 | supporting integer indexing in range from 0 to len(self) exclusive. 14 | """ 15 | 16 | def __getitem__(self, index): 17 | raise NotImplementedError 18 | 19 | def __len__(self): 20 | raise NotImplementedError 21 | 22 | def __add__(self, other): 23 | return ConcatDataset([self, other]) 24 | 25 | 26 | class TensorDataset(Dataset): 27 | """Dataset wrapping data and target tensors. 28 | 29 | Each sample will be retrieved by indexing both tensors along the first 30 | dimension. 31 | 32 | Arguments: 33 | data_tensor (Tensor): contains sample data. 34 | target_tensor (Tensor): contains sample targets (labels). 35 | """ 36 | 37 | def __init__(self, data_tensor, target_tensor): 38 | assert data_tensor.size(0) == target_tensor.size(0) 39 | self.data_tensor = data_tensor 40 | self.target_tensor = target_tensor 41 | 42 | def __getitem__(self, index): 43 | return self.data_tensor[index], self.target_tensor[index] 44 | 45 | def __len__(self): 46 | return self.data_tensor.size(0) 47 | 48 | 49 | class ConcatDataset(Dataset): 50 | """ 51 | Dataset to concatenate multiple datasets. 52 | Purpose: useful to assemble different existing datasets, possibly 53 | large-scale datasets as the concatenation operation is done in an 54 | on-the-fly manner. 55 | 56 | Arguments: 57 | datasets (iterable): List of datasets to be concatenated 58 | """ 59 | 60 | @staticmethod 61 | def cumsum(sequence): 62 | r, s = [], 0 63 | for e in sequence: 64 | l = len(e) 65 | r.append(l + s) 66 | s += l 67 | return r 68 | 69 | def __init__(self, datasets): 70 | super(ConcatDataset, self).__init__() 71 | assert len(datasets) > 0, 'datasets should not be an empty iterable' 72 | self.datasets = list(datasets) 73 | self.cumulative_sizes = self.cumsum(self.datasets) 74 | 75 | def __len__(self): 76 | return self.cumulative_sizes[-1] 77 | 78 | def __getitem__(self, idx): 79 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 80 | if dataset_idx == 0: 81 | sample_idx = idx 82 | else: 83 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 84 | return self.datasets[dataset_idx][sample_idx] 85 | 86 | @property 87 | def cummulative_sizes(self): 88 | warnings.warn("cummulative_sizes attribute is renamed to " 89 | "cumulative_sizes", DeprecationWarning, stacklevel=2) 90 | return self.cumulative_sizes 91 | 92 | 93 | class Subset(Dataset): 94 | def __init__(self, dataset, indices): 95 | self.dataset = dataset 96 | self.indices = indices 97 | 98 | def __getitem__(self, idx): 99 | return self.dataset[self.indices[idx]] 100 | 101 | def __len__(self): 102 | return len(self.indices) 103 | 104 | 105 | def random_split(dataset, lengths): 106 | """ 107 | Randomly split a dataset into non-overlapping new datasets of given lengths 108 | ds 109 | 110 | Arguments: 111 | dataset (Dataset): Dataset to be split 112 | lengths (iterable): lengths of splits to be produced 113 | """ 114 | if sum(lengths) != len(dataset): 115 | raise ValueError("Sum of input lengths does not equal the length of the input dataset!") 116 | 117 | indices = randperm(sum(lengths)) 118 | return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)] 119 | -------------------------------------------------------------------------------- /segmentTool/lib/utils/data/distributed.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from .sampler import Sampler 4 | from torch.distributed import get_world_size, get_rank 5 | 6 | 7 | class DistributedSampler(Sampler): 8 | """Sampler that restricts data loading to a subset of the dataset. 9 | 10 | It is especially useful in conjunction with 11 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 12 | process can pass a DistributedSampler instance as a DataLoader sampler, 13 | and load a subset of the original dataset that is exclusive to it. 14 | 15 | .. note:: 16 | Dataset is assumed to be of constant size. 17 | 18 | Arguments: 19 | dataset: Dataset used for sampling. 20 | num_replicas (optional): Number of processes participating in 21 | distributed training. 22 | rank (optional): Rank of the current process within num_replicas. 23 | """ 24 | 25 | def __init__(self, dataset, num_replicas=None, rank=None): 26 | if num_replicas is None: 27 | num_replicas = get_world_size() 28 | if rank is None: 29 | rank = get_rank() 30 | self.dataset = dataset 31 | self.num_replicas = num_replicas 32 | self.rank = rank 33 | self.epoch = 0 34 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 35 | self.total_size = self.num_samples * self.num_replicas 36 | 37 | def __iter__(self): 38 | # deterministically shuffle based on epoch 39 | g = torch.Generator() 40 | g.manual_seed(self.epoch) 41 | indices = list(torch.randperm(len(self.dataset), generator=g)) 42 | 43 | # add extra samples to make it evenly divisible 44 | indices += indices[:(self.total_size - len(indices))] 45 | assert len(indices) == self.total_size 46 | 47 | # subsample 48 | offset = self.num_samples * self.rank 49 | indices = indices[offset:offset + self.num_samples] 50 | assert len(indices) == self.num_samples 51 | 52 | return iter(indices) 53 | 54 | def __len__(self): 55 | return self.num_samples 56 | 57 | def set_epoch(self, epoch): 58 | self.epoch = epoch 59 | -------------------------------------------------------------------------------- /segmentTool/lib/utils/data/sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Sampler(object): 5 | """Base class for all Samplers. 6 | 7 | Every Sampler subclass has to provide an __iter__ method, providing a way 8 | to iterate over indices of dataset elements, and a __len__ method that 9 | returns the length of the returned iterators. 10 | """ 11 | 12 | def __init__(self, data_source): 13 | pass 14 | 15 | def __iter__(self): 16 | raise NotImplementedError 17 | 18 | def __len__(self): 19 | raise NotImplementedError 20 | 21 | 22 | class SequentialSampler(Sampler): 23 | """Samples elements sequentially, always in the same order. 24 | 25 | Arguments: 26 | data_source (Dataset): dataset to sample from 27 | """ 28 | 29 | def __init__(self, data_source): 30 | self.data_source = data_source 31 | 32 | def __iter__(self): 33 | return iter(range(len(self.data_source))) 34 | 35 | def __len__(self): 36 | return len(self.data_source) 37 | 38 | 39 | class RandomSampler(Sampler): 40 | """Samples elements randomly, without replacement. 41 | 42 | Arguments: 43 | data_source (Dataset): dataset to sample from 44 | """ 45 | 46 | def __init__(self, data_source): 47 | self.data_source = data_source 48 | 49 | def __iter__(self): 50 | return iter(torch.randperm(len(self.data_source)).long()) 51 | 52 | def __len__(self): 53 | return len(self.data_source) 54 | 55 | 56 | class SubsetRandomSampler(Sampler): 57 | """Samples elements randomly from a given list of indices, without replacement. 58 | 59 | Arguments: 60 | indices (list): a list of indices 61 | """ 62 | 63 | def __init__(self, indices): 64 | self.indices = indices 65 | 66 | def __iter__(self): 67 | return (self.indices[i] for i in torch.randperm(len(self.indices))) 68 | 69 | def __len__(self): 70 | return len(self.indices) 71 | 72 | 73 | class WeightedRandomSampler(Sampler): 74 | """Samples elements from [0,..,len(weights)-1] with given probabilities (weights). 75 | 76 | Arguments: 77 | weights (list) : a list of weights, not necessary summing up to one 78 | num_samples (int): number of samples to draw 79 | replacement (bool): if ``True``, samples are drawn with replacement. 80 | If not, they are drawn without replacement, which means that when a 81 | sample index is drawn for a row, it cannot be drawn again for that row. 82 | """ 83 | 84 | def __init__(self, weights, num_samples, replacement=True): 85 | self.weights = torch.DoubleTensor(weights) 86 | self.num_samples = num_samples 87 | self.replacement = replacement 88 | 89 | def __iter__(self): 90 | return iter(torch.multinomial(self.weights, self.num_samples, self.replacement)) 91 | 92 | def __len__(self): 93 | return self.num_samples 94 | 95 | 96 | class BatchSampler(object): 97 | """Wraps another sampler to yield a mini-batch of indices. 98 | 99 | Args: 100 | sampler (Sampler): Base sampler. 101 | batch_size (int): Size of mini-batch. 102 | drop_last (bool): If ``True``, the sampler will drop the last batch if 103 | its size would be less than ``batch_size`` 104 | 105 | Example: 106 | >>> list(BatchSampler(range(10), batch_size=3, drop_last=False)) 107 | [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] 108 | >>> list(BatchSampler(range(10), batch_size=3, drop_last=True)) 109 | [[0, 1, 2], [3, 4, 5], [6, 7, 8]] 110 | """ 111 | 112 | def __init__(self, sampler, batch_size, drop_last): 113 | self.sampler = sampler 114 | self.batch_size = batch_size 115 | self.drop_last = drop_last 116 | 117 | def __iter__(self): 118 | batch = [] 119 | for idx in self.sampler: 120 | batch.append(idx) 121 | if len(batch) == self.batch_size: 122 | yield batch 123 | batch = [] 124 | if len(batch) > 0 and not self.drop_last: 125 | yield batch 126 | 127 | def __len__(self): 128 | if self.drop_last: 129 | return len(self.sampler) // self.batch_size 130 | else: 131 | return (len(self.sampler) + self.batch_size - 1) // self.batch_size 132 | -------------------------------------------------------------------------------- /segmentTool/lib/utils/th.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import numpy as np 4 | import collections 5 | 6 | __all__ = ['as_variable', 'as_numpy', 'mark_volatile'] 7 | 8 | def as_variable(obj): 9 | if isinstance(obj, Variable): 10 | return obj 11 | if isinstance(obj, collections.Sequence): 12 | return [as_variable(v) for v in obj] 13 | elif isinstance(obj, collections.Mapping): 14 | return {k: as_variable(v) for k, v in obj.items()} 15 | else: 16 | return Variable(obj) 17 | 18 | def as_numpy(obj): 19 | if isinstance(obj, collections.Sequence): 20 | return [as_numpy(v) for v in obj] 21 | elif isinstance(obj, collections.Mapping): 22 | return {k: as_numpy(v) for k, v in obj.items()} 23 | elif isinstance(obj, Variable): 24 | return obj.data.cpu().numpy() 25 | elif torch.is_tensor(obj): 26 | return obj.cpu().numpy() 27 | else: 28 | return np.array(obj) 29 | 30 | def mark_volatile(obj): 31 | if torch.is_tensor(obj): 32 | obj = Variable(obj) 33 | if isinstance(obj, Variable): 34 | obj.no_grad = True 35 | return obj 36 | elif isinstance(obj, collections.Mapping): 37 | return {k: mark_volatile(o) for k, o in obj.items()} 38 | elif isinstance(obj, collections.Sequence): 39 | return [mark_volatile(o) for o in obj] 40 | else: 41 | return obj 42 | -------------------------------------------------------------------------------- /segmentTool/resnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import torch.nn as nn 5 | import math 6 | from segmentTool.lib.nn import SynchronizedBatchNorm2d 7 | 8 | try: 9 | from urllib import urlretrieve 10 | except ImportError: 11 | from urllib.request import urlretrieve 12 | 13 | 14 | __all__ = ['ResNet', 'resnet18', 'resnet50', 'resnet101'] # resnet101 is coming soon! 15 | 16 | 17 | model_urls = { 18 | 'resnet18': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet18-imagenet.pth', 19 | 'resnet50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet50-imagenet.pth', 20 | 'resnet101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet101-imagenet.pth' 21 | } 22 | 23 | 24 | def conv3x3(in_planes, out_planes, stride=1): 25 | "3x3 convolution with padding" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 27 | padding=1, bias=False) 28 | 29 | 30 | class BasicBlock(nn.Module): 31 | expansion = 1 32 | 33 | def __init__(self, inplanes, planes, stride=1, downsample=None): 34 | super(BasicBlock, self).__init__() 35 | self.conv1 = conv3x3(inplanes, planes, stride) 36 | self.bn1 = SynchronizedBatchNorm2d(planes) 37 | self.relu = nn.ReLU(inplace=True) 38 | self.conv2 = conv3x3(planes, planes) 39 | self.bn2 = SynchronizedBatchNorm2d(planes) 40 | self.downsample = downsample 41 | self.stride = stride 42 | 43 | def forward(self, x): 44 | residual = x 45 | 46 | out = self.conv1(x) 47 | out = self.bn1(out) 48 | out = self.relu(out) 49 | 50 | out = self.conv2(out) 51 | out = self.bn2(out) 52 | 53 | if self.downsample is not None: 54 | residual = self.downsample(x) 55 | 56 | out += residual 57 | out = self.relu(out) 58 | 59 | return out 60 | 61 | 62 | class Bottleneck(nn.Module): 63 | expansion = 4 64 | 65 | def __init__(self, inplanes, planes, stride=1, downsample=None): 66 | super(Bottleneck, self).__init__() 67 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 68 | self.bn1 = SynchronizedBatchNorm2d(planes) 69 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 70 | padding=1, bias=False) 71 | self.bn2 = SynchronizedBatchNorm2d(planes) 72 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 73 | self.bn3 = SynchronizedBatchNorm2d(planes * 4) 74 | self.relu = nn.ReLU(inplace=True) 75 | self.downsample = downsample 76 | self.stride = stride 77 | 78 | def forward(self, x): 79 | residual = x 80 | 81 | out = self.conv1(x) 82 | out = self.bn1(out) 83 | out = self.relu(out) 84 | 85 | out = self.conv2(out) 86 | out = self.bn2(out) 87 | out = self.relu(out) 88 | 89 | out = self.conv3(out) 90 | out = self.bn3(out) 91 | 92 | if self.downsample is not None: 93 | residual = self.downsample(x) 94 | 95 | out += residual 96 | out = self.relu(out) 97 | 98 | return out 99 | 100 | 101 | class ResNet(nn.Module): 102 | 103 | def __init__(self, block, layers, num_classes=1000): 104 | self.inplanes = 128 105 | super(ResNet, self).__init__() 106 | self.conv1 = conv3x3(3, 64, stride=2) 107 | self.bn1 = SynchronizedBatchNorm2d(64) 108 | self.relu1 = nn.ReLU(inplace=True) 109 | self.conv2 = conv3x3(64, 64) 110 | self.bn2 = SynchronizedBatchNorm2d(64) 111 | self.relu2 = nn.ReLU(inplace=True) 112 | self.conv3 = conv3x3(64, 128) 113 | self.bn3 = SynchronizedBatchNorm2d(128) 114 | self.relu3 = nn.ReLU(inplace=True) 115 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 116 | 117 | self.layer1 = self._make_layer(block, 64, layers[0]) 118 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 119 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 120 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 121 | self.avgpool = nn.AvgPool2d(7, stride=1) 122 | self.fc = nn.Linear(512 * block.expansion, num_classes) 123 | 124 | for m in self.modules(): 125 | if isinstance(m, nn.Conv2d): 126 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 127 | m.weight.data.normal_(0, math.sqrt(2. / n)) 128 | elif isinstance(m, SynchronizedBatchNorm2d): 129 | m.weight.data.fill_(1) 130 | m.bias.data.zero_() 131 | 132 | def _make_layer(self, block, planes, blocks, stride=1): 133 | downsample = None 134 | if stride != 1 or self.inplanes != planes * block.expansion: 135 | downsample = nn.Sequential( 136 | nn.Conv2d(self.inplanes, planes * block.expansion, 137 | kernel_size=1, stride=stride, bias=False), 138 | SynchronizedBatchNorm2d(planes * block.expansion), 139 | ) 140 | 141 | layers = [] 142 | layers.append(block(self.inplanes, planes, stride, downsample)) 143 | self.inplanes = planes * block.expansion 144 | for i in range(1, blocks): 145 | layers.append(block(self.inplanes, planes)) 146 | 147 | return nn.Sequential(*layers) 148 | 149 | def forward(self, x): 150 | x = self.relu1(self.bn1(self.conv1(x))) 151 | x = self.relu2(self.bn2(self.conv2(x))) 152 | x = self.relu3(self.bn3(self.conv3(x))) 153 | x = self.maxpool(x) 154 | 155 | x = self.layer1(x) 156 | x = self.layer2(x) 157 | x = self.layer3(x) 158 | x = self.layer4(x) 159 | 160 | x = self.avgpool(x) 161 | x = x.view(x.size(0), -1) 162 | x = self.fc(x) 163 | 164 | return x 165 | 166 | def resnet18(pretrained=False, **kwargs): 167 | """Constructs a ResNet-18 model. 168 | 169 | Args: 170 | pretrained (bool): If True, returns a model pre-trained on Places 171 | """ 172 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 173 | if pretrained: 174 | model.load_state_dict(load_url(model_urls['resnet18'])) 175 | return model 176 | 177 | ''' 178 | def resnet34(pretrained=False, **kwargs): 179 | """Constructs a ResNet-34 model. 180 | 181 | Args: 182 | pretrained (bool): If True, returns a model pre-trained on Places 183 | """ 184 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 185 | if pretrained: 186 | model.load_state_dict(load_url(model_urls['resnet34'])) 187 | return model 188 | ''' 189 | 190 | def resnet50(pretrained=False, **kwargs): 191 | """Constructs a ResNet-50 model. 192 | 193 | Args: 194 | pretrained (bool): If True, returns a model pre-trained on Places 195 | """ 196 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 197 | if pretrained: 198 | model.load_state_dict(load_url(model_urls['resnet50']), strict=False) 199 | return model 200 | 201 | 202 | def resnet101(pretrained=False, **kwargs): 203 | """Constructs a ResNet-101 model. 204 | 205 | Args: 206 | pretrained (bool): If True, returns a model pre-trained on Places 207 | """ 208 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 209 | if pretrained: 210 | model.load_state_dict(load_url(model_urls['resnet101']), strict=False) 211 | return model 212 | 213 | # def resnet152(pretrained=False, **kwargs): 214 | # """Constructs a ResNet-152 model. 215 | # 216 | # Args: 217 | # pretrained (bool): If True, returns a model pre-trained on Places 218 | # """ 219 | # model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 220 | # if pretrained: 221 | # model.load_state_dict(load_url(model_urls['resnet152'])) 222 | # return model 223 | 224 | def load_url(url, model_dir='./pretrained', map_location=None): 225 | if not os.path.exists(model_dir): 226 | os.makedirs(model_dir) 227 | filename = url.split('/')[-1] 228 | cached_file = os.path.join(model_dir, filename) 229 | if not os.path.exists(cached_file): 230 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) 231 | urlretrieve(url, cached_file) 232 | return torch.load(cached_file, map_location=map_location) 233 | -------------------------------------------------------------------------------- /segmentTool/resnext.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import torch.nn as nn 5 | import math 6 | from segmentTool.lib.nn import SynchronizedBatchNorm2d 7 | 8 | try: 9 | from urllib import urlretrieve 10 | except ImportError: 11 | from urllib.request import urlretrieve 12 | 13 | 14 | __all__ = ['ResNeXt', 'resnext101'] # support resnext 101 15 | 16 | 17 | model_urls = { 18 | #'resnext50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnext50-imagenet.pth', 19 | 'resnext101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnext101-imagenet.pth' 20 | } 21 | 22 | 23 | def conv3x3(in_planes, out_planes, stride=1): 24 | "3x3 convolution with padding" 25 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 26 | padding=1, bias=False) 27 | 28 | 29 | class GroupBottleneck(nn.Module): 30 | expansion = 2 31 | 32 | def __init__(self, inplanes, planes, stride=1, groups=1, downsample=None): 33 | super(GroupBottleneck, self).__init__() 34 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 35 | self.bn1 = SynchronizedBatchNorm2d(planes) 36 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 37 | padding=1, groups=groups, bias=False) 38 | self.bn2 = SynchronizedBatchNorm2d(planes) 39 | self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1, bias=False) 40 | self.bn3 = SynchronizedBatchNorm2d(planes * 2) 41 | self.relu = nn.ReLU(inplace=True) 42 | self.downsample = downsample 43 | self.stride = stride 44 | 45 | def forward(self, x): 46 | residual = x 47 | 48 | out = self.conv1(x) 49 | out = self.bn1(out) 50 | out = self.relu(out) 51 | 52 | out = self.conv2(out) 53 | out = self.bn2(out) 54 | out = self.relu(out) 55 | 56 | out = self.conv3(out) 57 | out = self.bn3(out) 58 | 59 | if self.downsample is not None: 60 | residual = self.downsample(x) 61 | 62 | out += residual 63 | out = self.relu(out) 64 | 65 | return out 66 | 67 | 68 | class ResNeXt(nn.Module): 69 | 70 | def __init__(self, block, layers, groups=32, num_classes=1000): 71 | self.inplanes = 128 72 | super(ResNeXt, self).__init__() 73 | self.conv1 = conv3x3(3, 64, stride=2) 74 | self.bn1 = SynchronizedBatchNorm2d(64) 75 | self.relu1 = nn.ReLU(inplace=True) 76 | self.conv2 = conv3x3(64, 64) 77 | self.bn2 = SynchronizedBatchNorm2d(64) 78 | self.relu2 = nn.ReLU(inplace=True) 79 | self.conv3 = conv3x3(64, 128) 80 | self.bn3 = SynchronizedBatchNorm2d(128) 81 | self.relu3 = nn.ReLU(inplace=True) 82 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 83 | 84 | self.layer1 = self._make_layer(block, 128, layers[0], groups=groups) 85 | self.layer2 = self._make_layer(block, 256, layers[1], stride=2, groups=groups) 86 | self.layer3 = self._make_layer(block, 512, layers[2], stride=2, groups=groups) 87 | self.layer4 = self._make_layer(block, 1024, layers[3], stride=2, groups=groups) 88 | self.avgpool = nn.AvgPool2d(7, stride=1) 89 | self.fc = nn.Linear(1024 * block.expansion, num_classes) 90 | 91 | for m in self.modules(): 92 | if isinstance(m, nn.Conv2d): 93 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels // m.groups 94 | m.weight.data.normal_(0, math.sqrt(2. / n)) 95 | elif isinstance(m, SynchronizedBatchNorm2d): 96 | m.weight.data.fill_(1) 97 | m.bias.data.zero_() 98 | 99 | def _make_layer(self, block, planes, blocks, stride=1, groups=1): 100 | downsample = None 101 | if stride != 1 or self.inplanes != planes * block.expansion: 102 | downsample = nn.Sequential( 103 | nn.Conv2d(self.inplanes, planes * block.expansion, 104 | kernel_size=1, stride=stride, bias=False), 105 | SynchronizedBatchNorm2d(planes * block.expansion), 106 | ) 107 | 108 | layers = [] 109 | layers.append(block(self.inplanes, planes, stride, groups, downsample)) 110 | self.inplanes = planes * block.expansion 111 | for i in range(1, blocks): 112 | layers.append(block(self.inplanes, planes, groups=groups)) 113 | 114 | return nn.Sequential(*layers) 115 | 116 | def forward(self, x): 117 | x = self.relu1(self.bn1(self.conv1(x))) 118 | x = self.relu2(self.bn2(self.conv2(x))) 119 | x = self.relu3(self.bn3(self.conv3(x))) 120 | x = self.maxpool(x) 121 | 122 | x = self.layer1(x) 123 | x = self.layer2(x) 124 | x = self.layer3(x) 125 | x = self.layer4(x) 126 | 127 | x = self.avgpool(x) 128 | x = x.view(x.size(0), -1) 129 | x = self.fc(x) 130 | 131 | return x 132 | 133 | 134 | ''' 135 | def resnext50(pretrained=False, **kwargs): 136 | """Constructs a ResNet-50 model. 137 | 138 | Args: 139 | pretrained (bool): If True, returns a model pre-trained on Places 140 | """ 141 | model = ResNeXt(GroupBottleneck, [3, 4, 6, 3], **kwargs) 142 | if pretrained: 143 | model.load_state_dict(load_url(model_urls['resnext50']), strict=False) 144 | return model 145 | ''' 146 | 147 | 148 | def resnext101(pretrained=False, **kwargs): 149 | """Constructs a ResNet-101 model. 150 | 151 | Args: 152 | pretrained (bool): If True, returns a model pre-trained on Places 153 | """ 154 | model = ResNeXt(GroupBottleneck, [3, 4, 23, 3], **kwargs) 155 | if pretrained: 156 | model.load_state_dict(load_url(model_urls['resnext101']), strict=False) 157 | return model 158 | 159 | 160 | # def resnext152(pretrained=False, **kwargs): 161 | # """Constructs a ResNeXt-152 model. 162 | # 163 | # Args: 164 | # pretrained (bool): If True, returns a model pre-trained on Places 165 | # """ 166 | # model = ResNeXt(GroupBottleneck, [3, 8, 36, 3], **kwargs) 167 | # if pretrained: 168 | # model.load_state_dict(load_url(model_urls['resnext152'])) 169 | # return model 170 | 171 | 172 | def load_url(url, model_dir='./pretrained', map_location=None): 173 | if not os.path.exists(model_dir): 174 | os.makedirs(model_dir) 175 | filename = url.split('/')[-1] 176 | cached_file = os.path.join(model_dir, filename) 177 | if not os.path.exists(cached_file): 178 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) 179 | urlretrieve(url, cached_file) 180 | return torch.load(cached_file, map_location=map_location) 181 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : test.py 4 | # Author : Bowen Pan 5 | # Email : panbowen0607@gmail.com 6 | # Date : 09/25/2018 7 | # 8 | # Distributed under terms of the MIT license. 9 | 10 | """ 11 | 12 | """ 13 | from utils import Foo 14 | from models import VPNModel 15 | from datasets import OVMDataset 16 | from opts import parser 17 | from transform import * 18 | import torchvision 19 | import torch 20 | from torch import nn 21 | import os 22 | import time 23 | import cv2 24 | import dominate 25 | from dominate.tags import * 26 | 27 | mean_rgb = [0.485, 0.456, 0.406] 28 | std_rgb = [0.229, 0.224, 0.225] 29 | 30 | def main(): 31 | global args, web_path, best_prec1 32 | best_prec1 = 0 33 | args = parser.parse_args() 34 | network_config = Foo( 35 | encoder=args.encoder, 36 | decoder=args.decoder, 37 | fc_dim=args.fc_dim, 38 | num_views=args.n_views, 39 | num_class=94, 40 | transform_type=args.transform_type, 41 | output_size=args.label_resolution, 42 | ) 43 | 44 | val_dataset = OVMDataset(args.data_root, args.eval_list, 45 | transform=torchvision.transforms.Compose([ 46 | Stack(roll=True), 47 | ToTorchFormatTensor(div=True), 48 | GroupNormalize(mean_rgb, std_rgb) 49 | ]), 50 | num_views=network_config.num_views, input_size=args.input_resolution, 51 | label_size=args.segSize, use_mask=args.use_mask, use_depth=args.use_depth, is_train=False) 52 | 53 | val_loader = torch.utils.data.DataLoader( 54 | val_dataset, batch_size=args.batch_size, 55 | num_workers=args.num_workers, shuffle=False, 56 | pin_memory=True 57 | ) 58 | 59 | 60 | mapper = VPNModel(network_config) 61 | mapper = nn.DataParallel(mapper.cuda()) 62 | 63 | if args.weights: 64 | if os.path.isfile(args.weights): 65 | print(("=> loading checkpoint '{}'".format(args.weights))) 66 | checkpoint = torch.load(args.weights) 67 | args.start_epoch = checkpoint['epoch'] 68 | mapper.load_state_dict(checkpoint['state_dict']) 69 | print(("=> loaded checkpoint '{}' (epoch {})" 70 | .format(args.evaluate, checkpoint['epoch']))) 71 | else: 72 | print(("=> no checkpoint found at '{}'".format(args.weights))) 73 | 74 | 75 | criterion = nn.NLLLoss(weight=None, size_average=True) 76 | eval(val_loader, mapper, criterion) 77 | 78 | web_path = os.path.join(args.visualize, args.store_name) 79 | if os.path.isdir(web_path): 80 | pass 81 | else: 82 | os.makedirs(web_path) 83 | 84 | with dominate.document(title=web_path) as web: 85 | for step in range(len(val_loader)): 86 | if step % args.print_freq == 0: 87 | h2('Step {}'.format(step*args.batch_size)) 88 | with table(border = 1, style = 'table-layout: fixed;'): 89 | with tr(): 90 | for i in range(args.test_views): 91 | path = 'Step-{}-{}.png'.format(step * args.batch_size, i) 92 | with td(style='word-wrap: break-word;', halign='center', valign='top'): 93 | img(style='width:128px', src=path) 94 | path = 'Step-{}-pred.png'.format(step * args.batch_size) 95 | with td(style='word-wrap: break-word;', halign='center', valign='top'): 96 | img(style='width:128px', src=path) 97 | path = 'Step-{}-gt.png'.format(step * args.batch_size) 98 | with td(style='word-wrap: break-word;', halign='center', valign='top'): 99 | img(style='width:128px', src=path) 100 | 101 | with open(os.path.join(web_path, 'index.html'), 'w') as fp: 102 | fp.write(web.render()) 103 | 104 | 105 | 106 | def eval(val_loader, mapper, criterion): 107 | batch_time = AverageMeter() 108 | data_time = AverageMeter() 109 | losses = AverageMeter() 110 | top1 = AverageMeter() 111 | top5 = AverageMeter() 112 | 113 | mapper.eval() 114 | 115 | end = time.time() 116 | 117 | web_path = os.path.join(args.visualize, args.store_name) 118 | if os.path.isdir(web_path): 119 | pass 120 | else: 121 | os.makedirs(web_path) 122 | 123 | prec_stat = {} 124 | for i in range(args.num_class): 125 | prec_stat[str(i)] = {'intersec': 0, 'union': 0, 'all': 0} 126 | 127 | with open('./metadata/colormap_coarse.csv') as f: 128 | lines = f.readlines() 129 | cat = [] 130 | for line in lines: 131 | line = line.rstrip() 132 | cat.append(line) 133 | cat = cat[1:] 134 | label_dic = {} 135 | for i, value in enumerate(cat): 136 | key = str(i) 137 | label_dic[key] = [int(x) for x in value.split(',')[1:]] 138 | 139 | for step, (rgb_stack, target, rgb_origin, OverMaskOrigin) in enumerate(val_loader): 140 | data_time.update(time.time() - end) 141 | with torch.no_grad(): 142 | input_rgb_var = torch.autograd.variable(rgb_stack).cuda() 143 | _, output = mapper(x=input_rgb_var, return_feat=True) 144 | target_var = target.cuda() 145 | target_var = target_var.view(-1) 146 | upsample = output.view(-1, args.label_resolution, args.label_resolution, args.num_class).transpose(3,2).transpose(2,1).contiguous() 147 | upsample = nn.functional.upsample(upsample, size=args.segSize, mode='bilinear', align_corners=False) 148 | upsample = nn.functional.softmax(upsample, dim=1) 149 | output = torch.log(upsample.transpose(1,2).transpose(2,3).contiguous().view(-1, args.num_class)) 150 | _, pred = upsample.data.topk(1, 1, True, True) 151 | pred = pred.squeeze(1) 152 | loss = criterion(output, target_var) 153 | losses.update(loss.data[0], input_rgb_var.size(0)) 154 | prec_stat = count_mean_accuracy(output.data, target_var.data, prec_stat) 155 | prec1 = accuracy(output.data, target_var.data, topk=(1,))[0] 156 | 157 | batch_time.update(time.time() - end) 158 | end = time.time() 159 | 160 | if step % args.print_freq == 0: 161 | output = ('Test: [{0}/{1}]\t' 162 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 163 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 164 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 165 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 166 | step + 1, len(val_loader), batch_time=batch_time, 167 | data_time=data_time, loss=losses, top1=top1)) 168 | print(output) 169 | 170 | pred = np.uint8(pred.cpu()[0]) 171 | predMask = np.uint8(np.zeros((args.segSize, args.segSize, 3))) 172 | for i, _ in enumerate(pred): 173 | for j, _ in enumerate(pred[0]): 174 | key = str(pred[i][j]) 175 | predMask[i,j] = label_dic[key] 176 | predMask = cv2.resize(predMask[:, :, ::-1], (256, 256), interpolation=cv2.INTER_NEAREST) 177 | cv2.imwrite(os.path.join(web_path, 'Step-{}-pred.png'.format(step * args.batch_size, i)), predMask) 178 | 179 | gtMask = OverMaskOrigin[0].cpu().numpy() 180 | gtMask = cv2.resize(gtMask, (256, 256), interpolation=cv2.INTER_NEAREST) 181 | cv2.imwrite(os.path.join(web_path, 'Step-{}-gt.png'.format(step * args.batch_size, i)), gtMask) 182 | 183 | rgb = rgb_origin.cpu().numpy()[0] 184 | for i in range(args.test_views): 185 | cv2.imwrite(os.path.join(web_path, 'Step-{}-{}.png'.format(step * args.batch_size, i)), cv2.resize(rgb[(i + args.view_bias) % 8], (256, 256), interpolation=cv2.INTER_NEAREST)) 186 | 187 | sum_acc = 0 188 | counted_cat = 0 189 | sum_iou = 0 190 | for key in prec_stat: 191 | if int(prec_stat[key]['all']) != 0: 192 | acc = prec_stat[key]['intersec'] / (prec_stat[key]['all'] + 1e-10) 193 | iou = prec_stat[key]['intersec'] / (prec_stat[key]['union'] + 1e-10) 194 | sum_acc += acc 195 | sum_iou += iou 196 | counted_cat += 1 197 | mean_acc = sum_acc / counted_cat 198 | mean_iou = sum_iou / counted_cat 199 | output = ('Testing Results: Prec@1 {top1.avg:.3f} Mean Prec@1 {meantop:.3f} Mean IoU {meaniou:.3f} Loss {loss.avg:.5f}' 200 | .format(top1=top1, loss=losses, meantop=mean_acc, meaniou=mean_iou)) 201 | print(output) 202 | output_best = '\nBest Prec@1 of: %.3f' % (best_prec1) 203 | print(output_best) 204 | 205 | return top1.avg 206 | 207 | class AverageMeter(object): 208 | """Computes and stores the average and current value""" 209 | def __init__(self): 210 | self.reset() 211 | 212 | def reset(self): 213 | self.val = 0 214 | self.avg = 0 215 | self.sum = 0 216 | self.count = 0 217 | 218 | def update(self, val, n=1): 219 | self.val = val 220 | self.sum += val * n 221 | self.count += n 222 | self.avg = self.sum / self.count 223 | 224 | def count_mean_accuracy(output, target, prec_stat): 225 | _, pred = output.topk(1, 1, True, True) 226 | pred = pred.squeeze(1) 227 | for key in prec_stat.keys(): 228 | label = int(key) 229 | pred_map = np.uint8(pred.cpu().numpy() == label) 230 | target_map = np.uint8(target.cpu().numpy() == label) 231 | intersection_t = pred_map * (pred_map == target_map) 232 | union_t = pred_map + target_map - intersection_t 233 | prec_stat[key]['intersec'] += np.sum(intersection_t) 234 | prec_stat[key]['union'] += np.sum(union_t) 235 | prec_stat[key]['all'] += np.sum(target_map) 236 | return prec_stat 237 | 238 | def accuracy(output, target, topk=(1,)): 239 | """Computes the precision@k for the specified values of k""" 240 | maxk = max(topk) 241 | batch_size = target.size(0) 242 | 243 | _, pred = output.topk(maxk, 1, True, True) 244 | pred = pred.t() 245 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 246 | 247 | res = [] 248 | for k in topk: 249 | correct_k = correct[:k].view(-1).float().sum(0) 250 | res.append(correct_k.mul_(100.0 / batch_size)) 251 | return res 252 | 253 | if __name__=='__main__': 254 | main() 255 | -------------------------------------------------------------------------------- /test_carla.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : test.py 4 | # Author : Bowen Pan 5 | # Email : panbowen0607@gmail.com 6 | # Date : 09/25/2018 7 | # 8 | # Distributed under terms of the MIT license. 9 | 10 | """ 11 | 12 | """ 13 | from utils import Foo 14 | from models import VPNModel 15 | from datasets import OVMDataset 16 | from opts import parser 17 | from transform import * 18 | import torchvision 19 | import torch 20 | from torch import nn 21 | from torch.optim.lr_scheduler import MultiStepLR 22 | from torch import optim 23 | import os 24 | import time 25 | from torch.nn.utils import clip_grad_norm 26 | # from examples.cognitive_mapping.Logger import Logger 27 | import cv2 28 | import shutil 29 | import dominate 30 | from dominate.tags import * 31 | 32 | mean_rgb = [0.485, 0.456, 0.406] 33 | std_rgb = [0.229, 0.224, 0.225] 34 | 35 | def main(): 36 | global args, web_path, best_prec1 37 | parser.add_argument('--test-views', type=int, default=94) 38 | parser.add_argument('--view-bias', type=int, default=8) 39 | 40 | best_prec1 = 0 41 | args = parser.parse_args() 42 | network_config = Foo( 43 | encoder=args.encoder, 44 | decoder=args.decoder, 45 | fc_dim=args.fc_dim, 46 | num_views=args.n_views, 47 | num_class=args.num_class, 48 | transform_type=args.transform_type, 49 | output_size=args.label_resolution, 50 | ) 51 | 52 | val_dataset = OVMDataset(args.data_root, args.eval_list, 53 | transform=torchvision.transforms.Compose([ 54 | Stack(roll=True), 55 | ToTorchFormatTensor(div=True), 56 | GroupNormalize(mean_rgb, std_rgb) 57 | ]), 58 | num_views=network_config.num_views, input_size=args.input_resolution, 59 | label_size=args.segSize, use_mask=args.use_mask, use_depth=args.use_depth, is_train=False) 60 | 61 | val_loader = torch.utils.data.DataLoader( 62 | val_dataset, batch_size=args.batch_size, 63 | num_workers=args.num_workers, shuffle=False, 64 | pin_memory=True 65 | ) 66 | 67 | 68 | mapper = VPNModel(network_config) 69 | mapper = nn.DataParallel(mapper.cuda()) 70 | 71 | if args.weights: 72 | if os.path.isfile(args.weights): 73 | print(("=> loading checkpoint '{}'".format(args.weights))) 74 | checkpoint = torch.load(args.weights) 75 | args.start_epoch = checkpoint['epoch'] 76 | mapper.load_state_dict(checkpoint['state_dict']) 77 | print(("=> loaded checkpoint '{}' (epoch {})" 78 | .format(args.evaluate, checkpoint['epoch']))) 79 | else: 80 | print(("=> no checkpoint found at '{}'".format(args.weights))) 81 | 82 | 83 | criterion = nn.NLLLoss(weight=None, size_average=True) 84 | eval(val_loader, mapper, criterion) 85 | 86 | web_path = os.path.join(args.visualize, args.store_name) 87 | if os.path.isdir(web_path): 88 | pass 89 | else: 90 | os.makedirs(web_path) 91 | 92 | with dominate.document(title=web_path) as web: 93 | for step in range(len(val_loader)): 94 | if step % args.print_freq == 0: 95 | h2('Step {}'.format(step*args.batch_size)) 96 | with table(border = 1, style = 'table-layout: fixed;'): 97 | with tr(): 98 | for i in range(args.test_views): 99 | path = 'Step-{}-{}.png'.format(step * args.batch_size, i) 100 | with td(style='word-wrap: break-word;', halign='center', valign='top'): 101 | img(style='width:128px', src=path) 102 | path = 'Step-{}-pred.png'.format(step * args.batch_size) 103 | with td(style='word-wrap: break-word;', halign='center', valign='top'): 104 | img(style='width:128px', src=path) 105 | path = 'Step-{}-gt.png'.format(step * args.batch_size) 106 | with td(style='word-wrap: break-word;', halign='center', valign='top'): 107 | img(style='width:128px', src=path) 108 | 109 | with open(os.path.join(web_path, 'index.html'), 'w') as fp: 110 | fp.write(web.render()) 111 | 112 | 113 | 114 | def eval(val_loader, mapper, criterion): 115 | batch_time = AverageMeter() 116 | data_time = AverageMeter() 117 | losses = AverageMeter() 118 | top1 = AverageMeter() 119 | top5 = AverageMeter() 120 | best_prec1 = 0 121 | 122 | mapper.eval() 123 | 124 | end = time.time() 125 | 126 | web_path = os.path.join(args.visualize, args.store_name) 127 | if os.path.isdir(web_path): 128 | pass 129 | else: 130 | os.makedirs(web_path) 131 | 132 | prec_stat = {} 133 | for i in range(args.num_class): 134 | prec_stat[str(i)] = {'intersec': 0, 'union': 0, 'all': 0} 135 | 136 | with open('./metadata/colormap_coarse.csv') as f: 137 | lines = f.readlines() 138 | cat = [] 139 | for line in lines: 140 | line = line.rstrip() 141 | cat.append(line) 142 | cat = cat[1:] 143 | label_dic = {} 144 | for i, value in enumerate(cat): 145 | key = str(i) 146 | label_dic[key] = [int(x) for x in value.split(',')[1:]] 147 | 148 | for step, (rgb_stack, target, rgb_origin, OverMaskOrigin) in enumerate(val_loader): 149 | data_time.update(time.time() - end) 150 | with torch.no_grad(): 151 | input_rgb_var = torch.autograd.variable(rgb_stack).cuda() 152 | _, output = mapper(x=input_rgb_var, return_feat=True) 153 | target_var = target.cuda() 154 | target_var = target_var.view(-1) 155 | upsample = output.view(-1, args.label_resolution, args.label_resolution, args.num_class).transpose(3,2).transpose(2,1).contiguous() 156 | upsample = nn.functional.upsample(upsample, size=args.segSize, mode='bilinear', align_corners=False) 157 | upsample = nn.functional.softmax(upsample, dim=1) 158 | output = torch.log(upsample.transpose(1,2).transpose(2,3).contiguous().view(-1, args.num_class)) 159 | _, pred = upsample.data.topk(1, 1, True, True) 160 | pred = pred.squeeze(1) 161 | loss = criterion(output, target_var) 162 | losses.update(loss.item(), input_rgb_var.size(0)) 163 | prec_stat = count_mean_accuracy(output.data, target_var.data, prec_stat) 164 | prec1 = accuracy(output.data, target_var.data, topk=(1,))[0] 165 | top1.update(prec1.item(), rgb_stack.size(0)) 166 | best_prec1 = max(prec1, best_prec1) 167 | 168 | batch_time.update(time.time() - end) 169 | end = time.time() 170 | 171 | if step % args.print_freq == 0: 172 | output = ('Test: [{0}/{1}]\t' 173 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 174 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 175 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 176 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 177 | step + 1, len(val_loader), batch_time=batch_time, 178 | data_time=data_time, loss=losses, top1=top1)) 179 | print(output) 180 | 181 | pred = np.uint8(pred.cpu()[0]) 182 | predMask = np.uint8(np.zeros((args.segSize, args.segSize, 3))) 183 | for i, _ in enumerate(pred): 184 | for j, _ in enumerate(pred[0]): 185 | key = str(pred[i][j]) 186 | predMask[i,j] = label_dic[key] 187 | predMask = cv2.resize(predMask[:, :, ::-1], (256, 256), interpolation=cv2.INTER_NEAREST) 188 | cv2.imwrite(os.path.join(web_path, 'Step-{}-pred.png'.format(step * args.batch_size, i)), predMask) 189 | 190 | gtMask = OverMaskOrigin[0].cpu().numpy() 191 | print('gtMask.shape: ', gtMask.shape) 192 | gt_rgb = np.uint8(np.zeros((gtMask.shape[0], gtMask.shape[0], 3))) 193 | for i, _ in enumerate(gtMask): 194 | for j, _ in enumerate(gtMask[0]): 195 | key = str(gtMask[i][j]) 196 | gt_rgb[i,j] = label_dic[key] 197 | 198 | gt_rgb = cv2.resize(gt_rgb[:, :, ::-1], (256, 256), interpolation=cv2.INTER_NEAREST) 199 | cv2.imwrite(os.path.join(web_path, 'Step-{}-gt.png'.format(step * args.batch_size, i)), gt_rgb) 200 | 201 | rgb = rgb_origin.cpu().numpy()[0] 202 | for i in range(args.test_views): 203 | cv2.imwrite(os.path.join(web_path, 'Step-{}-{}.png'.format(step * args.batch_size, i)), cv2.resize(rgb[(i + args.view_bias) % 8], (256, 256), interpolation=cv2.INTER_NEAREST)) 204 | 205 | sum_acc = 0 206 | counted_cat = 0 207 | sum_iou = 0 208 | for key in prec_stat: 209 | if int(prec_stat[key]['all']) != 0: 210 | acc = prec_stat[key]['intersec'] / (prec_stat[key]['all'] + 1e-10) 211 | iou = prec_stat[key]['intersec'] / (prec_stat[key]['union'] + 1e-10) 212 | sum_acc += acc 213 | sum_iou += iou 214 | counted_cat += 1 215 | mean_acc = sum_acc / counted_cat 216 | mean_iou = sum_iou / counted_cat 217 | output = ('Testing Results: Prec@1 {top1.avg:.3f} Mean Prec@1 {meantop:.3f} Mean IoU {meaniou:.3f} Loss {loss.avg:.5f}' 218 | .format(top1=top1, loss=losses, meantop=mean_acc, meaniou=mean_iou)) 219 | print(output) 220 | output_best = '\nBest Prec@1 of: %.3f' % (best_prec1) 221 | print(output_best) 222 | 223 | return top1.avg 224 | 225 | class AverageMeter(object): 226 | """Computes and stores the average and current value""" 227 | def __init__(self): 228 | self.reset() 229 | 230 | def reset(self): 231 | self.val = 0 232 | self.avg = 0 233 | self.sum = 0 234 | self.count = 0 235 | 236 | def update(self, val, n=1): 237 | self.val = val 238 | self.sum += val * n 239 | self.count += n 240 | self.avg = self.sum / self.count 241 | 242 | def count_mean_accuracy(output, target, prec_stat): 243 | _, pred = output.topk(1, 1, True, True) 244 | pred = pred.squeeze(1) 245 | for key in prec_stat.keys(): 246 | label = int(key) 247 | pred_map = np.uint8(pred.cpu().numpy() == label) 248 | target_map = np.uint8(target.cpu().numpy() == label) 249 | intersection_t = pred_map * (pred_map == target_map) 250 | union_t = pred_map + target_map - intersection_t 251 | prec_stat[key]['intersec'] += np.sum(intersection_t) 252 | prec_stat[key]['union'] += np.sum(union_t) 253 | prec_stat[key]['all'] += np.sum(target_map) 254 | return prec_stat 255 | 256 | def accuracy(output, target, topk=(1,)): 257 | """Computes the precision@k for the specified values of k""" 258 | maxk = max(topk) 259 | batch_size = target.size(0) 260 | 261 | _, pred = output.topk(maxk, 1, True, True) 262 | pred = pred.t() 263 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 264 | 265 | res = [] 266 | for k in topk: 267 | correct_k = correct[:k].view(-1).float().sum(0) 268 | res.append(correct_k.mul_(100.0 / batch_size)) 269 | return res 270 | 271 | if __name__=='__main__': 272 | main() 273 | -------------------------------------------------------------------------------- /test_seq.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : test.py 4 | # Author : Bowen Pan 5 | # Email : panbowen0607@gmail.com 6 | # Date : 09/25/2018 7 | # 8 | # Distributed under terms of the MIT license. 9 | 10 | """ 11 | 12 | """ 13 | from utils import Foo 14 | from models import VPNModel 15 | from datasets import Seq_OVMDataset 16 | from opts import parser 17 | from transform import * 18 | import torchvision 19 | import torch 20 | from torch import nn 21 | from torch.optim.lr_scheduler import MultiStepLR 22 | from torch import optim 23 | import os 24 | import time 25 | from torch.nn.utils import clip_grad_norm 26 | import cv2 27 | import shutil 28 | import dominate 29 | from dominate.tags import * 30 | import moviepy.editor as mpy 31 | 32 | 33 | mean_rgb = [0.485, 0.456, 0.406] 34 | std_rgb = [0.229, 0.224, 0.225] 35 | 36 | def main(): 37 | global args, web_path, best_prec1 38 | best_prec1 = 0 39 | args = parser.parse_args() 40 | network_config = Foo( 41 | encoder=args.encoder, 42 | decoder=args.decoder, 43 | fc_dim=args.fc_dim, 44 | num_views=args.n_views, 45 | num_class=94, 46 | transform_type=args.transform_type, 47 | output_size=args.label_resolution, 48 | ) 49 | 50 | val_dataset = Seq_OVMDataset(args.test_dir, pix_file=args.pix_file, 51 | transform=torchvision.transforms.Compose([ 52 | Stack(roll=True), 53 | ToTorchFormatTensor(div=True), 54 | GroupNormalize(mean_rgb, std_rgb) 55 | ]), 56 | n_views=network_config.num_views, resolution=args.input_resolution, 57 | label_res=args.label_resolution, use_mask=args.use_mask, is_train=False) 58 | 59 | val_loader = torch.utils.data.DataLoader( 60 | val_dataset, batch_size=1, 61 | shuffle=False, pin_memory=True 62 | ) 63 | 64 | 65 | mapper = VPNModel(network_config) 66 | mapper = nn.DataParallel(mapper.cuda()) 67 | 68 | if args.weights: 69 | if os.path.isfile(args.weights): 70 | print(("=> loading checkpoint '{}'".format(args.weights))) 71 | checkpoint = torch.load(args.weights) 72 | args.start_epoch = checkpoint['epoch'] 73 | mapper.load_state_dict(checkpoint['state_dict']) 74 | print(("=> loaded checkpoint '{}' (epoch {})" 75 | .format(args.evaluate, checkpoint['epoch']))) 76 | else: 77 | print(("=> no checkpoint found at '{}'".format(args.weights))) 78 | 79 | web_path = os.path.join(args.visualize, args.store_name) 80 | criterion = nn.NLLLoss(weight=None, size_average=True) 81 | eval(val_loader, mapper, criterion, web_path) 82 | 83 | web_path = os.path.join(args.visualize, args.store_name) 84 | 85 | 86 | def eval(val_loader, mapper, criterion, web_path): 87 | batch_time = AverageMeter() 88 | data_time = AverageMeter() 89 | losses = AverageMeter() 90 | top1 = AverageMeter() 91 | top5 = AverageMeter() 92 | 93 | mapper.eval() 94 | 95 | end = time.time() 96 | if os.path.isdir(web_path): 97 | pass 98 | else: 99 | os.makedirs(web_path) 100 | 101 | frames = [] 102 | 103 | prec_stat = {} 104 | for i in range(args.num_class): 105 | prec_stat[str(i)] = {'intersec': 0, 'all': 0} 106 | 107 | with open('./metadata/colormap_coarse.csv') as f: 108 | lines = f.readlines() 109 | cat = [] 110 | for line in lines: 111 | line = line.rstrip() 112 | cat.append(line) 113 | cat = cat[1:] 114 | label_dic = {} 115 | for i, value in enumerate(cat): 116 | key = str(i) 117 | label_dic[key] = [int(x) for x in value.split(',')[1:]] 118 | 119 | reachable = [10, 40, 43, 45, 64, 80] 120 | 121 | for step, (rgb_stack, target, rgb_origin, topmap, OverMaskOrigin) in enumerate(val_loader): 122 | data_time.update(time.time() - end) 123 | with torch.no_grad(): 124 | input_rgb_var = torch.autograd.Variable(rgb_stack).cuda() 125 | _, output = mapper(x=input_rgb_var, test_comb=[x * int(args.n_views / args.test_views) for x in list(range(args.test_views))], return_feat=True) 126 | target_var = target.cuda() 127 | target_var = target_var.view(-1) 128 | output = output.view(-1, args.num_class) 129 | upsample = output.view(-1, args.label_resolution, args.label_resolution, args.num_class).transpose(3,2).transpose(2,1).contiguous() 130 | upsample = nn.functional.upsample(upsample, size=args.segSize, mode='bilinear', align_corners=False) 131 | upsample = nn.functional.softmax(upsample, dim=1) 132 | freemap = upsample.data.index_select(dim=1, index=torch.Tensor(reachable).long().cuda()) 133 | freemap = freemap.sum(dim=1, keepdim=False) 134 | output = nn.functional.log_softmax(output, dim=1) 135 | _, pred = upsample.data.topk(1, 1, True, True) 136 | pred = pred.squeeze(1) 137 | loss = criterion(output, target_var) 138 | losses.update(loss.data[0], input_rgb_var.size(0)) 139 | prec_stat = count_mean_accuracy(output.data, target_var.data, prec_stat) 140 | prec1, prec5 = accuracy(output.data, target_var.data, topk=(1, 5)) 141 | top1.update(prec1[0], rgb_stack.size(0)) 142 | top5.update(prec5[0], rgb_stack.size(0)) 143 | 144 | batch_time.update(time.time() - end) 145 | end = time.time() 146 | 147 | if step % 1 == 0: 148 | output = ('Test: [{0}/{1}]\t' 149 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 150 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 151 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 152 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 153 | step + 1, len(val_loader), batch_time=batch_time, 154 | data_time=data_time, loss=losses, top1=top1)) 155 | print(output) 156 | 157 | syn_map = np.zeros((256*3, 256*4 + 40, 3)) 158 | syn_map += 255 159 | 160 | pred = np.uint8(pred.cpu()[0]) 161 | predMask = np.uint8(np.zeros((args.segSize, args.segSize, 3))) 162 | for i, _ in enumerate(pred): 163 | for j, _ in enumerate(pred[0]): 164 | key = str(pred[i][j]) 165 | predMask[i,j] = label_dic[key] 166 | predMask = cv2.resize(predMask[:, :, ::-1], (256, 256), interpolation=cv2.INTER_NEAREST) 167 | 168 | gtMask = OverMaskOrigin[0].cpu().numpy() 169 | gtMask = cv2.resize(gtMask, (256, 256), interpolation=cv2.INTER_NEAREST) 170 | syn_map[:256, 256*3 + 40:] = gtMask 171 | syn_map[256:256*2, 256*3 + 40:] = predMask 172 | 173 | rgb = rgb_origin.cpu().numpy()[0] 174 | topmap = topmap.cpu().numpy() 175 | 176 | freemap = np.uint8(freemap[0]*255) 177 | freemap = cv2.applyColorMap(freemap, cv2.COLORMAP_JET) 178 | freemap = cv2.resize(freemap, (256, 256), interpolation=cv2.INTER_NEAREST) 179 | syn_map[256*2:256*3, 256*3 + 40:] = freemap 180 | 181 | orient_rank = [-2, 6, -2, 0, -1, 4, -2, 2, -2] 182 | for i, orient in enumerate(orient_rank): 183 | if orient >= 0: 184 | syn_map[(i%3)*256:(i%3+1)*256, int(i/3)*256:int(i/3+1)*256] = cv2.resize(rgb[orient], (256, 256)) 185 | elif orient == -1: 186 | syn_map[(i%3)*256:(i%3+1)*256, int(i/3)*256:int(i/3+1)*256] = topmap[:, :, ::-1] 187 | elif orient == -2: 188 | syn_map[(i%3)*256:(i%3+1)*256, int(i/3)*256:int(i/3+1)*256] = 240 189 | for i in range(2): 190 | syn_map[256 * (i + 1) - 3: 256 * (i + 1) + 3] = 255 191 | for i in range(3): 192 | syn_map[:, 256 * (i + 1) -3: 256 * (i + 1) + 3] = 255 193 | cv2.imwrite(os.path.join(web_path, 'syn_step%d.jpg'%(step+1)), syn_map) 194 | frames.append(syn_map[:, :, ::-1]) 195 | clip = mpy.ImageSequenceClip(frames, fps=8) 196 | clip.write_videofile(os.path.join(web_path, 'OverviewSemVideo.mp4'), fps=8) 197 | 198 | 199 | sum_acc = 0 200 | counted_cat = 0 201 | for key in prec_stat: 202 | if int(prec_stat[key]['all']) != 0: 203 | acc = prec_stat[key]['intersec'] / (prec_stat[key]['all'] + 1e-10) 204 | sum_acc += acc 205 | counted_cat += 1 206 | mean_acc = sum_acc / counted_cat 207 | output = ('Testing Results: Prec@1 {top1.avg:.3f} Mean Prec@1 {meantop:.3f} Loss {loss.avg:.5f}' 208 | .format(top1=top1, loss=losses, meantop=mean_acc)) 209 | print(output) 210 | output_best = '\nBest Prec@1 of: %.3f' % (best_prec1) 211 | print(output_best) 212 | 213 | return top1.avg 214 | 215 | class AverageMeter(object): 216 | """Computes and stores the average and current value""" 217 | def __init__(self): 218 | self.reset() 219 | 220 | def reset(self): 221 | self.val = 0 222 | self.avg = 0 223 | self.sum = 0 224 | self.count = 0 225 | 226 | def update(self, val, n=1): 227 | self.val = val 228 | self.sum += val * n 229 | self.count += n 230 | self.avg = self.sum / self.count 231 | 232 | def count_mean_accuracy(output, target, prec_stat): 233 | _, pred = output.topk(1, 1, True, True) 234 | pred = pred.squeeze(1) 235 | for key in prec_stat.keys(): 236 | label = int(key) 237 | pred_map = np.uint8(pred.cpu().numpy() == label) 238 | target_map = np.uint8(target.cpu().numpy() == label) 239 | intersection_t = pred_map * (pred_map == target_map) 240 | prec_stat[key]['intersec'] += np.sum(intersection_t) 241 | prec_stat[key]['all'] += np.sum(target_map) 242 | return prec_stat 243 | 244 | def accuracy(output, target, topk=(1,)): 245 | """Computes the precision@k for the specified values of k""" 246 | maxk = max(topk) 247 | batch_size = target.size(0) 248 | 249 | _, pred = output.topk(maxk, 1, True, True) 250 | pred = pred.t() 251 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 252 | 253 | res = [] 254 | for k in topk: 255 | correct_k = correct[:k].view(-1).float().sum(0) 256 | res.append(correct_k.mul_(100.0 / batch_size)) 257 | return res 258 | 259 | if __name__=='__main__': 260 | main() 261 | -------------------------------------------------------------------------------- /tools/get_trainning_data_from_house3d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import csv 4 | import cv2 5 | import pickle 6 | import os 7 | 8 | import gym 9 | from gym import spaces 10 | from House3D.house import House 11 | from House3D.core import Environment, MultiHouseEnv 12 | from House3D import objrender 13 | from House3D.objrender import RenderMode 14 | 15 | __all__ = ['CollectDataTask'] 16 | 17 | ############################################### 18 | # Task related definitions and configurations 19 | ############################################### 20 | flag_print_debug_info = False # flag for printing debug info 21 | 22 | dist_reward_scale = 1.0 # 2.0 23 | collision_penalty_reward = 0.3 # penalty for collision 24 | correct_move_reward = None # reward when moving closer to target 25 | stay_room_reward = 0.1 # reward when staying in the correct target room 26 | indicator_reward = 0.5 27 | success_reward = 10 # reward when success 28 | 29 | time_penalty_reward = 0.1 # penalty for each time step 30 | delta_reward_coef = 0.5 31 | speed_reward_coef = 1.0 32 | 33 | success_distance_range = 1.0 34 | success_stay_time_steps = 5 35 | success_see_target_time_steps = 2 # time steps required for success under the "see" criteria 36 | 37 | # sensitivity setting 38 | rotation_sensitivity = 30 # 45 # maximum rotation per time step 39 | default_move_sensitivity = 0.5 # 1.0 # maximum movement per time step 40 | 41 | # discrete action space actions, totally <13> actions 42 | # Fwd, L, R, LF, RF, Lrot, Rrot, Bck, s-Fwd, s-L, s-R, s-Lrot, s-Rrot 43 | discrete_actions = [(1., 0., 0.), (0., 1., 0.), (0., -1., 0.), (0.5, 0.5, 0.), (0.5, -0.5, 0.), 44 | (0., 0., 1.), (0., 0., -1.), 45 | (-0.4, 0., 0.), 46 | (0.4, 0., 0.), (0., 0.4, 0.), (0., -0.4, 0.), 47 | (0., 0., 0.4), (0., 0., -0.4)] 48 | n_discrete_actions = len(discrete_actions) 49 | 50 | # criteria for seeing the object 51 | n_pixel_for_object_see = 450 # need at least see 450 pixels for success under default resolution 120 x 90 52 | n_pixel_for_object_sense = 50 53 | L_pixel_reward_range = n_pixel_for_object_see - n_pixel_for_object_sense 54 | pixel_object_reward = 0.4 55 | 56 | 57 | ################# 58 | # Util Functions 59 | ################# 60 | 61 | def reset_see_criteria(resolution): 62 | total_pixel = resolution[0] * resolution[1] 63 | global n_pixel_for_object_see, n_pixel_for_object_sense, L_pixel_reward_range 64 | n_pixel_for_object_see = max(int(total_pixel * 0.045), 5) 65 | n_pixel_for_object_sense = max(int(total_pixel * 0.005), 1) 66 | L_pixel_reward_range = n_pixel_for_object_see - n_pixel_for_object_sense 67 | 68 | 69 | class CollectDataTask(gym.Env): 70 | def __init__(self, env, 71 | seed=None, 72 | reward_type='delta', 73 | hardness=None, 74 | move_sensitivity=None, 75 | segment_input=True, 76 | joint_visual_signal=False, 77 | depth_signal=True, 78 | max_steps=-1, 79 | success_measure='see', 80 | discrete_action=False): 81 | self.env = env 82 | assert isinstance(env, Environment), '[RoomNavTask] env must be an instance of Environment!' 83 | if env.resolution != (120, 90): reset_see_criteria(env.resolution) 84 | self.resolution = resolution = env.resolution 85 | assert reward_type in [None, 'none', 'linear', 'indicator', 'delta', 'speed'] 86 | self.reward_type = reward_type 87 | self.colorDataFile = self.env.config['colorFile'] 88 | self.segment_input = segment_input 89 | self.joint_visual_signal = joint_visual_signal 90 | self.depth_signal = depth_signal 91 | n_channel = 3 92 | if segment_input: 93 | self.env.set_render_mode('semantic') 94 | else: 95 | self.env.set_render_mode('rgb') 96 | if joint_visual_signal: n_channel += 3 97 | if depth_signal: n_channel += 1 98 | self._observation_shape = (resolution[0], resolution[1], n_channel) 99 | self._observation_space = spaces.Box(0, 255, shape=self._observation_shape) 100 | 101 | self.max_steps = max_steps 102 | 103 | self.discrete_action = discrete_action 104 | if discrete_action: 105 | self._action_space = spaces.Discrete(n_discrete_actions) 106 | else: 107 | self._action_space = spaces.Tuple([spaces.Box(0, 1, shape=(4,)), spaces.Box(0, 1, shape=(2,))]) 108 | 109 | if seed is not None: 110 | np.random.seed(seed) 111 | random.seed(seed) 112 | 113 | # configs 114 | self.move_sensitivity = (move_sensitivity or default_move_sensitivity) # at most * meters per frame 115 | self.rot_sensitivity = rotation_sensitivity 116 | self.dist_scale = dist_reward_scale or 1 117 | self.successRew = success_reward 118 | self.inroomRew = stay_room_reward or 0.2 119 | self.colideRew = collision_penalty_reward or 0.02 120 | self.goodMoveRew = correct_move_reward or 0.0 121 | 122 | self.last_obs = None 123 | self.last_info = None 124 | self._object_cnt = 0 125 | 126 | # config hardness 127 | self.hardness = None 128 | self.availCoors = None 129 | self._availCoorsDict = None 130 | self.reset_hardness(hardness) 131 | 132 | # temp storage 133 | self.collision_flag = False 134 | 135 | # episode counter 136 | self.current_episode_step = 0 137 | 138 | # config success measure 139 | assert success_measure in ['stay', 'see'] 140 | self.success_measure = success_measure 141 | print('[RoomNavTask] >> Success Measure = <{}>'.format(success_measure)) 142 | self.success_stay_cnt = 0 143 | if success_measure == 'see': 144 | self.room_target_object = dict() 145 | self._load_target_object_data(self.env.config['roomTargetFile']) 146 | 147 | def _load_target_object_data(self, roomTargetFile): 148 | with open(roomTargetFile) as csvFile: 149 | reader = csv.DictReader(csvFile) 150 | for row in reader: 151 | c = np.array((row['r'], row['g'], row['b']), dtype=np.uint8) 152 | room = row['target_room'] 153 | if room not in self.room_target_object: 154 | self.room_target_object[room] = [] 155 | self.room_target_object[room].append(c) 156 | 157 | """ 158 | reset the target room type to navigate to 159 | when target is None, a valid target will be randomly selected 160 | """ 161 | 162 | def reset_target(self, target=None): 163 | if target is None: 164 | target = random.choice(self.house.all_desired_roomTypes) 165 | else: 166 | assert target in self.house.all_desired_roomTypes, '[RoomNavTask] desired target <{}> does not exist in the current house!'.format( 167 | target) 168 | if self.house.setTargetRoom(target): # target room changed!!! 169 | _id = self.house._id 170 | if self.house.targetRoomTp not in self._availCoorsDict[_id]: 171 | if self.hardness is None: 172 | self.availCoors = self.house.connectedCoors 173 | else: 174 | allowed_dist = self.house.maxConnDist * self.hardness 175 | self.availCoors = [c for c in self.house.connectedCoors 176 | if self.house.connMap[c[0], c[1]] <= allowed_dist] 177 | self._availCoorsDict[_id][self.house.targetRoomTp] = self.availCoors 178 | else: 179 | self.availCoors = self._availCoorsDict[_id][self.house.targetRoomTp] 180 | 181 | @property 182 | def house(self): 183 | return self.env.house 184 | 185 | """ 186 | gym api: reset function 187 | when target is not None, we will set the target room type to navigate to 188 | """ 189 | 190 | def reset(self, target=None): 191 | # clear episode steps 192 | self.current_episode_step = 0 193 | self.success_stay_cnt = 0 194 | self._object_cnt = 0 195 | 196 | # reset house 197 | self.env.reset_house() 198 | 199 | self.house.targetRoomTp = None # [NOTE] IMPORTANT! clear this!!!!! 200 | 201 | # reset target room 202 | self.reset_target(target=target) # randomly reset 203 | 204 | # general birth place 205 | gx, gy = random.choice(self.availCoors) 206 | self.collision_flag = False 207 | # generate state 208 | x, y = self.house.to_coor(gx, gy, True) 209 | self.env.reset(x=x, y=y) 210 | self.last_obs = self.env.render() 211 | if self.joint_visual_signal: 212 | self.last_obs = np.concatenate([self.env.render(mode='rgb'), self.last_obs], axis=-1) 213 | ret_obs = self.last_obs 214 | if self.depth_signal: 215 | dep_sig = self.env.render(mode='depth') 216 | if dep_sig.shape[-1] > 1: 217 | dep_sig = dep_sig[..., 0:1] 218 | ret_obs = np.concatenate([ret_obs, dep_sig], axis=-1) 219 | self.last_info = self.info 220 | return ret_obs 221 | 222 | def _apply_action(self, action): 223 | if self.discrete_action: 224 | return discrete_actions[action] 225 | else: 226 | rot = action[1][0] - action[1][1] 227 | act = action[0] 228 | return (act[0] - act[1]), (act[2] - act[3]), rot 229 | 230 | def _is_success(self, raw_dist): 231 | if raw_dist > 0: 232 | self.success_stay_cnt = 0 233 | return False 234 | if self.success_measure == 'stay': 235 | self.success_stay_cnt += 1 236 | return self.success_stay_cnt >= success_stay_time_steps 237 | # self.success_measure == 'see' 238 | flag_see_target_objects = False 239 | object_color_list = self.room_target_object[self.house.targetRoomTp] 240 | if (self.last_obs is not None) and self.segment_input: 241 | seg_obs = self.last_obs if not self.joint_visual_signal else self.last_obs[:, :, 3:6] 242 | else: 243 | seg_obs = self.env.render(mode='semantic') 244 | self._object_cnt = 0 245 | for c in object_color_list: 246 | cur_n = np.sum(np.all(seg_obs == c, axis=2)) 247 | self._object_cnt += cur_n 248 | if self._object_cnt >= n_pixel_for_object_see: 249 | flag_see_target_objects = True 250 | break 251 | if flag_see_target_objects: 252 | self.success_stay_cnt += 1 253 | else: 254 | self.success_stay_cnt = 0 # did not see any target objects! 255 | return self.success_stay_cnt >= success_see_target_time_steps 256 | 257 | @property 258 | def observation_space(self): 259 | return self._observation_space 260 | 261 | @property 262 | def action_space(self): 263 | return self._action_space 264 | 265 | @property 266 | def info(self): 267 | ret = self.env.info 268 | gx, gy = ret['grid'] 269 | ret['dist'] = dist = self.house.connMap[gx, gy] 270 | ret['scaled_dist'] = self.house.getScaledDist(gx, gy) 271 | ret['optsteps'] = int(dist / (self.move_sensitivity / self.house.grid_det) + 0.5) 272 | ret['collision'] = int(self.collision_flag) 273 | ret['target_room'] = self.house.targetRoomTp 274 | return ret 275 | 276 | """ 277 | return all the available target room types of the current house 278 | """ 279 | 280 | def get_avail_targets(self): 281 | return self.house.all_desired_roomTypes 282 | 283 | """ 284 | reset the hardness of the task 285 | """ 286 | 287 | def reset_hardness(self, hardness=None): 288 | self.hardness = hardness 289 | if hardness is None: 290 | self.availCoors = self.house.connectedCoors 291 | else: 292 | allowed_dist = self.house.maxConnDist * hardness 293 | self.availCoors = [c for c in self.house.connectedCoors 294 | if self.house.connMap[c[0], c[1]] <= allowed_dist] 295 | n_house = self.env.num_house 296 | self._availCoorsDict = [dict() for i in range(n_house)] 297 | self._availCoorsDict[self.house._id][self.house.targetRoomTp] = self.availCoors 298 | 299 | """ 300 | recover the state (location) of the agent from the info dictionary 301 | """ 302 | 303 | def set_state(self, info): 304 | if len(info['pos']) == 2: 305 | self.env.reset(x=info['pos'][0], y=info['pos'][1], yaw=info['yaw']) 306 | else: 307 | self.env.reset(x=info['pos'][0], y=info['pos'][1], z=info['pos'][2], yaw=info['yaw']) 308 | 309 | """ 310 | return 2d topdown map 311 | """ 312 | 313 | def get_2dmap(self): 314 | return self.env.gen_2dmap() 315 | 316 | """ 317 | show a image. 318 | if img is not None, it will visualize the img to monitor 319 | if img is None, it will return a img object of the observation 320 | Note: whenever this function is called with img not None, the task cannot perform rendering any more!!!! 321 | """ 322 | 323 | def show(self, img=None): 324 | return self.env.show(img=img, 325 | renderMapLoc=(None if img is not None else self.info['grid']), 326 | renderSegment=True, 327 | display=(img is not None)) 328 | 329 | def debug_show(self): 330 | return self.env.debug_render() 331 | 332 | 333 | if __name__ == '__main__': 334 | from House3D.common import load_config 335 | 336 | modes = [RenderMode.RGB, RenderMode.SEMANTIC, RenderMode.INSTANCE, RenderMode.DEPTH] 337 | 338 | api = objrender.RenderAPI(w=400, h=400, device=0) 339 | cfg = load_config('config.json') 340 | 341 | House_dir = '/data/vision/oliva/scenedataset/activevision/suncg/suncg_data/house' 342 | dataset_root = '/data/vision/oliva/scenedataset/syntheticscene/TopViewMaskDataset' 343 | if not os.path.isdir(dataset_root): 344 | os.mkdir(dataset_root) 345 | 346 | houses = os.listdir(House_dir) 347 | available_houses = [] 348 | for house in houses: 349 | if os.path.isfile(os.path.join(House_dir, house, 'house.obj')): 350 | available_houses.append(house) 351 | if len(available_houses) == 500: 352 | break 353 | ruined = [] 354 | for house in available_houses: 355 | if os.path.isdir(os.path.join(dataset_root, house)): 356 | print(house + ' Already exist!') 357 | continue 358 | try: 359 | env = Environment(api, house, cfg) 360 | except: 361 | ruined.append(house) 362 | print('House id ' + house + ' can not be used') 363 | continue 364 | print('House ' + house + ' extracting...') 365 | os.mkdir(os.path.join(dataset_root, house)) 366 | task = CollectDataTask(env, hardness=0.6, discrete_action=True) 367 | house_obj = task.house 368 | connectedCoor = house_obj.connectedCoors 369 | cam = task.env.cam 370 | for coor in connectedCoor: 371 | gx, gy = coor 372 | if house_obj.connMap[gx, gy] != 0 or gx % 10 != 0 or gy % 10 != 0: 373 | continue 374 | coor_dir = os.path.join(dataset_root, house, 'gx_{}_gy_{}'.format(gx, gy)) 375 | if not os.path.isdir(coor_dir): 376 | os.mkdir(coor_dir) 377 | print(coor) 378 | x, y = house_obj.to_coor(gx, gy, True) 379 | info = {'pos': [x, y], 'yaw': 0} 380 | task.set_state(info) 381 | top_map, _ = task.env.gen_localmap(scale=3) 382 | top_map = top_map[:, :, ::-1] 383 | b, g, r = top_map[200, 200] 384 | if b < 80 and b > 30 and g > 30 and g < 80 and r > 240: 385 | cv2.imwrite(os.path.join(coor_dir, 'topdown_view.png'), top_map) 386 | for yaw in [0, 45]: 387 | info = {'pos': [x, y], 'yaw': yaw} 388 | task.set_state(info) 389 | for i, mod in enumerate(['rgb', 'sem', 'ins', 'depth']): 390 | cube_map = task.env.render_cube_map(modes[i]) 391 | cube_map = cube_map[:, :4 * cube_map.shape[1] // 6] 392 | if mod == 'depth': 393 | infmask = cube_map[:, :, 1] 394 | cube_map = cube_map[:, :, 0] * (infmask == 0) 395 | else: 396 | cube_map = cube_map[:, :, ::-1] 397 | cv2.imwrite(os.path.join(coor_dir, 'mode={}_{}.png'.format(mod, yaw)), cube_map) 398 | info = {'pos': [x, y, 3.6], 'yaw': 0} 399 | task.set_state(info) 400 | cube_map = task.env.render_cube_map(modes[1]) 401 | cube_map = cube_map[:, - cube_map.shape[1] // 6 :] 402 | cube_map = cube_map[:, :, ::-1] 403 | cv2.imwrite(os.path.join(coor_dir, 'OverviewMask.png'), cube_map) 404 | 405 | print(ruined) 406 | print(len(ruined)) 407 | -------------------------------------------------------------------------------- /tools/process_data_for_VPN.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | from dominate.tags import * 4 | import dominate 5 | 6 | dataset_root = '/mnt/lustre/share/VPN_driving_scene/mp3d' 7 | scene_list = os.listdir(dataset_root) 8 | train_list = scene_list[:int(5*len(scene_list)/6)] 9 | val_list = scene_list[int(5*len(scene_list)/6):] 10 | train_output = [] 11 | val_output = [] 12 | data_type = ['rgb', 'sem', 'depth', 'ins', ] 13 | 14 | 15 | for scene_id in train_list: 16 | print('Processing ' + scene_id) 17 | coor_set = os.listdir(os.path.join(dataset_root, scene_id)) 18 | for coor in coor_set: 19 | coor_path = os.path.join(scene_id, coor) 20 | skip = False 21 | if not os.path.isfile(os.path.join(dataset_root, coor_path, 'topdown-semantics.png')): 22 | skip = True 23 | if skip: 24 | continue 25 | train_output.append(coor_path) 26 | web_path = os.path.join(dataset_root, coor_path) 27 | with dominate.document(title=web_path) as web: 28 | for mod in data_type: 29 | h2('Mode=%s, Yaw=%d'%(mod, 0)) 30 | with table(border=1, style='table-layout: fixed;'): 31 | with tr(): 32 | path = 'mode=%s_%d.png'%(mod, 0) 33 | with td(style='word-wrap: break-word;', halign='center', valign='top'): 34 | img(style='height:256px', src=path) 35 | h2('Mode=%s, Yaw=%d' % (mod, 45)) 36 | with table(border=1, style='table-layout: fixed;'): 37 | with tr(): 38 | path = 'mode=%s_%d.png' % (mod, 45) 39 | with td(style='word-wrap: break-word;', halign='center', valign='top'): 40 | img(style='height:256px', src=path) 41 | h2('Top-view Semantic Mask') 42 | with table(border=1, style='table-layout: fixed;'): 43 | with tr(): 44 | path = 'OverviewMask.png' 45 | with td(style='word-wrap: break-word;', halign='center', valign='top'): 46 | img(style='width:256px', src=path) 47 | # with open(os.path.join(web_path, 'index.html'), 'w') as fp: 48 | # fp.write(web.render()) 49 | print('Index of ' + os.path.join(scene_id, coor) + ' generated!') 50 | 51 | for scene_id in val_list: 52 | print('Processing ' + scene_id) 53 | coor_set = os.listdir(os.path.join(dataset_root, scene_id)) 54 | for coor in coor_set: 55 | coor_path = os.path.join(scene_id, coor) 56 | skip = False 57 | if not os.path.isfile(os.path.join(dataset_root, coor_path, 'topdown-semantics.png')): 58 | # print(os.path.join(coor_path, '0', 'topdown_view.jpg') + ' is lost') 59 | skip = True 60 | if skip: 61 | continue 62 | val_output.append(coor_path) 63 | web_path = os.path.join(dataset_root, coor_path) 64 | with dominate.document(title=web_path) as web: 65 | for mod in data_type: 66 | h2('Mode=%s, Yaw=%d'%(mod, 0)) 67 | with table(border=1, style='table-layout: fixed;'): 68 | with tr(): 69 | path = 'mode=%s_%d.png'%(mod, 0) 70 | with td(style='word-wrap: break-word;', halign='center', valign='top'): 71 | img(style='height:256px', src=path) 72 | h2('Mode=%s, Yaw=%d' % (mod, 45)) 73 | with table(border=1, style='table-layout: fixed;'): 74 | with tr(): 75 | path = 'mode=%s_%d.png' % (mod, 45) 76 | with td(style='word-wrap: break-word;', halign='center', valign='top'): 77 | img(style='height:256px', src=path) 78 | h2('Top-view Semantic Mask') 79 | with table(border=1, style='table-layout: fixed;'): 80 | with tr(): 81 | path = 'OverviewMask.png' 82 | with td(style='word-wrap: break-word;', halign='center', valign='top'): 83 | img(style='width:256px', src=path) 84 | # with open(os.path.join(web_path, 'index.html'), 'w') as fp: 85 | # fp.write(web.render()) 86 | print('Index of ' + os.path.join(scene_id, coor) + ' generated!') 87 | 88 | 89 | with open('train_list.txt','w') as f: 90 | print('Writing train file ...') 91 | f.write('\n'.join(train_output)) 92 | 93 | with open('val_list.txt', 'w') as f: 94 | print('Writing validation file ...') 95 | f.write('\n'.join(val_output)) 96 | -------------------------------------------------------------------------------- /tools/process_transfer_driving_data.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | 4 | CARLA_data_root = '../data/Carla_Dataset_v1' 5 | scene_list = sorted(os.listdir(CARLA_data_root)) 6 | train_list = scene_list[:int(5*len(scene_list)/6)] 7 | val_list = scene_list[int(5*len(scene_list)/6):] 8 | train_output = [] 9 | val_output = [] 10 | # data_type = ['rgb', 'sem', 'depth', 'ins', ] 11 | 12 | 13 | for scene_id in train_list: 14 | print('Processing ' + scene_id) 15 | coor_set = sorted(os.listdir(os.path.join(CARLA_data_root, scene_id, 'CAM_RGB_FRONT'))) 16 | 17 | for coor in coor_set: 18 | coor_path = os.path.join(scene_id, 'CAM_RGB_FRONT', coor) 19 | train_output.append(coor_path) 20 | 21 | with open('../tools/train_source_list.txt', 'w') as f: 22 | print('Writing train_source_list file ...') 23 | f.write('\n'.join(train_output)) 24 | 25 | for scene_id in val_list: 26 | print('Processing ' + scene_id) 27 | coor_set = sorted(os.listdir(os.path.join(CARLA_data_root, scene_id, 'CAM_RGB_FRONT'))) 28 | 29 | for coor in coor_set: 30 | coor_path = os.path.join(scene_id, 'CAM_RGB_FRONT', coor) 31 | val_output.append(coor_path) 32 | 33 | with open('../tools/val_source_list.txt', 'w') as f: 34 | print('Writing val_source_list file ...') 35 | f.write('\n'.join(val_output)) 36 | -------------------------------------------------------------------------------- /tools/process_transfer_indoor_data.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | 4 | CARLA_data_root = '../data/Carla_Dataset_v1' 5 | dataset_root = '../data/nuScenes' 6 | scene_list = sorted(os.listdir(CARLA_data_root)) 7 | train_list = scene_list[:int(5*len(scene_list)/6)] 8 | # val_list = scene_list[int(5*len(scene_list)/6):] 9 | train_output = [] 10 | # val_output = [] 11 | # data_type = ['rgb', 'sem', 'depth', 'ins', ] 12 | 13 | 14 | for scene_id in train_list: 15 | print('Processing ' + scene_id) 16 | coor_set = sorted(os.listdir(os.path.join(CARLA_data_root, scene_id, 'CAM_RGB_FRONT'))) 17 | 18 | for coor in coor_set: 19 | coor_path = os.path.join(scene_id, 'CAM_RGB_FRONT', coor) 20 | train_output.append(coor_path) 21 | 22 | with open('train_source_list.txt', 'w') as f: 23 | print('Writing train_source_list file ...') 24 | f.write('\n'.join(train_output)) 25 | 26 | nuScenes_data_root = '../data/nuScenes' 27 | scene_list = sorted(os.listdir(nuScenes_data_root)) 28 | train_list = scene_list 29 | train_output = [] 30 | 31 | for scene_id in train_list: 32 | print('Processing ' + scene_id) 33 | coor_set = sorted(os.listdir(os.path.join(nuScenes_data_root, scene_id, 'CAM_RGB_FRONT'))) 34 | coor_set = coor_set[:18220] 35 | 36 | for coor in coor_set: 37 | coor_path = os.path.join(scene_id, 'CAM_RGB_FRONT', coor) 38 | train_output.append(coor_path) 39 | 40 | 41 | with open('train_target_list.txt', 'w') as f: 42 | print('Writing train_target_list file ...') 43 | f.write('\n'.join(train_output)) 44 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from utils import Foo 2 | from models import VPNModel 3 | from datasets import OVMDataset 4 | from opts import parser 5 | from transform import * 6 | import torchvision 7 | import torch 8 | from torch import nn 9 | from torch import optim 10 | import os 11 | import time 12 | import shutil 13 | 14 | mean_rgb = [0.485, 0.456, 0.406] 15 | std_rgb = [0.229, 0.224, 0.225] 16 | 17 | def main(): 18 | global args, best_prec1 19 | best_prec1 = 0 20 | args = parser.parse_args() 21 | network_config = Foo( 22 | encoder=args.encoder, 23 | decoder=args.decoder, 24 | fc_dim=args.fc_dim, 25 | output_size=args.label_resolution, 26 | num_views=args.n_views, 27 | num_class=args.num_class, 28 | transform_type=args.transform_type, 29 | ) 30 | train_dataset = OVMDataset(args.data_root, args.train_list, 31 | transform=torchvision.transforms.Compose([ 32 | Stack(roll=True), 33 | ToTorchFormatTensor(div=True), 34 | GroupNormalize(mean_rgb, std_rgb) 35 | ]), 36 | num_views=network_config.num_views, input_size=args.input_resolution, 37 | label_size=args.label_resolution, use_mask=args.use_mask, use_depth=args.use_depth) 38 | val_dataset = OVMDataset(args.data_root, args.eval_list, 39 | transform=torchvision.transforms.Compose([ 40 | Stack(roll=True), 41 | ToTorchFormatTensor(div=True), 42 | GroupNormalize(mean_rgb, std_rgb) 43 | ]), 44 | num_views=network_config.num_views, input_size=args.input_resolution, 45 | label_size=args.label_resolution, use_mask=args.use_mask, use_depth=args.use_depth) 46 | train_loader = torch.utils.data.DataLoader( 47 | train_dataset, batch_size=args.batch_size, 48 | num_workers=args.num_workers, shuffle=True, 49 | pin_memory=True 50 | ) 51 | 52 | val_loader = torch.utils.data.DataLoader( 53 | val_dataset, batch_size=args.batch_size, 54 | num_workers=args.num_workers, shuffle=False, 55 | pin_memory=True 56 | ) 57 | 58 | mapper = VPNModel(network_config) 59 | mapper = nn.DataParallel(mapper.cuda()) 60 | 61 | if args.resume: 62 | if os.path.isfile(args.resume): 63 | print(("=> loading checkpoint '{}'".format(args.resume))) 64 | checkpoint = torch.load(args.resume) 65 | args.start_epoch = checkpoint['epoch'] 66 | mapper.load_state_dict(checkpoint['state_dict']) 67 | print(("=> loaded checkpoint '{}' (epoch {})" 68 | .format(args.evaluate, checkpoint['epoch']))) 69 | else: 70 | print(("=> no checkpoint found at '{}'".format(args.resume))) 71 | 72 | 73 | criterion = nn.NLLLoss(weight=None, size_average=True) 74 | optimizer = optim.Adam(mapper.parameters(), 75 | lr=args.start_lr, betas=(0.95, 0.999)) 76 | 77 | if not os.path.isdir(args.log_root): 78 | os.mkdir(args.log_root) 79 | log_train = open(os.path.join(args.log_root, '%s.csv' % args.store_name), 'w') 80 | 81 | for epoch in range(args.start_epoch, args.epochs): 82 | adjust_learning_rate(optimizer, epoch, args.lr_steps) 83 | train(train_loader, mapper, criterion, optimizer, epoch, log_train) 84 | 85 | if (epoch + 1) % args.ckpt_freq == 0 or epoch == args.epochs - 1: 86 | prec1 = eval(val_loader, mapper, criterion, log_train, epoch) 87 | is_best = prec1 > best_prec1 88 | best_prec1 = max(prec1, best_prec1) 89 | save_checkpoint({ 90 | 'epoch': epoch + 1, 91 | 'arch': network_config.encoder, 92 | 'state_dict': mapper.state_dict(), 93 | 'best_prec1': best_prec1, 94 | }, is_best) 95 | 96 | def train(train_loader, mapper, criterion, optimizer, epoch, log): 97 | batch_time = AverageMeter() 98 | data_time = AverageMeter() 99 | losses = AverageMeter() 100 | top1 = AverageMeter() 101 | top5 = AverageMeter() 102 | 103 | mapper.train() 104 | 105 | end = time.time() 106 | for step, data in enumerate(train_loader): 107 | rgb_stack, target = data 108 | data_time.update(time.time() - end) 109 | target_var = target.cuda() 110 | input_rgb_var = torch.autograd.Variable(rgb_stack).cuda() 111 | output = mapper(input_rgb_var) 112 | target_var = target_var.view(-1) 113 | output = output.view(-1, args.num_class) 114 | loss = criterion(output, target_var) 115 | losses.update(loss.data[0], input_rgb_var.size(0)) 116 | prec1, prec5 = accuracy(output.data, target_var.data, topk=(1, 5)) 117 | top1.update(prec1[0], rgb_stack.size(0)) 118 | top5.update(prec5[0], rgb_stack.size(0)) 119 | 120 | optimizer.zero_grad() 121 | 122 | loss.backward() 123 | optimizer.step() 124 | 125 | batch_time.update(time.time() - end) 126 | end = time.time() 127 | 128 | if step % args.print_freq == 0: 129 | output = ('Epoch: [{0}][{1}/{2}], lr: {lr:.5f}\t' 130 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 131 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 132 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 133 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 134 | epoch + 1, step + 1, len(train_loader), batch_time=batch_time, 135 | data_time=data_time, loss=losses, top1=top1, lr=optimizer.param_groups[-1]['lr'])) 136 | print(output) 137 | log.write(output + '\n') 138 | log.flush() 139 | 140 | def eval(val_loader, mapper, criterion, log, epoch): 141 | batch_time = AverageMeter() 142 | data_time = AverageMeter() 143 | losses = AverageMeter() 144 | top1 = AverageMeter() 145 | top5 = AverageMeter() 146 | 147 | mapper.eval() 148 | 149 | end = time.time() 150 | for step, (rgb_stack, target) in enumerate(val_loader): 151 | data_time.update(time.time() - end) 152 | with torch.no_grad(): 153 | input_rgb_var = torch.autograd.Variable(rgb_stack).cuda() 154 | output = mapper(input_rgb_var) 155 | target_var = target.cuda() 156 | target_var = target_var.view(-1) 157 | output = output.view(-1, args.num_class) 158 | loss = criterion(output, target_var) 159 | losses.update(loss.data[0], input_rgb_var.size(0)) 160 | prec1, prec5 = accuracy(output.data, target_var.data, topk=(1, 5)) 161 | top1.update(prec1[0], rgb_stack.size(0)) 162 | top5.update(prec5[0], rgb_stack.size(0)) 163 | 164 | batch_time.update(time.time() - end) 165 | end = time.time() 166 | 167 | if step % args.print_freq == 0: 168 | output = ('Test: [{0}][{1}/{2}]\t' 169 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 170 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 171 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 172 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 173 | epoch + 1, step + 1, len(val_loader), batch_time=batch_time, 174 | data_time=data_time, loss=losses, top1=top1)) 175 | print(output) 176 | log.write(output + '\n') 177 | log.flush() 178 | output = ('Testing Results: Prec@1 {top1.avg:.3f} Loss {loss.avg:.5f}' 179 | .format(top1=top1, loss=losses)) 180 | print(output) 181 | output_best = '\nBest Prec@1: %.3f' % (best_prec1) 182 | print(output_best) 183 | log.write(output + ' ' + output_best + '\n') 184 | log.flush() 185 | return top1.avg 186 | 187 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 188 | torch.save(state, '%s/%s_checkpoint.pth.tar' % (args.root_model, args.store_name)) 189 | if is_best: 190 | shutil.copyfile('%s/%s_checkpoint.pth.tar' % (args.root_model, args.store_name), '%s/%s_best.pth.tar' % (args.root_model, args.store_name)) 191 | 192 | 193 | def adjust_learning_rate(optimizer, epoch, lr_steps): 194 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 195 | decay = 0.1 ** (sum(epoch >= np.array(lr_steps))) 196 | lr = args.start_lr * decay 197 | decay = args.weight_decay 198 | for param_group in optimizer.param_groups: 199 | param_group['lr'] = lr 200 | param_group['weight_decay'] = decay 201 | 202 | class AverageMeter(object): 203 | """Computes and stores the average and current value""" 204 | def __init__(self): 205 | self.reset() 206 | 207 | def reset(self): 208 | self.val = 0 209 | self.avg = 0 210 | self.sum = 0 211 | self.count = 0 212 | 213 | def update(self, val, n=1): 214 | self.val = val 215 | self.sum += val * n 216 | self.count += n 217 | self.avg = self.sum / self.count 218 | 219 | def accuracy(output, target, topk=(1,)): 220 | """Computes the precision@k for the specified values of k""" 221 | maxk = max(topk) 222 | batch_size = target.size(0) 223 | 224 | _, pred = output.topk(maxk, 1, True, True) 225 | pred = pred.t() 226 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 227 | 228 | res = [] 229 | for k in topk: 230 | correct_k = correct[:k].view(-1).float().sum(0) 231 | res.append(correct_k.mul_(100.0 / batch_size)) 232 | return res 233 | 234 | if __name__=='__main__': 235 | main() 236 | -------------------------------------------------------------------------------- /train_carla.py: -------------------------------------------------------------------------------- 1 | from utils import Foo 2 | from models import VPNModel 3 | from datasets import OVMDataset 4 | from opts import parser 5 | from transform import * 6 | import torchvision 7 | import torch 8 | from torch import nn 9 | from torch.optim.lr_scheduler import MultiStepLR 10 | from torch import optim 11 | import os 12 | import time 13 | from torch.nn.utils import clip_grad_norm 14 | # from examples.cognitive_mapping.Logger import Logger 15 | import cv2 16 | import shutil 17 | 18 | mean_rgb = [0.485, 0.456, 0.406] 19 | std_rgb = [0.229, 0.224, 0.225] 20 | 21 | def main(): 22 | global args, best_prec1 23 | best_prec1 = 0 24 | args = parser.parse_args() 25 | network_config = Foo( 26 | encoder=args.encoder, 27 | decoder=args.decoder, 28 | fc_dim=args.fc_dim, 29 | output_size=args.label_resolution, 30 | num_views=args.n_views, 31 | num_class=args.num_class, 32 | transform_type=args.transform_type, 33 | ) 34 | train_dataset = OVMDataset(args.data_root, args.train_list, 35 | transform=torchvision.transforms.Compose([ 36 | Stack(roll=True), 37 | ToTorchFormatTensor(div=True), 38 | GroupNormalize(mean_rgb, std_rgb) 39 | ]), 40 | num_views=network_config.num_views, input_size=args.input_resolution, 41 | label_size=args.label_resolution, use_mask=args.use_mask, use_depth=args.use_depth) 42 | val_dataset = OVMDataset(args.data_root, args.eval_list, 43 | transform=torchvision.transforms.Compose([ 44 | Stack(roll=True), 45 | ToTorchFormatTensor(div=True), 46 | GroupNormalize(mean_rgb, std_rgb) 47 | ]), 48 | num_views=network_config.num_views, input_size=args.input_resolution, 49 | label_size=args.label_resolution, use_mask=args.use_mask, use_depth=args.use_depth) 50 | train_loader = torch.utils.data.DataLoader( 51 | train_dataset, batch_size=args.batch_size, 52 | num_workers=args.num_workers, shuffle=True, 53 | pin_memory=True 54 | ) 55 | 56 | val_loader = torch.utils.data.DataLoader( 57 | val_dataset, batch_size=args.batch_size, 58 | num_workers=args.num_workers, shuffle=False, 59 | pin_memory=True 60 | ) 61 | 62 | mapper = VPNModel(network_config) 63 | mapper = nn.DataParallel(mapper.cuda()) 64 | 65 | if args.resume: 66 | if os.path.isfile(args.resume): 67 | print(("=> loading checkpoint '{}'".format(args.resume))) 68 | checkpoint = torch.load(args.resume) 69 | args.start_epoch = checkpoint['epoch'] 70 | mapper.load_state_dict(checkpoint['state_dict']) 71 | print(("=> loaded checkpoint '{}' (epoch {})" 72 | .format(args.evaluate, checkpoint['epoch']))) 73 | else: 74 | print(("=> no checkpoint found at '{}'".format(args.resume))) 75 | 76 | 77 | criterion = nn.NLLLoss(weight=None, size_average=True) 78 | optimizer = optim.Adam(mapper.parameters(), 79 | lr=args.start_lr, betas=(0.95, 0.999)) 80 | 81 | if not os.path.isdir(args.log_root): 82 | os.mkdir(args.log_root) 83 | log_train = open(os.path.join(args.log_root, '%s.csv' % args.store_name), 'w') 84 | 85 | for epoch in range(args.start_epoch, args.epochs): 86 | adjust_learning_rate(optimizer, epoch, args.lr_steps) 87 | train(train_loader, mapper, criterion, optimizer, epoch, log_train) 88 | 89 | if (epoch + 1) % args.ckpt_freq == 0 or epoch == args.epochs - 1: 90 | prec1 = eval(val_loader, mapper, criterion, log_train, epoch) 91 | is_best = prec1 > best_prec1 92 | best_prec1 = max(prec1, best_prec1) 93 | save_checkpoint({ 94 | 'epoch': epoch + 1, 95 | 'arch': network_config.encoder, 96 | 'state_dict': mapper.state_dict(), 97 | 'best_prec1': best_prec1, 98 | }, is_best) 99 | 100 | def train(train_loader, mapper, criterion, optimizer, epoch, log): 101 | batch_time = AverageMeter() 102 | data_time = AverageMeter() 103 | losses = AverageMeter() 104 | top1 = AverageMeter() 105 | top5 = AverageMeter() 106 | 107 | mapper.train() 108 | 109 | end = time.time() 110 | for step, data in enumerate(train_loader): 111 | rgb_stack, target = data 112 | data_time.update(time.time() - end) 113 | target_var = target.cuda() 114 | input_rgb_var = torch.autograd.Variable(rgb_stack).cuda() 115 | output = mapper(input_rgb_var) 116 | target_var = target_var.view(-1) 117 | output = output.view(-1, args.num_class) 118 | loss = criterion(output, target_var) 119 | losses.update(loss.item(), input_rgb_var.size(0)) 120 | prec1, prec5 = accuracy(output.data, target_var.data, topk=(1, 5)) 121 | top1.update(prec1.item(), rgb_stack.size(0)) 122 | top5.update(prec5.item(), rgb_stack.size(0)) 123 | 124 | optimizer.zero_grad() 125 | 126 | loss.backward() 127 | optimizer.step() 128 | 129 | batch_time.update(time.time() - end) 130 | end = time.time() 131 | 132 | if step % args.print_freq == 0: 133 | output = ('Epoch: [{0}][{1}/{2}], lr: {lr:.5f}\t' 134 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 135 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 136 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 137 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 138 | epoch + 1, step + 1, len(train_loader), batch_time=batch_time, 139 | data_time=data_time, loss=losses, top1=top1, lr=optimizer.param_groups[-1]['lr'])) 140 | print(output) 141 | log.write(output + '\n') 142 | log.flush() 143 | 144 | def eval(val_loader, mapper, criterion, log, epoch): 145 | batch_time = AverageMeter() 146 | data_time = AverageMeter() 147 | losses = AverageMeter() 148 | top1 = AverageMeter() 149 | top5 = AverageMeter() 150 | 151 | mapper.eval() 152 | 153 | end = time.time() 154 | for step, (rgb_stack, target) in enumerate(val_loader): 155 | data_time.update(time.time() - end) 156 | with torch.no_grad(): 157 | input_rgb_var = torch.autograd.Variable(rgb_stack).cuda() 158 | output = mapper(input_rgb_var) 159 | target_var = target.cuda() 160 | target_var = target_var.view(-1) 161 | output = output.view(-1, args.num_class) 162 | loss = criterion(output, target_var) 163 | losses.update(loss.item(), input_rgb_var.size(0)) 164 | prec1, prec5 = accuracy(output.data, target_var.data, topk=(1, 5)) 165 | top1.update(prec1.item(), rgb_stack.size(0)) 166 | top5.update(prec5.item(), rgb_stack.size(0)) 167 | 168 | batch_time.update(time.time() - end) 169 | end = time.time() 170 | 171 | if step % args.print_freq == 0: 172 | output = ('Test: [{0}][{1}/{2}]\t' 173 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 174 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 175 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 176 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 177 | epoch + 1, step + 1, len(val_loader), batch_time=batch_time, 178 | data_time=data_time, loss=losses, top1=top1)) 179 | print(output) 180 | log.write(output + '\n') 181 | log.flush() 182 | output = ('Testing Results: Prec@1 {top1.avg:.3f} Loss {loss.avg:.5f}' 183 | .format(top1=top1, loss=losses)) 184 | print(output) 185 | output_best = '\nBest Prec@1: %.3f' % (best_prec1) 186 | print(output_best) 187 | log.write(output + ' ' + output_best + '\n') 188 | log.flush() 189 | return top1.avg 190 | 191 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 192 | torch.save(state, '%s/%s_checkpoint.pth.tar' % (args.root_model, args.store_name)) 193 | if is_best: 194 | shutil.copyfile('%s/%s_checkpoint.pth.tar' % (args.root_model, args.store_name), '%s/%s_best.pth.tar' % (args.root_model, args.store_name)) 195 | 196 | 197 | def adjust_learning_rate(optimizer, epoch, lr_steps): 198 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 199 | decay = 0.1 ** (sum(epoch >= np.array(lr_steps))) 200 | lr = args.start_lr * decay 201 | decay = args.weight_decay 202 | for param_group in optimizer.param_groups: 203 | param_group['lr'] = lr 204 | param_group['weight_decay'] = decay 205 | 206 | class AverageMeter(object): 207 | """Computes and stores the average and current value""" 208 | def __init__(self): 209 | self.reset() 210 | 211 | def reset(self): 212 | self.val = 0 213 | self.avg = 0 214 | self.sum = 0 215 | self.count = 0 216 | 217 | def update(self, val, n=1): 218 | self.val = val 219 | self.sum += val * n 220 | self.count += n 221 | self.avg = self.sum / self.count 222 | 223 | def accuracy(output, target, topk=(1,)): 224 | """Computes the precision@k for the specified values of k""" 225 | maxk = max(topk) 226 | batch_size = target.size(0) 227 | 228 | _, pred = output.topk(maxk, 1, True, True) 229 | pred = pred.t() 230 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 231 | 232 | res = [] 233 | for k in topk: 234 | correct_k = correct[:k].view(-1).float().sum(0) 235 | res.append(correct_k.mul_(100.0 / batch_size)) 236 | return res 237 | 238 | if __name__=='__main__': 239 | main() 240 | -------------------------------------------------------------------------------- /train_transfer.py: -------------------------------------------------------------------------------- 1 | from utils import Foo 2 | from models import VPNModel, FCDiscriminator 3 | from datasets import House3D_Dataset, MP3D_Dataset, Carla_Dataset, nuScenes_Dataset 4 | from opts import parser 5 | from transform import * 6 | import torchvision 7 | import torch 8 | from torch import nn 9 | from torch.optim.lr_scheduler import MultiStepLR 10 | from torch import optim 11 | import os 12 | import time 13 | from torch.nn.utils import clip_grad_norm 14 | 15 | import shutil 16 | import torch.nn.functional as F 17 | import os.path as osp 18 | from tensorboardX import SummaryWriter 19 | import argparse 20 | 21 | mean_rgb = [0.485, 0.456, 0.406] 22 | std_rgb = [0.229, 0.224, 0.225] 23 | 24 | def str2bool(v): 25 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 26 | return True 27 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 28 | return False 29 | else: 30 | raise argparse.ArgumentTypeError('Unsupported value encountered.') 31 | 32 | def main(): 33 | global args, best_prec1 34 | best_prec1 = 0 35 | 36 | parser.add_argument('--source_dir', type=str, 37 | default='/mnt/lustre/share/VPN_driving_scene/TopViewMaskDataset') 38 | parser.add_argument('--target_dir', type=str, 39 | default='/mnt/lustre/share/VPN_driving_scene/mp3d') 40 | parser.add_argument('--num-steps', type=int, default=250000) 41 | parser.add_argument('--iter-size-G', type=int, default=3) 42 | parser.add_argument('--iter-size-D', type=int, default=1) 43 | parser.add_argument('--learning-rate-D', type=float, default=1e-4) 44 | parser.add_argument('--learning-rate', type=float, default=2.5e-4) 45 | parser.add_argument("--SegSize", type=int, default=128, 46 | help="Comma-separated string with height and width of source images.") 47 | parser.add_argument("--SegSize-target", type=int, default=128, 48 | help="Comma-separated string with height and width of target images.") 49 | parser.add_argument('--resume-D', type=str) 50 | parser.add_argument('--resume-G', type=str) 51 | parser.add_argument('--train_source_list', type=str, default='./train_source_list.txt') 52 | parser.add_argument('--train_target_list', type=str, default='./train_target_list.txt') 53 | parser.add_argument('--num-classes', type=int, default=94) 54 | parser.add_argument('--power', type=float, default=0.9) 55 | parser.add_argument('--lambda_adv_target', type=float, default=0.001) 56 | parser.add_argument('--num_steps_stop', type=int, default=150000) 57 | parser.add_argument('--save_pred_every', type=int, default=5000) 58 | parser.add_argument('--snapshot-dir', type=str, default='/mnt/lustre/panbowen/VPN-transfer/snapshot/') 59 | parser.add_argument("--tensorboard", type=str2bool, default=True) 60 | parser.add_argument("--tf-logdir", type=str, default='/mnt/lustre/panbowen/VPN-transfer/tf_log/', 61 | help="Path to the directory of log.") 62 | parser.add_argument('--VPN-weights', type=str) 63 | parser.add_argument('--task-id', type=str) 64 | parser.add_argument('--scenario', type=str, default='indoor') 65 | 66 | args = parser.parse_args() 67 | 68 | network_config = Foo( 69 | encoder=args.encoder, 70 | decoder=args.decoder, 71 | fc_dim=args.fc_dim, 72 | output_size=args.label_resolution, 73 | num_views=args.n_views, 74 | num_class=94, 75 | transform_type=args.transform_type, 76 | ) 77 | 78 | if args.scenario == 'indoor': 79 | train_source_dataset = House3D_Dataset(args.source_dir, args.train_source_list, 80 | transform=torchvision.transforms.Compose([ 81 | Stack(roll=True), 82 | ToTorchFormatTensor(div=True), 83 | GroupNormalize(mean_rgb, std_rgb) 84 | ]), 85 | num_views=network_config.num_views, input_size=args.input_resolution, 86 | label_size=args.SegSize) 87 | train_target_dataset = MP3D_Dataset(args.target_dir, args.train_target_list, 88 | transform=torchvision.transforms.Compose([ 89 | Stack(roll=True), 90 | ToTorchFormatTensor(div=True), 91 | GroupNormalize(mean_rgb, std_rgb) 92 | ]), 93 | num_views=network_config.num_views, input_size=args.input_resolution, 94 | label_size=args.SegSize_target) 95 | elif args.scenario == 'traffic': 96 | train_source_dataset = Carla_Dataset(args.source_dir, args.train_source_list, 97 | transform=torchvision.transforms.Compose([ 98 | Stack(roll=True), 99 | ToTorchFormatTensor(div=True), 100 | GroupNormalize(mean_rgb, std_rgb) 101 | ]), 102 | num_views=network_config.num_views, input_size=args.input_resolution, 103 | label_size=args.SegSize) 104 | train_target_dataset = nuScenes_Dataset(args.target_dir, args.train_target_list, 105 | transform=torchvision.transforms.Compose([ 106 | Stack(roll=True), 107 | ToTorchFormatTensor(div=True), 108 | GroupNormalize(mean_rgb, std_rgb) 109 | ]), 110 | num_views=network_config.num_views, input_size=args.input_resolution, 111 | label_size=args.SegSize_target) 112 | 113 | source_loader = torch.utils.data.DataLoader( 114 | train_source_dataset, batch_size=args.batch_size, 115 | num_workers=args.num_workers, shuffle=True, 116 | pin_memory=True 117 | ) 118 | 119 | target_loader = torch.utils.data.DataLoader( 120 | train_target_dataset, batch_size=args.batch_size, 121 | num_workers=args.num_workers, shuffle=True, 122 | pin_memory=True 123 | ) 124 | 125 | mapper = VPNModel(network_config) 126 | mapper = nn.DataParallel(mapper.cuda()) 127 | mapper.train() 128 | 129 | model_D1 = FCDiscriminator(num_classes=args.num_classes) 130 | model_D1 = nn.DataParallel(model_D1.cuda()) 131 | model_D1.train() 132 | 133 | if args.VPN_weights: 134 | if os.path.isfile(args.VPN_weights): 135 | print(("=> loading checkpoint '{}'".format(args.VPN_weights))) 136 | checkpoint = torch.load(args.VPN_weights) 137 | args.start_epoch = checkpoint['epoch'] 138 | mapper.load_state_dict(checkpoint['state_dict']) 139 | print(("=> loaded checkpoint '{}' (epoch {})" 140 | .format(args.VPN_weights, checkpoint['epoch']))) 141 | else: 142 | print(("=> no checkpoint found at '{}'".format(args.VPN_weights))) 143 | 144 | resume_iter = None 145 | if args.resume_G: 146 | if os.path.isfile(args.resume_G): 147 | print(("=> loading checkpoint '{}'".format(args.resume_G))) 148 | state_dict = torch.load(args.resume_G) 149 | mapper.load_state_dict(state_dict) 150 | print(("=> loaded checkpoint '{}' (epoch {})" 151 | .format(args.evaluate, checkpoint['epoch']))) 152 | resume_iter = int(args.resume_G.split('_')[-1].split('.')[0]) 153 | else: 154 | print(("=> no checkpoint found at '{}'".format(args.resume_G))) 155 | 156 | if args.resume_D: 157 | if os.path.isfile(args.resume_D): 158 | print(("=> loading checkpoint '{}'".format(args.resume_D))) 159 | checkpoint = torch.load(args.resume_D) 160 | args.start_epoch = checkpoint['epoch'] 161 | model_D1.load_state_dict(checkpoint['state_dict']) 162 | print(("=> loaded checkpoint '{}' (epoch {})" 163 | .format(args.evaluate, checkpoint['epoch']))) 164 | else: 165 | print(("=> no checkpoint found at '{}'".format(args.resume_D))) 166 | 167 | optimizer = optim.SGD(mapper.parameters(), 168 | lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) 169 | optimizer.zero_grad() 170 | 171 | optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) 172 | optimizer_D1.zero_grad() 173 | 174 | 175 | criterion_seg = nn.NLLLoss(weight=None, size_average=True) 176 | criterion_bce = nn.BCEWithLogitsLoss() 177 | 178 | train(source_loader, target_loader, mapper, model_D1, 179 | criterion_seg, criterion_bce, optimizer, optimizer_D1, resume_iter) 180 | 181 | def train(source_loader, target_loader, mapper, model_D1, seg_loss, bce_loss, optimizer, optimizer_D1, resume_iter): 182 | source_loader_iter = enumerate(source_loader) 183 | 184 | target_loader_iter = enumerate(target_loader) 185 | 186 | # set up tensor board 187 | if args.tensorboard: 188 | if not os.path.exists(os.path.join(args.tf_logdir, args.task_id)): 189 | os.makedirs(os.path.join(args.tf_logdir, args.task_id)) 190 | 191 | writer = SummaryWriter(os.path.join(args.tf_logdir, args.task_id)) 192 | 193 | interp = nn.Upsample(size=(args.SegSize, args.SegSize), mode='bilinear', align_corners=True) 194 | interp_target = nn.Upsample(size=(args.SegSize_target, args.SegSize_target), mode='bilinear', align_corners=True) 195 | source_label = 0 196 | target_label = 1 197 | 198 | for i_iter in range(args.num_steps): 199 | if resume_iter is not None and i_iter < resume_iter: 200 | continue 201 | 202 | loss_seg_value = 0 203 | loss_adv_target_value = 0 204 | loss_D_value = 0 205 | 206 | optimizer.zero_grad() 207 | adjust_learning_rate(optimizer, i_iter) 208 | 209 | optimizer_D1.zero_grad() 210 | adjust_learning_rate_D(optimizer_D1, i_iter) 211 | 212 | # train G 213 | # don't accumulate grads in D 214 | for param in model_D1.parameters(): 215 | param.requires_grad = False 216 | 217 | for sub_i in range(args.iter_size_G): 218 | # train with source 219 | 220 | try: 221 | _, batch = source_loader_iter.__next__() 222 | except: 223 | source_loader_iter = enumerate(source_loader) 224 | _, batch = source_loader_iter.__next__() 225 | rgb_stack, label = batch 226 | label_var = label.cuda() 227 | input_rgb_var = torch.autograd.Variable(rgb_stack).cuda() 228 | 229 | _, pred_feat = mapper(input_rgb_var, return_feat=True) 230 | 231 | pred_feat = pred_feat.transpose(3, 2).transpose(2, 1).contiguous() 232 | pred = interp(pred_feat) 233 | 234 | pred = F.log_softmax(pred, dim=1) 235 | pred = pred.transpose(1, 2).transpose(2, 3).contiguous() 236 | 237 | label_var = label_var.view(-1) 238 | output = pred.view(-1, args.num_class) 239 | 240 | loss_seg = seg_loss(output, label_var) 241 | loss = loss_seg / args.iter_size_G 242 | loss.backward() 243 | loss_seg_value += loss_seg.item() / args.iter_size_G 244 | 245 | # train with target 246 | try: 247 | _, batch = target_loader_iter.__next__() 248 | except: 249 | target_loader_iter = enumerate(target_loader) 250 | _, batch = target_loader_iter.__next__() 251 | rgb_stack = batch 252 | input_rgb_var = torch.autograd.Variable(rgb_stack).cuda() 253 | _, pred_target = mapper(input_rgb_var, return_feat=True) 254 | pred_target = pred_target.transpose(3, 2).transpose(2, 1).contiguous() 255 | pred_target = interp_target(pred_target) 256 | pred_target = F.log_softmax(pred_target, dim=1) 257 | D_out = model_D1(torch.exp(pred_target)) 258 | 259 | loss_adv_target = bce_loss(D_out, torch.FloatTensor(D_out.data.size()).fill_(source_label).cuda()) 260 | loss = args.lambda_adv_target * loss_adv_target / args.iter_size_G 261 | loss.backward() 262 | loss_adv_target_value += loss_adv_target.item() / args.iter_size_G 263 | 264 | 265 | # train D 266 | # bring back requires_grad 267 | for param in model_D1.parameters(): 268 | param.requires_grad = True 269 | 270 | for sub_i in range(args.iter_size_D): 271 | 272 | # train with source 273 | 274 | pred = pred.detach() 275 | pred = pred.transpose(3, 2).transpose(2, 1).contiguous() 276 | D_out = model_D1(torch.exp(pred)) 277 | 278 | loss_D = bce_loss(D_out, torch.FloatTensor(D_out.data.size()).fill_(source_label).cuda()) 279 | loss_D = loss_D / args.iter_size_D / 2 280 | loss_D.backward() 281 | 282 | loss_D_value += loss_D.item() 283 | 284 | # train with target 285 | pred_target = pred_target.detach() 286 | 287 | D_out = model_D1(torch.exp(pred_target)) 288 | 289 | loss_D = bce_loss(D_out, torch.FloatTensor(D_out.data.size()).fill_(target_label).cuda()) 290 | 291 | loss_D = loss_D / args.iter_size_D / 2 292 | loss_D.backward() 293 | loss_D_value += loss_D.item() 294 | 295 | optimizer.step() 296 | optimizer_D1.step() 297 | 298 | if args.tensorboard: 299 | scalar_info = { 300 | 'loss_seg': loss_seg_value, 301 | 'loss_adv': loss_adv_target_value, 302 | 'loss_D': loss_D_value, 303 | } 304 | 305 | if i_iter % 10 == 0: 306 | print(args.tf_logdir) 307 | for key, val in scalar_info.items(): 308 | writer.add_scalar(key, val, i_iter) 309 | 310 | print( 311 | 'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f} loss_adv = {3:.3f}, loss_D = {4:.3f} '.format( 312 | i_iter, args.num_steps, loss_seg_value, loss_adv_target_value, loss_D_value)) 313 | 314 | if i_iter >= args.num_steps_stop - 1: 315 | print('save model ...') 316 | torch.save(mapper.state_dict(), osp.join(args.snapshot_dir, 'House3D_' + str(args.num_steps_stop) + '.pth')) 317 | torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'House3D_' + str(args.num_steps_stop) + '_D.pth')) 318 | break 319 | 320 | if i_iter % args.save_pred_every == 0 and i_iter != 0: 321 | print('taking snapshot ...') 322 | torch.save(mapper.state_dict(), osp.join(args.snapshot_dir, 'House3D_' + str(i_iter) + '.pth')) 323 | torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'House3D_' + str(i_iter) + '_D.pth')) 324 | 325 | if args.tensorboard: 326 | writer.close() 327 | 328 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 329 | torch.save(state, '%s/%s_checkpoint.pth.tar' % (args.root_model, args.store_name)) 330 | if is_best: 331 | shutil.copyfile('%s/%s_checkpoint.pth.tar' % (args.root_model, args.store_name), '%s/%s_best.pth.tar' % (args.root_model, args.store_name)) 332 | 333 | def lr_poly(base_lr, iter, max_iter, power): 334 | return base_lr * ((1 - float(iter) / max_iter) ** (power)) 335 | 336 | 337 | def adjust_learning_rate(optimizer, i_iter): 338 | lr = lr_poly(args.learning_rate, i_iter, args.num_steps, args.power) 339 | optimizer.param_groups[0]['lr'] = lr 340 | if len(optimizer.param_groups) > 1: 341 | optimizer.param_groups[1]['lr'] = lr * 10 342 | 343 | 344 | def adjust_learning_rate_D(optimizer, i_iter): 345 | lr = lr_poly(args.learning_rate_D, i_iter, args.num_steps, args.power) 346 | optimizer.param_groups[0]['lr'] = lr 347 | if len(optimizer.param_groups) > 1: 348 | optimizer.param_groups[1]['lr'] = lr * 10 349 | 350 | class AverageMeter(object): 351 | """Computes and stores the average and current value""" 352 | def __init__(self): 353 | self.reset() 354 | 355 | def reset(self): 356 | self.val = 0 357 | self.avg = 0 358 | self.sum = 0 359 | self.count = 0 360 | 361 | def update(self, val, n=1): 362 | self.val = val 363 | self.sum += val * n 364 | self.count += n 365 | self.avg = self.sum / self.count 366 | 367 | def accuracy(output, target, topk=(1,)): 368 | """Computes the precision@k for the specified values of k""" 369 | maxk = max(topk) 370 | batch_size = target.size(0) 371 | 372 | _, pred = output.topk(maxk, 1, True, True) 373 | pred = pred.t() 374 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 375 | 376 | res = [] 377 | for k in topk: 378 | correct_k = correct[:k].view(-1).float().sum(0) 379 | res.append(correct_k.mul_(100.0 / batch_size)) 380 | return res 381 | 382 | if __name__=='__main__': 383 | main() 384 | -------------------------------------------------------------------------------- /transform.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import random 3 | from PIL import Image, ImageOps 4 | import numpy as np 5 | import numbers 6 | import math 7 | import torch 8 | 9 | 10 | class GroupRandomCrop(object): 11 | def __init__(self, size): 12 | if isinstance(size, numbers.Number): 13 | self.size = (int(size), int(size)) 14 | else: 15 | self.size = size 16 | 17 | def __call__(self, img_group): 18 | 19 | w, h = img_group[0].size 20 | th, tw = self.size 21 | 22 | out_images = list() 23 | 24 | x1 = random.randint(0, w - tw) 25 | y1 = random.randint(0, h - th) 26 | 27 | for img in img_group: 28 | assert(img.size[0] == w and img.size[1] == h) 29 | if w == tw and h == th: 30 | out_images.append(img) 31 | else: 32 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th))) 33 | 34 | return out_images 35 | 36 | 37 | class GroupCenterCrop(object): 38 | def __init__(self, size): 39 | self.worker = torchvision.transforms.CenterCrop(size) 40 | 41 | def __call__(self, img_group): 42 | return [self.worker(img) for img in img_group] 43 | 44 | 45 | class GroupRandomHorizontalFlip(object): 46 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5 47 | """ 48 | def __init__(self, is_flow=False): 49 | self.is_flow = is_flow 50 | 51 | def __call__(self, img_group, is_flow=False): 52 | v = random.random() 53 | if v < 0.5: 54 | ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group] 55 | if self.is_flow: 56 | for i in range(0, len(ret), 2): 57 | ret[i] = ImageOps.invert(ret[i]) # invert flow pixel values when flipping 58 | return ret 59 | else: 60 | return img_group 61 | 62 | 63 | class GroupNormalize(object): 64 | def __init__(self, mean, std): 65 | self.mean = mean 66 | self.std = std 67 | 68 | def __call__(self, tensor): 69 | rep_mean = self.mean * (tensor.size()[0]//len(self.mean)) 70 | rep_std = self.std * (tensor.size()[0]//len(self.std)) 71 | 72 | # TODO: make efficient 73 | for t, m, s in zip(tensor, rep_mean, rep_std): 74 | t.sub_(m).div_(s) 75 | 76 | return tensor 77 | 78 | 79 | class GroupScale(object): 80 | """ Rescales the input PIL.Image to the given 'size'. 81 | 'size' will be the size of the smaller edge. 82 | For example, if height > width, then image will be 83 | rescaled to (size * height / width, size) 84 | size: size of the smaller edge 85 | interpolation: Default: PIL.Image.BILINEAR 86 | """ 87 | 88 | def __init__(self, size, interpolation=Image.BILINEAR): 89 | self.worker = torchvision.transforms.Scale(size, interpolation) 90 | 91 | def __call__(self, img_group): 92 | return [self.worker(img) for img in img_group] 93 | 94 | 95 | class GroupOverSample(object): 96 | def __init__(self, crop_size, scale_size=None): 97 | self.crop_size = crop_size if not isinstance(crop_size, int) else (crop_size, crop_size) 98 | 99 | if scale_size is not None: 100 | self.scale_worker = GroupScale(scale_size) 101 | else: 102 | self.scale_worker = None 103 | 104 | def __call__(self, img_group): 105 | 106 | if self.scale_worker is not None: 107 | img_group = self.scale_worker(img_group) 108 | 109 | image_w, image_h = img_group[0].size 110 | crop_w, crop_h = self.crop_size 111 | 112 | offsets = GroupMultiScaleCrop.fill_fix_offset(False, image_w, image_h, crop_w, crop_h) 113 | oversample_group = list() 114 | for o_w, o_h in offsets: 115 | normal_group = list() 116 | flip_group = list() 117 | for i, img in enumerate(img_group): 118 | crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h)) 119 | normal_group.append(crop) 120 | flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT) 121 | 122 | if img.mode == 'L' and i % 2 == 0: 123 | flip_group.append(ImageOps.invert(flip_crop)) 124 | else: 125 | flip_group.append(flip_crop) 126 | 127 | oversample_group.extend(normal_group) 128 | oversample_group.extend(flip_group) 129 | return oversample_group 130 | 131 | 132 | class GroupMultiScaleCrop(object): 133 | 134 | def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True): 135 | self.scales = scales if scales is not None else [1, 875, .75, .66] 136 | self.max_distort = max_distort 137 | self.fix_crop = fix_crop 138 | self.more_fix_crop = more_fix_crop 139 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size] 140 | self.interpolation = Image.BILINEAR 141 | 142 | def __call__(self, img_group): 143 | 144 | im_size = img_group[0].size 145 | 146 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size) 147 | crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group] 148 | ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) 149 | for img in crop_img_group] 150 | return ret_img_group 151 | 152 | def _sample_crop_size(self, im_size): 153 | image_w, image_h = im_size[0], im_size[1] 154 | 155 | # find a crop size 156 | base_size = min(image_w, image_h) 157 | crop_sizes = [int(base_size * x) for x in self.scales] 158 | crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes] 159 | crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes] 160 | 161 | pairs = [] 162 | for i, h in enumerate(crop_h): 163 | for j, w in enumerate(crop_w): 164 | if abs(i - j) <= self.max_distort: 165 | pairs.append((w, h)) 166 | 167 | crop_pair = random.choice(pairs) 168 | if not self.fix_crop: 169 | w_offset = random.randint(0, image_w - crop_pair[0]) 170 | h_offset = random.randint(0, image_h - crop_pair[1]) 171 | else: 172 | w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1]) 173 | 174 | return crop_pair[0], crop_pair[1], w_offset, h_offset 175 | 176 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h): 177 | offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h) 178 | return random.choice(offsets) 179 | 180 | @staticmethod 181 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h): 182 | w_step = (image_w - crop_w) // 4 183 | h_step = (image_h - crop_h) // 4 184 | 185 | ret = list() 186 | ret.append((0, 0)) # upper left 187 | ret.append((4 * w_step, 0)) # upper right 188 | ret.append((0, 4 * h_step)) # lower left 189 | ret.append((4 * w_step, 4 * h_step)) # lower right 190 | ret.append((2 * w_step, 2 * h_step)) # center 191 | 192 | if more_fix_crop: 193 | ret.append((0, 2 * h_step)) # center left 194 | ret.append((4 * w_step, 2 * h_step)) # center right 195 | ret.append((2 * w_step, 4 * h_step)) # lower center 196 | ret.append((2 * w_step, 0 * h_step)) # upper center 197 | 198 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter 199 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter 200 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter 201 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter 202 | 203 | return ret 204 | 205 | 206 | class GroupRandomSizedCrop(object): 207 | """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size 208 | and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio 209 | This is popularly used to train the Inception networks 210 | size: size of the smaller edge 211 | interpolation: Default: PIL.Image.BILINEAR 212 | """ 213 | def __init__(self, size, interpolation=Image.BILINEAR): 214 | self.size = size 215 | self.interpolation = interpolation 216 | 217 | def __call__(self, img_group): 218 | for attempt in range(10): 219 | area = img_group[0].size[0] * img_group[0].size[1] 220 | target_area = random.uniform(0.08, 1.0) * area 221 | aspect_ratio = random.uniform(3. / 4, 4. / 3) 222 | 223 | w = int(round(math.sqrt(target_area * aspect_ratio))) 224 | h = int(round(math.sqrt(target_area / aspect_ratio))) 225 | 226 | if random.random() < 0.5: 227 | w, h = h, w 228 | 229 | if w <= img_group[0].size[0] and h <= img_group[0].size[1]: 230 | x1 = random.randint(0, img_group[0].size[0] - w) 231 | y1 = random.randint(0, img_group[0].size[1] - h) 232 | found = True 233 | break 234 | else: 235 | found = False 236 | x1 = 0 237 | y1 = 0 238 | 239 | if found: 240 | out_group = list() 241 | for img in img_group: 242 | img = img.crop((x1, y1, x1 + w, y1 + h)) 243 | assert(img.size == (w, h)) 244 | out_group.append(img.resize((self.size, self.size), self.interpolation)) 245 | return out_group 246 | else: 247 | # Fallback 248 | scale = GroupScale(self.size, interpolation=self.interpolation) 249 | crop = GroupRandomCrop(self.size) 250 | return crop(scale(img_group)) 251 | 252 | 253 | class Stack(object): 254 | 255 | def __init__(self, roll=False): 256 | self.roll = roll 257 | 258 | def __call__(self, img_group): 259 | if self.roll: 260 | return np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2) 261 | else: 262 | return np.concatenate(img_group, axis=2) 263 | 264 | 265 | class ToTorchFormatTensor(object): 266 | """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] 267 | to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """ 268 | def __init__(self, div=True, float=True): 269 | self.div = div 270 | self.float = float 271 | 272 | def __call__(self, pic): 273 | if isinstance(pic, np.ndarray): 274 | # handle numpy array 275 | img = torch.from_numpy(pic).permute(2, 0, 1).contiguous() 276 | else: 277 | # handle PIL Image 278 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 279 | img = img.view(pic.size[1], pic.size[0], len(pic.mode)) 280 | # put it from HWC to CHW format 281 | # yikes, this transpose takes 80% of the loading time/CPU 282 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 283 | if self.div: 284 | img = img.div(255) 285 | return img.float() if self.float else img.long() 286 | 287 | 288 | class IdentityTransform(object): 289 | 290 | def __call__(self, data): 291 | return data 292 | 293 | 294 | if __name__ == "__main__": 295 | trans = torchvision.transforms.Compose([ 296 | GroupScale(256), 297 | GroupRandomCrop(224), 298 | Stack(), 299 | ToTorchFormatTensor(), 300 | GroupNormalize( 301 | mean=[.485, .456, .406], 302 | std=[.229, .224, .225] 303 | )] 304 | ) 305 | 306 | im = Image.open('../tensorflow-model-zoo.torch/lena_299.png') 307 | 308 | color_group = [im] * 3 309 | rst = trans(color_group) 310 | 311 | gray_group = [im.convert('L')] * 9 312 | gray_rst = trans(gray_group) 313 | 314 | trans2 = torchvision.transforms.Compose([ 315 | GroupRandomSizedCrop(256), 316 | Stack(), 317 | ToTorchFormatTensor(), 318 | GroupNormalize( 319 | mean=[.485, .456, .406], 320 | std=[.229, .224, .225]) 321 | ]) 322 | print(trans2(color_group)) 323 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np, os, time 2 | from six.moves import xrange 3 | import logging 4 | import torch 5 | from torch.autograd import Variable 6 | 7 | class Foo(object): 8 | def __init__(self, **kwargs): 9 | self.__dict__.update(kwargs) 10 | def __str__(self): 11 | str_ = '' 12 | for v in vars(self).keys(): 13 | a = getattr(self, v) 14 | str__ = str(a) 15 | str__ = str__.replace('\n', '\n ') 16 | str_ += '{:s}: {:s}'.format(v, str__) 17 | str_ += '\n' 18 | return str_ 19 | 20 | 21 | def get_flow(t, theta, map_size): 22 | """ 23 | Rotates the map by theta and translates the rotated map by t. 24 | 25 | Assume that the robot rotates by an angle theta and then moves forward by 26 | translation t. This function returns the flow field field. For every pixel in 27 | the new image it tells us which pixel in the original image it came from: 28 | NewI(x, y) = OldI(flow_x(x,y), flow_y(x,y)). 29 | 30 | Assume there is a point p in the original image. Robot rotates by R and moves 31 | forward by t. p1 = Rt*p; p2 = p1 - t; (the world moves in opposite direction. 32 | So, p2 = Rt*p - t, thus p2 came from R*(p2+t), which is what this function 33 | calculates. 34 | 35 | t: ... x 2 (translation for B batches of N motions each). 36 | theta: ... x 1 (rotation for B batches of N motions each). 37 | 38 | Output: ... x map_size x map_size x 2 39 | """ 40 | B = t.view(-1, 2).size()[0] 41 | tx, ty = torch.unbind(torch.view(t, [-1, 1, 1, 1, 2]), dim=4) # Bx1x1x1 42 | theta = torch.view(theta, [-1, 1, 1, 1]) 43 | # c = tf.constant((map_size - 1.) / 2., dtype=tf.float32) 44 | c = Variable(torch.Tensor([(map_size - 1.) / 2.]).double()) 45 | x, y = np.meshgrid(np.arange(map_size[0]), np.arange(map_size[1])) 46 | x = Variable(x[np.newaxis, :, :, np.newaxis]).view(1, map_size[0], map_size[1], 1) 47 | y = Variable(y[np.newaxis, :, :, np.newaxis]).view(1, map_size[0], map_size[1], 1) 48 | # x = tf.constant(x[np.newaxis, :, :, np.newaxis], dtype=tf.float32, name='x', 49 | # shape=[1, map_size, map_size, 1]) 50 | # y = tf.constant(y[np.newaxis, :, :, np.newaxis], dtype=tf.float32, name='y', 51 | # shape=[1, map_size, map_size, 1]) 52 | 53 | tx = tx - c.expand(tx.size()) 54 | x = x.expand([B] + x.size()[1:]) 55 | x = x + tx.expand(x.size()) 56 | ty = ty - c.expand(ty.size()) 57 | y = y.expand([B] + y.size()[1:]) 58 | y = y + ty.expand(y.size()) # BxHxWx1 59 | # x = x - (-tx + c.expand(tx.size())) #1xHxWx1 60 | # y = y - (-ty + c.expand(ty.size())) 61 | 62 | sin_theta = torch.sin(theta) #Bx1x1x1 63 | cos_theta = torch.cos(theta) 64 | xr = x * cos_theta.expand(x.size()) - y * sin_theta.expand(y.size()) 65 | yr = x * sin_theta.expand(x.size()) + y * cos_theta.expand(y.size()) # BxHxWx1 66 | # xr = cos_theta * x - sin_theta * y 67 | # yr = sin_theta * x + cos_theta * y 68 | 69 | xr = xr + c.expand(xr.size()) 70 | yr = yr + c.expand(yr.size()) 71 | 72 | flow = torch.stack([xr, yr], axis=-1) 73 | sh = t.size()[:-1] + [map_size[0], map_size[1], 2] 74 | # sh = tf.unstack(tf.shape(t), axis=0) 75 | # sh = tf.stack(sh[:-1] + [tf.constant(_, dtype=tf.int32) for _ in [map_size, map_size, 2]]) 76 | flow = torch.view(flow, shape=sh) 77 | return flow 78 | 79 | 80 | def dense_resample(im, flow_im, output_valid_mask=False): 81 | """ Resample reward at particular locations. 82 | Args: 83 | im: ...xHxW matrix to sample from. 84 | flow_im: ...xHxWx2 matrix, samples the image using absolute offsets as given 85 | by the flow_im. 86 | """ 87 | valid_mask = None 88 | 89 | x, y = torch.unbind(flow_im, axis=-1) 90 | x = x.view(-1) 91 | y = y.view(-1) 92 | 93 | # constants 94 | # shape = tf.unstack(tf.shape(im)) 95 | # channels = shape[-1] 96 | shape = im.size() 97 | width = shape[-1] 98 | height = shape[-2] 99 | num_batch = 1 100 | for dim in shape[:-2]: 101 | num_batch *= dim 102 | zero = Variable(torch.Tensor([0]).double()) 103 | # num_batch = tf.cast(tf.reduce_prod(tf.stack(shape[:-3])), 'int32') 104 | # zero = tf.constant(0, dtype=tf.int32) 105 | 106 | # Round up and down. 107 | x0 = torch.floor(x) 108 | x1 = x0 + 1 109 | y0 = torch.floor(y) 110 | y1 = y0 + 1 111 | 112 | x0 = x0.clamp(0, width - 1) 113 | x1 = x1.clamp(0, width - 1) 114 | y0 = y0.clamp(0, height - 1) 115 | y1 = y1.clamp(0, height - 1) 116 | dim2 = width 117 | dim1 = width * height 118 | 119 | # Create base index 120 | base = torch.range(num_batch) * dim1 121 | base = base.view(-1, 1) 122 | # base = tf.reshape(tf.range(num_batch) * dim1, shape=[-1, 1]) 123 | base = base.expand(base.size()[0], height * width).view(-1) # batch_size * H * W 124 | # base = tf.reshape(tf.tile(base, [1, height * width]), shape=[-1]) 125 | 126 | base_y0 = base + y0.expand(base.size()) * dim2 127 | base_y1 = base + y1.expand(base.size()) * dim2 128 | idx_a = base_y0 + x0.expand(base_y0.size()) 129 | idx_b = base_y1 + x0.expand(base_y1.size()) 130 | idx_c = base_y0 + x1.expand(base_y0.size()) 131 | idx_d = base_y1 + x1.expand(base_y1.size()) 132 | 133 | # use indices to lookup pixels in the flat image and restore channels dim 134 | # sh = tf.stack([tf.constant(-1, dtype=tf.int32), channels]) 135 | im_flat = torch.view(im, [-1]) 136 | # im_flat = tf.cast(tf.reshape(im, sh), dtype=tf.float32) 137 | pixel_a = torch.gather(im_flat, idx_a) 138 | pixel_b = torch.gather(im_flat, idx_b) 139 | pixel_c = torch.gather(im_flat, idx_c) 140 | pixel_d = torch.gather(im_flat, idx_d) 141 | 142 | # and finally calculate interpolated values 143 | # x1_f = tf.to_float(x1) 144 | # y1_f = tf.to_float(y1) 145 | x1_f = x1.float() 146 | y1_f = y1.float() 147 | 148 | wa = torch.unsqueeze(((x1_f - x) * (y1_f - y)), 1) 149 | wb = torch.unsqueeze(((x1_f - x) * (1.0 - (y1_f - y))), 1) 150 | wc = torch.unsqueeze(((1.0 - (x1_f - x)) * (y1_f - y)), 1) 151 | wd = torch.unsqueeze(((1.0 - (x1_f - x)) * (1.0 - (y1_f - y))), 1) 152 | 153 | output = wa * pixel_a.unsqueeze(1) + wb * pixel_b.unsqueeze(1) + wc * pixel_c.unsqueeze(1) + wd * pixel_d.unsqueeze(1) 154 | # output = tf.reshape(output, shape=tf.shape(im)) 155 | output = output.view(im.size()) 156 | return output, valid_mask 157 | --------------------------------------------------------------------------------