├── README.md
├── data
└── list
│ ├── face_train.txt
│ └── face_val.txt
├── enviroment.yaml
├── experiments
├── cityscapes
│ ├── ddrnet23.yaml
│ ├── ddrnet23_slim.yaml
│ └── ddrnet39.yaml
└── face
│ └── ddrnet23_slim.yaml
├── images
├── a242.jpg
├── ddrnet.png
├── face.jpeg
├── mobile.jpg
└── png.png
├── lib
├── config
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-35.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── default.cpython-35.pyc
│ │ ├── default.cpython-37.pyc
│ │ ├── models.cpython-35.pyc
│ │ └── models.cpython-37.pyc
│ ├── default.py
│ ├── hrnet_config.py
│ └── models.py
├── core
│ ├── __pycache__
│ │ ├── criterion.cpython-37.pyc
│ │ ├── function.cpython-35.pyc
│ │ └── function.cpython-37.pyc
│ ├── criterion.py
│ └── function.py
├── datasets
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-35.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── ade20k.cpython-35.pyc
│ │ ├── ade20k.cpython-37.pyc
│ │ ├── base_dataset.cpython-35.pyc
│ │ ├── base_dataset.cpython-37.pyc
│ │ ├── cityscapes.cpython-35.pyc
│ │ ├── cityscapes.cpython-37.pyc
│ │ ├── cocostuff.cpython-35.pyc
│ │ ├── cocostuff.cpython-37.pyc
│ │ ├── face.cpython-37.pyc
│ │ ├── lip.cpython-35.pyc
│ │ ├── lip.cpython-37.pyc
│ │ ├── map.cpython-35.pyc
│ │ ├── map.cpython-37.pyc
│ │ ├── parking.cpython-35.pyc
│ │ ├── parking.cpython-37.pyc
│ │ ├── pascal_ctx.cpython-35.pyc
│ │ └── pascal_ctx.cpython-37.pyc
│ ├── ade20k.py
│ ├── base_dataset.py
│ ├── cityscapes.py
│ ├── cocostuff.py
│ ├── face.py
│ ├── lip.py
│ ├── map.py
│ ├── parking.py
│ └── pascal_ctx.py
├── models
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-35.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── bn_helper.cpython-35.pyc
│ │ ├── bn_helper.cpython-37.pyc
│ │ ├── ddrnet_23.cpython-35.pyc
│ │ ├── ddrnet_23.cpython-37.pyc
│ │ ├── ddrnet_23_slim.cpython-35.pyc
│ │ ├── ddrnet_23_slim.cpython-37.pyc
│ │ ├── ddrnet_39.cpython-35.pyc
│ │ ├── ddrnet_39.cpython-37.pyc
│ │ ├── seg_hrnet.cpython-35.pyc
│ │ ├── seg_hrnet.cpython-37.pyc
│ │ ├── seg_hrnet_ocr.cpython-35.pyc
│ │ └── seg_hrnet_ocr.cpython-37.pyc
│ ├── bn_helper.py
│ ├── ddrlite.py
│ ├── ddrnet_23.py
│ ├── ddrnet_23_slim.py
│ ├── ddrnet_23_slim_noextra.py
│ ├── ddrnet_23_slim_quant.py
│ ├── ddrnet_39.py
│ ├── hrnet.py
│ ├── seg_hrnet.py
│ ├── seg_hrnet_ocr.py
│ └── sync_bn
│ │ ├── LICENSE
│ │ ├── __init__.py
│ │ └── inplace_abn
│ │ ├── __init__.py
│ │ ├── bn.py
│ │ ├── functions.py
│ │ └── src
│ │ ├── common.h
│ │ ├── inplace_abn.cpp
│ │ ├── inplace_abn.h
│ │ ├── inplace_abn_cpu.cpp
│ │ └── inplace_abn_cuda.cu
└── utils
│ ├── DenseCRF.py
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-35.pyc
│ ├── __init__.cpython-37.pyc
│ ├── distributed.cpython-35.pyc
│ ├── distributed.cpython-37.pyc
│ ├── modelsummary.cpython-35.pyc
│ ├── modelsummary.cpython-37.pyc
│ ├── utils.cpython-35.pyc
│ └── utils.cpython-37.pyc
│ ├── distributed.py
│ ├── modelsummary.py
│ └── utils.py
├── requirements.txt
├── tools
├── __pycache__
│ ├── _init_paths.cpython-35.pyc
│ └── _init_paths.cpython-37.pyc
├── _init_paths.py
├── convert2jit.py
├── convert2trt.py
├── demo.py
├── demo_img.py
├── demo_img_noaug.py
├── demo_img_orig.py
├── eval.py
├── getwts.py
├── maks.py
├── quantize.py
├── test.py
├── to_onnx.py
├── train.py
└── train_single.py
└── train.sh
/README.md:
--------------------------------------------------------------------------------
1 | # Deep Dual-resolution Networks for Real-time and Accurate Semantic Segmentation of Road Scenes
2 |
3 | ## Introduction
4 | This is the unofficial code of [Deep Dual-resolution Networks for Real-time and Accurate Semantic Segmentation of Road Scenes](https://arxiv.org/pdf/2101.06085.pdf). the origin official is [the official repository](https://github.com/ydhongHIT/DDRNet), and I borrowed most of the code from [DDRNet.Pytorch](https://github.com/chenjun2hao/DDRNet.pytorch)thanks for their work.
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 | ## Main Change
15 |
16 | 1. Input 512*512;
17 | 2. Change DAPPM module ;
18 | 3. Data augmentation: random_brightness,random_RotateAndCrop, random_hue, random_saturation, random_contrast ...
19 | 4. Train on face segmentation dataset[Semantic_Human_Matting](https://github.com/aisegmentcn/matting_human_datasets)
20 |
21 | ## Quick start
22 |
23 | ### 1. Data preparation
24 |
25 | You need to download the [Semantic_Human_Matting](https://github.com/aisegmentcn/matting_human_datasets)datasets. and rename the folder `face`, then put the data under `data` folder.
26 | ```
27 | └── data
28 | ├── face
29 | |————train_images
30 | |————train_labels
31 | |————val_images
32 | |————val_labels
33 | └── list
34 | ```
35 |
36 | ### 2. Pretrained model
37 |
38 | download the pretrained model on imagenet or the segmentation model from the [official](https://github.com/ydhongHIT/DDRNet),and put the files in `${PROJECT}/pretrained_models` folder
39 |
40 |
41 | ### 3. TRAIN
42 |
43 | download [the imagenet pretrained model](https://github.com/ydhongHIT/DDRNet), and then train the model with 2 nvidia-3080
44 |
45 | ```python
46 | python tools/train_single.py --cfg experiments/face/ddrnet23_slim.yaml
47 | ```
48 |
49 | ## Results
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 | ## Train Custom Data
60 |
61 | The only change is to write your own dataset, you can reference to ‘./lib/datasets’
62 |
63 | ## Mobile Seg
64 |
65 | follow [TorchMobile](https://pytorch.org/mobile/home/),test with S855+ and take about 150 ms per image.
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 | ## TensorRT
76 |
77 | https://github.com/midasklr/DDRNet.TensorRT
78 |
79 | Test on RTX2070
80 |
81 | | model | input | FPS |
82 | | -------------- | --------------- | ---- |
83 | | Pytorch-aug | (3,1024,1024) | 107 |
84 | | Pytorch-no-aug | (3,1024,1024) | 108 |
85 | | TensorRT-FP32 | (3,1024,1024) | 117 |
86 | | TensorRT-FP16 | (3,1024,1024) | 215 |
87 | | TensorRT-FP16 | (3,512,512) | 334 |
88 |
89 | Pytorch-aug means augment=True.
90 |
91 | ## Reference
92 |
93 | [1] [DDRNet](https://github.com/chenjun2hao/DDRNet.pytorch)
94 |
95 | [2] [the official repository](https://github.com/ydhongHIT/DDRNet)
96 |
97 |
--------------------------------------------------------------------------------
/data/list/face_train.txt:
--------------------------------------------------------------------------------
1 | train_images/1803240928/clip_00000000/1803240928-00000229.jpg,train_labels/1803240928/matting_00000000/1803240928-00000229.png
2 | train_images/1803240928/clip_00000000/1803240928-00000212.jpg,train_labels/1803240928/matting_00000000/1803240928-00000212.png
3 |
--------------------------------------------------------------------------------
/data/list/face_val.txt:
--------------------------------------------------------------------------------
1 | val_images/1803151818/clip_00000006/1803151818-00006554.jpg,val_labels/1803151818/matting_00000006/1803151818-00006554.png
2 | val_images/1803151818/clip_00000006/1803151818-00006721.jpg,val_labels/1803151818/matting_00000006/1803151818-00006721.png
3 |
--------------------------------------------------------------------------------
/enviroment.yaml:
--------------------------------------------------------------------------------
1 | name: torch110
2 | channels:
3 | - conda-forge/label/cf202003
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 | - _libgcc_mutex=0.1=main
8 | - apr=1.6.5=h516909a_2
9 | - boost-cpp=1.70.0=ha2d47e9_1
10 | - bzip2=1.0.8=h516909a_3
11 | - c-ares=1.16.1=h516909a_3
12 | - ca-certificates=2020.11.8=ha878542_0
13 | - catkin_pkg=0.4.23=pyh9f0ad1d_0
14 | - certifi=2020.11.8=py36h5fab9bb_0
15 | - cloudpickle=1.6.0=py_0
16 | - cmake=3.18.4=h1f3970d_0
17 | - console_bridge=1.0.1=hc9558a2_0
18 | - cytoolz=0.10.1=py36h516909a_0
19 | - dask-core=2.25.0=py_0
20 | - dbus=1.13.16=hb2f20db_0
21 | - decorator=4.4.2=py_0
22 | - distro=1.5.0=pyh9f0ad1d_0
23 | - docutils=0.16=py36h9880bd3_2
24 | - empy=3.3.4=pyh9f0ad1d_1
25 | - expat=2.2.9=he6710b0_2
26 | - ffmpeg=4.3.1=h167e202_0
27 | - fontconfig=2.13.0=h9420a91_0
28 | - freetype=2.10.2=h5ab3b9f_0
29 | - future=0.18.2=py36h9f0ad1d_1
30 | - gettext=0.19.8.1=hf34092f_1004
31 | - giflib=5.2.1=h516909a_2
32 | - glib=2.63.1=h5a9c865_0
33 | - gmock=1.10.0=ha770c72_4
34 | - gmp=6.2.0=he1b5a44_2
35 | - gnutls=3.6.13=h79a8f9a_0
36 | - gpgme=1.13.1=he1b5a44_1
37 | - gst-plugins-base=1.14.0=hbbd80ab_1
38 | - gstreamer=1.14.0=hb453b48_1
39 | - gtest=1.10.0=h0efe328_4
40 | - icu=58.2=he6710b0_3
41 | - imageio=2.9.0=py_0
42 | - imgviz=1.2.2=pyh9f0ad1d_0
43 | - joblib=0.16.0=py_0
44 | - jpeg=9b=h024ee3a_2
45 | - krb5=1.17.2=h926e7f8_0
46 | - labelme=4.5.6=py36h9f0ad1d_0
47 | - lame=3.100=h14c3975_1001
48 | - lcms2=2.11=h396b838_0
49 | - libapr=1.6.5=h516909a_2
50 | - libapriconv=1.2.2=h516909a_2
51 | - libaprutil=1.6.1=ha1d75be_2
52 | - libassuan=2.5.3=he1b5a44_1
53 | - libblas=3.8.0=17_openblas
54 | - libcblas=3.8.0=17_openblas
55 | - libcurl=7.71.1=hcdd3856_8
56 | - libedit=3.1.20191231=h14c3975_1
57 | - libev=4.33=h516909a_1
58 | - libffi=3.2.1=hd88cf55_4
59 | - libgcc-ng=9.1.0=hdf63c60_0
60 | - libgfortran-ng=7.5.0=hdf63c60_16
61 | - libgpg-error=1.39=he1b5a44_0
62 | - libiconv=1.16=h516909a_0
63 | - liblapack=3.8.0=17_openblas
64 | - libnghttp2=1.41.0=h8cfc5f6_2
65 | - libopenblas=0.3.10=pthreads_hb3c22a3_4
66 | - libpng=1.6.37=hbc83047_0
67 | - libssh2=1.9.0=hab1572f_5
68 | - libstdcxx-ng=9.1.0=hdf63c60_0
69 | - libtiff=4.1.0=h2733197_1
70 | - libuuid=1.0.3=h1bed415_2
71 | - libuv=1.40.0=hd18ef5c_0
72 | - libwebp=0.5.2=7
73 | - libxcb=1.14=h7b6447c_0
74 | - libxml2=2.9.10=he19cac6_1
75 | - log4cxx=0.11.0=h0856e36_0
76 | - lz4-c=1.9.2=he1b5a44_3
77 | - matplotlib-base=3.3.1=py36h5ffbc53_1
78 | - ncurses=6.2=he6710b0_1
79 | - netifaces=0.10.9=py36h8c4c3a4_1003
80 | - nettle=3.4.1=h1bed415_1002
81 | - networkx=2.5=py_0
82 | - nose=1.3.7=py_1006
83 | - numpy=1.19.1=py36h3849536_2
84 | - olefile=0.46=py_0
85 | - openh264=2.1.1=h8b12597_0
86 | - openssl=1.1.1h=h516909a_0
87 | - pcre=8.44=he6710b0_0
88 | - pillow=7.2.0=py36hb39fc2d_0
89 | - pip=20.2.2=py36_0
90 | - pkg-config=0.29.2=h36c2ea0_1008
91 | - poco=1.10.1=h876a3cc_1
92 | - pycrypto=2.6.1=py36he6145b8_1005
93 | - pyglet=1.5.7=py36h9f0ad1d_0
94 | - pyqt=5.9.2=py36h05f1152_2
95 | - python=3.6.7=h0371630_0
96 | - python-gnupg=0.4.6=pyh9f0ad1d_0
97 | - python_abi=3.6=1_cp36m
98 | - pywavelets=1.1.1=py36h68bb277_2
99 | - qt=5.9.7=h5867ecd_1
100 | - qtpy=1.9.0=py_0
101 | - readline=7.0=h7b6447c_5
102 | - rhash=1.3.6=h516909a_1001
103 | - ros-catkin=0.7.17=py36h831f99a_5
104 | - ros-class-loader=0.4.1=h8b68381_0
105 | - ros-conda-base=0.0.2=hcb32578_2
106 | - ros-conda-mutex=1.0=melodic
107 | - ros-cpp-common=0.6.12=py36he1b5a44_2
108 | - ros-environment=1.2.1=py36h831f99a_2
109 | - ros-gencpp=0.6.2=py36h831f99a_1
110 | - ros-geneus=2.2.6=py36h831f99a_1
111 | - ros-genlisp=0.4.16=py36h831f99a_1
112 | - ros-genmsg=0.5.12=py36h831f99a_1
113 | - ros-gennodejs=2.0.1=py36h831f99a_1
114 | - ros-genpy=0.6.8=py36h831f99a_1
115 | - ros-message-generation=0.4.0=h831f99a_1
116 | - ros-message-runtime=0.4.12=he1b5a44_0
117 | - ros-pluginlib=1.12.1=h8b68381_0
118 | - ros-rosbag=1.14.3=py36h8b68381_0
119 | - ros-rosbag-storage=1.14.3=hbe7f094_0
120 | - ros-rosbuild=1.14.6=he1b5a44_0
121 | - ros-rosconsole=1.13.10=h8b68381_0
122 | - ros-roscpp=1.14.3=py36h8b68381_2
123 | - ros-roscpp-serialization=0.6.12=he1b5a44_0
124 | - ros-roscpp-traits=0.6.12=he1b5a44_0
125 | - ros-rosgraph=1.14.3=py36h831f99a_1
126 | - ros-rosgraph-msgs=1.11.2=py36h831f99a_1
127 | - ros-roslib=1.14.6=py36h77863c7_4
128 | - ros-roslz4=1.14.10.1=py36h1dc43ef_1
129 | - ros-rospack=2.5.3=py36h8b68381_0
130 | - ros-rospy=1.14.3=py36he1b5a44_0
131 | - ros-rostime=0.6.12=h8b68381_0
132 | - ros-std-msgs=0.5.12=py36h831f99a_1
133 | - ros-std-srvs=1.11.2=py36h831f99a_1
134 | - ros-topic-tools=1.14.3=py36h831f99a_1
135 | - ros-xmlrpcpp=1.14.3=he1b5a44_0
136 | - rosdep=0.20.0=py36h5fab9bb_0
137 | - rosdistro=0.8.3=py36h9f0ad1d_1
138 | - rospkg=1.2.9=pyhd3deb0d_0
139 | - scikit-image=0.15.0=py36hb3f55d8_2
140 | - scikit-learn=0.23.2=py36hfb379a7_0
141 | - scipy=1.5.2=py36h3a855aa_0
142 | - setuptools=49.6.0=py36_0
143 | - sip=4.19.8=py36hf484d3e_0
144 | - six=1.15.0=pyh9f0ad1d_0
145 | - sqlite=3.32.3=h62c20be_0
146 | - threadpoolctl=2.1.0=pyh5ca1d4c_0
147 | - tinyxml2=8.0.0=he1b5a44_1
148 | - tk=8.6.10=hbc83047_0
149 | - toolz=0.10.0=py_0
150 | - tornado=6.0.4=py36h8c4c3a4_1
151 | - unixodbc=2.3.9=h0e019cf_0
152 | - wheel=0.34.2=py36_0
153 | - x264=1!152.20180806=h14c3975_0
154 | - xz=5.2.5=h7b6447c_0
155 | - yaml=0.2.5=h516909a_0
156 | - zlib=1.2.11=h7b6447c_3
157 | - zstd=1.4.5=h6597ccf_2
158 | - pip:
159 | - absl-py==0.10.0
160 | - aniso8601==8.0.0
161 | - apex==0.1
162 | - appdirs==1.4.4
163 | - bidict==0.21.0
164 | - cachetools==4.1.1
165 | - chardet==3.0.4
166 | - click==7.1.2
167 | - cycler==0.10.0
168 | - cython==0.29.21
169 | - dataclasses==0.8
170 | - easydict==1.7
171 | - fire==0.3.1
172 | - flask==1.1.2
173 | - flask-restful==0.3.8
174 | - fvcore==0.1.2.post20200912
175 | - google-auth==1.21.3
176 | - google-auth-oauthlib==0.4.1
177 | - grpcio==1.31.0
178 | - h5py==2.10.0
179 | - idna==2.10
180 | - importlib-metadata==2.0.0
181 | - iniconfig==1.1.1
182 | - itsdangerous==1.1.0
183 | - jinja2==2.11.2
184 | - json-tricks==3.15.4
185 | - jsonschema==2.6.0
186 | - kiwisolver==1.2.0
187 | - mako==1.1.3
188 | - markdown==3.2.2
189 | - markupsafe==1.1.1
190 | - matplotlib==3.3.1
191 | - ninja==1.10.0.post2
192 | - oauthlib==3.1.0
193 | - onnx==1.4.1
194 | - onnxruntime==1.5.2
195 | - opencv-python==3.4.1.15
196 | - packaging==20.7
197 | - pandas==1.1.1
198 | - pascal-voc-writer==0.1.4
199 | - plotly==4.5.4
200 | - pluggy==0.13.1
201 | - portalocker==2.0.0
202 | - protobuf==3.13.0
203 | - ptable==0.9.2
204 | - py==1.9.0
205 | - pyasn1==0.4.8
206 | - pyasn1-modules==0.2.8
207 | - pycocotools==2.0
208 | - pycuda==2020.1
209 | - pydensecrf==1.0rc3
210 | - pyparsing==3.0.0a2
211 | - pytest==6.1.2
212 | - python-dateutil==2.8.1
213 | - python-json-logger==0.1.8
214 | - python-speech-features==0.6
215 | - pytools==2020.4.3
216 | - pytz==2020.1
217 | - pyyaml==5.3.1
218 | - requests==2.24.0
219 | - requests-oauthlib==1.3.0
220 | - requests-toolbelt==0.9.1
221 | - retrying==1.3.3
222 | - rsa==4.6
223 | - shapely==1.6.4
224 | - simplejson==3.17.2
225 | - sk-video==1.1.10
226 | - tabulate==0.8.7
227 | - tensorboard==2.3.0
228 | - tensorboard-plugin-wit==1.7.0
229 | - tensorboardx==2.1
230 | - tensorflow==1.0.0
231 | - tensorrt==5.1.5.0
232 | - termcolor==1.1.0
233 | - thop==0.0.31-2005241907
234 | - toml==0.10.2
235 | - torch==1.1.0
236 | - torchcontrib==0.0.2
237 | - torchvision==0.3.0
238 | - tqdm==4.48.2
239 | - typing==3.7.4.3
240 | - typing-extensions==3.7.4.3
241 | - urllib3==1.25.10
242 | - werkzeug==1.0.1
243 | - yacs==0.1.8
244 | - zipp==3.2.0
245 | prefix: /home/data/miniconda3/envs/torch110
246 |
247 |
--------------------------------------------------------------------------------
/experiments/cityscapes/ddrnet23.yaml:
--------------------------------------------------------------------------------
1 | CUDNN:
2 | BENCHMARK: true
3 | DETERMINISTIC: false
4 | ENABLED: true
5 | GPUS: (0,1)
6 | OUTPUT_DIR: 'output'
7 | LOG_DIR: 'log'
8 | WORKERS: 4
9 | PRINT_FREQ: 10
10 |
11 | DATASET:
12 | DATASET: cityscapes
13 | ROOT: data/
14 | TEST_SET: 'list/cityscapes/val.lst'
15 | TRAIN_SET: 'list/cityscapes/train.lst'
16 | NUM_CLASSES: 19
17 | MODEL:
18 | NAME: ddrnet_23
19 | NUM_OUTPUTS: 2
20 | PRETRAINED: "pretrained_models/DDRNet23_imagenet.pth"
21 | ALIGN_CORNERS: false
22 | LOSS:
23 | USE_OHEM: true
24 | OHEMTHRES: 0.9
25 | OHEMKEEP: 131072
26 | BALANCE_WEIGHTS: [1, 0.4]
27 | TRAIN:
28 | IMAGE_SIZE:
29 | - 1024
30 | - 1024
31 | BASE_SIZE: 2048
32 | BATCH_SIZE_PER_GPU: 8
33 | SHUFFLE: true
34 | BEGIN_EPOCH: 0
35 | END_EPOCH: 484
36 | RESUME: false
37 | OPTIMIZER: sgd
38 | LR: 0.01
39 | WD: 0.0005
40 | MOMENTUM: 0.9
41 | NESTEROV: false
42 | FLIP: true
43 | MULTI_SCALE: true
44 | DOWNSAMPLERATE: 1
45 | IGNORE_LABEL: 255
46 | SCALE_FACTOR: 16
47 | TEST:
48 | IMAGE_SIZE:
49 | - 2048
50 | - 1024
51 | BASE_SIZE: 2048
52 | BATCH_SIZE_PER_GPU: 4
53 | FLIP_TEST: true
54 | MULTI_SCALE: false
55 | MODEL_FILE: "pretrained_models/best_val.pth"
56 | OUTPUT_INDEX: 0
57 |
--------------------------------------------------------------------------------
/experiments/cityscapes/ddrnet23_slim.yaml:
--------------------------------------------------------------------------------
1 | CUDNN:
2 | BENCHMARK: true
3 | DETERMINISTIC: false
4 | ENABLED: true
5 | GPUS: 0
6 | OUTPUT_DIR: 'output'
7 | LOG_DIR: 'log'
8 | WORKERS: 4
9 | PRINT_FREQ: 10
10 |
11 | DATASET:
12 | DATASET: cityscapes
13 | ROOT: data/
14 | TEST_SET: 'list/cityscapes/val.lst'
15 | TRAIN_SET: 'list/cityscapes/train.lst'
16 | NUM_CLASSES: 19
17 | MODEL:
18 | NAME: ddrnet_23_slim
19 | NUM_OUTPUTS: 2
20 | PRETRAINED: "/home/hwits/Documents/CarVid/DDRNet/DDRNet.pytorch/best_val_smaller.pth"
21 | ALIGN_CORNERS: false
22 | LOSS:
23 | USE_OHEM: true
24 | OHEMTHRES: 0.9
25 | OHEMKEEP: 131072
26 | BALANCE_WEIGHTS: [1, 0.4]
27 | TRAIN:
28 | IMAGE_SIZE:
29 | - 1024
30 | - 1024
31 | BASE_SIZE: 2048
32 | BATCH_SIZE_PER_GPU: 8
33 | SHUFFLE: true
34 | BEGIN_EPOCH: 0
35 | END_EPOCH: 484
36 | RESUME: false
37 | OPTIMIZER: sgd
38 | LR: 0.01
39 | WD: 0.0005
40 | MOMENTUM: 0.9
41 | NESTEROV: false
42 | FLIP: true
43 | MULTI_SCALE: true
44 | DOWNSAMPLERATE: 1
45 | IGNORE_LABEL: 255
46 | SCALE_FACTOR: 16
47 | TEST:
48 | IMAGE_SIZE:
49 | - 2048
50 | - 1024
51 | BASE_SIZE: 2048
52 | BATCH_SIZE_PER_GPU: 4
53 | FLIP_TEST: false
54 | MULTI_SCALE: false
55 | SCALE_LIST: [1]
56 | #0.5,0.75,1.0,1.25,1.5,1.75
57 | MODEL_FILE: "/home/hwits/Documents/CarVid/DDRNet/DDRNet.pytorch/best_val_smaller.pth"
58 | OUTPUT_INDEX: 0
59 |
--------------------------------------------------------------------------------
/experiments/cityscapes/ddrnet39.yaml:
--------------------------------------------------------------------------------
1 | CUDNN:
2 | BENCHMARK: true
3 | DETERMINISTIC: false
4 | ENABLED: true
5 | GPUS: (0,1)
6 | OUTPUT_DIR: 'output'
7 | LOG_DIR: 'log'
8 | WORKERS: 4
9 | PRINT_FREQ: 10
10 |
11 | DATASET:
12 | DATASET: cityscapes
13 | ROOT: data/
14 | TEST_SET: 'list/cityscapes/val.lst'
15 | TRAIN_SET: 'list/cityscapes/train.lst'
16 | NUM_CLASSES: 19
17 | MODEL:
18 | NAME: ddrnet_39
19 | NUM_OUTPUTS: 2
20 | PRETRAINED: "pretrained_models/DDRNet39_imagenet.pth"
21 | LOSS:
22 | USE_OHEM: true
23 | OHEMTHRES: 0.9
24 | OHEMKEEP: 131072
25 | BALANCE_WEIGHTS: [1, 0.4]
26 | TRAIN:
27 | IMAGE_SIZE:
28 | - 1024
29 | - 1024
30 | BASE_SIZE: 2048
31 | BATCH_SIZE_PER_GPU: 8
32 | SHUFFLE: true
33 | BEGIN_EPOCH: 0
34 | END_EPOCH: 484
35 | RESUME: false
36 | OPTIMIZER: sgd
37 | LR: 0.01
38 | WD: 0.0005
39 | MOMENTUM: 0.9
40 | NESTEROV: false
41 | FLIP: true
42 | MULTI_SCALE: true
43 | DOWNSAMPLERATE: 1
44 | IGNORE_LABEL: 255
45 | SCALE_FACTOR: 16
46 | TEST:
47 | IMAGE_SIZE:
48 | - 2048
49 | - 1024
50 | BASE_SIZE: 2048
51 | BATCH_SIZE_PER_GPU: 4
52 | FLIP_TEST: true
53 | MULTI_SCALE: false
54 | MODEL_FILE:
55 | OUTPUT_INDEX: 0
56 |
--------------------------------------------------------------------------------
/experiments/face/ddrnet23_slim.yaml:
--------------------------------------------------------------------------------
1 | CUDNN:
2 | BENCHMARK: true
3 | DETERMINISTIC: false
4 | ENABLED: true
5 | GPUS: [0]
6 | OUTPUT_DIR: 'output'
7 | LOG_DIR: 'log'
8 | WORKERS: 4
9 | PRINT_FREQ: 10
10 |
11 | DATASET:
12 | DATASET: face
13 | ROOT: data/
14 | TEST_SET: 'list/face_val.txt'
15 | TRAIN_SET: 'list/face_train.txt'
16 | NUM_CLASSES: 2
17 | MODEL:
18 | NAME: ddrnet_23_slim
19 | NUM_OUTPUTS: 2
20 | PRETRAINED: "/home/kong/Documents/DDRNet.Pytorch/DDRNet.Pytorch/pretrained_models/best_val_smaller.pth"
21 | ALIGN_CORNERS: false
22 | LOSS:
23 | USE_OHEM: true
24 | OHEMTHRES: 0.9
25 | OHEMKEEP: 131072
26 | BALANCE_WEIGHTS: [1, 0.4]
27 | TRAIN:
28 | IMAGE_SIZE:
29 | - 512
30 | - 512
31 | BASE_SIZE: 512
32 | BATCH_SIZE_PER_GPU: 32
33 | SHUFFLE: true
34 | BEGIN_EPOCH: 0
35 | END_EPOCH: 50
36 | RANDOM_BRIGHTNESS: true
37 | RESUME: false
38 | OPTIMIZER: sgd
39 | LR: 0.001
40 | WD: 0.0005
41 | MOMENTUM: 0.9
42 | NESTEROV: false
43 | FLIP: true
44 | MULTI_SCALE: true
45 | DOWNSAMPLERATE: 1
46 | IGNORE_LABEL: 255
47 | SCALE_FACTOR: 15
48 | RANDOM_ROTATE: true
49 | CONTRAST: true
50 | SATURATION: true
51 | HUE: true
52 | TEST:
53 | IMAGE_SIZE:
54 | - 512
55 | - 512
56 | BASE_SIZE: 512
57 | BATCH_SIZE_PER_GPU: 4
58 | FLIP_TEST: false
59 | MULTI_SCALE: false
60 | SCALE_LIST: [1]
61 | #0.5,0.75,1.0,1.25,1.5,1.75
62 | MODEL_FILE: "/home/hwits/Documents/CarVid/DDRNet/DDRNet.pytorch/best_val_smaller.pth"
63 | OUTPUT_INDEX: 0
64 |
--------------------------------------------------------------------------------
/images/a242.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/images/a242.jpg
--------------------------------------------------------------------------------
/images/ddrnet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/images/ddrnet.png
--------------------------------------------------------------------------------
/images/face.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/images/face.jpeg
--------------------------------------------------------------------------------
/images/mobile.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/images/mobile.jpg
--------------------------------------------------------------------------------
/images/png.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/images/png.png
--------------------------------------------------------------------------------
/lib/config/__init__.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn)
5 | # ------------------------------------------------------------------------------
6 | from __future__ import absolute_import
7 | from __future__ import division
8 | from __future__ import print_function
9 |
10 | from .default import _C as config
11 | from .default import update_config
12 | from .models import MODEL_EXTRAS
13 |
--------------------------------------------------------------------------------
/lib/config/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/config/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/lib/config/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/config/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/config/__pycache__/default.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/config/__pycache__/default.cpython-35.pyc
--------------------------------------------------------------------------------
/lib/config/__pycache__/default.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/config/__pycache__/default.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/config/__pycache__/models.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/config/__pycache__/models.cpython-35.pyc
--------------------------------------------------------------------------------
/lib/config/__pycache__/models.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/config/__pycache__/models.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/config/default.py:
--------------------------------------------------------------------------------
1 |
2 | # ------------------------------------------------------------------------------
3 | # Copyright (c) Microsoft
4 | # Licensed under the MIT License.
5 | # Written by Ke Sun (sunk@mail.ustc.edu.cn)
6 | # ------------------------------------------------------------------------------
7 |
8 | from __future__ import absolute_import
9 | from __future__ import division
10 | from __future__ import print_function
11 |
12 | import os
13 |
14 | from yacs.config import CfgNode as CN
15 |
16 |
17 | _C = CN()
18 |
19 | _C.OUTPUT_DIR = ''
20 | _C.LOG_DIR = ''
21 | _C.GPUS = [0]
22 | _C.WORKERS = 4
23 | _C.PRINT_FREQ = 20
24 | _C.AUTO_RESUME = False
25 | _C.PIN_MEMORY = True
26 | _C.RANK = 0
27 |
28 | # Cudnn related params
29 | _C.CUDNN = CN()
30 | _C.CUDNN.BENCHMARK = True
31 | _C.CUDNN.DETERMINISTIC = False
32 | _C.CUDNN.ENABLED = True
33 |
34 | # common params for NETWORK
35 | _C.MODEL = CN()
36 | _C.MODEL.NAME = 'seg_hrnet'
37 | _C.MODEL.PRETRAINED = ''
38 | _C.MODEL.ALIGN_CORNERS = True
39 | _C.MODEL.NUM_OUTPUTS = 2
40 | _C.MODEL.EXTRA = CN(new_allowed=True)
41 |
42 |
43 | _C.MODEL.OCR = CN()
44 | _C.MODEL.OCR.MID_CHANNELS = 512
45 | _C.MODEL.OCR.KEY_CHANNELS = 256
46 | _C.MODEL.OCR.DROPOUT = 0.05
47 | _C.MODEL.OCR.SCALE = 1
48 |
49 | _C.LOSS = CN()
50 | _C.LOSS.USE_OHEM = False
51 | _C.LOSS.OHEMTHRES = 0.9
52 | _C.LOSS.OHEMKEEP = 100000
53 | _C.LOSS.CLASS_BALANCE = False
54 | _C.LOSS.BALANCE_WEIGHTS = [0.5, 0.5]
55 |
56 | # DATASET related params
57 | _C.DATASET = CN()
58 | _C.DATASET.MODEL = 'train'
59 | _C.DATASET.ROOT = ''
60 | _C.DATASET.DATASET = 'cityscapes'
61 | _C.DATASET.NUM_CLASSES = 19
62 | _C.DATASET.TRAIN_SET = 'list/cityscapes/train.lst'
63 | _C.DATASET.EXTRA_TRAIN_SET = ''
64 | _C.DATASET.TEST_SET = 'list/cityscapes/val.lst'
65 |
66 | # training
67 | _C.TRAIN = CN()
68 |
69 | _C.TRAIN.FREEZE_LAYERS = ''
70 | _C.TRAIN.FREEZE_EPOCHS = -1
71 | _C.TRAIN.NONBACKBONE_KEYWORDS = []
72 | _C.TRAIN.NONBACKBONE_MULT = 10
73 |
74 | _C.TRAIN.IMAGE_SIZE = [1024, 512] # width * height
75 | _C.TRAIN.BASE_SIZE = 2048
76 | _C.TRAIN.DOWNSAMPLERATE = 1
77 | _C.TRAIN.FLIP = True
78 | _C.TRAIN.MULTI_SCALE = True
79 | _C.TRAIN.SCALE_FACTOR = 16
80 |
81 | _C.TRAIN.RANDOM_BRIGHTNESS = True
82 | _C.TRAIN.RANDOM_BRIGHTNESS_SHIFT_VALUE = 20
83 |
84 | _C.TRAIN.LR_FACTOR = 0.1
85 | _C.TRAIN.LR_STEP = [60, 80]
86 | # _C.TRAIN.LR_STEP = [90, 110]
87 | _C.TRAIN.LR = 0.01
88 | _C.TRAIN.EXTRA_LR = 0.001
89 |
90 | _C.TRAIN.OPTIMIZER = 'sgd'
91 | _C.TRAIN.MOMENTUM = 0.9
92 | _C.TRAIN.WD = 0.0001
93 | _C.TRAIN.NESTEROV = False
94 | _C.TRAIN.IGNORE_LABEL = -1
95 |
96 | _C.TRAIN.BEGIN_EPOCH = 0
97 | _C.TRAIN.END_EPOCH = 484
98 | _C.TRAIN.EXTRA_EPOCH = 0
99 |
100 | _C.TRAIN.RESUME = False
101 |
102 | _C.TRAIN.BATCH_SIZE_PER_GPU = 32
103 | _C.TRAIN.SHUFFLE = True
104 | _C.TRAIN.HUE = True
105 | _C.TRAIN.RANDOM_HUE_VALUE = (-0.01, 0.01)
106 | _C.TRAIN.SATURATION = True
107 | _C.TRAIN.RANDOM_SAT_VALUE = (0.95, 1.05)
108 | _C.TRAIN.CONTRAST = True
109 | _C.TRAIN.RANDOM_CONTRAST_VALUE = (0.9, 1.1)
110 | _C.TRAIN.RANDOM_ROTATE = True
111 | _C.TRAIN.RANDOM_ROTATE_ANGLE = 20
112 | # only using some training samples
113 | _C.TRAIN.NUM_SAMPLES = 0
114 |
115 | # testing
116 | _C.TEST = CN()
117 |
118 | _C.TEST.IMAGE_SIZE = [2048, 1024] # width * height
119 | _C.TEST.BASE_SIZE = 2048
120 |
121 | _C.TEST.BATCH_SIZE_PER_GPU = 32
122 | # only testing some samples
123 | _C.TEST.NUM_SAMPLES = 0
124 |
125 | _C.TEST.MODEL_FILE = ''
126 | _C.TEST.FLIP_TEST = False
127 | _C.TEST.MULTI_SCALE = False
128 | _C.TEST.SCALE_LIST = [1]
129 |
130 | _C.TEST.OUTPUT_INDEX = -1
131 |
132 | # debug
133 | _C.DEBUG = CN()
134 | _C.DEBUG.DEBUG = False
135 | _C.DEBUG.SAVE_BATCH_IMAGES_GT = False
136 | _C.DEBUG.SAVE_BATCH_IMAGES_PRED = False
137 | _C.DEBUG.SAVE_HEATMAPS_GT = False
138 | _C.DEBUG.SAVE_HEATMAPS_PRED = False
139 |
140 |
141 | def update_config(cfg, args):
142 | cfg.defrost()
143 |
144 | cfg.merge_from_file(args.cfg)
145 | cfg.merge_from_list(args.opts)
146 |
147 | cfg.freeze()
148 |
149 |
150 | if __name__ == '__main__':
151 | import sys
152 | with open(sys.argv[1], 'w') as f:
153 | print(_C, file=f)
154 |
155 |
--------------------------------------------------------------------------------
/lib/config/hrnet_config.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Create by Bin Xiao (Bin.Xiao@microsoft.com)
5 | # Modified by Ke Sun (sunk@mail.ustc.edu.cn), Rainbowsecret (yuyua@microsoft.com)
6 | # ------------------------------------------------------------------------------
7 |
8 | from __future__ import absolute_import
9 | from __future__ import division
10 | from __future__ import print_function
11 |
12 | from yacs.config import CfgNode as CN
13 |
14 |
15 | # configs for HRNet48
16 | HRNET_48 = CN()
17 | HRNET_48.FINAL_CONV_KERNEL = 1
18 |
19 | HRNET_48.STAGE1 = CN()
20 | HRNET_48.STAGE1.NUM_MODULES = 1
21 | HRNET_48.STAGE1.NUM_BRANCHES = 1
22 | HRNET_48.STAGE1.NUM_BLOCKS = [4]
23 | HRNET_48.STAGE1.NUM_CHANNELS = [64]
24 | HRNET_48.STAGE1.BLOCK = 'BOTTLENECK'
25 | HRNET_48.STAGE1.FUSE_METHOD = 'SUM'
26 |
27 | HRNET_48.STAGE2 = CN()
28 | HRNET_48.STAGE2.NUM_MODULES = 1
29 | HRNET_48.STAGE2.NUM_BRANCHES = 2
30 | HRNET_48.STAGE2.NUM_BLOCKS = [4, 4]
31 | HRNET_48.STAGE2.NUM_CHANNELS = [48, 96]
32 | HRNET_48.STAGE2.BLOCK = 'BASIC'
33 | HRNET_48.STAGE2.FUSE_METHOD = 'SUM'
34 |
35 | HRNET_48.STAGE3 = CN()
36 | HRNET_48.STAGE3.NUM_MODULES = 4
37 | HRNET_48.STAGE3.NUM_BRANCHES = 3
38 | HRNET_48.STAGE3.NUM_BLOCKS = [4, 4, 4]
39 | HRNET_48.STAGE3.NUM_CHANNELS = [48, 96, 192]
40 | HRNET_48.STAGE3.BLOCK = 'BASIC'
41 | HRNET_48.STAGE3.FUSE_METHOD = 'SUM'
42 |
43 | HRNET_48.STAGE4 = CN()
44 | HRNET_48.STAGE4.NUM_MODULES = 3
45 | HRNET_48.STAGE4.NUM_BRANCHES = 4
46 | HRNET_48.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
47 | HRNET_48.STAGE4.NUM_CHANNELS = [48, 96, 192, 384]
48 | HRNET_48.STAGE4.BLOCK = 'BASIC'
49 | HRNET_48.STAGE4.FUSE_METHOD = 'SUM'
50 |
51 |
52 | # configs for HRNet32
53 | HRNET_32 = CN()
54 | HRNET_32.FINAL_CONV_KERNEL = 1
55 |
56 | HRNET_32.STAGE1 = CN()
57 | HRNET_32.STAGE1.NUM_MODULES = 1
58 | HRNET_32.STAGE1.NUM_BRANCHES = 1
59 | HRNET_32.STAGE1.NUM_BLOCKS = [4]
60 | HRNET_32.STAGE1.NUM_CHANNELS = [64]
61 | HRNET_32.STAGE1.BLOCK = 'BOTTLENECK'
62 | HRNET_32.STAGE1.FUSE_METHOD = 'SUM'
63 |
64 | HRNET_32.STAGE2 = CN()
65 | HRNET_32.STAGE2.NUM_MODULES = 1
66 | HRNET_32.STAGE2.NUM_BRANCHES = 2
67 | HRNET_32.STAGE2.NUM_BLOCKS = [4, 4]
68 | HRNET_32.STAGE2.NUM_CHANNELS = [32, 64]
69 | HRNET_32.STAGE2.BLOCK = 'BASIC'
70 | HRNET_32.STAGE2.FUSE_METHOD = 'SUM'
71 |
72 | HRNET_32.STAGE3 = CN()
73 | HRNET_32.STAGE3.NUM_MODULES = 4
74 | HRNET_32.STAGE3.NUM_BRANCHES = 3
75 | HRNET_32.STAGE3.NUM_BLOCKS = [4, 4, 4]
76 | HRNET_32.STAGE3.NUM_CHANNELS = [32, 64, 128]
77 | HRNET_32.STAGE3.BLOCK = 'BASIC'
78 | HRNET_32.STAGE3.FUSE_METHOD = 'SUM'
79 |
80 | HRNET_32.STAGE4 = CN()
81 | HRNET_32.STAGE4.NUM_MODULES = 3
82 | HRNET_32.STAGE4.NUM_BRANCHES = 4
83 | HRNET_32.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
84 | HRNET_32.STAGE4.NUM_CHANNELS = [32, 64, 128, 256]
85 | HRNET_32.STAGE4.BLOCK = 'BASIC'
86 | HRNET_32.STAGE4.FUSE_METHOD = 'SUM'
87 |
88 |
89 | # configs for HRNet18
90 | HRNET_18 = CN()
91 | HRNET_18.FINAL_CONV_KERNEL = 1
92 |
93 | HRNET_18.STAGE1 = CN()
94 | HRNET_18.STAGE1.NUM_MODULES = 1
95 | HRNET_18.STAGE1.NUM_BRANCHES = 1
96 | HRNET_18.STAGE1.NUM_BLOCKS = [4]
97 | HRNET_18.STAGE1.NUM_CHANNELS = [64]
98 | HRNET_18.STAGE1.BLOCK = 'BOTTLENECK'
99 | HRNET_18.STAGE1.FUSE_METHOD = 'SUM'
100 |
101 | HRNET_18.STAGE2 = CN()
102 | HRNET_18.STAGE2.NUM_MODULES = 1
103 | HRNET_18.STAGE2.NUM_BRANCHES = 2
104 | HRNET_18.STAGE2.NUM_BLOCKS = [4, 4]
105 | HRNET_18.STAGE2.NUM_CHANNELS = [18, 36]
106 | HRNET_18.STAGE2.BLOCK = 'BASIC'
107 | HRNET_18.STAGE2.FUSE_METHOD = 'SUM'
108 |
109 | HRNET_18.STAGE3 = CN()
110 | HRNET_18.STAGE3.NUM_MODULES = 4
111 | HRNET_18.STAGE3.NUM_BRANCHES = 3
112 | HRNET_18.STAGE3.NUM_BLOCKS = [4, 4, 4]
113 | HRNET_18.STAGE3.NUM_CHANNELS = [18, 36, 72]
114 | HRNET_18.STAGE3.BLOCK = 'BASIC'
115 | HRNET_18.STAGE3.FUSE_METHOD = 'SUM'
116 |
117 | HRNET_18.STAGE4 = CN()
118 | HRNET_18.STAGE4.NUM_MODULES = 3
119 | HRNET_18.STAGE4.NUM_BRANCHES = 4
120 | HRNET_18.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
121 | HRNET_18.STAGE4.NUM_CHANNELS = [18, 36, 72, 144]
122 | HRNET_18.STAGE4.BLOCK = 'BASIC'
123 | HRNET_18.STAGE4.FUSE_METHOD = 'SUM'
124 |
125 |
126 | MODEL_CONFIGS = {
127 | 'hrnet18': HRNET_18,
128 | 'hrnet32': HRNET_32,
129 | 'hrnet48': HRNET_48,
130 | }
--------------------------------------------------------------------------------
/lib/config/models.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn)
5 | # ------------------------------------------------------------------------------
6 |
7 | from __future__ import absolute_import
8 | from __future__ import division
9 | from __future__ import print_function
10 |
11 | from yacs.config import CfgNode as CN
12 |
13 | # high_resoluton_net related params for segmentation
14 | HIGH_RESOLUTION_NET = CN()
15 | HIGH_RESOLUTION_NET.PRETRAINED_LAYERS = ['*']
16 | HIGH_RESOLUTION_NET.STEM_INPLANES = 64
17 | HIGH_RESOLUTION_NET.FINAL_CONV_KERNEL = 1
18 | HIGH_RESOLUTION_NET.WITH_HEAD = True
19 |
20 | HIGH_RESOLUTION_NET.STAGE2 = CN()
21 | HIGH_RESOLUTION_NET.STAGE2.NUM_MODULES = 1
22 | HIGH_RESOLUTION_NET.STAGE2.NUM_BRANCHES = 2
23 | HIGH_RESOLUTION_NET.STAGE2.NUM_BLOCKS = [4, 4]
24 | HIGH_RESOLUTION_NET.STAGE2.NUM_CHANNELS = [32, 64]
25 | HIGH_RESOLUTION_NET.STAGE2.BLOCK = 'BASIC'
26 | HIGH_RESOLUTION_NET.STAGE2.FUSE_METHOD = 'SUM'
27 |
28 | HIGH_RESOLUTION_NET.STAGE3 = CN()
29 | HIGH_RESOLUTION_NET.STAGE3.NUM_MODULES = 1
30 | HIGH_RESOLUTION_NET.STAGE3.NUM_BRANCHES = 3
31 | HIGH_RESOLUTION_NET.STAGE3.NUM_BLOCKS = [4, 4, 4]
32 | HIGH_RESOLUTION_NET.STAGE3.NUM_CHANNELS = [32, 64, 128]
33 | HIGH_RESOLUTION_NET.STAGE3.BLOCK = 'BASIC'
34 | HIGH_RESOLUTION_NET.STAGE3.FUSE_METHOD = 'SUM'
35 |
36 | HIGH_RESOLUTION_NET.STAGE4 = CN()
37 | HIGH_RESOLUTION_NET.STAGE4.NUM_MODULES = 1
38 | HIGH_RESOLUTION_NET.STAGE4.NUM_BRANCHES = 4
39 | HIGH_RESOLUTION_NET.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
40 | HIGH_RESOLUTION_NET.STAGE4.NUM_CHANNELS = [32, 64, 128, 256]
41 | HIGH_RESOLUTION_NET.STAGE4.BLOCK = 'BASIC'
42 | HIGH_RESOLUTION_NET.STAGE4.FUSE_METHOD = 'SUM'
43 |
44 | MODEL_EXTRAS = {
45 | 'seg_hrnet': HIGH_RESOLUTION_NET,
46 | }
47 |
--------------------------------------------------------------------------------
/lib/core/__pycache__/criterion.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/core/__pycache__/criterion.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/core/__pycache__/function.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/core/__pycache__/function.cpython-35.pyc
--------------------------------------------------------------------------------
/lib/core/__pycache__/function.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/core/__pycache__/function.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/core/criterion.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn)
5 | # ------------------------------------------------------------------------------
6 |
7 | import torch
8 | import torch.nn as nn
9 | from torch.nn import functional as F
10 | import logging
11 | from config import config
12 |
13 |
14 | class CrossEntropy(nn.Module):
15 | def __init__(self, ignore_label=-1, weight=None):
16 | super(CrossEntropy, self).__init__()
17 | self.ignore_label = ignore_label
18 | self.criterion = nn.CrossEntropyLoss(
19 | weight=weight,
20 | ignore_index=ignore_label
21 | )
22 |
23 | def _forward(self, score, target):
24 | # print("score:",score.size(),"target:",target.size())
25 | ph, pw = score.size(2), score.size(3)
26 | h, w = target.size(1), target.size(2)
27 | if ph != h or pw != w:
28 | score = F.interpolate(input=score, size=(
29 | h, w), mode='nearest')
30 | # print("score:",score.size())
31 |
32 | loss = self.criterion(score, target)
33 |
34 | return loss
35 |
36 | def forward(self, score, target):
37 | # print("score size:...",len(score))
38 |
39 | if config.MODEL.NUM_OUTPUTS == 1:
40 | score = [score]
41 |
42 | weights = config.LOSS.BALANCE_WEIGHTS
43 | assert len(weights) == len(score)
44 |
45 | return sum([w * self._forward(x, target) for (w, x) in zip(weights, score)])
46 |
47 |
48 | class OhemCrossEntropy(nn.Module):
49 | def __init__(self, ignore_label=-1, thres=0.7,
50 | min_kept=100000, weight=None):
51 | super(OhemCrossEntropy, self).__init__()
52 | self.thresh = thres
53 | self.min_kept = max(1, min_kept)
54 | self.ignore_label = ignore_label
55 | self.criterion = nn.CrossEntropyLoss(
56 | weight=weight,
57 | ignore_index=ignore_label,
58 | reduction='none'
59 | )
60 |
61 | def _ce_forward(self, score, target):
62 | ph, pw = score.size(2), score.size(3)
63 | h, w = target.size(1), target.size(2)
64 | if ph != h or pw != w:
65 | score = F.interpolate(input=score, size=(
66 | h, w), mode='nearest')
67 |
68 | loss = self.criterion(score, target)
69 |
70 | return loss
71 |
72 | def _ohem_forward(self, score, target, **kwargs):
73 | ph, pw = score.size(2), score.size(3)
74 | h, w = target.size(1), target.size(2)
75 | if ph != h or pw != w:
76 | score = F.interpolate(input=score, size=(
77 | h, w), mode='nearest')
78 | pred = F.softmax(score, dim=1)
79 | pixel_losses = self.criterion(score, target).contiguous().view(-1)
80 | mask = target.contiguous().view(-1) != self.ignore_label
81 |
82 | tmp_target = target.clone()
83 | tmp_target[tmp_target == self.ignore_label] = 0
84 | pred = pred.gather(1, tmp_target.unsqueeze(1))
85 | pred, ind = pred.contiguous().view(-1,)[mask].contiguous().sort()
86 | min_value = pred[min(self.min_kept, pred.numel() - 1)]
87 | threshold = max(min_value, self.thresh)
88 |
89 | pixel_losses = pixel_losses[mask][ind]
90 | pixel_losses = pixel_losses[pred < threshold]
91 | return pixel_losses.mean()
92 |
93 | def forward(self, score, target):
94 |
95 | if config.MODEL.NUM_OUTPUTS == 1:
96 | score = [score]
97 |
98 | weights = config.LOSS.BALANCE_WEIGHTS
99 | assert len(weights) == len(score)
100 |
101 | functions = [self._ce_forward] * \
102 | (len(weights) - 1) + [self._ohem_forward]
103 | # print("loss weight : ",weights, len(score), functions)
104 | return sum([
105 | w * func(x, target)
106 | for (w, x, func) in zip(weights, score, functions)
107 | ])
108 |
--------------------------------------------------------------------------------
/lib/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn)
5 | # ------------------------------------------------------------------------------
6 |
7 | from __future__ import absolute_import
8 | from __future__ import division
9 | from __future__ import print_function
10 |
11 | from .cityscapes import Cityscapes as cityscapes
12 | from .parking import Parking as parking
13 | from .face import Face as face
14 | from .lip import LIP as lip
15 | from .pascal_ctx import PASCALContext as pascal_ctx
16 | from .ade20k import ADE20K as ade20k
17 | from .map import MAP as map
18 | from .cocostuff import COCOStuff as cocostuff
19 |
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/datasets/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/datasets/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/ade20k.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/datasets/__pycache__/ade20k.cpython-35.pyc
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/ade20k.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/datasets/__pycache__/ade20k.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/base_dataset.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/datasets/__pycache__/base_dataset.cpython-35.pyc
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/base_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/datasets/__pycache__/base_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/cityscapes.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/datasets/__pycache__/cityscapes.cpython-35.pyc
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/cityscapes.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/datasets/__pycache__/cityscapes.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/cocostuff.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/datasets/__pycache__/cocostuff.cpython-35.pyc
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/cocostuff.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/datasets/__pycache__/cocostuff.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/face.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/datasets/__pycache__/face.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/lip.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/datasets/__pycache__/lip.cpython-35.pyc
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/lip.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/datasets/__pycache__/lip.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/map.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/datasets/__pycache__/map.cpython-35.pyc
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/map.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/datasets/__pycache__/map.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/parking.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/datasets/__pycache__/parking.cpython-35.pyc
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/parking.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/datasets/__pycache__/parking.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/pascal_ctx.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/datasets/__pycache__/pascal_ctx.cpython-35.pyc
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/pascal_ctx.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/datasets/__pycache__/pascal_ctx.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/datasets/ade20k.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn)
5 | # ------------------------------------------------------------------------------
6 |
7 | import os
8 |
9 | import cv2
10 | import numpy as np
11 |
12 | import torch
13 | from torch.nn import functional as F
14 | from PIL import Image
15 |
16 | from .base_dataset import BaseDataset
17 |
18 |
19 | class ADE20K(BaseDataset):
20 | def __init__(self,
21 | root,
22 | list_path,
23 | num_samples=None,
24 | num_classes=150,
25 | multi_scale=True,
26 | flip=True,
27 | ignore_label=-1,
28 | base_size=520,
29 | crop_size=(520, 520),
30 | downsample_rate=1,
31 | scale_factor=11,
32 | mean=[0.485, 0.456, 0.406],
33 | std=[0.229, 0.224, 0.225]):
34 |
35 | super(ADE20K, self).__init__(ignore_label, base_size,
36 | crop_size, downsample_rate, scale_factor, mean, std)
37 |
38 | self.root = root
39 | self.num_classes = num_classes
40 | self.list_path = list_path
41 | self.class_weights = None
42 |
43 | self.multi_scale = multi_scale
44 | self.flip = flip
45 | self.img_list = [line.strip().split() for line in open(root+list_path)]
46 |
47 | self.files = self.read_files()
48 | if num_samples:
49 | self.files = self.files[:num_samples]
50 |
51 | def read_files(self):
52 | files = []
53 | for item in self.img_list:
54 | image_path, label_path = item
55 | name = os.path.splitext(os.path.basename(label_path))[0]
56 | sample = {
57 | 'img': image_path,
58 | 'label': label_path,
59 | 'name': name
60 | }
61 | files.append(sample)
62 | return files
63 |
64 | def resize_image(self, image, label, size):
65 | image = cv2.resize(image, size, interpolation=cv2.INTER_LINEAR)
66 | label = cv2.resize(label, size, interpolation=cv2.INTER_NEAREST)
67 | return image, label
68 |
69 | def __getitem__(self, index):
70 | item = self.files[index]
71 | name = item["name"]
72 | image_path = os.path.join(self.root, 'ade20k', item['img'])
73 | label_path = os.path.join(self.root, 'ade20k', item['label'])
74 | image = cv2.imread(
75 | image_path,
76 | cv2.IMREAD_COLOR
77 | )
78 | label = np.array(
79 | Image.open(label_path).convert('P')
80 | )
81 | label = self.reduce_zero_label(label)
82 | size = label.shape
83 |
84 | if 'testval' in self.list_path:
85 | image = self.resize_short_length(
86 | image,
87 | short_length=self.base_size,
88 | fit_stride=8
89 | )
90 | image = self.input_transform(image)
91 | image = image.transpose((2, 0, 1))
92 |
93 | return image.copy(), label.copy(), np.array(size), name
94 |
95 | if 'val' in self.list_path:
96 | image, label = self.resize_short_length(
97 | image,
98 | label=label,
99 | short_length=self.base_size,
100 | fit_stride=8
101 | )
102 | image, label = self.rand_crop(image, label)
103 | image = self.input_transform(image)
104 | image = image.transpose((2, 0, 1))
105 |
106 | return image.copy(), label.copy(), np.array(size), name
107 |
108 | image, label = self.resize_short_length(image, label, short_length=self.base_size)
109 | image, label = self.gen_sample(image, label, self.multi_scale, self.flip)
110 |
111 | return image.copy(), label.copy(), np.array(size), name
--------------------------------------------------------------------------------
/lib/datasets/cityscapes.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn)
5 | # ------------------------------------------------------------------------------
6 |
7 | import os
8 |
9 | import cv2
10 | import numpy as np
11 | from PIL import Image
12 |
13 | import torch
14 | from torch.nn import functional as F
15 |
16 | from .base_dataset import BaseDataset
17 |
18 | class Cityscapes(BaseDataset):
19 | def __init__(self,
20 | root,
21 | list_path,
22 | num_samples=None,
23 | num_classes=19,
24 | multi_scale=True,
25 | flip=True,
26 | ignore_label=-1,
27 | base_size=2048,
28 | crop_size=(512, 1024),
29 | downsample_rate=1,
30 | scale_factor=16,
31 | mean=[0.485, 0.456, 0.406],
32 | std=[0.229, 0.224, 0.225]):
33 |
34 | super(Cityscapes, self).__init__(ignore_label, base_size,
35 | crop_size, downsample_rate, scale_factor, mean, std,)
36 |
37 | self.root = root
38 | self.list_path = list_path
39 | self.num_classes = num_classes
40 |
41 | self.multi_scale = multi_scale
42 | self.flip = flip
43 |
44 | self.img_list = [line.strip().split() for line in open(root+list_path)]
45 |
46 | self.files = self.read_files()
47 | if num_samples:
48 | self.files = self.files[:num_samples]
49 |
50 | self.label_mapping = {-1: ignore_label, 0: ignore_label,
51 | 1: ignore_label, 2: ignore_label,
52 | 3: ignore_label, 4: ignore_label,
53 | 5: ignore_label, 6: ignore_label,
54 | 7: 0, 8: 1, 9: ignore_label,
55 | 10: ignore_label, 11: 2, 12: 3,
56 | 13: 4, 14: ignore_label, 15: ignore_label,
57 | 16: ignore_label, 17: 5, 18: ignore_label,
58 | 19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11,
59 | 25: 12, 26: 13, 27: 14, 28: 15,
60 | 29: ignore_label, 30: ignore_label,
61 | 31: 16, 32: 17, 33: 18}
62 | self.class_weights = torch.FloatTensor([0.8373, 0.918, 0.866, 1.0345,
63 | 1.0166, 0.9969, 0.9754, 1.0489,
64 | 0.8786, 1.0023, 0.9539, 0.9843,
65 | 1.1116, 0.9037, 1.0865, 1.0955,
66 | 1.0865, 1.1529, 1.0507]).cuda()
67 |
68 | def read_files(self):
69 | files = []
70 | if 'test' in self.list_path:
71 | for item in self.img_list:
72 | image_path = item
73 | name = os.path.splitext(os.path.basename(image_path[0]))[0]
74 | files.append({
75 | "img": image_path[0],
76 | "name": name,
77 | })
78 | else:
79 | for item in self.img_list:
80 | image_path, label_path = item
81 | name = os.path.splitext(os.path.basename(label_path))[0]
82 | files.append({
83 | "img": image_path,
84 | "label": label_path,
85 | "name": name,
86 | "weight": 1
87 | })
88 | return files
89 |
90 | def convert_label(self, label, inverse=False):
91 | temp = label.copy()
92 | if inverse:
93 | for v, k in self.label_mapping.items():
94 | label[temp == k] = v
95 | else:
96 | for k, v in self.label_mapping.items():
97 | label[temp == k] = v
98 | return label
99 |
100 | def __getitem__(self, index):
101 | item = self.files[index]
102 | name = item["name"]
103 | image = cv2.imread(os.path.join(self.root,'cityscapes',item["img"]),
104 | cv2.IMREAD_COLOR)
105 | size = image.shape
106 |
107 | if 'test' in self.list_path:
108 | image = self.input_transform(image)
109 | image = image.transpose((2, 0, 1))
110 |
111 | return image.copy(), np.array(size), name
112 |
113 | label = cv2.imread(os.path.join(self.root,'cityscapes',item["label"]),
114 | cv2.IMREAD_GRAYSCALE)
115 | label = self.convert_label(label)
116 |
117 | image, label = self.gen_sample(image, label,
118 | self.multi_scale, self.flip)
119 |
120 | return image.copy(), label.copy(), np.array(size), name
121 |
122 | def multi_scale_inference(self, config, model, image, scales=[1], flip=False):
123 | batch, _, ori_height, ori_width = image.size()
124 | assert batch == 1, "only supporting batchsize 1."
125 | image = image.numpy()[0].transpose((1,2,0)).copy()
126 | stride_h = np.int(self.crop_size[0] * 1.0)
127 | stride_w = np.int(self.crop_size[1] * 1.0)
128 | final_pred = torch.zeros([1, self.num_classes,
129 | ori_height,ori_width]).cuda()
130 | for scale in scales:
131 | new_img = self.multi_scale_aug(image=image,
132 | rand_scale=scale,
133 | rand_crop=False)
134 | height, width = new_img.shape[:-1]
135 |
136 | if scale <= 1.0:
137 | new_img = new_img.transpose((2, 0, 1))
138 | new_img = np.expand_dims(new_img, axis=0)
139 | new_img = torch.from_numpy(new_img)
140 | preds = self.inference(config, model, new_img, flip)
141 | preds = preds[:, :, 0:height, 0:width]
142 | else:
143 | new_h, new_w = new_img.shape[:-1]
144 | rows = np.int(np.ceil(1.0 * (new_h -
145 | self.crop_size[0]) / stride_h)) + 1
146 | cols = np.int(np.ceil(1.0 * (new_w -
147 | self.crop_size[1]) / stride_w)) + 1
148 | preds = torch.zeros([1, self.num_classes,
149 | new_h,new_w]).cuda()
150 | count = torch.zeros([1,1, new_h, new_w]).cuda()
151 |
152 | for r in range(rows):
153 | for c in range(cols):
154 | h0 = r * stride_h
155 | w0 = c * stride_w
156 | h1 = min(h0 + self.crop_size[0], new_h)
157 | w1 = min(w0 + self.crop_size[1], new_w)
158 | h0 = max(int(h1 - self.crop_size[0]), 0)
159 | w0 = max(int(w1 - self.crop_size[1]), 0)
160 | crop_img = new_img[h0:h1, w0:w1, :]
161 | crop_img = crop_img.transpose((2, 0, 1))
162 | crop_img = np.expand_dims(crop_img, axis=0)
163 | crop_img = torch.from_numpy(crop_img)
164 | pred = self.inference(config, model, crop_img, flip)
165 | preds[:,:,h0:h1,w0:w1] += pred[:,:, 0:h1-h0, 0:w1-w0]
166 | count[:,:,h0:h1,w0:w1] += 1
167 | preds = preds / count
168 | preds = preds[:,:,:height,:width]
169 |
170 | preds = F.interpolate(
171 | preds, (ori_height, ori_width),
172 | mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS
173 | )
174 | final_pred += preds
175 | return final_pred
176 |
177 | def get_palette(self, n):
178 | palette = [0] * (n * 3)
179 | for j in range(0, n):
180 | lab = j
181 | palette[j * 3 + 0] = 0
182 | palette[j * 3 + 1] = 0
183 | palette[j * 3 + 2] = 0
184 | i = 0
185 | while lab:
186 | palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
187 | palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
188 | palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
189 | i += 1
190 | lab >>= 3
191 | return palette
192 |
193 | def save_pred(self, preds, sv_path, name):
194 | palette = self.get_palette(256)
195 | preds = np.asarray(np.argmax(preds.cpu(), axis=1), dtype=np.uint8)
196 | for i in range(preds.shape[0]):
197 | pred = self.convert_label(preds[i], inverse=True)
198 | save_img = Image.fromarray(pred)
199 | save_img.putpalette(palette)
200 | save_img.save(os.path.join(sv_path, name[i]+'.png'))
201 |
202 |
203 |
204 |
--------------------------------------------------------------------------------
/lib/datasets/cocostuff.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn)
5 | # ------------------------------------------------------------------------------
6 |
7 | import os
8 |
9 | import cv2
10 | import numpy as np
11 |
12 | import torch
13 | from torch.nn import functional as F
14 | from PIL import Image
15 |
16 | from .base_dataset import BaseDataset
17 |
18 |
19 | class COCOStuff(BaseDataset):
20 | def __init__(self,
21 | root,
22 | list_path,
23 | num_samples=None,
24 | num_classes=171,
25 | multi_scale=True,
26 | flip=True,
27 | ignore_label=-1,
28 | base_size=520,
29 | crop_size=(520, 520),
30 | downsample_rate=1,
31 | scale_factor=11,
32 | mean=[0.485, 0.456, 0.406],
33 | std=[0.229, 0.224, 0.225]):
34 |
35 | super(COCOStuff, self).__init__(ignore_label, base_size,
36 | crop_size, downsample_rate, scale_factor, mean, std)
37 |
38 | self.root = root
39 | self.num_classes = num_classes
40 | self.list_path = list_path
41 | self.class_weights = None
42 |
43 | self.multi_scale = multi_scale
44 | self.flip = flip
45 | self.img_list = [line.strip().split() for line in open(root+list_path)]
46 |
47 | self.files = self.read_files()
48 | if num_samples:
49 | self.files = self.files[:num_samples]
50 | self.mapping = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20,
51 | 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39,
52 | 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58,
53 | 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77,
54 | 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90, 92, 93, 94, 95, 96,
55 | 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112,
56 | 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128,
57 | 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144,
58 | 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160,
59 | 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176,
60 | 177, 178, 179, 180, 181, 182]
61 |
62 | def read_files(self):
63 | files = []
64 | for item in self.img_list:
65 | image_path, label_path = item
66 | name = os.path.splitext(os.path.basename(label_path))[0]
67 | sample = {
68 | 'img': image_path,
69 | 'label': label_path,
70 | 'name': name
71 | }
72 | files.append(sample)
73 | return files
74 |
75 | def encode_label(self, labelmap):
76 | ret = np.ones_like(labelmap) * 255
77 | for idx, label in enumerate(self.mapping):
78 | ret[labelmap == label] = idx
79 |
80 | return ret
81 |
82 | def resize_image(self, image, label, size):
83 | image = cv2.resize(image, size, interpolation=cv2.INTER_LINEAR)
84 | label = cv2.resize(label, size, interpolation=cv2.INTER_NEAREST)
85 | return image, label
86 |
87 | def __getitem__(self, index):
88 | item = self.files[index]
89 | name = item["name"]
90 | image_path = os.path.join(self.root, 'cocostuff', item['img'])
91 | label_path = os.path.join(self.root, 'cocostuff', item['label'])
92 | image = cv2.imread(
93 | image_path,
94 | cv2.IMREAD_COLOR
95 | )
96 | label = np.array(
97 | Image.open(label_path).convert('P')
98 | )
99 | label = self.encode_label(label)
100 | label = self.reduce_zero_label(label)
101 | size = label.shape
102 |
103 | if 'testval' in self.list_path:
104 | image, border_padding = self.resize_short_length(
105 | image,
106 | short_length=self.base_size,
107 | fit_stride=8,
108 | return_padding=True
109 | )
110 | image = self.input_transform(image)
111 | image = image.transpose((2, 0, 1))
112 |
113 | return image.copy(), label.copy(), np.array(size), name, border_padding
114 |
115 | if 'val' in self.list_path:
116 | image, label = self.resize_short_length(
117 | image,
118 | label=label,
119 | short_length=self.base_size,
120 | fit_stride=8
121 | )
122 | image, label = self.rand_crop(image, label)
123 | image = self.input_transform(image)
124 | image = image.transpose((2, 0, 1))
125 |
126 | return image.copy(), label.copy(), np.array(size), name
127 |
128 | image, label = self.resize_short_length(image, label, short_length=self.base_size)
129 | image, label = self.gen_sample(image, label, self.multi_scale, self.flip)
130 |
131 | return image.copy(), label.copy(), np.array(size), name
--------------------------------------------------------------------------------
/lib/datasets/face.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn)
5 | # ------------------------------------------------------------------------------
6 |
7 | import os
8 |
9 | import cv2
10 | import numpy as np
11 | from PIL import Image
12 |
13 | import torch
14 | from torch.nn import functional as F
15 |
16 | from .base_dataset import BaseDataset
17 |
18 | class Face(BaseDataset):
19 | def __init__(self,
20 | root,
21 | list_path,
22 | num_samples=None,
23 | num_classes=2,
24 | multi_scale=True,
25 | flip=True,
26 | ignore_label=-1,
27 | base_size=2048,
28 | crop_size=(512, 1024),
29 | downsample_rate=1,
30 | scale_factor=15,
31 | mean=[0.485, 0.456, 0.406],
32 | std=[0.229, 0.224, 0.225]):
33 |
34 | super(Face, self).__init__(ignore_label, base_size,
35 | crop_size, downsample_rate, scale_factor, mean, std,)
36 |
37 | self.root = root
38 | self.list_path = list_path
39 | self.num_classes = num_classes
40 |
41 | self.multi_scale = multi_scale
42 | self.flip = flip
43 |
44 | self.img_list = [line.strip().split(",") for line in open(root+list_path)]
45 |
46 | self.files = self.read_files()
47 |
48 | if num_samples:
49 | self.files = self.files[:num_samples]
50 |
51 | self.label_mapping = {-1: 0, 0: 0}
52 | for s in range(1,256):
53 | self.label_mapping[s] = 1
54 |
55 | self.class_weights = torch.FloatTensor([1, 3]).cuda()
56 |
57 | def read_files(self):
58 | files = []
59 | if 'test' in self.list_path:
60 | for item in self.img_list:
61 | image_path = item
62 | print("os.path.basename(image_path[0]):",os.path.basename(image_path[0]))
63 | name = os.path.splitext(os.path.basename(image_path[0]))[0]
64 | files.append({
65 | "img": image_path[0],
66 | "name": name,
67 | })
68 | else:
69 | for item in self.img_list:
70 | image_path, label_path = item
71 | name = os.path.splitext(os.path.basename(label_path))[0]
72 | files.append({
73 | "img": image_path,
74 | "label": label_path,
75 | "name": name,
76 | "weight": 1
77 | })
78 | return files
79 |
80 | def convert_label(self, label, inverse=False):
81 |
82 | temp = label.copy()
83 | if inverse:
84 | for v, k in self.label_mapping.items():
85 | label[temp == k] = v
86 | else:
87 | for k, v in self.label_mapping.items():
88 | label[temp == k] = v
89 | return label
90 |
91 | def __getitem__(self, index):
92 | item = self.files[index]
93 | name = item["name"]
94 | image = cv2.imread(os.path.join(self.root,'face',item["img"]),
95 | cv2.IMREAD_COLOR)
96 |
97 | image = cv2.resize(image,(self.base_size,self.base_size))
98 |
99 | size = image.shape
100 |
101 | if 'test' in self.list_path:
102 | image = self.input_transform(image)
103 | image = image.transpose((2, 0, 1))
104 |
105 | return image.copy(), np.array(size), name
106 | label = cv2.imread(os.path.join(self.root,'face',item["label"]),
107 | cv2.IMREAD_UNCHANGED)
108 | label = label[:,:,3]
109 | label = cv2.resize(label, (self.base_size, self.base_size))
110 | # cv2.imwrite("before_label.jpg", label)
111 |
112 | image, label = self.gen_sample(image, label,
113 | self.multi_scale, self.flip)
114 |
115 | label = self.convert_label(label)
116 | #print(label.shape)
117 | # print("image : ",image.shape, label.shape)
118 | # cv2.imwrite("aug/{}_after_image.jpg".format(index), image)
119 | # cv2.imwrite("aug/{}_after_label.jpg".format(index), label)
120 |
121 |
122 | return image.copy(), label.copy(), np.array(size), name
123 |
124 | def multi_scale_inference(self, config, model, image, scales=[1], flip=False):
125 | batch, _, ori_height, ori_width = image.size()
126 | assert batch == 1, "only supporting batchsize 1."
127 | image = image.numpy()[0].transpose((1,2,0)).copy()
128 | stride_h = np.int(self.crop_size[0] * 1.0)
129 | stride_w = np.int(self.crop_size[1] * 1.0)
130 | final_pred = torch.zeros([1, self.num_classes,
131 | ori_height,ori_width]).cuda()
132 | for scale in scales:
133 | new_img = self.multi_scale_aug(image=image,
134 | rand_scale=scale,
135 | rand_crop=False)
136 | height, width = new_img.shape[:-1]
137 |
138 | if scale <= 1.0:
139 | new_img = new_img.transpose((2, 0, 1))
140 | new_img = np.expand_dims(new_img, axis=0)
141 | new_img = torch.from_numpy(new_img)
142 | preds = self.inference(config, model, new_img, flip)
143 | preds = preds[:, :, 0:height, 0:width]
144 | else:
145 | new_h, new_w = new_img.shape[:-1]
146 | rows = np.int(np.ceil(1.0 * (new_h -
147 | self.crop_size[0]) / stride_h)) + 1
148 | cols = np.int(np.ceil(1.0 * (new_w -
149 | self.crop_size[1]) / stride_w)) + 1
150 | preds = torch.zeros([1, self.num_classes,
151 | new_h,new_w]).cuda()
152 | count = torch.zeros([1,1, new_h, new_w]).cuda()
153 |
154 | for r in range(rows):
155 | for c in range(cols):
156 | h0 = r * stride_h
157 | w0 = c * stride_w
158 | h1 = min(h0 + self.crop_size[0], new_h)
159 | w1 = min(w0 + self.crop_size[1], new_w)
160 | h0 = max(int(h1 - self.crop_size[0]), 0)
161 | w0 = max(int(w1 - self.crop_size[1]), 0)
162 | crop_img = new_img[h0:h1, w0:w1, :]
163 | crop_img = crop_img.transpose((2, 0, 1))
164 | crop_img = np.expand_dims(crop_img, axis=0)
165 | crop_img = torch.from_numpy(crop_img)
166 | pred = self.inference(config, model, crop_img, flip)
167 | preds[:,:,h0:h1,w0:w1] += pred[:,:, 0:h1-h0, 0:w1-w0]
168 | count[:,:,h0:h1,w0:w1] += 1
169 | preds = preds / count
170 | preds = preds[:,:,:height,:width]
171 |
172 | preds = F.interpolate(
173 | preds, (ori_height, ori_width),
174 | mode='nearest'
175 | )
176 | final_pred += preds
177 | return final_pred
178 |
179 | def get_palette(self, n):
180 | palette = [0] * (n * 3)
181 | for j in range(0, n):
182 | lab = j
183 | palette[j * 3 + 0] = 0
184 | palette[j * 3 + 1] = 0
185 | palette[j * 3 + 2] = 0
186 | i = 0
187 | while lab:
188 | palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
189 | palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
190 | palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
191 | i += 1
192 | lab >>= 3
193 | return palette
194 |
195 | def save_pred(self, preds, sv_path, name):
196 | palette = self.get_palette(256)
197 | preds = np.asarray(np.argmax(preds.cpu(), axis=1), dtype=np.uint8)
198 | for i in range(preds.shape[0]):
199 | pred = self.convert_label(preds[i], inverse=True)
200 | save_img = Image.fromarray(pred)
201 | save_img.putpalette(palette)
202 | save_img.save(os.path.join(sv_path, name[i]+'.png'))
203 |
204 |
205 |
206 |
--------------------------------------------------------------------------------
/lib/datasets/lip.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn)
5 | # ------------------------------------------------------------------------------
6 |
7 | import os
8 |
9 | import cv2
10 | import numpy as np
11 |
12 | import torch
13 | from torch.nn import functional as F
14 |
15 | from .base_dataset import BaseDataset
16 |
17 |
18 | class LIP(BaseDataset):
19 | def __init__(self,
20 | root,
21 | list_path,
22 | num_samples=None,
23 | num_classes=20,
24 | multi_scale=True,
25 | flip=True,
26 | ignore_label=-1,
27 | base_size=473,
28 | crop_size=(473, 473),
29 | downsample_rate=1,
30 | scale_factor=11,
31 | mean=[0.485, 0.456, 0.406],
32 | std=[0.229, 0.224, 0.225]):
33 |
34 | super(LIP, self).__init__(ignore_label, base_size,
35 | crop_size, downsample_rate, scale_factor, mean, std)
36 |
37 | self.root = root
38 | self.num_classes = num_classes
39 | self.list_path = list_path
40 | self.class_weights = None
41 |
42 | self.multi_scale = multi_scale
43 | self.flip = flip
44 | self.img_list = [line.strip().split() for line in open(root+list_path)]
45 |
46 | self.files = self.read_files()
47 | if num_samples:
48 | self.files = self.files[:num_samples]
49 |
50 | def read_files(self):
51 | files = []
52 | for item in self.img_list:
53 | if 'train' in self.list_path:
54 | image_path, label_path, label_rev_path, _ = item
55 | name = os.path.splitext(os.path.basename(label_path))[0]
56 | sample = {"img": image_path,
57 | "label": label_path,
58 | "label_rev": label_rev_path,
59 | "name": name, }
60 | elif 'val' in self.list_path:
61 | image_path, label_path = item
62 | name = os.path.splitext(os.path.basename(label_path))[0]
63 | sample = {"img": image_path,
64 | "label": label_path,
65 | "name": name, }
66 | else:
67 | raise NotImplementedError('Unknown subset.')
68 | files.append(sample)
69 | return files
70 |
71 | def resize_image(self, image, label, size):
72 | image = cv2.resize(image, size, interpolation=cv2.INTER_LINEAR)
73 | label = cv2.resize(label, size, interpolation=cv2.INTER_NEAREST)
74 | return image, label
75 |
76 | def __getitem__(self, index):
77 | item = self.files[index]
78 | name = item["name"]
79 | item["img"] = item["img"].replace(
80 | "train_images", "LIP_Train").replace("val_images", "LIP_Val")
81 | item["label"] = item["label"].replace(
82 | "train_segmentations", "LIP_Train").replace("val_segmentations", "LIP_Val")
83 | image = cv2.imread(os.path.join(
84 | self.root, 'lip/TrainVal_images/', item["img"]),
85 | cv2.IMREAD_COLOR)
86 | label = cv2.imread(os.path.join(
87 | self.root, 'lip/TrainVal_parsing_annotations/',
88 | item["label"]),
89 | cv2.IMREAD_GRAYSCALE)
90 | size = label.shape
91 |
92 | if 'testval' in self.list_path:
93 | image = cv2.resize(image, self.crop_size,
94 | interpolation=cv2.INTER_LINEAR)
95 | image = self.input_transform(image)
96 | image = image.transpose((2, 0, 1))
97 |
98 | return image.copy(), label.copy(), np.array(size), name
99 |
100 | if self.flip:
101 | flip = np.random.choice(2) * 2 - 1
102 | image = image[:, ::flip, :]
103 | label = label[:, ::flip]
104 |
105 | if flip == -1:
106 | right_idx = [15, 17, 19]
107 | left_idx = [14, 16, 18]
108 | for i in range(0, 3):
109 | right_pos = np.where(label == right_idx[i])
110 | left_pos = np.where(label == left_idx[i])
111 | label[right_pos[0], right_pos[1]] = left_idx[i]
112 | label[left_pos[0], left_pos[1]] = right_idx[i]
113 |
114 | image, label = self.resize_image(image, label, self.crop_size)
115 | image, label = self.gen_sample(image, label,
116 | self.multi_scale, False)
117 |
118 | return image.copy(), label.copy(), np.array(size), name
119 |
120 | def inference(self, config, model, image, flip):
121 | size = image.size()
122 | pred = model(image)
123 | if config.MODEL.NUM_OUTPUTS > 1:
124 | pred = pred[config.TEST.OUTPUT_INDEX]
125 |
126 | pred = F.interpolate(
127 | input=pred, size=size[-2:],
128 | mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS
129 | )
130 |
131 | if flip:
132 | flip_img = image.numpy()[:, :, :, ::-1]
133 | flip_output = model(torch.from_numpy(flip_img.copy()))
134 |
135 | if config.MODEL.NUM_OUTPUTS > 1:
136 | flip_output = flip_output[config.TEST.OUTPUT_INDEX]
137 |
138 | flip_output = F.interpolate(
139 | input=flip_output, size=size[-2:],
140 | mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS
141 | )
142 |
143 | flip_output = flip_output.cpu()
144 | flip_pred = flip_output.cpu().numpy().copy()
145 | flip_pred[:, 14, :, :] = flip_output[:, 15, :, :]
146 | flip_pred[:, 15, :, :] = flip_output[:, 14, :, :]
147 | flip_pred[:, 16, :, :] = flip_output[:, 17, :, :]
148 | flip_pred[:, 17, :, :] = flip_output[:, 16, :, :]
149 | flip_pred[:, 18, :, :] = flip_output[:, 19, :, :]
150 | flip_pred[:, 19, :, :] = flip_output[:, 18, :, :]
151 | flip_pred = torch.from_numpy(
152 | flip_pred[:, :, :, ::-1].copy()).cuda()
153 | pred += flip_pred
154 | pred = pred * 0.5
155 | return pred.exp()
156 |
--------------------------------------------------------------------------------
/lib/datasets/map.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn)
5 | # ------------------------------------------------------------------------------
6 |
7 | import os
8 |
9 | import cv2
10 | import numpy as np
11 |
12 | import torch
13 | from torch.nn import functional as F
14 | from PIL import Image
15 |
16 | from .base_dataset import BaseDataset
17 |
18 |
19 | class MAP(BaseDataset):
20 | def __init__(self,
21 | root,
22 | list_path=None,
23 | num_samples=None,
24 | num_classes=17,
25 | multi_scale=True,
26 | flip=True,
27 | ignore_label=-1,
28 | base_size=520,
29 | crop_size=(520, 520),
30 | downsample_rate=1,
31 | scale_factor=11,
32 | mean=[0.485, 0.456, 0.406],
33 | std=[0.229, 0.224, 0.225]):
34 |
35 | super(MAP, self).__init__(ignore_label, base_size,
36 | crop_size, downsample_rate, scale_factor, mean, std,)
37 |
38 | self.root = root
39 | self.list_path = list_path
40 | self.num_classes = num_classes
41 |
42 | self.multi_scale = multi_scale
43 | self.flip = flip
44 |
45 | if os.path.isfile(self.list_path):
46 | self.img_list = [line.strip().split() for line in open(list_path)]
47 | elif os.path.isdir(self.list_path):
48 | self.img_list = [line.split('.')[0] for line in os.listdir(self.list_path)]
49 |
50 | self.files = self.read_files()
51 | if num_samples:
52 | self.files = self.files[:num_samples]
53 |
54 | self.class_weights = None
55 | # self.class_weights = torch.FloatTensor([0.8373, 0.918, 0.866, 1.0345,
56 | # 1.0166, 0.9969, 0.9754, 1.0489,
57 | # 0.8786, 1.0023, 0.9539, 0.9843,
58 | # 1.1116, 0.9037, 1.0865, 1.0955,
59 | # 1.0865, 1.1529, 1.0507]).cuda()
60 |
61 | def read_files(self):
62 | files = []
63 | if os.path.basename(self.list_path).split(".")[0] == "test":
64 | for item in self.img_list:
65 | image_path = item[0]
66 | name = os.path.basename(image_path).split('.')[0]
67 | files.append({
68 | "img": image_path,
69 | "name": name,
70 | })
71 | else:
72 | for item in self.img_list:
73 | image_path, label_path = "mapv3/images/{}.jpg".format(item[0]), "mapv3/annotations/{}.png".format(item[0])
74 | name = os.path.splitext(os.path.basename(label_path))[0]
75 | files.append({
76 | "img": image_path,
77 | "label": label_path,
78 | "name": name,
79 | "weight": 1
80 | })
81 | return files
82 |
83 | def convert_label(self, label, inverse=False):
84 | temp = label.copy()
85 | if inverse:
86 | for v, k in self.label_mapping.items():
87 | label[temp == k] = v
88 | else:
89 | for k, v in self.label_mapping.items():
90 | label[temp == k] = v
91 | return label
92 |
93 | def __getitem__(self, index):
94 | item = self.files[index]
95 | name = item["name"]
96 | image = cv2.imread(os.path.join(self.root,item["img"]),
97 | cv2.IMREAD_COLOR)
98 | size = image.shape
99 |
100 | if os.path.basename(self.list_path).split(".")[0] == "test":
101 | image = self.input_transform(image)
102 | image = image.transpose((2, 0, 1))
103 |
104 | return image.copy(), np.array(size), name
105 |
106 | # resize the short length to basesize
107 | # if 'testval' in self.list_path:
108 | # label = cv2.imread(os.path.join(self.root,item["label"]),
109 | # cv2.IMREAD_GRAYSCALE)
110 | # image, label = self.resize_short_length(
111 | # image,
112 | # label=label,
113 | # short_length=self.base_size,
114 | # fit_stride=8 )
115 | # size = image.shape
116 | # image = self.input_transform(image)
117 | # image = image.transpose((2, 0, 1))
118 | # label = self.label_transform(label)
119 | # return image.copy(), label.copy(), np.array(size), name
120 |
121 | label = cv2.imread(os.path.join(self.root,item["label"]),
122 | cv2.IMREAD_GRAYSCALE)
123 |
124 | image, label = self.gen_sample(image, label,
125 | self.multi_scale, self.flip)
126 |
127 | return image.copy(), label.copy(), np.array(size), name
128 |
129 | def multi_scale_inference(self, config, model, image, scales=[1], flip=False):
130 | batch, _, ori_height, ori_width = image.size()
131 | assert batch == 1, "only supporting batchsize 1."
132 | image = image.numpy()[0].transpose((1,2,0)).copy()
133 | stride_h = np.int(self.crop_size[0] * 1.0)
134 | stride_w = np.int(self.crop_size[1] * 1.0)
135 | final_pred = torch.zeros([1, self.num_classes,
136 | ori_height,ori_width]).cuda()
137 | for scale in scales:
138 | new_img = self.multi_scale_aug(image=image,
139 | rand_scale=scale,
140 | rand_crop=False)
141 | height, width = new_img.shape[:-1]
142 |
143 | if scale <= 1.0:
144 | new_img = new_img.transpose((2, 0, 1))
145 | new_img = np.expand_dims(new_img, axis=0)
146 | new_img = torch.from_numpy(new_img)
147 | preds = self.inference(config, model, new_img, flip)
148 | preds = preds[:, :, 0:height, 0:width]
149 | else:
150 | new_h, new_w = new_img.shape[:-1]
151 | rows = np.int(np.ceil(1.0 * (new_h -
152 | self.crop_size[0]) / stride_h)) + 1
153 | cols = np.int(np.ceil(1.0 * (new_w -
154 | self.crop_size[1]) / stride_w)) + 1
155 | preds = torch.zeros([1, self.num_classes,
156 | new_h,new_w]).cuda()
157 | count = torch.zeros([1,1, new_h, new_w]).cuda()
158 |
159 | for r in range(rows):
160 | for c in range(cols):
161 | h0 = r * stride_h
162 | w0 = c * stride_w
163 | h1 = min(h0 + self.crop_size[0], new_h)
164 | w1 = min(w0 + self.crop_size[1], new_w)
165 | h0 = max(int(h1 - self.crop_size[0]), 0)
166 | w0 = max(int(w1 - self.crop_size[1]), 0)
167 | crop_img = new_img[h0:h1, w0:w1, :]
168 | crop_img = crop_img.transpose((2, 0, 1))
169 | crop_img = np.expand_dims(crop_img, axis=0)
170 | crop_img = torch.from_numpy(crop_img)
171 | pred = self.inference(config, model, crop_img, flip)
172 | preds[:,:,h0:h1,w0:w1] += pred[:,:, 0:h1-h0, 0:w1-w0]
173 | count[:,:,h0:h1,w0:w1] += 1
174 | preds = preds / count
175 | preds = preds[:,:,:height,:width]
176 |
177 | preds = F.interpolate(
178 | preds, (ori_height, ori_width),
179 | mode='nearest'
180 | )
181 | final_pred += preds
182 | return final_pred
183 |
184 | def get_palette(self, n):
185 | palette = [0] * (n * 3)
186 | for j in range(0, n):
187 | lab = j
188 | palette[j * 3 + 0] = 0
189 | palette[j * 3 + 1] = 0
190 | palette[j * 3 + 2] = 0
191 | i = 0
192 | while lab:
193 | palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
194 | palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
195 | palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
196 | i += 1
197 | lab >>= 3
198 | return palette
199 |
200 | def save_pred(self, image, preds, sv_path, name):
201 | image = image.squeeze(0)
202 | image = image.numpy().transpose((1,2,0))
203 | image *= self.std
204 | image += self.mean
205 | image *= 255.0
206 | image = image.astype(np.uint8)
207 | palette = self.get_palette(256)
208 | preds = np.asarray(np.argmax(preds.cpu(), axis=1), dtype=np.uint8)
209 | for i in range(preds.shape[0]):
210 | pred = preds[i]
211 | save_img = Image.fromarray(pred)
212 | save_img.putpalette(palette)
213 | save_img.save(os.path.join(sv_path, name[i]+'.png'))
214 |
215 | def save_pred2(self, image, preds, sv_path, name):
216 | preds = torch.argmax(preds, dim=1).squeeze(0).cpu().numpy()
217 | image = image.squeeze(0)
218 | image = image.numpy().transpose((1,2,0))
219 | image *= self.std
220 | image += self.mean
221 | image *= 255.0
222 | image = image.astype(np.uint8)
223 | colors = np.array([[0, 0, 0],
224 | [0, 0, 255],
225 | [0, 255, 0],
226 | [0, 255, 255],
227 | [255, 0, 0 ],
228 | [255, 0, 255 ],
229 | [255, 255, 0 ],
230 | [255, 255, 255 ],
231 | [0, 0, 128 ],
232 | [0, 128, 0 ],
233 | [128, 0, 0 ],
234 | [0, 128, 128 ],
235 | [128, 0, 0 ],
236 | [128, 0, 128 ],
237 | [128, 128, 0 ],
238 | [128, 128, 128 ],
239 | [192, 192, 192 ]], dtype=np.uint8)
240 | pred_color = colorEncode(preds, colors)
241 | im_vis = image * 0.5 + pred_color * 0.5
242 | im_vis = im_vis.astype(np.uint8)
243 | save_img = Image.fromarray(im_vis)
244 | save_img.save(os.path.join(sv_path, name[0]+'.png'))
245 |
246 | def colorEncode(labelmap, colors, mode='RGB'):
247 | labelmap = labelmap.astype('int')
248 | labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3),
249 | dtype=np.uint8)
250 | for label in np.unique(labelmap):
251 | if label < 0:
252 | continue
253 | labelmap_rgb += (labelmap == label)[:, :, np.newaxis] * \
254 | np.tile(colors[label],
255 | (labelmap.shape[0], labelmap.shape[1], 1))
256 |
257 | if mode == 'BGR':
258 | return labelmap_rgb[:, :, ::-1]
259 | else:
260 | return labelmap_rgb
261 |
262 |
--------------------------------------------------------------------------------
/lib/datasets/parking.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn)
5 | # ------------------------------------------------------------------------------
6 |
7 | import os
8 |
9 | import cv2
10 | import numpy as np
11 | from PIL import Image
12 |
13 | import torch
14 | from torch.nn import functional as F
15 |
16 | from .base_dataset import BaseDataset
17 |
18 | class Parking(BaseDataset):
19 | def __init__(self,
20 | root,
21 | list_path,
22 | num_samples=None,
23 | num_classes=2,
24 | multi_scale=True,
25 | flip=True,
26 | ignore_label=-1,
27 | base_size=2048,
28 | crop_size=(512, 1024),
29 | downsample_rate=1,
30 | scale_factor=15,
31 | mean=[0.485, 0.456, 0.406],
32 | std=[0.229, 0.224, 0.225]):
33 |
34 | super(Parking, self).__init__(ignore_label, base_size,
35 | crop_size, downsample_rate, scale_factor, mean, std,)
36 |
37 | self.root = root
38 | self.list_path = list_path
39 | self.num_classes = num_classes
40 |
41 | self.multi_scale = multi_scale
42 | self.flip = flip
43 |
44 | self.img_list = [line.strip().split(",") for line in open(root+list_path)]
45 |
46 | self.files = self.read_files()
47 |
48 | if num_samples:
49 | self.files = self.files[:num_samples]
50 |
51 | self.label_mapping = {-1: 0, 38: 1}
52 | for s in range(38):
53 | self.label_mapping[s] = 0
54 | for s in range(39,256):
55 | self.label_mapping[s] = 0
56 |
57 | self.class_weights = torch.FloatTensor([1, 3]).cuda()
58 |
59 | def read_files(self):
60 | files = []
61 | if 'test' in self.list_path:
62 | for item in self.img_list:
63 | image_path = item
64 | print("os.path.basename(image_path[0]):",os.path.basename(image_path[0]))
65 | name = os.path.splitext(os.path.basename(image_path[0]))[0]
66 | files.append({
67 | "img": image_path[0],
68 | "name": name,
69 | })
70 | else:
71 | for item in self.img_list:
72 | image_path, label_path = item
73 | name = os.path.splitext(os.path.basename(label_path))[0]
74 | files.append({
75 | "img": image_path,
76 | "label": label_path,
77 | "name": name,
78 | "weight": 1
79 | })
80 | return files
81 |
82 | def convert_label(self, label, inverse=False):
83 |
84 | temp = label.copy()
85 | if inverse:
86 | for v, k in self.label_mapping.items():
87 | label[temp == k] = v
88 | else:
89 | for k, v in self.label_mapping.items():
90 | label[temp == k] = v
91 | return label
92 |
93 | def __getitem__(self, index):
94 | item = self.files[index]
95 | name = item["name"]
96 | image = cv2.imread(os.path.join(self.root,'parking',item["img"]),
97 | cv2.IMREAD_COLOR)
98 |
99 | image = cv2.resize(image,(self.base_size,self.base_size))
100 |
101 | size = image.shape
102 |
103 | if 'test' in self.list_path:
104 | image = self.input_transform(image)
105 | image = image.transpose((2, 0, 1))
106 |
107 | return image.copy(), np.array(size), name
108 | label = cv2.imread(os.path.join(self.root,'parking',item["label"]),
109 | cv2.IMREAD_GRAYSCALE)
110 | label = cv2.resize(label, (self.base_size, self.base_size))
111 | # cv2.imwrite("before_label.jpg", label)
112 |
113 | image, label = self.gen_sample(image, label,
114 | self.multi_scale, self.flip)
115 |
116 | label = self.convert_label(label)
117 | #print(label.shape)
118 | # print("image : ",image.shape, label.shape)
119 | # cv2.imwrite("aug/{}_after_image.jpg".format(index), image)
120 | # cv2.imwrite("aug/{}_after_label.jpg".format(index), label)
121 |
122 |
123 | return image.copy(), label.copy(), np.array(size), name
124 |
125 | def multi_scale_inference(self, config, model, image, scales=[1], flip=False):
126 | batch, _, ori_height, ori_width = image.size()
127 | assert batch == 1, "only supporting batchsize 1."
128 | image = image.numpy()[0].transpose((1,2,0)).copy()
129 | stride_h = np.int(self.crop_size[0] * 1.0)
130 | stride_w = np.int(self.crop_size[1] * 1.0)
131 | final_pred = torch.zeros([1, self.num_classes,
132 | ori_height,ori_width]).cuda()
133 | for scale in scales:
134 | new_img = self.multi_scale_aug(image=image,
135 | rand_scale=scale,
136 | rand_crop=False)
137 | height, width = new_img.shape[:-1]
138 |
139 | if scale <= 1.0:
140 | new_img = new_img.transpose((2, 0, 1))
141 | new_img = np.expand_dims(new_img, axis=0)
142 | new_img = torch.from_numpy(new_img)
143 | preds = self.inference(config, model, new_img, flip)
144 | preds = preds[:, :, 0:height, 0:width]
145 | else:
146 | new_h, new_w = new_img.shape[:-1]
147 | rows = np.int(np.ceil(1.0 * (new_h -
148 | self.crop_size[0]) / stride_h)) + 1
149 | cols = np.int(np.ceil(1.0 * (new_w -
150 | self.crop_size[1]) / stride_w)) + 1
151 | preds = torch.zeros([1, self.num_classes,
152 | new_h,new_w]).cuda()
153 | count = torch.zeros([1,1, new_h, new_w]).cuda()
154 |
155 | for r in range(rows):
156 | for c in range(cols):
157 | h0 = r * stride_h
158 | w0 = c * stride_w
159 | h1 = min(h0 + self.crop_size[0], new_h)
160 | w1 = min(w0 + self.crop_size[1], new_w)
161 | h0 = max(int(h1 - self.crop_size[0]), 0)
162 | w0 = max(int(w1 - self.crop_size[1]), 0)
163 | crop_img = new_img[h0:h1, w0:w1, :]
164 | crop_img = crop_img.transpose((2, 0, 1))
165 | crop_img = np.expand_dims(crop_img, axis=0)
166 | crop_img = torch.from_numpy(crop_img)
167 | pred = self.inference(config, model, crop_img, flip)
168 | preds[:,:,h0:h1,w0:w1] += pred[:,:, 0:h1-h0, 0:w1-w0]
169 | count[:,:,h0:h1,w0:w1] += 1
170 | preds = preds / count
171 | preds = preds[:,:,:height,:width]
172 |
173 | preds = F.interpolate(
174 | preds, (ori_height, ori_width),
175 | mode='nearest'
176 | )
177 | final_pred += preds
178 | return final_pred
179 |
180 | def get_palette(self, n):
181 | palette = [0] * (n * 3)
182 | for j in range(0, n):
183 | lab = j
184 | palette[j * 3 + 0] = 0
185 | palette[j * 3 + 1] = 0
186 | palette[j * 3 + 2] = 0
187 | i = 0
188 | while lab:
189 | palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
190 | palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
191 | palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
192 | i += 1
193 | lab >>= 3
194 | return palette
195 |
196 | def save_pred(self, preds, sv_path, name):
197 | palette = self.get_palette(256)
198 | preds = np.asarray(np.argmax(preds.cpu(), axis=1), dtype=np.uint8)
199 | for i in range(preds.shape[0]):
200 | pred = self.convert_label(preds[i], inverse=True)
201 | save_img = Image.fromarray(pred)
202 | save_img.putpalette(palette)
203 | save_img.save(os.path.join(sv_path, name[i]+'.png'))
204 |
205 |
206 |
207 |
--------------------------------------------------------------------------------
/lib/datasets/pascal_ctx.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn)
5 | # Referring to the implementation in
6 | # https://github.com/zhanghang1989/PyTorch-Encoding
7 | # ------------------------------------------------------------------------------
8 |
9 | import os
10 |
11 | import cv2
12 | import numpy as np
13 | from PIL import Image
14 |
15 | import torch
16 |
17 | from .base_dataset import BaseDataset
18 |
19 | class PASCALContext(BaseDataset):
20 | def __init__(self,
21 | root,
22 | list_path,
23 | num_samples=None,
24 | num_classes=59,
25 | multi_scale=True,
26 | flip=True,
27 | ignore_label=-1,
28 | base_size=520,
29 | crop_size=(480, 480),
30 | downsample_rate=1,
31 | scale_factor=16,
32 | mean=[0.485, 0.456, 0.406],
33 | std=[0.229, 0.224, 0.225],):
34 |
35 | super(PASCALContext, self).__init__(ignore_label, base_size,
36 | crop_size, downsample_rate, scale_factor, mean, std)
37 |
38 | self.root = os.path.join(root, 'pascal_ctx/VOCdevkit/VOC2010')
39 | self.split = list_path
40 |
41 | self.num_classes = num_classes
42 | self.class_weights = None
43 |
44 | self.multi_scale = multi_scale
45 | self.flip = flip
46 | self.crop_size = crop_size
47 |
48 | # prepare data
49 | annots = os.path.join(self.root, 'trainval_merged.json')
50 | img_path = os.path.join(self.root, 'JPEGImages')
51 | from detail import Detail
52 | if 'val' in self.split:
53 | self.detail = Detail(annots, img_path, 'val')
54 | mask_file = os.path.join(self.root, 'val.pth')
55 | elif 'train' in self.split:
56 | self.mode = 'train'
57 | self.detail = Detail(annots, img_path, 'train')
58 | mask_file = os.path.join(self.root, 'train.pth')
59 | else:
60 | raise NotImplementedError('only supporting train and val set.')
61 | self.files = self.detail.getImgs()
62 |
63 | # generate masks
64 | self._mapping = np.sort(np.array([
65 | 0, 2, 259, 260, 415, 324, 9, 258, 144, 18, 19, 22,
66 | 23, 397, 25, 284, 158, 159, 416, 33, 162, 420, 454, 295, 296,
67 | 427, 44, 45, 46, 308, 59, 440, 445, 31, 232, 65, 354, 424,
68 | 68, 326, 72, 458, 34, 207, 80, 355, 85, 347, 220, 349, 360,
69 | 98, 187, 104, 105, 366, 189, 368, 113, 115]))
70 |
71 | self._key = np.array(range(len(self._mapping))).astype('uint8')
72 |
73 | print('mask_file:', mask_file)
74 | if os.path.exists(mask_file):
75 | self.masks = torch.load(mask_file)
76 | else:
77 | self.masks = self._preprocess(mask_file)
78 |
79 | def _class_to_index(self, mask):
80 | # assert the values
81 | values = np.unique(mask)
82 | for i in range(len(values)):
83 | assert(values[i] in self._mapping)
84 | index = np.digitize(mask.ravel(), self._mapping, right=True)
85 | return self._key[index].reshape(mask.shape)
86 |
87 | def _preprocess(self, mask_file):
88 | masks = {}
89 | print("Preprocessing mask, this will take a while." + \
90 | "But don't worry, it only run once for each split.")
91 | for i in range(len(self.files)):
92 | img_id = self.files[i]
93 | mask = Image.fromarray(self._class_to_index(
94 | self.detail.getMask(img_id)))
95 | masks[img_id['image_id']] = mask
96 | torch.save(masks, mask_file)
97 | return masks
98 |
99 | def __getitem__(self, index):
100 | item = self.files[index]
101 | name = item['file_name']
102 | img_id = item['image_id']
103 |
104 | image = cv2.imread(os.path.join(self.detail.img_folder,name),
105 | cv2.IMREAD_COLOR)
106 | label = np.asarray(self.masks[img_id],dtype=np.int)
107 | size = image.shape
108 |
109 | if self.split == 'val':
110 | image = cv2.resize(image, self.crop_size,
111 | interpolation = cv2.INTER_LINEAR)
112 | image = self.input_transform(image)
113 | image = image.transpose((2, 0, 1))
114 |
115 | label = cv2.resize(label, self.crop_size,
116 | interpolation=cv2.INTER_NEAREST)
117 | label = self.label_transform(label)
118 | elif self.split == 'testval':
119 | # evaluate model on val dataset
120 | image = self.input_transform(image)
121 | image = image.transpose((2, 0, 1))
122 | label = self.label_transform(label)
123 | else:
124 | image, label = self.gen_sample(image, label,
125 | self.multi_scale, self.flip)
126 |
127 | return image.copy(), label.copy(), np.array(size), name
128 |
129 | def label_transform(self, label):
130 | if self.num_classes == 59:
131 | # background is ignored
132 | label = np.array(label).astype('int32') - 1
133 | label[label==-2] = -1
134 | else:
135 | label = np.array(label).astype('int32')
136 | return label
137 |
--------------------------------------------------------------------------------
/lib/models/__init__.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn)
5 | # ------------------------------------------------------------------------------
6 |
7 | from __future__ import absolute_import
8 | from __future__ import division
9 | from __future__ import print_function
10 |
11 | import models.seg_hrnet
12 | import models.seg_hrnet_ocr
13 | import models.ddrnet_23_slim
14 | import models.ddrnet_23
15 | import models.ddrnet_39
--------------------------------------------------------------------------------
/lib/models/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/models/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/lib/models/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/models/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/models/__pycache__/bn_helper.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/models/__pycache__/bn_helper.cpython-35.pyc
--------------------------------------------------------------------------------
/lib/models/__pycache__/bn_helper.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/models/__pycache__/bn_helper.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/models/__pycache__/ddrnet_23.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/models/__pycache__/ddrnet_23.cpython-35.pyc
--------------------------------------------------------------------------------
/lib/models/__pycache__/ddrnet_23.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/models/__pycache__/ddrnet_23.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/models/__pycache__/ddrnet_23_slim.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/models/__pycache__/ddrnet_23_slim.cpython-35.pyc
--------------------------------------------------------------------------------
/lib/models/__pycache__/ddrnet_23_slim.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/models/__pycache__/ddrnet_23_slim.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/models/__pycache__/ddrnet_39.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/models/__pycache__/ddrnet_39.cpython-35.pyc
--------------------------------------------------------------------------------
/lib/models/__pycache__/ddrnet_39.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/models/__pycache__/ddrnet_39.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/models/__pycache__/seg_hrnet.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/models/__pycache__/seg_hrnet.cpython-35.pyc
--------------------------------------------------------------------------------
/lib/models/__pycache__/seg_hrnet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/models/__pycache__/seg_hrnet.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/models/__pycache__/seg_hrnet_ocr.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/models/__pycache__/seg_hrnet_ocr.cpython-35.pyc
--------------------------------------------------------------------------------
/lib/models/__pycache__/seg_hrnet_ocr.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/models/__pycache__/seg_hrnet_ocr.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/models/bn_helper.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import functools
3 |
4 | if torch.__version__.startswith('0'):
5 | from .sync_bn.inplace_abn.bn import InPlaceABNSync
6 | BatchNorm2d = functools.partial(InPlaceABNSync, activation='none')
7 | BatchNorm2d_class = InPlaceABNSync
8 | relu_inplace = False
9 | else:
10 | BatchNorm2d_class = BatchNorm2d = torch.nn.SyncBatchNorm
11 | relu_inplace = True
--------------------------------------------------------------------------------
/lib/models/sync_bn/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | BSD 3-Clause License
3 |
4 | Copyright (c) 2017, mapillary
5 | All rights reserved.
6 |
7 | Redistribution and use in source and binary forms, with or without
8 | modification, are permitted provided that the following conditions are met:
9 |
10 | * Redistributions of source code must retain the above copyright notice, this
11 | list of conditions and the following disclaimer.
12 |
13 | * Redistributions in binary form must reproduce the above copyright notice,
14 | this list of conditions and the following disclaimer in the documentation
15 | and/or other materials provided with the distribution.
16 |
17 | * Neither the name of the copyright holder nor the names of its
18 | contributors may be used to endorse or promote products derived from
19 | this software without specific prior written permission.
20 |
21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 |
--------------------------------------------------------------------------------
/lib/models/sync_bn/__init__.py:
--------------------------------------------------------------------------------
1 | from .inplace_abn import bn
--------------------------------------------------------------------------------
/lib/models/sync_bn/inplace_abn/__init__.py:
--------------------------------------------------------------------------------
1 | from .bn import ABN, InPlaceABN, InPlaceABNSync
2 | from .functions import ACT_RELU, ACT_LEAKY_RELU, ACT_ELU, ACT_NONE
3 |
--------------------------------------------------------------------------------
/lib/models/sync_bn/inplace_abn/bn.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as functional
5 |
6 | try:
7 | from queue import Queue
8 | except ImportError:
9 | from Queue import Queue
10 |
11 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
12 | sys.path.append(BASE_DIR)
13 | sys.path.append(os.path.join(BASE_DIR, '../src'))
14 | from functions import *
15 |
16 |
17 | class ABN(nn.Module):
18 | """Activated Batch Normalization
19 |
20 | This gathers a `BatchNorm2d` and an activation function in a single module
21 | """
22 |
23 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01):
24 | """Creates an Activated Batch Normalization module
25 |
26 | Parameters
27 | ----------
28 | num_features : int
29 | Number of feature channels in the input and output.
30 | eps : float
31 | Small constant to prevent numerical issues.
32 | momentum : float
33 | Momentum factor applied to compute running statistics as.
34 | affine : bool
35 | If `True` apply learned scale and shift transformation after normalization.
36 | activation : str
37 | Name of the activation functions, one of: `leaky_relu`, `elu` or `none`.
38 | slope : float
39 | Negative slope for the `leaky_relu` activation.
40 | """
41 | super(ABN, self).__init__()
42 | self.num_features = num_features
43 | self.affine = affine
44 | self.eps = eps
45 | self.momentum = momentum
46 | self.activation = activation
47 | self.slope = slope
48 | if self.affine:
49 | self.weight = nn.Parameter(torch.ones(num_features))
50 | self.bias = nn.Parameter(torch.zeros(num_features))
51 | else:
52 | self.register_parameter('weight', None)
53 | self.register_parameter('bias', None)
54 | self.register_buffer('running_mean', torch.zeros(num_features))
55 | self.register_buffer('running_var', torch.ones(num_features))
56 | self.reset_parameters()
57 |
58 | def reset_parameters(self):
59 | nn.init.constant_(self.running_mean, 0)
60 | nn.init.constant_(self.running_var, 1)
61 | if self.affine:
62 | nn.init.constant_(self.weight, 1)
63 | nn.init.constant_(self.bias, 0)
64 |
65 | def forward(self, x):
66 | x = functional.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias,
67 | self.training, self.momentum, self.eps)
68 |
69 | if self.activation == ACT_RELU:
70 | return functional.relu(x, inplace=True)
71 | elif self.activation == ACT_LEAKY_RELU:
72 | return functional.leaky_relu(x, negative_slope=self.slope, inplace=True)
73 | elif self.activation == ACT_ELU:
74 | return functional.elu(x, inplace=True)
75 | else:
76 | return x
77 |
78 | def __repr__(self):
79 | rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \
80 | ' affine={affine}, activation={activation}'
81 | if self.activation == "leaky_relu":
82 | rep += ', slope={slope})'
83 | else:
84 | rep += ')'
85 | return rep.format(name=self.__class__.__name__, **self.__dict__)
86 |
87 |
88 | class InPlaceABN(ABN):
89 | """InPlace Activated Batch Normalization"""
90 |
91 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01):
92 | """Creates an InPlace Activated Batch Normalization module
93 |
94 | Parameters
95 | ----------
96 | num_features : int
97 | Number of feature channels in the input and output.
98 | eps : float
99 | Small constant to prevent numerical issues.
100 | momentum : float
101 | Momentum factor applied to compute running statistics as.
102 | affine : bool
103 | If `True` apply learned scale and shift transformation after normalization.
104 | activation : str
105 | Name of the activation functions, one of: `leaky_relu`, `elu` or `none`.
106 | slope : float
107 | Negative slope for the `leaky_relu` activation.
108 | """
109 | super(InPlaceABN, self).__init__(num_features, eps, momentum, affine, activation, slope)
110 |
111 | def forward(self, x):
112 | return inplace_abn(x, self.weight, self.bias, self.running_mean, self.running_var,
113 | self.training, self.momentum, self.eps, self.activation, self.slope)
114 |
115 |
116 | class InPlaceABNSync(ABN):
117 | """InPlace Activated Batch Normalization with cross-GPU synchronization
118 |
119 | This assumes that it will be replicated across GPUs using the same mechanism as in `nn.DataParallel`.
120 | """
121 |
122 | def __init__(self, num_features, devices=None, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu",
123 | slope=0.01):
124 | """Creates a synchronized, InPlace Activated Batch Normalization module
125 |
126 | Parameters
127 | ----------
128 | num_features : int
129 | Number of feature channels in the input and output.
130 | devices : list of int or None
131 | IDs of the GPUs that will run the replicas of this module.
132 | eps : float
133 | Small constant to prevent numerical issues.
134 | momentum : float
135 | Momentum factor applied to compute running statistics as.
136 | affine : bool
137 | If `True` apply learned scale and shift transformation after normalization.
138 | activation : str
139 | Name of the activation functions, one of: `leaky_relu`, `elu` or `none`.
140 | slope : float
141 | Negative slope for the `leaky_relu` activation.
142 | """
143 | super(InPlaceABNSync, self).__init__(num_features, eps, momentum, affine, activation, slope)
144 | self.devices = devices if devices else list(range(torch.cuda.device_count()))
145 |
146 | # Initialize queues
147 | self.worker_ids = self.devices[1:]
148 | self.master_queue = Queue(len(self.worker_ids))
149 | self.worker_queues = [Queue(1) for _ in self.worker_ids]
150 |
151 | def forward(self, x):
152 | if x.get_device() == self.devices[0]:
153 | # Master mode
154 | extra = {
155 | "is_master": True,
156 | "master_queue": self.master_queue,
157 | "worker_queues": self.worker_queues,
158 | "worker_ids": self.worker_ids
159 | }
160 | else:
161 | # Worker mode
162 | extra = {
163 | "is_master": False,
164 | "master_queue": self.master_queue,
165 | "worker_queue": self.worker_queues[self.worker_ids.index(x.get_device())]
166 | }
167 |
168 | return inplace_abn_sync(x, self.weight, self.bias, self.running_mean, self.running_var,
169 | extra, self.training, self.momentum, self.eps, self.activation, self.slope)
170 |
171 | def __repr__(self):
172 | rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \
173 | ' affine={affine}, devices={devices}, activation={activation}'
174 | if self.activation == "leaky_relu":
175 | rep += ', slope={slope})'
176 | else:
177 | rep += ')'
178 | return rep.format(name=self.__class__.__name__, **self.__dict__)
179 |
--------------------------------------------------------------------------------
/lib/models/sync_bn/inplace_abn/functions.py:
--------------------------------------------------------------------------------
1 | from os import path
2 |
3 | import torch.autograd as autograd
4 | import torch.cuda.comm as comm
5 | from torch.autograd.function import once_differentiable
6 | from torch.utils.cpp_extension import load
7 |
8 | _src_path = path.join(path.dirname(path.abspath(__file__)), "src")
9 | _backend = load(name="inplace_abn",
10 | extra_cflags=["-O3"],
11 | sources=[path.join(_src_path, f) for f in [
12 | "inplace_abn.cpp",
13 | "inplace_abn_cpu.cpp",
14 | "inplace_abn_cuda.cu"
15 | ]],
16 | extra_cuda_cflags=["--expt-extended-lambda"])
17 |
18 | # Activation names
19 | ACT_RELU = "relu"
20 | ACT_LEAKY_RELU = "leaky_relu"
21 | ACT_ELU = "elu"
22 | ACT_NONE = "none"
23 |
24 |
25 | def _check(fn, *args, **kwargs):
26 | success = fn(*args, **kwargs)
27 | if not success:
28 | raise RuntimeError("CUDA Error encountered in {}".format(fn))
29 |
30 |
31 | def _broadcast_shape(x):
32 | out_size = []
33 | for i, s in enumerate(x.size()):
34 | if i != 1:
35 | out_size.append(1)
36 | else:
37 | out_size.append(s)
38 | return out_size
39 |
40 |
41 | def _reduce(x):
42 | if len(x.size()) == 2:
43 | return x.sum(dim=0)
44 | else:
45 | n, c = x.size()[0:2]
46 | return x.contiguous().view((n, c, -1)).sum(2).sum(0)
47 |
48 |
49 | def _count_samples(x):
50 | count = 1
51 | for i, s in enumerate(x.size()):
52 | if i != 1:
53 | count *= s
54 | return count
55 |
56 |
57 | def _act_forward(ctx, x):
58 | if ctx.activation == ACT_LEAKY_RELU:
59 | _backend.leaky_relu_forward(x, ctx.slope)
60 | elif ctx.activation == ACT_ELU:
61 | _backend.elu_forward(x)
62 | elif ctx.activation == ACT_NONE:
63 | pass
64 |
65 |
66 | def _act_backward(ctx, x, dx):
67 | if ctx.activation == ACT_LEAKY_RELU:
68 | _backend.leaky_relu_backward(x, dx, ctx.slope)
69 | elif ctx.activation == ACT_ELU:
70 | _backend.elu_backward(x, dx)
71 | elif ctx.activation == ACT_NONE:
72 | pass
73 |
74 |
75 | class InPlaceABN(autograd.Function):
76 | @staticmethod
77 | def forward(ctx, x, weight, bias, running_mean, running_var,
78 | training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01):
79 | # Save context
80 | ctx.training = training
81 | ctx.momentum = momentum
82 | ctx.eps = eps
83 | ctx.activation = activation
84 | ctx.slope = slope
85 | ctx.affine = weight is not None and bias is not None
86 |
87 | # Prepare inputs
88 | count = _count_samples(x)
89 | x = x.contiguous()
90 | weight = weight.contiguous() if ctx.affine else x.new_empty(0)
91 | bias = bias.contiguous() if ctx.affine else x.new_empty(0)
92 |
93 | if ctx.training:
94 | mean, var = _backend.mean_var(x)
95 |
96 | # Update running stats
97 | running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean)
98 | running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * count / (count - 1))
99 |
100 | # Mark in-place modified tensors
101 | ctx.mark_dirty(x, running_mean, running_var)
102 | else:
103 | mean, var = running_mean.contiguous(), running_var.contiguous()
104 | ctx.mark_dirty(x)
105 |
106 | # BN forward + activation
107 | _backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps)
108 | _act_forward(ctx, x)
109 |
110 | # Output
111 | ctx.var = var
112 | ctx.save_for_backward(x, var, weight, bias)
113 | return x
114 |
115 | @staticmethod
116 | @once_differentiable
117 | def backward(ctx, dz):
118 | z, var, weight, bias = ctx.saved_tensors
119 | dz = dz.contiguous()
120 |
121 | # Undo activation
122 | _act_backward(ctx, z, dz)
123 |
124 | if ctx.training:
125 | edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps)
126 | else:
127 | # TODO: implement simplified CUDA backward for inference mode
128 | edz = dz.new_zeros(dz.size(1))
129 | eydz = dz.new_zeros(dz.size(1))
130 |
131 | dx, dweight, dbias = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps)
132 | dweight = dweight if ctx.affine else None
133 | dbias = dbias if ctx.affine else None
134 |
135 | return dx, dweight, dbias, None, None, None, None, None, None, None
136 |
137 |
138 | class InPlaceABNSync(autograd.Function):
139 | @classmethod
140 | def forward(cls, ctx, x, weight, bias, running_mean, running_var,
141 | extra, training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01):
142 | # Save context
143 | cls._parse_extra(ctx, extra)
144 | ctx.training = training
145 | ctx.momentum = momentum
146 | ctx.eps = eps
147 | ctx.activation = activation
148 | ctx.slope = slope
149 | ctx.affine = weight is not None and bias is not None
150 |
151 | # Prepare inputs
152 | count = _count_samples(x) * (ctx.master_queue.maxsize + 1)
153 | x = x.contiguous()
154 | weight = weight.contiguous() if ctx.affine else x.new_empty(0)
155 | bias = bias.contiguous() if ctx.affine else x.new_empty(0)
156 |
157 | if ctx.training:
158 | mean, var = _backend.mean_var(x)
159 |
160 | if ctx.is_master:
161 | means, vars = [mean.unsqueeze(0)], [var.unsqueeze(0)]
162 | for _ in range(ctx.master_queue.maxsize):
163 | mean_w, var_w = ctx.master_queue.get()
164 | ctx.master_queue.task_done()
165 | means.append(mean_w.unsqueeze(0))
166 | vars.append(var_w.unsqueeze(0))
167 |
168 | means = comm.gather(means)
169 | vars = comm.gather(vars)
170 |
171 | mean = means.mean(0)
172 | var = (vars + (mean - means) ** 2).mean(0)
173 |
174 | tensors = comm.broadcast_coalesced((mean, var), [mean.get_device()] + ctx.worker_ids)
175 | for ts, queue in zip(tensors[1:], ctx.worker_queues):
176 | queue.put(ts)
177 | else:
178 | ctx.master_queue.put((mean, var))
179 | mean, var = ctx.worker_queue.get()
180 | ctx.worker_queue.task_done()
181 |
182 | # Update running stats
183 | running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean)
184 | running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * count / (count - 1))
185 |
186 | # Mark in-place modified tensors
187 | ctx.mark_dirty(x, running_mean, running_var)
188 | else:
189 | mean, var = running_mean.contiguous(), running_var.contiguous()
190 | ctx.mark_dirty(x)
191 |
192 | # BN forward + activation
193 | _backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps)
194 | _act_forward(ctx, x)
195 |
196 | # Output
197 | ctx.var = var
198 | ctx.save_for_backward(x, var, weight, bias)
199 | return x
200 |
201 | @staticmethod
202 | @once_differentiable
203 | def backward(ctx, dz):
204 | z, var, weight, bias = ctx.saved_tensors
205 | dz = dz.contiguous()
206 |
207 | # Undo activation
208 | _act_backward(ctx, z, dz)
209 |
210 | if ctx.training:
211 | edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps)
212 |
213 | if ctx.is_master:
214 | edzs, eydzs = [edz], [eydz]
215 | for _ in range(len(ctx.worker_queues)):
216 | edz_w, eydz_w = ctx.master_queue.get()
217 | ctx.master_queue.task_done()
218 | edzs.append(edz_w)
219 | eydzs.append(eydz_w)
220 |
221 | edz = comm.reduce_add(edzs) / (ctx.master_queue.maxsize + 1)
222 | eydz = comm.reduce_add(eydzs) / (ctx.master_queue.maxsize + 1)
223 |
224 | tensors = comm.broadcast_coalesced((edz, eydz), [edz.get_device()] + ctx.worker_ids)
225 | for ts, queue in zip(tensors[1:], ctx.worker_queues):
226 | queue.put(ts)
227 | else:
228 | ctx.master_queue.put((edz, eydz))
229 | edz, eydz = ctx.worker_queue.get()
230 | ctx.worker_queue.task_done()
231 | else:
232 | edz = dz.new_zeros(dz.size(1))
233 | eydz = dz.new_zeros(dz.size(1))
234 |
235 | dx, dweight, dbias = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps)
236 | dweight = dweight if ctx.affine else None
237 | dbias = dbias if ctx.affine else None
238 |
239 | return dx, dweight, dbias, None, None, None, None, None, None, None, None
240 |
241 | @staticmethod
242 | def _parse_extra(ctx, extra):
243 | ctx.is_master = extra["is_master"]
244 | if ctx.is_master:
245 | ctx.master_queue = extra["master_queue"]
246 | ctx.worker_queues = extra["worker_queues"]
247 | ctx.worker_ids = extra["worker_ids"]
248 | else:
249 | ctx.master_queue = extra["master_queue"]
250 | ctx.worker_queue = extra["worker_queue"]
251 |
252 |
253 | inplace_abn = InPlaceABN.apply
254 | inplace_abn_sync = InPlaceABNSync.apply
255 |
256 | __all__ = ["inplace_abn", "inplace_abn_sync", "ACT_RELU", "ACT_LEAKY_RELU", "ACT_ELU", "ACT_NONE"]
257 |
--------------------------------------------------------------------------------
/lib/models/sync_bn/inplace_abn/src/common.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 |
5 | /*
6 | * General settings
7 | */
8 | const int WARP_SIZE = 32;
9 | const int MAX_BLOCK_SIZE = 512;
10 |
11 | template
12 | struct Pair {
13 | T v1, v2;
14 | __device__ Pair() {}
15 | __device__ Pair(T _v1, T _v2) : v1(_v1), v2(_v2) {}
16 | __device__ Pair(T v) : v1(v), v2(v) {}
17 | __device__ Pair(int v) : v1(v), v2(v) {}
18 | __device__ Pair &operator+=(const Pair &a) {
19 | v1 += a.v1;
20 | v2 += a.v2;
21 | return *this;
22 | }
23 | };
24 |
25 | /*
26 | * Utility functions
27 | */
28 | template
29 | __device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize,
30 | unsigned int mask = 0xffffffff) {
31 | #if CUDART_VERSION >= 9000
32 | return __shfl_xor_sync(mask, value, laneMask, width);
33 | #else
34 | return __shfl_xor(value, laneMask, width);
35 | #endif
36 | }
37 |
38 | __device__ __forceinline__ int getMSB(int val) { return 31 - __clz(val); }
39 |
40 | static int getNumThreads(int nElem) {
41 | int threadSizes[5] = {32, 64, 128, 256, MAX_BLOCK_SIZE};
42 | for (int i = 0; i != 5; ++i) {
43 | if (nElem <= threadSizes[i]) {
44 | return threadSizes[i];
45 | }
46 | }
47 | return MAX_BLOCK_SIZE;
48 | }
49 |
50 | template
51 | static __device__ __forceinline__ T warpSum(T val) {
52 | #if __CUDA_ARCH__ >= 300
53 | for (int i = 0; i < getMSB(WARP_SIZE); ++i) {
54 | val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE);
55 | }
56 | #else
57 | __shared__ T values[MAX_BLOCK_SIZE];
58 | values[threadIdx.x] = val;
59 | __threadfence_block();
60 | const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE;
61 | for (int i = 1; i < WARP_SIZE; i++) {
62 | val += values[base + ((i + threadIdx.x) % WARP_SIZE)];
63 | }
64 | #endif
65 | return val;
66 | }
67 |
68 | template
69 | static __device__ __forceinline__ Pair warpSum(Pair value) {
70 | value.v1 = warpSum(value.v1);
71 | value.v2 = warpSum(value.v2);
72 | return value;
73 | }
74 |
75 | template
76 | __device__ T reduce(Op op, int plane, int N, int C, int S) {
77 | T sum = (T)0;
78 | for (int batch = 0; batch < N; ++batch) {
79 | for (int x = threadIdx.x; x < S; x += blockDim.x) {
80 | sum += op(batch, plane, x);
81 | }
82 | }
83 |
84 | // sum over NumThreads within a warp
85 | sum = warpSum(sum);
86 |
87 | // 'transpose', and reduce within warp again
88 | __shared__ T shared[32];
89 | __syncthreads();
90 | if (threadIdx.x % WARP_SIZE == 0) {
91 | shared[threadIdx.x / WARP_SIZE] = sum;
92 | }
93 | if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) {
94 | // zero out the other entries in shared
95 | shared[threadIdx.x] = (T)0;
96 | }
97 | __syncthreads();
98 | if (threadIdx.x / WARP_SIZE == 0) {
99 | sum = warpSum(shared[threadIdx.x]);
100 | if (threadIdx.x == 0) {
101 | shared[0] = sum;
102 | }
103 | }
104 | __syncthreads();
105 |
106 | // Everyone picks it up, should be broadcast into the whole gradInput
107 | return shared[0];
108 | }
--------------------------------------------------------------------------------
/lib/models/sync_bn/inplace_abn/src/inplace_abn.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include
4 |
5 | #include "inplace_abn.h"
6 |
7 | std::vector mean_var(at::Tensor x) {
8 | if (x.is_cuda()) {
9 | return mean_var_cuda(x);
10 | } else {
11 | return mean_var_cpu(x);
12 | }
13 | }
14 |
15 | at::Tensor forward(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
16 | bool affine, float eps) {
17 | if (x.is_cuda()) {
18 | return forward_cuda(x, mean, var, weight, bias, affine, eps);
19 | } else {
20 | return forward_cpu(x, mean, var, weight, bias, affine, eps);
21 | }
22 | }
23 |
24 | std::vector edz_eydz(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
25 | bool affine, float eps) {
26 | if (z.is_cuda()) {
27 | return edz_eydz_cuda(z, dz, weight, bias, affine, eps);
28 | } else {
29 | return edz_eydz_cpu(z, dz, weight, bias, affine, eps);
30 | }
31 | }
32 |
33 | std::vector backward(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
34 | at::Tensor edz, at::Tensor eydz, bool affine, float eps) {
35 | if (z.is_cuda()) {
36 | return backward_cuda(z, dz, var, weight, bias, edz, eydz, affine, eps);
37 | } else {
38 | return backward_cpu(z, dz, var, weight, bias, edz, eydz, affine, eps);
39 | }
40 | }
41 |
42 | void leaky_relu_forward(at::Tensor z, float slope) {
43 | at::leaky_relu_(z, slope);
44 | }
45 |
46 | void leaky_relu_backward(at::Tensor z, at::Tensor dz, float slope) {
47 | if (z.is_cuda()) {
48 | return leaky_relu_backward_cuda(z, dz, slope);
49 | } else {
50 | return leaky_relu_backward_cpu(z, dz, slope);
51 | }
52 | }
53 |
54 | void elu_forward(at::Tensor z) {
55 | at::elu_(z);
56 | }
57 |
58 | void elu_backward(at::Tensor z, at::Tensor dz) {
59 | if (z.is_cuda()) {
60 | return elu_backward_cuda(z, dz);
61 | } else {
62 | return elu_backward_cpu(z, dz);
63 | }
64 | }
65 |
66 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
67 | m.def("mean_var", &mean_var, "Mean and variance computation");
68 | m.def("forward", &forward, "In-place forward computation");
69 | m.def("edz_eydz", &edz_eydz, "First part of backward computation");
70 | m.def("backward", &backward, "Second part of backward computation");
71 | m.def("leaky_relu_forward", &leaky_relu_forward, "Leaky relu forward computation");
72 | m.def("leaky_relu_backward", &leaky_relu_backward, "Leaky relu backward computation and inversion");
73 | m.def("elu_forward", &elu_forward, "Elu forward computation");
74 | m.def("elu_backward", &elu_backward, "Elu backward computation and inversion");
75 | }
--------------------------------------------------------------------------------
/lib/models/sync_bn/inplace_abn/src/inplace_abn.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 |
5 | #include
6 |
7 | std::vector mean_var_cpu(at::Tensor x);
8 | std::vector mean_var_cuda(at::Tensor x);
9 |
10 | at::Tensor forward_cpu(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
11 | bool affine, float eps);
12 | at::Tensor forward_cuda(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
13 | bool affine, float eps);
14 |
15 | std::vector edz_eydz_cpu(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
16 | bool affine, float eps);
17 | std::vector edz_eydz_cuda(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
18 | bool affine, float eps);
19 |
20 | std::vector backward_cpu(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
21 | at::Tensor edz, at::Tensor eydz, bool affine, float eps);
22 | std::vector backward_cuda(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
23 | at::Tensor edz, at::Tensor eydz, bool affine, float eps);
24 |
25 | void leaky_relu_backward_cpu(at::Tensor z, at::Tensor dz, float slope);
26 | void leaky_relu_backward_cuda(at::Tensor z, at::Tensor dz, float slope);
27 |
28 | void elu_backward_cpu(at::Tensor z, at::Tensor dz);
29 | void elu_backward_cuda(at::Tensor z, at::Tensor dz);
--------------------------------------------------------------------------------
/lib/models/sync_bn/inplace_abn/src/inplace_abn_cpu.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include
4 |
5 | #include "inplace_abn.h"
6 |
7 | at::Tensor reduce_sum(at::Tensor x) {
8 | if (x.ndimension() == 2) {
9 | return x.sum(0);
10 | } else {
11 | auto x_view = x.view({x.size(0), x.size(1), -1});
12 | return x_view.sum(-1).sum(0);
13 | }
14 | }
15 |
16 | at::Tensor broadcast_to(at::Tensor v, at::Tensor x) {
17 | if (x.ndimension() == 2) {
18 | return v;
19 | } else {
20 | std::vector broadcast_size = {1, -1};
21 | for (int64_t i = 2; i < x.ndimension(); ++i)
22 | broadcast_size.push_back(1);
23 |
24 | return v.view(broadcast_size);
25 | }
26 | }
27 |
28 | int64_t count(at::Tensor x) {
29 | int64_t count = x.size(0);
30 | for (int64_t i = 2; i < x.ndimension(); ++i)
31 | count *= x.size(i);
32 |
33 | return count;
34 | }
35 |
36 | at::Tensor invert_affine(at::Tensor z, at::Tensor weight, at::Tensor bias, bool affine, float eps) {
37 | if (affine) {
38 | return (z - broadcast_to(bias, z)) / broadcast_to(at::abs(weight) + eps, z);
39 | } else {
40 | return z;
41 | }
42 | }
43 |
44 | std::vector mean_var_cpu(at::Tensor x) {
45 | auto num = count(x);
46 | auto mean = reduce_sum(x) / num;
47 | auto diff = x - broadcast_to(mean, x);
48 | auto var = reduce_sum(diff.pow(2)) / num;
49 |
50 | return {mean, var};
51 | }
52 |
53 | at::Tensor forward_cpu(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
54 | bool affine, float eps) {
55 | auto gamma = affine ? at::abs(weight) + eps : at::ones_like(var);
56 | auto mul = at::rsqrt(var + eps) * gamma;
57 |
58 | x.sub_(broadcast_to(mean, x));
59 | x.mul_(broadcast_to(mul, x));
60 | if (affine) x.add_(broadcast_to(bias, x));
61 |
62 | return x;
63 | }
64 |
65 | std::vector edz_eydz_cpu(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
66 | bool affine, float eps) {
67 | auto edz = reduce_sum(dz);
68 | auto y = invert_affine(z, weight, bias, affine, eps);
69 | auto eydz = reduce_sum(y * dz);
70 |
71 | return {edz, eydz};
72 | }
73 |
74 | std::vector backward_cpu(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
75 | at::Tensor edz, at::Tensor eydz, bool affine, float eps) {
76 | auto y = invert_affine(z, weight, bias, affine, eps);
77 | auto mul = affine ? at::rsqrt(var + eps) * (at::abs(weight) + eps) : at::rsqrt(var + eps);
78 |
79 | auto num = count(z);
80 | auto dx = (dz - broadcast_to(edz / num, dz) - y * broadcast_to(eydz / num, dz)) * broadcast_to(mul, dz);
81 |
82 | auto dweight = at::empty(z.type(), {0});
83 | auto dbias = at::empty(z.type(), {0});
84 | if (affine) {
85 | dweight = eydz * at::sign(weight);
86 | dbias = edz;
87 | }
88 |
89 | return {dx, dweight, dbias};
90 | }
91 |
92 | void leaky_relu_backward_cpu(at::Tensor z, at::Tensor dz, float slope) {
93 | AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cpu", ([&] {
94 | int64_t count = z.numel();
95 | auto *_z = z.data();
96 | auto *_dz = dz.data();
97 |
98 | for (int64_t i = 0; i < count; ++i) {
99 | if (_z[i] < 0) {
100 | _z[i] *= 1 / slope;
101 | _dz[i] *= slope;
102 | }
103 | }
104 | }));
105 | }
106 |
107 | void elu_backward_cpu(at::Tensor z, at::Tensor dz) {
108 | AT_DISPATCH_FLOATING_TYPES(z.type(), "elu_backward_cpu", ([&] {
109 | int64_t count = z.numel();
110 | auto *_z = z.data();
111 | auto *_dz = dz.data();
112 |
113 | for (int64_t i = 0; i < count; ++i) {
114 | if (_z[i] < 0) {
115 | _z[i] = log1p(_z[i]);
116 | _dz[i] *= (_z[i] + 1.f);
117 | }
118 | }
119 | }));
120 | }
--------------------------------------------------------------------------------
/lib/utils/DenseCRF.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Author: Kazuto Nakashima
4 | # URL: https://kazuto1011.github.io
5 | # Date: 09 January 2019
6 |
7 | import numpy as np
8 | import pydensecrf.densecrf as dcrf
9 | import pydensecrf.utils as utils
10 |
11 | class DenseCRF(object):
12 | def __init__(self, max_epochs=5, delta_aphla=80, delta_beta=3, w1=10, delta_gamma=3, w2=3):
13 | self.max_epochs = max_epochs
14 | self.delta_gamma = delta_gamma
15 | self.delta_alpha = delta_aphla
16 | self.delta_beta = delta_beta
17 | self.w1 = w1
18 | self.w2 = w2
19 |
20 | def __call__(self, image, probmap):
21 | c, h, w = probmap.shape
22 |
23 | U = utils.unary_from_softmax(probmap)
24 | U = np.ascontiguousarray(U)
25 |
26 | image = np.ascontiguousarray(image)
27 |
28 | d = dcrf.DenseCRF2D(w, h, c)
29 | d.setUnaryEnergy(U)
30 |
31 | d.addPairwiseGaussian(sxy=self.delta_gamma, compat=self.w2)
32 | d.addPairwiseBilateral(sxy=self.delta_alpha, srgb=self.delta_beta, rgbim=image, compat=self.w1)
33 |
34 | Q = d.inference(self.max_epochs)
35 | Q = np.array(Q).reshape((c, h, w))
36 |
37 | return Q
38 |
39 | # import numpy as np
40 | # import pydensecrf.densecrf as dcrf
41 | # import pydensecrf.utils as utils
42 |
43 |
44 | # class DenseCRF(object):
45 | # def __init__(self, iter_max, pos_w, pos_xy_std, bi_w, bi_xy_std, bi_rgb_std):
46 | # self.iter_max = iter_max # iter num
47 | # self.pos_w = pos_w # the weight of the Gaussian kernel which only depends on Pixel Position
48 | # self.pos_xy_std = pos_xy_std
49 | # self.bi_w = bi_w # the weight of bilateral kernel
50 | # self.bi_xy_std = bi_xy_std
51 | # self.bi_rgb_std = bi_rgb_std
52 |
53 | # def __call__(self, image, probmap):
54 | # C, H, W = probmap.shape
55 |
56 | # U = utils.unary_from_softmax(probmap)
57 | # U = np.ascontiguousarray(U)
58 |
59 | # image = np.ascontiguousarray(image)
60 |
61 | # d = dcrf.DenseCRF2D(W, H, C)
62 | # d.setUnaryEnergy(U)
63 |
64 | # # the gaussian kernel depends only on pixel position
65 | # d.addPairwiseGaussian(sxy=self.pos_xy_std, compat=self.pos_w)
66 |
67 | # # bilateral kernel depends on both position and color
68 | # d.addPairwiseBilateral(
69 | # sxy=self.bi_xy_std, srgb=self.bi_rgb_std, rgbim=image, compat=self.bi_w
70 | # )
71 |
72 | # Q = d.inference(self.iter_max)
73 | # Q = np.array(Q).reshape((C, H, W))
74 |
75 | # return Q
--------------------------------------------------------------------------------
/lib/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/utils/__init__.py
--------------------------------------------------------------------------------
/lib/utils/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/utils/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/lib/utils/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/utils/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/utils/__pycache__/distributed.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/utils/__pycache__/distributed.cpython-35.pyc
--------------------------------------------------------------------------------
/lib/utils/__pycache__/distributed.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/utils/__pycache__/distributed.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/utils/__pycache__/modelsummary.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/utils/__pycache__/modelsummary.cpython-35.pyc
--------------------------------------------------------------------------------
/lib/utils/__pycache__/modelsummary.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/utils/__pycache__/modelsummary.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/utils/__pycache__/utils.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/utils/__pycache__/utils.cpython-35.pyc
--------------------------------------------------------------------------------
/lib/utils/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/lib/utils/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/utils/distributed.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Jingyi Xie (hsfzxjy@gmail.com)
5 | # ------------------------------------------------------------------------------
6 |
7 | import torch
8 | import torch.distributed as torch_dist
9 |
10 | def is_distributed():
11 | return torch_dist.is_initialized()
12 |
13 | def get_world_size():
14 | if not torch_dist.is_initialized():
15 | return 1
16 | return torch_dist.get_world_size()
17 |
18 | def get_rank():
19 | if not torch_dist.is_initialized():
20 | return 0
21 | return torch_dist.get_rank()
--------------------------------------------------------------------------------
/lib/utils/modelsummary.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Bin Xiao (Bin.Xiao@microsoft.com)
5 | # Modified by Ke Sun (sunk@mail.ustc.edu.cn)
6 | # ------------------------------------------------------------------------------
7 |
8 | from __future__ import absolute_import
9 | from __future__ import division
10 | from __future__ import print_function
11 |
12 | import os
13 | import logging
14 | from collections import namedtuple
15 |
16 | import torch
17 | import torch.nn as nn
18 |
19 | def get_model_summary(model, *input_tensors, item_length=26, verbose=False):
20 | """
21 | :param model:
22 | :param input_tensors:
23 | :param item_length:
24 | :return:
25 | """
26 |
27 | summary = []
28 |
29 | ModuleDetails = namedtuple(
30 | "Layer", ["name", "input_size", "output_size", "num_parameters", "multiply_adds"])
31 | hooks = []
32 | layer_instances = {}
33 |
34 | def add_hooks(module):
35 |
36 | def hook(module, input, output):
37 | class_name = str(module.__class__.__name__)
38 |
39 | instance_index = 1
40 | if class_name not in layer_instances:
41 | layer_instances[class_name] = instance_index
42 | else:
43 | instance_index = layer_instances[class_name] + 1
44 | layer_instances[class_name] = instance_index
45 |
46 | layer_name = class_name + "_" + str(instance_index)
47 |
48 | params = 0
49 |
50 | if class_name.find("Conv") != -1 or class_name.find("BatchNorm") != -1 or \
51 | class_name.find("Linear") != -1:
52 | for param_ in module.parameters():
53 | params += param_.view(-1).size(0)
54 |
55 | flops = "Not Available"
56 | if class_name.find("Conv") != -1 and hasattr(module, "weight"):
57 | flops = (
58 | torch.prod(
59 | torch.LongTensor(list(module.weight.data.size()))) *
60 | torch.prod(
61 | torch.LongTensor(list(output.size())[2:]))).item()
62 | elif isinstance(module, nn.Linear):
63 | flops = (torch.prod(torch.LongTensor(list(output.size()))) \
64 | * input[0].size(1)).item()
65 |
66 | if isinstance(input[0], list):
67 | input = input[0]
68 | if isinstance(output, list):
69 | output = output[0]
70 |
71 | summary.append(
72 | ModuleDetails(
73 | name=layer_name,
74 | input_size=list(input[0].size()),
75 | output_size=list(output.size()),
76 | num_parameters=params,
77 | multiply_adds=flops)
78 | )
79 |
80 | if not isinstance(module, nn.ModuleList) \
81 | and not isinstance(module, nn.Sequential) \
82 | and module != model:
83 | hooks.append(module.register_forward_hook(hook))
84 |
85 | model.eval()
86 | model.apply(add_hooks)
87 |
88 | space_len = item_length
89 |
90 | model(*input_tensors)
91 | for hook in hooks:
92 | hook.remove()
93 |
94 | details = ''
95 | if verbose:
96 | details = "Model Summary" + \
97 | os.linesep + \
98 | "Name{}Input Size{}Output Size{}Parameters{}Multiply Adds (Flops){}".format(
99 | ' ' * (space_len - len("Name")),
100 | ' ' * (space_len - len("Input Size")),
101 | ' ' * (space_len - len("Output Size")),
102 | ' ' * (space_len - len("Parameters")),
103 | ' ' * (space_len - len("Multiply Adds (Flops)"))) \
104 | + os.linesep + '-' * space_len * 5 + os.linesep
105 |
106 | params_sum = 0
107 | flops_sum = 0
108 | for layer in summary:
109 | params_sum += layer.num_parameters
110 | if layer.multiply_adds != "Not Available":
111 | flops_sum += layer.multiply_adds
112 | if verbose:
113 | details += "{}{}{}{}{}{}{}{}{}{}".format(
114 | layer.name,
115 | ' ' * (space_len - len(layer.name)),
116 | layer.input_size,
117 | ' ' * (space_len - len(str(layer.input_size))),
118 | layer.output_size,
119 | ' ' * (space_len - len(str(layer.output_size))),
120 | layer.num_parameters,
121 | ' ' * (space_len - len(str(layer.num_parameters))),
122 | layer.multiply_adds,
123 | ' ' * (space_len - len(str(layer.multiply_adds)))) \
124 | + os.linesep + '-' * space_len * 5 + os.linesep
125 |
126 | details += os.linesep \
127 | + "Total Parameters: {:,}".format(params_sum) \
128 | + os.linesep + '-' * space_len * 5 + os.linesep
129 | details += "Total Multiply Adds (For Convolution and Linear Layers only): {:,} GFLOPs".format(flops_sum/(1024**3)) \
130 | + os.linesep + '-' * space_len * 5 + os.linesep
131 | details += "Number of Layers" + os.linesep
132 | for layer in layer_instances:
133 | details += "{} : {} layers ".format(layer, layer_instances[layer])
134 |
135 | return details
--------------------------------------------------------------------------------
/lib/utils/utils.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn)
5 | # ------------------------------------------------------------------------------
6 |
7 | from __future__ import absolute_import
8 | from __future__ import division
9 | from __future__ import print_function
10 |
11 | import os
12 | import logging
13 | import time
14 | from pathlib import Path
15 |
16 | import numpy as np
17 |
18 | import torch
19 | import torch.nn as nn
20 | import torch.nn.functional as F
21 |
22 | class FullModel(nn.Module):
23 | """
24 | Distribute the loss on multi-gpu to reduce
25 | the memory cost in the main gpu.
26 | You can check the following discussion.
27 | https://discuss.pytorch.org/t/dataparallel-imbalanced-memory-usage/22551/21
28 | """
29 | def __init__(self, model, loss):
30 | super(FullModel, self).__init__()
31 | self.model = model
32 | print(model)
33 | self.loss = loss
34 |
35 | def pixel_acc(self, pred, label):
36 | # print('pre:',pred.shape, label.shape)
37 | if pred.shape[2] != label.shape[1] and pred.shape[3] != label.shape[2]:
38 | pred = F.interpolate(pred, (label.shape[1:]), mode="nearest")
39 |
40 | _, preds = torch.max(pred, dim=1)
41 | valid = (label >= 0).long()
42 | acc_sum = torch.sum(valid * (preds == label).long())
43 | pixel_sum = torch.sum(valid)
44 | acc = acc_sum.float() / (pixel_sum.float() + 1e-10)
45 | return acc
46 |
47 | def forward(self, inputs, labels, *args, **kwargs):
48 | outputs = self.model(inputs, *args, **kwargs)
49 | # print("output:",len(outputs), outputs[0].size(),outputs[1].size())
50 | loss = self.loss(outputs, labels)
51 | acc = self.pixel_acc(outputs[1], labels)
52 | return torch.unsqueeze(loss,0), outputs, acc
53 |
54 | class AverageMeter(object):
55 | """Computes and stores the average and current value"""
56 |
57 | def __init__(self):
58 | self.initialized = False
59 | self.val = None
60 | self.avg = None
61 | self.sum = None
62 | self.count = None
63 |
64 | def initialize(self, val, weight):
65 | self.val = val
66 | self.avg = val
67 | self.sum = val * weight
68 | self.count = weight
69 | self.initialized = True
70 |
71 | def update(self, val, weight=1):
72 | if not self.initialized:
73 | self.initialize(val, weight)
74 | else:
75 | self.add(val, weight)
76 |
77 | def add(self, val, weight):
78 | self.val = val
79 | self.sum += val * weight
80 | self.count += weight
81 | self.avg = self.sum / self.count
82 |
83 | def value(self):
84 | return self.val
85 |
86 | def average(self):
87 | return self.avg
88 |
89 | def create_logger(cfg, cfg_name, phase='train'):
90 | root_output_dir = Path(cfg.OUTPUT_DIR)
91 | # set up logger
92 | if not root_output_dir.exists():
93 | print('=> creating {}'.format(root_output_dir))
94 | root_output_dir.mkdir()
95 |
96 | dataset = cfg.DATASET.DATASET
97 | model = cfg.MODEL.NAME
98 | cfg_name = os.path.basename(cfg_name).split('.')[0]
99 |
100 | final_output_dir = root_output_dir / dataset / cfg_name
101 |
102 | print('=> creating {}'.format(final_output_dir))
103 | final_output_dir.mkdir(parents=True, exist_ok=True)
104 |
105 | time_str = time.strftime('%Y-%m-%d-%H-%M')
106 | log_file = '{}_{}_{}.log'.format(cfg_name, time_str, phase)
107 | final_log_file = final_output_dir / log_file
108 | head = '%(asctime)-15s %(message)s'
109 | logging.basicConfig(filename=str(final_log_file),
110 | format=head)
111 | logger = logging.getLogger()
112 | logger.setLevel(logging.INFO)
113 | console = logging.StreamHandler()
114 | logging.getLogger('').addHandler(console)
115 |
116 | tensorboard_log_dir = Path(cfg.LOG_DIR) / dataset / model / \
117 | (cfg_name + '_' + time_str)
118 | print('=> creating {}'.format(tensorboard_log_dir))
119 | tensorboard_log_dir.mkdir(parents=True, exist_ok=True)
120 |
121 | return logger, str(final_output_dir), str(tensorboard_log_dir)
122 |
123 | def get_confusion_matrix(label, pred, size, num_class, ignore=-1):
124 | """
125 | Calcute the confusion matrix by given label and pred
126 | """
127 | output = pred.cpu().numpy().transpose(0, 2, 3, 1)
128 | seg_pred = np.asarray(np.argmax(output, axis=3), dtype=np.uint8)
129 | seg_gt = np.asarray(
130 | label.cpu().numpy()[:, :size[-2], :size[-1]], dtype=np.int)
131 |
132 | ignore_index = seg_gt != ignore
133 | seg_gt = seg_gt[ignore_index]
134 | seg_pred = seg_pred[ignore_index]
135 |
136 | index = (seg_gt * num_class + seg_pred).astype('int32')
137 | label_count = np.bincount(index)
138 | confusion_matrix = np.zeros((num_class, num_class))
139 |
140 | for i_label in range(num_class):
141 | for i_pred in range(num_class):
142 | cur_index = i_label * num_class + i_pred
143 | if cur_index < len(label_count):
144 | confusion_matrix[i_label,
145 | i_pred] = label_count[cur_index]
146 | return confusion_matrix
147 |
148 | def adjust_learning_rate(optimizer, base_lr, max_iters,
149 | cur_iters, power=0.9, nbb_mult=10):
150 | lr = base_lr*((1-float(cur_iters)/max_iters)**(power))
151 | optimizer.param_groups[0]['lr'] = lr
152 | if len(optimizer.param_groups) == 2:
153 | optimizer.param_groups[1]['lr'] = lr * nbb_mult
154 | return lr
155 |
156 | import cv2
157 | from PIL import Image
158 |
159 | def colorEncode(labelmap, colors, mode='RGB'):
160 | labelmap = labelmap.astype('int')
161 | labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3),
162 | dtype=np.uint8)
163 | for label in np.unique(labelmap):
164 | if label < 0:
165 | continue
166 | labelmap_rgb += (labelmap == label)[:, :, np.newaxis] * \
167 | np.tile(colors[label],
168 | (labelmap.shape[0], labelmap.shape[1], 1))
169 |
170 | if mode == 'BGR':
171 | return labelmap_rgb[:, :, ::-1]
172 | else:
173 | return labelmap_rgb
174 |
175 | class Vedio(object):
176 | def __init__(self, video_path):
177 | self.video_path = video_path
178 | self.cap = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), 15, (1280, 480))
179 |
180 | def addImage(self, img, colorMask):
181 | img = img[:,:,::-1]
182 | colorMask = colorMask[:,:,::-1] # shape:
183 | img = np.concatenate([img, colorMask], axis=1)
184 | self.cap.write(img)
185 |
186 | def releaseCap(self):
187 | self.cap.release()
188 |
189 |
190 | class Map16(object):
191 | def __init__(self, vedioCap, visualpoint=True):
192 | self.names = ("background", "floor", "bed", "cabinet,wardrobe,bookcase,shelf",
193 | "person", "door", "table,desk,coffee", "chair,armchair,sofa,bench,swivel,stool",
194 | "rug", "railing", "column", "refrigerator", "stairs,stairway,step", "escalator", "wall","c","b","a",
195 | "dog", "plant")
196 | self.colors = np.array([[0, 0, 0],
197 | [0, 0, 255],
198 | [0, 255, 0],
199 | [0, 255, 255],
200 | [255, 0, 0 ],
201 | [255, 0, 255 ],
202 | [255, 255, 0 ],
203 | [255, 255, 255 ],
204 | [0, 0, 128 ],
205 | [0, 128, 0 ],
206 | [128, 0, 0 ],
207 | [0, 128, 128 ],
208 | [128, 0, 0 ],
209 | [128, 0, 128 ],
210 | [128, 128, 0 ],
211 |
212 | [128, 255, 0 ],
213 | [128, 255, 128 ],
214 | [128, 128, 255 ],
215 |
216 | [128, 128, 128 ],
217 | [192, 192, 192 ],], dtype=np.uint8)
218 | self.outDir = "output/map16"
219 | self.vedioCap = vedioCap
220 | self.visualpoint = visualpoint
221 |
222 | def visualize_result(self, data, pred, dir, img_name=None):
223 | img = data
224 |
225 | pred = np.int32(pred)
226 | pixs = pred.size
227 | uniques, counts = np.unique(pred, return_counts=True)
228 | for idx in np.argsort(counts)[::-1]:
229 | name = self.names[uniques[idx]]
230 | ratio = counts[idx] / pixs * 100
231 | if ratio > 0.1:
232 | print(" {}: {:.2f}%".format(name, ratio))
233 |
234 | # calculate point
235 | if self.visualpoint:
236 | #转化为灰度float32类型进行处理
237 | img = img.copy()
238 | img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
239 | img_gray = np.float32(img_gray)
240 | #得到角点坐标向量
241 | goodfeatures_corners = cv2.goodFeaturesToTrack(img_gray, 400, 0.01, 10)
242 | goodfeatures_corners = np.int0(goodfeatures_corners)
243 | # 注意学习这种遍历的方法(写法)
244 | for i in goodfeatures_corners:
245 | #注意到i 是以列表为元素的列表,所以需要flatten或者ravel一下。
246 | x,y = i.flatten()
247 | # cv2.circle(img,(x,y), 3, [0,255,], -1)
248 |
249 | # colorize prediction
250 | pred_color = colorEncode(pred, self.colors).astype(np.uint8)
251 |
252 | im_vis = img * 0.7 + pred_color * 0.3
253 | im_vis = im_vis.astype(np.uint8)
254 |
255 | # for vedio result show
256 | self.vedioCap.addImage(im_vis, pred_color)
257 |
258 | img_name = img_name
259 | if not os.path.exists(dir):
260 | os.makedirs(dir)
261 | Image.fromarray(im_vis).save(
262 | os.path.join(dir, img_name))
263 |
264 |
265 | def speed_test(model, size=896, iteration=100):
266 | input_t = torch.Tensor(1, 3, size, size).cuda()
267 | feed_dict = {}
268 | feed_dict['img_data'] = input_t
269 |
270 | print("start warm up")
271 |
272 | for i in range(10):
273 | model(feed_dict, segSize=(size, size))
274 |
275 | print("warm up done")
276 | start_ts = time.time()
277 | for i in range(iteration):
278 | model(feed_dict, segSize=(size, size))
279 |
280 | torch.cuda.synchronize()
281 | end_ts = time.time()
282 |
283 | t_cnt = end_ts - start_ts
284 | print("=======================================")
285 | print("FPS: %f" % (100 / t_cnt))
286 | # print(f"Inference time {t_cnt/100*1000} ms")
287 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | EasyDict==1.7
2 | opencv-python==3.4.1.15
3 | shapely==1.6.4
4 | Cython
5 | scipy
6 | pandas
7 | pyyaml
8 | json_tricks
9 | scikit-image
10 | yacs>=0.1.5
11 | tensorboardX>=1.6
12 | tqdm
13 | ninja
14 |
--------------------------------------------------------------------------------
/tools/__pycache__/_init_paths.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/tools/__pycache__/_init_paths.cpython-35.pyc
--------------------------------------------------------------------------------
/tools/__pycache__/_init_paths.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/DDRNet.Pytorch/decce75534bcf9e9018f88fcc52b29c0c408ecac/tools/__pycache__/_init_paths.cpython-37.pyc
--------------------------------------------------------------------------------
/tools/_init_paths.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn)
5 | # ------------------------------------------------------------------------------
6 |
7 | from __future__ import absolute_import
8 | from __future__ import division
9 | from __future__ import print_function
10 |
11 | import os.path as osp
12 | import sys
13 |
14 |
15 | def add_path(path):
16 | if path not in sys.path:
17 | sys.path.insert(0, path)
18 |
19 | this_dir = osp.dirname(__file__)
20 |
21 | lib_path = osp.join(this_dir, '..', 'lib')
22 | add_path(lib_path)
23 |
--------------------------------------------------------------------------------
/tools/convert2jit.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn)
5 | # ------------------------------------------------------------------------------
6 |
7 | import argparse
8 | import os
9 | import pprint
10 | import shutil
11 | import sys
12 |
13 | import logging
14 | import time
15 | import timeit
16 | from pathlib import Path
17 | import time
18 | import numpy as np
19 |
20 | import torch
21 | import torch.nn as nn
22 | import torch.backends.cudnn as cudnn
23 |
24 | import _init_paths
25 | import models
26 | import cv2
27 | import torch.nn.functional as F
28 | import datasets
29 | from config import config
30 | from config import update_config
31 | from core.function import testval, test
32 | from utils.modelsummary import get_model_summary
33 | from utils.utils import create_logger, FullModel, speed_test
34 |
35 | def parse_args():
36 | parser = argparse.ArgumentParser(description='Train segmentation network')
37 |
38 | parser.add_argument('--cfg',
39 | help='experiment configure file name',
40 | default="experiments/cityscapes/ddrnet23_slim.yaml",
41 | type=str)
42 | parser.add_argument('opts',
43 | help="Modify config options using the command-line",
44 | default=None,
45 | nargs=argparse.REMAINDER)
46 |
47 | args = parser.parse_args()
48 | update_config(config, args)
49 |
50 | return args
51 |
52 | def main():
53 | mean=[0.485, 0.456, 0.406],
54 | std=[0.229, 0.224, 0.225]
55 | args = parse_args()
56 |
57 | logger, final_output_dir, _ = create_logger(
58 | config, args.cfg, 'test')
59 |
60 | logger.info(pprint.pformat(args))
61 | logger.info(pprint.pformat(config))
62 |
63 | # cudnn related setting
64 | cudnn.benchmark = config.CUDNN.BENCHMARK
65 | cudnn.deterministic = config.CUDNN.DETERMINISTIC
66 | cudnn.enabled = config.CUDNN.ENABLED
67 |
68 | # build model
69 | if torch.__version__.startswith('1'):
70 | module = eval('models.'+config.MODEL.NAME)
71 | module.BatchNorm2d_class = module.BatchNorm2d = torch.nn.BatchNorm2d
72 | model = eval('models.'+config.MODEL.NAME +
73 | '.get_seg_model')(config)
74 |
75 | dump_input = torch.rand(
76 | (1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])
77 | )
78 | logger.info(get_model_summary(model.cuda(), dump_input.cuda()))
79 |
80 | if config.TEST.MODEL_FILE:
81 | model_state_file = config.TEST.MODEL_FILE
82 | else:
83 | # model_state_file = os.path.join(final_output_dir, 'best_0.7589.pth')
84 | model_state_file = os.path.join(final_output_dir, 'best.pth')
85 | logger.info('=> loading model from {}'.format(model_state_file))
86 |
87 | pretrained_dict = torch.load('/home/kong/Documents/DDRNet.Pytorch/DDRNet.Pytorch/output/face/ddrnet23_slim/checkpoint.pth.tar')
88 | if 'state_dict' in pretrained_dict:
89 | pretrained_dict = pretrained_dict['state_dict']
90 |
91 | newstate_dict = {k:v for k,v in pretrained_dict.items() if k in model.state_dict()}
92 | # print(pretrained_dict.keys())
93 |
94 | model.load_state_dict(newstate_dict)
95 | model = model.to("cpu")
96 | print(model)
97 | model.eval()
98 | example = torch.rand(1, 3, 512, 512)
99 |
100 | model = torch.quantization.convert(model)
101 | # traced_script_module = torch.jit.trace(model, example)
102 | # traced_script_module.save("ddrnetfp32.pt")
103 | scriptedm = torch.jit.script(model)
104 | opt_model = torch.utils.optimize_for_mobile(scriptedm)
105 | torch.jit.save(opt_model, "ddrnetint8.pt")
106 |
107 |
108 |
109 |
110 |
111 |
112 | if __name__ == '__main__':
113 | main()
114 |
--------------------------------------------------------------------------------
/tools/convert2trt.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn)
5 | # ------------------------------------------------------------------------------
6 |
7 | import argparse
8 | import os
9 | import pprint
10 | import shutil
11 | import sys
12 |
13 | import logging
14 | import time
15 | import timeit
16 | from pathlib import Path
17 | import time
18 | import numpy as np
19 | from torch2trt import torch2trt
20 |
21 | import torch
22 | import torch.nn as nn
23 | import torch.backends.cudnn as cudnn
24 |
25 | import _init_paths
26 | import models
27 | import cv2
28 | import torch.nn.functional as F
29 | import datasets
30 | from config import config
31 | from config import update_config
32 | from core.function import testval, test
33 | from utils.modelsummary import get_model_summary
34 | from utils.utils import create_logger, FullModel, speed_test
35 |
36 | def parse_args():
37 | parser = argparse.ArgumentParser(description='Train segmentation network')
38 |
39 | parser.add_argument('--cfg',
40 | help='experiment configure file name',
41 | default="experiments/cityscapes/ddrnet23_slim.yaml",
42 | type=str)
43 | parser.add_argument('opts',
44 | help="Modify config options using the command-line",
45 | default=None,
46 | nargs=argparse.REMAINDER)
47 |
48 | args = parser.parse_args()
49 | update_config(config, args)
50 |
51 | return args
52 |
53 | def main():
54 | mean=[0.485, 0.456, 0.406],
55 | std=[0.229, 0.224, 0.225]
56 | args = parse_args()
57 |
58 | logger, final_output_dir, _ = create_logger(
59 | config, args.cfg, 'test')
60 |
61 | logger.info(pprint.pformat(args))
62 | logger.info(pprint.pformat(config))
63 |
64 | # cudnn related setting
65 | cudnn.benchmark = config.CUDNN.BENCHMARK
66 | cudnn.deterministic = config.CUDNN.DETERMINISTIC
67 | cudnn.enabled = config.CUDNN.ENABLED
68 |
69 | # build model
70 | if torch.__version__.startswith('1'):
71 | module = eval('models.'+config.MODEL.NAME)
72 | module.BatchNorm2d_class = module.BatchNorm2d = torch.nn.BatchNorm2d
73 | model = eval('models.'+config.MODEL.NAME +
74 | '.get_seg_model')(config)
75 |
76 | dump_input = torch.rand(
77 | (1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])
78 | )
79 | logger.info(get_model_summary(model.cuda(), dump_input.cuda()))
80 |
81 | if config.TEST.MODEL_FILE:
82 | model_state_file = config.TEST.MODEL_FILE
83 | else:
84 | # model_state_file = os.path.join(final_output_dir, 'best_0.7589.pth')
85 | model_state_file = os.path.join(final_output_dir, 'best.pth')
86 | logger.info('=> loading model from {}'.format(model_state_file))
87 |
88 | pretrained_dict = torch.load('/home/hwits/Documents/CarVid/DDRNet/DDRNet.pytorch/model_best_bacc.pth.tar')
89 | if 'state_dict' in pretrained_dict:
90 | pretrained_dict = pretrained_dict['state_dict']
91 | newstate_dict = {k:v for k,v in pretrained_dict.items() if k in model.state_dict()}
92 | # print(pretrained_dict.keys())
93 |
94 | model.load_state_dict(newstate_dict)
95 | model = model.cuda()
96 | print(model)
97 |
98 | x = torch.ones((1, 3, 1024, 1024)).cuda()
99 | model_trt = torch2trt(model, [x])
100 |
101 |
102 | y = model(x)
103 | print(len(y))
104 | y_trt = model_trt(x)
105 | print(y_trt.shape)
106 |
107 | print(len(y_trt))
108 |
109 | # check the output against PyTorch
110 | print(torch.max(torch.abs(y - y_trt)))
111 |
112 |
113 | with open('ddrnet.engine', "wb") as f:
114 | f.write(model_trt.engine.serialize())
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 | if __name__ == '__main__':
125 | main()
126 |
--------------------------------------------------------------------------------
/tools/demo.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn)
5 | # ------------------------------------------------------------------------------
6 |
7 | import argparse
8 | import os
9 | import pprint
10 | import shutil
11 | import sys
12 |
13 | import logging
14 | import time
15 | import timeit
16 | from pathlib import Path
17 |
18 | import numpy as np
19 |
20 | import torch
21 | import torch.nn as nn
22 | import torch.backends.cudnn as cudnn
23 |
24 | import _init_paths
25 | import models
26 | import datasets
27 | from config import config
28 | from config import update_config
29 | from core.function import testval, test
30 | from utils.modelsummary import get_model_summary
31 | from utils.utils import create_logger, FullModel, speed_test
32 |
33 | def parse_args():
34 | parser = argparse.ArgumentParser(description='Train segmentation network')
35 |
36 | parser.add_argument('--cfg',
37 | help='experiment configure file name',
38 | default="experiments/map/map_hrnet_ocr_w18_small_v2_512x1024_sgd_lr1e-2_wd5e-4_bs_12_epoch484.yaml",
39 | type=str)
40 | parser.add_argument('opts',
41 | help="Modify config options using the command-line",
42 | default=None,
43 | nargs=argparse.REMAINDER)
44 |
45 | args = parser.parse_args()
46 | update_config(config, args)
47 |
48 | return args
49 |
50 | def main():
51 | args = parse_args()
52 |
53 | logger, final_output_dir, _ = create_logger(
54 | config, args.cfg, 'test')
55 |
56 | logger.info(pprint.pformat(args))
57 | logger.info(pprint.pformat(config))
58 |
59 | # cudnn related setting
60 | cudnn.benchmark = config.CUDNN.BENCHMARK
61 | cudnn.deterministic = config.CUDNN.DETERMINISTIC
62 | cudnn.enabled = config.CUDNN.ENABLED
63 |
64 | # build model
65 | if torch.__version__.startswith('1'):
66 | module = eval('models.'+config.MODEL.NAME)
67 | module.BatchNorm2d_class = module.BatchNorm2d = torch.nn.BatchNorm2d
68 | model = eval('models.'+config.MODEL.NAME +
69 | '.get_seg_model')(config)
70 |
71 | dump_input = torch.rand(
72 | (1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])
73 | )
74 | logger.info(get_model_summary(model.cuda(), dump_input.cuda()))
75 |
76 | if config.TEST.MODEL_FILE:
77 | model_state_file = config.TEST.MODEL_FILE
78 | else:
79 | # model_state_file = os.path.join(final_output_dir, 'best_0.7589.pth')
80 | model_state_file = os.path.join(final_output_dir, 'best.pth')
81 | logger.info('=> loading model from {}'.format(model_state_file))
82 |
83 | pretrained_dict = torch.load(model_state_file)
84 | if 'state_dict' in pretrained_dict:
85 | pretrained_dict = pretrained_dict['state_dict']
86 | model_dict = model.state_dict()
87 | pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items()
88 | if k[6:] in model_dict.keys()}
89 | for k, _ in pretrained_dict.items():
90 | logger.info(
91 | '=> loading {} from pretrained model'.format(k))
92 | model_dict.update(pretrained_dict)
93 | model.load_state_dict(model_dict)
94 | model = model.cuda()
95 |
96 | # gpus = list(config.GPUS)
97 | # model = nn.DataParallel(model, device_ids=gpus).cuda()
98 |
99 | # prepare data
100 | test_size = (config.TEST.IMAGE_SIZE[1], config.TEST.IMAGE_SIZE[0])
101 | test_dataset = eval('datasets.'+config.DATASET.DATASET)(
102 | root=config.DATASET.ROOT,
103 | list_path=config.DATASET.TEST_SET,
104 | num_samples=None,
105 | num_classes=config.DATASET.NUM_CLASSES,
106 | multi_scale=False,
107 | flip=False,
108 | ignore_label=config.TRAIN.IGNORE_LABEL,
109 | base_size=config.TEST.BASE_SIZE,
110 | crop_size=test_size,
111 | downsample_rate=1)
112 |
113 | testloader = torch.utils.data.DataLoader(
114 | test_dataset,
115 | batch_size=1,
116 | shuffle=False,
117 | num_workers=config.WORKERS,
118 | pin_memory=True)
119 |
120 | start = timeit.default_timer()
121 |
122 | test(config,
123 | test_dataset,
124 | testloader,
125 | model,
126 | sv_dir=final_output_dir+'/test_result')
127 |
128 | end = timeit.default_timer()
129 | logger.info('Mins: %d' % np.int((end-start)/60))
130 | logger.info('Done')
131 |
132 |
133 | if __name__ == '__main__':
134 | main()
135 |
--------------------------------------------------------------------------------
/tools/demo_img.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn)
5 | # ------------------------------------------------------------------------------
6 |
7 | import argparse
8 | import os
9 | import pprint
10 | import shutil
11 | import sys
12 |
13 | import logging
14 | import time
15 | import timeit
16 | from pathlib import Path
17 | import time
18 | import numpy as np
19 |
20 | import torch
21 | import torch.nn as nn
22 | import torch.backends.cudnn as cudnn
23 |
24 | import _init_paths
25 | import models
26 | import cv2
27 | import torch.nn.functional as F
28 | import datasets
29 | from config import config
30 | from config import update_config
31 | from core.function import testval, test
32 | from utils.modelsummary import get_model_summary
33 | from utils.utils import create_logger, FullModel, speed_test
34 |
35 | def parse_args():
36 | parser = argparse.ArgumentParser(description='Train segmentation network')
37 |
38 | parser.add_argument('--cfg',
39 | help='experiment configure file name',
40 | default="experiments/cityscapes/ddrnet23_slim.yaml",
41 | type=str)
42 | parser.add_argument('opts',
43 | help="Modify config options using the command-line",
44 | default=None,
45 | nargs=argparse.REMAINDER)
46 |
47 | args = parser.parse_args()
48 | update_config(config, args)
49 |
50 | return args
51 |
52 | def main():
53 | mean=[0.485, 0.456, 0.406],
54 | std=[0.229, 0.224, 0.225]
55 | args = parse_args()
56 |
57 | logger, final_output_dir, _ = create_logger(
58 | config, args.cfg, 'test')
59 |
60 | logger.info(pprint.pformat(args))
61 | logger.info(pprint.pformat(config))
62 |
63 | # cudnn related setting
64 | cudnn.benchmark = config.CUDNN.BENCHMARK
65 | cudnn.deterministic = config.CUDNN.DETERMINISTIC
66 | cudnn.enabled = config.CUDNN.ENABLED
67 |
68 | # build model
69 | if torch.__version__.startswith('1'):
70 | module = eval('models.'+config.MODEL.NAME)
71 | module.BatchNorm2d_class = module.BatchNorm2d = torch.nn.BatchNorm2d
72 | model = eval('models.'+config.MODEL.NAME +
73 | '.get_seg_model')(config)
74 |
75 | dump_input = torch.rand(
76 | (1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])
77 | )
78 | logger.info(get_model_summary(model.cuda(), dump_input.cuda()))
79 |
80 | if config.TEST.MODEL_FILE:
81 | model_state_file = config.TEST.MODEL_FILE
82 | else:
83 | # model_state_file = os.path.join(final_output_dir, 'best_0.7589.pth')
84 | model_state_file = os.path.join(final_output_dir, 'best.pth')
85 | logger.info('=> loading model from {}'.format(model_state_file))
86 |
87 | pretrained_dict = torch.load('/home/kong/Documents/DDRNet.Pytorch/DDRNet.Pytorch/output/face/ddrnet23_slim/checkpoint.pth.tar')
88 | if 'state_dict' in pretrained_dict:
89 | pretrained_dict = pretrained_dict['state_dict']
90 |
91 | # print(pretrained_dict.keys())
92 | new_state = {k:v for k,v in pretrained_dict.items() if k in model.state_dict()}
93 |
94 | model.load_state_dict(new_state)
95 | model = model.cuda()
96 |
97 | torch.save(model.state_dict(), 'model_best_bacc.pth.tar', _use_new_zipfile_serialization=False)
98 |
99 |
100 | # gpus = list(config.GPUS)
101 | # model = nn.DataParallel(model, device_ids=gpus).cuda()
102 |
103 | # prepare data
104 | test_size = (config.TEST.IMAGE_SIZE[1], config.TEST.IMAGE_SIZE[0])
105 | print(test_size)
106 | img = cv2.imread("/home/kong/Downloads/731509a96420ef3dd0cffe869a4a53cb.jpeg")
107 |
108 | img = cv2.resize(img,(512,512))
109 | image = img.astype(np.float32)[:, :, ::-1]
110 | image = image / 255.0
111 | image -= mean
112 | image /= std
113 |
114 | image = image.transpose((2,0,1))
115 | image = torch.from_numpy(image)
116 |
117 | # image = image.permute((2, 0, 1))
118 |
119 | print(image.shape)
120 | image = image.unsqueeze(0)
121 |
122 | image = image.cuda()
123 | start = time.time()
124 | print("input : ",image)
125 | for i in range(1):
126 | out= model(image)
127 | end = time.time()
128 | print("Cuda 1000 images inference time : ",1000.0/(end - start))
129 | outadd = out[0]*1.0 + out[1]*0.4
130 | out0 = out[0].squeeze(dim=0)
131 | out1 = out[1].squeeze(dim=0)
132 |
133 | print(out0.size(),out0[0,1,1],out0[1,1,1])
134 | print("out:",out0)
135 |
136 |
137 | outadd = outadd.squeeze(dim=0)
138 | out0 = F.softmax(out0,dim=0)
139 | out1 = F.softmax(out1,dim=0)
140 | outadd = F.softmax(outadd,dim=0)
141 |
142 | out0 = torch.argmax(out0,dim=0)
143 | out1 = torch.argmax(out1,dim=0)
144 | outadd = torch.argmax(outadd,dim=0)
145 |
146 | pred0 = out0.detach().cpu().numpy()
147 | pred1 = out1.detach().cpu().numpy()
148 | predadd = outadd.detach().cpu().numpy()
149 | pred0 = pred0*255
150 | pred1 = pred1*255
151 | predadd = predadd*255
152 |
153 | pred_ch = np.zeros(pred0.shape)
154 | pred_rgb0 = np.array([pred_ch,pred_ch,pred0])
155 | pred_rgb1 = np.array([pred_ch,pred_ch,pred1])
156 | pred_rgbadd = np.array([predadd,pred_ch,predadd])
157 | pred_rgb0 = pred_rgb0.transpose(1,2,0)
158 | pred_rgb1 = pred_rgb1.transpose(1,2,0)
159 | pred_rgbadd = pred_rgbadd.transpose(1,2,0)
160 | pred_rgb0 = cv2.resize(pred_rgb0,(img.shape[1],img.shape[0]))
161 | pred_rgb1 = cv2.resize(pred_rgb1,(img.shape[1],img.shape[0]))
162 | pred_rgbadd = cv2.resize(pred_rgbadd,(img.shape[1],img.shape[0]))
163 | dst=cv2.addWeighted(img,0.7,pred_rgb0.astype(np.uint8),0.3,0)
164 | dst1=cv2.addWeighted(img,0.7,pred_rgb1.astype(np.uint8),0.3,0)
165 | dstadd=cv2.addWeighted(img,0.7,pred_rgbadd.astype(np.uint8),0.3,0)
166 |
167 | imgadd = np.vstack((img,pred_rgb0,dst,pred_rgb1, dst1,pred_rgbadd, dstadd))
168 |
169 |
170 | cv2.imwrite("a242.jpg",imgadd)
171 |
172 |
173 |
174 |
175 | if __name__ == '__main__':
176 | main()
177 |
--------------------------------------------------------------------------------
/tools/demo_img_noaug.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn)
5 | # ------------------------------------------------------------------------------
6 |
7 | import argparse
8 | import os
9 | import pprint
10 | import shutil
11 | import sys
12 |
13 | import logging
14 | import time
15 | import timeit
16 | from pathlib import Path
17 | import time
18 | import numpy as np
19 |
20 | import torch
21 | import torch.nn as nn
22 | import torch.backends.cudnn as cudnn
23 |
24 | import _init_paths
25 | import models
26 | import cv2
27 | import torch.nn.functional as F
28 | import datasets
29 | from config import config
30 | from config import update_config
31 | from core.function import testval, test
32 | from utils.modelsummary import get_model_summary
33 | from utils.utils import create_logger, FullModel, speed_test
34 |
35 | def parse_args():
36 | parser = argparse.ArgumentParser(description='Train segmentation network')
37 |
38 | parser.add_argument('--cfg',
39 | help='experiment configure file name',
40 | default="experiments/cityscapes/ddrnet23_slim.yaml",
41 | type=str)
42 | parser.add_argument('opts',
43 | help="Modify config options using the command-line",
44 | default=None,
45 | nargs=argparse.REMAINDER)
46 |
47 | args = parser.parse_args()
48 | update_config(config, args)
49 |
50 | return args
51 |
52 | def main():
53 | mean=[0.485, 0.456, 0.406],
54 | std=[0.229, 0.224, 0.225]
55 | args = parse_args()
56 |
57 | logger, final_output_dir, _ = create_logger(
58 | config, args.cfg, 'test')
59 |
60 | logger.info(pprint.pformat(args))
61 | logger.info(pprint.pformat(config))
62 |
63 | # cudnn related setting
64 | cudnn.benchmark = config.CUDNN.BENCHMARK
65 | cudnn.deterministic = config.CUDNN.DETERMINISTIC
66 | cudnn.enabled = config.CUDNN.ENABLED
67 |
68 | # build model
69 | if torch.__version__.startswith('1'):
70 | module = eval('models.'+config.MODEL.NAME)
71 | module.BatchNorm2d_class = module.BatchNorm2d = torch.nn.BatchNorm2d
72 | model = eval('models.'+config.MODEL.NAME +
73 | '.get_seg_model')(config)
74 |
75 | dump_input = torch.rand(
76 | (1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])
77 | )
78 | logger.info(get_model_summary(model.cuda(), dump_input.cuda()))
79 |
80 | if config.TEST.MODEL_FILE:
81 | model_state_file = config.TEST.MODEL_FILE
82 | else:
83 | # model_state_file = os.path.join(final_output_dir, 'best_0.7589.pth')
84 | model_state_file = os.path.join(final_output_dir, 'best.pth')
85 | logger.info('=> loading model from {}'.format(model_state_file))
86 |
87 | pretrained_dict = torch.load('/home/kong/Documents/DDRNet.Pytorch/DDRNet.Pytorch/output/face/ddrnet23_slim/checkpoint.pth.tar')
88 | if 'state_dict' in pretrained_dict:
89 | pretrained_dict = pretrained_dict['state_dict']
90 |
91 | newstate_dict = {k:v for k,v in pretrained_dict.items() if k in model.state_dict()}
92 | # print(pretrained_dict.keys())
93 |
94 | model.load_state_dict(newstate_dict)
95 | model = model.cuda()
96 | print(model)
97 |
98 | # torch.save(model.state_dict(), 'model_best_bacc.pth.tar', _use_new_zipfile_serialization=False)
99 |
100 |
101 | # gpus = list(config.GPUS)
102 | # model = nn.DataParallel(model, device_ids=gpus).cuda()
103 |
104 | # prepare data
105 | test_size = (config.TEST.IMAGE_SIZE[1], config.TEST.IMAGE_SIZE[0])
106 | print(test_size)
107 | img = cv2.imread("/home/kong/Documents/DDRNet.Pytorch/DDRNet.Pytorch/images/418cd0c0b416d93bc5a129834537f2e1.jpeg")
108 | stat1 = time.time()
109 | img = cv2.resize(img,(512,512))
110 | image = img.astype(np.float32)[:, :, ::-1]
111 | image = image / 255.0
112 | image -= mean
113 | image /= std
114 |
115 | image = image.transpose((2,0,1))
116 | image = torch.from_numpy(image)
117 |
118 | # image = image.permute((2, 0, 1))
119 |
120 |
121 | image = image.unsqueeze(0)
122 |
123 | image = image.cuda()
124 | stat2 = time.time()
125 | print("pre-process time : ",stat2 - stat1)
126 | start = time.time()
127 | for i in range(1000):
128 | out= model(image)
129 | end = time.time()
130 | print("FPS : ",1000.0/(end - start))
131 | # print("out:",out)
132 | out0 = out.squeeze(dim=0)
133 | out0 = F.softmax(out0,dim=0)
134 |
135 |
136 | out0 = torch.argmax(out0,dim=0)
137 |
138 | pred0 = out0.detach().cpu().numpy()
139 |
140 | pred0 = pred0*255
141 |
142 |
143 | pred_ch = np.zeros(pred0.shape)
144 | pred_rgb0 = np.array([pred_ch,pred_ch,pred0])
145 |
146 | pred_rgb0 = pred_rgb0.transpose(1,2,0)
147 |
148 | pred_rgb0 = cv2.resize(pred_rgb0,(img.shape[1],img.shape[0]))
149 |
150 | dst=cv2.addWeighted(img,0.7,pred_rgb0.astype(np.uint8),0.3,0)
151 |
152 |
153 | imgadd = np.vstack((img,pred_rgb0,dst))
154 |
155 |
156 | cv2.imwrite("a222.jpg",imgadd)
157 |
158 |
159 |
160 |
161 | if __name__ == '__main__':
162 | main()
163 |
--------------------------------------------------------------------------------
/tools/demo_img_orig.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn)
5 | # ------------------------------------------------------------------------------
6 |
7 | import argparse
8 | import os
9 | import pprint
10 | import shutil
11 | import sys
12 |
13 | import logging
14 | import time
15 | import timeit
16 | from pathlib import Path
17 | import time
18 | import numpy as np
19 |
20 | import torch
21 | import torch.nn as nn
22 | import torch.backends.cudnn as cudnn
23 |
24 | import _init_paths
25 | import models
26 | import cv2
27 | import torch.nn.functional as F
28 | import datasets
29 | from config import config
30 | from config import update_config
31 | from core.function import testval, test
32 | from utils.modelsummary import get_model_summary
33 | from utils.utils import create_logger, FullModel, speed_test
34 |
35 | from utils.utils import Map16, Vedio
36 | # from utils.DenseCRF import DenseCRF
37 |
38 |
39 | vedioCap = Vedio('./output/cdOffice.mp4')
40 | map16 = Map16(vedioCap)
41 |
42 | def parse_args():
43 | parser = argparse.ArgumentParser(description='Train segmentation network')
44 |
45 | parser.add_argument('--cfg',
46 | help='experiment configure file name',
47 | default="experiments/cityscapes/ddrnet23_slim.yaml",
48 | type=str)
49 | parser.add_argument('opts',
50 | help="Modify config options using the command-line",
51 | default=None,
52 | nargs=argparse.REMAINDER)
53 |
54 | args = parser.parse_args()
55 | update_config(config, args)
56 |
57 | return args
58 |
59 | def main():
60 | mean=[0.485, 0.456, 0.406],
61 | std=[0.229, 0.224, 0.225]
62 | args = parse_args()
63 |
64 | logger, final_output_dir, _ = create_logger(
65 | config, args.cfg, 'test')
66 |
67 | logger.info(pprint.pformat(args))
68 | logger.info(pprint.pformat(config))
69 |
70 | # cudnn related setting
71 | cudnn.benchmark = config.CUDNN.BENCHMARK
72 | cudnn.deterministic = config.CUDNN.DETERMINISTIC
73 | cudnn.enabled = config.CUDNN.ENABLED
74 |
75 | # build model
76 | if torch.__version__.startswith('1'):
77 | module = eval('models.'+config.MODEL.NAME)
78 | module.BatchNorm2d_class = module.BatchNorm2d = torch.nn.BatchNorm2d
79 | model = eval('models.'+config.MODEL.NAME +
80 | '.get_seg_model')(config)
81 |
82 | dump_input = torch.rand(
83 | (1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])
84 | )
85 | logger.info(get_model_summary(model.cuda(), dump_input.cuda()))
86 |
87 | if config.TEST.MODEL_FILE:
88 | model_state_file = config.TEST.MODEL_FILE
89 | else:
90 | # model_state_file = os.path.join(final_output_dir, 'best_0.7589.pth')
91 | model_state_file = os.path.join(final_output_dir, 'best.pth')
92 | logger.info('=> loading model from {}'.format(model_state_file))
93 |
94 | pretrained_dict = torch.load('/home/hwits/Documents/CarVid/DDRNet/segmentation/best_val_smaller.pth')
95 | # print("pretrained_dict:",pretrained_dict.keys())
96 | if 'state_dict' in pretrained_dict:
97 | pretrained_dict = pretrained_dict['state_dict']
98 |
99 | # print(pretrained_dict.keys())
100 | new_state = {k[6:]:v for k,v in pretrained_dict.items() if k[6:] in model.state_dict()}
101 |
102 | model.load_state_dict(new_state)
103 | model = model.cuda()
104 |
105 | torch.save(model.state_dict(), 'model_best_bacc.pth.tar', _use_new_zipfile_serialization=False)
106 |
107 |
108 | # gpus = list(config.GPUS)
109 | # model = nn.DataParallel(model, device_ids=gpus).cuda()
110 |
111 | # prepare data
112 | test_size = (config.TEST.IMAGE_SIZE[1], config.TEST.IMAGE_SIZE[0])
113 | print(test_size)
114 | img = cv2.imread("/home/hwits/berlin_000528_000019_leftImg8bit.png")
115 |
116 | img = cv2.resize(img,(1024,1024))
117 | image = img.astype(np.float32)[:, :, ::-1]
118 | image = image / 255.0
119 | image -= mean
120 | image /= std
121 |
122 | image = image.transpose((2,0,1))
123 | image = torch.from_numpy(image)
124 |
125 | # image = image.permute((2, 0, 1))
126 |
127 | print(image.shape)
128 | image = image.unsqueeze(0)
129 |
130 | image = image.cuda()
131 | start = time.time()
132 | # print("input : ",image)
133 | for i in range(1):
134 | out= model(image)
135 | end = time.time()
136 | print("out : ",out[0].shape)
137 |
138 | pred = F.interpolate(
139 | out[0], (img.shape[0],img.shape[1]),
140 | mode='nearest'
141 | )
142 | print("results : ",pred.shape)
143 | _, pred = torch.max(pred, dim=1)
144 | pred = pred.squeeze(0).cpu().numpy()
145 |
146 | map16.visualize_result(img, pred, '.', 'cityscape.jpg')
147 |
148 |
149 | print("Cuda 1000 images inference time : ",1000.0/(end - start))
150 | outadd = out[0]*1.0 + out[1]*0.4
151 | out0 = out[0].squeeze(dim=0)
152 | out1 = out[1].squeeze(dim=0)
153 |
154 |
155 |
156 |
157 |
158 | outadd = outadd.squeeze(dim=0)
159 | out0 = F.softmax(out0,dim=0)
160 | out1 = F.softmax(out1,dim=0)
161 | outadd = F.softmax(outadd,dim=0)
162 |
163 | out0 = torch.argmax(out0,dim=0)
164 | out1 = torch.argmax(out1,dim=0)
165 | outadd = torch.argmax(outadd,dim=0)
166 |
167 | pred0 = out0.detach().cpu().numpy()
168 | pred1 = out1.detach().cpu().numpy()
169 | predadd = outadd.detach().cpu().numpy()
170 | pred0 = pred0*255
171 | pred1 = pred1*255
172 | predadd = predadd*255
173 |
174 | pred_ch = np.zeros(pred0.shape)
175 | pred_rgb0 = np.array([pred_ch,pred_ch,pred0])
176 | pred_rgb1 = np.array([pred_ch,pred_ch,pred1])
177 | pred_rgbadd = np.array([predadd,pred_ch,predadd])
178 | pred_rgb0 = pred_rgb0.transpose(1,2,0)
179 | pred_rgb1 = pred_rgb1.transpose(1,2,0)
180 | pred_rgbadd = pred_rgbadd.transpose(1,2,0)
181 | pred_rgb0 = cv2.resize(pred_rgb0,(img.shape[1],img.shape[0]))
182 | pred_rgb1 = cv2.resize(pred_rgb1,(img.shape[1],img.shape[0]))
183 | pred_rgbadd = cv2.resize(pred_rgbadd,(img.shape[1],img.shape[0]))
184 | dst=cv2.addWeighted(img,0.7,pred_rgb0.astype(np.uint8),0.3,0)
185 | dst1=cv2.addWeighted(img,0.7,pred_rgb1.astype(np.uint8),0.3,0)
186 | dstadd=cv2.addWeighted(img,0.7,pred_rgbadd.astype(np.uint8),0.3,0)
187 |
188 | imgadd = np.vstack((img,pred_rgb0,dst,pred_rgb1, dst1,pred_rgbadd, dstadd))
189 |
190 |
191 | cv2.imwrite("a22.jpg",imgadd)
192 |
193 |
194 |
195 |
196 | if __name__ == '__main__':
197 | main()
198 |
--------------------------------------------------------------------------------
/tools/eval.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn)
5 | # ------------------------------------------------------------------------------
6 |
7 | import argparse
8 | import os
9 | import pprint
10 | import shutil
11 | import sys
12 |
13 | import logging
14 | import time
15 | import timeit
16 | from pathlib import Path
17 |
18 | import numpy as np
19 |
20 | import torch
21 | import torch.nn as nn
22 | import torch.backends.cudnn as cudnn
23 |
24 | import _init_paths
25 | import models
26 | import datasets
27 | from config import config
28 | from config import update_config
29 | from core.function import testval, test
30 | from utils.modelsummary import get_model_summary
31 | from utils.utils import create_logger, FullModel, speed_test
32 |
33 | def parse_args():
34 | parser = argparse.ArgumentParser(description='Train segmentation network')
35 |
36 | parser.add_argument('--cfg',
37 | help='experiment configure file name',
38 | default="experiments/cityscapes/ddrnet23_slim.yaml",
39 | type=str)
40 | parser.add_argument('opts',
41 | help="Modify config options using the command-line",
42 | default=None,
43 | nargs=argparse.REMAINDER)
44 |
45 | args = parser.parse_args()
46 | update_config(config, args)
47 |
48 | return args
49 |
50 | def main():
51 | args = parse_args()
52 |
53 | logger, final_output_dir, _ = create_logger(
54 | config, args.cfg, 'test')
55 |
56 | logger.info(pprint.pformat(args))
57 | logger.info(pprint.pformat(config))
58 |
59 | # cudnn related setting
60 | cudnn.benchmark = config.CUDNN.BENCHMARK
61 | cudnn.deterministic = config.CUDNN.DETERMINISTIC
62 | cudnn.enabled = config.CUDNN.ENABLED
63 |
64 | # build model
65 | if torch.__version__.startswith('1'):
66 | module = eval('models.'+config.MODEL.NAME)
67 | module.BatchNorm2d_class = module.BatchNorm2d = torch.nn.BatchNorm2d
68 | model = eval('models.'+config.MODEL.NAME +
69 | '.get_seg_model')(config)
70 |
71 | # dump_input = torch.rand(
72 | # (1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])
73 | # )
74 | # logger.info(get_model_summary(model.cuda(), dump_input.cuda()))
75 |
76 | if config.TEST.MODEL_FILE:
77 | model_state_file = config.TEST.MODEL_FILE
78 | else:
79 | model_state_file = os.path.join(final_output_dir, 'best.pth')
80 | # model_state_file = os.path.join(final_output_dir, 'final_state.pth')
81 | logger.info('=> loading model from {}'.format(model_state_file))
82 |
83 | pretrained_dict = torch.load(model_state_file)
84 | if 'state_dict' in pretrained_dict:
85 | pretrained_dict = pretrained_dict['state_dict']
86 | model_dict = model.state_dict()
87 | pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items()
88 | if k[6:] in model_dict.keys()}
89 | for k, _ in pretrained_dict.items():
90 | logger.info(
91 | '=> loading {} from pretrained model'.format(k))
92 | model_dict.update(pretrained_dict)
93 | model.load_state_dict(model_dict)
94 |
95 | gpus = list(config.GPUS)
96 | model = model.cuda()
97 | # model = nn.DataParallel(model, device_ids=gpus).cuda()
98 |
99 | # prepare data
100 | test_size = (config.TEST.IMAGE_SIZE[1], config.TEST.IMAGE_SIZE[0])
101 | test_dataset = eval('datasets.'+config.DATASET.DATASET)(
102 | root=config.DATASET.ROOT,
103 | list_path=config.DATASET.TEST_SET,
104 | num_samples=None,
105 | num_classes=config.DATASET.NUM_CLASSES,
106 | multi_scale=False,
107 | flip=False,
108 | ignore_label=config.TRAIN.IGNORE_LABEL,
109 | base_size=config.TEST.BASE_SIZE,
110 | crop_size=test_size,
111 | downsample_rate=1)
112 |
113 | testloader = torch.utils.data.DataLoader(
114 | test_dataset,
115 | batch_size=1,
116 | shuffle=False,
117 | num_workers=config.WORKERS,
118 | pin_memory=True)
119 |
120 | start = timeit.default_timer()
121 |
122 | mean_IoU, IoU_array, pixel_acc, mean_acc = testval(config,
123 | test_dataset,
124 | testloader,
125 | model,
126 | sv_pred=False)
127 |
128 | msg = 'MeanIU: {: 4.4f}, Pixel_Acc: {: 4.4f}, \
129 | Mean_Acc: {: 4.4f}, Class IoU: '.format(mean_IoU,
130 | pixel_acc, mean_acc)
131 | logging.info(msg)
132 | logging.info(IoU_array)
133 |
134 | end = timeit.default_timer()
135 | logger.info('Mins: %d' % np.int((end-start)/60))
136 | logger.info('Done')
137 |
138 |
139 | if __name__ == '__main__':
140 | main()
141 |
--------------------------------------------------------------------------------
/tools/getwts.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn)
5 | # ------------------------------------------------------------------------------
6 |
7 | import argparse
8 | import os
9 | import pprint
10 | import shutil
11 | import sys
12 |
13 | import logging
14 | import time
15 | import timeit
16 | from pathlib import Path
17 | import time
18 | import numpy as np
19 | import torch
20 | import torch.nn as nn
21 | import torch.backends.cudnn as cudnn
22 | import struct
23 | import _init_paths
24 | import models
25 | import cv2
26 | import torch.nn.functional as F
27 | import datasets
28 | from config import config
29 | from config import update_config
30 | from core.function import testval, test
31 | from utils.modelsummary import get_model_summary
32 | from utils.utils import create_logger, FullModel, speed_test
33 |
34 | def parse_args():
35 | parser = argparse.ArgumentParser(description='Train segmentation network')
36 |
37 | parser.add_argument('--cfg',
38 | help='experiment configure file name',
39 | default="experiments/cityscapes/ddrnet23_slim.yaml",
40 | type=str)
41 | parser.add_argument('opts',
42 | help="Modify config options using the command-line",
43 | default=None,
44 | nargs=argparse.REMAINDER)
45 |
46 | args = parser.parse_args()
47 | update_config(config, args)
48 |
49 | return args
50 |
51 | def main():
52 | mean=[0.485, 0.456, 0.406],
53 | std=[0.229, 0.224, 0.225]
54 | args = parse_args()
55 |
56 | logger, final_output_dir, _ = create_logger(
57 | config, args.cfg, 'test')
58 |
59 | logger.info(pprint.pformat(args))
60 | logger.info(pprint.pformat(config))
61 |
62 | # cudnn related setting
63 | cudnn.benchmark = config.CUDNN.BENCHMARK
64 | cudnn.deterministic = config.CUDNN.DETERMINISTIC
65 | cudnn.enabled = config.CUDNN.ENABLED
66 |
67 | # build model
68 | if torch.__version__.startswith('1'):
69 | module = eval('models.'+config.MODEL.NAME)
70 | module.BatchNorm2d_class = module.BatchNorm2d = torch.nn.BatchNorm2d
71 | model = eval('models.'+config.MODEL.NAME +
72 | '.get_seg_model')(config)
73 |
74 | dump_input = torch.rand(
75 | (1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])
76 | )
77 | logger.info(get_model_summary(model.cuda(), dump_input.cuda()))
78 |
79 | if config.TEST.MODEL_FILE:
80 | model_state_file = config.TEST.MODEL_FILE
81 | else:
82 | # model_state_file = os.path.join(final_output_dir, 'best_0.7589.pth')
83 | model_state_file = os.path.join(final_output_dir, 'best.pth')
84 | logger.info('=> loading model from {}'.format(model_state_file))
85 |
86 | pretrained_dict = torch.load('/home/hwits/Documents/CarVid/DDRNet/DDRNet.pytorch/model_best_bacc.pth.tar')
87 | if 'state_dict' in pretrained_dict:
88 | pretrained_dict = pretrained_dict['state_dict']
89 | newstate_dict = {k:v for k,v in pretrained_dict.items() if k in model.state_dict()}
90 | # print(pretrained_dict.keys())
91 |
92 | model.load_state_dict(newstate_dict)
93 | model = model.cuda()
94 |
95 |
96 | if True:
97 | save_wts = True
98 | print(model)
99 | if save_wts:
100 | f = open('DDRNetLite.wts', 'w')
101 | f.write('{}\n'.format(len(model.state_dict().keys())))
102 | for k, v in model.state_dict().items():
103 | print("Layer {} ; Size {}".format(k,v.cpu().numpy().shape))
104 | vr = v.reshape(-1).cpu().numpy()
105 | f.write('{} {} '.format(k, len(vr)))
106 | for vv in vr:
107 | f.write(' ')
108 | f.write(struct.pack('>f', float(vv)).hex())
109 | f.write('\n')
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 | if __name__ == '__main__':
118 | main()
119 |
--------------------------------------------------------------------------------
/tools/maks.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn)
5 | # ------------------------------------------------------------------------------
6 |
7 | import argparse
8 | import os
9 | import pprint
10 | import shutil
11 | import sys
12 |
13 | import logging
14 | import time
15 | import timeit
16 | from pathlib import Path
17 | import time
18 | import numpy as np
19 |
20 | import torch
21 | import torch.nn as nn
22 | import torch.backends.cudnn as cudnn
23 |
24 | import _init_paths
25 | import models
26 | import cv2
27 | import torch.nn.functional as F
28 | import datasets
29 | from config import config
30 | from config import update_config
31 | from core.function import testval, test
32 | from utils.modelsummary import get_model_summary
33 | from utils.utils import create_logger, FullModel, speed_test
34 |
35 | def parse_args():
36 | parser = argparse.ArgumentParser(description='Train segmentation network')
37 |
38 | parser.add_argument('--cfg',
39 | help='experiment configure file name',
40 | default="experiments/cityscapes/ddrnet23_slim.yaml",
41 | type=str)
42 | parser.add_argument('opts',
43 | help="Modify config options using the command-line",
44 | default=None,
45 | nargs=argparse.REMAINDER)
46 |
47 | args = parser.parse_args()
48 | update_config(config, args)
49 |
50 | return args
51 |
52 | def add_alpha_channel(img):
53 | """ 为jpg图像添加alpha通道 """
54 |
55 | b_channel, g_channel, r_channel = cv2.split(img) # 剥离jpg图像通道
56 | alpha_channel = np.ones(b_channel.shape, dtype=b_channel.dtype) * 255 # 创建Alpha通道
57 |
58 | img_new = cv2.merge((b_channel, g_channel, r_channel, alpha_channel)) # 融合通道
59 | return img_new
60 |
61 | def merge_img(jpg_img, png_img, y1, y2, x1, x2):
62 | """ 将png透明图像与jpg图像叠加
63 | y1,y2,x1,x2为叠加位置坐标值
64 | """
65 |
66 | # 判断jpg图像是否已经为4通道
67 | if jpg_img.shape[2] == 3:
68 | jpg_img = add_alpha_channel(jpg_img)
69 |
70 | '''
71 | 当叠加图像时,可能因为叠加位置设置不当,导致png图像的边界超过背景jpg图像,而程序报错
72 | 这里设定一系列叠加位置的限制,可以满足png图像超出jpg图像范围时,依然可以正常叠加
73 | '''
74 | yy1 = 0
75 | yy2 = png_img.shape[0]
76 | xx1 = 0
77 | xx2 = png_img.shape[1]
78 |
79 | if x1 < 0:
80 | xx1 = -x1
81 | x1 = 0
82 | if y1 < 0:
83 | yy1 = - y1
84 | y1 = 0
85 | if x2 > jpg_img.shape[1]:
86 | xx2 = png_img.shape[1] - (x2 - jpg_img.shape[1])
87 | x2 = jpg_img.shape[1]
88 | if y2 > jpg_img.shape[0]:
89 | yy2 = png_img.shape[0] - (y2 - jpg_img.shape[0])
90 | y2 = jpg_img.shape[0]
91 |
92 | # 获取要覆盖图像的alpha值,将像素值除以255,使值保持在0-1之间
93 | alpha_png = png_img[yy1:yy2,xx1:xx2,3] / 255.0
94 | alpha_jpg = 1 - alpha_png
95 |
96 | # 开始叠加
97 | for c in range(0,3):
98 | jpg_img[y1:y2, x1:x2, c] = ((alpha_jpg*jpg_img[y1:y2,x1:x2,c]) + (alpha_png*png_img[yy1:yy2,xx1:xx2,c]))
99 |
100 | return jpg_img
101 |
102 | def main():
103 | mean=[0.485, 0.456, 0.406],
104 | std=[0.229, 0.224, 0.225]
105 | args = parse_args()
106 |
107 | logger, final_output_dir, _ = create_logger(
108 | config, args.cfg, 'test')
109 |
110 | logger.info(pprint.pformat(args))
111 | logger.info(pprint.pformat(config))
112 |
113 | # cudnn related setting
114 | cudnn.benchmark = config.CUDNN.BENCHMARK
115 | cudnn.deterministic = config.CUDNN.DETERMINISTIC
116 | cudnn.enabled = config.CUDNN.ENABLED
117 |
118 | # build model
119 | if torch.__version__.startswith('1'):
120 | module = eval('models.'+config.MODEL.NAME)
121 | module.BatchNorm2d_class = module.BatchNorm2d = torch.nn.BatchNorm2d
122 | model = eval('models.'+config.MODEL.NAME +
123 | '.get_seg_model')(config)
124 |
125 | dump_input = torch.rand(
126 | (1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])
127 | )
128 | logger.info(get_model_summary(model.cuda(), dump_input.cuda()))
129 |
130 | if config.TEST.MODEL_FILE:
131 | model_state_file = config.TEST.MODEL_FILE
132 | else:
133 | # model_state_file = os.path.join(final_output_dir, 'best_0.7589.pth')
134 | model_state_file = os.path.join(final_output_dir, 'best.pth')
135 | logger.info('=> loading model from {}'.format(model_state_file))
136 |
137 | pretrained_dict = torch.load('/home/kong/Documents/DDRNet.Pytorch/DDRNet.Pytorch/output/face/ddrnet23_slim/checkpoint.pth.tar')
138 | if 'state_dict' in pretrained_dict:
139 | pretrained_dict = pretrained_dict['state_dict']
140 |
141 | # print(pretrained_dict.keys())
142 | new_state = {k:v for k,v in pretrained_dict.items() if k in model.state_dict()}
143 |
144 | model.load_state_dict(new_state)
145 | model = model.cuda()
146 |
147 | torch.save(model.state_dict(), 'model_best_bacc.pth.tar', _use_new_zipfile_serialization=False)
148 |
149 |
150 | # gpus = list(config.GPUS)
151 | # model = nn.DataParallel(model, device_ids=gpus).cuda()
152 |
153 | # prepare data
154 | test_size = (config.TEST.IMAGE_SIZE[1], config.TEST.IMAGE_SIZE[0])
155 | print(test_size)
156 | img = cv2.imread("/home/kong/Documents/DDRNet.Pytorch/DDRNet.Pytorch/images/418cd0c0b416d93bc5a129834537f2e1.jpeg")
157 |
158 | img = cv2.resize(img,(512,512))
159 | image = img.astype(np.float32)[:, :, ::-1]
160 | image = image / 255.0
161 | image -= mean
162 | image /= std
163 |
164 | image = image.transpose((2,0,1))
165 | image = torch.from_numpy(image)
166 |
167 | # image = image.permute((2, 0, 1))
168 |
169 | # print(image.shape)
170 | image = image.unsqueeze(0)
171 |
172 | image = image.cuda()
173 | start = time.time()
174 | # print("input : ",image)
175 | for i in range(1):
176 | out= model(image)
177 | end = time.time()
178 | #print("Cuda 1000 images inference time : ",1000.0/(end - start))
179 | outadd = out[0]*1.0 + out[1]*0.4
180 | out0 = out[0].squeeze(dim=0)
181 | out1 = out[1].squeeze(dim=0)
182 |
183 | # print(out0.size(),out0[0,1,1],out0[1,1,1])
184 | # print("out:",out0)
185 |
186 |
187 | outadd = outadd.squeeze(dim=0)
188 | out0 = F.softmax(out0,dim=0)
189 | out1 = F.softmax(out1,dim=0)
190 | outadd = F.softmax(outadd,dim=0)
191 |
192 | out0 = torch.argmax(out0,dim=0)
193 | out1 = torch.argmax(out1,dim=0)
194 | outadd = torch.argmax(outadd,dim=0)
195 |
196 | pred0 = out0.detach().cpu().numpy()
197 | pred1 = out1.detach().cpu().numpy()
198 | predadd = outadd.detach().cpu().numpy()
199 | pred0 = pred0*255
200 | pred1 = pred1*255
201 | predadd = predadd*255
202 |
203 |
204 | ####================= alpha channel =========================#
205 | print("pred0:",pred0.shape, img.shape)
206 | pred0 = np.array(pred0,np.uint8)
207 | pred0up = cv2.resize(pred0,(512,512))
208 | png = np.dstack((img,pred0up))
209 | bg = cv2.imread("/home/kong/Downloads/4aeb26a89778f73261ccef283e70992f.jpeg")
210 |
211 | addpng = merge_img(bg,png,500,1012,100,612)
212 | cv2.imwrite("png.png",addpng)
213 |
214 | pred_ch = np.zeros(pred0.shape)
215 | pred_rgb0 = np.array([pred_ch,pred_ch,pred0])
216 | pred_rgb1 = np.array([pred_ch,pred_ch,pred1])
217 | pred_rgbadd = np.array([predadd,pred_ch,predadd])
218 | pred_rgb0 = pred_rgb0.transpose(1,2,0)
219 | pred_rgb1 = pred_rgb1.transpose(1,2,0)
220 | pred_rgbadd = pred_rgbadd.transpose(1,2,0)
221 | pred_rgb0 = cv2.resize(pred_rgb0,(img.shape[1],img.shape[0]))
222 | pred_rgb1 = cv2.resize(pred_rgb1,(img.shape[1],img.shape[0]))
223 | pred_rgbadd = cv2.resize(pred_rgbadd,(img.shape[1],img.shape[0]))
224 | dst=cv2.addWeighted(img,0.7,pred_rgb0.astype(np.uint8),0.3,0)
225 | dst1=cv2.addWeighted(img,0.7,pred_rgb1.astype(np.uint8),0.3,0)
226 | dstadd=cv2.addWeighted(img,0.7,pred_rgbadd.astype(np.uint8),0.3,0)
227 |
228 | imgadd = np.hstack((img,dstadd))
229 |
230 |
231 | cv2.imwrite("a242.jpg",imgadd)
232 |
233 |
234 |
235 |
236 | if __name__ == '__main__':
237 | main()
238 |
--------------------------------------------------------------------------------
/tools/quantize.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn)
5 | # ------------------------------------------------------------------------------
6 |
7 | import argparse
8 | import os
9 | import pprint
10 | import shutil
11 | import sys
12 |
13 | import logging
14 | import time
15 | import timeit
16 | from pathlib import Path
17 | import time
18 | import numpy as np
19 |
20 | import torch
21 | import torch.nn as nn
22 | import torch.backends.cudnn as cudnn
23 |
24 | import _init_paths
25 | import models
26 | import cv2
27 | import torch.nn.functional as F
28 | import datasets
29 | from config import config
30 | from config import update_config
31 | from core.function import testval, test
32 | from utils.modelsummary import get_model_summary
33 | from utils.utils import create_logger, FullModel, speed_test
34 |
35 | def parse_args():
36 | parser = argparse.ArgumentParser(description='Train segmentation network')
37 |
38 | parser.add_argument('--cfg',
39 | help='experiment configure file name',
40 | default="experiments/cityscapes/ddrnet23_slim.yaml",
41 | type=str)
42 | parser.add_argument('opts',
43 | help="Modify config options using the command-line",
44 | default=None,
45 | nargs=argparse.REMAINDER)
46 |
47 | args = parser.parse_args()
48 | update_config(config, args)
49 |
50 | return args
51 |
52 | def main():
53 | mean=[0.485, 0.456, 0.406],
54 | std=[0.229, 0.224, 0.225]
55 | args = parse_args()
56 |
57 | logger, final_output_dir, _ = create_logger(
58 | config, args.cfg, 'test')
59 |
60 | logger.info(pprint.pformat(args))
61 | logger.info(pprint.pformat(config))
62 |
63 | # cudnn related setting
64 | cudnn.benchmark = config.CUDNN.BENCHMARK
65 | cudnn.deterministic = config.CUDNN.DETERMINISTIC
66 | cudnn.enabled = config.CUDNN.ENABLED
67 |
68 | # build model
69 | if torch.__version__.startswith('1'):
70 | module = eval('models.'+config.MODEL.NAME)
71 | module.BatchNorm2d_class = module.BatchNorm2d = torch.nn.BatchNorm2d
72 | model = eval('models.'+config.MODEL.NAME +
73 | '.get_seg_model')(config)
74 |
75 | dump_input = torch.rand(
76 | (1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])
77 | )
78 | logger.info(get_model_summary(model.cuda(), dump_input.cuda()))
79 |
80 | if config.TEST.MODEL_FILE:
81 | model_state_file = config.TEST.MODEL_FILE
82 | else:
83 | # model_state_file = os.path.join(final_output_dir, 'best_0.7589.pth')
84 | model_state_file = os.path.join(final_output_dir, 'best.pth')
85 | logger.info('=> loading model from {}'.format(model_state_file))
86 |
87 | pretrained_dict = torch.load('/home/kong/Documents/DDRNet.Pytorch/DDRNet.Pytorch/output/face/ddrnet23_slim/checkpoint.pth.tar')
88 | if 'state_dict' in pretrained_dict:
89 | pretrained_dict = pretrained_dict['state_dict']
90 |
91 | newstate_dict = {k:v for k,v in pretrained_dict.items() if k in model.state_dict()}
92 | # print(pretrained_dict.keys())
93 |
94 | model.load_state_dict(newstate_dict)
95 | model = model.to("cpu")
96 | print(model)
97 | model.eval()
98 | example = torch.rand(1, 3, 512, 512)
99 |
100 | model = torch.quantization.convert(model)
101 | # traced_script_module = torch.jit.trace(model, example)
102 | # traced_script_module.save("ddrnetfp32.pt")
103 | scriptedm = torch.jit.script(model)
104 | opt_model = torch.utils.optimize_for_mobile(scriptedm)
105 | torch.jit.save(opt_model, "ddrnetint8.pt")
106 |
107 |
108 |
109 |
110 |
111 |
112 | if __name__ == '__main__':
113 | main()
114 |
--------------------------------------------------------------------------------
/tools/test.py:
--------------------------------------------------------------------------------
1 |
2 | # Ltrain = "data/mapv3/1_train.lst"
3 | # train_list = []
4 | # Cval = "data/mapv3/testval.txt"
5 | # val_list = []
6 |
7 | # with open(Ltrain, "r") as f:
8 | # datas = f.readlines()
9 | # for data in datas:
10 | # t = data.split()[0]
11 | # t = t.split(".")[0]
12 | # train_list.append(t)
13 |
14 | # with open(Cval, "r") as f:
15 | # datas = f.readlines()
16 | # for data in datas:
17 | # val_list.append(data.split()[0])
18 |
19 | # temp = list(set(train_list).intersection(set(val_list)))
20 |
21 | # t = 1
22 |
23 |
24 |
--------------------------------------------------------------------------------
/tools/to_onnx.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn)
5 | # ------------------------------------------------------------------------------
6 |
7 | import argparse
8 | import os
9 | import pprint
10 | import onnx
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 | import torch.backends.cudnn as cudnn
15 |
16 | import _init_paths
17 | import models
18 | import datasets
19 | from config import config
20 | from config import update_config
21 | from core.function import testval, test
22 | from utils.modelsummary import get_model_summary
23 | from utils.utils import create_logger, FullModel, speed_test
24 |
25 | def parse_args():
26 | parser = argparse.ArgumentParser(description='Train segmentation network')
27 |
28 | parser.add_argument('--cfg',
29 | help='experiment configure file name',
30 | default="experiments/cityscapes/ddrnet23_slim.yaml",
31 | type=str)
32 | parser.add_argument('opts',
33 | help="Modify config options using the command-line",
34 | default=None,
35 | nargs=argparse.REMAINDER)
36 |
37 | args = parser.parse_args()
38 | update_config(config, args)
39 |
40 | return args
41 |
42 | class onnx_net(nn.Module):
43 | def __init__(self, model):
44 | super(onnx_net, self).__init__()
45 | self.backone = model
46 |
47 | def forward(self, x):
48 | x1, x2 = self.backone(x)
49 | y = F.interpolate(x1, size=(480,640), mode='bilinear')
50 | # y = F.softmax(y, dim=1)
51 | y = torch.argmax(y, dim=1)
52 |
53 | return y
54 |
55 |
56 | def main():
57 | args = parse_args()
58 |
59 | logger, final_output_dir, _ = create_logger(
60 | config, args.cfg, 'test')
61 |
62 | logger.info(pprint.pformat(args))
63 | logger.info(pprint.pformat(config))
64 |
65 | # cudnn related setting
66 | cudnn.benchmark = config.CUDNN.BENCHMARK
67 | cudnn.deterministic = config.CUDNN.DETERMINISTIC
68 | cudnn.enabled = config.CUDNN.ENABLED
69 |
70 | # build model
71 | if torch.__version__.startswith('1'):
72 | module = eval('models.'+config.MODEL.NAME)
73 | module.BatchNorm2d_class = module.BatchNorm2d = torch.nn.BatchNorm2d
74 | model = eval('models.'+config.MODEL.NAME +
75 | '.get_seg_model')(config)
76 |
77 | if config.TEST.MODEL_FILE:
78 | model_state_file = config.TEST.MODEL_FILE
79 | else:
80 | model_state_file = os.path.join(final_output_dir, 'best_0.7589.pth')
81 | # model_state_file = os.path.join(final_output_dir, 'final_state.pth')
82 | logger.info('=> loading model from {}'.format(model_state_file))
83 |
84 | pretrained_dict = torch.load(model_state_file)
85 | if 'state_dict' in pretrained_dict:
86 | pretrained_dict = pretrained_dict['state_dict']
87 | model_dict = model.state_dict()
88 | pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items()
89 | if k[6:] in model_dict.keys()}
90 | for k, _ in pretrained_dict.items():
91 | logger.info(
92 | '=> loading {} from pretrained model'.format(k))
93 | model_dict.update(pretrained_dict)
94 | model.load_state_dict(model_dict)
95 |
96 | net = onnx_net(model)
97 | net = net.eval()
98 |
99 | # x = torch.randn((1, 3, 512, 384))
100 | x = torch.randn((1,3,480,640))
101 | torch_out = net(x)
102 |
103 | # output_path = "output/tensorrt/resnet50/resnet50_bilinear.onnx"
104 | output_path = "output/ddrnet23_slim.onnx"
105 | torch.onnx.export(net, # model being run
106 | x, # model input (or a tuple for multiple inputs)
107 | output_path, # where to save the model (can be a file or file-like object)
108 | export_params=True, # store the trained parameter weights inside the model file
109 | opset_version=11, # the ONNX version to export the model to
110 | do_constant_folding=True, # whether to execute constant folding for optimization
111 | input_names = ['inputx'], # the model's input names
112 | output_names = ['outputy'], # the model's output names
113 | verbose=True,
114 | )
115 | # onnx.checker.check_model(output_path)
116 |
117 |
118 |
119 |
120 | if __name__ == '__main__':
121 | main()
122 |
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | python tools/train_single.py --cfg experiments/face/ddrnet23_slim.yaml
2 |
--------------------------------------------------------------------------------