├── README.md
├── datasets.py
├── metadata
├── colormap_coarse.csv
├── train_list.txt
└── val_list.txt
├── models.py
├── opts.py
├── segmentTool
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── models.cpython-36.pyc
│ ├── resnet.cpython-36.pyc
│ └── resnext.cpython-36.pyc
├── lib
│ ├── nn
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ └── __init__.cpython-36.pyc
│ │ ├── modules
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ │ ├── __init__.cpython-36.pyc
│ │ │ │ ├── batchnorm.cpython-36.pyc
│ │ │ │ ├── comm.cpython-36.pyc
│ │ │ │ └── replicate.cpython-36.pyc
│ │ │ ├── batchnorm.py
│ │ │ ├── comm.py
│ │ │ ├── replicate.py
│ │ │ ├── tests
│ │ │ │ ├── test_numeric_batchnorm.py
│ │ │ │ └── test_sync_batchnorm.py
│ │ │ └── unittest.py
│ │ └── parallel
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ ├── __init__.cpython-36.pyc
│ │ │ └── data_parallel.cpython-36.pyc
│ │ │ └── data_parallel.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ └── th.cpython-36.pyc
│ │ ├── data
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-36.pyc
│ │ │ ├── dataloader.cpython-36.pyc
│ │ │ ├── dataset.cpython-36.pyc
│ │ │ └── sampler.cpython-36.pyc
│ │ ├── dataloader.py
│ │ ├── dataset.py
│ │ ├── distributed.py
│ │ └── sampler.py
│ │ └── th.py
├── models.py
├── resnet.py
└── resnext.py
├── test.py
├── test_carla.py
├── test_seq.py
├── tools
├── get_trainning_data_from_house3d.py
├── process_data_for_VPN.py
├── process_transfer_driving_data.py
└── process_transfer_indoor_data.py
├── train.py
├── train_carla.py
├── train_transfer.py
├── transform.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # Cross-view Semantic Segmentation for Sensing Surroundings
2 |
3 | We release the code of the View Parsing Networks, the main model for Cross-view Semantic Segmentation task.
4 |
5 |
6 | **[Cross-View Semantic Segmentation for Sensing Surroundings
7 | ](https://arxiv.org/pdf/1906.03560.pdf)**
8 |
9 | [Bowen Pan](http://people.csail.mit.edu/bpan/),
10 | [Jiankai Sun](),
11 | [Ho Yin Tiga Leung](),
12 | [Alex Andonian](https://www.alexandonian.com/), and
13 | [Bolei Zhou](http://bzhou.ie.cuhk.edu.hk/)
14 |
15 | IEEE Robotics and Automation Letters
16 |
17 | In IEEE International Conference on Intelligent Robots and Systems (IROS) 2020
18 |
19 | [[Paper]](https://arxiv.org/pdf/1906.03560.pdf)
20 | [[Project Page]](https://decisionforce.github.io/VPN/)
21 |
22 | ```
23 | @ARTICLE{pan2019crossview,
24 | author={B. {Pan} and J. {Sun} and H. Y. T. {Leung} and A. {Andonian} and B. {Zhou}},
25 | journal={IEEE Robotics and Automation Letters},
26 | title={Cross-View Semantic Segmentation for Sensing Surroundings},
27 | year={2020},
28 | volume={5},
29 | number={3},
30 | pages={4867-4873},
31 | }
32 | ```
33 |
34 |
35 | ### Requirement
36 | - Install the [House3D](https://github.com/facebookresearch/House3D) simulator, or [Gibson](http://gibsonenv.stanford.edu) simulator.
37 | - Software: Ubuntu 16.04.3 LTS, CUDA>=8.0, Python>=3.5, PyTorch>=0.4.0
38 |
39 | ## Train and test VPN
40 |
41 | ### Data processing (use House3D for example)
42 | - Use [get_training_data_from_house3d.py](https://github.com/pbw-Berwin/View-Parsing-Network/blob/master/tools/get_trainning_data_from_house3d.py) to extract data from House3D environment.
43 | - Use [process_data_for_VPN.py](https://github.com/pbw-Berwin/View-Parsing-Network/blob/master/tools/process_data_for_VPN.py) to split training set and validation set, and visualize data.
44 |
45 | ### Training Command
46 | ```
47 | # Training in indoor-room scenarios, using RGB input modality, with 8 input views.
48 | python -u train.py --fc-dim 256 --use-depth false --use-mask false --transform-type fc --input-resolution 400 --label-res 25 --store-name [STORE_NAME] --n-views 8 --batch-size 48 -j 10 --data_root [PATH_TO_DATASET_ROOT] --train-list [PATH_TO_TRAIN_LIST] --eval-list [PATH_TO_EVAL_LIST]
49 |
50 | # Training in driving-traffic scenarios, using RGB input modality, with 6 input views.
51 | python -u train_carla.py --fc-dim 256 --use-depth false --use-mask false --transform-type fc --input-resolution 400 --label-res 25 --store-name [STORE_NAME] --n-views 6 --batch-size 48 -j 10 --data_root [PATH_TO_DATASET_ROOT] --train-list [PATH_TO_TRAIN_LIST] --eval-list [PATH_TO_EVAL_LIST]
52 | ```
53 |
54 | ### Testing Command
55 | ```
56 | # Training in indoor-room scenarios, using RGB input modality, with 8 input views.
57 | python -u test.py --fc-dim 256 --use-depth false --use-mask false --transform-type fc --input-resolution 400 --label-res 25 --store-name [STORE_NAME] --n-views 8 --batch-size 4 --test-views 8 --data_root [PATH_TO_DATASET_ROOT] --eval-list [PATH_TO_EVAL_LIST] --num-class [NUM_CLASS] -j 10 --weights [PATH_TO_PRETRAIN_MODEL]
58 |
59 | # Testing in driving-traffic scenarios, using RGB input modality, with 6 input views.
60 | python -u test_carla.py --fc-dim 256 --use-depth false --use-mask false --transform-type fc --input-resolution 400 --label-res 25 --store-name [STORE_NAME] --n-views 6 --batch-size 4 --test-views 6 --data_root [PATH_TO_DATASET_ROOT] --eval-list [PATH_TO_EVAL_LIST] --num-class [NUM_CLASS] -j 10 --weights [PATH_TO_PRETRAIN_MODEL]
61 | ```
62 |
63 | ## Transfer learning for sim-to-real adaptation
64 |
65 | ### Data processing (use indoor-room scenarios for example)
66 | - Use [process_transfer_indoor_data.py](https://github.com/pbw-Berwin/View-Parsing-Network/blob/master/tools/process_transfer_indoor_data.py) to split source domain set and target domain set.
67 |
68 | ### Training Command
69 | ```
70 | # Training in indoor-room scenarios, using RGB input modality, with 8 input views.
71 | python -u train_transfer.py --task-id [TASK_NAME] --num-class [NUM_CLASS] --learning-rate-D 3e-6 --iter-size-G 1 --iter-size-D 1 --snapshot-dir ./snapshot --batch-size 20 --tensorboard true --n-views 6 --train_source_list [PATH_TO_TRAIN_LIST] --train_target_list [PATH_TO_EVAL_LIST] --VPN-weights [PATH_TO_PRETRAINED_WEIGHT] --scenarios indoor
72 | ```
73 |
--------------------------------------------------------------------------------
/metadata/colormap_coarse.csv:
--------------------------------------------------------------------------------
1 | name,r,g,b
2 | ottoman,0,0,0
3 | storage_bench,255,255,0
4 | mirror,28,230,255
5 | shower,255,52,255
6 | kitchen_appliance,255,74,70
7 | sofa,0,137,65
8 | outdoor_lamp,0,111,166
9 | tripod,163,0,89
10 | toilet,122,73,0
11 | cart,0,0,166
12 | Ground,99,255,172
13 | decoration,183,151,98
14 | arch,0,77,67
15 | unknown,143,176,255
16 | pet,153,125,135
17 | computer,90,0,7
18 | stairs,128,150,147
19 | bed,254,255,230
20 | heater,27,68,0
21 | plant,79,198,1
22 | mailbox,59,93,255
23 | fence,74,59,83
24 | shelving,255,47,128
25 | safe,97,97,90
26 | person,186,9,0
27 | desk,107,121,0
28 | television,0,194,160
29 | stand,255,170,146
30 | table_and_chair,255,144,201
31 | whiteboard,185,3,170
32 | ceiling,209,97,0
33 | coffin,221,239,255
34 | window,0,0,53
35 | drinkbar,123,79,75
36 | picture_frame,161,194,153
37 | chair,48,0,24
38 | wood_board,10,166,216
39 | headstone,1,51,73
40 | OTHER,0,132,111
41 | wardrobe_cabinet,55,33,1
42 | floor,255,181,0
43 | fireplace,194,255,237
44 | cloth,160,121,191
45 | empty,204,7,68
46 | tv_stand,194,255,153
47 | garage_door,0,30,9
48 | workplace,0,72,156
49 | table,111,0,98
50 | hanging_kitchen_cabinet,12,189,102
51 | shoes,238,195,255
52 | trash_can,69,109,117
53 | kitchen_set,183,123,104
54 | ATM,122,135,161
55 | bench_chair,120,141,102
56 | hanger,136,85,120
57 | outdoor_seating,250,208,159
58 | bathtub,255,138,154
59 | kitchenware,209,87,160
60 | outdoor_cover,190,196,89
61 | dressing_table,69,102,72
62 | bathroom_stuff,0,134,237
63 | vehicle,136,111,76
64 | candle,52,54,45
65 | dresser,180,168,189
66 | Ceiling,0,166,170
67 | clock,69,44,44
68 | pool,99,99,117
69 | door,163,200,201
70 | fan,255,145,63
71 | magazines,147,138,129
72 | kitchen_cabinet,87,83,41
73 | partition,0,254,207
74 | recreation,176,91,111
75 | trinket,140,208,255
76 | Box,59,151,0
77 | wall,4,247,87
78 | roof,200,161,161
79 | books,30,110,0
80 | gym_equipment,121,0,215
81 | sink,167,117,0
82 | rug,99,103,169
83 | outdoor_spring,160,88,55
84 | indoor_lamp,107,0,44
85 | household_appliance,119,38,0
86 | pillow,215,144,255
87 | air_conditioner,155,151,0
88 | toy,84,158,121
89 | grill,255,246,159
90 | vase,32,22,37
91 | column,114,65,143
92 | switch,188,35,255
93 | shoes_cabinet,153,173,192
94 | curtain,58,36,101
95 | music,146,35,41
96 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # File : models.py
4 | # Author : Bowen Pan
5 | # Email : panbowen0607@gmail.com
6 | # Date : 09/18/2018
7 | #
8 | # Distributed under terms of the MIT license.
9 |
10 | import torch
11 | from torch import nn
12 | import utils
13 | from collections import OrderedDict
14 | import torch.nn.functional as F
15 | from segmentTool.models import *
16 | import numpy as np
17 | from itertools import combinations
18 |
19 | builder = ModelBuilder()
20 |
21 | class TransformModule(nn.Module):
22 | def __init__(self, dim=25, num_view=8):
23 | super(TransformModule, self).__init__()
24 | self.num_view = num_view
25 | self.dim = dim
26 | self.mat_list = nn.ModuleList()
27 | for i in range(self.num_view):
28 | fc_transform = nn.Sequential(
29 | nn.Linear(dim * dim, dim * dim),
30 | nn.ReLU(),
31 | nn.Linear(dim * dim, dim * dim),
32 | nn.ReLU()
33 | )
34 | self.mat_list += [fc_transform]
35 |
36 | def forward(self, x):
37 | # shape x: B, V, C, H, W
38 | x = x.view(list(x.size()[:3]) + [self.dim * self.dim,])
39 | view_comb = self.mat_list[0](x[:, 0])
40 | for index in range(x.size(1))[1:]:
41 | view_comb += self.mat_list[index](x[:, index])
42 | view_comb = view_comb.view(list(view_comb.size()[:2]) + [self.dim, self.dim])
43 | return view_comb
44 |
45 |
46 | class SumModule(nn.Module):
47 | def __init__(self):
48 | super(SumModule, self).__init__()
49 |
50 | def forward(self, x):
51 | # shape x: B, V, C, H, W
52 | x = torch.sum(x, dim=1, keepdim=False)
53 | return x
54 |
55 |
56 | class VPNModel(nn.Module):
57 | def __init__(self, config):
58 | super(VPNModel, self).__init__()
59 | self.num_views = config.num_views
60 | self.output_size = config.output_size
61 | self.transform_type = config.transform_type
62 | print('Views number: ' + str(self.num_views))
63 | print('Transform Type: ', self.transform_type)
64 | self.encoder = builder.build_encoder(
65 | arch=config.encoder,
66 | fc_dim=config.fc_dim,
67 | )
68 | if self.transform_type == 'fc':
69 | self.transform_module = TransformModule(dim=self.output_size, num_view=self.num_views)
70 | elif self.transform_type == 'sum':
71 | self.transform_module = SumModule()
72 | self.decoder = builder.build_decoder(
73 | arch=config.decoder,
74 | fc_dim=config.fc_dim,
75 | num_class=config.num_class,
76 | use_softmax=False,
77 | )
78 |
79 | def forward(self, x, return_feat=False):
80 | B, N, C, H, W = x.view([-1, self.num_views, int(x.size()[1] / self.num_views)] \
81 | + list(x.size()[2:])).size()
82 |
83 | x = x.view(B*N, C, H, W)
84 | x = self.encoder(x)[0]
85 | x = x.view([B, N] + list(x.size()[1:]))
86 | x = self.transform_module(x)
87 | if return_feat:
88 | x, feat = self.decoder([x], return_feat=return_feat)
89 | else:
90 | x = self.decoder([x])
91 | x = x.transpose(1,2).transpose(2,3).contiguous()
92 | if return_feat:
93 | feat = feat.transpose(1,2).transpose(2,3).contiguous()
94 | return x, feat
95 | return x
96 |
97 |
98 |
99 | class FCDiscriminator(nn.Module):
100 | def __init__(self, num_classes, ndf = 64):
101 | super(FCDiscriminator, self).__init__()
102 |
103 | self.conv1 = nn.Conv2d(num_classes, ndf, kernel_size=4, stride=2, padding=1)
104 | self.conv2 = nn.Conv2d(ndf, ndf*2, kernel_size=4, stride=2, padding=1)
105 | self.conv3 = nn.Conv2d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1)
106 | self.conv4 = nn.Conv2d(ndf*4, ndf*8, kernel_size=4, stride=2, padding=1)
107 | self.classifier = nn.Conv2d(ndf*8, 1, kernel_size=4, stride=2, padding=1)
108 |
109 | self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
110 | #self.up_sample = nn.Upsample(scale_factor=32, mode='bilinear')
111 | #self.sigmoid = nn.Sigmoid()
112 |
113 |
114 | def forward(self, x):
115 | x = self.conv1(x)
116 | x = self.leaky_relu(x)
117 | x = self.conv2(x)
118 | x = self.leaky_relu(x)
119 | x = self.conv3(x)
120 | x = self.leaky_relu(x)
121 | x = self.conv4(x)
122 | x = self.leaky_relu(x)
123 | x = self.classifier(x)
124 | #x = self.up_sample(x)
125 | #x = self.sigmoid(x)
126 | return x
127 |
128 |
129 |
130 |
--------------------------------------------------------------------------------
/opts.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def str2bool(v):
4 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
5 | return True
6 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
7 | return False
8 | else:
9 | raise argparse.ArgumentTypeError('Unsupported value encountered.')
10 |
11 | parser = argparse.ArgumentParser(description="PyTorch implementation of CMP")
12 | parser.add_argument('--data_root', type=str, default='/data/vision/oliva/scenedataset/syntheticscene/TopViewMaskDataset')
13 | parser.add_argument('--test-dir', type=str, default='')
14 | parser.add_argument('--train-list', type=str, default='./metadata/train_list.txt')
15 | parser.add_argument('--eval-list', type=str, default='./metadata/val_list.txt')
16 | parser.add_argument('--start-lr', type=float, default=2e-4)
17 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float,
18 | metavar='W', help='weight decay (default: 5e-4)')
19 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
20 | help='momentum')
21 | parser.add_argument('--transform-type', type=str, default='fc')
22 | parser.add_argument('--use-mask', type=str2bool, nargs='?', const=False)
23 | parser.add_argument('--use-depth', type=str2bool, nargs='?', const=False)
24 | parser.add_argument('--input-resolution', default=400, type=int, metavar='N')
25 | parser.add_argument('--label-resolution', default=25, type=int, metavar='N')
26 | parser.add_argument('--fc-dim', default=256, type=int, metavar='N')
27 | parser.add_argument('--segSize', default=256, type=int, metavar='N')
28 | parser.add_argument('--log-root', type=str, default='./log')
29 | parser.add_argument('--encoder', type=str, default='resnet18')
30 | parser.add_argument('--decoder', type=str, default='ppm_bilinear')
31 | parser.add_argument('--store-name', type=str, default='')
32 | parser.add_argument('--start_epoch', type=int, default=0)
33 | parser.add_argument('--epochs', type=int, default=40)
34 | parser.add_argument('--n-views', type=int, default=8)
35 | parser.add_argument('--lr_steps', default=[10], type=float, nargs="+",
36 | metavar='LRSteps', help='epochs to decay learning rate by 10')
37 | parser.add_argument('--print-freq', '-p', default=10, type=int,
38 | metavar='N', help='print frequency (default: 10)')
39 | parser.add_argument('--print-img-freq', default=100, type=int,
40 | metavar='N', help='print frequency (default: 100)')
41 | parser.add_argument('--ckpt-freq', default=1, type=int,
42 | metavar='N', help='save frequency (default: 2)')
43 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
44 | help='evaluate model on validation set')
45 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
46 | help='path to latest checkpoint (default: none)')
47 | parser.add_argument('--tflogdir', default='', type=str, metavar="PATH")
48 | parser.add_argument('--logname', type=str, default="")
49 | parser.add_argument('-b', '--batch-size', default=104, type=int,
50 | metavar='N', help='mini-batch size (default: 256)')
51 | parser.add_argument('--scale_size', default=224, type=int)
52 | parser.add_argument('--num-class', default=94, type=int)
53 | parser.add_argument('-j', '--num_workers', default=4, type=int, metavar='N',
54 | help='number of data loading workers (default: 4)')
55 | parser.add_argument('--root_model', type=str, default='model')
56 | parser.add_argument('--weights', type=str,
57 | default='')
58 | parser.add_argument('--visualize', type=str, default='./visualize')
59 | parser.add_argument('--centralSize', type=int, default=12)
60 | parser.add_argument('--mapSize', type=int, default=1000)
61 | parser.add_argument('--ppi', type=int, default=4)
62 | parser.add_argument('--scale', type=float, default=2)
63 | parser.add_argument('--trajectory-file', type=str, default='')
64 | parser.add_argument('--real-scale', type=float, default=1.2)
65 | parser.add_argument('--use-topdown', type=str2bool, default=False)
66 | parser.add_argument('--visual-input', type=str2bool, default=False)
67 |
--------------------------------------------------------------------------------
/segmentTool/__init__.py:
--------------------------------------------------------------------------------
1 | from .models import ModelBuilder, SegmentationModule
2 |
--------------------------------------------------------------------------------
/segmentTool/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pbw-Berwin/View-Parsing-Network/9ccfcce978aecbe6c469256013e6ca94a2436b67/segmentTool/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/segmentTool/__pycache__/models.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pbw-Berwin/View-Parsing-Network/9ccfcce978aecbe6c469256013e6ca94a2436b67/segmentTool/__pycache__/models.cpython-36.pyc
--------------------------------------------------------------------------------
/segmentTool/__pycache__/resnet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pbw-Berwin/View-Parsing-Network/9ccfcce978aecbe6c469256013e6ca94a2436b67/segmentTool/__pycache__/resnet.cpython-36.pyc
--------------------------------------------------------------------------------
/segmentTool/__pycache__/resnext.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pbw-Berwin/View-Parsing-Network/9ccfcce978aecbe6c469256013e6ca94a2436b67/segmentTool/__pycache__/resnext.cpython-36.pyc
--------------------------------------------------------------------------------
/segmentTool/lib/nn/__init__.py:
--------------------------------------------------------------------------------
1 | from .modules import *
2 | from .parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to
3 |
--------------------------------------------------------------------------------
/segmentTool/lib/nn/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pbw-Berwin/View-Parsing-Network/9ccfcce978aecbe6c469256013e6ca94a2436b67/segmentTool/lib/nn/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/segmentTool/lib/nn/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : __init__.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
12 | from .replicate import DataParallelWithCallback, patch_replication_callback
13 |
--------------------------------------------------------------------------------
/segmentTool/lib/nn/modules/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pbw-Berwin/View-Parsing-Network/9ccfcce978aecbe6c469256013e6ca94a2436b67/segmentTool/lib/nn/modules/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/segmentTool/lib/nn/modules/__pycache__/batchnorm.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pbw-Berwin/View-Parsing-Network/9ccfcce978aecbe6c469256013e6ca94a2436b67/segmentTool/lib/nn/modules/__pycache__/batchnorm.cpython-36.pyc
--------------------------------------------------------------------------------
/segmentTool/lib/nn/modules/__pycache__/comm.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pbw-Berwin/View-Parsing-Network/9ccfcce978aecbe6c469256013e6ca94a2436b67/segmentTool/lib/nn/modules/__pycache__/comm.cpython-36.pyc
--------------------------------------------------------------------------------
/segmentTool/lib/nn/modules/__pycache__/replicate.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pbw-Berwin/View-Parsing-Network/9ccfcce978aecbe6c469256013e6ca94a2436b67/segmentTool/lib/nn/modules/__pycache__/replicate.cpython-36.pyc
--------------------------------------------------------------------------------
/segmentTool/lib/nn/modules/batchnorm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : batchnorm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import collections
12 |
13 | import torch
14 | import torch.nn.functional as F
15 |
16 | from torch.nn.modules.batchnorm import _BatchNorm
17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
18 |
19 | from .comm import SyncMaster
20 |
21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
22 |
23 |
24 | def _sum_ft(tensor):
25 | """sum over the first and last dimention"""
26 | return tensor.sum(dim=0).sum(dim=-1)
27 |
28 |
29 | def _unsqueeze_ft(tensor):
30 | """add new dementions at the front and the tail"""
31 | return tensor.unsqueeze(0).unsqueeze(-1)
32 |
33 |
34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
36 |
37 |
38 | class _SynchronizedBatchNorm(_BatchNorm):
39 | def __init__(self, num_features, eps=1e-5, momentum=0.001, affine=True):
40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
41 |
42 | self._sync_master = SyncMaster(self._data_parallel_master)
43 |
44 | self._is_parallel = False
45 | self._parallel_id = None
46 | self._slave_pipe = None
47 |
48 | # customed batch norm statistics
49 | self._moving_average_fraction = 1. - momentum
50 | self.register_buffer('_tmp_running_mean', torch.zeros(self.num_features))
51 | self.register_buffer('_tmp_running_var', torch.ones(self.num_features))
52 | self.register_buffer('_running_iter', torch.ones(1))
53 | self._tmp_running_mean = self.running_mean.clone() * self._running_iter
54 | self._tmp_running_var = self.running_var.clone() * self._running_iter
55 |
56 | def forward(self, input):
57 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
58 | if not (self._is_parallel and self.training):
59 | return F.batch_norm(
60 | input, self.running_mean, self.running_var, self.weight, self.bias,
61 | self.training, self.momentum, self.eps)
62 |
63 | # Resize the input to (B, C, -1).
64 | input_shape = input.size()
65 | input = input.view(input.size(0), self.num_features, -1)
66 |
67 | # Compute the sum and square-sum.
68 | sum_size = input.size(0) * input.size(2)
69 | input_sum = _sum_ft(input)
70 | input_ssum = _sum_ft(input ** 2)
71 |
72 | # Reduce-and-broadcast the statistics.
73 | if self._parallel_id == 0:
74 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
75 | else:
76 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
77 |
78 | # Compute the output.
79 | if self.affine:
80 | # MJY:: Fuse the multiplication for speed.
81 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
82 | else:
83 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
84 |
85 | # Reshape it.
86 | return output.view(input_shape)
87 |
88 | def __data_parallel_replicate__(self, ctx, copy_id):
89 | self._is_parallel = True
90 | self._parallel_id = copy_id
91 |
92 | # parallel_id == 0 means master device.
93 | if self._parallel_id == 0:
94 | ctx.sync_master = self._sync_master
95 | else:
96 | self._slave_pipe = ctx.sync_master.register_slave(copy_id)
97 |
98 | def _data_parallel_master(self, intermediates):
99 | """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
100 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
101 |
102 | to_reduce = [i[1][:2] for i in intermediates]
103 | to_reduce = [j for i in to_reduce for j in i] # flatten
104 | target_gpus = [i[1].sum.get_device() for i in intermediates]
105 |
106 | sum_size = sum([i[1].sum_size for i in intermediates])
107 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
108 |
109 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
110 |
111 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
112 |
113 | outputs = []
114 | for i, rec in enumerate(intermediates):
115 | outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
116 |
117 | return outputs
118 |
119 | def _add_weighted(self, dest, delta, alpha=1, beta=1, bias=0):
120 | """return *dest* by `dest := dest*alpha + delta*beta + bias`"""
121 | return dest * alpha + delta * beta + bias
122 |
123 | def _compute_mean_std(self, sum_, ssum, size):
124 | """Compute the mean and standard-deviation with sum and square-sum. This method
125 | also maintains the moving average on the master device."""
126 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
127 | mean = sum_ / size
128 | sumvar = ssum - sum_ * mean
129 | unbias_var = sumvar / (size - 1)
130 | bias_var = sumvar / size
131 |
132 | self._tmp_running_mean = self._add_weighted(self._tmp_running_mean, mean.data, alpha=self._moving_average_fraction)
133 | self._tmp_running_var = self._add_weighted(self._tmp_running_var, unbias_var.data, alpha=self._moving_average_fraction)
134 | self._running_iter = self._add_weighted(self._running_iter, 1, alpha=self._moving_average_fraction)
135 |
136 | self.running_mean = self._tmp_running_mean / self._running_iter
137 | self.running_var = self._tmp_running_var / self._running_iter
138 |
139 | return mean, bias_var.clamp(self.eps) ** -0.5
140 |
141 |
142 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
143 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
144 | mini-batch.
145 |
146 | .. math::
147 |
148 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
149 |
150 | This module differs from the built-in PyTorch BatchNorm1d as the mean and
151 | standard-deviation are reduced across all devices during training.
152 |
153 | For example, when one uses `nn.DataParallel` to wrap the network during
154 | training, PyTorch's implementation normalize the tensor on each device using
155 | the statistics only on that device, which accelerated the computation and
156 | is also easy to implement, but the statistics might be inaccurate.
157 | Instead, in this synchronized version, the statistics will be computed
158 | over all training samples distributed on multiple devices.
159 |
160 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
161 | as the built-in PyTorch implementation.
162 |
163 | The mean and standard-deviation are calculated per-dimension over
164 | the mini-batches and gamma and beta are learnable parameter vectors
165 | of size C (where C is the input size).
166 |
167 | During training, this layer keeps a running estimate of its computed mean
168 | and variance. The running sum is kept with a default momentum of 0.1.
169 |
170 | During evaluation, this running mean/variance is used for normalization.
171 |
172 | Because the BatchNorm is done over the `C` dimension, computing statistics
173 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
174 |
175 | Args:
176 | num_features: num_features from an expected input of size
177 | `batch_size x num_features [x width]`
178 | eps: a value added to the denominator for numerical stability.
179 | Default: 1e-5
180 | momentum: the value used for the running_mean and running_var
181 | computation. Default: 0.1
182 | affine: a boolean value that when set to ``True``, gives the layer learnable
183 | affine parameters. Default: ``True``
184 |
185 | Shape:
186 | - Input: :math:`(N, C)` or :math:`(N, C, L)`
187 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
188 |
189 | Examples:
190 | >>> # With Learnable Parameters
191 | >>> m = SynchronizedBatchNorm1d(100)
192 | >>> # Without Learnable Parameters
193 | >>> m = SynchronizedBatchNorm1d(100, affine=False)
194 | >>> input = torch.autograd.Variable(torch.randn(20, 100))
195 | >>> output = m(input)
196 | """
197 |
198 | def _check_input_dim(self, input):
199 | if input.dim() != 2 and input.dim() != 3:
200 | raise ValueError('expected 2D or 3D input (got {}D input)'
201 | .format(input.dim()))
202 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
203 |
204 |
205 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
206 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
207 | of 3d inputs
208 |
209 | .. math::
210 |
211 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
212 |
213 | This module differs from the built-in PyTorch BatchNorm2d as the mean and
214 | standard-deviation are reduced across all devices during training.
215 |
216 | For example, when one uses `nn.DataParallel` to wrap the network during
217 | training, PyTorch's implementation normalize the tensor on each device using
218 | the statistics only on that device, which accelerated the computation and
219 | is also easy to implement, but the statistics might be inaccurate.
220 | Instead, in this synchronized version, the statistics will be computed
221 | over all training samples distributed on multiple devices.
222 |
223 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
224 | as the built-in PyTorch implementation.
225 |
226 | The mean and standard-deviation are calculated per-dimension over
227 | the mini-batches and gamma and beta are learnable parameter vectors
228 | of size C (where C is the input size).
229 |
230 | During training, this layer keeps a running estimate of its computed mean
231 | and variance. The running sum is kept with a default momentum of 0.1.
232 |
233 | During evaluation, this running mean/variance is used for normalization.
234 |
235 | Because the BatchNorm is done over the `C` dimension, computing statistics
236 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
237 |
238 | Args:
239 | num_features: num_features from an expected input of
240 | size batch_size x num_features x height x width
241 | eps: a value added to the denominator for numerical stability.
242 | Default: 1e-5
243 | momentum: the value used for the running_mean and running_var
244 | computation. Default: 0.1
245 | affine: a boolean value that when set to ``True``, gives the layer learnable
246 | affine parameters. Default: ``True``
247 |
248 | Shape:
249 | - Input: :math:`(N, C, H, W)`
250 | - Output: :math:`(N, C, H, W)` (same shape as input)
251 |
252 | Examples:
253 | >>> # With Learnable Parameters
254 | >>> m = SynchronizedBatchNorm2d(100)
255 | >>> # Without Learnable Parameters
256 | >>> m = SynchronizedBatchNorm2d(100, affine=False)
257 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
258 | >>> output = m(input)
259 | """
260 |
261 | def _check_input_dim(self, input):
262 | if input.dim() != 4:
263 | raise ValueError('expected 4D input (got {}D input)'
264 | .format(input.dim()))
265 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
266 |
267 |
268 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
269 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
270 | of 4d inputs
271 |
272 | .. math::
273 |
274 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
275 |
276 | This module differs from the built-in PyTorch BatchNorm3d as the mean and
277 | standard-deviation are reduced across all devices during training.
278 |
279 | For example, when one uses `nn.DataParallel` to wrap the network during
280 | training, PyTorch's implementation normalize the tensor on each device using
281 | the statistics only on that device, which accelerated the computation and
282 | is also easy to implement, but the statistics might be inaccurate.
283 | Instead, in this synchronized version, the statistics will be computed
284 | over all training samples distributed on multiple devices.
285 |
286 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
287 | as the built-in PyTorch implementation.
288 |
289 | The mean and standard-deviation are calculated per-dimension over
290 | the mini-batches and gamma and beta are learnable parameter vectors
291 | of size C (where C is the input size).
292 |
293 | During training, this layer keeps a running estimate of its computed mean
294 | and variance. The running sum is kept with a default momentum of 0.1.
295 |
296 | During evaluation, this running mean/variance is used for normalization.
297 |
298 | Because the BatchNorm is done over the `C` dimension, computing statistics
299 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
300 | or Spatio-temporal BatchNorm
301 |
302 | Args:
303 | num_features: num_features from an expected input of
304 | size batch_size x num_features x depth x height x width
305 | eps: a value added to the denominator for numerical stability.
306 | Default: 1e-5
307 | momentum: the value used for the running_mean and running_var
308 | computation. Default: 0.1
309 | affine: a boolean value that when set to ``True``, gives the layer learnable
310 | affine parameters. Default: ``True``
311 |
312 | Shape:
313 | - Input: :math:`(N, C, D, H, W)`
314 | - Output: :math:`(N, C, D, H, W)` (same shape as input)
315 |
316 | Examples:
317 | >>> # With Learnable Parameters
318 | >>> m = SynchronizedBatchNorm3d(100)
319 | >>> # Without Learnable Parameters
320 | >>> m = SynchronizedBatchNorm3d(100, affine=False)
321 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
322 | >>> output = m(input)
323 | """
324 |
325 | def _check_input_dim(self, input):
326 | if input.dim() != 5:
327 | raise ValueError('expected 5D input (got {}D input)'
328 | .format(input.dim()))
329 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
330 |
--------------------------------------------------------------------------------
/segmentTool/lib/nn/modules/comm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : comm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import queue
12 | import collections
13 | import threading
14 |
15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
16 |
17 |
18 | class FutureResult(object):
19 | """A thread-safe future implementation. Used only as one-to-one pipe."""
20 |
21 | def __init__(self):
22 | self._result = None
23 | self._lock = threading.Lock()
24 | self._cond = threading.Condition(self._lock)
25 |
26 | def put(self, result):
27 | with self._lock:
28 | assert self._result is None, 'Previous result has\'t been fetched.'
29 | self._result = result
30 | self._cond.notify()
31 |
32 | def get(self):
33 | with self._lock:
34 | if self._result is None:
35 | self._cond.wait()
36 |
37 | res = self._result
38 | self._result = None
39 | return res
40 |
41 |
42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
44 |
45 |
46 | class SlavePipe(_SlavePipeBase):
47 | """Pipe for master-slave communication."""
48 |
49 | def run_slave(self, msg):
50 | self.queue.put((self.identifier, msg))
51 | ret = self.result.get()
52 | self.queue.put(True)
53 | return ret
54 |
55 |
56 | class SyncMaster(object):
57 | """An abstract `SyncMaster` object.
58 |
59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master.
61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
62 | and passed to a registered callback.
63 | - After receiving the messages, the master device should gather the information and determine to message passed
64 | back to each slave devices.
65 | """
66 |
67 | def __init__(self, master_callback):
68 | """
69 |
70 | Args:
71 | master_callback: a callback to be invoked after having collected messages from slave devices.
72 | """
73 | self._master_callback = master_callback
74 | self._queue = queue.Queue()
75 | self._registry = collections.OrderedDict()
76 | self._activated = False
77 |
78 | def register_slave(self, identifier):
79 | """
80 | Register an slave device.
81 |
82 | Args:
83 | identifier: an identifier, usually is the device id.
84 |
85 | Returns: a `SlavePipe` object which can be used to communicate with the master device.
86 |
87 | """
88 | if self._activated:
89 | assert self._queue.empty(), 'Queue is not clean before next initialization.'
90 | self._activated = False
91 | self._registry.clear()
92 | future = FutureResult()
93 | self._registry[identifier] = _MasterRegistry(future)
94 | return SlavePipe(identifier, self._queue, future)
95 |
96 | def run_master(self, master_msg):
97 | """
98 | Main entry for the master device in each forward pass.
99 | The messages were first collected from each devices (including the master device), and then
100 | an callback will be invoked to compute the message to be sent back to each devices
101 | (including the master device).
102 |
103 | Args:
104 | master_msg: the message that the master want to send to itself. This will be placed as the first
105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
106 |
107 | Returns: the message to be sent back to the master device.
108 |
109 | """
110 | self._activated = True
111 |
112 | intermediates = [(0, master_msg)]
113 | for i in range(self.nr_slaves):
114 | intermediates.append(self._queue.get())
115 |
116 | results = self._master_callback(intermediates)
117 | assert results[0][0] == 0, 'The first result should belongs to the master.'
118 |
119 | for i, res in results:
120 | if i == 0:
121 | continue
122 | self._registry[i].result.put(res)
123 |
124 | for i in range(self.nr_slaves):
125 | assert self._queue.get() is True
126 |
127 | return results[0][1]
128 |
129 | @property
130 | def nr_slaves(self):
131 | return len(self._registry)
132 |
--------------------------------------------------------------------------------
/segmentTool/lib/nn/modules/replicate.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : replicate.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import functools
12 |
13 | from torch.nn.parallel.data_parallel import DataParallel
14 |
15 | __all__ = [
16 | 'CallbackContext',
17 | 'execute_replication_callbacks',
18 | 'DataParallelWithCallback',
19 | 'patch_replication_callback'
20 | ]
21 |
22 |
23 | class CallbackContext(object):
24 | pass
25 |
26 |
27 | def execute_replication_callbacks(modules):
28 | """
29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
30 |
31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
32 |
33 | Note that, as all modules are isomorphism, we assign each sub-module with a context
34 | (shared among multiple copies of this module on different devices).
35 | Through this context, different copies can share some information.
36 |
37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
38 | of any slave copies.
39 | """
40 | master_copy = modules[0]
41 | nr_modules = len(list(master_copy.modules()))
42 | ctxs = [CallbackContext() for _ in range(nr_modules)]
43 |
44 | for i, module in enumerate(modules):
45 | for j, m in enumerate(module.modules()):
46 | if hasattr(m, '__data_parallel_replicate__'):
47 | m.__data_parallel_replicate__(ctxs[j], i)
48 |
49 |
50 | class DataParallelWithCallback(DataParallel):
51 | """
52 | Data Parallel with a replication callback.
53 |
54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
55 | original `replicate` function.
56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
57 |
58 | Examples:
59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
61 | # sync_bn.__data_parallel_replicate__ will be invoked.
62 | """
63 |
64 | def replicate(self, module, device_ids):
65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
66 | execute_replication_callbacks(modules)
67 | return modules
68 |
69 |
70 | def patch_replication_callback(data_parallel):
71 | """
72 | Monkey-patch an existing `DataParallel` object. Add the replication callback.
73 | Useful when you have customized `DataParallel` implementation.
74 |
75 | Examples:
76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
78 | > patch_replication_callback(sync_bn)
79 | # this is equivalent to
80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
82 | """
83 |
84 | assert isinstance(data_parallel, DataParallel)
85 |
86 | old_replicate = data_parallel.replicate
87 |
88 | @functools.wraps(old_replicate)
89 | def new_replicate(module, device_ids):
90 | modules = old_replicate(module, device_ids)
91 | execute_replication_callbacks(modules)
92 | return modules
93 |
94 | data_parallel.replicate = new_replicate
95 |
--------------------------------------------------------------------------------
/segmentTool/lib/nn/modules/tests/test_numeric_batchnorm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : test_numeric_batchnorm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 |
9 | import unittest
10 |
11 | import torch
12 | import torch.nn as nn
13 | from torch.autograd import Variable
14 |
15 | from sync_batchnorm.unittest import TorchTestCase
16 |
17 |
18 | def handy_var(a, unbias=True):
19 | n = a.size(0)
20 | asum = a.sum(dim=0)
21 | as_sum = (a ** 2).sum(dim=0) # a square sum
22 | sumvar = as_sum - asum * asum / n
23 | if unbias:
24 | return sumvar / (n - 1)
25 | else:
26 | return sumvar / n
27 |
28 |
29 | class NumericTestCase(TorchTestCase):
30 | def testNumericBatchNorm(self):
31 | a = torch.rand(16, 10)
32 | bn = nn.BatchNorm2d(10, momentum=1, eps=1e-5, affine=False)
33 | bn.train()
34 |
35 | a_var1 = Variable(a, requires_grad=True)
36 | b_var1 = bn(a_var1)
37 | loss1 = b_var1.sum()
38 | loss1.backward()
39 |
40 | a_var2 = Variable(a, requires_grad=True)
41 | a_mean2 = a_var2.mean(dim=0, keepdim=True)
42 | a_std2 = torch.sqrt(handy_var(a_var2, unbias=False).clamp(min=1e-5))
43 | # a_std2 = torch.sqrt(a_var2.var(dim=0, keepdim=True, unbiased=False) + 1e-5)
44 | b_var2 = (a_var2 - a_mean2) / a_std2
45 | loss2 = b_var2.sum()
46 | loss2.backward()
47 |
48 | self.assertTensorClose(bn.running_mean, a.mean(dim=0))
49 | self.assertTensorClose(bn.running_var, handy_var(a))
50 | self.assertTensorClose(a_var1.data, a_var2.data)
51 | self.assertTensorClose(b_var1.data, b_var2.data)
52 | self.assertTensorClose(a_var1.grad, a_var2.grad)
53 |
54 |
55 | if __name__ == '__main__':
56 | unittest.main()
57 |
--------------------------------------------------------------------------------
/segmentTool/lib/nn/modules/tests/test_sync_batchnorm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : test_sync_batchnorm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 |
9 | import unittest
10 |
11 | import torch
12 | import torch.nn as nn
13 | from torch.autograd import Variable
14 |
15 | from sync_batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, DataParallelWithCallback
16 | from sync_batchnorm.unittest import TorchTestCase
17 |
18 |
19 | def handy_var(a, unbias=True):
20 | n = a.size(0)
21 | asum = a.sum(dim=0)
22 | as_sum = (a ** 2).sum(dim=0) # a square sum
23 | sumvar = as_sum - asum * asum / n
24 | if unbias:
25 | return sumvar / (n - 1)
26 | else:
27 | return sumvar / n
28 |
29 |
30 | def _find_bn(module):
31 | for m in module.modules():
32 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, SynchronizedBatchNorm1d, SynchronizedBatchNorm2d)):
33 | return m
34 |
35 |
36 | class SyncTestCase(TorchTestCase):
37 | def _syncParameters(self, bn1, bn2):
38 | bn1.reset_parameters()
39 | bn2.reset_parameters()
40 | if bn1.affine and bn2.affine:
41 | bn2.weight.data.copy_(bn1.weight.data)
42 | bn2.bias.data.copy_(bn1.bias.data)
43 |
44 | def _checkBatchNormResult(self, bn1, bn2, input, is_train, cuda=False):
45 | """Check the forward and backward for the customized batch normalization."""
46 | bn1.train(mode=is_train)
47 | bn2.train(mode=is_train)
48 |
49 | if cuda:
50 | input = input.cuda()
51 |
52 | self._syncParameters(_find_bn(bn1), _find_bn(bn2))
53 |
54 | input1 = Variable(input, requires_grad=True)
55 | output1 = bn1(input1)
56 | output1.sum().backward()
57 | input2 = Variable(input, requires_grad=True)
58 | output2 = bn2(input2)
59 | output2.sum().backward()
60 |
61 | self.assertTensorClose(input1.data, input2.data)
62 | self.assertTensorClose(output1.data, output2.data)
63 | self.assertTensorClose(input1.grad, input2.grad)
64 | self.assertTensorClose(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean)
65 | self.assertTensorClose(_find_bn(bn1).running_var, _find_bn(bn2).running_var)
66 |
67 | def testSyncBatchNormNormalTrain(self):
68 | bn = nn.BatchNorm1d(10)
69 | sync_bn = SynchronizedBatchNorm1d(10)
70 |
71 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True)
72 |
73 | def testSyncBatchNormNormalEval(self):
74 | bn = nn.BatchNorm1d(10)
75 | sync_bn = SynchronizedBatchNorm1d(10)
76 |
77 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False)
78 |
79 | def testSyncBatchNormSyncTrain(self):
80 | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
81 | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
82 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
83 |
84 | bn.cuda()
85 | sync_bn.cuda()
86 |
87 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True, cuda=True)
88 |
89 | def testSyncBatchNormSyncEval(self):
90 | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
91 | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
92 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
93 |
94 | bn.cuda()
95 | sync_bn.cuda()
96 |
97 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True)
98 |
99 | def testSyncBatchNorm2DSyncTrain(self):
100 | bn = nn.BatchNorm2d(10)
101 | sync_bn = SynchronizedBatchNorm2d(10)
102 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
103 |
104 | bn.cuda()
105 | sync_bn.cuda()
106 |
107 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True)
108 |
109 |
110 | if __name__ == '__main__':
111 | unittest.main()
112 |
--------------------------------------------------------------------------------
/segmentTool/lib/nn/modules/unittest.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : unittest.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import unittest
12 |
13 | import numpy as np
14 | from torch.autograd import Variable
15 |
16 |
17 | def as_numpy(v):
18 | if isinstance(v, Variable):
19 | v = v.data
20 | return v.cpu().numpy()
21 |
22 |
23 | class TorchTestCase(unittest.TestCase):
24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
25 | npa, npb = as_numpy(a), as_numpy(b)
26 | self.assertTrue(
27 | np.allclose(npa, npb, atol=atol),
28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
29 | )
30 |
--------------------------------------------------------------------------------
/segmentTool/lib/nn/parallel/__init__.py:
--------------------------------------------------------------------------------
1 | from .data_parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to
2 |
--------------------------------------------------------------------------------
/segmentTool/lib/nn/parallel/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pbw-Berwin/View-Parsing-Network/9ccfcce978aecbe6c469256013e6ca94a2436b67/segmentTool/lib/nn/parallel/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/segmentTool/lib/nn/parallel/__pycache__/data_parallel.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pbw-Berwin/View-Parsing-Network/9ccfcce978aecbe6c469256013e6ca94a2436b67/segmentTool/lib/nn/parallel/__pycache__/data_parallel.cpython-36.pyc
--------------------------------------------------------------------------------
/segmentTool/lib/nn/parallel/data_parallel.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf8 -*-
2 |
3 | import torch.cuda as cuda
4 | import torch.nn as nn
5 | import torch
6 | from torch.autograd import Variable
7 | import collections
8 | from torch.nn.parallel._functions import Gather
9 |
10 | __all__ = ['UserScatteredDataParallel', 'user_scattered_collate', 'async_copy_to']
11 |
12 | def async_copy_to(obj, dev, main_stream=None):
13 | if torch.is_tensor(obj):
14 | obj = Variable(obj)
15 | if isinstance(obj, Variable):
16 | v = obj.cuda(dev, async=True)
17 | if main_stream is not None:
18 | v.data.record_stream(main_stream)
19 | return v
20 | elif isinstance(obj, collections.Mapping):
21 | return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()}
22 | elif isinstance(obj, collections.Sequence):
23 | return [async_copy_to(o, dev, main_stream) for o in obj]
24 | else:
25 | return obj
26 |
27 |
28 | def dict_gather(outputs, target_device, dim=0):
29 | """
30 | Gathers variables from different GPUs on a specified device
31 | (-1 means the CPU), with dictionary support.
32 | """
33 | def gather_map(outputs):
34 | out = outputs[0]
35 | if isinstance(out, Variable):
36 | # MJY(20180330) HACK:: force nr_dims > 0
37 | if out.dim() == 0:
38 | outputs = [o.unsqueeze(0) for o in outputs]
39 | return Gather.apply(target_device, dim, *outputs)
40 | elif out is None:
41 | return None
42 | elif isinstance(out, collections.Mapping):
43 | return {k: gather_map([o[k] for o in outputs]) for k in out}
44 | elif isinstance(out, collections.Sequence):
45 | return type(out)(map(gather_map, zip(*outputs)))
46 | return gather_map(outputs)
47 |
48 |
49 | class DictGatherDataParallel(nn.DataParallel):
50 | def gather(self, outputs, output_device):
51 | return dict_gather(outputs, output_device, dim=self.dim)
52 |
53 |
54 | class UserScatteredDataParallel(DictGatherDataParallel):
55 | def scatter(self, inputs, kwargs, device_ids):
56 | assert len(inputs) == 1
57 | inputs = inputs[0]
58 | inputs = _async_copy_stream(inputs, device_ids)
59 | inputs = [[i] for i in inputs]
60 | assert len(kwargs) == 0
61 | kwargs = [{} for _ in range(len(inputs))]
62 |
63 | return inputs, kwargs
64 |
65 |
66 | def user_scattered_collate(batch):
67 | return batch
68 |
69 |
70 | def _async_copy(inputs, device_ids):
71 | nr_devs = len(device_ids)
72 | assert type(inputs) in (tuple, list)
73 | assert len(inputs) == nr_devs
74 |
75 | outputs = []
76 | for i, dev in zip(inputs, device_ids):
77 | with cuda.device(dev):
78 | outputs.append(async_copy_to(i, dev))
79 |
80 | return tuple(outputs)
81 |
82 |
83 | def _async_copy_stream(inputs, device_ids):
84 | nr_devs = len(device_ids)
85 | assert type(inputs) in (tuple, list)
86 | assert len(inputs) == nr_devs
87 |
88 | outputs = []
89 | streams = [_get_stream(d) for d in device_ids]
90 | for i, dev, stream in zip(inputs, device_ids, streams):
91 | with cuda.device(dev):
92 | main_stream = cuda.current_stream()
93 | with cuda.stream(stream):
94 | outputs.append(async_copy_to(i, dev, main_stream=main_stream))
95 | main_stream.wait_stream(stream)
96 |
97 | return outputs
98 |
99 |
100 | """Adapted from: torch/nn/parallel/_functions.py"""
101 | # background streams used for copying
102 | _streams = None
103 |
104 |
105 | def _get_stream(device):
106 | """Gets a background stream for copying between CPU and GPU"""
107 | global _streams
108 | if device == -1:
109 | return None
110 | if _streams is None:
111 | _streams = [None] * cuda.device_count()
112 | if _streams[device] is None: _streams[device] = cuda.Stream(device)
113 | return _streams[device]
114 |
--------------------------------------------------------------------------------
/segmentTool/lib/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .th import *
2 |
--------------------------------------------------------------------------------
/segmentTool/lib/utils/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pbw-Berwin/View-Parsing-Network/9ccfcce978aecbe6c469256013e6ca94a2436b67/segmentTool/lib/utils/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/segmentTool/lib/utils/__pycache__/th.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pbw-Berwin/View-Parsing-Network/9ccfcce978aecbe6c469256013e6ca94a2436b67/segmentTool/lib/utils/__pycache__/th.cpython-36.pyc
--------------------------------------------------------------------------------
/segmentTool/lib/utils/data/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from .dataset import Dataset, TensorDataset, ConcatDataset
3 | from .dataloader import DataLoader
4 |
--------------------------------------------------------------------------------
/segmentTool/lib/utils/data/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pbw-Berwin/View-Parsing-Network/9ccfcce978aecbe6c469256013e6ca94a2436b67/segmentTool/lib/utils/data/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/segmentTool/lib/utils/data/__pycache__/dataloader.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pbw-Berwin/View-Parsing-Network/9ccfcce978aecbe6c469256013e6ca94a2436b67/segmentTool/lib/utils/data/__pycache__/dataloader.cpython-36.pyc
--------------------------------------------------------------------------------
/segmentTool/lib/utils/data/__pycache__/dataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pbw-Berwin/View-Parsing-Network/9ccfcce978aecbe6c469256013e6ca94a2436b67/segmentTool/lib/utils/data/__pycache__/dataset.cpython-36.pyc
--------------------------------------------------------------------------------
/segmentTool/lib/utils/data/__pycache__/sampler.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pbw-Berwin/View-Parsing-Network/9ccfcce978aecbe6c469256013e6ca94a2436b67/segmentTool/lib/utils/data/__pycache__/sampler.cpython-36.pyc
--------------------------------------------------------------------------------
/segmentTool/lib/utils/data/dataloader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.multiprocessing as multiprocessing
3 | from torch._C import _set_worker_signal_handlers, _update_worker_pids, \
4 | _remove_worker_pids, _error_if_any_worker_fails
5 | from .sampler import SequentialSampler, RandomSampler, BatchSampler
6 | import signal
7 | import functools
8 | import collections
9 | import re
10 | import sys
11 | import threading
12 | import traceback
13 | from torch._six import string_classes, int_classes
14 | import numpy as np
15 |
16 | if sys.version_info[0] == 2:
17 | import Queue as queue
18 | else:
19 | import queue
20 |
21 |
22 | class ExceptionWrapper(object):
23 | r"Wraps an exception plus traceback to communicate across threads"
24 |
25 | def __init__(self, exc_info):
26 | self.exc_type = exc_info[0]
27 | self.exc_msg = "".join(traceback.format_exception(*exc_info))
28 |
29 |
30 | _use_shared_memory = False
31 | """Whether to use shared memory in default_collate"""
32 |
33 |
34 | def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id):
35 | global _use_shared_memory
36 | _use_shared_memory = True
37 |
38 | # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
39 | # module's handlers are executed after Python returns from C low-level
40 | # handlers, likely when the same fatal signal happened again already.
41 | # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1
42 | _set_worker_signal_handlers()
43 |
44 | torch.set_num_threads(1)
45 | torch.manual_seed(seed)
46 | np.random.seed(seed)
47 |
48 | if init_fn is not None:
49 | init_fn(worker_id)
50 |
51 | while True:
52 | r = index_queue.get()
53 | if r is None:
54 | break
55 | idx, batch_indices = r
56 | try:
57 | samples = collate_fn([dataset[i] for i in batch_indices])
58 | except Exception:
59 | data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
60 | else:
61 | data_queue.put((idx, samples))
62 |
63 |
64 | def _worker_manager_loop(in_queue, out_queue, done_event, pin_memory, device_id):
65 | if pin_memory:
66 | torch.cuda.set_device(device_id)
67 |
68 | while True:
69 | try:
70 | r = in_queue.get()
71 | except Exception:
72 | if done_event.is_set():
73 | return
74 | raise
75 | if r is None:
76 | break
77 | if isinstance(r[1], ExceptionWrapper):
78 | out_queue.put(r)
79 | continue
80 | idx, batch = r
81 | try:
82 | if pin_memory:
83 | batch = pin_memory_batch(batch)
84 | except Exception:
85 | out_queue.put((idx, ExceptionWrapper(sys.exc_info())))
86 | else:
87 | out_queue.put((idx, batch))
88 |
89 | numpy_type_map = {
90 | 'float64': torch.DoubleTensor,
91 | 'float32': torch.FloatTensor,
92 | 'float16': torch.HalfTensor,
93 | 'int64': torch.LongTensor,
94 | 'int32': torch.IntTensor,
95 | 'int16': torch.ShortTensor,
96 | 'int8': torch.CharTensor,
97 | 'uint8': torch.ByteTensor,
98 | }
99 |
100 |
101 | def default_collate(batch):
102 | "Puts each data field into a tensor with outer dimension batch size"
103 |
104 | error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
105 | elem_type = type(batch[0])
106 | if torch.is_tensor(batch[0]):
107 | out = None
108 | if _use_shared_memory:
109 | # If we're in a background process, concatenate directly into a
110 | # shared memory tensor to avoid an extra copy
111 | numel = sum([x.numel() for x in batch])
112 | storage = batch[0].storage()._new_shared(numel)
113 | out = batch[0].new(storage)
114 | return torch.stack(batch, 0, out=out)
115 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
116 | and elem_type.__name__ != 'string_':
117 | elem = batch[0]
118 | if elem_type.__name__ == 'ndarray':
119 | # array of string classes and object
120 | if re.search('[SaUO]', elem.dtype.str) is not None:
121 | raise TypeError(error_msg.format(elem.dtype))
122 |
123 | return torch.stack([torch.from_numpy(b) for b in batch], 0)
124 | if elem.shape == (): # scalars
125 | py_type = float if elem.dtype.name.startswith('float') else int
126 | return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
127 | elif isinstance(batch[0], int_classes):
128 | return torch.LongTensor(batch)
129 | elif isinstance(batch[0], float):
130 | return torch.DoubleTensor(batch)
131 | elif isinstance(batch[0], string_classes):
132 | return batch
133 | elif isinstance(batch[0], collections.Mapping):
134 | return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
135 | elif isinstance(batch[0], collections.Sequence):
136 | transposed = zip(*batch)
137 | return [default_collate(samples) for samples in transposed]
138 |
139 | raise TypeError((error_msg.format(type(batch[0]))))
140 |
141 |
142 | def pin_memory_batch(batch):
143 | if torch.is_tensor(batch):
144 | return batch.pin_memory()
145 | elif isinstance(batch, string_classes):
146 | return batch
147 | elif isinstance(batch, collections.Mapping):
148 | return {k: pin_memory_batch(sample) for k, sample in batch.items()}
149 | elif isinstance(batch, collections.Sequence):
150 | return [pin_memory_batch(sample) for sample in batch]
151 | else:
152 | return batch
153 |
154 |
155 | _SIGCHLD_handler_set = False
156 | """Whether SIGCHLD handler is set for DataLoader worker failures. Only one
157 | handler needs to be set for all DataLoaders in a process."""
158 |
159 |
160 | def _set_SIGCHLD_handler():
161 | # Windows doesn't support SIGCHLD handler
162 | if sys.platform == 'win32':
163 | return
164 | # can't set signal in child threads
165 | if not isinstance(threading.current_thread(), threading._MainThread):
166 | return
167 | global _SIGCHLD_handler_set
168 | if _SIGCHLD_handler_set:
169 | return
170 | previous_handler = signal.getsignal(signal.SIGCHLD)
171 | if not callable(previous_handler):
172 | previous_handler = None
173 |
174 | def handler(signum, frame):
175 | # This following call uses `waitid` with WNOHANG from C side. Therefore,
176 | # Python can still get and update the process status successfully.
177 | _error_if_any_worker_fails()
178 | if previous_handler is not None:
179 | previous_handler(signum, frame)
180 |
181 | signal.signal(signal.SIGCHLD, handler)
182 | _SIGCHLD_handler_set = True
183 |
184 |
185 | class DataLoaderIter(object):
186 | "Iterates once over the DataLoader's dataset, as specified by the sampler"
187 |
188 | def __init__(self, loader):
189 | self.dataset = loader.dataset
190 | self.collate_fn = loader.collate_fn
191 | self.batch_sampler = loader.batch_sampler
192 | self.num_workers = loader.num_workers
193 | self.pin_memory = loader.pin_memory and torch.cuda.is_available()
194 | self.timeout = loader.timeout
195 | self.done_event = threading.Event()
196 |
197 | self.sample_iter = iter(self.batch_sampler)
198 |
199 | if self.num_workers > 0:
200 | self.worker_init_fn = loader.worker_init_fn
201 | self.index_queue = multiprocessing.SimpleQueue()
202 | self.worker_result_queue = multiprocessing.SimpleQueue()
203 | self.batches_outstanding = 0
204 | self.worker_pids_set = False
205 | self.shutdown = False
206 | self.send_idx = 0
207 | self.rcvd_idx = 0
208 | self.reorder_dict = {}
209 |
210 | base_seed = torch.LongTensor(1).random_(0, 2**31-1)[0]
211 | self.workers = [
212 | multiprocessing.Process(
213 | target=_worker_loop,
214 | args=(self.dataset, self.index_queue, self.worker_result_queue, self.collate_fn,
215 | base_seed + i, self.worker_init_fn, i))
216 | for i in range(self.num_workers)]
217 |
218 | if self.pin_memory or self.timeout > 0:
219 | self.data_queue = queue.Queue()
220 | if self.pin_memory:
221 | maybe_device_id = torch.cuda.current_device()
222 | else:
223 | # do not initialize cuda context if not necessary
224 | maybe_device_id = None
225 | self.worker_manager_thread = threading.Thread(
226 | target=_worker_manager_loop,
227 | args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
228 | maybe_device_id))
229 | self.worker_manager_thread.daemon = True
230 | self.worker_manager_thread.start()
231 | else:
232 | self.data_queue = self.worker_result_queue
233 |
234 | for w in self.workers:
235 | w.daemon = True # ensure that the worker exits on process exit
236 | w.start()
237 |
238 | _update_worker_pids(id(self), tuple(w.pid for w in self.workers))
239 | _set_SIGCHLD_handler()
240 | self.worker_pids_set = True
241 |
242 | # prime the prefetch loop
243 | for _ in range(2 * self.num_workers):
244 | self._put_indices()
245 |
246 | def __len__(self):
247 | return len(self.batch_sampler)
248 |
249 | def _get_batch(self):
250 | if self.timeout > 0:
251 | try:
252 | return self.data_queue.get(timeout=self.timeout)
253 | except queue.Empty:
254 | raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))
255 | else:
256 | return self.data_queue.get()
257 |
258 | def __next__(self):
259 | if self.num_workers == 0: # same-process loading
260 | indices = next(self.sample_iter) # may raise StopIteration
261 | batch = self.collate_fn([self.dataset[i] for i in indices])
262 | if self.pin_memory:
263 | batch = pin_memory_batch(batch)
264 | return batch
265 |
266 | # check if the next sample has already been generated
267 | if self.rcvd_idx in self.reorder_dict:
268 | batch = self.reorder_dict.pop(self.rcvd_idx)
269 | return self._process_next_batch(batch)
270 |
271 | if self.batches_outstanding == 0:
272 | self._shutdown_workers()
273 | raise StopIteration
274 |
275 | while True:
276 | assert (not self.shutdown and self.batches_outstanding > 0)
277 | idx, batch = self._get_batch()
278 | self.batches_outstanding -= 1
279 | if idx != self.rcvd_idx:
280 | # store out-of-order samples
281 | self.reorder_dict[idx] = batch
282 | continue
283 | return self._process_next_batch(batch)
284 |
285 | next = __next__ # Python 2 compatibility
286 |
287 | def __iter__(self):
288 | return self
289 |
290 | def _put_indices(self):
291 | assert self.batches_outstanding < 2 * self.num_workers
292 | indices = next(self.sample_iter, None)
293 | if indices is None:
294 | return
295 | self.index_queue.put((self.send_idx, indices))
296 | self.batches_outstanding += 1
297 | self.send_idx += 1
298 |
299 | def _process_next_batch(self, batch):
300 | self.rcvd_idx += 1
301 | self._put_indices()
302 | if isinstance(batch, ExceptionWrapper):
303 | raise batch.exc_type(batch.exc_msg)
304 | return batch
305 |
306 | def __getstate__(self):
307 | # TODO: add limited pickling support for sharing an iterator
308 | # across multiple threads for HOGWILD.
309 | # Probably the best way to do this is by moving the sample pushing
310 | # to a separate thread and then just sharing the data queue
311 | # but signalling the end is tricky without a non-blocking API
312 | raise NotImplementedError("DataLoaderIterator cannot be pickled")
313 |
314 | def _shutdown_workers(self):
315 | try:
316 | if not self.shutdown:
317 | self.shutdown = True
318 | self.done_event.set()
319 | # if worker_manager_thread is waiting to put
320 | while not self.data_queue.empty():
321 | self.data_queue.get()
322 | for _ in self.workers:
323 | self.index_queue.put(None)
324 | # done_event should be sufficient to exit worker_manager_thread,
325 | # but be safe here and put another None
326 | self.worker_result_queue.put(None)
327 | finally:
328 | # removes pids no matter what
329 | if self.worker_pids_set:
330 | _remove_worker_pids(id(self))
331 | self.worker_pids_set = False
332 |
333 | def __del__(self):
334 | if self.num_workers > 0:
335 | self._shutdown_workers()
336 |
337 |
338 | class DataLoader(object):
339 | """
340 | Data loader. Combines a dataset and a sampler, and provides
341 | single- or multi-process iterators over the dataset.
342 |
343 | Arguments:
344 | dataset (Dataset): dataset from which to load the data.
345 | batch_size (int, optional): how many samples per batch to load
346 | (default: 1).
347 | shuffle (bool, optional): set to ``True`` to have the data reshuffled
348 | at every epoch (default: False).
349 | sampler (Sampler, optional): defines the strategy to draw samples from
350 | the dataset. If specified, ``shuffle`` must be False.
351 | batch_sampler (Sampler, optional): like sampler, but returns a batch of
352 | indices at a time. Mutually exclusive with batch_size, shuffle,
353 | sampler, and drop_last.
354 | num_workers (int, optional): how many subprocesses to use for data
355 | loading. 0 means that the data will be loaded in the main process.
356 | (default: 0)
357 | collate_fn (callable, optional): merges a list of samples to form a mini-batch.
358 | pin_memory (bool, optional): If ``True``, the data loader will copy tensors
359 | into CUDA pinned memory before returning them.
360 | drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
361 | if the dataset size is not divisible by the batch size. If ``False`` and
362 | the size of dataset is not divisible by the batch size, then the last batch
363 | will be smaller. (default: False)
364 | timeout (numeric, optional): if positive, the timeout value for collecting a batch
365 | from workers. Should always be non-negative. (default: 0)
366 | worker_init_fn (callable, optional): If not None, this will be called on each
367 | worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
368 | input, after seeding and before data loading. (default: None)
369 |
370 | .. note:: By default, each worker will have its PyTorch seed set to
371 | ``base_seed + worker_id``, where ``base_seed`` is a long generated
372 | by main process using its RNG. You may use ``torch.initial_seed()`` to access
373 | this value in :attr:`worker_init_fn`, which can be used to set other seeds
374 | (e.g. NumPy) before data loading.
375 |
376 | .. warning:: If ``spawn'' start method is used, :attr:`worker_init_fn` cannot be an
377 | unpicklable object, e.g., a lambda function.
378 | """
379 |
380 | def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
381 | num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
382 | timeout=0, worker_init_fn=None):
383 | self.dataset = dataset
384 | self.batch_size = batch_size
385 | self.num_workers = num_workers
386 | self.collate_fn = collate_fn
387 | self.pin_memory = pin_memory
388 | self.drop_last = drop_last
389 | self.timeout = timeout
390 | self.worker_init_fn = worker_init_fn
391 |
392 | if timeout < 0:
393 | raise ValueError('timeout option should be non-negative')
394 |
395 | if batch_sampler is not None:
396 | if batch_size > 1 or shuffle or sampler is not None or drop_last:
397 | raise ValueError('batch_sampler is mutually exclusive with '
398 | 'batch_size, shuffle, sampler, and drop_last')
399 |
400 | if sampler is not None and shuffle:
401 | raise ValueError('sampler is mutually exclusive with shuffle')
402 |
403 | if self.num_workers < 0:
404 | raise ValueError('num_workers cannot be negative; '
405 | 'use num_workers=0 to disable multiprocessing.')
406 |
407 | if batch_sampler is None:
408 | if sampler is None:
409 | if shuffle:
410 | sampler = RandomSampler(dataset)
411 | else:
412 | sampler = SequentialSampler(dataset)
413 | batch_sampler = BatchSampler(sampler, batch_size, drop_last)
414 |
415 | self.sampler = sampler
416 | self.batch_sampler = batch_sampler
417 |
418 | def __iter__(self):
419 | return DataLoaderIter(self)
420 |
421 | def __len__(self):
422 | return len(self.batch_sampler)
423 |
--------------------------------------------------------------------------------
/segmentTool/lib/utils/data/dataset.py:
--------------------------------------------------------------------------------
1 | import bisect
2 | import warnings
3 |
4 | from torch._utils import _accumulate
5 | from torch import randperm
6 |
7 |
8 | class Dataset(object):
9 | """An abstract class representing a Dataset.
10 |
11 | All other datasets should subclass it. All subclasses should override
12 | ``__len__``, that provides the size of the dataset, and ``__getitem__``,
13 | supporting integer indexing in range from 0 to len(self) exclusive.
14 | """
15 |
16 | def __getitem__(self, index):
17 | raise NotImplementedError
18 |
19 | def __len__(self):
20 | raise NotImplementedError
21 |
22 | def __add__(self, other):
23 | return ConcatDataset([self, other])
24 |
25 |
26 | class TensorDataset(Dataset):
27 | """Dataset wrapping data and target tensors.
28 |
29 | Each sample will be retrieved by indexing both tensors along the first
30 | dimension.
31 |
32 | Arguments:
33 | data_tensor (Tensor): contains sample data.
34 | target_tensor (Tensor): contains sample targets (labels).
35 | """
36 |
37 | def __init__(self, data_tensor, target_tensor):
38 | assert data_tensor.size(0) == target_tensor.size(0)
39 | self.data_tensor = data_tensor
40 | self.target_tensor = target_tensor
41 |
42 | def __getitem__(self, index):
43 | return self.data_tensor[index], self.target_tensor[index]
44 |
45 | def __len__(self):
46 | return self.data_tensor.size(0)
47 |
48 |
49 | class ConcatDataset(Dataset):
50 | """
51 | Dataset to concatenate multiple datasets.
52 | Purpose: useful to assemble different existing datasets, possibly
53 | large-scale datasets as the concatenation operation is done in an
54 | on-the-fly manner.
55 |
56 | Arguments:
57 | datasets (iterable): List of datasets to be concatenated
58 | """
59 |
60 | @staticmethod
61 | def cumsum(sequence):
62 | r, s = [], 0
63 | for e in sequence:
64 | l = len(e)
65 | r.append(l + s)
66 | s += l
67 | return r
68 |
69 | def __init__(self, datasets):
70 | super(ConcatDataset, self).__init__()
71 | assert len(datasets) > 0, 'datasets should not be an empty iterable'
72 | self.datasets = list(datasets)
73 | self.cumulative_sizes = self.cumsum(self.datasets)
74 |
75 | def __len__(self):
76 | return self.cumulative_sizes[-1]
77 |
78 | def __getitem__(self, idx):
79 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
80 | if dataset_idx == 0:
81 | sample_idx = idx
82 | else:
83 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
84 | return self.datasets[dataset_idx][sample_idx]
85 |
86 | @property
87 | def cummulative_sizes(self):
88 | warnings.warn("cummulative_sizes attribute is renamed to "
89 | "cumulative_sizes", DeprecationWarning, stacklevel=2)
90 | return self.cumulative_sizes
91 |
92 |
93 | class Subset(Dataset):
94 | def __init__(self, dataset, indices):
95 | self.dataset = dataset
96 | self.indices = indices
97 |
98 | def __getitem__(self, idx):
99 | return self.dataset[self.indices[idx]]
100 |
101 | def __len__(self):
102 | return len(self.indices)
103 |
104 |
105 | def random_split(dataset, lengths):
106 | """
107 | Randomly split a dataset into non-overlapping new datasets of given lengths
108 | ds
109 |
110 | Arguments:
111 | dataset (Dataset): Dataset to be split
112 | lengths (iterable): lengths of splits to be produced
113 | """
114 | if sum(lengths) != len(dataset):
115 | raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
116 |
117 | indices = randperm(sum(lengths))
118 | return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]
119 |
--------------------------------------------------------------------------------
/segmentTool/lib/utils/data/distributed.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from .sampler import Sampler
4 | from torch.distributed import get_world_size, get_rank
5 |
6 |
7 | class DistributedSampler(Sampler):
8 | """Sampler that restricts data loading to a subset of the dataset.
9 |
10 | It is especially useful in conjunction with
11 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
12 | process can pass a DistributedSampler instance as a DataLoader sampler,
13 | and load a subset of the original dataset that is exclusive to it.
14 |
15 | .. note::
16 | Dataset is assumed to be of constant size.
17 |
18 | Arguments:
19 | dataset: Dataset used for sampling.
20 | num_replicas (optional): Number of processes participating in
21 | distributed training.
22 | rank (optional): Rank of the current process within num_replicas.
23 | """
24 |
25 | def __init__(self, dataset, num_replicas=None, rank=None):
26 | if num_replicas is None:
27 | num_replicas = get_world_size()
28 | if rank is None:
29 | rank = get_rank()
30 | self.dataset = dataset
31 | self.num_replicas = num_replicas
32 | self.rank = rank
33 | self.epoch = 0
34 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
35 | self.total_size = self.num_samples * self.num_replicas
36 |
37 | def __iter__(self):
38 | # deterministically shuffle based on epoch
39 | g = torch.Generator()
40 | g.manual_seed(self.epoch)
41 | indices = list(torch.randperm(len(self.dataset), generator=g))
42 |
43 | # add extra samples to make it evenly divisible
44 | indices += indices[:(self.total_size - len(indices))]
45 | assert len(indices) == self.total_size
46 |
47 | # subsample
48 | offset = self.num_samples * self.rank
49 | indices = indices[offset:offset + self.num_samples]
50 | assert len(indices) == self.num_samples
51 |
52 | return iter(indices)
53 |
54 | def __len__(self):
55 | return self.num_samples
56 |
57 | def set_epoch(self, epoch):
58 | self.epoch = epoch
59 |
--------------------------------------------------------------------------------
/segmentTool/lib/utils/data/sampler.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class Sampler(object):
5 | """Base class for all Samplers.
6 |
7 | Every Sampler subclass has to provide an __iter__ method, providing a way
8 | to iterate over indices of dataset elements, and a __len__ method that
9 | returns the length of the returned iterators.
10 | """
11 |
12 | def __init__(self, data_source):
13 | pass
14 |
15 | def __iter__(self):
16 | raise NotImplementedError
17 |
18 | def __len__(self):
19 | raise NotImplementedError
20 |
21 |
22 | class SequentialSampler(Sampler):
23 | """Samples elements sequentially, always in the same order.
24 |
25 | Arguments:
26 | data_source (Dataset): dataset to sample from
27 | """
28 |
29 | def __init__(self, data_source):
30 | self.data_source = data_source
31 |
32 | def __iter__(self):
33 | return iter(range(len(self.data_source)))
34 |
35 | def __len__(self):
36 | return len(self.data_source)
37 |
38 |
39 | class RandomSampler(Sampler):
40 | """Samples elements randomly, without replacement.
41 |
42 | Arguments:
43 | data_source (Dataset): dataset to sample from
44 | """
45 |
46 | def __init__(self, data_source):
47 | self.data_source = data_source
48 |
49 | def __iter__(self):
50 | return iter(torch.randperm(len(self.data_source)).long())
51 |
52 | def __len__(self):
53 | return len(self.data_source)
54 |
55 |
56 | class SubsetRandomSampler(Sampler):
57 | """Samples elements randomly from a given list of indices, without replacement.
58 |
59 | Arguments:
60 | indices (list): a list of indices
61 | """
62 |
63 | def __init__(self, indices):
64 | self.indices = indices
65 |
66 | def __iter__(self):
67 | return (self.indices[i] for i in torch.randperm(len(self.indices)))
68 |
69 | def __len__(self):
70 | return len(self.indices)
71 |
72 |
73 | class WeightedRandomSampler(Sampler):
74 | """Samples elements from [0,..,len(weights)-1] with given probabilities (weights).
75 |
76 | Arguments:
77 | weights (list) : a list of weights, not necessary summing up to one
78 | num_samples (int): number of samples to draw
79 | replacement (bool): if ``True``, samples are drawn with replacement.
80 | If not, they are drawn without replacement, which means that when a
81 | sample index is drawn for a row, it cannot be drawn again for that row.
82 | """
83 |
84 | def __init__(self, weights, num_samples, replacement=True):
85 | self.weights = torch.DoubleTensor(weights)
86 | self.num_samples = num_samples
87 | self.replacement = replacement
88 |
89 | def __iter__(self):
90 | return iter(torch.multinomial(self.weights, self.num_samples, self.replacement))
91 |
92 | def __len__(self):
93 | return self.num_samples
94 |
95 |
96 | class BatchSampler(object):
97 | """Wraps another sampler to yield a mini-batch of indices.
98 |
99 | Args:
100 | sampler (Sampler): Base sampler.
101 | batch_size (int): Size of mini-batch.
102 | drop_last (bool): If ``True``, the sampler will drop the last batch if
103 | its size would be less than ``batch_size``
104 |
105 | Example:
106 | >>> list(BatchSampler(range(10), batch_size=3, drop_last=False))
107 | [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
108 | >>> list(BatchSampler(range(10), batch_size=3, drop_last=True))
109 | [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
110 | """
111 |
112 | def __init__(self, sampler, batch_size, drop_last):
113 | self.sampler = sampler
114 | self.batch_size = batch_size
115 | self.drop_last = drop_last
116 |
117 | def __iter__(self):
118 | batch = []
119 | for idx in self.sampler:
120 | batch.append(idx)
121 | if len(batch) == self.batch_size:
122 | yield batch
123 | batch = []
124 | if len(batch) > 0 and not self.drop_last:
125 | yield batch
126 |
127 | def __len__(self):
128 | if self.drop_last:
129 | return len(self.sampler) // self.batch_size
130 | else:
131 | return (len(self.sampler) + self.batch_size - 1) // self.batch_size
132 |
--------------------------------------------------------------------------------
/segmentTool/lib/utils/th.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Variable
3 | import numpy as np
4 | import collections
5 |
6 | __all__ = ['as_variable', 'as_numpy', 'mark_volatile']
7 |
8 | def as_variable(obj):
9 | if isinstance(obj, Variable):
10 | return obj
11 | if isinstance(obj, collections.Sequence):
12 | return [as_variable(v) for v in obj]
13 | elif isinstance(obj, collections.Mapping):
14 | return {k: as_variable(v) for k, v in obj.items()}
15 | else:
16 | return Variable(obj)
17 |
18 | def as_numpy(obj):
19 | if isinstance(obj, collections.Sequence):
20 | return [as_numpy(v) for v in obj]
21 | elif isinstance(obj, collections.Mapping):
22 | return {k: as_numpy(v) for k, v in obj.items()}
23 | elif isinstance(obj, Variable):
24 | return obj.data.cpu().numpy()
25 | elif torch.is_tensor(obj):
26 | return obj.cpu().numpy()
27 | else:
28 | return np.array(obj)
29 |
30 | def mark_volatile(obj):
31 | if torch.is_tensor(obj):
32 | obj = Variable(obj)
33 | if isinstance(obj, Variable):
34 | obj.no_grad = True
35 | return obj
36 | elif isinstance(obj, collections.Mapping):
37 | return {k: mark_volatile(o) for k, o in obj.items()}
38 | elif isinstance(obj, collections.Sequence):
39 | return [mark_volatile(o) for o in obj]
40 | else:
41 | return obj
42 |
--------------------------------------------------------------------------------
/segmentTool/resnet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import torch
4 | import torch.nn as nn
5 | import math
6 | from segmentTool.lib.nn import SynchronizedBatchNorm2d
7 |
8 | try:
9 | from urllib import urlretrieve
10 | except ImportError:
11 | from urllib.request import urlretrieve
12 |
13 |
14 | __all__ = ['ResNet', 'resnet18', 'resnet50', 'resnet101'] # resnet101 is coming soon!
15 |
16 |
17 | model_urls = {
18 | 'resnet18': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet18-imagenet.pth',
19 | 'resnet50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet50-imagenet.pth',
20 | 'resnet101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet101-imagenet.pth'
21 | }
22 |
23 |
24 | def conv3x3(in_planes, out_planes, stride=1):
25 | "3x3 convolution with padding"
26 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
27 | padding=1, bias=False)
28 |
29 |
30 | class BasicBlock(nn.Module):
31 | expansion = 1
32 |
33 | def __init__(self, inplanes, planes, stride=1, downsample=None):
34 | super(BasicBlock, self).__init__()
35 | self.conv1 = conv3x3(inplanes, planes, stride)
36 | self.bn1 = SynchronizedBatchNorm2d(planes)
37 | self.relu = nn.ReLU(inplace=True)
38 | self.conv2 = conv3x3(planes, planes)
39 | self.bn2 = SynchronizedBatchNorm2d(planes)
40 | self.downsample = downsample
41 | self.stride = stride
42 |
43 | def forward(self, x):
44 | residual = x
45 |
46 | out = self.conv1(x)
47 | out = self.bn1(out)
48 | out = self.relu(out)
49 |
50 | out = self.conv2(out)
51 | out = self.bn2(out)
52 |
53 | if self.downsample is not None:
54 | residual = self.downsample(x)
55 |
56 | out += residual
57 | out = self.relu(out)
58 |
59 | return out
60 |
61 |
62 | class Bottleneck(nn.Module):
63 | expansion = 4
64 |
65 | def __init__(self, inplanes, planes, stride=1, downsample=None):
66 | super(Bottleneck, self).__init__()
67 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
68 | self.bn1 = SynchronizedBatchNorm2d(planes)
69 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
70 | padding=1, bias=False)
71 | self.bn2 = SynchronizedBatchNorm2d(planes)
72 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
73 | self.bn3 = SynchronizedBatchNorm2d(planes * 4)
74 | self.relu = nn.ReLU(inplace=True)
75 | self.downsample = downsample
76 | self.stride = stride
77 |
78 | def forward(self, x):
79 | residual = x
80 |
81 | out = self.conv1(x)
82 | out = self.bn1(out)
83 | out = self.relu(out)
84 |
85 | out = self.conv2(out)
86 | out = self.bn2(out)
87 | out = self.relu(out)
88 |
89 | out = self.conv3(out)
90 | out = self.bn3(out)
91 |
92 | if self.downsample is not None:
93 | residual = self.downsample(x)
94 |
95 | out += residual
96 | out = self.relu(out)
97 |
98 | return out
99 |
100 |
101 | class ResNet(nn.Module):
102 |
103 | def __init__(self, block, layers, num_classes=1000):
104 | self.inplanes = 128
105 | super(ResNet, self).__init__()
106 | self.conv1 = conv3x3(3, 64, stride=2)
107 | self.bn1 = SynchronizedBatchNorm2d(64)
108 | self.relu1 = nn.ReLU(inplace=True)
109 | self.conv2 = conv3x3(64, 64)
110 | self.bn2 = SynchronizedBatchNorm2d(64)
111 | self.relu2 = nn.ReLU(inplace=True)
112 | self.conv3 = conv3x3(64, 128)
113 | self.bn3 = SynchronizedBatchNorm2d(128)
114 | self.relu3 = nn.ReLU(inplace=True)
115 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
116 |
117 | self.layer1 = self._make_layer(block, 64, layers[0])
118 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
119 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
120 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
121 | self.avgpool = nn.AvgPool2d(7, stride=1)
122 | self.fc = nn.Linear(512 * block.expansion, num_classes)
123 |
124 | for m in self.modules():
125 | if isinstance(m, nn.Conv2d):
126 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
127 | m.weight.data.normal_(0, math.sqrt(2. / n))
128 | elif isinstance(m, SynchronizedBatchNorm2d):
129 | m.weight.data.fill_(1)
130 | m.bias.data.zero_()
131 |
132 | def _make_layer(self, block, planes, blocks, stride=1):
133 | downsample = None
134 | if stride != 1 or self.inplanes != planes * block.expansion:
135 | downsample = nn.Sequential(
136 | nn.Conv2d(self.inplanes, planes * block.expansion,
137 | kernel_size=1, stride=stride, bias=False),
138 | SynchronizedBatchNorm2d(planes * block.expansion),
139 | )
140 |
141 | layers = []
142 | layers.append(block(self.inplanes, planes, stride, downsample))
143 | self.inplanes = planes * block.expansion
144 | for i in range(1, blocks):
145 | layers.append(block(self.inplanes, planes))
146 |
147 | return nn.Sequential(*layers)
148 |
149 | def forward(self, x):
150 | x = self.relu1(self.bn1(self.conv1(x)))
151 | x = self.relu2(self.bn2(self.conv2(x)))
152 | x = self.relu3(self.bn3(self.conv3(x)))
153 | x = self.maxpool(x)
154 |
155 | x = self.layer1(x)
156 | x = self.layer2(x)
157 | x = self.layer3(x)
158 | x = self.layer4(x)
159 |
160 | x = self.avgpool(x)
161 | x = x.view(x.size(0), -1)
162 | x = self.fc(x)
163 |
164 | return x
165 |
166 | def resnet18(pretrained=False, **kwargs):
167 | """Constructs a ResNet-18 model.
168 |
169 | Args:
170 | pretrained (bool): If True, returns a model pre-trained on Places
171 | """
172 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
173 | if pretrained:
174 | model.load_state_dict(load_url(model_urls['resnet18']))
175 | return model
176 |
177 | '''
178 | def resnet34(pretrained=False, **kwargs):
179 | """Constructs a ResNet-34 model.
180 |
181 | Args:
182 | pretrained (bool): If True, returns a model pre-trained on Places
183 | """
184 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
185 | if pretrained:
186 | model.load_state_dict(load_url(model_urls['resnet34']))
187 | return model
188 | '''
189 |
190 | def resnet50(pretrained=False, **kwargs):
191 | """Constructs a ResNet-50 model.
192 |
193 | Args:
194 | pretrained (bool): If True, returns a model pre-trained on Places
195 | """
196 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
197 | if pretrained:
198 | model.load_state_dict(load_url(model_urls['resnet50']), strict=False)
199 | return model
200 |
201 |
202 | def resnet101(pretrained=False, **kwargs):
203 | """Constructs a ResNet-101 model.
204 |
205 | Args:
206 | pretrained (bool): If True, returns a model pre-trained on Places
207 | """
208 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
209 | if pretrained:
210 | model.load_state_dict(load_url(model_urls['resnet101']), strict=False)
211 | return model
212 |
213 | # def resnet152(pretrained=False, **kwargs):
214 | # """Constructs a ResNet-152 model.
215 | #
216 | # Args:
217 | # pretrained (bool): If True, returns a model pre-trained on Places
218 | # """
219 | # model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
220 | # if pretrained:
221 | # model.load_state_dict(load_url(model_urls['resnet152']))
222 | # return model
223 |
224 | def load_url(url, model_dir='./pretrained', map_location=None):
225 | if not os.path.exists(model_dir):
226 | os.makedirs(model_dir)
227 | filename = url.split('/')[-1]
228 | cached_file = os.path.join(model_dir, filename)
229 | if not os.path.exists(cached_file):
230 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
231 | urlretrieve(url, cached_file)
232 | return torch.load(cached_file, map_location=map_location)
233 |
--------------------------------------------------------------------------------
/segmentTool/resnext.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import torch
4 | import torch.nn as nn
5 | import math
6 | from segmentTool.lib.nn import SynchronizedBatchNorm2d
7 |
8 | try:
9 | from urllib import urlretrieve
10 | except ImportError:
11 | from urllib.request import urlretrieve
12 |
13 |
14 | __all__ = ['ResNeXt', 'resnext101'] # support resnext 101
15 |
16 |
17 | model_urls = {
18 | #'resnext50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnext50-imagenet.pth',
19 | 'resnext101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnext101-imagenet.pth'
20 | }
21 |
22 |
23 | def conv3x3(in_planes, out_planes, stride=1):
24 | "3x3 convolution with padding"
25 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
26 | padding=1, bias=False)
27 |
28 |
29 | class GroupBottleneck(nn.Module):
30 | expansion = 2
31 |
32 | def __init__(self, inplanes, planes, stride=1, groups=1, downsample=None):
33 | super(GroupBottleneck, self).__init__()
34 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
35 | self.bn1 = SynchronizedBatchNorm2d(planes)
36 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
37 | padding=1, groups=groups, bias=False)
38 | self.bn2 = SynchronizedBatchNorm2d(planes)
39 | self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1, bias=False)
40 | self.bn3 = SynchronizedBatchNorm2d(planes * 2)
41 | self.relu = nn.ReLU(inplace=True)
42 | self.downsample = downsample
43 | self.stride = stride
44 |
45 | def forward(self, x):
46 | residual = x
47 |
48 | out = self.conv1(x)
49 | out = self.bn1(out)
50 | out = self.relu(out)
51 |
52 | out = self.conv2(out)
53 | out = self.bn2(out)
54 | out = self.relu(out)
55 |
56 | out = self.conv3(out)
57 | out = self.bn3(out)
58 |
59 | if self.downsample is not None:
60 | residual = self.downsample(x)
61 |
62 | out += residual
63 | out = self.relu(out)
64 |
65 | return out
66 |
67 |
68 | class ResNeXt(nn.Module):
69 |
70 | def __init__(self, block, layers, groups=32, num_classes=1000):
71 | self.inplanes = 128
72 | super(ResNeXt, self).__init__()
73 | self.conv1 = conv3x3(3, 64, stride=2)
74 | self.bn1 = SynchronizedBatchNorm2d(64)
75 | self.relu1 = nn.ReLU(inplace=True)
76 | self.conv2 = conv3x3(64, 64)
77 | self.bn2 = SynchronizedBatchNorm2d(64)
78 | self.relu2 = nn.ReLU(inplace=True)
79 | self.conv3 = conv3x3(64, 128)
80 | self.bn3 = SynchronizedBatchNorm2d(128)
81 | self.relu3 = nn.ReLU(inplace=True)
82 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
83 |
84 | self.layer1 = self._make_layer(block, 128, layers[0], groups=groups)
85 | self.layer2 = self._make_layer(block, 256, layers[1], stride=2, groups=groups)
86 | self.layer3 = self._make_layer(block, 512, layers[2], stride=2, groups=groups)
87 | self.layer4 = self._make_layer(block, 1024, layers[3], stride=2, groups=groups)
88 | self.avgpool = nn.AvgPool2d(7, stride=1)
89 | self.fc = nn.Linear(1024 * block.expansion, num_classes)
90 |
91 | for m in self.modules():
92 | if isinstance(m, nn.Conv2d):
93 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels // m.groups
94 | m.weight.data.normal_(0, math.sqrt(2. / n))
95 | elif isinstance(m, SynchronizedBatchNorm2d):
96 | m.weight.data.fill_(1)
97 | m.bias.data.zero_()
98 |
99 | def _make_layer(self, block, planes, blocks, stride=1, groups=1):
100 | downsample = None
101 | if stride != 1 or self.inplanes != planes * block.expansion:
102 | downsample = nn.Sequential(
103 | nn.Conv2d(self.inplanes, planes * block.expansion,
104 | kernel_size=1, stride=stride, bias=False),
105 | SynchronizedBatchNorm2d(planes * block.expansion),
106 | )
107 |
108 | layers = []
109 | layers.append(block(self.inplanes, planes, stride, groups, downsample))
110 | self.inplanes = planes * block.expansion
111 | for i in range(1, blocks):
112 | layers.append(block(self.inplanes, planes, groups=groups))
113 |
114 | return nn.Sequential(*layers)
115 |
116 | def forward(self, x):
117 | x = self.relu1(self.bn1(self.conv1(x)))
118 | x = self.relu2(self.bn2(self.conv2(x)))
119 | x = self.relu3(self.bn3(self.conv3(x)))
120 | x = self.maxpool(x)
121 |
122 | x = self.layer1(x)
123 | x = self.layer2(x)
124 | x = self.layer3(x)
125 | x = self.layer4(x)
126 |
127 | x = self.avgpool(x)
128 | x = x.view(x.size(0), -1)
129 | x = self.fc(x)
130 |
131 | return x
132 |
133 |
134 | '''
135 | def resnext50(pretrained=False, **kwargs):
136 | """Constructs a ResNet-50 model.
137 |
138 | Args:
139 | pretrained (bool): If True, returns a model pre-trained on Places
140 | """
141 | model = ResNeXt(GroupBottleneck, [3, 4, 6, 3], **kwargs)
142 | if pretrained:
143 | model.load_state_dict(load_url(model_urls['resnext50']), strict=False)
144 | return model
145 | '''
146 |
147 |
148 | def resnext101(pretrained=False, **kwargs):
149 | """Constructs a ResNet-101 model.
150 |
151 | Args:
152 | pretrained (bool): If True, returns a model pre-trained on Places
153 | """
154 | model = ResNeXt(GroupBottleneck, [3, 4, 23, 3], **kwargs)
155 | if pretrained:
156 | model.load_state_dict(load_url(model_urls['resnext101']), strict=False)
157 | return model
158 |
159 |
160 | # def resnext152(pretrained=False, **kwargs):
161 | # """Constructs a ResNeXt-152 model.
162 | #
163 | # Args:
164 | # pretrained (bool): If True, returns a model pre-trained on Places
165 | # """
166 | # model = ResNeXt(GroupBottleneck, [3, 8, 36, 3], **kwargs)
167 | # if pretrained:
168 | # model.load_state_dict(load_url(model_urls['resnext152']))
169 | # return model
170 |
171 |
172 | def load_url(url, model_dir='./pretrained', map_location=None):
173 | if not os.path.exists(model_dir):
174 | os.makedirs(model_dir)
175 | filename = url.split('/')[-1]
176 | cached_file = os.path.join(model_dir, filename)
177 | if not os.path.exists(cached_file):
178 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
179 | urlretrieve(url, cached_file)
180 | return torch.load(cached_file, map_location=map_location)
181 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # File : test.py
4 | # Author : Bowen Pan
5 | # Email : panbowen0607@gmail.com
6 | # Date : 09/25/2018
7 | #
8 | # Distributed under terms of the MIT license.
9 |
10 | """
11 |
12 | """
13 | from utils import Foo
14 | from models import VPNModel
15 | from datasets import OVMDataset
16 | from opts import parser
17 | from transform import *
18 | import torchvision
19 | import torch
20 | from torch import nn
21 | import os
22 | import time
23 | import cv2
24 | import dominate
25 | from dominate.tags import *
26 |
27 | mean_rgb = [0.485, 0.456, 0.406]
28 | std_rgb = [0.229, 0.224, 0.225]
29 |
30 | def main():
31 | global args, web_path, best_prec1
32 | best_prec1 = 0
33 | args = parser.parse_args()
34 | network_config = Foo(
35 | encoder=args.encoder,
36 | decoder=args.decoder,
37 | fc_dim=args.fc_dim,
38 | num_views=args.n_views,
39 | num_class=94,
40 | transform_type=args.transform_type,
41 | output_size=args.label_resolution,
42 | )
43 |
44 | val_dataset = OVMDataset(args.data_root, args.eval_list,
45 | transform=torchvision.transforms.Compose([
46 | Stack(roll=True),
47 | ToTorchFormatTensor(div=True),
48 | GroupNormalize(mean_rgb, std_rgb)
49 | ]),
50 | num_views=network_config.num_views, input_size=args.input_resolution,
51 | label_size=args.segSize, use_mask=args.use_mask, use_depth=args.use_depth, is_train=False)
52 |
53 | val_loader = torch.utils.data.DataLoader(
54 | val_dataset, batch_size=args.batch_size,
55 | num_workers=args.num_workers, shuffle=False,
56 | pin_memory=True
57 | )
58 |
59 |
60 | mapper = VPNModel(network_config)
61 | mapper = nn.DataParallel(mapper.cuda())
62 |
63 | if args.weights:
64 | if os.path.isfile(args.weights):
65 | print(("=> loading checkpoint '{}'".format(args.weights)))
66 | checkpoint = torch.load(args.weights)
67 | args.start_epoch = checkpoint['epoch']
68 | mapper.load_state_dict(checkpoint['state_dict'])
69 | print(("=> loaded checkpoint '{}' (epoch {})"
70 | .format(args.evaluate, checkpoint['epoch'])))
71 | else:
72 | print(("=> no checkpoint found at '{}'".format(args.weights)))
73 |
74 |
75 | criterion = nn.NLLLoss(weight=None, size_average=True)
76 | eval(val_loader, mapper, criterion)
77 |
78 | web_path = os.path.join(args.visualize, args.store_name)
79 | if os.path.isdir(web_path):
80 | pass
81 | else:
82 | os.makedirs(web_path)
83 |
84 | with dominate.document(title=web_path) as web:
85 | for step in range(len(val_loader)):
86 | if step % args.print_freq == 0:
87 | h2('Step {}'.format(step*args.batch_size))
88 | with table(border = 1, style = 'table-layout: fixed;'):
89 | with tr():
90 | for i in range(args.test_views):
91 | path = 'Step-{}-{}.png'.format(step * args.batch_size, i)
92 | with td(style='word-wrap: break-word;', halign='center', valign='top'):
93 | img(style='width:128px', src=path)
94 | path = 'Step-{}-pred.png'.format(step * args.batch_size)
95 | with td(style='word-wrap: break-word;', halign='center', valign='top'):
96 | img(style='width:128px', src=path)
97 | path = 'Step-{}-gt.png'.format(step * args.batch_size)
98 | with td(style='word-wrap: break-word;', halign='center', valign='top'):
99 | img(style='width:128px', src=path)
100 |
101 | with open(os.path.join(web_path, 'index.html'), 'w') as fp:
102 | fp.write(web.render())
103 |
104 |
105 |
106 | def eval(val_loader, mapper, criterion):
107 | batch_time = AverageMeter()
108 | data_time = AverageMeter()
109 | losses = AverageMeter()
110 | top1 = AverageMeter()
111 | top5 = AverageMeter()
112 |
113 | mapper.eval()
114 |
115 | end = time.time()
116 |
117 | web_path = os.path.join(args.visualize, args.store_name)
118 | if os.path.isdir(web_path):
119 | pass
120 | else:
121 | os.makedirs(web_path)
122 |
123 | prec_stat = {}
124 | for i in range(args.num_class):
125 | prec_stat[str(i)] = {'intersec': 0, 'union': 0, 'all': 0}
126 |
127 | with open('./metadata/colormap_coarse.csv') as f:
128 | lines = f.readlines()
129 | cat = []
130 | for line in lines:
131 | line = line.rstrip()
132 | cat.append(line)
133 | cat = cat[1:]
134 | label_dic = {}
135 | for i, value in enumerate(cat):
136 | key = str(i)
137 | label_dic[key] = [int(x) for x in value.split(',')[1:]]
138 |
139 | for step, (rgb_stack, target, rgb_origin, OverMaskOrigin) in enumerate(val_loader):
140 | data_time.update(time.time() - end)
141 | with torch.no_grad():
142 | input_rgb_var = torch.autograd.variable(rgb_stack).cuda()
143 | _, output = mapper(x=input_rgb_var, return_feat=True)
144 | target_var = target.cuda()
145 | target_var = target_var.view(-1)
146 | upsample = output.view(-1, args.label_resolution, args.label_resolution, args.num_class).transpose(3,2).transpose(2,1).contiguous()
147 | upsample = nn.functional.upsample(upsample, size=args.segSize, mode='bilinear', align_corners=False)
148 | upsample = nn.functional.softmax(upsample, dim=1)
149 | output = torch.log(upsample.transpose(1,2).transpose(2,3).contiguous().view(-1, args.num_class))
150 | _, pred = upsample.data.topk(1, 1, True, True)
151 | pred = pred.squeeze(1)
152 | loss = criterion(output, target_var)
153 | losses.update(loss.data[0], input_rgb_var.size(0))
154 | prec_stat = count_mean_accuracy(output.data, target_var.data, prec_stat)
155 | prec1 = accuracy(output.data, target_var.data, topk=(1,))[0]
156 |
157 | batch_time.update(time.time() - end)
158 | end = time.time()
159 |
160 | if step % args.print_freq == 0:
161 | output = ('Test: [{0}/{1}]\t'
162 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
163 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
164 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
165 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
166 | step + 1, len(val_loader), batch_time=batch_time,
167 | data_time=data_time, loss=losses, top1=top1))
168 | print(output)
169 |
170 | pred = np.uint8(pred.cpu()[0])
171 | predMask = np.uint8(np.zeros((args.segSize, args.segSize, 3)))
172 | for i, _ in enumerate(pred):
173 | for j, _ in enumerate(pred[0]):
174 | key = str(pred[i][j])
175 | predMask[i,j] = label_dic[key]
176 | predMask = cv2.resize(predMask[:, :, ::-1], (256, 256), interpolation=cv2.INTER_NEAREST)
177 | cv2.imwrite(os.path.join(web_path, 'Step-{}-pred.png'.format(step * args.batch_size, i)), predMask)
178 |
179 | gtMask = OverMaskOrigin[0].cpu().numpy()
180 | gtMask = cv2.resize(gtMask, (256, 256), interpolation=cv2.INTER_NEAREST)
181 | cv2.imwrite(os.path.join(web_path, 'Step-{}-gt.png'.format(step * args.batch_size, i)), gtMask)
182 |
183 | rgb = rgb_origin.cpu().numpy()[0]
184 | for i in range(args.test_views):
185 | cv2.imwrite(os.path.join(web_path, 'Step-{}-{}.png'.format(step * args.batch_size, i)), cv2.resize(rgb[(i + args.view_bias) % 8], (256, 256), interpolation=cv2.INTER_NEAREST))
186 |
187 | sum_acc = 0
188 | counted_cat = 0
189 | sum_iou = 0
190 | for key in prec_stat:
191 | if int(prec_stat[key]['all']) != 0:
192 | acc = prec_stat[key]['intersec'] / (prec_stat[key]['all'] + 1e-10)
193 | iou = prec_stat[key]['intersec'] / (prec_stat[key]['union'] + 1e-10)
194 | sum_acc += acc
195 | sum_iou += iou
196 | counted_cat += 1
197 | mean_acc = sum_acc / counted_cat
198 | mean_iou = sum_iou / counted_cat
199 | output = ('Testing Results: Prec@1 {top1.avg:.3f} Mean Prec@1 {meantop:.3f} Mean IoU {meaniou:.3f} Loss {loss.avg:.5f}'
200 | .format(top1=top1, loss=losses, meantop=mean_acc, meaniou=mean_iou))
201 | print(output)
202 | output_best = '\nBest Prec@1 of: %.3f' % (best_prec1)
203 | print(output_best)
204 |
205 | return top1.avg
206 |
207 | class AverageMeter(object):
208 | """Computes and stores the average and current value"""
209 | def __init__(self):
210 | self.reset()
211 |
212 | def reset(self):
213 | self.val = 0
214 | self.avg = 0
215 | self.sum = 0
216 | self.count = 0
217 |
218 | def update(self, val, n=1):
219 | self.val = val
220 | self.sum += val * n
221 | self.count += n
222 | self.avg = self.sum / self.count
223 |
224 | def count_mean_accuracy(output, target, prec_stat):
225 | _, pred = output.topk(1, 1, True, True)
226 | pred = pred.squeeze(1)
227 | for key in prec_stat.keys():
228 | label = int(key)
229 | pred_map = np.uint8(pred.cpu().numpy() == label)
230 | target_map = np.uint8(target.cpu().numpy() == label)
231 | intersection_t = pred_map * (pred_map == target_map)
232 | union_t = pred_map + target_map - intersection_t
233 | prec_stat[key]['intersec'] += np.sum(intersection_t)
234 | prec_stat[key]['union'] += np.sum(union_t)
235 | prec_stat[key]['all'] += np.sum(target_map)
236 | return prec_stat
237 |
238 | def accuracy(output, target, topk=(1,)):
239 | """Computes the precision@k for the specified values of k"""
240 | maxk = max(topk)
241 | batch_size = target.size(0)
242 |
243 | _, pred = output.topk(maxk, 1, True, True)
244 | pred = pred.t()
245 | correct = pred.eq(target.view(1, -1).expand_as(pred))
246 |
247 | res = []
248 | for k in topk:
249 | correct_k = correct[:k].view(-1).float().sum(0)
250 | res.append(correct_k.mul_(100.0 / batch_size))
251 | return res
252 |
253 | if __name__=='__main__':
254 | main()
255 |
--------------------------------------------------------------------------------
/test_carla.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # File : test.py
4 | # Author : Bowen Pan
5 | # Email : panbowen0607@gmail.com
6 | # Date : 09/25/2018
7 | #
8 | # Distributed under terms of the MIT license.
9 |
10 | """
11 |
12 | """
13 | from utils import Foo
14 | from models import VPNModel
15 | from datasets import OVMDataset
16 | from opts import parser
17 | from transform import *
18 | import torchvision
19 | import torch
20 | from torch import nn
21 | from torch.optim.lr_scheduler import MultiStepLR
22 | from torch import optim
23 | import os
24 | import time
25 | from torch.nn.utils import clip_grad_norm
26 | # from examples.cognitive_mapping.Logger import Logger
27 | import cv2
28 | import shutil
29 | import dominate
30 | from dominate.tags import *
31 |
32 | mean_rgb = [0.485, 0.456, 0.406]
33 | std_rgb = [0.229, 0.224, 0.225]
34 |
35 | def main():
36 | global args, web_path, best_prec1
37 | parser.add_argument('--test-views', type=int, default=94)
38 | parser.add_argument('--view-bias', type=int, default=8)
39 |
40 | best_prec1 = 0
41 | args = parser.parse_args()
42 | network_config = Foo(
43 | encoder=args.encoder,
44 | decoder=args.decoder,
45 | fc_dim=args.fc_dim,
46 | num_views=args.n_views,
47 | num_class=args.num_class,
48 | transform_type=args.transform_type,
49 | output_size=args.label_resolution,
50 | )
51 |
52 | val_dataset = OVMDataset(args.data_root, args.eval_list,
53 | transform=torchvision.transforms.Compose([
54 | Stack(roll=True),
55 | ToTorchFormatTensor(div=True),
56 | GroupNormalize(mean_rgb, std_rgb)
57 | ]),
58 | num_views=network_config.num_views, input_size=args.input_resolution,
59 | label_size=args.segSize, use_mask=args.use_mask, use_depth=args.use_depth, is_train=False)
60 |
61 | val_loader = torch.utils.data.DataLoader(
62 | val_dataset, batch_size=args.batch_size,
63 | num_workers=args.num_workers, shuffle=False,
64 | pin_memory=True
65 | )
66 |
67 |
68 | mapper = VPNModel(network_config)
69 | mapper = nn.DataParallel(mapper.cuda())
70 |
71 | if args.weights:
72 | if os.path.isfile(args.weights):
73 | print(("=> loading checkpoint '{}'".format(args.weights)))
74 | checkpoint = torch.load(args.weights)
75 | args.start_epoch = checkpoint['epoch']
76 | mapper.load_state_dict(checkpoint['state_dict'])
77 | print(("=> loaded checkpoint '{}' (epoch {})"
78 | .format(args.evaluate, checkpoint['epoch'])))
79 | else:
80 | print(("=> no checkpoint found at '{}'".format(args.weights)))
81 |
82 |
83 | criterion = nn.NLLLoss(weight=None, size_average=True)
84 | eval(val_loader, mapper, criterion)
85 |
86 | web_path = os.path.join(args.visualize, args.store_name)
87 | if os.path.isdir(web_path):
88 | pass
89 | else:
90 | os.makedirs(web_path)
91 |
92 | with dominate.document(title=web_path) as web:
93 | for step in range(len(val_loader)):
94 | if step % args.print_freq == 0:
95 | h2('Step {}'.format(step*args.batch_size))
96 | with table(border = 1, style = 'table-layout: fixed;'):
97 | with tr():
98 | for i in range(args.test_views):
99 | path = 'Step-{}-{}.png'.format(step * args.batch_size, i)
100 | with td(style='word-wrap: break-word;', halign='center', valign='top'):
101 | img(style='width:128px', src=path)
102 | path = 'Step-{}-pred.png'.format(step * args.batch_size)
103 | with td(style='word-wrap: break-word;', halign='center', valign='top'):
104 | img(style='width:128px', src=path)
105 | path = 'Step-{}-gt.png'.format(step * args.batch_size)
106 | with td(style='word-wrap: break-word;', halign='center', valign='top'):
107 | img(style='width:128px', src=path)
108 |
109 | with open(os.path.join(web_path, 'index.html'), 'w') as fp:
110 | fp.write(web.render())
111 |
112 |
113 |
114 | def eval(val_loader, mapper, criterion):
115 | batch_time = AverageMeter()
116 | data_time = AverageMeter()
117 | losses = AverageMeter()
118 | top1 = AverageMeter()
119 | top5 = AverageMeter()
120 | best_prec1 = 0
121 |
122 | mapper.eval()
123 |
124 | end = time.time()
125 |
126 | web_path = os.path.join(args.visualize, args.store_name)
127 | if os.path.isdir(web_path):
128 | pass
129 | else:
130 | os.makedirs(web_path)
131 |
132 | prec_stat = {}
133 | for i in range(args.num_class):
134 | prec_stat[str(i)] = {'intersec': 0, 'union': 0, 'all': 0}
135 |
136 | with open('./metadata/colormap_coarse.csv') as f:
137 | lines = f.readlines()
138 | cat = []
139 | for line in lines:
140 | line = line.rstrip()
141 | cat.append(line)
142 | cat = cat[1:]
143 | label_dic = {}
144 | for i, value in enumerate(cat):
145 | key = str(i)
146 | label_dic[key] = [int(x) for x in value.split(',')[1:]]
147 |
148 | for step, (rgb_stack, target, rgb_origin, OverMaskOrigin) in enumerate(val_loader):
149 | data_time.update(time.time() - end)
150 | with torch.no_grad():
151 | input_rgb_var = torch.autograd.variable(rgb_stack).cuda()
152 | _, output = mapper(x=input_rgb_var, return_feat=True)
153 | target_var = target.cuda()
154 | target_var = target_var.view(-1)
155 | upsample = output.view(-1, args.label_resolution, args.label_resolution, args.num_class).transpose(3,2).transpose(2,1).contiguous()
156 | upsample = nn.functional.upsample(upsample, size=args.segSize, mode='bilinear', align_corners=False)
157 | upsample = nn.functional.softmax(upsample, dim=1)
158 | output = torch.log(upsample.transpose(1,2).transpose(2,3).contiguous().view(-1, args.num_class))
159 | _, pred = upsample.data.topk(1, 1, True, True)
160 | pred = pred.squeeze(1)
161 | loss = criterion(output, target_var)
162 | losses.update(loss.item(), input_rgb_var.size(0))
163 | prec_stat = count_mean_accuracy(output.data, target_var.data, prec_stat)
164 | prec1 = accuracy(output.data, target_var.data, topk=(1,))[0]
165 | top1.update(prec1.item(), rgb_stack.size(0))
166 | best_prec1 = max(prec1, best_prec1)
167 |
168 | batch_time.update(time.time() - end)
169 | end = time.time()
170 |
171 | if step % args.print_freq == 0:
172 | output = ('Test: [{0}/{1}]\t'
173 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
174 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
175 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
176 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
177 | step + 1, len(val_loader), batch_time=batch_time,
178 | data_time=data_time, loss=losses, top1=top1))
179 | print(output)
180 |
181 | pred = np.uint8(pred.cpu()[0])
182 | predMask = np.uint8(np.zeros((args.segSize, args.segSize, 3)))
183 | for i, _ in enumerate(pred):
184 | for j, _ in enumerate(pred[0]):
185 | key = str(pred[i][j])
186 | predMask[i,j] = label_dic[key]
187 | predMask = cv2.resize(predMask[:, :, ::-1], (256, 256), interpolation=cv2.INTER_NEAREST)
188 | cv2.imwrite(os.path.join(web_path, 'Step-{}-pred.png'.format(step * args.batch_size, i)), predMask)
189 |
190 | gtMask = OverMaskOrigin[0].cpu().numpy()
191 | print('gtMask.shape: ', gtMask.shape)
192 | gt_rgb = np.uint8(np.zeros((gtMask.shape[0], gtMask.shape[0], 3)))
193 | for i, _ in enumerate(gtMask):
194 | for j, _ in enumerate(gtMask[0]):
195 | key = str(gtMask[i][j])
196 | gt_rgb[i,j] = label_dic[key]
197 |
198 | gt_rgb = cv2.resize(gt_rgb[:, :, ::-1], (256, 256), interpolation=cv2.INTER_NEAREST)
199 | cv2.imwrite(os.path.join(web_path, 'Step-{}-gt.png'.format(step * args.batch_size, i)), gt_rgb)
200 |
201 | rgb = rgb_origin.cpu().numpy()[0]
202 | for i in range(args.test_views):
203 | cv2.imwrite(os.path.join(web_path, 'Step-{}-{}.png'.format(step * args.batch_size, i)), cv2.resize(rgb[(i + args.view_bias) % 8], (256, 256), interpolation=cv2.INTER_NEAREST))
204 |
205 | sum_acc = 0
206 | counted_cat = 0
207 | sum_iou = 0
208 | for key in prec_stat:
209 | if int(prec_stat[key]['all']) != 0:
210 | acc = prec_stat[key]['intersec'] / (prec_stat[key]['all'] + 1e-10)
211 | iou = prec_stat[key]['intersec'] / (prec_stat[key]['union'] + 1e-10)
212 | sum_acc += acc
213 | sum_iou += iou
214 | counted_cat += 1
215 | mean_acc = sum_acc / counted_cat
216 | mean_iou = sum_iou / counted_cat
217 | output = ('Testing Results: Prec@1 {top1.avg:.3f} Mean Prec@1 {meantop:.3f} Mean IoU {meaniou:.3f} Loss {loss.avg:.5f}'
218 | .format(top1=top1, loss=losses, meantop=mean_acc, meaniou=mean_iou))
219 | print(output)
220 | output_best = '\nBest Prec@1 of: %.3f' % (best_prec1)
221 | print(output_best)
222 |
223 | return top1.avg
224 |
225 | class AverageMeter(object):
226 | """Computes and stores the average and current value"""
227 | def __init__(self):
228 | self.reset()
229 |
230 | def reset(self):
231 | self.val = 0
232 | self.avg = 0
233 | self.sum = 0
234 | self.count = 0
235 |
236 | def update(self, val, n=1):
237 | self.val = val
238 | self.sum += val * n
239 | self.count += n
240 | self.avg = self.sum / self.count
241 |
242 | def count_mean_accuracy(output, target, prec_stat):
243 | _, pred = output.topk(1, 1, True, True)
244 | pred = pred.squeeze(1)
245 | for key in prec_stat.keys():
246 | label = int(key)
247 | pred_map = np.uint8(pred.cpu().numpy() == label)
248 | target_map = np.uint8(target.cpu().numpy() == label)
249 | intersection_t = pred_map * (pred_map == target_map)
250 | union_t = pred_map + target_map - intersection_t
251 | prec_stat[key]['intersec'] += np.sum(intersection_t)
252 | prec_stat[key]['union'] += np.sum(union_t)
253 | prec_stat[key]['all'] += np.sum(target_map)
254 | return prec_stat
255 |
256 | def accuracy(output, target, topk=(1,)):
257 | """Computes the precision@k for the specified values of k"""
258 | maxk = max(topk)
259 | batch_size = target.size(0)
260 |
261 | _, pred = output.topk(maxk, 1, True, True)
262 | pred = pred.t()
263 | correct = pred.eq(target.view(1, -1).expand_as(pred))
264 |
265 | res = []
266 | for k in topk:
267 | correct_k = correct[:k].view(-1).float().sum(0)
268 | res.append(correct_k.mul_(100.0 / batch_size))
269 | return res
270 |
271 | if __name__=='__main__':
272 | main()
273 |
--------------------------------------------------------------------------------
/test_seq.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # File : test.py
4 | # Author : Bowen Pan
5 | # Email : panbowen0607@gmail.com
6 | # Date : 09/25/2018
7 | #
8 | # Distributed under terms of the MIT license.
9 |
10 | """
11 |
12 | """
13 | from utils import Foo
14 | from models import VPNModel
15 | from datasets import Seq_OVMDataset
16 | from opts import parser
17 | from transform import *
18 | import torchvision
19 | import torch
20 | from torch import nn
21 | from torch.optim.lr_scheduler import MultiStepLR
22 | from torch import optim
23 | import os
24 | import time
25 | from torch.nn.utils import clip_grad_norm
26 | import cv2
27 | import shutil
28 | import dominate
29 | from dominate.tags import *
30 | import moviepy.editor as mpy
31 |
32 |
33 | mean_rgb = [0.485, 0.456, 0.406]
34 | std_rgb = [0.229, 0.224, 0.225]
35 |
36 | def main():
37 | global args, web_path, best_prec1
38 | best_prec1 = 0
39 | args = parser.parse_args()
40 | network_config = Foo(
41 | encoder=args.encoder,
42 | decoder=args.decoder,
43 | fc_dim=args.fc_dim,
44 | num_views=args.n_views,
45 | num_class=94,
46 | transform_type=args.transform_type,
47 | output_size=args.label_resolution,
48 | )
49 |
50 | val_dataset = Seq_OVMDataset(args.test_dir, pix_file=args.pix_file,
51 | transform=torchvision.transforms.Compose([
52 | Stack(roll=True),
53 | ToTorchFormatTensor(div=True),
54 | GroupNormalize(mean_rgb, std_rgb)
55 | ]),
56 | n_views=network_config.num_views, resolution=args.input_resolution,
57 | label_res=args.label_resolution, use_mask=args.use_mask, is_train=False)
58 |
59 | val_loader = torch.utils.data.DataLoader(
60 | val_dataset, batch_size=1,
61 | shuffle=False, pin_memory=True
62 | )
63 |
64 |
65 | mapper = VPNModel(network_config)
66 | mapper = nn.DataParallel(mapper.cuda())
67 |
68 | if args.weights:
69 | if os.path.isfile(args.weights):
70 | print(("=> loading checkpoint '{}'".format(args.weights)))
71 | checkpoint = torch.load(args.weights)
72 | args.start_epoch = checkpoint['epoch']
73 | mapper.load_state_dict(checkpoint['state_dict'])
74 | print(("=> loaded checkpoint '{}' (epoch {})"
75 | .format(args.evaluate, checkpoint['epoch'])))
76 | else:
77 | print(("=> no checkpoint found at '{}'".format(args.weights)))
78 |
79 | web_path = os.path.join(args.visualize, args.store_name)
80 | criterion = nn.NLLLoss(weight=None, size_average=True)
81 | eval(val_loader, mapper, criterion, web_path)
82 |
83 | web_path = os.path.join(args.visualize, args.store_name)
84 |
85 |
86 | def eval(val_loader, mapper, criterion, web_path):
87 | batch_time = AverageMeter()
88 | data_time = AverageMeter()
89 | losses = AverageMeter()
90 | top1 = AverageMeter()
91 | top5 = AverageMeter()
92 |
93 | mapper.eval()
94 |
95 | end = time.time()
96 | if os.path.isdir(web_path):
97 | pass
98 | else:
99 | os.makedirs(web_path)
100 |
101 | frames = []
102 |
103 | prec_stat = {}
104 | for i in range(args.num_class):
105 | prec_stat[str(i)] = {'intersec': 0, 'all': 0}
106 |
107 | with open('./metadata/colormap_coarse.csv') as f:
108 | lines = f.readlines()
109 | cat = []
110 | for line in lines:
111 | line = line.rstrip()
112 | cat.append(line)
113 | cat = cat[1:]
114 | label_dic = {}
115 | for i, value in enumerate(cat):
116 | key = str(i)
117 | label_dic[key] = [int(x) for x in value.split(',')[1:]]
118 |
119 | reachable = [10, 40, 43, 45, 64, 80]
120 |
121 | for step, (rgb_stack, target, rgb_origin, topmap, OverMaskOrigin) in enumerate(val_loader):
122 | data_time.update(time.time() - end)
123 | with torch.no_grad():
124 | input_rgb_var = torch.autograd.Variable(rgb_stack).cuda()
125 | _, output = mapper(x=input_rgb_var, test_comb=[x * int(args.n_views / args.test_views) for x in list(range(args.test_views))], return_feat=True)
126 | target_var = target.cuda()
127 | target_var = target_var.view(-1)
128 | output = output.view(-1, args.num_class)
129 | upsample = output.view(-1, args.label_resolution, args.label_resolution, args.num_class).transpose(3,2).transpose(2,1).contiguous()
130 | upsample = nn.functional.upsample(upsample, size=args.segSize, mode='bilinear', align_corners=False)
131 | upsample = nn.functional.softmax(upsample, dim=1)
132 | freemap = upsample.data.index_select(dim=1, index=torch.Tensor(reachable).long().cuda())
133 | freemap = freemap.sum(dim=1, keepdim=False)
134 | output = nn.functional.log_softmax(output, dim=1)
135 | _, pred = upsample.data.topk(1, 1, True, True)
136 | pred = pred.squeeze(1)
137 | loss = criterion(output, target_var)
138 | losses.update(loss.data[0], input_rgb_var.size(0))
139 | prec_stat = count_mean_accuracy(output.data, target_var.data, prec_stat)
140 | prec1, prec5 = accuracy(output.data, target_var.data, topk=(1, 5))
141 | top1.update(prec1[0], rgb_stack.size(0))
142 | top5.update(prec5[0], rgb_stack.size(0))
143 |
144 | batch_time.update(time.time() - end)
145 | end = time.time()
146 |
147 | if step % 1 == 0:
148 | output = ('Test: [{0}/{1}]\t'
149 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
150 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
151 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
152 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
153 | step + 1, len(val_loader), batch_time=batch_time,
154 | data_time=data_time, loss=losses, top1=top1))
155 | print(output)
156 |
157 | syn_map = np.zeros((256*3, 256*4 + 40, 3))
158 | syn_map += 255
159 |
160 | pred = np.uint8(pred.cpu()[0])
161 | predMask = np.uint8(np.zeros((args.segSize, args.segSize, 3)))
162 | for i, _ in enumerate(pred):
163 | for j, _ in enumerate(pred[0]):
164 | key = str(pred[i][j])
165 | predMask[i,j] = label_dic[key]
166 | predMask = cv2.resize(predMask[:, :, ::-1], (256, 256), interpolation=cv2.INTER_NEAREST)
167 |
168 | gtMask = OverMaskOrigin[0].cpu().numpy()
169 | gtMask = cv2.resize(gtMask, (256, 256), interpolation=cv2.INTER_NEAREST)
170 | syn_map[:256, 256*3 + 40:] = gtMask
171 | syn_map[256:256*2, 256*3 + 40:] = predMask
172 |
173 | rgb = rgb_origin.cpu().numpy()[0]
174 | topmap = topmap.cpu().numpy()
175 |
176 | freemap = np.uint8(freemap[0]*255)
177 | freemap = cv2.applyColorMap(freemap, cv2.COLORMAP_JET)
178 | freemap = cv2.resize(freemap, (256, 256), interpolation=cv2.INTER_NEAREST)
179 | syn_map[256*2:256*3, 256*3 + 40:] = freemap
180 |
181 | orient_rank = [-2, 6, -2, 0, -1, 4, -2, 2, -2]
182 | for i, orient in enumerate(orient_rank):
183 | if orient >= 0:
184 | syn_map[(i%3)*256:(i%3+1)*256, int(i/3)*256:int(i/3+1)*256] = cv2.resize(rgb[orient], (256, 256))
185 | elif orient == -1:
186 | syn_map[(i%3)*256:(i%3+1)*256, int(i/3)*256:int(i/3+1)*256] = topmap[:, :, ::-1]
187 | elif orient == -2:
188 | syn_map[(i%3)*256:(i%3+1)*256, int(i/3)*256:int(i/3+1)*256] = 240
189 | for i in range(2):
190 | syn_map[256 * (i + 1) - 3: 256 * (i + 1) + 3] = 255
191 | for i in range(3):
192 | syn_map[:, 256 * (i + 1) -3: 256 * (i + 1) + 3] = 255
193 | cv2.imwrite(os.path.join(web_path, 'syn_step%d.jpg'%(step+1)), syn_map)
194 | frames.append(syn_map[:, :, ::-1])
195 | clip = mpy.ImageSequenceClip(frames, fps=8)
196 | clip.write_videofile(os.path.join(web_path, 'OverviewSemVideo.mp4'), fps=8)
197 |
198 |
199 | sum_acc = 0
200 | counted_cat = 0
201 | for key in prec_stat:
202 | if int(prec_stat[key]['all']) != 0:
203 | acc = prec_stat[key]['intersec'] / (prec_stat[key]['all'] + 1e-10)
204 | sum_acc += acc
205 | counted_cat += 1
206 | mean_acc = sum_acc / counted_cat
207 | output = ('Testing Results: Prec@1 {top1.avg:.3f} Mean Prec@1 {meantop:.3f} Loss {loss.avg:.5f}'
208 | .format(top1=top1, loss=losses, meantop=mean_acc))
209 | print(output)
210 | output_best = '\nBest Prec@1 of: %.3f' % (best_prec1)
211 | print(output_best)
212 |
213 | return top1.avg
214 |
215 | class AverageMeter(object):
216 | """Computes and stores the average and current value"""
217 | def __init__(self):
218 | self.reset()
219 |
220 | def reset(self):
221 | self.val = 0
222 | self.avg = 0
223 | self.sum = 0
224 | self.count = 0
225 |
226 | def update(self, val, n=1):
227 | self.val = val
228 | self.sum += val * n
229 | self.count += n
230 | self.avg = self.sum / self.count
231 |
232 | def count_mean_accuracy(output, target, prec_stat):
233 | _, pred = output.topk(1, 1, True, True)
234 | pred = pred.squeeze(1)
235 | for key in prec_stat.keys():
236 | label = int(key)
237 | pred_map = np.uint8(pred.cpu().numpy() == label)
238 | target_map = np.uint8(target.cpu().numpy() == label)
239 | intersection_t = pred_map * (pred_map == target_map)
240 | prec_stat[key]['intersec'] += np.sum(intersection_t)
241 | prec_stat[key]['all'] += np.sum(target_map)
242 | return prec_stat
243 |
244 | def accuracy(output, target, topk=(1,)):
245 | """Computes the precision@k for the specified values of k"""
246 | maxk = max(topk)
247 | batch_size = target.size(0)
248 |
249 | _, pred = output.topk(maxk, 1, True, True)
250 | pred = pred.t()
251 | correct = pred.eq(target.view(1, -1).expand_as(pred))
252 |
253 | res = []
254 | for k in topk:
255 | correct_k = correct[:k].view(-1).float().sum(0)
256 | res.append(correct_k.mul_(100.0 / batch_size))
257 | return res
258 |
259 | if __name__=='__main__':
260 | main()
261 |
--------------------------------------------------------------------------------
/tools/get_trainning_data_from_house3d.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | import csv
4 | import cv2
5 | import pickle
6 | import os
7 |
8 | import gym
9 | from gym import spaces
10 | from House3D.house import House
11 | from House3D.core import Environment, MultiHouseEnv
12 | from House3D import objrender
13 | from House3D.objrender import RenderMode
14 |
15 | __all__ = ['CollectDataTask']
16 |
17 | ###############################################
18 | # Task related definitions and configurations
19 | ###############################################
20 | flag_print_debug_info = False # flag for printing debug info
21 |
22 | dist_reward_scale = 1.0 # 2.0
23 | collision_penalty_reward = 0.3 # penalty for collision
24 | correct_move_reward = None # reward when moving closer to target
25 | stay_room_reward = 0.1 # reward when staying in the correct target room
26 | indicator_reward = 0.5
27 | success_reward = 10 # reward when success
28 |
29 | time_penalty_reward = 0.1 # penalty for each time step
30 | delta_reward_coef = 0.5
31 | speed_reward_coef = 1.0
32 |
33 | success_distance_range = 1.0
34 | success_stay_time_steps = 5
35 | success_see_target_time_steps = 2 # time steps required for success under the "see" criteria
36 |
37 | # sensitivity setting
38 | rotation_sensitivity = 30 # 45 # maximum rotation per time step
39 | default_move_sensitivity = 0.5 # 1.0 # maximum movement per time step
40 |
41 | # discrete action space actions, totally <13> actions
42 | # Fwd, L, R, LF, RF, Lrot, Rrot, Bck, s-Fwd, s-L, s-R, s-Lrot, s-Rrot
43 | discrete_actions = [(1., 0., 0.), (0., 1., 0.), (0., -1., 0.), (0.5, 0.5, 0.), (0.5, -0.5, 0.),
44 | (0., 0., 1.), (0., 0., -1.),
45 | (-0.4, 0., 0.),
46 | (0.4, 0., 0.), (0., 0.4, 0.), (0., -0.4, 0.),
47 | (0., 0., 0.4), (0., 0., -0.4)]
48 | n_discrete_actions = len(discrete_actions)
49 |
50 | # criteria for seeing the object
51 | n_pixel_for_object_see = 450 # need at least see 450 pixels for success under default resolution 120 x 90
52 | n_pixel_for_object_sense = 50
53 | L_pixel_reward_range = n_pixel_for_object_see - n_pixel_for_object_sense
54 | pixel_object_reward = 0.4
55 |
56 |
57 | #################
58 | # Util Functions
59 | #################
60 |
61 | def reset_see_criteria(resolution):
62 | total_pixel = resolution[0] * resolution[1]
63 | global n_pixel_for_object_see, n_pixel_for_object_sense, L_pixel_reward_range
64 | n_pixel_for_object_see = max(int(total_pixel * 0.045), 5)
65 | n_pixel_for_object_sense = max(int(total_pixel * 0.005), 1)
66 | L_pixel_reward_range = n_pixel_for_object_see - n_pixel_for_object_sense
67 |
68 |
69 | class CollectDataTask(gym.Env):
70 | def __init__(self, env,
71 | seed=None,
72 | reward_type='delta',
73 | hardness=None,
74 | move_sensitivity=None,
75 | segment_input=True,
76 | joint_visual_signal=False,
77 | depth_signal=True,
78 | max_steps=-1,
79 | success_measure='see',
80 | discrete_action=False):
81 | self.env = env
82 | assert isinstance(env, Environment), '[RoomNavTask] env must be an instance of Environment!'
83 | if env.resolution != (120, 90): reset_see_criteria(env.resolution)
84 | self.resolution = resolution = env.resolution
85 | assert reward_type in [None, 'none', 'linear', 'indicator', 'delta', 'speed']
86 | self.reward_type = reward_type
87 | self.colorDataFile = self.env.config['colorFile']
88 | self.segment_input = segment_input
89 | self.joint_visual_signal = joint_visual_signal
90 | self.depth_signal = depth_signal
91 | n_channel = 3
92 | if segment_input:
93 | self.env.set_render_mode('semantic')
94 | else:
95 | self.env.set_render_mode('rgb')
96 | if joint_visual_signal: n_channel += 3
97 | if depth_signal: n_channel += 1
98 | self._observation_shape = (resolution[0], resolution[1], n_channel)
99 | self._observation_space = spaces.Box(0, 255, shape=self._observation_shape)
100 |
101 | self.max_steps = max_steps
102 |
103 | self.discrete_action = discrete_action
104 | if discrete_action:
105 | self._action_space = spaces.Discrete(n_discrete_actions)
106 | else:
107 | self._action_space = spaces.Tuple([spaces.Box(0, 1, shape=(4,)), spaces.Box(0, 1, shape=(2,))])
108 |
109 | if seed is not None:
110 | np.random.seed(seed)
111 | random.seed(seed)
112 |
113 | # configs
114 | self.move_sensitivity = (move_sensitivity or default_move_sensitivity) # at most * meters per frame
115 | self.rot_sensitivity = rotation_sensitivity
116 | self.dist_scale = dist_reward_scale or 1
117 | self.successRew = success_reward
118 | self.inroomRew = stay_room_reward or 0.2
119 | self.colideRew = collision_penalty_reward or 0.02
120 | self.goodMoveRew = correct_move_reward or 0.0
121 |
122 | self.last_obs = None
123 | self.last_info = None
124 | self._object_cnt = 0
125 |
126 | # config hardness
127 | self.hardness = None
128 | self.availCoors = None
129 | self._availCoorsDict = None
130 | self.reset_hardness(hardness)
131 |
132 | # temp storage
133 | self.collision_flag = False
134 |
135 | # episode counter
136 | self.current_episode_step = 0
137 |
138 | # config success measure
139 | assert success_measure in ['stay', 'see']
140 | self.success_measure = success_measure
141 | print('[RoomNavTask] >> Success Measure = <{}>'.format(success_measure))
142 | self.success_stay_cnt = 0
143 | if success_measure == 'see':
144 | self.room_target_object = dict()
145 | self._load_target_object_data(self.env.config['roomTargetFile'])
146 |
147 | def _load_target_object_data(self, roomTargetFile):
148 | with open(roomTargetFile) as csvFile:
149 | reader = csv.DictReader(csvFile)
150 | for row in reader:
151 | c = np.array((row['r'], row['g'], row['b']), dtype=np.uint8)
152 | room = row['target_room']
153 | if room not in self.room_target_object:
154 | self.room_target_object[room] = []
155 | self.room_target_object[room].append(c)
156 |
157 | """
158 | reset the target room type to navigate to
159 | when target is None, a valid target will be randomly selected
160 | """
161 |
162 | def reset_target(self, target=None):
163 | if target is None:
164 | target = random.choice(self.house.all_desired_roomTypes)
165 | else:
166 | assert target in self.house.all_desired_roomTypes, '[RoomNavTask] desired target <{}> does not exist in the current house!'.format(
167 | target)
168 | if self.house.setTargetRoom(target): # target room changed!!!
169 | _id = self.house._id
170 | if self.house.targetRoomTp not in self._availCoorsDict[_id]:
171 | if self.hardness is None:
172 | self.availCoors = self.house.connectedCoors
173 | else:
174 | allowed_dist = self.house.maxConnDist * self.hardness
175 | self.availCoors = [c for c in self.house.connectedCoors
176 | if self.house.connMap[c[0], c[1]] <= allowed_dist]
177 | self._availCoorsDict[_id][self.house.targetRoomTp] = self.availCoors
178 | else:
179 | self.availCoors = self._availCoorsDict[_id][self.house.targetRoomTp]
180 |
181 | @property
182 | def house(self):
183 | return self.env.house
184 |
185 | """
186 | gym api: reset function
187 | when target is not None, we will set the target room type to navigate to
188 | """
189 |
190 | def reset(self, target=None):
191 | # clear episode steps
192 | self.current_episode_step = 0
193 | self.success_stay_cnt = 0
194 | self._object_cnt = 0
195 |
196 | # reset house
197 | self.env.reset_house()
198 |
199 | self.house.targetRoomTp = None # [NOTE] IMPORTANT! clear this!!!!!
200 |
201 | # reset target room
202 | self.reset_target(target=target) # randomly reset
203 |
204 | # general birth place
205 | gx, gy = random.choice(self.availCoors)
206 | self.collision_flag = False
207 | # generate state
208 | x, y = self.house.to_coor(gx, gy, True)
209 | self.env.reset(x=x, y=y)
210 | self.last_obs = self.env.render()
211 | if self.joint_visual_signal:
212 | self.last_obs = np.concatenate([self.env.render(mode='rgb'), self.last_obs], axis=-1)
213 | ret_obs = self.last_obs
214 | if self.depth_signal:
215 | dep_sig = self.env.render(mode='depth')
216 | if dep_sig.shape[-1] > 1:
217 | dep_sig = dep_sig[..., 0:1]
218 | ret_obs = np.concatenate([ret_obs, dep_sig], axis=-1)
219 | self.last_info = self.info
220 | return ret_obs
221 |
222 | def _apply_action(self, action):
223 | if self.discrete_action:
224 | return discrete_actions[action]
225 | else:
226 | rot = action[1][0] - action[1][1]
227 | act = action[0]
228 | return (act[0] - act[1]), (act[2] - act[3]), rot
229 |
230 | def _is_success(self, raw_dist):
231 | if raw_dist > 0:
232 | self.success_stay_cnt = 0
233 | return False
234 | if self.success_measure == 'stay':
235 | self.success_stay_cnt += 1
236 | return self.success_stay_cnt >= success_stay_time_steps
237 | # self.success_measure == 'see'
238 | flag_see_target_objects = False
239 | object_color_list = self.room_target_object[self.house.targetRoomTp]
240 | if (self.last_obs is not None) and self.segment_input:
241 | seg_obs = self.last_obs if not self.joint_visual_signal else self.last_obs[:, :, 3:6]
242 | else:
243 | seg_obs = self.env.render(mode='semantic')
244 | self._object_cnt = 0
245 | for c in object_color_list:
246 | cur_n = np.sum(np.all(seg_obs == c, axis=2))
247 | self._object_cnt += cur_n
248 | if self._object_cnt >= n_pixel_for_object_see:
249 | flag_see_target_objects = True
250 | break
251 | if flag_see_target_objects:
252 | self.success_stay_cnt += 1
253 | else:
254 | self.success_stay_cnt = 0 # did not see any target objects!
255 | return self.success_stay_cnt >= success_see_target_time_steps
256 |
257 | @property
258 | def observation_space(self):
259 | return self._observation_space
260 |
261 | @property
262 | def action_space(self):
263 | return self._action_space
264 |
265 | @property
266 | def info(self):
267 | ret = self.env.info
268 | gx, gy = ret['grid']
269 | ret['dist'] = dist = self.house.connMap[gx, gy]
270 | ret['scaled_dist'] = self.house.getScaledDist(gx, gy)
271 | ret['optsteps'] = int(dist / (self.move_sensitivity / self.house.grid_det) + 0.5)
272 | ret['collision'] = int(self.collision_flag)
273 | ret['target_room'] = self.house.targetRoomTp
274 | return ret
275 |
276 | """
277 | return all the available target room types of the current house
278 | """
279 |
280 | def get_avail_targets(self):
281 | return self.house.all_desired_roomTypes
282 |
283 | """
284 | reset the hardness of the task
285 | """
286 |
287 | def reset_hardness(self, hardness=None):
288 | self.hardness = hardness
289 | if hardness is None:
290 | self.availCoors = self.house.connectedCoors
291 | else:
292 | allowed_dist = self.house.maxConnDist * hardness
293 | self.availCoors = [c for c in self.house.connectedCoors
294 | if self.house.connMap[c[0], c[1]] <= allowed_dist]
295 | n_house = self.env.num_house
296 | self._availCoorsDict = [dict() for i in range(n_house)]
297 | self._availCoorsDict[self.house._id][self.house.targetRoomTp] = self.availCoors
298 |
299 | """
300 | recover the state (location) of the agent from the info dictionary
301 | """
302 |
303 | def set_state(self, info):
304 | if len(info['pos']) == 2:
305 | self.env.reset(x=info['pos'][0], y=info['pos'][1], yaw=info['yaw'])
306 | else:
307 | self.env.reset(x=info['pos'][0], y=info['pos'][1], z=info['pos'][2], yaw=info['yaw'])
308 |
309 | """
310 | return 2d topdown map
311 | """
312 |
313 | def get_2dmap(self):
314 | return self.env.gen_2dmap()
315 |
316 | """
317 | show a image.
318 | if img is not None, it will visualize the img to monitor
319 | if img is None, it will return a img object of the observation
320 | Note: whenever this function is called with img not None, the task cannot perform rendering any more!!!!
321 | """
322 |
323 | def show(self, img=None):
324 | return self.env.show(img=img,
325 | renderMapLoc=(None if img is not None else self.info['grid']),
326 | renderSegment=True,
327 | display=(img is not None))
328 |
329 | def debug_show(self):
330 | return self.env.debug_render()
331 |
332 |
333 | if __name__ == '__main__':
334 | from House3D.common import load_config
335 |
336 | modes = [RenderMode.RGB, RenderMode.SEMANTIC, RenderMode.INSTANCE, RenderMode.DEPTH]
337 |
338 | api = objrender.RenderAPI(w=400, h=400, device=0)
339 | cfg = load_config('config.json')
340 |
341 | House_dir = '/data/vision/oliva/scenedataset/activevision/suncg/suncg_data/house'
342 | dataset_root = '/data/vision/oliva/scenedataset/syntheticscene/TopViewMaskDataset'
343 | if not os.path.isdir(dataset_root):
344 | os.mkdir(dataset_root)
345 |
346 | houses = os.listdir(House_dir)
347 | available_houses = []
348 | for house in houses:
349 | if os.path.isfile(os.path.join(House_dir, house, 'house.obj')):
350 | available_houses.append(house)
351 | if len(available_houses) == 500:
352 | break
353 | ruined = []
354 | for house in available_houses:
355 | if os.path.isdir(os.path.join(dataset_root, house)):
356 | print(house + ' Already exist!')
357 | continue
358 | try:
359 | env = Environment(api, house, cfg)
360 | except:
361 | ruined.append(house)
362 | print('House id ' + house + ' can not be used')
363 | continue
364 | print('House ' + house + ' extracting...')
365 | os.mkdir(os.path.join(dataset_root, house))
366 | task = CollectDataTask(env, hardness=0.6, discrete_action=True)
367 | house_obj = task.house
368 | connectedCoor = house_obj.connectedCoors
369 | cam = task.env.cam
370 | for coor in connectedCoor:
371 | gx, gy = coor
372 | if house_obj.connMap[gx, gy] != 0 or gx % 10 != 0 or gy % 10 != 0:
373 | continue
374 | coor_dir = os.path.join(dataset_root, house, 'gx_{}_gy_{}'.format(gx, gy))
375 | if not os.path.isdir(coor_dir):
376 | os.mkdir(coor_dir)
377 | print(coor)
378 | x, y = house_obj.to_coor(gx, gy, True)
379 | info = {'pos': [x, y], 'yaw': 0}
380 | task.set_state(info)
381 | top_map, _ = task.env.gen_localmap(scale=3)
382 | top_map = top_map[:, :, ::-1]
383 | b, g, r = top_map[200, 200]
384 | if b < 80 and b > 30 and g > 30 and g < 80 and r > 240:
385 | cv2.imwrite(os.path.join(coor_dir, 'topdown_view.png'), top_map)
386 | for yaw in [0, 45]:
387 | info = {'pos': [x, y], 'yaw': yaw}
388 | task.set_state(info)
389 | for i, mod in enumerate(['rgb', 'sem', 'ins', 'depth']):
390 | cube_map = task.env.render_cube_map(modes[i])
391 | cube_map = cube_map[:, :4 * cube_map.shape[1] // 6]
392 | if mod == 'depth':
393 | infmask = cube_map[:, :, 1]
394 | cube_map = cube_map[:, :, 0] * (infmask == 0)
395 | else:
396 | cube_map = cube_map[:, :, ::-1]
397 | cv2.imwrite(os.path.join(coor_dir, 'mode={}_{}.png'.format(mod, yaw)), cube_map)
398 | info = {'pos': [x, y, 3.6], 'yaw': 0}
399 | task.set_state(info)
400 | cube_map = task.env.render_cube_map(modes[1])
401 | cube_map = cube_map[:, - cube_map.shape[1] // 6 :]
402 | cube_map = cube_map[:, :, ::-1]
403 | cv2.imwrite(os.path.join(coor_dir, 'OverviewMask.png'), cube_map)
404 |
405 | print(ruined)
406 | print(len(ruined))
407 |
--------------------------------------------------------------------------------
/tools/process_data_for_VPN.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import os
3 | from dominate.tags import *
4 | import dominate
5 |
6 | dataset_root = '/mnt/lustre/share/VPN_driving_scene/mp3d'
7 | scene_list = os.listdir(dataset_root)
8 | train_list = scene_list[:int(5*len(scene_list)/6)]
9 | val_list = scene_list[int(5*len(scene_list)/6):]
10 | train_output = []
11 | val_output = []
12 | data_type = ['rgb', 'sem', 'depth', 'ins', ]
13 |
14 |
15 | for scene_id in train_list:
16 | print('Processing ' + scene_id)
17 | coor_set = os.listdir(os.path.join(dataset_root, scene_id))
18 | for coor in coor_set:
19 | coor_path = os.path.join(scene_id, coor)
20 | skip = False
21 | if not os.path.isfile(os.path.join(dataset_root, coor_path, 'topdown-semantics.png')):
22 | skip = True
23 | if skip:
24 | continue
25 | train_output.append(coor_path)
26 | web_path = os.path.join(dataset_root, coor_path)
27 | with dominate.document(title=web_path) as web:
28 | for mod in data_type:
29 | h2('Mode=%s, Yaw=%d'%(mod, 0))
30 | with table(border=1, style='table-layout: fixed;'):
31 | with tr():
32 | path = 'mode=%s_%d.png'%(mod, 0)
33 | with td(style='word-wrap: break-word;', halign='center', valign='top'):
34 | img(style='height:256px', src=path)
35 | h2('Mode=%s, Yaw=%d' % (mod, 45))
36 | with table(border=1, style='table-layout: fixed;'):
37 | with tr():
38 | path = 'mode=%s_%d.png' % (mod, 45)
39 | with td(style='word-wrap: break-word;', halign='center', valign='top'):
40 | img(style='height:256px', src=path)
41 | h2('Top-view Semantic Mask')
42 | with table(border=1, style='table-layout: fixed;'):
43 | with tr():
44 | path = 'OverviewMask.png'
45 | with td(style='word-wrap: break-word;', halign='center', valign='top'):
46 | img(style='width:256px', src=path)
47 | # with open(os.path.join(web_path, 'index.html'), 'w') as fp:
48 | # fp.write(web.render())
49 | print('Index of ' + os.path.join(scene_id, coor) + ' generated!')
50 |
51 | for scene_id in val_list:
52 | print('Processing ' + scene_id)
53 | coor_set = os.listdir(os.path.join(dataset_root, scene_id))
54 | for coor in coor_set:
55 | coor_path = os.path.join(scene_id, coor)
56 | skip = False
57 | if not os.path.isfile(os.path.join(dataset_root, coor_path, 'topdown-semantics.png')):
58 | # print(os.path.join(coor_path, '0', 'topdown_view.jpg') + ' is lost')
59 | skip = True
60 | if skip:
61 | continue
62 | val_output.append(coor_path)
63 | web_path = os.path.join(dataset_root, coor_path)
64 | with dominate.document(title=web_path) as web:
65 | for mod in data_type:
66 | h2('Mode=%s, Yaw=%d'%(mod, 0))
67 | with table(border=1, style='table-layout: fixed;'):
68 | with tr():
69 | path = 'mode=%s_%d.png'%(mod, 0)
70 | with td(style='word-wrap: break-word;', halign='center', valign='top'):
71 | img(style='height:256px', src=path)
72 | h2('Mode=%s, Yaw=%d' % (mod, 45))
73 | with table(border=1, style='table-layout: fixed;'):
74 | with tr():
75 | path = 'mode=%s_%d.png' % (mod, 45)
76 | with td(style='word-wrap: break-word;', halign='center', valign='top'):
77 | img(style='height:256px', src=path)
78 | h2('Top-view Semantic Mask')
79 | with table(border=1, style='table-layout: fixed;'):
80 | with tr():
81 | path = 'OverviewMask.png'
82 | with td(style='word-wrap: break-word;', halign='center', valign='top'):
83 | img(style='width:256px', src=path)
84 | # with open(os.path.join(web_path, 'index.html'), 'w') as fp:
85 | # fp.write(web.render())
86 | print('Index of ' + os.path.join(scene_id, coor) + ' generated!')
87 |
88 |
89 | with open('train_list.txt','w') as f:
90 | print('Writing train file ...')
91 | f.write('\n'.join(train_output))
92 |
93 | with open('val_list.txt', 'w') as f:
94 | print('Writing validation file ...')
95 | f.write('\n'.join(val_output))
96 |
--------------------------------------------------------------------------------
/tools/process_transfer_driving_data.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import os
3 |
4 | CARLA_data_root = '../data/Carla_Dataset_v1'
5 | scene_list = sorted(os.listdir(CARLA_data_root))
6 | train_list = scene_list[:int(5*len(scene_list)/6)]
7 | val_list = scene_list[int(5*len(scene_list)/6):]
8 | train_output = []
9 | val_output = []
10 | # data_type = ['rgb', 'sem', 'depth', 'ins', ]
11 |
12 |
13 | for scene_id in train_list:
14 | print('Processing ' + scene_id)
15 | coor_set = sorted(os.listdir(os.path.join(CARLA_data_root, scene_id, 'CAM_RGB_FRONT')))
16 |
17 | for coor in coor_set:
18 | coor_path = os.path.join(scene_id, 'CAM_RGB_FRONT', coor)
19 | train_output.append(coor_path)
20 |
21 | with open('../tools/train_source_list.txt', 'w') as f:
22 | print('Writing train_source_list file ...')
23 | f.write('\n'.join(train_output))
24 |
25 | for scene_id in val_list:
26 | print('Processing ' + scene_id)
27 | coor_set = sorted(os.listdir(os.path.join(CARLA_data_root, scene_id, 'CAM_RGB_FRONT')))
28 |
29 | for coor in coor_set:
30 | coor_path = os.path.join(scene_id, 'CAM_RGB_FRONT', coor)
31 | val_output.append(coor_path)
32 |
33 | with open('../tools/val_source_list.txt', 'w') as f:
34 | print('Writing val_source_list file ...')
35 | f.write('\n'.join(val_output))
36 |
--------------------------------------------------------------------------------
/tools/process_transfer_indoor_data.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import os
3 |
4 | CARLA_data_root = '../data/Carla_Dataset_v1'
5 | dataset_root = '../data/nuScenes'
6 | scene_list = sorted(os.listdir(CARLA_data_root))
7 | train_list = scene_list[:int(5*len(scene_list)/6)]
8 | # val_list = scene_list[int(5*len(scene_list)/6):]
9 | train_output = []
10 | # val_output = []
11 | # data_type = ['rgb', 'sem', 'depth', 'ins', ]
12 |
13 |
14 | for scene_id in train_list:
15 | print('Processing ' + scene_id)
16 | coor_set = sorted(os.listdir(os.path.join(CARLA_data_root, scene_id, 'CAM_RGB_FRONT')))
17 |
18 | for coor in coor_set:
19 | coor_path = os.path.join(scene_id, 'CAM_RGB_FRONT', coor)
20 | train_output.append(coor_path)
21 |
22 | with open('train_source_list.txt', 'w') as f:
23 | print('Writing train_source_list file ...')
24 | f.write('\n'.join(train_output))
25 |
26 | nuScenes_data_root = '../data/nuScenes'
27 | scene_list = sorted(os.listdir(nuScenes_data_root))
28 | train_list = scene_list
29 | train_output = []
30 |
31 | for scene_id in train_list:
32 | print('Processing ' + scene_id)
33 | coor_set = sorted(os.listdir(os.path.join(nuScenes_data_root, scene_id, 'CAM_RGB_FRONT')))
34 | coor_set = coor_set[:18220]
35 |
36 | for coor in coor_set:
37 | coor_path = os.path.join(scene_id, 'CAM_RGB_FRONT', coor)
38 | train_output.append(coor_path)
39 |
40 |
41 | with open('train_target_list.txt', 'w') as f:
42 | print('Writing train_target_list file ...')
43 | f.write('\n'.join(train_output))
44 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from utils import Foo
2 | from models import VPNModel
3 | from datasets import OVMDataset
4 | from opts import parser
5 | from transform import *
6 | import torchvision
7 | import torch
8 | from torch import nn
9 | from torch import optim
10 | import os
11 | import time
12 | import shutil
13 |
14 | mean_rgb = [0.485, 0.456, 0.406]
15 | std_rgb = [0.229, 0.224, 0.225]
16 |
17 | def main():
18 | global args, best_prec1
19 | best_prec1 = 0
20 | args = parser.parse_args()
21 | network_config = Foo(
22 | encoder=args.encoder,
23 | decoder=args.decoder,
24 | fc_dim=args.fc_dim,
25 | output_size=args.label_resolution,
26 | num_views=args.n_views,
27 | num_class=args.num_class,
28 | transform_type=args.transform_type,
29 | )
30 | train_dataset = OVMDataset(args.data_root, args.train_list,
31 | transform=torchvision.transforms.Compose([
32 | Stack(roll=True),
33 | ToTorchFormatTensor(div=True),
34 | GroupNormalize(mean_rgb, std_rgb)
35 | ]),
36 | num_views=network_config.num_views, input_size=args.input_resolution,
37 | label_size=args.label_resolution, use_mask=args.use_mask, use_depth=args.use_depth)
38 | val_dataset = OVMDataset(args.data_root, args.eval_list,
39 | transform=torchvision.transforms.Compose([
40 | Stack(roll=True),
41 | ToTorchFormatTensor(div=True),
42 | GroupNormalize(mean_rgb, std_rgb)
43 | ]),
44 | num_views=network_config.num_views, input_size=args.input_resolution,
45 | label_size=args.label_resolution, use_mask=args.use_mask, use_depth=args.use_depth)
46 | train_loader = torch.utils.data.DataLoader(
47 | train_dataset, batch_size=args.batch_size,
48 | num_workers=args.num_workers, shuffle=True,
49 | pin_memory=True
50 | )
51 |
52 | val_loader = torch.utils.data.DataLoader(
53 | val_dataset, batch_size=args.batch_size,
54 | num_workers=args.num_workers, shuffle=False,
55 | pin_memory=True
56 | )
57 |
58 | mapper = VPNModel(network_config)
59 | mapper = nn.DataParallel(mapper.cuda())
60 |
61 | if args.resume:
62 | if os.path.isfile(args.resume):
63 | print(("=> loading checkpoint '{}'".format(args.resume)))
64 | checkpoint = torch.load(args.resume)
65 | args.start_epoch = checkpoint['epoch']
66 | mapper.load_state_dict(checkpoint['state_dict'])
67 | print(("=> loaded checkpoint '{}' (epoch {})"
68 | .format(args.evaluate, checkpoint['epoch'])))
69 | else:
70 | print(("=> no checkpoint found at '{}'".format(args.resume)))
71 |
72 |
73 | criterion = nn.NLLLoss(weight=None, size_average=True)
74 | optimizer = optim.Adam(mapper.parameters(),
75 | lr=args.start_lr, betas=(0.95, 0.999))
76 |
77 | if not os.path.isdir(args.log_root):
78 | os.mkdir(args.log_root)
79 | log_train = open(os.path.join(args.log_root, '%s.csv' % args.store_name), 'w')
80 |
81 | for epoch in range(args.start_epoch, args.epochs):
82 | adjust_learning_rate(optimizer, epoch, args.lr_steps)
83 | train(train_loader, mapper, criterion, optimizer, epoch, log_train)
84 |
85 | if (epoch + 1) % args.ckpt_freq == 0 or epoch == args.epochs - 1:
86 | prec1 = eval(val_loader, mapper, criterion, log_train, epoch)
87 | is_best = prec1 > best_prec1
88 | best_prec1 = max(prec1, best_prec1)
89 | save_checkpoint({
90 | 'epoch': epoch + 1,
91 | 'arch': network_config.encoder,
92 | 'state_dict': mapper.state_dict(),
93 | 'best_prec1': best_prec1,
94 | }, is_best)
95 |
96 | def train(train_loader, mapper, criterion, optimizer, epoch, log):
97 | batch_time = AverageMeter()
98 | data_time = AverageMeter()
99 | losses = AverageMeter()
100 | top1 = AverageMeter()
101 | top5 = AverageMeter()
102 |
103 | mapper.train()
104 |
105 | end = time.time()
106 | for step, data in enumerate(train_loader):
107 | rgb_stack, target = data
108 | data_time.update(time.time() - end)
109 | target_var = target.cuda()
110 | input_rgb_var = torch.autograd.Variable(rgb_stack).cuda()
111 | output = mapper(input_rgb_var)
112 | target_var = target_var.view(-1)
113 | output = output.view(-1, args.num_class)
114 | loss = criterion(output, target_var)
115 | losses.update(loss.data[0], input_rgb_var.size(0))
116 | prec1, prec5 = accuracy(output.data, target_var.data, topk=(1, 5))
117 | top1.update(prec1[0], rgb_stack.size(0))
118 | top5.update(prec5[0], rgb_stack.size(0))
119 |
120 | optimizer.zero_grad()
121 |
122 | loss.backward()
123 | optimizer.step()
124 |
125 | batch_time.update(time.time() - end)
126 | end = time.time()
127 |
128 | if step % args.print_freq == 0:
129 | output = ('Epoch: [{0}][{1}/{2}], lr: {lr:.5f}\t'
130 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
131 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
132 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
133 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
134 | epoch + 1, step + 1, len(train_loader), batch_time=batch_time,
135 | data_time=data_time, loss=losses, top1=top1, lr=optimizer.param_groups[-1]['lr']))
136 | print(output)
137 | log.write(output + '\n')
138 | log.flush()
139 |
140 | def eval(val_loader, mapper, criterion, log, epoch):
141 | batch_time = AverageMeter()
142 | data_time = AverageMeter()
143 | losses = AverageMeter()
144 | top1 = AverageMeter()
145 | top5 = AverageMeter()
146 |
147 | mapper.eval()
148 |
149 | end = time.time()
150 | for step, (rgb_stack, target) in enumerate(val_loader):
151 | data_time.update(time.time() - end)
152 | with torch.no_grad():
153 | input_rgb_var = torch.autograd.Variable(rgb_stack).cuda()
154 | output = mapper(input_rgb_var)
155 | target_var = target.cuda()
156 | target_var = target_var.view(-1)
157 | output = output.view(-1, args.num_class)
158 | loss = criterion(output, target_var)
159 | losses.update(loss.data[0], input_rgb_var.size(0))
160 | prec1, prec5 = accuracy(output.data, target_var.data, topk=(1, 5))
161 | top1.update(prec1[0], rgb_stack.size(0))
162 | top5.update(prec5[0], rgb_stack.size(0))
163 |
164 | batch_time.update(time.time() - end)
165 | end = time.time()
166 |
167 | if step % args.print_freq == 0:
168 | output = ('Test: [{0}][{1}/{2}]\t'
169 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
170 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
171 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
172 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
173 | epoch + 1, step + 1, len(val_loader), batch_time=batch_time,
174 | data_time=data_time, loss=losses, top1=top1))
175 | print(output)
176 | log.write(output + '\n')
177 | log.flush()
178 | output = ('Testing Results: Prec@1 {top1.avg:.3f} Loss {loss.avg:.5f}'
179 | .format(top1=top1, loss=losses))
180 | print(output)
181 | output_best = '\nBest Prec@1: %.3f' % (best_prec1)
182 | print(output_best)
183 | log.write(output + ' ' + output_best + '\n')
184 | log.flush()
185 | return top1.avg
186 |
187 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
188 | torch.save(state, '%s/%s_checkpoint.pth.tar' % (args.root_model, args.store_name))
189 | if is_best:
190 | shutil.copyfile('%s/%s_checkpoint.pth.tar' % (args.root_model, args.store_name), '%s/%s_best.pth.tar' % (args.root_model, args.store_name))
191 |
192 |
193 | def adjust_learning_rate(optimizer, epoch, lr_steps):
194 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
195 | decay = 0.1 ** (sum(epoch >= np.array(lr_steps)))
196 | lr = args.start_lr * decay
197 | decay = args.weight_decay
198 | for param_group in optimizer.param_groups:
199 | param_group['lr'] = lr
200 | param_group['weight_decay'] = decay
201 |
202 | class AverageMeter(object):
203 | """Computes and stores the average and current value"""
204 | def __init__(self):
205 | self.reset()
206 |
207 | def reset(self):
208 | self.val = 0
209 | self.avg = 0
210 | self.sum = 0
211 | self.count = 0
212 |
213 | def update(self, val, n=1):
214 | self.val = val
215 | self.sum += val * n
216 | self.count += n
217 | self.avg = self.sum / self.count
218 |
219 | def accuracy(output, target, topk=(1,)):
220 | """Computes the precision@k for the specified values of k"""
221 | maxk = max(topk)
222 | batch_size = target.size(0)
223 |
224 | _, pred = output.topk(maxk, 1, True, True)
225 | pred = pred.t()
226 | correct = pred.eq(target.view(1, -1).expand_as(pred))
227 |
228 | res = []
229 | for k in topk:
230 | correct_k = correct[:k].view(-1).float().sum(0)
231 | res.append(correct_k.mul_(100.0 / batch_size))
232 | return res
233 |
234 | if __name__=='__main__':
235 | main()
236 |
--------------------------------------------------------------------------------
/train_carla.py:
--------------------------------------------------------------------------------
1 | from utils import Foo
2 | from models import VPNModel
3 | from datasets import OVMDataset
4 | from opts import parser
5 | from transform import *
6 | import torchvision
7 | import torch
8 | from torch import nn
9 | from torch.optim.lr_scheduler import MultiStepLR
10 | from torch import optim
11 | import os
12 | import time
13 | from torch.nn.utils import clip_grad_norm
14 | # from examples.cognitive_mapping.Logger import Logger
15 | import cv2
16 | import shutil
17 |
18 | mean_rgb = [0.485, 0.456, 0.406]
19 | std_rgb = [0.229, 0.224, 0.225]
20 |
21 | def main():
22 | global args, best_prec1
23 | best_prec1 = 0
24 | args = parser.parse_args()
25 | network_config = Foo(
26 | encoder=args.encoder,
27 | decoder=args.decoder,
28 | fc_dim=args.fc_dim,
29 | output_size=args.label_resolution,
30 | num_views=args.n_views,
31 | num_class=args.num_class,
32 | transform_type=args.transform_type,
33 | )
34 | train_dataset = OVMDataset(args.data_root, args.train_list,
35 | transform=torchvision.transforms.Compose([
36 | Stack(roll=True),
37 | ToTorchFormatTensor(div=True),
38 | GroupNormalize(mean_rgb, std_rgb)
39 | ]),
40 | num_views=network_config.num_views, input_size=args.input_resolution,
41 | label_size=args.label_resolution, use_mask=args.use_mask, use_depth=args.use_depth)
42 | val_dataset = OVMDataset(args.data_root, args.eval_list,
43 | transform=torchvision.transforms.Compose([
44 | Stack(roll=True),
45 | ToTorchFormatTensor(div=True),
46 | GroupNormalize(mean_rgb, std_rgb)
47 | ]),
48 | num_views=network_config.num_views, input_size=args.input_resolution,
49 | label_size=args.label_resolution, use_mask=args.use_mask, use_depth=args.use_depth)
50 | train_loader = torch.utils.data.DataLoader(
51 | train_dataset, batch_size=args.batch_size,
52 | num_workers=args.num_workers, shuffle=True,
53 | pin_memory=True
54 | )
55 |
56 | val_loader = torch.utils.data.DataLoader(
57 | val_dataset, batch_size=args.batch_size,
58 | num_workers=args.num_workers, shuffle=False,
59 | pin_memory=True
60 | )
61 |
62 | mapper = VPNModel(network_config)
63 | mapper = nn.DataParallel(mapper.cuda())
64 |
65 | if args.resume:
66 | if os.path.isfile(args.resume):
67 | print(("=> loading checkpoint '{}'".format(args.resume)))
68 | checkpoint = torch.load(args.resume)
69 | args.start_epoch = checkpoint['epoch']
70 | mapper.load_state_dict(checkpoint['state_dict'])
71 | print(("=> loaded checkpoint '{}' (epoch {})"
72 | .format(args.evaluate, checkpoint['epoch'])))
73 | else:
74 | print(("=> no checkpoint found at '{}'".format(args.resume)))
75 |
76 |
77 | criterion = nn.NLLLoss(weight=None, size_average=True)
78 | optimizer = optim.Adam(mapper.parameters(),
79 | lr=args.start_lr, betas=(0.95, 0.999))
80 |
81 | if not os.path.isdir(args.log_root):
82 | os.mkdir(args.log_root)
83 | log_train = open(os.path.join(args.log_root, '%s.csv' % args.store_name), 'w')
84 |
85 | for epoch in range(args.start_epoch, args.epochs):
86 | adjust_learning_rate(optimizer, epoch, args.lr_steps)
87 | train(train_loader, mapper, criterion, optimizer, epoch, log_train)
88 |
89 | if (epoch + 1) % args.ckpt_freq == 0 or epoch == args.epochs - 1:
90 | prec1 = eval(val_loader, mapper, criterion, log_train, epoch)
91 | is_best = prec1 > best_prec1
92 | best_prec1 = max(prec1, best_prec1)
93 | save_checkpoint({
94 | 'epoch': epoch + 1,
95 | 'arch': network_config.encoder,
96 | 'state_dict': mapper.state_dict(),
97 | 'best_prec1': best_prec1,
98 | }, is_best)
99 |
100 | def train(train_loader, mapper, criterion, optimizer, epoch, log):
101 | batch_time = AverageMeter()
102 | data_time = AverageMeter()
103 | losses = AverageMeter()
104 | top1 = AverageMeter()
105 | top5 = AverageMeter()
106 |
107 | mapper.train()
108 |
109 | end = time.time()
110 | for step, data in enumerate(train_loader):
111 | rgb_stack, target = data
112 | data_time.update(time.time() - end)
113 | target_var = target.cuda()
114 | input_rgb_var = torch.autograd.Variable(rgb_stack).cuda()
115 | output = mapper(input_rgb_var)
116 | target_var = target_var.view(-1)
117 | output = output.view(-1, args.num_class)
118 | loss = criterion(output, target_var)
119 | losses.update(loss.item(), input_rgb_var.size(0))
120 | prec1, prec5 = accuracy(output.data, target_var.data, topk=(1, 5))
121 | top1.update(prec1.item(), rgb_stack.size(0))
122 | top5.update(prec5.item(), rgb_stack.size(0))
123 |
124 | optimizer.zero_grad()
125 |
126 | loss.backward()
127 | optimizer.step()
128 |
129 | batch_time.update(time.time() - end)
130 | end = time.time()
131 |
132 | if step % args.print_freq == 0:
133 | output = ('Epoch: [{0}][{1}/{2}], lr: {lr:.5f}\t'
134 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
135 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
136 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
137 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
138 | epoch + 1, step + 1, len(train_loader), batch_time=batch_time,
139 | data_time=data_time, loss=losses, top1=top1, lr=optimizer.param_groups[-1]['lr']))
140 | print(output)
141 | log.write(output + '\n')
142 | log.flush()
143 |
144 | def eval(val_loader, mapper, criterion, log, epoch):
145 | batch_time = AverageMeter()
146 | data_time = AverageMeter()
147 | losses = AverageMeter()
148 | top1 = AverageMeter()
149 | top5 = AverageMeter()
150 |
151 | mapper.eval()
152 |
153 | end = time.time()
154 | for step, (rgb_stack, target) in enumerate(val_loader):
155 | data_time.update(time.time() - end)
156 | with torch.no_grad():
157 | input_rgb_var = torch.autograd.Variable(rgb_stack).cuda()
158 | output = mapper(input_rgb_var)
159 | target_var = target.cuda()
160 | target_var = target_var.view(-1)
161 | output = output.view(-1, args.num_class)
162 | loss = criterion(output, target_var)
163 | losses.update(loss.item(), input_rgb_var.size(0))
164 | prec1, prec5 = accuracy(output.data, target_var.data, topk=(1, 5))
165 | top1.update(prec1.item(), rgb_stack.size(0))
166 | top5.update(prec5.item(), rgb_stack.size(0))
167 |
168 | batch_time.update(time.time() - end)
169 | end = time.time()
170 |
171 | if step % args.print_freq == 0:
172 | output = ('Test: [{0}][{1}/{2}]\t'
173 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
174 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
175 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
176 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
177 | epoch + 1, step + 1, len(val_loader), batch_time=batch_time,
178 | data_time=data_time, loss=losses, top1=top1))
179 | print(output)
180 | log.write(output + '\n')
181 | log.flush()
182 | output = ('Testing Results: Prec@1 {top1.avg:.3f} Loss {loss.avg:.5f}'
183 | .format(top1=top1, loss=losses))
184 | print(output)
185 | output_best = '\nBest Prec@1: %.3f' % (best_prec1)
186 | print(output_best)
187 | log.write(output + ' ' + output_best + '\n')
188 | log.flush()
189 | return top1.avg
190 |
191 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
192 | torch.save(state, '%s/%s_checkpoint.pth.tar' % (args.root_model, args.store_name))
193 | if is_best:
194 | shutil.copyfile('%s/%s_checkpoint.pth.tar' % (args.root_model, args.store_name), '%s/%s_best.pth.tar' % (args.root_model, args.store_name))
195 |
196 |
197 | def adjust_learning_rate(optimizer, epoch, lr_steps):
198 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
199 | decay = 0.1 ** (sum(epoch >= np.array(lr_steps)))
200 | lr = args.start_lr * decay
201 | decay = args.weight_decay
202 | for param_group in optimizer.param_groups:
203 | param_group['lr'] = lr
204 | param_group['weight_decay'] = decay
205 |
206 | class AverageMeter(object):
207 | """Computes and stores the average and current value"""
208 | def __init__(self):
209 | self.reset()
210 |
211 | def reset(self):
212 | self.val = 0
213 | self.avg = 0
214 | self.sum = 0
215 | self.count = 0
216 |
217 | def update(self, val, n=1):
218 | self.val = val
219 | self.sum += val * n
220 | self.count += n
221 | self.avg = self.sum / self.count
222 |
223 | def accuracy(output, target, topk=(1,)):
224 | """Computes the precision@k for the specified values of k"""
225 | maxk = max(topk)
226 | batch_size = target.size(0)
227 |
228 | _, pred = output.topk(maxk, 1, True, True)
229 | pred = pred.t()
230 | correct = pred.eq(target.view(1, -1).expand_as(pred))
231 |
232 | res = []
233 | for k in topk:
234 | correct_k = correct[:k].view(-1).float().sum(0)
235 | res.append(correct_k.mul_(100.0 / batch_size))
236 | return res
237 |
238 | if __name__=='__main__':
239 | main()
240 |
--------------------------------------------------------------------------------
/train_transfer.py:
--------------------------------------------------------------------------------
1 | from utils import Foo
2 | from models import VPNModel, FCDiscriminator
3 | from datasets import House3D_Dataset, MP3D_Dataset, Carla_Dataset, nuScenes_Dataset
4 | from opts import parser
5 | from transform import *
6 | import torchvision
7 | import torch
8 | from torch import nn
9 | from torch.optim.lr_scheduler import MultiStepLR
10 | from torch import optim
11 | import os
12 | import time
13 | from torch.nn.utils import clip_grad_norm
14 |
15 | import shutil
16 | import torch.nn.functional as F
17 | import os.path as osp
18 | from tensorboardX import SummaryWriter
19 | import argparse
20 |
21 | mean_rgb = [0.485, 0.456, 0.406]
22 | std_rgb = [0.229, 0.224, 0.225]
23 |
24 | def str2bool(v):
25 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
26 | return True
27 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
28 | return False
29 | else:
30 | raise argparse.ArgumentTypeError('Unsupported value encountered.')
31 |
32 | def main():
33 | global args, best_prec1
34 | best_prec1 = 0
35 |
36 | parser.add_argument('--source_dir', type=str,
37 | default='/mnt/lustre/share/VPN_driving_scene/TopViewMaskDataset')
38 | parser.add_argument('--target_dir', type=str,
39 | default='/mnt/lustre/share/VPN_driving_scene/mp3d')
40 | parser.add_argument('--num-steps', type=int, default=250000)
41 | parser.add_argument('--iter-size-G', type=int, default=3)
42 | parser.add_argument('--iter-size-D', type=int, default=1)
43 | parser.add_argument('--learning-rate-D', type=float, default=1e-4)
44 | parser.add_argument('--learning-rate', type=float, default=2.5e-4)
45 | parser.add_argument("--SegSize", type=int, default=128,
46 | help="Comma-separated string with height and width of source images.")
47 | parser.add_argument("--SegSize-target", type=int, default=128,
48 | help="Comma-separated string with height and width of target images.")
49 | parser.add_argument('--resume-D', type=str)
50 | parser.add_argument('--resume-G', type=str)
51 | parser.add_argument('--train_source_list', type=str, default='./train_source_list.txt')
52 | parser.add_argument('--train_target_list', type=str, default='./train_target_list.txt')
53 | parser.add_argument('--num-classes', type=int, default=94)
54 | parser.add_argument('--power', type=float, default=0.9)
55 | parser.add_argument('--lambda_adv_target', type=float, default=0.001)
56 | parser.add_argument('--num_steps_stop', type=int, default=150000)
57 | parser.add_argument('--save_pred_every', type=int, default=5000)
58 | parser.add_argument('--snapshot-dir', type=str, default='/mnt/lustre/panbowen/VPN-transfer/snapshot/')
59 | parser.add_argument("--tensorboard", type=str2bool, default=True)
60 | parser.add_argument("--tf-logdir", type=str, default='/mnt/lustre/panbowen/VPN-transfer/tf_log/',
61 | help="Path to the directory of log.")
62 | parser.add_argument('--VPN-weights', type=str)
63 | parser.add_argument('--task-id', type=str)
64 | parser.add_argument('--scenario', type=str, default='indoor')
65 |
66 | args = parser.parse_args()
67 |
68 | network_config = Foo(
69 | encoder=args.encoder,
70 | decoder=args.decoder,
71 | fc_dim=args.fc_dim,
72 | output_size=args.label_resolution,
73 | num_views=args.n_views,
74 | num_class=94,
75 | transform_type=args.transform_type,
76 | )
77 |
78 | if args.scenario == 'indoor':
79 | train_source_dataset = House3D_Dataset(args.source_dir, args.train_source_list,
80 | transform=torchvision.transforms.Compose([
81 | Stack(roll=True),
82 | ToTorchFormatTensor(div=True),
83 | GroupNormalize(mean_rgb, std_rgb)
84 | ]),
85 | num_views=network_config.num_views, input_size=args.input_resolution,
86 | label_size=args.SegSize)
87 | train_target_dataset = MP3D_Dataset(args.target_dir, args.train_target_list,
88 | transform=torchvision.transforms.Compose([
89 | Stack(roll=True),
90 | ToTorchFormatTensor(div=True),
91 | GroupNormalize(mean_rgb, std_rgb)
92 | ]),
93 | num_views=network_config.num_views, input_size=args.input_resolution,
94 | label_size=args.SegSize_target)
95 | elif args.scenario == 'traffic':
96 | train_source_dataset = Carla_Dataset(args.source_dir, args.train_source_list,
97 | transform=torchvision.transforms.Compose([
98 | Stack(roll=True),
99 | ToTorchFormatTensor(div=True),
100 | GroupNormalize(mean_rgb, std_rgb)
101 | ]),
102 | num_views=network_config.num_views, input_size=args.input_resolution,
103 | label_size=args.SegSize)
104 | train_target_dataset = nuScenes_Dataset(args.target_dir, args.train_target_list,
105 | transform=torchvision.transforms.Compose([
106 | Stack(roll=True),
107 | ToTorchFormatTensor(div=True),
108 | GroupNormalize(mean_rgb, std_rgb)
109 | ]),
110 | num_views=network_config.num_views, input_size=args.input_resolution,
111 | label_size=args.SegSize_target)
112 |
113 | source_loader = torch.utils.data.DataLoader(
114 | train_source_dataset, batch_size=args.batch_size,
115 | num_workers=args.num_workers, shuffle=True,
116 | pin_memory=True
117 | )
118 |
119 | target_loader = torch.utils.data.DataLoader(
120 | train_target_dataset, batch_size=args.batch_size,
121 | num_workers=args.num_workers, shuffle=True,
122 | pin_memory=True
123 | )
124 |
125 | mapper = VPNModel(network_config)
126 | mapper = nn.DataParallel(mapper.cuda())
127 | mapper.train()
128 |
129 | model_D1 = FCDiscriminator(num_classes=args.num_classes)
130 | model_D1 = nn.DataParallel(model_D1.cuda())
131 | model_D1.train()
132 |
133 | if args.VPN_weights:
134 | if os.path.isfile(args.VPN_weights):
135 | print(("=> loading checkpoint '{}'".format(args.VPN_weights)))
136 | checkpoint = torch.load(args.VPN_weights)
137 | args.start_epoch = checkpoint['epoch']
138 | mapper.load_state_dict(checkpoint['state_dict'])
139 | print(("=> loaded checkpoint '{}' (epoch {})"
140 | .format(args.VPN_weights, checkpoint['epoch'])))
141 | else:
142 | print(("=> no checkpoint found at '{}'".format(args.VPN_weights)))
143 |
144 | resume_iter = None
145 | if args.resume_G:
146 | if os.path.isfile(args.resume_G):
147 | print(("=> loading checkpoint '{}'".format(args.resume_G)))
148 | state_dict = torch.load(args.resume_G)
149 | mapper.load_state_dict(state_dict)
150 | print(("=> loaded checkpoint '{}' (epoch {})"
151 | .format(args.evaluate, checkpoint['epoch'])))
152 | resume_iter = int(args.resume_G.split('_')[-1].split('.')[0])
153 | else:
154 | print(("=> no checkpoint found at '{}'".format(args.resume_G)))
155 |
156 | if args.resume_D:
157 | if os.path.isfile(args.resume_D):
158 | print(("=> loading checkpoint '{}'".format(args.resume_D)))
159 | checkpoint = torch.load(args.resume_D)
160 | args.start_epoch = checkpoint['epoch']
161 | model_D1.load_state_dict(checkpoint['state_dict'])
162 | print(("=> loaded checkpoint '{}' (epoch {})"
163 | .format(args.evaluate, checkpoint['epoch'])))
164 | else:
165 | print(("=> no checkpoint found at '{}'".format(args.resume_D)))
166 |
167 | optimizer = optim.SGD(mapper.parameters(),
168 | lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
169 | optimizer.zero_grad()
170 |
171 | optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
172 | optimizer_D1.zero_grad()
173 |
174 |
175 | criterion_seg = nn.NLLLoss(weight=None, size_average=True)
176 | criterion_bce = nn.BCEWithLogitsLoss()
177 |
178 | train(source_loader, target_loader, mapper, model_D1,
179 | criterion_seg, criterion_bce, optimizer, optimizer_D1, resume_iter)
180 |
181 | def train(source_loader, target_loader, mapper, model_D1, seg_loss, bce_loss, optimizer, optimizer_D1, resume_iter):
182 | source_loader_iter = enumerate(source_loader)
183 |
184 | target_loader_iter = enumerate(target_loader)
185 |
186 | # set up tensor board
187 | if args.tensorboard:
188 | if not os.path.exists(os.path.join(args.tf_logdir, args.task_id)):
189 | os.makedirs(os.path.join(args.tf_logdir, args.task_id))
190 |
191 | writer = SummaryWriter(os.path.join(args.tf_logdir, args.task_id))
192 |
193 | interp = nn.Upsample(size=(args.SegSize, args.SegSize), mode='bilinear', align_corners=True)
194 | interp_target = nn.Upsample(size=(args.SegSize_target, args.SegSize_target), mode='bilinear', align_corners=True)
195 | source_label = 0
196 | target_label = 1
197 |
198 | for i_iter in range(args.num_steps):
199 | if resume_iter is not None and i_iter < resume_iter:
200 | continue
201 |
202 | loss_seg_value = 0
203 | loss_adv_target_value = 0
204 | loss_D_value = 0
205 |
206 | optimizer.zero_grad()
207 | adjust_learning_rate(optimizer, i_iter)
208 |
209 | optimizer_D1.zero_grad()
210 | adjust_learning_rate_D(optimizer_D1, i_iter)
211 |
212 | # train G
213 | # don't accumulate grads in D
214 | for param in model_D1.parameters():
215 | param.requires_grad = False
216 |
217 | for sub_i in range(args.iter_size_G):
218 | # train with source
219 |
220 | try:
221 | _, batch = source_loader_iter.__next__()
222 | except:
223 | source_loader_iter = enumerate(source_loader)
224 | _, batch = source_loader_iter.__next__()
225 | rgb_stack, label = batch
226 | label_var = label.cuda()
227 | input_rgb_var = torch.autograd.Variable(rgb_stack).cuda()
228 |
229 | _, pred_feat = mapper(input_rgb_var, return_feat=True)
230 |
231 | pred_feat = pred_feat.transpose(3, 2).transpose(2, 1).contiguous()
232 | pred = interp(pred_feat)
233 |
234 | pred = F.log_softmax(pred, dim=1)
235 | pred = pred.transpose(1, 2).transpose(2, 3).contiguous()
236 |
237 | label_var = label_var.view(-1)
238 | output = pred.view(-1, args.num_class)
239 |
240 | loss_seg = seg_loss(output, label_var)
241 | loss = loss_seg / args.iter_size_G
242 | loss.backward()
243 | loss_seg_value += loss_seg.item() / args.iter_size_G
244 |
245 | # train with target
246 | try:
247 | _, batch = target_loader_iter.__next__()
248 | except:
249 | target_loader_iter = enumerate(target_loader)
250 | _, batch = target_loader_iter.__next__()
251 | rgb_stack = batch
252 | input_rgb_var = torch.autograd.Variable(rgb_stack).cuda()
253 | _, pred_target = mapper(input_rgb_var, return_feat=True)
254 | pred_target = pred_target.transpose(3, 2).transpose(2, 1).contiguous()
255 | pred_target = interp_target(pred_target)
256 | pred_target = F.log_softmax(pred_target, dim=1)
257 | D_out = model_D1(torch.exp(pred_target))
258 |
259 | loss_adv_target = bce_loss(D_out, torch.FloatTensor(D_out.data.size()).fill_(source_label).cuda())
260 | loss = args.lambda_adv_target * loss_adv_target / args.iter_size_G
261 | loss.backward()
262 | loss_adv_target_value += loss_adv_target.item() / args.iter_size_G
263 |
264 |
265 | # train D
266 | # bring back requires_grad
267 | for param in model_D1.parameters():
268 | param.requires_grad = True
269 |
270 | for sub_i in range(args.iter_size_D):
271 |
272 | # train with source
273 |
274 | pred = pred.detach()
275 | pred = pred.transpose(3, 2).transpose(2, 1).contiguous()
276 | D_out = model_D1(torch.exp(pred))
277 |
278 | loss_D = bce_loss(D_out, torch.FloatTensor(D_out.data.size()).fill_(source_label).cuda())
279 | loss_D = loss_D / args.iter_size_D / 2
280 | loss_D.backward()
281 |
282 | loss_D_value += loss_D.item()
283 |
284 | # train with target
285 | pred_target = pred_target.detach()
286 |
287 | D_out = model_D1(torch.exp(pred_target))
288 |
289 | loss_D = bce_loss(D_out, torch.FloatTensor(D_out.data.size()).fill_(target_label).cuda())
290 |
291 | loss_D = loss_D / args.iter_size_D / 2
292 | loss_D.backward()
293 | loss_D_value += loss_D.item()
294 |
295 | optimizer.step()
296 | optimizer_D1.step()
297 |
298 | if args.tensorboard:
299 | scalar_info = {
300 | 'loss_seg': loss_seg_value,
301 | 'loss_adv': loss_adv_target_value,
302 | 'loss_D': loss_D_value,
303 | }
304 |
305 | if i_iter % 10 == 0:
306 | print(args.tf_logdir)
307 | for key, val in scalar_info.items():
308 | writer.add_scalar(key, val, i_iter)
309 |
310 | print(
311 | 'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f} loss_adv = {3:.3f}, loss_D = {4:.3f} '.format(
312 | i_iter, args.num_steps, loss_seg_value, loss_adv_target_value, loss_D_value))
313 |
314 | if i_iter >= args.num_steps_stop - 1:
315 | print('save model ...')
316 | torch.save(mapper.state_dict(), osp.join(args.snapshot_dir, 'House3D_' + str(args.num_steps_stop) + '.pth'))
317 | torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'House3D_' + str(args.num_steps_stop) + '_D.pth'))
318 | break
319 |
320 | if i_iter % args.save_pred_every == 0 and i_iter != 0:
321 | print('taking snapshot ...')
322 | torch.save(mapper.state_dict(), osp.join(args.snapshot_dir, 'House3D_' + str(i_iter) + '.pth'))
323 | torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'House3D_' + str(i_iter) + '_D.pth'))
324 |
325 | if args.tensorboard:
326 | writer.close()
327 |
328 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
329 | torch.save(state, '%s/%s_checkpoint.pth.tar' % (args.root_model, args.store_name))
330 | if is_best:
331 | shutil.copyfile('%s/%s_checkpoint.pth.tar' % (args.root_model, args.store_name), '%s/%s_best.pth.tar' % (args.root_model, args.store_name))
332 |
333 | def lr_poly(base_lr, iter, max_iter, power):
334 | return base_lr * ((1 - float(iter) / max_iter) ** (power))
335 |
336 |
337 | def adjust_learning_rate(optimizer, i_iter):
338 | lr = lr_poly(args.learning_rate, i_iter, args.num_steps, args.power)
339 | optimizer.param_groups[0]['lr'] = lr
340 | if len(optimizer.param_groups) > 1:
341 | optimizer.param_groups[1]['lr'] = lr * 10
342 |
343 |
344 | def adjust_learning_rate_D(optimizer, i_iter):
345 | lr = lr_poly(args.learning_rate_D, i_iter, args.num_steps, args.power)
346 | optimizer.param_groups[0]['lr'] = lr
347 | if len(optimizer.param_groups) > 1:
348 | optimizer.param_groups[1]['lr'] = lr * 10
349 |
350 | class AverageMeter(object):
351 | """Computes and stores the average and current value"""
352 | def __init__(self):
353 | self.reset()
354 |
355 | def reset(self):
356 | self.val = 0
357 | self.avg = 0
358 | self.sum = 0
359 | self.count = 0
360 |
361 | def update(self, val, n=1):
362 | self.val = val
363 | self.sum += val * n
364 | self.count += n
365 | self.avg = self.sum / self.count
366 |
367 | def accuracy(output, target, topk=(1,)):
368 | """Computes the precision@k for the specified values of k"""
369 | maxk = max(topk)
370 | batch_size = target.size(0)
371 |
372 | _, pred = output.topk(maxk, 1, True, True)
373 | pred = pred.t()
374 | correct = pred.eq(target.view(1, -1).expand_as(pred))
375 |
376 | res = []
377 | for k in topk:
378 | correct_k = correct[:k].view(-1).float().sum(0)
379 | res.append(correct_k.mul_(100.0 / batch_size))
380 | return res
381 |
382 | if __name__=='__main__':
383 | main()
384 |
--------------------------------------------------------------------------------
/transform.py:
--------------------------------------------------------------------------------
1 | import torchvision
2 | import random
3 | from PIL import Image, ImageOps
4 | import numpy as np
5 | import numbers
6 | import math
7 | import torch
8 |
9 |
10 | class GroupRandomCrop(object):
11 | def __init__(self, size):
12 | if isinstance(size, numbers.Number):
13 | self.size = (int(size), int(size))
14 | else:
15 | self.size = size
16 |
17 | def __call__(self, img_group):
18 |
19 | w, h = img_group[0].size
20 | th, tw = self.size
21 |
22 | out_images = list()
23 |
24 | x1 = random.randint(0, w - tw)
25 | y1 = random.randint(0, h - th)
26 |
27 | for img in img_group:
28 | assert(img.size[0] == w and img.size[1] == h)
29 | if w == tw and h == th:
30 | out_images.append(img)
31 | else:
32 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
33 |
34 | return out_images
35 |
36 |
37 | class GroupCenterCrop(object):
38 | def __init__(self, size):
39 | self.worker = torchvision.transforms.CenterCrop(size)
40 |
41 | def __call__(self, img_group):
42 | return [self.worker(img) for img in img_group]
43 |
44 |
45 | class GroupRandomHorizontalFlip(object):
46 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5
47 | """
48 | def __init__(self, is_flow=False):
49 | self.is_flow = is_flow
50 |
51 | def __call__(self, img_group, is_flow=False):
52 | v = random.random()
53 | if v < 0.5:
54 | ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
55 | if self.is_flow:
56 | for i in range(0, len(ret), 2):
57 | ret[i] = ImageOps.invert(ret[i]) # invert flow pixel values when flipping
58 | return ret
59 | else:
60 | return img_group
61 |
62 |
63 | class GroupNormalize(object):
64 | def __init__(self, mean, std):
65 | self.mean = mean
66 | self.std = std
67 |
68 | def __call__(self, tensor):
69 | rep_mean = self.mean * (tensor.size()[0]//len(self.mean))
70 | rep_std = self.std * (tensor.size()[0]//len(self.std))
71 |
72 | # TODO: make efficient
73 | for t, m, s in zip(tensor, rep_mean, rep_std):
74 | t.sub_(m).div_(s)
75 |
76 | return tensor
77 |
78 |
79 | class GroupScale(object):
80 | """ Rescales the input PIL.Image to the given 'size'.
81 | 'size' will be the size of the smaller edge.
82 | For example, if height > width, then image will be
83 | rescaled to (size * height / width, size)
84 | size: size of the smaller edge
85 | interpolation: Default: PIL.Image.BILINEAR
86 | """
87 |
88 | def __init__(self, size, interpolation=Image.BILINEAR):
89 | self.worker = torchvision.transforms.Scale(size, interpolation)
90 |
91 | def __call__(self, img_group):
92 | return [self.worker(img) for img in img_group]
93 |
94 |
95 | class GroupOverSample(object):
96 | def __init__(self, crop_size, scale_size=None):
97 | self.crop_size = crop_size if not isinstance(crop_size, int) else (crop_size, crop_size)
98 |
99 | if scale_size is not None:
100 | self.scale_worker = GroupScale(scale_size)
101 | else:
102 | self.scale_worker = None
103 |
104 | def __call__(self, img_group):
105 |
106 | if self.scale_worker is not None:
107 | img_group = self.scale_worker(img_group)
108 |
109 | image_w, image_h = img_group[0].size
110 | crop_w, crop_h = self.crop_size
111 |
112 | offsets = GroupMultiScaleCrop.fill_fix_offset(False, image_w, image_h, crop_w, crop_h)
113 | oversample_group = list()
114 | for o_w, o_h in offsets:
115 | normal_group = list()
116 | flip_group = list()
117 | for i, img in enumerate(img_group):
118 | crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h))
119 | normal_group.append(crop)
120 | flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT)
121 |
122 | if img.mode == 'L' and i % 2 == 0:
123 | flip_group.append(ImageOps.invert(flip_crop))
124 | else:
125 | flip_group.append(flip_crop)
126 |
127 | oversample_group.extend(normal_group)
128 | oversample_group.extend(flip_group)
129 | return oversample_group
130 |
131 |
132 | class GroupMultiScaleCrop(object):
133 |
134 | def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True):
135 | self.scales = scales if scales is not None else [1, 875, .75, .66]
136 | self.max_distort = max_distort
137 | self.fix_crop = fix_crop
138 | self.more_fix_crop = more_fix_crop
139 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size]
140 | self.interpolation = Image.BILINEAR
141 |
142 | def __call__(self, img_group):
143 |
144 | im_size = img_group[0].size
145 |
146 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size)
147 | crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group]
148 | ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation)
149 | for img in crop_img_group]
150 | return ret_img_group
151 |
152 | def _sample_crop_size(self, im_size):
153 | image_w, image_h = im_size[0], im_size[1]
154 |
155 | # find a crop size
156 | base_size = min(image_w, image_h)
157 | crop_sizes = [int(base_size * x) for x in self.scales]
158 | crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes]
159 | crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes]
160 |
161 | pairs = []
162 | for i, h in enumerate(crop_h):
163 | for j, w in enumerate(crop_w):
164 | if abs(i - j) <= self.max_distort:
165 | pairs.append((w, h))
166 |
167 | crop_pair = random.choice(pairs)
168 | if not self.fix_crop:
169 | w_offset = random.randint(0, image_w - crop_pair[0])
170 | h_offset = random.randint(0, image_h - crop_pair[1])
171 | else:
172 | w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1])
173 |
174 | return crop_pair[0], crop_pair[1], w_offset, h_offset
175 |
176 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h):
177 | offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h)
178 | return random.choice(offsets)
179 |
180 | @staticmethod
181 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h):
182 | w_step = (image_w - crop_w) // 4
183 | h_step = (image_h - crop_h) // 4
184 |
185 | ret = list()
186 | ret.append((0, 0)) # upper left
187 | ret.append((4 * w_step, 0)) # upper right
188 | ret.append((0, 4 * h_step)) # lower left
189 | ret.append((4 * w_step, 4 * h_step)) # lower right
190 | ret.append((2 * w_step, 2 * h_step)) # center
191 |
192 | if more_fix_crop:
193 | ret.append((0, 2 * h_step)) # center left
194 | ret.append((4 * w_step, 2 * h_step)) # center right
195 | ret.append((2 * w_step, 4 * h_step)) # lower center
196 | ret.append((2 * w_step, 0 * h_step)) # upper center
197 |
198 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter
199 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter
200 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter
201 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter
202 |
203 | return ret
204 |
205 |
206 | class GroupRandomSizedCrop(object):
207 | """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size
208 | and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio
209 | This is popularly used to train the Inception networks
210 | size: size of the smaller edge
211 | interpolation: Default: PIL.Image.BILINEAR
212 | """
213 | def __init__(self, size, interpolation=Image.BILINEAR):
214 | self.size = size
215 | self.interpolation = interpolation
216 |
217 | def __call__(self, img_group):
218 | for attempt in range(10):
219 | area = img_group[0].size[0] * img_group[0].size[1]
220 | target_area = random.uniform(0.08, 1.0) * area
221 | aspect_ratio = random.uniform(3. / 4, 4. / 3)
222 |
223 | w = int(round(math.sqrt(target_area * aspect_ratio)))
224 | h = int(round(math.sqrt(target_area / aspect_ratio)))
225 |
226 | if random.random() < 0.5:
227 | w, h = h, w
228 |
229 | if w <= img_group[0].size[0] and h <= img_group[0].size[1]:
230 | x1 = random.randint(0, img_group[0].size[0] - w)
231 | y1 = random.randint(0, img_group[0].size[1] - h)
232 | found = True
233 | break
234 | else:
235 | found = False
236 | x1 = 0
237 | y1 = 0
238 |
239 | if found:
240 | out_group = list()
241 | for img in img_group:
242 | img = img.crop((x1, y1, x1 + w, y1 + h))
243 | assert(img.size == (w, h))
244 | out_group.append(img.resize((self.size, self.size), self.interpolation))
245 | return out_group
246 | else:
247 | # Fallback
248 | scale = GroupScale(self.size, interpolation=self.interpolation)
249 | crop = GroupRandomCrop(self.size)
250 | return crop(scale(img_group))
251 |
252 |
253 | class Stack(object):
254 |
255 | def __init__(self, roll=False):
256 | self.roll = roll
257 |
258 | def __call__(self, img_group):
259 | if self.roll:
260 | return np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2)
261 | else:
262 | return np.concatenate(img_group, axis=2)
263 |
264 |
265 | class ToTorchFormatTensor(object):
266 | """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
267 | to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
268 | def __init__(self, div=True, float=True):
269 | self.div = div
270 | self.float = float
271 |
272 | def __call__(self, pic):
273 | if isinstance(pic, np.ndarray):
274 | # handle numpy array
275 | img = torch.from_numpy(pic).permute(2, 0, 1).contiguous()
276 | else:
277 | # handle PIL Image
278 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
279 | img = img.view(pic.size[1], pic.size[0], len(pic.mode))
280 | # put it from HWC to CHW format
281 | # yikes, this transpose takes 80% of the loading time/CPU
282 | img = img.transpose(0, 1).transpose(0, 2).contiguous()
283 | if self.div:
284 | img = img.div(255)
285 | return img.float() if self.float else img.long()
286 |
287 |
288 | class IdentityTransform(object):
289 |
290 | def __call__(self, data):
291 | return data
292 |
293 |
294 | if __name__ == "__main__":
295 | trans = torchvision.transforms.Compose([
296 | GroupScale(256),
297 | GroupRandomCrop(224),
298 | Stack(),
299 | ToTorchFormatTensor(),
300 | GroupNormalize(
301 | mean=[.485, .456, .406],
302 | std=[.229, .224, .225]
303 | )]
304 | )
305 |
306 | im = Image.open('../tensorflow-model-zoo.torch/lena_299.png')
307 |
308 | color_group = [im] * 3
309 | rst = trans(color_group)
310 |
311 | gray_group = [im.convert('L')] * 9
312 | gray_rst = trans(gray_group)
313 |
314 | trans2 = torchvision.transforms.Compose([
315 | GroupRandomSizedCrop(256),
316 | Stack(),
317 | ToTorchFormatTensor(),
318 | GroupNormalize(
319 | mean=[.485, .456, .406],
320 | std=[.229, .224, .225])
321 | ])
322 | print(trans2(color_group))
323 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np, os, time
2 | from six.moves import xrange
3 | import logging
4 | import torch
5 | from torch.autograd import Variable
6 |
7 | class Foo(object):
8 | def __init__(self, **kwargs):
9 | self.__dict__.update(kwargs)
10 | def __str__(self):
11 | str_ = ''
12 | for v in vars(self).keys():
13 | a = getattr(self, v)
14 | str__ = str(a)
15 | str__ = str__.replace('\n', '\n ')
16 | str_ += '{:s}: {:s}'.format(v, str__)
17 | str_ += '\n'
18 | return str_
19 |
20 |
21 | def get_flow(t, theta, map_size):
22 | """
23 | Rotates the map by theta and translates the rotated map by t.
24 |
25 | Assume that the robot rotates by an angle theta and then moves forward by
26 | translation t. This function returns the flow field field. For every pixel in
27 | the new image it tells us which pixel in the original image it came from:
28 | NewI(x, y) = OldI(flow_x(x,y), flow_y(x,y)).
29 |
30 | Assume there is a point p in the original image. Robot rotates by R and moves
31 | forward by t. p1 = Rt*p; p2 = p1 - t; (the world moves in opposite direction.
32 | So, p2 = Rt*p - t, thus p2 came from R*(p2+t), which is what this function
33 | calculates.
34 |
35 | t: ... x 2 (translation for B batches of N motions each).
36 | theta: ... x 1 (rotation for B batches of N motions each).
37 |
38 | Output: ... x map_size x map_size x 2
39 | """
40 | B = t.view(-1, 2).size()[0]
41 | tx, ty = torch.unbind(torch.view(t, [-1, 1, 1, 1, 2]), dim=4) # Bx1x1x1
42 | theta = torch.view(theta, [-1, 1, 1, 1])
43 | # c = tf.constant((map_size - 1.) / 2., dtype=tf.float32)
44 | c = Variable(torch.Tensor([(map_size - 1.) / 2.]).double())
45 | x, y = np.meshgrid(np.arange(map_size[0]), np.arange(map_size[1]))
46 | x = Variable(x[np.newaxis, :, :, np.newaxis]).view(1, map_size[0], map_size[1], 1)
47 | y = Variable(y[np.newaxis, :, :, np.newaxis]).view(1, map_size[0], map_size[1], 1)
48 | # x = tf.constant(x[np.newaxis, :, :, np.newaxis], dtype=tf.float32, name='x',
49 | # shape=[1, map_size, map_size, 1])
50 | # y = tf.constant(y[np.newaxis, :, :, np.newaxis], dtype=tf.float32, name='y',
51 | # shape=[1, map_size, map_size, 1])
52 |
53 | tx = tx - c.expand(tx.size())
54 | x = x.expand([B] + x.size()[1:])
55 | x = x + tx.expand(x.size())
56 | ty = ty - c.expand(ty.size())
57 | y = y.expand([B] + y.size()[1:])
58 | y = y + ty.expand(y.size()) # BxHxWx1
59 | # x = x - (-tx + c.expand(tx.size())) #1xHxWx1
60 | # y = y - (-ty + c.expand(ty.size()))
61 |
62 | sin_theta = torch.sin(theta) #Bx1x1x1
63 | cos_theta = torch.cos(theta)
64 | xr = x * cos_theta.expand(x.size()) - y * sin_theta.expand(y.size())
65 | yr = x * sin_theta.expand(x.size()) + y * cos_theta.expand(y.size()) # BxHxWx1
66 | # xr = cos_theta * x - sin_theta * y
67 | # yr = sin_theta * x + cos_theta * y
68 |
69 | xr = xr + c.expand(xr.size())
70 | yr = yr + c.expand(yr.size())
71 |
72 | flow = torch.stack([xr, yr], axis=-1)
73 | sh = t.size()[:-1] + [map_size[0], map_size[1], 2]
74 | # sh = tf.unstack(tf.shape(t), axis=0)
75 | # sh = tf.stack(sh[:-1] + [tf.constant(_, dtype=tf.int32) for _ in [map_size, map_size, 2]])
76 | flow = torch.view(flow, shape=sh)
77 | return flow
78 |
79 |
80 | def dense_resample(im, flow_im, output_valid_mask=False):
81 | """ Resample reward at particular locations.
82 | Args:
83 | im: ...xHxW matrix to sample from.
84 | flow_im: ...xHxWx2 matrix, samples the image using absolute offsets as given
85 | by the flow_im.
86 | """
87 | valid_mask = None
88 |
89 | x, y = torch.unbind(flow_im, axis=-1)
90 | x = x.view(-1)
91 | y = y.view(-1)
92 |
93 | # constants
94 | # shape = tf.unstack(tf.shape(im))
95 | # channels = shape[-1]
96 | shape = im.size()
97 | width = shape[-1]
98 | height = shape[-2]
99 | num_batch = 1
100 | for dim in shape[:-2]:
101 | num_batch *= dim
102 | zero = Variable(torch.Tensor([0]).double())
103 | # num_batch = tf.cast(tf.reduce_prod(tf.stack(shape[:-3])), 'int32')
104 | # zero = tf.constant(0, dtype=tf.int32)
105 |
106 | # Round up and down.
107 | x0 = torch.floor(x)
108 | x1 = x0 + 1
109 | y0 = torch.floor(y)
110 | y1 = y0 + 1
111 |
112 | x0 = x0.clamp(0, width - 1)
113 | x1 = x1.clamp(0, width - 1)
114 | y0 = y0.clamp(0, height - 1)
115 | y1 = y1.clamp(0, height - 1)
116 | dim2 = width
117 | dim1 = width * height
118 |
119 | # Create base index
120 | base = torch.range(num_batch) * dim1
121 | base = base.view(-1, 1)
122 | # base = tf.reshape(tf.range(num_batch) * dim1, shape=[-1, 1])
123 | base = base.expand(base.size()[0], height * width).view(-1) # batch_size * H * W
124 | # base = tf.reshape(tf.tile(base, [1, height * width]), shape=[-1])
125 |
126 | base_y0 = base + y0.expand(base.size()) * dim2
127 | base_y1 = base + y1.expand(base.size()) * dim2
128 | idx_a = base_y0 + x0.expand(base_y0.size())
129 | idx_b = base_y1 + x0.expand(base_y1.size())
130 | idx_c = base_y0 + x1.expand(base_y0.size())
131 | idx_d = base_y1 + x1.expand(base_y1.size())
132 |
133 | # use indices to lookup pixels in the flat image and restore channels dim
134 | # sh = tf.stack([tf.constant(-1, dtype=tf.int32), channels])
135 | im_flat = torch.view(im, [-1])
136 | # im_flat = tf.cast(tf.reshape(im, sh), dtype=tf.float32)
137 | pixel_a = torch.gather(im_flat, idx_a)
138 | pixel_b = torch.gather(im_flat, idx_b)
139 | pixel_c = torch.gather(im_flat, idx_c)
140 | pixel_d = torch.gather(im_flat, idx_d)
141 |
142 | # and finally calculate interpolated values
143 | # x1_f = tf.to_float(x1)
144 | # y1_f = tf.to_float(y1)
145 | x1_f = x1.float()
146 | y1_f = y1.float()
147 |
148 | wa = torch.unsqueeze(((x1_f - x) * (y1_f - y)), 1)
149 | wb = torch.unsqueeze(((x1_f - x) * (1.0 - (y1_f - y))), 1)
150 | wc = torch.unsqueeze(((1.0 - (x1_f - x)) * (y1_f - y)), 1)
151 | wd = torch.unsqueeze(((1.0 - (x1_f - x)) * (1.0 - (y1_f - y))), 1)
152 |
153 | output = wa * pixel_a.unsqueeze(1) + wb * pixel_b.unsqueeze(1) + wc * pixel_c.unsqueeze(1) + wd * pixel_d.unsqueeze(1)
154 | # output = tf.reshape(output, shape=tf.shape(im))
155 | output = output.view(im.size())
156 | return output, valid_mask
157 |
--------------------------------------------------------------------------------