├── .gitignore
├── README.md
├── assets
├── bird_4008.gif
├── bird_4648.gif
├── bird_5602.gif
├── sg_bird_4008.gif
├── sg_bird_4648.gif
├── sg_bird_5602.gif
├── target_bird_4008.png
├── target_bird_4648.png
└── target_bird_5602.png
├── environment.yml
├── input
├── target_bird_4008.png
├── target_bird_4648.png
└── target_bird_5602.png
├── output
├── baseline_painted_target_bird_4008.png
├── baseline_painted_target_bird_4648.png
├── sg_painted_target_bird_4008.png
└── sg_painted_target_bird_4648.png
├── semantic_guidance
├── DRL
│ ├── actor.py
│ ├── critic.py
│ ├── ddpg.py
│ ├── evaluator.py
│ ├── multi.py
│ ├── rpm.py
│ └── wgan.py
├── Renderer
│ ├── __init__.py
│ ├── model.py
│ └── stroke_gen.py
├── __init__.py
├── env_ins.py
├── preprocess.py
├── test.py
├── test_utils.py
├── train.py
├── train_renderer.py
└── utils
│ ├── __init__.py
│ ├── dataloader.py
│ ├── tensorboard.py
│ └── util.py
└── video
├── baseline_target_bird_4648.mp4
├── sg_bird_5602.gif
└── sg_target_bird_4648.mp4
/.gitignore:
--------------------------------------------------------------------------------
1 | data
2 | #video
3 | .DS_Store
4 | .idea/
5 | baseline_spc
6 | semantic_guidance/test_v2*
7 | semantic_guidance/test_copy*
8 | files
9 | semantic_guidance/pretrained_models
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Semantic-Guidance: Distilling Object Awareness into Paintings [](https://twitter.com/intent/tweet?text=Automatically%20generate%20human-level%20paintings%20using%20a%20combination%20of%20Deep-RL%20and%20Semantic-Guidance&url=https://github.com/1jsingh/semantic-guidance&&hashtags=LearningToPaint,CVPR2021)
2 |
3 | This repository contains code for our CVPR-2021 paper on [Combining Semantic Guidance and Deep Reinforcement Learning For Generating Human Level Paintings](https://arxiv.org/pdf/2011.12589.pdf).
4 |
5 | The Semantic Guidance pipeline distills different forms of object awareness (semantic segmentation, object localization and guided backpropagation maps) into the painting process itself. The resulting agent is able to paint canvases with increased saliency of foreground objects and enhanced granularity of key image features.
6 |
7 |
11 | ## Contents
12 | * [Demo](#demo)
13 | * [Environment Setup](#environment-setup)
14 | * [Dataset and Preprocessing](#dataset-and-preprocessing)
15 | * [Training](#training)
16 | * [Testing using Pretrained Models](#testing-using-pretrained-models)
17 | * [Citation](#citation)
18 |
19 |
20 | ## Demo
21 | Traditional reinforcement learning based methods for the "*learning to paint*" problem, show poor performance on real world datasets with high variance in position, scale and saliency of the foreground objects. To address this we propose a semantic guidance pipeline, which distills object awareness knowledge into the painting process, and thereby learns to generate semantically accurate canvases under adverse painting conditions.
22 |
23 | | Target Image | Baseline (Huang et al. 2019) | Semantic Guidance (Ours) |
24 | |:-------------:|:-------------:|:-------------:|
25 | |
|
|
|
26 | |
|
|
|
27 | |
|
|
|
28 |
29 |
30 |
31 | ### Environment Setup
32 |
33 | * Set up the python environment for running the experiments.
34 | ```bash
35 | conda env update --name semantic-guidance --file environment.yml
36 | conda activate semantic-guidance
37 | ```
38 |
39 | ### Dataset and Preprocessing
40 | * Download [CUB-200-2011 Birds](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html) dataset and place it in the `data/cub200/CUB_200_2011/` folder.
41 | ```bash
42 | mkdir -p data/cub200 && cd data/cub200
43 | gdown https://drive.google.com/uc?id=1hbzc_P1FuxMkcabkgn9ZKinBwW683j45
44 | tar -xvzf CUB_200_2011.tgz
45 | ```
46 |
47 | * The final data folder looks as follows,
48 | ```bash
49 | data
50 | ├── cub200/
51 | │ └── CUB_200_2011/
52 | │ └── images/
53 | │ └── ...
54 | │ └── images.txt
55 | ```
56 |
57 | * Download differentiable neural renderer: [renderer.pkl](https://drive.google.com/file/d/1VloSGAWYRiVYv3bRfBuB0uKj2m7Cyzu8/view?usp=sharing) and place it in the `data/.` folder.
58 | ```bash
59 | cd data
60 | gdown https://drive.google.com/uc?id=1VloSGAWYRiVYv3bRfBuB0uKj2m7Cyzu8
61 | ```
62 |
63 | * Download combined model for object localization and semantic segmentation from [here](https://drive.google.com/file/d/14CIdpem-85y53KkkW2oBspXh-2PPXtTs/view?usp=sharing), and place it in place it in the `data/.` folder.
64 | ```bash
65 | cd data
66 | gdown https://drive.google.com/uc?id=14CIdpem-85y53KkkW2oBspXh-2PPXtTs
67 | ```
68 |
69 | * Choose one of the following options to get preprocessed data predictions (preprocessing helps faciliate faster training),
70 | * **Option 1:** run the preprocessing script to generate object localization, semantic segmentation and bounding box predictions.
71 | ```bash
72 | cd semantic_guidance
73 | python preprocess.py
74 | ```
75 |
76 | * **Option 2:** you can also directly download the preprocessed birds dataset from [here](https://drive.google.com/file/d/1s3lvo0Dn538lPghpXY1gEOTAZOTsojxJ/view?usp=sharing), and place the prediction folders in the original data directory.
77 | ```bash
78 | cd data/cub200/CUB_200_2011/
79 | gdown https://drive.google.com/uc?id=1s3lvo0Dn538lPghpXY1gEOTAZOTsojxJ
80 | unzip preprocessed-cub200-2011.zip
81 | mv preprocessed-cub200-2011/* .
82 | ```
83 |
84 | - The final data directory should look like:
85 | ```bash
86 | data
87 | ├── cub200/
88 | │ └── CUB_200_2011/
89 | │ └── images/
90 | │ └── ...
91 | │ └── segmentations_pred/
92 | │ └── ...
93 | │ └── gbp_global/
94 | │ └── ...
95 | │ └── bounding_boxes_pred.txt
96 | │ └── images.txt
97 | └── renderer.pkl
98 | └── birds_obj_seg.pkl
99 | ```
100 |
101 | ### Training
102 |
103 | * Train the baseline model from [Huang et al. 2019](https://arxiv.org/abs/1903.04411)
104 | ```bash
105 | cd semantic_guidance
106 | python train.py \
107 | --dataset cub200 \
108 | --debug \
109 | --batch_size=96 \
110 | --max_eps_len=50 \
111 | --bundle_size=5 \
112 | --exp_suffix baseline
113 | ```
114 |
115 | * Train the deep reinforcement learning based painting agent using Semantic Guidance pipeline.
116 | ```bash
117 | cd semantic_guidance
118 | python train.py \
119 | --dataset cub200 \
120 | --debug \
121 | --batch_size=96 \
122 | --max_eps_len=50 \
123 | --bundle_size=5 \
124 | --use_bilevel \
125 | --use_gbp \
126 | --exp_suffix semantic-guidance
127 | ```
128 |
129 | ### Testing using Pretrained Models
130 |
131 | * Download the pretrained models for the [Baseline](https://drive.google.com/file/d/1OvN7yRia44nhD16KmjcAvxG8xICWl42p/view?usp=sharing) and [Semantic Guidance](https://drive.google.com/file/d/173p2rUQlNpp8fLA3u5s24QKJLU68QTkw/view?usp=sharing) agents. Place the downloaded models in `./semantic_guidance/pretrained_models` directory.
132 | ```bash
133 | cd semantic_guidance
134 | mkdir pretrained_models && cd pretrained_models
135 | gdown https://drive.google.com/uc?id=1OvN7yRia44nhD16KmjcAvxG8xICWl42p
136 | gdown https://drive.google.com/uc?id=173p2rUQlNpp8fLA3u5s24QKJLU68QTkw
137 | ```
138 |
139 | * The final directory structure should look as follows,
140 | ```bash
141 | semantic-guidance
142 | ├── semantic_guidance/
143 | │ └── pretrained_models/
144 | │ └── actor_baseline.pkl
145 | │ └── actor_semantic_guidance.pkl
146 | ```
147 |
148 | * Generate the painting sequence using pretrained baseline agent.
149 | ```bash
150 | cd semantic_guidance
151 | python test.py \
152 | --img ../input/target_bird_4648.png \
153 | --actor pretrained_models/actor_baseline.pkl \
154 | --use_baseline
155 | ```
156 |
157 | * Use the pretrained Semantic Guidance agent to paint canvases.
158 | ```bash
159 | cd semantic_guidance
160 | python test.py \
161 | --img ../input/target_bird_4648.png \
162 | --actor pretrained_models/actor_semantic_guidance.pkl
163 | ```
164 |
165 | * The `test` script stores the final canvas state in the `./output` folder and saves a video for the painting sequence in the `./video` directory.
166 |
167 |
168 | # Citation
169 |
170 | If you find this work useful in your research, please cite our paper:
171 | ```
172 | @inproceedings{singh2021combining,
173 | title={Combining Semantic Guidance and Deep Reinforcement Learning For Generating Human Level Paintings},
174 | author={Singh, Jaskirat and Zheng, Liang},
175 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
176 | pages={16387--16396},
177 | year={2021}
178 | }
179 | ```
180 |
181 |
184 |
--------------------------------------------------------------------------------
/assets/bird_4008.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1jsingh/semantic-guidance/092f923a0f5d42a949489bd49df8449db9d94d53/assets/bird_4008.gif
--------------------------------------------------------------------------------
/assets/bird_4648.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1jsingh/semantic-guidance/092f923a0f5d42a949489bd49df8449db9d94d53/assets/bird_4648.gif
--------------------------------------------------------------------------------
/assets/bird_5602.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1jsingh/semantic-guidance/092f923a0f5d42a949489bd49df8449db9d94d53/assets/bird_5602.gif
--------------------------------------------------------------------------------
/assets/sg_bird_4008.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1jsingh/semantic-guidance/092f923a0f5d42a949489bd49df8449db9d94d53/assets/sg_bird_4008.gif
--------------------------------------------------------------------------------
/assets/sg_bird_4648.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1jsingh/semantic-guidance/092f923a0f5d42a949489bd49df8449db9d94d53/assets/sg_bird_4648.gif
--------------------------------------------------------------------------------
/assets/sg_bird_5602.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1jsingh/semantic-guidance/092f923a0f5d42a949489bd49df8449db9d94d53/assets/sg_bird_5602.gif
--------------------------------------------------------------------------------
/assets/target_bird_4008.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1jsingh/semantic-guidance/092f923a0f5d42a949489bd49df8449db9d94d53/assets/target_bird_4008.png
--------------------------------------------------------------------------------
/assets/target_bird_4648.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1jsingh/semantic-guidance/092f923a0f5d42a949489bd49df8449db9d94d53/assets/target_bird_4648.png
--------------------------------------------------------------------------------
/assets/target_bird_5602.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1jsingh/semantic-guidance/092f923a0f5d42a949489bd49df8449db9d94d53/assets/target_bird_5602.png
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | channels:
2 | - conda-forge
3 | - defaults
4 | dependencies:
5 | - _libgcc_mutex=0.1=main
6 | - bzip2=1.0.8=h7b6447c_0
7 | - ca-certificates=2020.7.22=0
8 | - certifi=2020.6.20=py36_0
9 | - freetype=2.10.2=h5ab3b9f_0
10 | - gdown=3.13.0=pyhd8ed1ab_0
11 | - gmp=6.1.2=h6c8ec71_1
12 | - gnutls=3.6.5=h71b1129_1002
13 | - htop=2.2.0=hf8c457e_1000
14 | - lame=3.100=h7b6447c_0
15 | - ld_impl_linux-64=2.33.1=h53a641e_7
16 | - libedit=3.1.20191231=h14c3975_1
17 | - libffi=3.3=he6710b0_2
18 | - libgcc-ng=9.1.0=hdf63c60_0
19 | - libiconv=1.16=h516909a_0
20 | - libopus=1.3.1=h7b6447c_0
21 | - libpng=1.6.37=hbc83047_0
22 | - libstdcxx-ng=9.1.0=hdf63c60_0
23 | - libvpx=1.7.0=h439df22_0
24 | - nano=2.9.8=hddfc1eb_1001
25 | - ncurses=6.2=he6710b0_1
26 | - nettle=3.4.1=hbb512f6_0
27 | - openh264=2.1.0=hd408876_0
28 | - openssl=1.1.1g=h7b6447c_0
29 | - pip=20.2.2=py36_0
30 | - python=3.6.10=h7579374_2
31 | - python_abi=3.6=1_cp36m
32 | - readline=8.0=h7b6447c_0
33 | - setuptools=49.6.0=py36_0
34 | - sqlite=3.32.3=h62c20be_0
35 | - tk=8.6.10=hbc83047_0
36 | - unzip=6.0=h611a1e1_0
37 | - wheel=0.34.2=py36_0
38 | - x264=1!157.20191217=h7b6447c_0
39 | - xz=5.2.5=h7b6447c_0
40 | - zlib=1.2.11=h7b6447c_3
41 | - pip:
42 | - absl-py==0.10.0
43 | - argon2-cffi==20.1.0
44 | - attrs==20.1.0
45 | - backcall==0.2.0
46 | - bleach==3.1.5
47 | - cachetools==4.1.1
48 | - cffi==1.14.2
49 | - chardet==3.0.4
50 | - cycler==0.10.0
51 | - cython==0.29.21
52 | - decorator==4.4.2
53 | - defusedxml==0.6.0
54 | - efficientnet-pytorch==0.7.0
55 | - entrypoints==0.3
56 | - ffmpeg==1.4
57 | - future==0.18.2
58 | - google-auth==1.21.1
59 | - google-auth-oauthlib==0.4.1
60 | - grpcio==1.31.0
61 | - h5py==2.10.0
62 | - idna==2.10
63 | - importlib-metadata==1.7.0
64 | - ipykernel==5.3.4
65 | - ipython==7.16.1
66 | - ipython-genutils==0.2.0
67 | - ipywidgets==7.5.1
68 | - jedi==0.17.2
69 | - jinja2==2.11.2
70 | - joblib==0.16.0
71 | - jsonschema==3.2.0
72 | - jupyter==1.0.0
73 | - jupyter-client==6.1.6
74 | - jupyter-console==6.1.0
75 | - jupyter-core==4.6.3
76 | - kiwisolver==1.2.0
77 | - markdown==3.2.2
78 | - markupsafe==1.1.1
79 | - matplotlib==3.3.1
80 | - mistune==0.8.4
81 | - nbconvert==5.6.1
82 | - nbformat==5.0.7
83 | - notebook==6.1.3
84 | - numpy==1.19.1
85 | - oauthlib==3.1.0
86 | - opencv-python==4.4.0.42
87 | - packaging==20.4
88 | - pandas==1.1.1
89 | - pandocfilters==1.4.2
90 | - parso==0.7.1
91 | - pexpect==4.8.0
92 | - pickleshare==0.7.5
93 | - pillow==7.2.0
94 | - prometheus-client==0.8.0
95 | - prompt-toolkit==3.0.6
96 | - protobuf==3.13.0
97 | - ptyprocess==0.6.0
98 | - pyasn1==0.4.8
99 | - pyasn1-modules==0.2.8
100 | - pycparser==2.20
101 | - pygments==2.6.1
102 | - pyparsing==2.4.7
103 | - pyrsistent==0.16.0
104 | - python-dateutil==2.8.1
105 | - pytz==2020.1
106 | - pyzmq==19.0.2
107 | - qtconsole==4.7.6
108 | - qtpy==1.9.0
109 | - requests==2.24.0
110 | - requests-oauthlib==1.3.0
111 | - rsa==4.6
112 | - scikit-learn==0.23.2
113 | - scipy==1.2.0
114 | - seaborn==0.10.1
115 | - send2trash==1.5.0
116 | - six==1.15.0
117 | - tensorboard==2.3.0
118 | - tensorboard-plugin-wit==1.7.0
119 | - tensorboardx==2.1
120 | - termcolor==1.1.0
121 | - terminado==0.8.3
122 | - testpath==0.4.4
123 | - threadpoolctl==2.1.0
124 | - torch==1.6.0
125 | - torchgeometry==0.1.2
126 | - torchvision==0.7.0
127 | - tornado==6.0.4
128 | - traitlets==4.3.3
129 | - urllib3==1.25.10
130 | - wcwidth==0.2.5
131 | - webencodings==0.5.1
132 | - werkzeug==1.0.1
133 | - widgetsnbextension==3.5.1
134 | - zipp==3.1.0
135 |
136 |
--------------------------------------------------------------------------------
/input/target_bird_4008.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1jsingh/semantic-guidance/092f923a0f5d42a949489bd49df8449db9d94d53/input/target_bird_4008.png
--------------------------------------------------------------------------------
/input/target_bird_4648.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1jsingh/semantic-guidance/092f923a0f5d42a949489bd49df8449db9d94d53/input/target_bird_4648.png
--------------------------------------------------------------------------------
/input/target_bird_5602.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1jsingh/semantic-guidance/092f923a0f5d42a949489bd49df8449db9d94d53/input/target_bird_5602.png
--------------------------------------------------------------------------------
/output/baseline_painted_target_bird_4008.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1jsingh/semantic-guidance/092f923a0f5d42a949489bd49df8449db9d94d53/output/baseline_painted_target_bird_4008.png
--------------------------------------------------------------------------------
/output/baseline_painted_target_bird_4648.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1jsingh/semantic-guidance/092f923a0f5d42a949489bd49df8449db9d94d53/output/baseline_painted_target_bird_4648.png
--------------------------------------------------------------------------------
/output/sg_painted_target_bird_4008.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1jsingh/semantic-guidance/092f923a0f5d42a949489bd49df8449db9d94d53/output/sg_painted_target_bird_4008.png
--------------------------------------------------------------------------------
/output/sg_painted_target_bird_4648.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1jsingh/semantic-guidance/092f923a0f5d42a949489bd49df8449db9d94d53/output/sg_painted_target_bird_4648.png
--------------------------------------------------------------------------------
/semantic_guidance/DRL/actor.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import torch.nn.utils.weight_norm as weightNorm
7 |
8 | from torch.autograd import Variable
9 | import sys
10 |
11 |
12 | def conv3x3(in_planes, out_planes, stride=1):
13 | return (nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False))
14 |
15 |
16 | def cfg(depth):
17 | depth_lst = [18, 34, 50, 101, 152]
18 | assert (depth in depth_lst), "Error : Resnet depth should be either 18, 34, 50, 101, 152"
19 | cf_dict = {
20 | '18': (BasicBlock, [2, 2, 2, 2]),
21 | '34': (BasicBlock, [3, 4, 6, 3]),
22 | '50': (Bottleneck, [3, 4, 6, 3]),
23 | '101': (Bottleneck, [3, 4, 23, 3]),
24 | '152': (Bottleneck, [3, 8, 36, 3]),
25 | }
26 |
27 | return cf_dict[str(depth)]
28 |
29 |
30 | class BasicBlock(nn.Module):
31 | expansion = 1
32 |
33 | def __init__(self, in_planes, planes, stride=1):
34 | super(BasicBlock, self).__init__()
35 | self.conv1 = conv3x3(in_planes, planes, stride)
36 | self.bn1 = nn.BatchNorm2d(planes)
37 | self.conv2 = conv3x3(planes, planes)
38 | self.bn2 = nn.BatchNorm2d(planes)
39 |
40 | self.shortcut = nn.Sequential()
41 | if stride != 1 or in_planes != self.expansion * planes:
42 | self.shortcut = nn.Sequential(
43 | (nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False)),
44 | nn.BatchNorm2d(self.expansion * planes)
45 | )
46 |
47 | def forward(self, x):
48 | out = F.relu(self.bn1(self.conv1(x)))
49 | out = self.bn2(self.conv2(out))
50 | out += self.shortcut(x)
51 | out = F.relu(out)
52 |
53 | return out
54 |
55 |
56 | class Bottleneck(nn.Module):
57 | expansion = 4
58 |
59 | def __init__(self, in_planes, planes, stride=1):
60 | super(Bottleneck, self).__init__()
61 | self.conv1 = (nn.Conv2d(in_planes, planes, kernel_size=1, bias=False))
62 | self.conv2 = (nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False))
63 | self.conv3 = (nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False))
64 | self.bn1 = nn.BatchNorm2d(planes)
65 | self.bn2 = nn.BatchNorm2d(planes)
66 | self.bn3 = nn.BatchNorm2d(self.expansion * planes)
67 |
68 | self.shortcut = nn.Sequential()
69 | if stride != 1 or in_planes != self.expansion * planes:
70 | self.shortcut = nn.Sequential(
71 | (nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False)),
72 | )
73 |
74 | def forward(self, x):
75 | out = F.relu(self.bn1(self.conv1(x)))
76 | out = F.relu(self.bn2(self.conv2(out)))
77 | out = self.bn3(self.conv3(out))
78 | out += self.shortcut(x)
79 | out = F.relu(out)
80 |
81 | return out
82 |
83 |
84 | class ResNet(nn.Module):
85 | def __init__(self, num_inputs, depth, num_outputs, high_res=False):
86 | super(ResNet, self).__init__()
87 | self.in_planes = 64
88 |
89 | block, num_blocks = cfg(depth)
90 |
91 | self.conv1 = conv3x3(num_inputs, 64, 2)
92 | self.bn1 = nn.BatchNorm2d(64)
93 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=2)
94 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
95 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
96 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
97 | self.fc = nn.Linear(512, num_outputs)
98 | self.high_res = high_res
99 |
100 | def _make_layer(self, block, planes, num_blocks, stride):
101 | strides = [stride] + [1] * (num_blocks - 1)
102 | layers = []
103 |
104 | for stride in strides:
105 | layers.append(block(self.in_planes, planes, stride))
106 | self.in_planes = planes * block.expansion
107 |
108 | return nn.Sequential(*layers)
109 |
110 | def forward(self, x):
111 | x = F.relu(self.bn1(self.conv1(x)))
112 | x = self.layer1(x)
113 | x = self.layer2(x)
114 | x = self.layer3(x)
115 | x = self.layer4(x)
116 | if self.high_res is True:
117 | x = F.avg_pool2d(x, 8)
118 | else:
119 | x = F.avg_pool2d(x, 4)
120 | x = x.view(x.size(0), -1)
121 | x = self.fc(x)
122 | x = torch.sigmoid(x)
123 | return x
124 |
--------------------------------------------------------------------------------
/semantic_guidance/DRL/critic.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.nn.utils.weight_norm as weightNorm
5 |
6 | from torch.autograd import Variable
7 | import sys
8 |
9 |
10 | def conv3x3(in_planes, out_planes, stride=1):
11 | return weightNorm(nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True))
12 |
13 |
14 | class TReLU(nn.Module):
15 | def __init__(self):
16 | super(TReLU, self).__init__()
17 | self.alpha = nn.Parameter(torch.FloatTensor(1), requires_grad=True)
18 | self.alpha.data.fill_(0)
19 |
20 | def forward(self, x):
21 | x = F.relu(x - self.alpha) + self.alpha
22 | return x
23 |
24 |
25 | def cfg(depth):
26 | depth_lst = [18, 34, 50, 101, 152]
27 | assert (depth in depth_lst), "Error : Resnet depth should be either 18, 34, 50, 101, 152"
28 | cf_dict = {
29 | '18': (BasicBlock, [2, 2, 2, 2]),
30 | '34': (BasicBlock, [3, 4, 6, 3]),
31 | '50': (Bottleneck, [3, 4, 6, 3]),
32 | '101': (Bottleneck, [3, 4, 23, 3]),
33 | '152': (Bottleneck, [3, 8, 36, 3]),
34 | }
35 |
36 | return cf_dict[str(depth)]
37 |
38 |
39 | class BasicBlock(nn.Module):
40 | expansion = 1
41 |
42 | def __init__(self, in_planes, planes, stride=1):
43 | super(BasicBlock, self).__init__()
44 | self.conv1 = conv3x3(in_planes, planes, stride)
45 | self.conv2 = conv3x3(planes, planes)
46 |
47 | self.shortcut = nn.Sequential()
48 | if stride != 1 or in_planes != self.expansion * planes:
49 | self.shortcut = nn.Sequential(
50 | weightNorm(nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=True)),
51 | )
52 | self.relu_1 = TReLU()
53 | self.relu_2 = TReLU()
54 |
55 | def forward(self, x):
56 | out = self.relu_1(self.conv1(x))
57 | out = self.conv2(out)
58 | out += self.shortcut(x)
59 | out = self.relu_2(out)
60 |
61 | return out
62 |
63 |
64 | class Bottleneck(nn.Module):
65 | expansion = 4
66 |
67 | def __init__(self, in_planes, planes, stride=1):
68 | super(Bottleneck, self).__init__()
69 | self.conv1 = weightNorm(nn.Conv2d(in_planes, planes, kernel_size=1, bias=True))
70 | self.conv2 = weightNorm(nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True))
71 | self.conv3 = weightNorm(nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=True))
72 | self.relu_1 = TReLU()
73 | self.relu_2 = TReLU()
74 | self.relu_3 = TReLU()
75 |
76 | self.shortcut = nn.Sequential()
77 | if stride != 1 or in_planes != self.expansion * planes:
78 | self.shortcut = nn.Sequential(
79 | weightNorm(nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=True)),
80 | )
81 |
82 | def forward(self, x):
83 | out = self.relu_1(self.conv1(x))
84 | out = self.relu_2(self.conv2(out))
85 | out = self.conv3(out)
86 | out += self.shortcut(x)
87 | out = self.relu_3(out)
88 |
89 | return out
90 |
91 |
92 | class ResNet_wobn(nn.Module):
93 | def __init__(self, num_inputs, depth, num_outputs, high_res=False):
94 | super(ResNet_wobn, self).__init__()
95 | self.in_planes = 64
96 |
97 | block, num_blocks = cfg(depth)
98 |
99 | self.conv1 = conv3x3(num_inputs, 64, 2)
100 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=2)
101 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
102 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
103 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
104 | self.fc = nn.Linear(512, num_outputs)
105 | self.relu_1 = TReLU()
106 | self.high_res = high_res
107 |
108 | def _make_layer(self, block, planes, num_blocks, stride):
109 | strides = [stride] + [1] * (num_blocks - 1)
110 | layers = []
111 |
112 | for stride in strides:
113 | layers.append(block(self.in_planes, planes, stride))
114 | self.in_planes = planes * block.expansion
115 |
116 | return nn.Sequential(*layers)
117 |
118 | def forward(self, x):
119 | x = self.relu_1(self.conv1(x))
120 | x = self.layer1(x)
121 | x = self.layer2(x)
122 | x = self.layer3(x)
123 | x = self.layer4(x)
124 | if self.high_res is True:
125 | x = F.avg_pool2d(x, 8)
126 | else:
127 | x = F.avg_pool2d(x, 4)
128 | x = x.view(x.size(0), -1)
129 | x = self.fc(x)
130 | return x
131 |
--------------------------------------------------------------------------------
/semantic_guidance/DRL/ddpg.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.optim import Adam, SGD
6 | from torch.distributions import Categorical
7 | import Renderer.model as renderer
8 | from DRL.rpm import rpm
9 | from DRL.actor import ResNet
10 | from DRL.critic import ResNet_wobn
11 | from DRL.wgan import WGAN
12 | import utils.util as util
13 | import torch.distributed as dist
14 | from torch.nn.parallel import DistributedDataParallel as DDP
15 | from DRL.multi import fastenv
16 |
17 | # setup device
18 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19 | criterion = nn.MSELoss()
20 |
21 |
22 | def cal_trans(s, t):
23 | return (s.transpose(0, 3) * t).transpose(0, 3)
24 |
25 |
26 | class DDPG(object):
27 | """
28 | Model-based DDPG agent class
29 | """
30 |
31 | def __init__(self, batch_size=64, nenv=1, max_eps_len=40, tau=0.001, discount=0.9, rmsize=800,
32 | writer=None, load_path=None, output_path=None, dataset='celeba', use_gbp=False, use_bilevel=False,
33 | gbp_coef=1.0, seggt_coef=1.0, gpu=0, distributed=False, high_res=False,
34 | bundle_size=5):
35 | # hyperparameters
36 | self.max_eps_len = max_eps_len
37 | self.nenv = nenv
38 | self.batch_size = batch_size
39 | self.gpu = gpu
40 | self.distributed = distributed
41 | self.bundle_size = bundle_size
42 |
43 | # set torch device
44 | torch.cuda.set_device(gpu)
45 |
46 | # gbp and seggt rewards
47 | self.use_gbp = use_gbp
48 | self.use_bilevel = use_bilevel
49 | self.gbp_coef = gbp_coef
50 | self.seggt_coef = seggt_coef
51 |
52 | # Multi-res and high-res
53 | self.high_res = high_res
54 |
55 | # environment
56 | self.env = fastenv(max_eps_len, nenv, writer, dataset, use_gbp, use_bilevel, gpu=gpu, high_res=self.high_res,
57 | bundle_size=bundle_size)
58 |
59 | # setup local and target actor, critic networks
60 | # input: target, canvas, stepnum, coordconv + gbp 3 + 3 + 1 + 2
61 | # output: (10+3)*5 (action bundle)
62 | self.actor = ResNet(9 + use_gbp + use_bilevel + 2*use_bilevel, 18, 13 * bundle_size * (1 + use_bilevel), self.high_res)
63 | self.actor_target = ResNet(9 + use_gbp + use_bilevel + 2*use_bilevel, 18, 13 * bundle_size * (1 + use_bilevel),
64 | self.high_res)
65 | self.critic = ResNet_wobn(3 + 9 + use_gbp + use_bilevel + 2*use_bilevel, 18, 1,
66 | self.high_res) # add the last canvas for better prediction
67 | self.critic_target = ResNet_wobn(3 + 9 + use_gbp + use_bilevel + 2*use_bilevel, 18, 1, self.high_res)
68 |
69 | for param in self.actor_target.parameters():
70 | param.requires_grad = False
71 |
72 | for param in self.critic_target.parameters():
73 | param.requires_grad = False
74 |
75 | # define gan
76 | self.wgan = WGAN(gpu, distributed, high_res=self.high_res)
77 | self.wgan_bg = WGAN(gpu, distributed, high_res=self.high_res)
78 |
79 | # transfer models to gpu
80 | self.choose_device()
81 |
82 | # optimizers for actor/critic models
83 | self.actor_optim = Adam(self.actor.parameters(), lr=1e-2)
84 | self.critic_optim = Adam(self.critic.parameters(), lr=1e-2)
85 |
86 | # load actor/critic/gan models if given a load path
87 | if (load_path != None):
88 | self.load_weights(load_path)
89 |
90 | # set same initial weights for local and target networks
91 | util.hard_update(self.actor_target, self.actor)
92 | util.hard_update(self.critic_target, self.critic)
93 |
94 | # Create replay buffer (to store max rmsize hyperparameters)
95 | self.memory = rpm(rmsize * max_eps_len, gpu)
96 |
97 | # training hyper-parameters
98 | self.tau = tau
99 | self.discount = discount
100 |
101 | # initialize summary logs
102 | self.writer = writer
103 | self.log = 0
104 |
105 | # initialize state, action logs
106 | self.state = [None] * self.nenv # Most recent state
107 | self.action = [None] * self.nenv # Most recent action
108 |
109 | self.get_coord_feats()
110 |
111 | def get_coord_feats(self):
112 | # x,y coordinates for 128 x 128 or 256x256 image
113 | if self.high_res is True:
114 | coord = torch.zeros([1, 2, 256, 256])
115 | for i in range(256):
116 | for j in range(256):
117 | coord[0, 0, i, j] = i / 255.
118 | coord[0, 1, i, j] = j / 255.
119 | else:
120 | coord = torch.zeros([1, 2, 128, 128])
121 | for i in range(128):
122 | for j in range(128):
123 | coord[0, 0, i, j] = i / 127.
124 | coord[0, 1, i, j] = j / 127.
125 |
126 | self.coord = coord.cuda(self.gpu)
127 |
128 | def act(self, state, target=False):
129 | """
130 | take action for current state
131 | :param state: merged state (canvas,gt,t/max_eps_len,coordinates)
132 | :param target: bool for whether the target actor model should be used.
133 | :return: action (stroke parameters)
134 | """
135 | if self.high_res is True:
136 | state_list = [state[:, :6].float() / 255, state[:, 6:7].float() / self.max_eps_len,
137 | self.coord.expand(state.shape[0], 2, 256, 256)]
138 | else:
139 | state_list = [state[:, :6].float() / 255, state[:, 6:7].float() / self.max_eps_len,
140 | self.coord.expand(state.shape[0], 2, 128, 128)]
141 |
142 | if self.use_gbp:
143 | state_list += [state[:, 7:8].float() / 255.]
144 | if self.use_bilevel:
145 | state_list += [state[:, 7 + self.use_gbp: 7 + self.use_gbp + 1 + 2].float() / 255.]
146 | # state_list += [state[:, 7 + self.use_gbp + 1: 7 + self.use_gbp + 1].float() / 255.]
147 | # define merged state (canvas,gt,t/max_eps_len,coordinates)
148 | state = torch.cat(state_list, 1)
149 |
150 | if target:
151 | return self.actor_target(state)
152 | else:
153 | return self.actor(state)
154 |
155 | def update_gan(self, state):
156 | """
157 | update WGAN based on current state (first 3 channels for canvas and last three for groundtruth)
158 | :param state:
159 | :return: None
160 | """
161 | # get canvas and groundtruth from the state
162 | canvas = state[:, :3].float() / 255
163 | gt = state[:, 3: 6].float() / 255
164 |
165 | # update gan based on canvas and background groundtruth images
166 | fake, real, penal = self.wgan_bg.update(canvas, gt)
167 |
168 | if self.use_bilevel:
169 | seggt = state[:, 7 + self.use_gbp: 7 + self.use_gbp + 1].float() / 255.
170 | grid = state[:, 7 + self.use_gbp + 1: 7 + self.use_gbp + 1 + 2].float() / 255.
171 | grid = 2 * grid - 1
172 | canvas = torch.nn.functional.grid_sample(canvas * seggt, grid.permute(0, 2, 3, 1))
173 | gt = torch.nn.functional.grid_sample(gt * seggt, grid.permute(0, 2, 3, 1))
174 |
175 | # update gan based on canvas and groundtruth images
176 | fake, real, penal = self.wgan.update(canvas, gt)
177 |
178 |
179 | def evaluate(self, state, action, target=False):
180 | """
181 | compute model performance (rewards) for given state, action (used for both training and testing
182 | based on whether target or local network is used)
183 | :param state: combined state (canvas,ground-truth)
184 | :param action: stroke parameters (10 for position and 3 for color)
185 | :param target: bool for whether the target critic model should be used.
186 | :return: critic value + gan reward (used for training actor in model-based DDPG), gan reward
187 | """
188 |
189 | # get canvas, ground-truth, time from merged state (gt,canvas,t)
190 | gt = state[:, 3: 6].float() / 255
191 | canvas0 = state[:, :3].float() / 255
192 | T = state[:, 6: 7]
193 |
194 | if self.use_gbp:
195 | gbpgt = state[:, 7:8].float() / 255.
196 | if self.use_bilevel:
197 | seggt = state[:, 7 + self.use_gbp: 7 + self.use_gbp + 1].float() / 255.
198 | grid = state[:, 7 + self.use_gbp + 1: 7 + self.use_gbp + 1 + 2].float() / 255.
199 | grid_ = 2 * grid - 1
200 |
201 | # update canvas given current action
202 | if self.use_bilevel:
203 | canvas1, stroke_masks = self.env.env.decode_parallel(action, canvas0,
204 | mask=self.use_gbp or self.use_bilevel,
205 | seg_mask=seggt)
206 | else:
207 | canvas1, stroke_masks = self.env.env.decode(action, canvas0, mask=self.use_gbp or self.use_bilevel)
208 |
209 | # if self.use_bilevel:
210 | # # canvas1 = canvas1 * seggt
211 | # # canvas0 = canvas0 * seggt
212 |
213 | # compute bg gan reward based on difference between wgan distances (L_t - L_t-1)
214 | bg_reward = self.wgan_bg.cal_reward(canvas1, gt) - self.wgan_bg.cal_reward(canvas0, gt)
215 | bg_reward = bg_reward.view(-1)
216 | # gan_reward = ((canvas0 - gt) ** 2).mean(1).mean(1).mean(1) - ((canvas1 - gt) ** 2).mean(1).mean(1).mean(1)
217 |
218 | if self.use_gbp:
219 | gbp_reward = (((canvas0 - gt) * gbpgt) ** 2).mean(1).sum(1).sum(1) \
220 | - (((canvas1 - gt) * gbpgt) ** 2).mean(1).sum(1).sum(1)
221 |
222 | gbp_reward = gbp_reward / torch.sum(gbpgt, dim=(1, 2, 3))
223 | gbp_reward = 1e3 * gbp_reward
224 | else:
225 | gbp_reward = torch.tensor(0.)
226 |
227 | if self.use_bilevel:
228 | canvas1_ = torch.nn.functional.grid_sample(canvas1 * seggt, grid_.permute(0, 2, 3, 1))
229 | canvas0_ = torch.nn.functional.grid_sample(canvas0 * seggt, grid_.permute(0, 2, 3, 1))
230 | # gt_, canvas0_, canvas1_ = self.nalignment(gt,canvas0_,canvas1_)
231 | gt_ = torch.nn.functional.grid_sample(gt * seggt, grid_.permute(0, 2, 3, 1))
232 | # compute foreground reward
233 | foreground_reward = self.wgan.cal_reward(canvas1_, gt_) - self.wgan.cal_reward(canvas0_, gt_)
234 | # foreground_reward = ((canvas0_ - gt_) ** 2).mean(1).mean(1).mean(1) - ((canvas1_ - gt_) ** 2).mean(1).mean(1).mean(1)
235 | foreground_reward = 2e0 * foreground_reward.view(-1)
236 | else:
237 | foreground_reward = torch.tensor(0.)
238 |
239 | # total reward
240 | total_reward = bg_reward + gbp_reward + foreground_reward
241 |
242 | # get new merged state
243 | if self.high_res is True:
244 | coord_ = self.coord.expand(state.shape[0], 2, 256, 256)
245 | else:
246 | coord_ = self.coord.expand(state.shape[0], 2, 128, 128)
247 | state_list = [canvas0, canvas1, gt, (T + 1).float() / self.max_eps_len, coord_]
248 |
249 | if self.use_gbp:
250 | state_list += [state[:, 7:8].float() / 255.]
251 | if self.use_bilevel:
252 | state_list += [seggt]
253 | state_list += [grid]
254 |
255 | # compute merged state
256 | merged_state = torch.cat(state_list, 1)
257 | # canvas0 is not necessarily added
258 |
259 | if target:
260 | # compute Q from target network
261 | Q = self.critic_target(merged_state)
262 | else:
263 | # compute Q from local network
264 | Q = self.critic(merged_state)
265 | return (Q + total_reward), total_reward
266 |
267 | def update_policy(self, lr):
268 | """
269 | update actor, critic using current replay memory buffer and given learning rate
270 | :param lr: learning rate
271 | :return: negative policy loss (current expected reward), value loss
272 | """
273 | self.log += 1
274 |
275 | # set different learning rate for actor and critic
276 | for param_group in self.critic_optim.param_groups:
277 | param_group['lr'] = lr[0]
278 | for param_group in self.actor_optim.param_groups:
279 | param_group['lr'] = lr[1]
280 |
281 | # sample a batch from the replay buffer
282 | state, action, reward, next_state, terminal = self.memory.sample_batch(self.batch_size, device)
283 |
284 | # update gan model
285 | self.update_gan(next_state)
286 |
287 | # Q-learning: Q(s,a) = r(s,a) + gamma * Q(s',a')
288 | with torch.no_grad():
289 | next_action = self.act(next_state, True)
290 | target_q, _ = self.evaluate(next_state, next_action, True)
291 | target_q = self.discount * ((1 - terminal.float()).view(-1, 1)) * target_q
292 |
293 | # add r(s,a) to Q(s,a)
294 | cur_q, step_reward = self.evaluate(state, action)
295 | target_q += step_reward.detach()
296 |
297 | # critic loss and update
298 | value_loss = criterion(cur_q, target_q)
299 | self.critic.zero_grad()
300 | value_loss.backward(retain_graph=True)
301 | self.critic_optim.step()
302 |
303 | # actor loss and update
304 | action = self.act(state)
305 | pre_q, _ = self.evaluate(state.detach(), action)
306 | policy_loss = -pre_q.mean()
307 | self.actor.zero_grad()
308 | policy_loss.backward(retain_graph=True)
309 | self.actor_optim.step()
310 |
311 | # Soft-update target networks for both actor and critic
312 | util.soft_update(self.actor_target, self.actor, self.tau)
313 | util.soft_update(self.critic_target, self.critic, self.tau)
314 |
315 | return -policy_loss, value_loss
316 |
317 | def observe(self, reward, state, done, step):
318 | """
319 | Store observed sample in replay buffer
320 | :param reward:
321 | :param state:
322 | :param done:
323 | :param step: step count within an episode
324 | :return: None
325 | """
326 | s0 = self.state.clone().detach().cpu() # torch.tensor(self.state, device='cpu')
327 | a = util.to_tensor(self.action, "cpu")
328 | r = util.to_tensor(reward, "cpu")
329 | s1 = state.clone().detach().cpu() # torch.tensor(state, device='cpu')
330 | d = util.to_tensor(done.astype('float32'), "cpu")
331 | for i in range(self.nenv):
332 | self.memory.append([s0[i], a[i], r[i], s1[i], d[i]])
333 | self.state = state
334 |
335 | def noise_action(self, noise_factor, state, action):
336 | """
337 | Add gaussian noise to continuous actions (stroke params) with zero mean and self.noise_level[i] variance
338 | :param noise_factor:
339 | :param state:
340 | :param action:
341 | :return: action (stroke params) clipped between 0,1
342 | """
343 | noise = np.zeros(action.shape)
344 | for i in range(self.nenv):
345 | action[i] = action[i] + np.random.normal(0, self.noise_level[i], action.shape[1:]).astype('float32')
346 | return np.clip(action.astype('float32'), 0, 1)
347 |
348 | def select_action(self, state, return_fix=False, noise_factor=0):
349 | """
350 | compute action given a state and noise_factor
351 | :param state:
352 | :param return_fix:
353 | :param noise_factor:
354 | :return:
355 | """
356 | self.eval()
357 | # compute action
358 | with torch.no_grad():
359 | action = self.act(state)
360 | action = util.to_numpy(action)
361 | # add noise to action
362 | if noise_factor > 0:
363 | action = self.noise_action(noise_factor, state, action)
364 | self.train()
365 |
366 | self.action = action
367 | if return_fix:
368 | return action
369 | return self.action
370 |
371 | def reset(self, obs, factor):
372 | self.state = obs
373 | self.noise_level = np.random.uniform(0, factor, self.nenv)
374 |
375 | def decode(self, x, canvas, mask=False): # b * (10 + 3)
376 | """
377 | Update canvas given stroke parameters x
378 | :param x: stroke parameters
379 | :param canvas: current canvas state
380 | :return: updated canvas with stroke drawn
381 | """
382 | # 13 stroke parameters (10 position and 3 RGB color)
383 | x = x.view(-1, 10 + 3)
384 |
385 | # get stroke on an empty canvas given 10 positional parameters
386 | stroke = 1 - self.decoder(x[:, :10])
387 | stroke = stroke.view(-1, 128, 128, 1)
388 |
389 | # add color to the stroke
390 | color_stroke = stroke * x[:, -3:].view(-1, 1, 1, 3)
391 | stroke = stroke.permute(0, 3, 1, 2)
392 | color_stroke = color_stroke.permute(0, 3, 1, 2)
393 |
394 | # draw bundle_size=5 strokes at a time (action bundle)
395 | stroke = stroke.view(-1, self.bundle_size, 1, 128, 128)
396 | color_stroke = color_stroke.view(-1, self.bundle_size, 3, 128, 128)
397 | for i in range(self.bundle_size):
398 | canvas = canvas * (1 - stroke[:, i]) + color_stroke[:, i]
399 |
400 | # also return stroke mask if required
401 | stroke_mask = None
402 | if mask:
403 | stroke_mask = (stroke != 0).float() # -1, bundle_size, 1, width, width
404 |
405 | return canvas, stroke_mask
406 |
407 | def load_weights(self, path, map_location=None, num_episodes=0):
408 | """
409 | load actor,critic,gan from given paths
410 | :param path:
411 | :return:
412 | """
413 | if map_location is None:
414 | map_location = {'cuda:0': 'cuda:{}'.format(self.gpu)}
415 |
416 | if path is None: return
417 |
418 | self.actor.load_state_dict(
419 | torch.load('{}/actor_{:05}.pkl'.format(path, num_episodes), map_location=map_location))
420 | self.critic.load_state_dict(
421 | torch.load('{}/critic_{:05}.pkl'.format(path, num_episodes), map_location=map_location))
422 | self.wgan.load_gan(path, map_location, num_episodes)
423 | self.wgan_bg.load_gan(path, map_location, num_episodes + 1)
424 |
425 | def save_model(self, path, num_episodes):
426 | """
427 | save trained actor,critic,gan models
428 | :param path: save parent dir
429 | :return: None
430 | """
431 | if self.gpu == 0:
432 | torch.save(self.actor.state_dict(), "{}/actor_{:05}.pkl".format(path, num_episodes))
433 | torch.save(self.critic.state_dict(), '{}/critic_{:05}.pkl'.format(path, num_episodes))
434 | self.wgan.save_gan(path, num_episodes)
435 | self.wgan_bg.save_gan(path, num_episodes + 1)
436 |
437 | # Use a barrier() to make sure that process 1 loads the model after process
438 | # 0 saves it.
439 | dist.barrier()
440 | # configure map_location properly
441 | device_pairs = [(0, self.gpu)]
442 | map_location = {'cuda:%d' % x: 'cuda:%d' % y for x, y in device_pairs}
443 | self.load_weights(path, map_location, num_episodes)
444 | dist.barrier()
445 | print("done saving on cuda:{}".format(self.gpu))
446 |
447 | def eval(self):
448 | """
449 | set actor, critic in eval mode
450 | :return: None
451 | """
452 | self.actor.eval()
453 | self.actor_target.eval()
454 | self.critic.eval()
455 | self.critic_target.eval()
456 |
457 | def train(self):
458 | """
459 | set actor, critic in train mode
460 | :return: None
461 | """
462 | self.actor.train()
463 | self.actor_target.train()
464 | self.critic.train()
465 | self.critic_target.train()
466 |
467 | def choose_device(self):
468 | """
469 | transfer renderer, actor, critic to device
470 | :return: None
471 | """
472 | self.actor.cuda(self.gpu)
473 | self.actor_target.cuda(self.gpu)
474 | self.critic.cuda(self.gpu)
475 | self.critic_target.cuda(self.gpu)
476 |
477 | if self.distributed:
478 | self.actor = DDP(self.actor, device_ids=[self.gpu])
479 | self.critic = DDP(self.critic, device_ids=[self.gpu])
480 |
--------------------------------------------------------------------------------
/semantic_guidance/DRL/evaluator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from utils.util import *
3 |
4 | class Evaluator(object):
5 |
6 | def __init__(self, args, writer):
7 | self.val_num_eps = args.val_num_eps
8 | self.max_eps_len = args.max_eps_len
9 | self.nenv = args.nenv
10 | self.writer = writer
11 | self.log = 0
12 |
13 | def __call__(self, env, policy, debug=False):
14 | observation = None
15 | for episode in range(self.val_num_eps):
16 | # reset at the start of episode
17 | observation = env.reset(test=True, episode=episode)
18 | episode_steps = 0
19 | episode_reward = 0.
20 | assert observation is not None
21 | # start episode
22 | episode_reward = np.zeros(self.nenv)
23 | while (episode_steps < self.max_eps_len or not self.max_eps_len):
24 | action = policy(observation)
25 | observation, reward, done, (step_num) = env.step(action)
26 | episode_reward += reward
27 | episode_steps += 1
28 | env.save_image(self.log, episode_steps)
29 | dist = env.get_dist()
30 | self.log += 1
31 | return episode_reward, dist
32 |
--------------------------------------------------------------------------------
/semantic_guidance/DRL/multi.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import torch
3 | import numpy as np
4 | from env_ins import Paint
5 | import utils.util as util
6 |
7 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8 |
9 |
10 | class fastenv:
11 | def __init__(self, max_episode_length=10, nenv=64,
12 | writer=None, dataset='cub200', use_gbp=False, use_bilevel=False, gpu=0, high_res=False,
13 | bundle_size=5):
14 | self.max_episode_length = max_episode_length
15 | self.nenv = nenv
16 | self.env = Paint(self.nenv, self.max_episode_length, dataset=dataset, use_gbp=use_gbp,
17 | use_bilevel=use_bilevel, gpu=gpu, high_res=high_res, bundle_size=bundle_size)
18 | self.env.load_data()
19 | self.observation_space = self.env.observation_space
20 | self.action_space = self.env.action_space
21 | self.writer = writer
22 | self.test = False
23 | self.log = 0
24 | self.gpu = gpu
25 | self.use_gbp = use_gbp
26 | self.use_bilevel = use_bilevel
27 |
28 | def nalignment(self, gt, canvas0):
29 | gt_ = (gt - self.mean) / self.std
30 | predictions = self.loc_model(gt_)
31 | M = torch.matmul(predictions, self.P)
32 | M = M - self.c_
33 | M = M.view(-1, 2, 3)
34 | grid = torch.nn.functional.affine_grid(M, gt.size())
35 | z_gt = torch.nn.functional.grid_sample(gt, grid.float())
36 | z_canvas0 = torch.nn.functional.grid_sample(canvas0, grid.float())
37 | # z_canvas1 = torch.nn.functional.grid_sample(canvas1, grid.float())
38 | return z_gt, z_canvas0
39 |
40 | def save_image(self, log, step):
41 | if self.gpu == 0:
42 | for i in range(self.nenv):
43 | if self.env.imgid[i] <= 10:
44 | # write background images
45 | canvas = util.to_numpy(self.env.canvas[i, :3].permute(1, 2, 0))
46 | self.writer.add_image('{}/canvas_{}.png'.format(str(self.env.imgid[i]), str(step)), canvas, log)
47 | if step == self.max_episode_length:
48 | if self.use_bilevel:
49 | # z_gt, z_canvas = self.nalignment(self.env.gt[:,:3].float() / 255,self.env.canvas[:,:3].float() / 255)
50 | grid = self.env.grid[:, :2].float() / 255
51 | grid = 2 * grid - 1
52 | z_gt = torch.nn.functional.grid_sample(self.env.gt[:, :3].float() / 255, grid.permute(0, 2, 3, 1))
53 | z_canvas = torch.nn.functional.grid_sample(self.env.canvas[:, :3].float() / 255,
54 | grid.permute(0, 2, 3, 1))
55 | for i in range(self.nenv):
56 | if self.env.imgid[i] < 50:
57 | # write background images
58 | gt = util.to_numpy(self.env.gt[i, :3].permute(1, 2, 0))
59 | canvas = util.to_numpy(self.env.canvas[i, :3].permute(1, 2, 0))
60 | self.writer.add_image(str(self.env.imgid[i]) + '/_target.png', gt, log)
61 | self.writer.add_image(str(self.env.imgid[i]) + '/_canvas.png', canvas, log)
62 | if self.use_bilevel:
63 | # # also write foreground images
64 | gt = util.to_numpy(z_gt[i, :3].permute(1, 2, 0))
65 | canvas = util.to_numpy(z_canvas[i, :3].permute(1, 2, 0))
66 | self.writer.add_image(str(self.env.imgid[i]) + '_foreground/_target.png', gt, log)
67 | self.writer.add_image(str(self.env.imgid[i]) + '_foreground/_canvas.png', canvas, log)
68 |
69 | def step(self, action):
70 | with torch.no_grad():
71 | ob, r, d, _ = self.env.step(torch.tensor(action).cuda(self.gpu))
72 | return ob, r, d, _
73 |
74 | def get_dist(self):
75 | return util.to_numpy(
76 | (((self.env.gt[:, :3].float() - self.env.canvas[:, :3].float()) / 255) ** 2).mean(1).mean(1).mean(1))
77 |
78 | def reset(self, test=False, episode=0):
79 | self.test = test
80 | ob = self.env.reset(self.test, episode * self.nenv)
81 | return ob
82 |
--------------------------------------------------------------------------------
/semantic_guidance/DRL/rpm.py:
--------------------------------------------------------------------------------
1 | # from collections import deque
2 | import numpy as np
3 | import random
4 | import torch
5 | import pickle as pickle
6 |
7 | class rpm(object):
8 | # replay memory
9 | def __init__(self, buffer_size, gpu=0):
10 | self.buffer_size = buffer_size
11 | self.buffer = []
12 | self.index = 0
13 | self.gpu = gpu
14 |
15 | def append(self, obj):
16 | if self.size() > self.buffer_size:
17 | print('buffer size larger than set value, trimming...')
18 | self.buffer = self.buffer[(self.size() - self.buffer_size):]
19 | elif self.size() == self.buffer_size:
20 | self.buffer[self.index] = obj
21 | self.index += 1
22 | self.index %= self.buffer_size
23 | else:
24 | self.buffer.append(obj)
25 |
26 | def size(self):
27 | return len(self.buffer)
28 |
29 | def sample_batch(self, batch_size, device, only_state=False):
30 | if self.size() < batch_size:
31 | batch = random.sample(self.buffer, self.size())
32 | else:
33 | batch = random.sample(self.buffer, batch_size)
34 |
35 | if only_state:
36 | res = torch.stack(tuple(item[3] for item in batch), dim=0)
37 | return res.cuda(self.gpu)
38 | else:
39 | item_count = 5
40 | res = []
41 | for i in range(5):
42 | k = torch.stack(tuple(item[i] for item in batch), dim=0)
43 | res.append(k.cuda(self.gpu))
44 | return res[0], res[1], res[2], res[3], res[4]
45 |
--------------------------------------------------------------------------------
/semantic_guidance/DRL/wgan.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from torch.optim import Adam, SGD
5 | from torch import autograd
6 | from torch.autograd import Variable
7 | import torch.nn.functional as F
8 | from torch.autograd import grad as torch_grad
9 | import torch.nn.utils.weight_norm as weightNorm
10 | import utils.util as util
11 | from torch.nn.parallel import DistributedDataParallel as DDP
12 |
13 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14 |
15 |
16 | # dim = 128
17 | # LAMBDA = 10 # Gradient penalty lambda hyperparameter
18 |
19 |
20 | class TReLU(nn.Module):
21 | def __init__(self):
22 | super(TReLU, self).__init__()
23 | self.alpha = nn.Parameter(torch.FloatTensor(1), requires_grad=True)
24 | self.alpha.data.fill_(0)
25 |
26 | def forward(self, x):
27 | x = F.relu(x - self.alpha) + self.alpha
28 | return x
29 |
30 |
31 | class Discriminator(nn.Module):
32 | def __init__(self, high_res=False):
33 | super(Discriminator, self).__init__()
34 |
35 | self.conv0 = weightNorm(nn.Conv2d(6, 16, 5, 2, 2))
36 | self.conv1 = weightNorm(nn.Conv2d(16, 32, 5, 2, 2))
37 | self.conv2 = weightNorm(nn.Conv2d(32, 64, 5, 2, 2))
38 | self.conv3 = weightNorm(nn.Conv2d(64, 128, 5, 2, 2))
39 | self.conv4 = weightNorm(nn.Conv2d(128, 1, 5, 2, 2))
40 | self.relu0 = TReLU()
41 | self.relu1 = TReLU()
42 | self.relu2 = TReLU()
43 | self.relu3 = TReLU()
44 | self.high_res = high_res
45 |
46 | def forward(self, x):
47 | x = self.conv0(x)
48 | x = self.relu0(x)
49 | x = self.conv1(x)
50 | x = self.relu1(x)
51 | x = self.conv2(x)
52 | x = self.relu2(x)
53 | x = self.conv3(x)
54 | x = self.relu3(x)
55 | x = self.conv4(x)
56 | if self.high_res is True:
57 | x = F.avg_pool2d(x, 8)
58 | else:
59 | x = F.avg_pool2d(x, 4)
60 | x = x.view(-1, 1)
61 | return x
62 |
63 |
64 | class WGAN:
65 | def __init__(self, gpu=0, distributed=False, dim=128, high_res=False):
66 | self.gpu = gpu
67 | self.distributed = distributed
68 | self.high_res = high_res
69 |
70 | if self.high_res is True:
71 | self.dim = 256
72 | else:
73 | self.dim = 128
74 |
75 | self.netD = Discriminator(high_res=self.high_res)
76 | self.target_netD = Discriminator(high_res=self.high_res)
77 |
78 | self.netD = self.netD.cuda(gpu)
79 | self.target_netD = self.target_netD.cuda(gpu)
80 |
81 | for param in self.target_netD.parameters():
82 | param.requires_grad = False
83 |
84 | if distributed:
85 | self.netD = DDP(self.netD, device_ids=[self.gpu])
86 | # self.target_netD = DDP(self.target_netD, device_ids=[self.gpu])
87 |
88 | util.hard_update(self.target_netD, self.netD)
89 |
90 | self.optimizerD = Adam(self.netD.parameters(), lr=3e-4, betas=(0.5, 0.999))
91 | # self.dim = dim
92 | self.LAMBDA = 10 # Gradient penalty lambda hyperparameter
93 |
94 | def cal_gradient_penalty(self, real_data, fake_data, batch_size):
95 | alpha = torch.rand(batch_size, 1)
96 | alpha = alpha.expand(batch_size, int(real_data.nelement() / batch_size)).contiguous()
97 | alpha = alpha.view(batch_size, 6, self.dim, self.dim)
98 | alpha = alpha.cuda(self.gpu)
99 | fake_data = fake_data.view(batch_size, 6, self.dim, self.dim)
100 | interpolates = Variable(alpha * real_data.data + ((1 - alpha) * fake_data.data), requires_grad=True)
101 | disc_interpolates = self.netD(interpolates)
102 | gradients = autograd.grad(disc_interpolates, interpolates,
103 | grad_outputs=torch.ones(disc_interpolates.size()).cuda(self.gpu),
104 | create_graph=True, retain_graph=True)[0]
105 | gradients = gradients.view(gradients.size(0), -1)
106 | gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * self.LAMBDA
107 | return gradient_penalty
108 |
109 | def cal_reward(self, fake_data, real_data):
110 | return self.target_netD(torch.cat([real_data, fake_data], 1))
111 |
112 | def save_gan(self, path, num_episodes):
113 | # self.netD.cpu()
114 | torch.save(self.netD.state_dict(), '{}/wgan_{:05}.pkl'.format(path, num_episodes))
115 | # self.netD.cuda(self.gpu)
116 |
117 | def load_gan(self, path, map_location, num_episodes):
118 | self.netD.load_state_dict(torch.load('{}/wgan_{:05}.pkl'.format(path, num_episodes), map_location=map_location))
119 |
120 | def random_masks(self):
121 | """
122 | Generate random masks for complement discriminator (need to make mask overlap with the stroke)
123 | :return: mask
124 | """
125 | # initialize mask
126 | mask = np.ones((3, self.dim, self.dim))
127 |
128 | # generate one of 4 random masks
129 | choose = 1 # np.random.randint(0, 1)
130 | if choose == 0:
131 | mask[:, :self.dim // 2] = 0
132 | elif choose == 1:
133 | mask[:, :, :self.dim // 2] = 0
134 | elif choose == 2:
135 | mask[:, :, self.dim // 2:] = 0
136 | elif choose == 3:
137 | mask[:, self.dim // 2:] = 0
138 |
139 | return mask
140 |
141 | def update(self, fake_data, real_data):
142 | fake_data = fake_data.detach()
143 | real_data = real_data.detach()
144 |
145 | # standard conditional training for discriminator
146 | fake = torch.cat([real_data, fake_data], 1)
147 | real = torch.cat([real_data, real_data], 1)
148 |
149 | # # complement discriminator conditional training for discriminator
150 | # mask = torch.tensor(random_masks()).float().to(device)
151 | # fake = torch.cat([(1 - mask) * real_data, mask * fake_data], 1)
152 | # real = torch.cat([(1 - mask) * real_data, mask * real_data], 1)
153 |
154 | # compute discriminator scores for real and fake data
155 | D_real = self.netD(real)
156 | D_fake = self.netD(fake)
157 |
158 | gradient_penalty = self.cal_gradient_penalty(real, fake, real.shape[0])
159 | self.optimizerD.zero_grad()
160 | D_cost = D_fake.mean() - D_real.mean() + gradient_penalty
161 | D_cost.backward()
162 | self.optimizerD.step()
163 | util.soft_update(self.target_netD, self.netD, 0.001)
164 | return D_fake.mean(), D_real.mean(), gradient_penalty
165 |
--------------------------------------------------------------------------------
/semantic_guidance/Renderer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1jsingh/semantic-guidance/092f923a0f5d42a949489bd49df8449db9d94d53/semantic_guidance/Renderer/__init__.py
--------------------------------------------------------------------------------
/semantic_guidance/Renderer/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.nn.utils.weight_norm as weightNorm
5 |
6 |
7 | class FCN(nn.Module):
8 | def __init__(self, high_res=False):
9 | super(FCN, self).__init__()
10 | self.fc1 = (nn.Linear(10, 512))
11 | self.fc2 = (nn.Linear(512, 1024))
12 | self.fc3 = (nn.Linear(1024, 2048))
13 | self.fc4 = (nn.Linear(2048, 4096))
14 | self.high_res = high_res
15 |
16 | if self.high_res is True:
17 | self.conv1 = (nn.Conv2d(16, 64, 3, 1, 1))
18 | self.conv2 = (nn.Conv2d(64, 64, 3, 1, 1))
19 | self.conv3 = (nn.Conv2d(16, 32, 3, 1, 1))
20 | self.conv4 = (nn.Conv2d(32, 32, 3, 1, 1))
21 | self.conv5 = (nn.Conv2d(8, 16, 3, 1, 1))
22 | self.conv6 = (nn.Conv2d(16, 16, 3, 1, 1))
23 | self.conv7 = (nn.Conv2d(4, 8, 3, 1, 1))
24 | self.conv8 = (nn.Conv2d(8, 4, 3, 1, 1))
25 | else:
26 | self.conv1 = (nn.Conv2d(16, 32, 3, 1, 1))
27 | self.conv2 = (nn.Conv2d(32, 32, 3, 1, 1))
28 | self.conv3 = (nn.Conv2d(8, 16, 3, 1, 1))
29 | self.conv4 = (nn.Conv2d(16, 16, 3, 1, 1))
30 | self.conv5 = (nn.Conv2d(4, 8, 3, 1, 1))
31 | self.conv6 = (nn.Conv2d(8, 4, 3, 1, 1))
32 | self.pixel_shuffle = nn.PixelShuffle(2)
33 |
34 | def forward(self, x):
35 | x = F.relu(self.fc1(x))
36 | x = F.relu(self.fc2(x))
37 | x = F.relu(self.fc3(x))
38 | x = F.relu(self.fc4(x))
39 | x = x.view(-1, 16, 16, 16)
40 | x = F.relu(self.conv1(x))
41 | x = self.pixel_shuffle(self.conv2(x))
42 | x = F.relu(self.conv3(x))
43 | x = self.pixel_shuffle(self.conv4(x))
44 | x = F.relu(self.conv5(x))
45 | x = self.pixel_shuffle(self.conv6(x))
46 | if self.high_res is True:
47 | x = F.relu(self.conv7(x))
48 | x = self.pixel_shuffle(self.conv8(x))
49 | x = torch.sigmoid(x)
50 | if self.high_res is True:
51 | return 1 - x.view(-1, 256, 256)
52 | else:
53 | return 1 - x.view(-1, 128, 128)
54 |
--------------------------------------------------------------------------------
/semantic_guidance/Renderer/stroke_gen.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 | def normal(x, width):
5 | """
6 | scale stroke parameter x ([0,1]) based on width of the canvas
7 | :param x: stroke parameter x ([0,1])
8 | :param width: width of canvas
9 | :return: scaled parameter
10 | """
11 | return (int)(x * (width - 1) + 0.5)
12 |
13 | def draw(f, width=128, mask = False):
14 | """
15 | Draw the Brezier curve on empty canvas
16 | :param f: stroke parameters (x0, y0, x1, y1, x2, y2, z0, z2, w0, w2)
17 | :param width: width of the canvas
18 | :param mask: boolean on whether mask is required for the canvas
19 | :return: painted canvas is zero at stroke locations and one otherwise,
20 | stroke_mask (ones at stroke parameters and zero otherwise)
21 | """
22 | # read stroke parameters (10 positional parameters)
23 | x0, y0, x1, y1, x2, y2, z0, z2, w0, w2 = f
24 | x1 = x0 + (x2 - x0) * x1
25 | y1 = y0 + (y2 - y0) * y1
26 | x0 = normal(x0, width * 2)
27 | x1 = normal(x1, width * 2)
28 | x2 = normal(x2, width * 2)
29 | y0 = normal(y0, width * 2)
30 | y1 = normal(y1, width * 2)
31 | y2 = normal(y2, width * 2)
32 | z0 = (int)(1 + z0 * width // 2)
33 | z2 = (int)(1 + z2 * width // 2)
34 |
35 | # initialize empty canvas
36 | canvas = np.zeros([width * 2, width * 2]).astype('float32')
37 | tmp = 1. / 100
38 |
39 | # Brezier curve is made of 100 smaller circles
40 | for i in range(100):
41 | t = i * tmp
42 | x = (int)((1-t) * (1-t) * x0 + 2 * t * (1-t) * x1 + t * t * x2)
43 | y = (int)((1-t) * (1-t) * y0 + 2 * t * (1-t) * y1 + t * t * y2)
44 | z = (int)((1-t) * z0 + t * z2)
45 | w = (1-t) * w0 + t * w2
46 | cv2.circle(canvas, (y, x), z, w, -1)
47 |
48 | # return mask if required
49 | if mask:
50 | stroke_mask = (canvas!=0).astype(np.int32)
51 | return 1 - cv2.resize(canvas, dsize=(width, width)), stroke_mask
52 | return 1 - cv2.resize(canvas, dsize=(width, width))
53 |
--------------------------------------------------------------------------------
/semantic_guidance/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/semantic_guidance/env_ins.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import json
3 | import torch
4 | import numpy as np
5 | import argparse
6 | import torchvision.transforms as transforms
7 | import cv2
8 | # from DRL.ddpg import decode
9 | import utils.util as util
10 | from utils.dataloader import ImageDataset
11 | import random
12 | from Renderer.model import FCN
13 | import pandas as pd
14 |
15 | # define device
16 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17 |
18 |
19 | class Paint:
20 | def __init__(self, batch_size, max_eps_len, dataset='cub200', use_gbp=False, use_bilevel=False, gpu=0,
21 | width=128, high_res=False, bundle_size=5):
22 | self.batch_size = batch_size
23 | self.max_eps_len = max_eps_len
24 | self.action_space = (13)
25 | self.use_gbp = use_gbp
26 | self.use_bilevel = use_bilevel
27 | self.width = width
28 | self.high_res = high_res
29 | self.bundle_size = bundle_size
30 |
31 | if self.high_res:
32 | self.width = 256
33 |
34 | # define and load renderer
35 | self.decoder = FCN(self.high_res)
36 | self.decoder.load_state_dict(torch.load('../data/renderer{}.pkl'.format("_256" if self.high_res else "")))
37 | self.decoder.cuda(gpu)
38 |
39 | for param in self.decoder.parameters():
40 | param.requires_grad = False
41 |
42 | # gpu and distributed parameters
43 | self.gpu = gpu
44 |
45 | # define observation space
46 | self.observation_space = (self.batch_size, self.width, self.width, 7 + use_gbp + use_bilevel)
47 | self.test = False
48 |
49 | # dataset name
50 | self.dataset = dataset
51 | if dataset == 'cub200':
52 | df = pd.read_csv('../data/cub200/CUB_200_2011/images.txt', sep=' ', index_col=0, names=['idx', 'img_names'])
53 | img_names = list(df['img_names'])
54 | self.data = np.array(
55 | ["../data/cub200/CUB_200_2011/images/{}.jpg".format(img_name[:-4]) for img_name in img_names])
56 | self.seg = np.array(
57 | ["../data/cub200/CUB_200_2011/segmentations_pred/{}.jpg".format(img_name[:-4]) for img_name in
58 | img_names])
59 | self.gbp = np.array(
60 | ["../data/cub200/CUB_200_2011/gbp_global/{}.jpg".format(img_name[:-4]) for img_name in img_names])
61 |
62 | df_ = pd.read_csv('../data/cub200/CUB_200_2011/bounding_boxes_pred.txt', sep=' ', index_col=0)
63 | x, y, w, h = np.array(df_['x']).astype(int), np.array(df_['y']).astype(int), np.array(df_['w']).astype(int), \
64 | np.array(df_['h']).astype(int)
65 | self.bbox = np.array(list(zip(x, y, w, h)))
66 |
67 | # random shuffle data
68 | shuffled_indices = np.arange(len(self.data)).astype(np.int32)
69 | np.random.shuffle(shuffled_indices)
70 | self.data = self.data[shuffled_indices]
71 | self.seg = self.seg[shuffled_indices]
72 | self.bbox = self.bbox[shuffled_indices]
73 | self.gbp = self.gbp[shuffled_indices]
74 |
75 | def load_data(self, num_test=2000):
76 | # random shuffle data
77 | shuffled_indices = np.arange(len(self.data)).astype(np.int32)
78 | np.random.shuffle(shuffled_indices)
79 |
80 | # divide data into train and test
81 | train_data, test_data = self.data[shuffled_indices[num_test:]], self.data[shuffled_indices[:num_test]]
82 |
83 | # divide gbp data
84 | if self.use_gbp:
85 | gbp_train, gbp_test = self.gbp[shuffled_indices[num_test:]], self.gbp[shuffled_indices[:num_test]]
86 | else:
87 | gbp_train, gbp_test = None, None
88 |
89 | if self.use_bilevel:
90 | seg_train, seg_test = self.seg[shuffled_indices[num_test:]], self.seg[
91 | shuffled_indices[:num_test]]
92 | bbox_train, bbox_test = self.bbox[shuffled_indices[num_test:]], self.bbox[
93 | shuffled_indices[:num_test]]
94 | else:
95 | seg_train, seg_test = None, None
96 | bbox_train, bbox_test = None, None
97 |
98 | # create train and test data
99 | self.train_dataset = ImageDataset(train_data, gbp_list=gbp_train, seg_list=seg_train,
100 | bbox_list=bbox_train, high_res=self.high_res)
101 | self.test_dataset = ImageDataset(test_data, gbp_list=gbp_test, seg_list=seg_test,
102 | bbox_list=bbox_test, high_res=self.high_res)
103 |
104 | # record train test split
105 | self.num_train, self.num_test = len(train_data), num_test
106 |
107 | def reset(self, test=False, begin_num=False):
108 | self.test = test
109 | # self.imgid = [0] * self.batch_size
110 | self.gt = torch.zeros([self.batch_size, 3, self.width, self.width], dtype=torch.uint8).cuda(self.gpu)
111 | if self.use_gbp:
112 | self.gbp_gt = torch.zeros([self.batch_size, 1, self.width, self.width], dtype=torch.uint8).cuda(self.gpu)
113 | if self.use_bilevel:
114 | self.seg_gt = torch.zeros([self.batch_size, 1, self.width, self.width], dtype=torch.uint8).cuda(self.gpu)
115 | self.grid = torch.zeros([self.batch_size, 2, self.width, self.width], dtype=torch.uint8).cuda(self.gpu)
116 |
117 | # get ground truths and corresponding idxs
118 | if test:
119 | self.imgid = (begin_num + np.arange(self.batch_size)) % self.num_test
120 | for i in range(self.batch_size):
121 | img, gbp_gt, seg_gt, grid = self.test_dataset[self.imgid[i]]
122 | self.gt[i] = img
123 | if self.use_gbp:
124 | self.gbp_gt[i, :] = gbp_gt
125 | if self.use_bilevel:
126 | self.seg_gt[i, :] = seg_gt
127 | self.grid[i, :] = grid
128 | else:
129 | self.imgid = np.random.choice(np.arange(self.num_train), self.batch_size, replace=False)
130 | for i in range(self.batch_size):
131 | img, gbp_gt, seg_gt, grid = self.train_dataset[self.imgid[i]]
132 | self.gt[i] = img
133 | if self.use_gbp:
134 | self.gbp_gt[i, :] = gbp_gt
135 | if self.use_bilevel:
136 | self.seg_gt[i, :] = seg_gt
137 | self.grid[i, :] = grid
138 |
139 | self.tot_reward = ((self.gt.float() / 255) ** 2).mean(1).mean(1).mean(1)
140 | self.stepnum = 0
141 | self.canvas = torch.zeros([self.batch_size, 3, self.width, self.width], dtype=torch.uint8).cuda(self.gpu)
142 | self.lastdis = self.ini_dis = self.cal_dis()
143 | return self.observation()
144 |
145 | def observation(self):
146 | # canvas B * 3 * width * width
147 | # gt B * 3 * width * width
148 | # T B * 1 * width * width
149 | ob = []
150 | T = torch.ones([self.batch_size, 1, self.width, self.width], dtype=torch.uint8) * self.stepnum
151 |
152 | # canvas, img, T
153 | obs_list = [self.canvas, self.gt, T.cuda(self.gpu)]
154 |
155 | if self.use_gbp:
156 | obs_list += [self.gbp_gt]
157 | if self.use_bilevel:
158 | obs_list += [self.seg_gt]
159 | obs_list += [self.grid]
160 |
161 | return torch.cat(obs_list, 1)
162 |
163 | def cal_trans(self, s, t):
164 | return (s.transpose(0, 3) * t).transpose(0, 3)
165 |
166 | def step(self, action):
167 | if self.use_bilevel:
168 | self.canvas = (
169 | self.decode_parallel(action, self.canvas.float() / 255, seg_mask=self.seg_gt.float() / 255)[
170 | 0] * 255).byte()
171 | else:
172 | self.canvas = (self.decode(action, self.canvas.float() / 255)[0] * 255).byte()
173 |
174 | self.stepnum += 1
175 | ob = self.observation()
176 | # ob = ob[:, :7, :, :]
177 | done = (self.stepnum == self.max_eps_len)
178 | reward = self.cal_reward() # np.array([0.] * self.batch_size)
179 | return ob.detach(), reward, np.array([done] * self.batch_size), None
180 |
181 | def cal_dis(self):
182 | return (((self.canvas.float() - self.gt.float()) / 255) ** 2).mean(1).mean(1).mean(1)
183 |
184 | def cal_reward(self):
185 | """L2 loss difference between canvas and ground truth"""
186 | dis = self.cal_dis()
187 | reward = (self.lastdis - dis) / (self.ini_dis + 1e-8)
188 | self.lastdis = dis
189 | return util.to_numpy(reward)
190 |
191 | def decode_parallel(self, x, canvas, seg_mask=None, mask=False):
192 | canvas, _ = self.decode(x[:, :13 * self.bundle_size], canvas, mask, 1 - seg_mask)
193 | canvas, _ = self.decode(x[:, 13 * self.bundle_size:], canvas, mask, seg_mask)
194 | return canvas, _
195 |
196 | def decode(self, x, canvas, mask=False, seg_mask=None): # b * (10 + 3)
197 | """
198 | Update canvas given stroke parameters x
199 | :param x: stroke parameters (N,13*5)
200 | :param canvas: current canvas state
201 | :return: updated canvas with stroke drawn
202 | """
203 | # 13 stroke parameters (10 position and 3 RGB color)
204 | x = x.contiguous().view(-1, 10 + 3)
205 |
206 | # get stroke on an empty canvas given 10 positional parameters
207 | stroke = 1 - self.decoder(x[:, :10])
208 | if self.high_res is True:
209 | stroke = stroke.view(-1, 256, 256, 1)
210 | else:
211 | stroke = stroke.view(-1, 128, 128, 1)
212 |
213 | # add color to the stroke
214 | color_stroke = stroke * x[:, -3:].view(-1, 1, 1, 3)
215 | stroke = stroke.permute(0, 3, 1, 2)
216 | color_stroke = color_stroke.permute(0, 3, 1, 2)
217 |
218 | # draw bundle_size=5 strokes at a time (action bundle)
219 | if self.high_res is True:
220 | stroke = stroke.view(-1, self.bundle_size, 1, 256, 256)
221 | color_stroke = color_stroke.view(-1, self.bundle_size, 3, 256, 256)
222 | else:
223 | stroke = stroke.view(-1, self.bundle_size, 1, 128, 128)
224 | color_stroke = color_stroke.view(-1, self.bundle_size, 3, 128, 128)
225 |
226 | for i in range(self.bundle_size):
227 | if seg_mask is not None:
228 | canvas = canvas * (1 - stroke[:, i]) + color_stroke[:, i] * seg_mask
229 | else:
230 | canvas = canvas * (1 - stroke[:, i]) + color_stroke[:, i]
231 |
232 | # also return stroke mask if required
233 | stroke_mask = None
234 | if mask:
235 | stroke_mask = (stroke != 0).float() # -1, bundle_size, 1, width, width
236 |
237 | return canvas, stroke_mask
238 |
--------------------------------------------------------------------------------
/semantic_guidance/preprocess.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os, sys
3 | import pandas as pd
4 | import cv2
5 |
6 | from PIL import Image
7 | import torch
8 | import torchvision
9 | from torchvision import transforms, utils
10 | from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
11 | from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
12 |
13 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
14 | print("using device: {}".format(device))
15 |
16 |
17 | # define function for getting combined object localization and semantic segmentation prediction model
18 | def get_model_instance_segmentation(num_classes):
19 | # load an instance segmentation model pre-trained pre-trained on COCO
20 | model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
21 |
22 | # get number of input features for the classifier
23 | in_features = model.roi_heads.box_predictor.cls_score.in_features
24 | # replace the pre-trained head with a new one
25 | model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
26 |
27 | # now get the number of input features for the mask classifier
28 | in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
29 | hidden_layer = 256
30 | # and replace the mask predictor with a new one
31 | model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
32 | hidden_layer,
33 | num_classes)
34 |
35 | return model
36 |
37 |
38 | # our dataset has two classes only - background and foreground
39 | num_classes = 2
40 | # load pretrained object localization and semantic segmentation model
41 | model = get_model_instance_segmentation(num_classes)
42 | model.load_state_dict(torch.load('../data/birds_obj_seg.pkl', map_location={'cuda:0': 'cpu'}))
43 | model.to(device)
44 | model.eval()
45 |
46 | transform_test = transforms.Compose([
47 | transforms.ToTensor(),
48 | ])
49 |
50 | expert_model = torch.hub.load('nicolalandro/ntsnet-cub200', 'ntsnet', pretrained=True,
51 | **{'topN': 6, 'device': 'cpu', 'num_classes': 200})
52 | expert_model.eval()
53 |
54 | from torch.nn import ReLU
55 |
56 | gbp_transform = transforms.Compose([
57 | transforms.ToPILImage(),
58 | transforms.Resize((440, 440)),
59 | transforms.CenterCrop((440, 440)),
60 | # transforms.RandomHorizontalFlip(), # only if train
61 | transforms.ToTensor(),
62 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
63 | ])
64 |
65 |
66 | def convert_to_grayscale(im_as_arr):
67 | """
68 | Converts 3d image to grayscale
69 | Args:
70 | im_as_arr (numpy arr): RGB image with shape (D,W,H)
71 | returns:
72 | grayscale_im (numpy_arr): Grayscale image with shape (1,W,D)
73 | """
74 | grayscale_im = np.sum(np.abs(im_as_arr), axis=0)
75 | im_max = np.percentile(grayscale_im, 99)
76 | im_min = np.min(grayscale_im)
77 | grayscale_im = (np.clip((grayscale_im - im_min) / (im_max - im_min), 0, 1))
78 | grayscale_im = np.expand_dims(grayscale_im, axis=0)
79 | return grayscale_im
80 |
81 |
82 | class GuidedBackprop():
83 | """
84 | Produces gradients generated with guided back propagation from the given image
85 | """
86 |
87 | def __init__(self, model):
88 | self.model = model
89 | self.gradients = None
90 | # Put model in evaluation mode
91 | self.model.eval()
92 | self.update_relus()
93 | self.hook_layers()
94 |
95 | def hook_layers(self):
96 | def hook_function(module, grad_in, grad_out):
97 | # print (module,grad_in,grad_out)
98 | self.gradients = grad_in[0]
99 |
100 | # Register hook to the first layer
101 | first_layer = list(self.model.children())[0]
102 | first_layer.register_backward_hook(hook_function)
103 |
104 | def update_relus(self):
105 | """
106 | Updates relu activation functions so that it only returns positive gradients
107 | """
108 |
109 | def relu_hook_function(module, grad_in, grad_out):
110 | """
111 | If there is a negative gradient, changes it to zero
112 | """
113 | if isinstance(module, ReLU):
114 | return (torch.clamp(grad_in[0], min=0.0),)
115 |
116 | # Loop through layers, hook up ReLUs with relu_hook_function
117 | for module in self.model.modules():
118 | if isinstance(module, ReLU):
119 | module.register_backward_hook(relu_hook_function)
120 |
121 | def generate_gradients(self, input_image):
122 | input_image.requires_grad = True
123 | # Forward pass
124 | model_output = self.model(input_image)[0]
125 | target_class = torch.argmax(model_output).item()
126 |
127 | # Zero gradients
128 | self.model.zero_grad()
129 | # Target for backprop
130 | one_hot_output = torch.FloatTensor(1, model_output.size()[-1]).zero_()
131 | one_hot_output[0][target_class] = 1
132 |
133 | # print (one_hot_output)
134 | # Backward pass
135 | model_output.backward(gradient=one_hot_output)
136 | # Convert Pytorch variable to numpy array
137 | # [0] to get rid of the first channel (1,3,224,224)
138 | gradients_as_arr = self.gradients.data.numpy()[0]
139 | return gradients_as_arr
140 |
141 |
142 | # df_ = pd.read_csv('data/cub200/CUB_200_2011/bounding_boxes.txt', sep=' ', index_col=0,
143 | # names=['idx', 'x', 'y', 'w', 'h'])
144 | # df_ = df_.copy()
145 | # data = []
146 |
147 | # get input image list
148 | df = pd.read_csv('../data/cub200/CUB_200_2011/images.txt', sep=' ', index_col=0, names=['idx', 'img_names'])
149 | img_names = list(df['img_names'])
150 | img_list = np.array(["../data/cub200/CUB_200_2011/images/{}.jpg".format(img_name[:-4]) for img_name in img_names])
151 |
152 | # predicted bounded box data
153 | pred_bbox_data = []
154 | GBP = GuidedBackprop(expert_model.pretrained_model)
155 |
156 | # create output directories for storing data
157 | gbp_dir = '../data/cub200/CUB_200_2011/gbp_global/'
158 | segmentations_dir = '../data/cub200/CUB_200_2011/segmentations_pred/'
159 | if not os.path.exists(gbp_dir):
160 | os.makedirs(gbp_dir)
161 | if not os.path.exists(segmentations_dir):
162 | os.makedirs(segmentations_dir)
163 |
164 | for choose in range(len(img_list)):
165 | # get image segmentation and bounding box predictions
166 | img_name = img_list[choose]
167 | img = Image.open(img_name) # cv2.cvtColor(cv2.imread(img_name, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
168 | img_ = transform_test(img).unsqueeze(0).to(device)
169 | with torch.no_grad():
170 | prediction = model(img_)
171 | bbox_pred = prediction[0]['boxes'][0]
172 | x, y, w, h = bbox_pred.detach().cpu().numpy().astype(int)
173 | seg_mask = prediction[0]['masks'][0, 0].mul(255).byte().cpu().numpy()
174 |
175 | pred_bbox_data.append([x, y, w - x, h - y])
176 | # df_.iloc[choose]['x'] = x
177 | # df_.iloc[choose]['y'] = y
178 | # df_.iloc[choose]['w'] = w - x
179 | # df_.iloc[choose]['h'] = h - y
180 |
181 | path_list = img_name.split('/')
182 | path_list[-3] = 'segmentations_pred'
183 | seg_name = '/'.join(path_list)
184 | # seg_name = "{}/segmentations_pred2/{}".format(img_name[:24], img_name[32:])
185 | if not os.path.exists(os.path.dirname(seg_name)):
186 | os.makedirs(os.path.dirname(seg_name))
187 | cv2.imwrite(seg_name, seg_mask)
188 |
189 | # get guided backpropagation maps from the expert model
190 | img = cv2.cvtColor(cv2.imread(img_name, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
191 | H, W, C = img.shape
192 | img = img[y:h, x:w].astype(np.uint8)
193 | seg_mask = seg_mask.astype(float) / 255.
194 | seg_mask = seg_mask[y:h, x: w]
195 |
196 | img = img.astype(float) * np.expand_dims(seg_mask, -1)
197 | img = img.astype(np.uint8)
198 |
199 | scaled_img = gbp_transform(img)
200 | torch_images = scaled_img.unsqueeze(0)
201 |
202 | guided_grads = GBP.generate_gradients(torch_images) # .transpose(1,2,0)
203 | gbp = convert_to_grayscale(guided_grads)[0]
204 | gbp = cv2.resize(gbp, (w-x, h-y))
205 | gbp = (255 * gbp).astype(np.uint8)
206 |
207 | seg_mask = (seg_mask > 0.5).astype(np.uint8)
208 | gbp = gbp * seg_mask
209 |
210 | gbp_global = np.zeros((H, W)).astype(np.uint8)
211 | w, h = min(w, W), min(h, H)
212 | gbp_global[y:h, x:w] = gbp
213 |
214 | path_list = img_name.split('/')
215 | path_list[-3] = 'gbp_global'
216 | gbp_name = '/'.join(path_list)
217 | if not os.path.exists(os.path.dirname(gbp_name)):
218 | os.makedirs(os.path.dirname(gbp_name))
219 | cv2.imwrite(gbp_name, gbp_global)
220 |
221 | if choose % 100 == 0:
222 | print("Processed :{}/{} images ...".format(choose, len(img_list)))
223 |
224 | # save bounding box predictions
225 | df_ = pd.DataFrame(data=pred_bbox_data, columns=['x', 'y', 'w', 'h'])
226 | df_.to_csv('../data/cub200/CUB_200_2011/bounding_boxes_pred.txt', sep=' ')
227 |
--------------------------------------------------------------------------------
/semantic_guidance/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import torch
4 | import numpy as np
5 | import argparse
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torchvision import transforms, utils
9 | from PIL import Image
10 | from DRL.actor import *
11 | from Renderer.stroke_gen import *
12 | # from Renderer.model import *
13 | import pandas as pd
14 | from DRL.actor import ResNet
15 | from Renderer.model import FCN
16 | import matplotlib.pyplot as plt
17 | from test_utils import *
18 | from collections import OrderedDict
19 |
20 | # define device
21 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22 |
23 | # parse input arguments
24 | parser = argparse.ArgumentParser(description='Paint canvases using semantic guidance')
25 | parser.add_argument('--max_eps_len', default=50, type=int, help='max length for episode')
26 | parser.add_argument('--actor', default='pretrained_models/actor_semantic_guidance.pkl', type=str, help='actor model')
27 | parser.add_argument('--use_baseline', action='store_true', help='use baseline model instead of semantic guidance')
28 | parser.add_argument('--renderer', default='../data/renderer.pkl', type=str, help='renderer model')
29 | parser.add_argument('--img', default='../image/test.png', type=str, help='test image')
30 | args = parser.parse_args()
31 |
32 | # width of canvas used
33 | width = 128
34 | # time image
35 | n_batch = 1
36 | T = torch.ones([n_batch, 1, width, width], dtype=torch.float32).to(device)
37 | coord = torch.zeros([n_batch, 2, width, width])
38 | for i in range(width):
39 | for j in range(width):
40 | coord[:, 0, i, j] = i / (width - 1.)
41 | coord[:, 1, i, j] = j / (width - 1.)
42 | coord = coord.to(device) # Coordconv
43 |
44 | # define, load actor and set to eval mode
45 | bundle_size = 5
46 | use_gbp, use_bilevel, use_neural_alignment = False, False, False
47 | if not args.use_baseline:
48 | use_gbp, use_bilevel, use_neural_alignment = True, True, True
49 | bundle_size = int(np.ceil(bundle_size / 2))
50 | actor = ResNet(9 + use_gbp + use_bilevel + 2 * use_neural_alignment, 18, 13 * bundle_size * (1 + use_bilevel))
51 |
52 | # load trained actor model
53 | state_dict = torch.load(args.actor, map_location={'cuda:0': 'cpu'})
54 | new_state_dict = OrderedDict()
55 | for k, v in state_dict.items():
56 | name = k.replace("module.", "") # k[7:] # remove `module.`
57 | new_state_dict[name] = v
58 | actor.load_state_dict(new_state_dict)
59 | actor = actor.to(device).eval()
60 |
61 | if not args.use_baseline:
62 | # load object localization and semantic segmentation model
63 | seg_model = get_model_instance_segmentation(num_classes=2)
64 | seg_model.load_state_dict(torch.load('../data/birds_obj_seg.pkl', map_location={'cuda:0': 'cpu'}))
65 | for param in seg_model.parameters():
66 | param.requires_grad = False
67 | seg_model.to(device)
68 | seg_model.eval()
69 |
70 | # load expert model for getting guided backpropagation maps
71 | expert_model = torch.hub.load('nicolalandro/ntsnet-cub200', 'ntsnet', pretrained=True,
72 | **{'topN': 6, 'device': 'cpu', 'num_classes': 200})
73 | expert_model.eval()
74 | GBP = GuidedBackprop(expert_model.pretrained_model)
75 |
76 | # get segmentation, bbox and guided backpropagation map predictions
77 | img = Image.open(args.img) # cv2.cvtColor(cv2.imread(img_name, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
78 | img_ = transform_test(img).unsqueeze(0).to(device)
79 | with torch.no_grad():
80 | prediction = seg_model(img_)
81 | bbox_pred = prediction[0]['boxes'][0]
82 | x, y, w, h = bbox_pred.detach().cpu().numpy().astype(int)
83 | seg_mask = prediction[0]['masks'][0, 0].mul(255).byte().cpu().numpy()
84 | classgt_mask = seg_mask.copy()
85 |
86 | # get guided backpropagation maps from the expert model
87 | img = cv2.cvtColor(cv2.imread(args.img, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
88 | H, W, C = img.shape
89 | img = img[y:h, x:w].astype(np.uint8)
90 | seg_mask = seg_mask.astype(float) / 255.
91 | seg_mask = seg_mask[y:h, x: w]
92 |
93 | img = img.astype(float) * np.expand_dims(seg_mask, -1)
94 | img = img.astype(np.uint8)
95 |
96 | scaled_img = gbp_transform(img)
97 | torch_images = scaled_img.unsqueeze(0)
98 |
99 | guided_grads = GBP.generate_gradients(torch_images) # .transpose(1,2,0)
100 | gbp = convert_to_grayscale(guided_grads)[0]
101 | gbp = cv2.resize(gbp, (w - x, h - y))
102 | gbp = (255 * gbp).astype(np.uint8)
103 |
104 | seg_mask = (seg_mask > 0.5).astype(np.uint8)
105 | gbp = gbp * seg_mask
106 |
107 | gbp_global = np.zeros((H, W)).astype(np.uint8)
108 | w, h = min(w, W), min(h, H)
109 | gbp_global[y:h, x:w] = gbp
110 |
111 | # get grid for spatial transformer network
112 | w, h = w - x, h - y
113 | x, y, w, h = x / W, y / H, w / W, h / H
114 | Affine_Mat_w = [w, 0, (2 * x + w - 1)]
115 | Affine_Mat_h = [0, h, (2 * y + h - 1)]
116 | M = np.c_[Affine_Mat_w, Affine_Mat_h].T
117 | M = torch.tensor(M).unsqueeze(0)
118 | grid = torch.nn.functional.affine_grid(M, (1, 3, 128, 128)) # (1,128,128,2)
119 | grid = (grid + 1) / 2 # scale between 0,1
120 | grid = torch.tensor(grid * 255, dtype=torch.uint8).permute(0, 3, 1, 2)
121 | grid = grid.to(device).float() / 255.
122 |
123 | # load segmentation image
124 | classgt_img = cv2.resize(classgt_mask, (128, 128))
125 | classgt_img = torch.tensor(classgt_img, dtype=torch.uint8)
126 | classgt_img = classgt_img.unsqueeze(0).unsqueeze(0).to(device).float() / 255.
127 |
128 | # load guided backpropagation map
129 | gbp_global = cv2.resize(gbp_global, (128, 128))
130 | gbp_img = torch.tensor(gbp_global, dtype=torch.uint8)
131 | gbp_img = gbp_img.unsqueeze(0).unsqueeze(0).to(device).float() / 255.
132 |
133 | # load image
134 | image = cv2.cvtColor(cv2.imread(args.img, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
135 | image = cv2.resize(image, (128, 128))
136 | image = torch.tensor(image, dtype=torch.uint8).permute(2, 0, 1)
137 | img = image.unsqueeze(0).to(device).float() / 255.
138 |
139 | # initialize empty canvas
140 | canvas = torch.zeros([n_batch, 3, width, width]).to(device)
141 | max_eps_len = args.max_eps_len
142 | # initialize canvas frames for generating video
143 | canvas_frames = []
144 | canvas_ = canvas.permute(0, 2, 3, 1).cpu().numpy()
145 | canvas_frames.append(canvas_[0])
146 |
147 | with torch.no_grad():
148 | # generate canvas episode
149 | for i in range(max_eps_len):
150 | stroke_type = 'both' # 'fg' if i > 1 else 'both'
151 | stepnum = T * i / max_eps_len
152 | if args.use_baseline:
153 | actions = actor(torch.cat([canvas, img, stepnum, coord], 1))
154 | canvas, _, frames = decode(actions, canvas, get_frames=True, bundle_size=bundle_size)
155 | else:
156 | actions = actor(torch.cat([canvas, img, stepnum, coord, gbp_img, classgt_img, grid], 1))
157 | canvas, _, frames = decode_parallel(actions, canvas, seg_mask=classgt_img, stroke_type=stroke_type,
158 | get_frames=True, bundle_size=bundle_size)
159 | canvas_frames += frames
160 |
161 | # get final painted canvas
162 | canvas_ = canvas.permute(0, 2, 3, 1).cpu().numpy()
163 | canvas_ = np.array(255 * canvas_[0]).astype(np.uint8)
164 | # H, W, C = cv2.imread(args.img, cv2.IMREAD_COLOR).shape
165 | H, W = 250, 250
166 | painted_canvas = cv2.cvtColor(cv2.resize(canvas_, (W, H)), cv2.COLOR_RGB2BGR)
167 |
168 | # save generated canvas
169 | save_img_name = "../output/{}_painted_{}".format('baseline' if args.use_baseline else 'sg', os.path.basename(args.img))
170 | print("\nSaving painted canvas to {}\n".format(save_img_name))
171 | cv2.imwrite(save_img_name, painted_canvas)
172 |
173 | # generate and save video for the painting episode
174 | video_name = '../video/{}_{}.mp4'.format('baseline' if args.use_baseline else 'sg', os.path.basename(args.img)[:-4])
175 | video = cv2.VideoWriter(video_name, cv2.VideoWriter_fourcc(*'mp4v'), 20, (W, H))
176 | for image in canvas_frames:
177 | image = np.array(255 * image, dtype=np.uint8)
178 | image = cv2.resize(image, (W, H))
179 | video.write(cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
180 | video.release()
181 | print("\nSaving painting video to {}\n".format(video_name))
182 |
--------------------------------------------------------------------------------
/semantic_guidance/test_utils.py:
--------------------------------------------------------------------------------
1 | import torchvision
2 | from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
3 | from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
4 | from torchvision import transforms, utils
5 | import os
6 | import cv2
7 | import torch
8 | import numpy as np
9 | import argparse
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from torchvision import transforms, utils
13 | from Renderer.model import FCN
14 | from torch.nn import ReLU
15 |
16 | # define device
17 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18 |
19 |
20 | def get_model_instance_segmentation(num_classes):
21 | # load an instance segmentation model pre-trained pre-trained on COCO
22 | model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
23 |
24 | # get number of input features for the classifier
25 | in_features = model.roi_heads.box_predictor.cls_score.in_features
26 | # replace the pre-trained head with a new one
27 | model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
28 |
29 | # now get the number of input features for the mask classifier
30 | in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
31 | hidden_layer = 256
32 | # and replace the mask predictor with a new one
33 | model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
34 | hidden_layer,
35 | num_classes)
36 |
37 | return model
38 |
39 |
40 | # renderer model
41 | Decoder = FCN()
42 | Decoder.load_state_dict(torch.load('../data/renderer.pkl'))
43 | Decoder = Decoder.to(device).eval()
44 |
45 | def decode_parallel(x, canvas, seg_mask=None, mask=False, stroke_type='both', get_frames=False, bundle_size=3):
46 | # print (seg_mask.shape)
47 |
48 | bg, fg = True, True
49 |
50 | if stroke_type == 'fg':
51 | bg = False
52 | if stroke_type == 'bg':
53 | fg = False
54 |
55 | canvas_frames = []
56 | if get_frames:
57 | if bg:
58 | canvas, _, frames = decode(x[:, :13 * bundle_size], canvas, mask, 1 - seg_mask, get_frames=get_frames,bundle_size=bundle_size)
59 | canvas_frames += frames
60 | if fg:
61 | canvas, _, frames = decode(x[:, 13 * bundle_size:], canvas, mask, seg_mask, get_frames=get_frames,bundle_size=bundle_size)
62 | canvas_frames += frames
63 | return canvas, _, canvas_frames
64 |
65 | if bg:
66 | canvas, _ = decode(x[:, :13 * bundle_size], canvas, mask, 1 - seg_mask, get_frames=get_frames,bundle_size=bundle_size)
67 | if fg:
68 | canvas, _ = decode(x[:, 13 * bundle_size:], canvas, mask, seg_mask, get_frames=get_frames,bundle_size=bundle_size)
69 | return canvas, _
70 |
71 |
72 | def decode(x, canvas, mask=False, seg_mask=None, get_frames=False, bundle_size=5): # b * (10 + 3)
73 | """
74 | Update canvas given stroke parameters x
75 | """
76 | # 13 stroke parameters (10 position and 3 RGB color)
77 | x = x.contiguous().view(-1, 10 + 3)
78 |
79 | # get stroke on an empty canvas given 10 positional parameters
80 | stroke = 1 - Decoder(x[:, :10])
81 |
82 | stroke = stroke.view(-1, 128, 128, 1)
83 |
84 | # add color to the stroke
85 | color_stroke = stroke * x[:, -3:].view(-1, 1, 1, 3)
86 | stroke = stroke.permute(0, 3, 1, 2)
87 | color_stroke = color_stroke.permute(0, 3, 1, 2)
88 |
89 | stroke = stroke.view(-1, bundle_size, 1, 128, 128)
90 | color_stroke = color_stroke.view(-1, bundle_size, 3, 128, 128)
91 |
92 | canvas_frames = []
93 | for i in range(bundle_size):
94 | if seg_mask is not None:
95 | canvas = canvas * (1 - stroke[:, i]) + color_stroke[:, i] * seg_mask
96 | else:
97 | canvas = canvas * (1 - stroke[:, i]) + color_stroke[:, i]
98 | canvas_ = canvas.permute(0, 2, 3, 1).cpu().numpy()
99 | canvas_frames.append(canvas_[0])
100 |
101 | # also return stroke mask if required
102 | stroke_mask = None
103 | if mask:
104 | stroke_mask = (stroke != 0).float() # -1, bundle_size, 1, width, width
105 |
106 | if get_frames:
107 | return canvas, stroke_mask, canvas_frames
108 | return canvas, stroke_mask
109 |
110 |
111 | gbp_transform = transforms.Compose([
112 | transforms.ToPILImage(),
113 | transforms.Resize((440, 440)),
114 | transforms.CenterCrop((440, 440)),
115 | # transforms.RandomHorizontalFlip(), # only if train
116 | transforms.ToTensor(),
117 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
118 | ])
119 |
120 | transform_test = transforms.Compose([
121 | transforms.ToTensor(),
122 | ])
123 |
124 |
125 | def convert_to_grayscale(im_as_arr):
126 | """
127 | Converts 3d image to grayscale
128 | Args:
129 | im_as_arr (numpy arr): RGB image with shape (D,W,H)
130 | returns:
131 | grayscale_im (numpy_arr): Grayscale image with shape (1,W,D)
132 | """
133 | grayscale_im = np.sum(np.abs(im_as_arr), axis=0)
134 | im_max = np.percentile(grayscale_im, 99)
135 | im_min = np.min(grayscale_im)
136 | grayscale_im = (np.clip((grayscale_im - im_min) / (im_max - im_min), 0, 1))
137 | grayscale_im = np.expand_dims(grayscale_im, axis=0)
138 | return grayscale_im
139 |
140 |
141 | class GuidedBackprop():
142 | """
143 | Produces gradients generated with guided back propagation from the given image
144 | """
145 |
146 | def __init__(self, model):
147 | self.model = model
148 | self.gradients = None
149 | # Put model in evaluation mode
150 | self.model.eval()
151 | self.update_relus()
152 | self.hook_layers()
153 |
154 | def hook_layers(self):
155 | def hook_function(module, grad_in, grad_out):
156 | # print (module,grad_in,grad_out)
157 | self.gradients = grad_in[0]
158 |
159 | # Register hook to the first layer
160 | first_layer = list(self.model.children())[0]
161 | first_layer.register_backward_hook(hook_function)
162 |
163 | def update_relus(self):
164 | """
165 | Updates relu activation functions so that it only returns positive gradients
166 | """
167 |
168 | def relu_hook_function(module, grad_in, grad_out):
169 | """
170 | If there is a negative gradient, changes it to zero
171 | """
172 | if isinstance(module, ReLU):
173 | return (torch.clamp(grad_in[0], min=0.0),)
174 |
175 | # Loop through layers, hook up ReLUs with relu_hook_function
176 | for module in self.model.modules():
177 | if isinstance(module, ReLU):
178 | module.register_backward_hook(relu_hook_function)
179 |
180 | def generate_gradients(self, input_image):
181 | input_image.requires_grad = True
182 | # Forward pass
183 | model_output = self.model(input_image)[0]
184 | target_class = torch.argmax(model_output).item()
185 |
186 | # Zero gradients
187 | self.model.zero_grad()
188 | # Target for backprop
189 | one_hot_output = torch.FloatTensor(1, model_output.size()[-1]).zero_()
190 | one_hot_output[0][target_class] = 1
191 |
192 | # print (one_hot_output)
193 | # Backward pass
194 | model_output.backward(gradient=one_hot_output)
195 | # Convert Pytorch variable to numpy array
196 | # [0] to get rid of the first channel (1,3,224,224)
197 | gradients_as_arr = self.gradients.data.numpy()[0]
198 | return gradients_as_arr
199 |
--------------------------------------------------------------------------------
/semantic_guidance/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import cv2
3 | import random
4 | import numpy as np
5 | import argparse
6 | from DRL.evaluator import Evaluator
7 | from utils.util import *
8 | from utils.tensorboard import TensorBoard
9 | import time
10 |
11 | from DRL.ddpg import DDPG
12 |
13 | import argparse
14 | import torch.distributed as dist
15 | import sys, os
16 | import torch
17 | import torch.multiprocessing as mp
18 | import numpy as np
19 |
20 | # exp = os.path.abspath('.').split('/')[-1]
21 | # writer = TensorBoard('../train_log/{}'.format(exp))
22 | # os.system('ln -sf ../train_log/{} ./log'.format(exp))
23 | os.system('mkdir ./model')
24 |
25 |
26 | def train(agent, evaluate, writer, args, gpu, distributed):
27 | """
28 | :param agent: DDPG agent
29 | :param evaluate:
30 | :param writer: tensorboard summary writer
31 | :return: None
32 | """
33 | ## hyperparameters
34 | # total timesteps for training
35 | train_timesteps = args.train_timesteps
36 | # number of parallel environments for faster sample collection
37 | nenv = args.nenv
38 | # number of episodes between validation tests
39 | val_interval = args.val_interval
40 | # maximum length of a single painting episode (number of brush strokes)
41 | max_eps_len = args.max_eps_len
42 | # number of training steps per episode
43 | train_steps_per_eps = args.train_steps_per_eps
44 | # noise factor used in training
45 | noise_factor = args.noise_factor
46 |
47 | ## display progress
48 | # verbose: print training progress if true
49 | debug = args.debug
50 |
51 | ## load and save directories
52 | # path for stored model if resuming training
53 | load_path = args.load_path
54 | # parent directory for storing trained models
55 | output = args.output
56 |
57 | ## intializations
58 | # get current time stamp
59 | time_stamp = time.time()
60 | # initialize training steps
61 | step = episode = episode_steps = 0
62 | # initialize total reward
63 | tot_reward = 0.
64 | # initialize state
65 | observation = None
66 |
67 | # synchronize initial models
68 | agent.save_model(output, episode)
69 |
70 | # begin training
71 | while step <= train_timesteps:
72 | # update training steps
73 | step += 1
74 | # steps within an episode (cannot be greater than max_eps_len)
75 | episode_steps += 1
76 |
77 | ## take a step in concurrent environments and store samples to the replay buffer
78 | # reset if it is the start of episode
79 | if observation is None:
80 | observation = agent.env.reset()
81 | agent.reset(observation, noise_factor)
82 | action = agent.select_action(observation, noise_factor=noise_factor)
83 | observation, reward, done, _ = agent.env.step(action)
84 | # store r,s,a tuple to replay memory
85 | agent.observe(reward, observation, done, step)
86 |
87 | # store progress and train if episode is done
88 | if episode_steps >= max_eps_len and max_eps_len:
89 | if step > args.warmup:
90 | # compute validation results
91 | if episode > 0 and val_interval > 0 and episode % val_interval == 0:
92 | reward, dist = evaluate(agent.env, agent.select_action, debug=debug)
93 | if debug and gpu == 0:
94 | prRed('Step_{:07d}: mean_reward:{:.3f} mean_dist:{:.3f} var_dist:{:.3f}'.format(step - 1,
95 | np.mean(reward),
96 | np.mean(dist),
97 | np.var(dist)))
98 | writer.add_scalar('validate/mean_reward', np.mean(reward), step)
99 | writer.add_scalar('validate/mean_dist', np.mean(dist), step)
100 | writer.add_scalar('validate/var_dist', np.var(dist), step)
101 |
102 | # save latest model
103 | if episode % 200 == 0:
104 | agent.save_model(output, episode)
105 |
106 | # get training time and update timestamp
107 | train_time_interval = time.time() - time_stamp
108 | time_stamp = time.time()
109 |
110 | # initialize total expected value and overall value loss for the episode
111 | tot_Q = 0.
112 | tot_value_loss = 0.
113 |
114 | if step > args.warmup:
115 | # step learning rate schedule after (1e4,2e4) episodes
116 | # also note lr[0],lr[1] are for updating critic,actor respectively.
117 | if step < 10000 * max_eps_len:
118 | lr = (3e-4, 1e-3)
119 | elif step < 20000 * max_eps_len:
120 | lr = (1e-4, 3e-4)
121 | else:
122 | lr = (3e-5, 1e-4)
123 |
124 | # perform several training steps after each episode
125 | for i in range(train_steps_per_eps):
126 | # train the agent
127 | Q, value_loss = agent.update_policy(lr)
128 | tot_Q += Q.data.cpu().numpy()
129 | tot_value_loss += value_loss.data.cpu().numpy()
130 |
131 | # store training performance summaries
132 | if gpu == 0:
133 | writer.add_scalar('train/critic_lr', lr[0], step)
134 | writer.add_scalar('train/actor_lr', lr[1], step)
135 | writer.add_scalar('train/Q', tot_Q / train_steps_per_eps, step)
136 | writer.add_scalar('train/critic_loss', tot_value_loss / train_steps_per_eps, step)
137 |
138 | # display training progress
139 | if debug and gpu == 0:
140 | prBlack('#{}: steps:{} interval_time:{:.2f} train_time:{:.2f}' \
141 | .format(episode, step, train_time_interval, time.time() - time_stamp))
142 |
143 | # reset/update timestamp and episode stats
144 | time_stamp = time.time()
145 | observation = None
146 | episode_steps = 0
147 | episode += 1
148 |
149 |
150 | def setup(args):
151 | os.environ['MASTER_ADDR'] = 'localhost'
152 | os.environ['MASTER_PORT'] = '12354'
153 |
154 | # initialize the process group
155 | dist.init_process_group("nccl", rank=args.rank, world_size=args.world_size)
156 | # limit number of threads per process
157 | setup_pytorch_for_mpi(args)
158 | torch.backends.cudnn.benchmark = True
159 | torch.backends.cudnn.deterministic = False
160 |
161 |
162 | def cleanup():
163 | dist.destroy_process_group()
164 |
165 |
166 | def setup_pytorch_for_mpi(args):
167 | """
168 | Avoid slowdowns caused by each separate process's PyTorch using
169 | more than its fair share of CPU resources.
170 | """
171 | print('Proc %d: Reporting original number of Torch threads as %d.' % (args.rank, torch.get_num_threads()),
172 | flush=True)
173 | if torch.get_num_threads() == 1:
174 | return
175 | fair_num_threads = max(int(torch.get_num_threads() / args.world_size), 1)
176 | fair_num_threads = 1
177 | torch.set_num_threads(fair_num_threads)
178 | print('Proc %d: Reporting new number of Torch threads as %d.' % (args.rank, torch.get_num_threads()), flush=True)
179 |
180 |
181 | def demo(gpu, args):
182 | # rank of the current process
183 | args.rank = args.nr * args.gpus + gpu
184 | # setup dist process
185 | setup(args)
186 |
187 | # Random seed
188 | seed = 10000 * gpu + gpu
189 | torch.manual_seed(seed)
190 | np.random.seed(seed)
191 | random.seed(seed)
192 |
193 | # # summary writer
194 | writer = TensorBoard(args.LOG_DIR)
195 |
196 | # setup concurrent environments
197 | # define agent
198 | agent = DDPG(args.batch_size, args.nenv, args.max_eps_len,
199 | args.tau, args.discount, args.rmsize,
200 | writer, args.load_path, args.output, args.dataset,
201 | use_gbp=args.use_gbp, use_bilevel=args.use_bilevel,
202 | gbp_coef=args.gbp_coef, seggt_coef=args.seggt_coef,
203 | gpu=gpu, distributed=False,
204 | high_res=args.high_res, bundle_size=args.bundle_size)
205 | evaluate = Evaluator(args, writer)
206 |
207 | # display state, action space info
208 | if gpu == 0:
209 | print('observation_space', agent.env.observation_space, 'action_space', agent.env.action_space)
210 |
211 | # begin training
212 | train(agent, evaluate, writer, args, gpu=gpu, distributed=True)
213 |
214 | if gpu == 0:
215 | print("Training finished")
216 |
217 | cleanup()
218 |
219 |
220 | if __name__ == "__main__":
221 | parser = argparse.ArgumentParser(description='Learning to Paint')
222 |
223 | # hyper-parameter
224 | parser.add_argument('--dataset', type=str, default='cub200', choices=['cub200'],
225 | help='dataset')
226 | parser.add_argument('--warmup', default=400, type=int,
227 | help='timestep without training but only filling the replay memory')
228 | parser.add_argument('--discount', default=0.95 ** 5, type=float, help='discount factor (gamma)')
229 | parser.add_argument('--batch_size', default=96, type=int, help='minibatch size')
230 | parser.add_argument('--bundle_size', default=5, type=int, help='action bundle size')
231 | parser.add_argument('--rmsize', default=800, type=int, help='replay memory size')
232 | parser.add_argument('--nenv', default=96, type=int,
233 | help='concurrent environment number/ number of environments')
234 | parser.add_argument('--tau', default=0.001, type=float, help='moving average for target network')
235 | parser.add_argument('--max_eps_len', default=40, type=int, help='max length for episode (*)')
236 | parser.add_argument('--noise_factor', default=0, type=float, help='noise level for parameter space noise')
237 | parser.add_argument('--val_interval', default=50, type=int, help='episode interval for performing validation')
238 | parser.add_argument('--val_num_eps', default=5, type=int, help='episodes used for performing validation')
239 | parser.add_argument('--train_timesteps', default=int(2e6), type=int, help='total training steps')
240 | parser.add_argument('--train_steps_per_eps', default=10, type=int, help='number of training steps per episode')
241 | parser.add_argument('--load_path', default=None, type=str, help='Load model and resume training')
242 | parser.add_argument('--exp_suffix', default='base', type=str,
243 | help='suffix for providing additional experiment info')
244 | parser.add_argument('--output', default='./model', type=str, help='Output path for storing model')
245 | parser.add_argument('--use_gbp', action='store_true', help='use gbp info along with rgb image')
246 | parser.add_argument('--use_bilevel', action='store_true', help='use semantic class maps info along with rgb image')
247 | parser.add_argument('--bundled_seggt', action='store_true',
248 | help='all strokes in a bundle should belong to same class')
249 | parser.add_argument('--gbp_coef', default=1.0, type=float, help='coefficient for gbp reward')
250 | parser.add_argument('--seggt_coef', default=1.0, type=float, help='coefficient for seggt reward')
251 | parser.add_argument('--high_res', action='store_true', help='use high resolution gt(256 x 256)')
252 | parser.add_argument('--debug', dest='debug', action='store_true', help='print some info')
253 | parser.add_argument('--seed', default=1234, type=int, help='random seed')
254 | parser.add_argument('-n', '--nodes', default=1, type=int, metavar='N',
255 | help='number of data loading workers (default: 1)')
256 | parser.add_argument('-g', '--gpus', default=1, type=int, help='number of gpus per node')
257 | parser.add_argument('-nr', '--nr', default=0, type=int, help='ranking within the nodes')
258 |
259 | # parse args
260 | args = parser.parse_args()
261 |
262 | # if using bilevel painting procedure than use only half of brush strokes for foreground
263 | # and rest for background
264 | if args.use_bilevel:
265 | args.bundle_size = int(np.ceil(args.bundle_size/2))
266 |
267 | # log directory
268 | exp_type = os.path.abspath('.').split('/')[-1]
269 | LOG_DIR = "../train_log/{}/".format(exp_type)
270 | # choose log directory based on the experiment
271 | exp_name = "{}/nenv{}_batchsize{}_maxstep_{}_tau{}_memsize{}_{}{}{}_{}".format(args.dataset, args.nenv,
272 | args.batch_size,
273 | args.max_eps_len, args.tau,
274 | args.rmsize,
275 | "bundlesize{}".format(args.bundle_size),
276 | "_gbp{}".format(args.gbp_coef) if args.use_gbp else "",
277 | "_seggt{}".format(args.seggt_coef) if args.use_bilevel else "",
278 | args.exp_suffix)
279 |
280 | # create summary writer
281 | LOG_DIR += exp_name
282 | args.LOG_DIR = LOG_DIR
283 |
284 | # create output directory
285 | args.output = get_output_folder(args.output + "/" + exp_name, "Paint")
286 |
287 | # total number of processes
288 | args.world_size = args.gpus * args.nodes
289 |
290 | print("starting")
291 | # launch multiple processes
292 | mp.spawn(demo, nprocs=args.gpus, args=(args,), join=True)
293 |
--------------------------------------------------------------------------------
/semantic_guidance/train_renderer.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import torch
3 | import argparse
4 | import numpy as np
5 |
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from utils.tensorboard import TensorBoard
9 | from Renderer.model import FCN
10 | from Renderer.stroke_gen import *
11 |
12 | def save_model():
13 | if use_cuda:
14 | net.cpu()
15 | torch.save(net.state_dict(), "data/renderer.pkl")
16 | if use_cuda:
17 | net.cuda()
18 |
19 |
20 | def load_weights():
21 | pretrained_dict = torch.load("data/renderer.pkl")
22 | model_dict = net.state_dict()
23 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
24 | model_dict.update(pretrained_dict)
25 | net.load_state_dict(model_dict)
26 |
27 |
28 | if __name__ == "__main__":
29 | parser = argparse.ArgumentParser(description='Learning to Paint')
30 |
31 | # hyper-parameter
32 | parser.add_argument('--high_res', default=False, type=bool, help='resolution')
33 | args = parser.parse_args()
34 |
35 | writer = TensorBoard("./train_log/")
36 | import torch.optim as optim
37 |
38 | criterion = nn.MSELoss()
39 | net = FCN(args.high_res)
40 | optimizer = optim.Adam(net.parameters(), lr=3e-6)
41 | batch_size = 64
42 |
43 | use_cuda = torch.cuda.is_available()
44 | step = 0
45 |
46 | load_weights()
47 | while step < 500000:
48 | net.train()
49 | train_batch = []
50 | ground_truth = []
51 | for i in range(batch_size):
52 | f = np.random.uniform(0, 1, 10)
53 | train_batch.append(f)
54 | if args.high_res is True:
55 | ground_truth.append(draw(f, width=256))
56 | else:
57 | ground_truth.append(draw(f))
58 |
59 | train_batch = torch.tensor(train_batch).float()
60 | ground_truth = torch.tensor(ground_truth).float()
61 | if use_cuda:
62 | net = net.cuda()
63 | train_batch = train_batch.cuda()
64 | ground_truth = ground_truth.cuda()
65 | gen = net(train_batch)
66 | optimizer.zero_grad()
67 | loss = criterion(gen, ground_truth)
68 | loss.backward()
69 | optimizer.step()
70 | print(step, loss.item())
71 | if step < 200000:
72 | lr = 1e-4
73 | elif step < 400000:
74 | lr = 1e-5
75 | else:
76 | lr = 1e-6
77 | for param_group in optimizer.param_groups:
78 | param_group["lr"] = lr
79 | writer.add_scalar("train/loss", loss.item(), step)
80 | if step % 100 == 0:
81 | net.eval()
82 | gen = net(train_batch)
83 | loss = criterion(gen, ground_truth)
84 | writer.add_scalar("val/loss", loss.item(), step)
85 | for i in range(32):
86 | G = gen[i].cpu().data.numpy()
87 | GT = ground_truth[i].cpu().data.numpy()
88 | writer.add_image("train/gen{}.png".format(i), G, step)
89 | writer.add_image("train/ground_truth{}.png".format(i), GT, step)
90 | if step % 1000 == 0:
91 | save_model()
92 | step += 1
93 |
--------------------------------------------------------------------------------
/semantic_guidance/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1jsingh/semantic-guidance/092f923a0f5d42a949489bd49df8449db9d94d53/semantic_guidance/utils/__init__.py
--------------------------------------------------------------------------------
/semantic_guidance/utils/dataloader.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 | import os
3 | import torch
4 | import numpy as np
5 | from torchvision import transforms, utils
6 | import cv2
7 |
8 |
9 | def adjust_gamma(image, gamma=1.0):
10 | # build a lookup table mapping the pixel values [0, 255] to
11 | # their adjusted gamma values
12 | invGamma = 1.0 / gamma
13 | table = np.array([((i / 255.0) ** invGamma) * 255
14 | for i in np.arange(0, 256)]).astype("uint8")
15 | # apply gamma correction using the lookup table
16 | return cv2.LUT(image, table)
17 |
18 |
19 | class ImageDataset:
20 | """
21 | Dataset for loading images at run-time
22 | """
23 |
24 | def __init__(self, img_list, gbp_list=None, seg_list=None, transform=None, width=128,
25 | bbox_list=None, high_res=False):
26 | """
27 | Args:
28 | img_list (string): list of images
29 | transform (callable, optional): Optional transform to be applied
30 | on a sample.
31 | """
32 | self.img_list = img_list
33 | self.gbp_list = gbp_list
34 | self.seg_list = seg_list
35 | self.transform = transform
36 | self.width = width
37 | self.high_res = high_res
38 | self.bbox_list = bbox_list
39 |
40 | if self.high_res is True:
41 | self.width = 256
42 |
43 | # custom transforms
44 | self.horizontal_flip = transforms.RandomHorizontalFlip(p=1.0)
45 | self.resize = transforms.Resize((width, width))
46 | self.toTensor = transforms.ToTensor()
47 |
48 | def __len__(self):
49 | return len(self.img_list)
50 |
51 | def __getitem__(self, idx, gbp_img=None):
52 | if torch.is_tensor(idx):
53 | idx = idx.tolist()
54 |
55 | # flip horizontal
56 | flip_horizontal = np.random.rand() > 0.5
57 |
58 | # read rgb image
59 | img_name = self.img_list[idx]
60 | image = cv2.cvtColor(cv2.imread(img_name), cv2.COLOR_BGR2RGB)
61 | H, W, C = image.shape
62 | image = cv2.resize(image, (self.width, self.width))
63 | # image = adjust_gamma(image, gamma=1.5)
64 | if flip_horizontal:
65 | image = cv2.flip(image, 1)
66 | image = torch.tensor(image, dtype=torch.uint8).permute(2, 0, 1)
67 |
68 | # initialize gbp, seg_img image
69 | gbp_img, seg_img, grid = None, None, None
70 |
71 | # read gbp image
72 | if self.gbp_list is not None:
73 | gbp_fname = self.gbp_list[idx]
74 |
75 | gbp_img = cv2.cvtColor(cv2.imread(gbp_fname), cv2.COLOR_BGR2GRAY)
76 | gbp_img = cv2.resize(gbp_img, (self.width, self.width))
77 | if flip_horizontal:
78 | gbp_img = cv2.flip(gbp_img, 1)
79 | gbp_img = torch.tensor(gbp_img, dtype=torch.uint8).unsqueeze(0)
80 |
81 | if self.seg_list is not None:
82 | seg_fname = self.seg_list[idx]
83 | seg_img = cv2.cvtColor(cv2.imread(seg_fname), cv2.COLOR_BGR2GRAY)
84 | seg_img = cv2.resize(seg_img, (self.width, self.width))
85 | if flip_horizontal:
86 | seg_img = cv2.flip(seg_img, 1)
87 | # convert to tensor
88 | seg_img = torch.tensor(seg_img, dtype=torch.uint8).unsqueeze(0)
89 |
90 | # create grid
91 | x, y, w, h = self.bbox_list[idx]
92 | x, y, w, h = x / W, y / H, w / W, h / H
93 | if flip_horizontal:
94 | x = 1 - x - w
95 | Affine_Mat_w = [w, 0, (2 * x + w - 1)]
96 | Affine_Mat_h = [0, h, (2 * y + h - 1)]
97 | M = np.c_[Affine_Mat_w, Affine_Mat_h].T
98 | M = torch.tensor(M).unsqueeze(0)
99 | grid = torch.nn.functional.affine_grid(M, (1, 3, 128, 128)) # (1,128,128,2)
100 | grid = (grid + 1) / 2 # scale between 0,1
101 | grid = torch.tensor(grid * 255, dtype=torch.uint8).permute(0, 3, 1, 2)
102 |
103 | return image, gbp_img, seg_img, grid
--------------------------------------------------------------------------------
/semantic_guidance/utils/tensorboard.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import scipy.misc
3 | from io import BytesIO
4 | import tensorboardX as tb
5 | from tensorboardX.summary import Summary
6 |
7 | class TensorBoard(object):
8 | def __init__(self, model_dir):
9 | self.summary_writer = tb.FileWriter(model_dir)
10 |
11 | def add_image(self, tag, img, step):
12 | summary = Summary()
13 | bio = BytesIO()
14 |
15 | if type(img) == str:
16 | img = Image.open(img)
17 | elif type(img) == Image.Image:
18 | pass
19 | else:
20 | img = scipy.misc.toimage(img)
21 |
22 | img.save(bio, format="png")
23 | image_summary = Summary.Image(encoded_image_string=bio.getvalue())
24 | summary.value.add(tag=tag, image=image_summary)
25 | self.summary_writer.add_summary(summary, global_step=step)
26 |
27 | def add_scalar(self, tag, value, step):
28 | summary = Summary(value=[Summary.Value(tag=tag, simple_value=value)])
29 | self.summary_writer.add_summary(summary, global_step=step)
30 |
--------------------------------------------------------------------------------
/semantic_guidance/utils/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch.autograd import Variable
4 |
5 | USE_CUDA = torch.cuda.is_available()
6 |
7 | def prRed(prt): print("\033[91m {}\033[00m" .format(prt))
8 | def prGreen(prt): print("\033[92m {}\033[00m" .format(prt))
9 | def prYellow(prt): print("\033[93m {}\033[00m" .format(prt))
10 | def prLightPurple(prt): print("\033[94m {}\033[00m" .format(prt))
11 | def prPurple(prt): print("\033[95m {}\033[00m" .format(prt))
12 | def prCyan(prt): print("\033[96m {}\033[00m" .format(prt))
13 | def prLightGray(prt): print("\033[97m {}\033[00m" .format(prt))
14 | def prBlack(prt): print("\033[98m {}\033[00m" .format(prt))
15 |
16 | def to_numpy(var):
17 | return var.cpu().data.numpy() if USE_CUDA else var.data.numpy()
18 |
19 | def to_tensor(ndarray, device):
20 | return torch.tensor(ndarray, dtype=torch.float, device=device)
21 |
22 | def soft_update(target, source, tau):
23 | for target_param, param in zip(target.parameters(), source.parameters()):
24 | target_param.data.copy_(
25 | target_param.data * (1.0 - tau) + param.data * tau
26 | )
27 |
28 | def hard_update(target, source):
29 | for m1, m2 in zip(target.modules(), source.modules()):
30 | m1._buffers = m2._buffers.copy()
31 | for target_param, param in zip(target.parameters(), source.parameters()):
32 | target_param.data.copy_(param.data)
33 |
34 | def get_output_folder(parent_dir, env_name):
35 | """Return save folder.
36 |
37 | Assumes folders in the parent_dir have suffix -run{run
38 | number}. Finds the highest run number and sets the output folder
39 | to that number + 1. This is just convenient so that if you run the
40 | same script multiple times tensorboard can plot all of the results
41 | on the same plots with different names.
42 |
43 | Parameters
44 | ----------
45 | parent_dir: str
46 | Path of the directory containing all experiment runs.
47 |
48 | Returns
49 | -------
50 | parent_dir/run_dir
51 | Path to this run's save directory.
52 | """
53 | os.makedirs(parent_dir, exist_ok=True)
54 | experiment_id = 0
55 | for folder_name in os.listdir(parent_dir):
56 | if not os.path.isdir(os.path.join(parent_dir, folder_name)):
57 | continue
58 | try:
59 | folder_name = int(folder_name.split('-run')[-1])
60 | if folder_name > experiment_id:
61 | experiment_id = folder_name
62 | except:
63 | pass
64 | experiment_id += 1
65 |
66 | parent_dir = os.path.join(parent_dir, env_name)
67 | parent_dir = parent_dir + '-run{}'.format(experiment_id)
68 | os.makedirs(parent_dir, exist_ok=True)
69 | return parent_dir
70 |
--------------------------------------------------------------------------------
/video/baseline_target_bird_4648.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1jsingh/semantic-guidance/092f923a0f5d42a949489bd49df8449db9d94d53/video/baseline_target_bird_4648.mp4
--------------------------------------------------------------------------------
/video/sg_bird_5602.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1jsingh/semantic-guidance/092f923a0f5d42a949489bd49df8449db9d94d53/video/sg_bird_5602.gif
--------------------------------------------------------------------------------
/video/sg_target_bird_4648.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1jsingh/semantic-guidance/092f923a0f5d42a949489bd49df8449db9d94d53/video/sg_target_bird_4648.mp4
--------------------------------------------------------------------------------