├── .gitignore ├── README.md ├── assets ├── bird_4008.gif ├── bird_4648.gif ├── bird_5602.gif ├── sg_bird_4008.gif ├── sg_bird_4648.gif ├── sg_bird_5602.gif ├── target_bird_4008.png ├── target_bird_4648.png └── target_bird_5602.png ├── environment.yml ├── input ├── target_bird_4008.png ├── target_bird_4648.png └── target_bird_5602.png ├── output ├── baseline_painted_target_bird_4008.png ├── baseline_painted_target_bird_4648.png ├── sg_painted_target_bird_4008.png └── sg_painted_target_bird_4648.png ├── semantic_guidance ├── DRL │ ├── actor.py │ ├── critic.py │ ├── ddpg.py │ ├── evaluator.py │ ├── multi.py │ ├── rpm.py │ └── wgan.py ├── Renderer │ ├── __init__.py │ ├── model.py │ └── stroke_gen.py ├── __init__.py ├── env_ins.py ├── preprocess.py ├── test.py ├── test_utils.py ├── train.py ├── train_renderer.py └── utils │ ├── __init__.py │ ├── dataloader.py │ ├── tensorboard.py │ └── util.py └── video ├── baseline_target_bird_4648.mp4 ├── sg_bird_5602.gif └── sg_target_bird_4648.mp4 /.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | #video 3 | .DS_Store 4 | .idea/ 5 | baseline_spc 6 | semantic_guidance/test_v2* 7 | semantic_guidance/test_copy* 8 | files 9 | semantic_guidance/pretrained_models -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Semantic-Guidance: Distilling Object Awareness into Paintings [![Tweet](https://img.shields.io/twitter/url/http/shields.io.svg?style=social)](https://twitter.com/intent/tweet?text=Automatically%20generate%20human-level%20paintings%20using%20a%20combination%20of%20Deep-RL%20and%20Semantic-Guidance&url=https://github.com/1jsingh/semantic-guidance&&hashtags=LearningToPaint,CVPR2021) 2 | 3 | This repository contains code for our CVPR-2021 paper on [Combining Semantic Guidance and Deep Reinforcement Learning For Generating Human Level Paintings](https://arxiv.org/pdf/2011.12589.pdf). 4 | 5 | The Semantic Guidance pipeline distills different forms of object awareness (semantic segmentation, object localization and guided backpropagation maps) into the painting process itself. The resulting agent is able to paint canvases with increased saliency of foreground objects and enhanced granularity of key image features. 6 | 7 | 11 | ## Contents 12 | * [Demo](#demo) 13 | * [Environment Setup](#environment-setup) 14 | * [Dataset and Preprocessing](#dataset-and-preprocessing) 15 | * [Training](#training) 16 | * [Testing using Pretrained Models](#testing-using-pretrained-models) 17 | * [Citation](#citation) 18 | 19 | 20 | ## Demo 21 | Traditional reinforcement learning based methods for the "*learning to paint*" problem, show poor performance on real world datasets with high variance in position, scale and saliency of the foreground objects. To address this we propose a semantic guidance pipeline, which distills object awareness knowledge into the painting process, and thereby learns to generate semantically accurate canvases under adverse painting conditions. 22 | 23 | | Target Image | Baseline (Huang et al. 2019) | Semantic Guidance (Ours) | 24 | |:-------------:|:-------------:|:-------------:| 25 | |||| 26 | |||| 27 | |||| 28 | 29 | 30 | 31 | ### Environment Setup 32 | 33 | * Set up the python environment for running the experiments. 34 | ```bash 35 | conda env update --name semantic-guidance --file environment.yml 36 | conda activate semantic-guidance 37 | ``` 38 | 39 | ### Dataset and Preprocessing 40 | * Download [CUB-200-2011 Birds](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html) dataset and place it in the `data/cub200/CUB_200_2011/` folder. 41 | ```bash 42 | mkdir -p data/cub200 && cd data/cub200 43 | gdown https://drive.google.com/uc?id=1hbzc_P1FuxMkcabkgn9ZKinBwW683j45 44 | tar -xvzf CUB_200_2011.tgz 45 | ``` 46 | 47 | * The final data folder looks as follows, 48 | ```bash 49 | data 50 | ├── cub200/ 51 | │ └── CUB_200_2011/ 52 | │ └── images/ 53 | │ └── ... 54 | │ └── images.txt 55 | ``` 56 | 57 | * Download differentiable neural renderer: [renderer.pkl](https://drive.google.com/file/d/1VloSGAWYRiVYv3bRfBuB0uKj2m7Cyzu8/view?usp=sharing) and place it in the `data/.` folder. 58 | ```bash 59 | cd data 60 | gdown https://drive.google.com/uc?id=1VloSGAWYRiVYv3bRfBuB0uKj2m7Cyzu8 61 | ``` 62 | 63 | * Download combined model for object localization and semantic segmentation from [here](https://drive.google.com/file/d/14CIdpem-85y53KkkW2oBspXh-2PPXtTs/view?usp=sharing), and place it in place it in the `data/.` folder. 64 | ```bash 65 | cd data 66 | gdown https://drive.google.com/uc?id=14CIdpem-85y53KkkW2oBspXh-2PPXtTs 67 | ``` 68 | 69 | * Choose one of the following options to get preprocessed data predictions (preprocessing helps faciliate faster training), 70 | * **Option 1:** run the preprocessing script to generate object localization, semantic segmentation and bounding box predictions. 71 | ```bash 72 | cd semantic_guidance 73 | python preprocess.py 74 | ``` 75 | 76 | * **Option 2:** you can also directly download the preprocessed birds dataset from [here](https://drive.google.com/file/d/1s3lvo0Dn538lPghpXY1gEOTAZOTsojxJ/view?usp=sharing), and place the prediction folders in the original data directory. 77 | ```bash 78 | cd data/cub200/CUB_200_2011/ 79 | gdown https://drive.google.com/uc?id=1s3lvo0Dn538lPghpXY1gEOTAZOTsojxJ 80 | unzip preprocessed-cub200-2011.zip 81 | mv preprocessed-cub200-2011/* . 82 | ``` 83 | 84 | - The final data directory should look like: 85 | ```bash 86 | data 87 | ├── cub200/ 88 | │ └── CUB_200_2011/ 89 | │ └── images/ 90 | │ └── ... 91 | │ └── segmentations_pred/ 92 | │ └── ... 93 | │ └── gbp_global/ 94 | │ └── ... 95 | │ └── bounding_boxes_pred.txt 96 | │ └── images.txt 97 | └── renderer.pkl 98 | └── birds_obj_seg.pkl 99 | ``` 100 | 101 | ### Training 102 | 103 | * Train the baseline model from [Huang et al. 2019](https://arxiv.org/abs/1903.04411) 104 | ```bash 105 | cd semantic_guidance 106 | python train.py \ 107 | --dataset cub200 \ 108 | --debug \ 109 | --batch_size=96 \ 110 | --max_eps_len=50 \ 111 | --bundle_size=5 \ 112 | --exp_suffix baseline 113 | ``` 114 | 115 | * Train the deep reinforcement learning based painting agent using Semantic Guidance pipeline. 116 | ```bash 117 | cd semantic_guidance 118 | python train.py \ 119 | --dataset cub200 \ 120 | --debug \ 121 | --batch_size=96 \ 122 | --max_eps_len=50 \ 123 | --bundle_size=5 \ 124 | --use_bilevel \ 125 | --use_gbp \ 126 | --exp_suffix semantic-guidance 127 | ``` 128 | 129 | ### Testing using Pretrained Models 130 | 131 | * Download the pretrained models for the [Baseline](https://drive.google.com/file/d/1OvN7yRia44nhD16KmjcAvxG8xICWl42p/view?usp=sharing) and [Semantic Guidance](https://drive.google.com/file/d/173p2rUQlNpp8fLA3u5s24QKJLU68QTkw/view?usp=sharing) agents. Place the downloaded models in `./semantic_guidance/pretrained_models` directory. 132 | ```bash 133 | cd semantic_guidance 134 | mkdir pretrained_models && cd pretrained_models 135 | gdown https://drive.google.com/uc?id=1OvN7yRia44nhD16KmjcAvxG8xICWl42p 136 | gdown https://drive.google.com/uc?id=173p2rUQlNpp8fLA3u5s24QKJLU68QTkw 137 | ``` 138 | 139 | * The final directory structure should look as follows, 140 | ```bash 141 | semantic-guidance 142 | ├── semantic_guidance/ 143 | │ └── pretrained_models/ 144 | │ └── actor_baseline.pkl 145 | │ └── actor_semantic_guidance.pkl 146 | ``` 147 | 148 | * Generate the painting sequence using pretrained baseline agent. 149 | ```bash 150 | cd semantic_guidance 151 | python test.py \ 152 | --img ../input/target_bird_4648.png \ 153 | --actor pretrained_models/actor_baseline.pkl \ 154 | --use_baseline 155 | ``` 156 | 157 | * Use the pretrained Semantic Guidance agent to paint canvases. 158 | ```bash 159 | cd semantic_guidance 160 | python test.py \ 161 | --img ../input/target_bird_4648.png \ 162 | --actor pretrained_models/actor_semantic_guidance.pkl 163 | ``` 164 | 165 | * The `test` script stores the final canvas state in the `./output` folder and saves a video for the painting sequence in the `./video` directory. 166 | 167 | 168 | # Citation 169 | 170 | If you find this work useful in your research, please cite our paper: 171 | ``` 172 | @inproceedings{singh2021combining, 173 | title={Combining Semantic Guidance and Deep Reinforcement Learning For Generating Human Level Paintings}, 174 | author={Singh, Jaskirat and Zheng, Liang}, 175 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 176 | pages={16387--16396}, 177 | year={2021} 178 | } 179 | ``` 180 | 181 | 184 | -------------------------------------------------------------------------------- /assets/bird_4008.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1jsingh/semantic-guidance/092f923a0f5d42a949489bd49df8449db9d94d53/assets/bird_4008.gif -------------------------------------------------------------------------------- /assets/bird_4648.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1jsingh/semantic-guidance/092f923a0f5d42a949489bd49df8449db9d94d53/assets/bird_4648.gif -------------------------------------------------------------------------------- /assets/bird_5602.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1jsingh/semantic-guidance/092f923a0f5d42a949489bd49df8449db9d94d53/assets/bird_5602.gif -------------------------------------------------------------------------------- /assets/sg_bird_4008.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1jsingh/semantic-guidance/092f923a0f5d42a949489bd49df8449db9d94d53/assets/sg_bird_4008.gif -------------------------------------------------------------------------------- /assets/sg_bird_4648.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1jsingh/semantic-guidance/092f923a0f5d42a949489bd49df8449db9d94d53/assets/sg_bird_4648.gif -------------------------------------------------------------------------------- /assets/sg_bird_5602.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1jsingh/semantic-guidance/092f923a0f5d42a949489bd49df8449db9d94d53/assets/sg_bird_5602.gif -------------------------------------------------------------------------------- /assets/target_bird_4008.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1jsingh/semantic-guidance/092f923a0f5d42a949489bd49df8449db9d94d53/assets/target_bird_4008.png -------------------------------------------------------------------------------- /assets/target_bird_4648.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1jsingh/semantic-guidance/092f923a0f5d42a949489bd49df8449db9d94d53/assets/target_bird_4648.png -------------------------------------------------------------------------------- /assets/target_bird_5602.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1jsingh/semantic-guidance/092f923a0f5d42a949489bd49df8449db9d94d53/assets/target_bird_5602.png -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - conda-forge 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - bzip2=1.0.8=h7b6447c_0 7 | - ca-certificates=2020.7.22=0 8 | - certifi=2020.6.20=py36_0 9 | - freetype=2.10.2=h5ab3b9f_0 10 | - gdown=3.13.0=pyhd8ed1ab_0 11 | - gmp=6.1.2=h6c8ec71_1 12 | - gnutls=3.6.5=h71b1129_1002 13 | - htop=2.2.0=hf8c457e_1000 14 | - lame=3.100=h7b6447c_0 15 | - ld_impl_linux-64=2.33.1=h53a641e_7 16 | - libedit=3.1.20191231=h14c3975_1 17 | - libffi=3.3=he6710b0_2 18 | - libgcc-ng=9.1.0=hdf63c60_0 19 | - libiconv=1.16=h516909a_0 20 | - libopus=1.3.1=h7b6447c_0 21 | - libpng=1.6.37=hbc83047_0 22 | - libstdcxx-ng=9.1.0=hdf63c60_0 23 | - libvpx=1.7.0=h439df22_0 24 | - nano=2.9.8=hddfc1eb_1001 25 | - ncurses=6.2=he6710b0_1 26 | - nettle=3.4.1=hbb512f6_0 27 | - openh264=2.1.0=hd408876_0 28 | - openssl=1.1.1g=h7b6447c_0 29 | - pip=20.2.2=py36_0 30 | - python=3.6.10=h7579374_2 31 | - python_abi=3.6=1_cp36m 32 | - readline=8.0=h7b6447c_0 33 | - setuptools=49.6.0=py36_0 34 | - sqlite=3.32.3=h62c20be_0 35 | - tk=8.6.10=hbc83047_0 36 | - unzip=6.0=h611a1e1_0 37 | - wheel=0.34.2=py36_0 38 | - x264=1!157.20191217=h7b6447c_0 39 | - xz=5.2.5=h7b6447c_0 40 | - zlib=1.2.11=h7b6447c_3 41 | - pip: 42 | - absl-py==0.10.0 43 | - argon2-cffi==20.1.0 44 | - attrs==20.1.0 45 | - backcall==0.2.0 46 | - bleach==3.1.5 47 | - cachetools==4.1.1 48 | - cffi==1.14.2 49 | - chardet==3.0.4 50 | - cycler==0.10.0 51 | - cython==0.29.21 52 | - decorator==4.4.2 53 | - defusedxml==0.6.0 54 | - efficientnet-pytorch==0.7.0 55 | - entrypoints==0.3 56 | - ffmpeg==1.4 57 | - future==0.18.2 58 | - google-auth==1.21.1 59 | - google-auth-oauthlib==0.4.1 60 | - grpcio==1.31.0 61 | - h5py==2.10.0 62 | - idna==2.10 63 | - importlib-metadata==1.7.0 64 | - ipykernel==5.3.4 65 | - ipython==7.16.1 66 | - ipython-genutils==0.2.0 67 | - ipywidgets==7.5.1 68 | - jedi==0.17.2 69 | - jinja2==2.11.2 70 | - joblib==0.16.0 71 | - jsonschema==3.2.0 72 | - jupyter==1.0.0 73 | - jupyter-client==6.1.6 74 | - jupyter-console==6.1.0 75 | - jupyter-core==4.6.3 76 | - kiwisolver==1.2.0 77 | - markdown==3.2.2 78 | - markupsafe==1.1.1 79 | - matplotlib==3.3.1 80 | - mistune==0.8.4 81 | - nbconvert==5.6.1 82 | - nbformat==5.0.7 83 | - notebook==6.1.3 84 | - numpy==1.19.1 85 | - oauthlib==3.1.0 86 | - opencv-python==4.4.0.42 87 | - packaging==20.4 88 | - pandas==1.1.1 89 | - pandocfilters==1.4.2 90 | - parso==0.7.1 91 | - pexpect==4.8.0 92 | - pickleshare==0.7.5 93 | - pillow==7.2.0 94 | - prometheus-client==0.8.0 95 | - prompt-toolkit==3.0.6 96 | - protobuf==3.13.0 97 | - ptyprocess==0.6.0 98 | - pyasn1==0.4.8 99 | - pyasn1-modules==0.2.8 100 | - pycparser==2.20 101 | - pygments==2.6.1 102 | - pyparsing==2.4.7 103 | - pyrsistent==0.16.0 104 | - python-dateutil==2.8.1 105 | - pytz==2020.1 106 | - pyzmq==19.0.2 107 | - qtconsole==4.7.6 108 | - qtpy==1.9.0 109 | - requests==2.24.0 110 | - requests-oauthlib==1.3.0 111 | - rsa==4.6 112 | - scikit-learn==0.23.2 113 | - scipy==1.2.0 114 | - seaborn==0.10.1 115 | - send2trash==1.5.0 116 | - six==1.15.0 117 | - tensorboard==2.3.0 118 | - tensorboard-plugin-wit==1.7.0 119 | - tensorboardx==2.1 120 | - termcolor==1.1.0 121 | - terminado==0.8.3 122 | - testpath==0.4.4 123 | - threadpoolctl==2.1.0 124 | - torch==1.6.0 125 | - torchgeometry==0.1.2 126 | - torchvision==0.7.0 127 | - tornado==6.0.4 128 | - traitlets==4.3.3 129 | - urllib3==1.25.10 130 | - wcwidth==0.2.5 131 | - webencodings==0.5.1 132 | - werkzeug==1.0.1 133 | - widgetsnbextension==3.5.1 134 | - zipp==3.1.0 135 | 136 | -------------------------------------------------------------------------------- /input/target_bird_4008.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1jsingh/semantic-guidance/092f923a0f5d42a949489bd49df8449db9d94d53/input/target_bird_4008.png -------------------------------------------------------------------------------- /input/target_bird_4648.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1jsingh/semantic-guidance/092f923a0f5d42a949489bd49df8449db9d94d53/input/target_bird_4648.png -------------------------------------------------------------------------------- /input/target_bird_5602.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1jsingh/semantic-guidance/092f923a0f5d42a949489bd49df8449db9d94d53/input/target_bird_5602.png -------------------------------------------------------------------------------- /output/baseline_painted_target_bird_4008.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1jsingh/semantic-guidance/092f923a0f5d42a949489bd49df8449db9d94d53/output/baseline_painted_target_bird_4008.png -------------------------------------------------------------------------------- /output/baseline_painted_target_bird_4648.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1jsingh/semantic-guidance/092f923a0f5d42a949489bd49df8449db9d94d53/output/baseline_painted_target_bird_4648.png -------------------------------------------------------------------------------- /output/sg_painted_target_bird_4008.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1jsingh/semantic-guidance/092f923a0f5d42a949489bd49df8449db9d94d53/output/sg_painted_target_bird_4008.png -------------------------------------------------------------------------------- /output/sg_painted_target_bird_4648.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1jsingh/semantic-guidance/092f923a0f5d42a949489bd49df8449db9d94d53/output/sg_painted_target_bird_4648.png -------------------------------------------------------------------------------- /semantic_guidance/DRL/actor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.nn.utils.weight_norm as weightNorm 7 | 8 | from torch.autograd import Variable 9 | import sys 10 | 11 | 12 | def conv3x3(in_planes, out_planes, stride=1): 13 | return (nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)) 14 | 15 | 16 | def cfg(depth): 17 | depth_lst = [18, 34, 50, 101, 152] 18 | assert (depth in depth_lst), "Error : Resnet depth should be either 18, 34, 50, 101, 152" 19 | cf_dict = { 20 | '18': (BasicBlock, [2, 2, 2, 2]), 21 | '34': (BasicBlock, [3, 4, 6, 3]), 22 | '50': (Bottleneck, [3, 4, 6, 3]), 23 | '101': (Bottleneck, [3, 4, 23, 3]), 24 | '152': (Bottleneck, [3, 8, 36, 3]), 25 | } 26 | 27 | return cf_dict[str(depth)] 28 | 29 | 30 | class BasicBlock(nn.Module): 31 | expansion = 1 32 | 33 | def __init__(self, in_planes, planes, stride=1): 34 | super(BasicBlock, self).__init__() 35 | self.conv1 = conv3x3(in_planes, planes, stride) 36 | self.bn1 = nn.BatchNorm2d(planes) 37 | self.conv2 = conv3x3(planes, planes) 38 | self.bn2 = nn.BatchNorm2d(planes) 39 | 40 | self.shortcut = nn.Sequential() 41 | if stride != 1 or in_planes != self.expansion * planes: 42 | self.shortcut = nn.Sequential( 43 | (nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False)), 44 | nn.BatchNorm2d(self.expansion * planes) 45 | ) 46 | 47 | def forward(self, x): 48 | out = F.relu(self.bn1(self.conv1(x))) 49 | out = self.bn2(self.conv2(out)) 50 | out += self.shortcut(x) 51 | out = F.relu(out) 52 | 53 | return out 54 | 55 | 56 | class Bottleneck(nn.Module): 57 | expansion = 4 58 | 59 | def __init__(self, in_planes, planes, stride=1): 60 | super(Bottleneck, self).__init__() 61 | self.conv1 = (nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)) 62 | self.conv2 = (nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)) 63 | self.conv3 = (nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)) 64 | self.bn1 = nn.BatchNorm2d(planes) 65 | self.bn2 = nn.BatchNorm2d(planes) 66 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 67 | 68 | self.shortcut = nn.Sequential() 69 | if stride != 1 or in_planes != self.expansion * planes: 70 | self.shortcut = nn.Sequential( 71 | (nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False)), 72 | ) 73 | 74 | def forward(self, x): 75 | out = F.relu(self.bn1(self.conv1(x))) 76 | out = F.relu(self.bn2(self.conv2(out))) 77 | out = self.bn3(self.conv3(out)) 78 | out += self.shortcut(x) 79 | out = F.relu(out) 80 | 81 | return out 82 | 83 | 84 | class ResNet(nn.Module): 85 | def __init__(self, num_inputs, depth, num_outputs, high_res=False): 86 | super(ResNet, self).__init__() 87 | self.in_planes = 64 88 | 89 | block, num_blocks = cfg(depth) 90 | 91 | self.conv1 = conv3x3(num_inputs, 64, 2) 92 | self.bn1 = nn.BatchNorm2d(64) 93 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=2) 94 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 95 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 96 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 97 | self.fc = nn.Linear(512, num_outputs) 98 | self.high_res = high_res 99 | 100 | def _make_layer(self, block, planes, num_blocks, stride): 101 | strides = [stride] + [1] * (num_blocks - 1) 102 | layers = [] 103 | 104 | for stride in strides: 105 | layers.append(block(self.in_planes, planes, stride)) 106 | self.in_planes = planes * block.expansion 107 | 108 | return nn.Sequential(*layers) 109 | 110 | def forward(self, x): 111 | x = F.relu(self.bn1(self.conv1(x))) 112 | x = self.layer1(x) 113 | x = self.layer2(x) 114 | x = self.layer3(x) 115 | x = self.layer4(x) 116 | if self.high_res is True: 117 | x = F.avg_pool2d(x, 8) 118 | else: 119 | x = F.avg_pool2d(x, 4) 120 | x = x.view(x.size(0), -1) 121 | x = self.fc(x) 122 | x = torch.sigmoid(x) 123 | return x 124 | -------------------------------------------------------------------------------- /semantic_guidance/DRL/critic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.utils.weight_norm as weightNorm 5 | 6 | from torch.autograd import Variable 7 | import sys 8 | 9 | 10 | def conv3x3(in_planes, out_planes, stride=1): 11 | return weightNorm(nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True)) 12 | 13 | 14 | class TReLU(nn.Module): 15 | def __init__(self): 16 | super(TReLU, self).__init__() 17 | self.alpha = nn.Parameter(torch.FloatTensor(1), requires_grad=True) 18 | self.alpha.data.fill_(0) 19 | 20 | def forward(self, x): 21 | x = F.relu(x - self.alpha) + self.alpha 22 | return x 23 | 24 | 25 | def cfg(depth): 26 | depth_lst = [18, 34, 50, 101, 152] 27 | assert (depth in depth_lst), "Error : Resnet depth should be either 18, 34, 50, 101, 152" 28 | cf_dict = { 29 | '18': (BasicBlock, [2, 2, 2, 2]), 30 | '34': (BasicBlock, [3, 4, 6, 3]), 31 | '50': (Bottleneck, [3, 4, 6, 3]), 32 | '101': (Bottleneck, [3, 4, 23, 3]), 33 | '152': (Bottleneck, [3, 8, 36, 3]), 34 | } 35 | 36 | return cf_dict[str(depth)] 37 | 38 | 39 | class BasicBlock(nn.Module): 40 | expansion = 1 41 | 42 | def __init__(self, in_planes, planes, stride=1): 43 | super(BasicBlock, self).__init__() 44 | self.conv1 = conv3x3(in_planes, planes, stride) 45 | self.conv2 = conv3x3(planes, planes) 46 | 47 | self.shortcut = nn.Sequential() 48 | if stride != 1 or in_planes != self.expansion * planes: 49 | self.shortcut = nn.Sequential( 50 | weightNorm(nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=True)), 51 | ) 52 | self.relu_1 = TReLU() 53 | self.relu_2 = TReLU() 54 | 55 | def forward(self, x): 56 | out = self.relu_1(self.conv1(x)) 57 | out = self.conv2(out) 58 | out += self.shortcut(x) 59 | out = self.relu_2(out) 60 | 61 | return out 62 | 63 | 64 | class Bottleneck(nn.Module): 65 | expansion = 4 66 | 67 | def __init__(self, in_planes, planes, stride=1): 68 | super(Bottleneck, self).__init__() 69 | self.conv1 = weightNorm(nn.Conv2d(in_planes, planes, kernel_size=1, bias=True)) 70 | self.conv2 = weightNorm(nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True)) 71 | self.conv3 = weightNorm(nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=True)) 72 | self.relu_1 = TReLU() 73 | self.relu_2 = TReLU() 74 | self.relu_3 = TReLU() 75 | 76 | self.shortcut = nn.Sequential() 77 | if stride != 1 or in_planes != self.expansion * planes: 78 | self.shortcut = nn.Sequential( 79 | weightNorm(nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=True)), 80 | ) 81 | 82 | def forward(self, x): 83 | out = self.relu_1(self.conv1(x)) 84 | out = self.relu_2(self.conv2(out)) 85 | out = self.conv3(out) 86 | out += self.shortcut(x) 87 | out = self.relu_3(out) 88 | 89 | return out 90 | 91 | 92 | class ResNet_wobn(nn.Module): 93 | def __init__(self, num_inputs, depth, num_outputs, high_res=False): 94 | super(ResNet_wobn, self).__init__() 95 | self.in_planes = 64 96 | 97 | block, num_blocks = cfg(depth) 98 | 99 | self.conv1 = conv3x3(num_inputs, 64, 2) 100 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=2) 101 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 102 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 103 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 104 | self.fc = nn.Linear(512, num_outputs) 105 | self.relu_1 = TReLU() 106 | self.high_res = high_res 107 | 108 | def _make_layer(self, block, planes, num_blocks, stride): 109 | strides = [stride] + [1] * (num_blocks - 1) 110 | layers = [] 111 | 112 | for stride in strides: 113 | layers.append(block(self.in_planes, planes, stride)) 114 | self.in_planes = planes * block.expansion 115 | 116 | return nn.Sequential(*layers) 117 | 118 | def forward(self, x): 119 | x = self.relu_1(self.conv1(x)) 120 | x = self.layer1(x) 121 | x = self.layer2(x) 122 | x = self.layer3(x) 123 | x = self.layer4(x) 124 | if self.high_res is True: 125 | x = F.avg_pool2d(x, 8) 126 | else: 127 | x = F.avg_pool2d(x, 4) 128 | x = x.view(x.size(0), -1) 129 | x = self.fc(x) 130 | return x 131 | -------------------------------------------------------------------------------- /semantic_guidance/DRL/ddpg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.optim import Adam, SGD 6 | from torch.distributions import Categorical 7 | import Renderer.model as renderer 8 | from DRL.rpm import rpm 9 | from DRL.actor import ResNet 10 | from DRL.critic import ResNet_wobn 11 | from DRL.wgan import WGAN 12 | import utils.util as util 13 | import torch.distributed as dist 14 | from torch.nn.parallel import DistributedDataParallel as DDP 15 | from DRL.multi import fastenv 16 | 17 | # setup device 18 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 19 | criterion = nn.MSELoss() 20 | 21 | 22 | def cal_trans(s, t): 23 | return (s.transpose(0, 3) * t).transpose(0, 3) 24 | 25 | 26 | class DDPG(object): 27 | """ 28 | Model-based DDPG agent class 29 | """ 30 | 31 | def __init__(self, batch_size=64, nenv=1, max_eps_len=40, tau=0.001, discount=0.9, rmsize=800, 32 | writer=None, load_path=None, output_path=None, dataset='celeba', use_gbp=False, use_bilevel=False, 33 | gbp_coef=1.0, seggt_coef=1.0, gpu=0, distributed=False, high_res=False, 34 | bundle_size=5): 35 | # hyperparameters 36 | self.max_eps_len = max_eps_len 37 | self.nenv = nenv 38 | self.batch_size = batch_size 39 | self.gpu = gpu 40 | self.distributed = distributed 41 | self.bundle_size = bundle_size 42 | 43 | # set torch device 44 | torch.cuda.set_device(gpu) 45 | 46 | # gbp and seggt rewards 47 | self.use_gbp = use_gbp 48 | self.use_bilevel = use_bilevel 49 | self.gbp_coef = gbp_coef 50 | self.seggt_coef = seggt_coef 51 | 52 | # Multi-res and high-res 53 | self.high_res = high_res 54 | 55 | # environment 56 | self.env = fastenv(max_eps_len, nenv, writer, dataset, use_gbp, use_bilevel, gpu=gpu, high_res=self.high_res, 57 | bundle_size=bundle_size) 58 | 59 | # setup local and target actor, critic networks 60 | # input: target, canvas, stepnum, coordconv + gbp 3 + 3 + 1 + 2 61 | # output: (10+3)*5 (action bundle) 62 | self.actor = ResNet(9 + use_gbp + use_bilevel + 2*use_bilevel, 18, 13 * bundle_size * (1 + use_bilevel), self.high_res) 63 | self.actor_target = ResNet(9 + use_gbp + use_bilevel + 2*use_bilevel, 18, 13 * bundle_size * (1 + use_bilevel), 64 | self.high_res) 65 | self.critic = ResNet_wobn(3 + 9 + use_gbp + use_bilevel + 2*use_bilevel, 18, 1, 66 | self.high_res) # add the last canvas for better prediction 67 | self.critic_target = ResNet_wobn(3 + 9 + use_gbp + use_bilevel + 2*use_bilevel, 18, 1, self.high_res) 68 | 69 | for param in self.actor_target.parameters(): 70 | param.requires_grad = False 71 | 72 | for param in self.critic_target.parameters(): 73 | param.requires_grad = False 74 | 75 | # define gan 76 | self.wgan = WGAN(gpu, distributed, high_res=self.high_res) 77 | self.wgan_bg = WGAN(gpu, distributed, high_res=self.high_res) 78 | 79 | # transfer models to gpu 80 | self.choose_device() 81 | 82 | # optimizers for actor/critic models 83 | self.actor_optim = Adam(self.actor.parameters(), lr=1e-2) 84 | self.critic_optim = Adam(self.critic.parameters(), lr=1e-2) 85 | 86 | # load actor/critic/gan models if given a load path 87 | if (load_path != None): 88 | self.load_weights(load_path) 89 | 90 | # set same initial weights for local and target networks 91 | util.hard_update(self.actor_target, self.actor) 92 | util.hard_update(self.critic_target, self.critic) 93 | 94 | # Create replay buffer (to store max rmsize hyperparameters) 95 | self.memory = rpm(rmsize * max_eps_len, gpu) 96 | 97 | # training hyper-parameters 98 | self.tau = tau 99 | self.discount = discount 100 | 101 | # initialize summary logs 102 | self.writer = writer 103 | self.log = 0 104 | 105 | # initialize state, action logs 106 | self.state = [None] * self.nenv # Most recent state 107 | self.action = [None] * self.nenv # Most recent action 108 | 109 | self.get_coord_feats() 110 | 111 | def get_coord_feats(self): 112 | # x,y coordinates for 128 x 128 or 256x256 image 113 | if self.high_res is True: 114 | coord = torch.zeros([1, 2, 256, 256]) 115 | for i in range(256): 116 | for j in range(256): 117 | coord[0, 0, i, j] = i / 255. 118 | coord[0, 1, i, j] = j / 255. 119 | else: 120 | coord = torch.zeros([1, 2, 128, 128]) 121 | for i in range(128): 122 | for j in range(128): 123 | coord[0, 0, i, j] = i / 127. 124 | coord[0, 1, i, j] = j / 127. 125 | 126 | self.coord = coord.cuda(self.gpu) 127 | 128 | def act(self, state, target=False): 129 | """ 130 | take action for current state 131 | :param state: merged state (canvas,gt,t/max_eps_len,coordinates) 132 | :param target: bool for whether the target actor model should be used. 133 | :return: action (stroke parameters) 134 | """ 135 | if self.high_res is True: 136 | state_list = [state[:, :6].float() / 255, state[:, 6:7].float() / self.max_eps_len, 137 | self.coord.expand(state.shape[0], 2, 256, 256)] 138 | else: 139 | state_list = [state[:, :6].float() / 255, state[:, 6:7].float() / self.max_eps_len, 140 | self.coord.expand(state.shape[0], 2, 128, 128)] 141 | 142 | if self.use_gbp: 143 | state_list += [state[:, 7:8].float() / 255.] 144 | if self.use_bilevel: 145 | state_list += [state[:, 7 + self.use_gbp: 7 + self.use_gbp + 1 + 2].float() / 255.] 146 | # state_list += [state[:, 7 + self.use_gbp + 1: 7 + self.use_gbp + 1].float() / 255.] 147 | # define merged state (canvas,gt,t/max_eps_len,coordinates) 148 | state = torch.cat(state_list, 1) 149 | 150 | if target: 151 | return self.actor_target(state) 152 | else: 153 | return self.actor(state) 154 | 155 | def update_gan(self, state): 156 | """ 157 | update WGAN based on current state (first 3 channels for canvas and last three for groundtruth) 158 | :param state: 159 | :return: None 160 | """ 161 | # get canvas and groundtruth from the state 162 | canvas = state[:, :3].float() / 255 163 | gt = state[:, 3: 6].float() / 255 164 | 165 | # update gan based on canvas and background groundtruth images 166 | fake, real, penal = self.wgan_bg.update(canvas, gt) 167 | 168 | if self.use_bilevel: 169 | seggt = state[:, 7 + self.use_gbp: 7 + self.use_gbp + 1].float() / 255. 170 | grid = state[:, 7 + self.use_gbp + 1: 7 + self.use_gbp + 1 + 2].float() / 255. 171 | grid = 2 * grid - 1 172 | canvas = torch.nn.functional.grid_sample(canvas * seggt, grid.permute(0, 2, 3, 1)) 173 | gt = torch.nn.functional.grid_sample(gt * seggt, grid.permute(0, 2, 3, 1)) 174 | 175 | # update gan based on canvas and groundtruth images 176 | fake, real, penal = self.wgan.update(canvas, gt) 177 | 178 | 179 | def evaluate(self, state, action, target=False): 180 | """ 181 | compute model performance (rewards) for given state, action (used for both training and testing 182 | based on whether target or local network is used) 183 | :param state: combined state (canvas,ground-truth) 184 | :param action: stroke parameters (10 for position and 3 for color) 185 | :param target: bool for whether the target critic model should be used. 186 | :return: critic value + gan reward (used for training actor in model-based DDPG), gan reward 187 | """ 188 | 189 | # get canvas, ground-truth, time from merged state (gt,canvas,t) 190 | gt = state[:, 3: 6].float() / 255 191 | canvas0 = state[:, :3].float() / 255 192 | T = state[:, 6: 7] 193 | 194 | if self.use_gbp: 195 | gbpgt = state[:, 7:8].float() / 255. 196 | if self.use_bilevel: 197 | seggt = state[:, 7 + self.use_gbp: 7 + self.use_gbp + 1].float() / 255. 198 | grid = state[:, 7 + self.use_gbp + 1: 7 + self.use_gbp + 1 + 2].float() / 255. 199 | grid_ = 2 * grid - 1 200 | 201 | # update canvas given current action 202 | if self.use_bilevel: 203 | canvas1, stroke_masks = self.env.env.decode_parallel(action, canvas0, 204 | mask=self.use_gbp or self.use_bilevel, 205 | seg_mask=seggt) 206 | else: 207 | canvas1, stroke_masks = self.env.env.decode(action, canvas0, mask=self.use_gbp or self.use_bilevel) 208 | 209 | # if self.use_bilevel: 210 | # # canvas1 = canvas1 * seggt 211 | # # canvas0 = canvas0 * seggt 212 | 213 | # compute bg gan reward based on difference between wgan distances (L_t - L_t-1) 214 | bg_reward = self.wgan_bg.cal_reward(canvas1, gt) - self.wgan_bg.cal_reward(canvas0, gt) 215 | bg_reward = bg_reward.view(-1) 216 | # gan_reward = ((canvas0 - gt) ** 2).mean(1).mean(1).mean(1) - ((canvas1 - gt) ** 2).mean(1).mean(1).mean(1) 217 | 218 | if self.use_gbp: 219 | gbp_reward = (((canvas0 - gt) * gbpgt) ** 2).mean(1).sum(1).sum(1) \ 220 | - (((canvas1 - gt) * gbpgt) ** 2).mean(1).sum(1).sum(1) 221 | 222 | gbp_reward = gbp_reward / torch.sum(gbpgt, dim=(1, 2, 3)) 223 | gbp_reward = 1e3 * gbp_reward 224 | else: 225 | gbp_reward = torch.tensor(0.) 226 | 227 | if self.use_bilevel: 228 | canvas1_ = torch.nn.functional.grid_sample(canvas1 * seggt, grid_.permute(0, 2, 3, 1)) 229 | canvas0_ = torch.nn.functional.grid_sample(canvas0 * seggt, grid_.permute(0, 2, 3, 1)) 230 | # gt_, canvas0_, canvas1_ = self.nalignment(gt,canvas0_,canvas1_) 231 | gt_ = torch.nn.functional.grid_sample(gt * seggt, grid_.permute(0, 2, 3, 1)) 232 | # compute foreground reward 233 | foreground_reward = self.wgan.cal_reward(canvas1_, gt_) - self.wgan.cal_reward(canvas0_, gt_) 234 | # foreground_reward = ((canvas0_ - gt_) ** 2).mean(1).mean(1).mean(1) - ((canvas1_ - gt_) ** 2).mean(1).mean(1).mean(1) 235 | foreground_reward = 2e0 * foreground_reward.view(-1) 236 | else: 237 | foreground_reward = torch.tensor(0.) 238 | 239 | # total reward 240 | total_reward = bg_reward + gbp_reward + foreground_reward 241 | 242 | # get new merged state 243 | if self.high_res is True: 244 | coord_ = self.coord.expand(state.shape[0], 2, 256, 256) 245 | else: 246 | coord_ = self.coord.expand(state.shape[0], 2, 128, 128) 247 | state_list = [canvas0, canvas1, gt, (T + 1).float() / self.max_eps_len, coord_] 248 | 249 | if self.use_gbp: 250 | state_list += [state[:, 7:8].float() / 255.] 251 | if self.use_bilevel: 252 | state_list += [seggt] 253 | state_list += [grid] 254 | 255 | # compute merged state 256 | merged_state = torch.cat(state_list, 1) 257 | # canvas0 is not necessarily added 258 | 259 | if target: 260 | # compute Q from target network 261 | Q = self.critic_target(merged_state) 262 | else: 263 | # compute Q from local network 264 | Q = self.critic(merged_state) 265 | return (Q + total_reward), total_reward 266 | 267 | def update_policy(self, lr): 268 | """ 269 | update actor, critic using current replay memory buffer and given learning rate 270 | :param lr: learning rate 271 | :return: negative policy loss (current expected reward), value loss 272 | """ 273 | self.log += 1 274 | 275 | # set different learning rate for actor and critic 276 | for param_group in self.critic_optim.param_groups: 277 | param_group['lr'] = lr[0] 278 | for param_group in self.actor_optim.param_groups: 279 | param_group['lr'] = lr[1] 280 | 281 | # sample a batch from the replay buffer 282 | state, action, reward, next_state, terminal = self.memory.sample_batch(self.batch_size, device) 283 | 284 | # update gan model 285 | self.update_gan(next_state) 286 | 287 | # Q-learning: Q(s,a) = r(s,a) + gamma * Q(s',a') 288 | with torch.no_grad(): 289 | next_action = self.act(next_state, True) 290 | target_q, _ = self.evaluate(next_state, next_action, True) 291 | target_q = self.discount * ((1 - terminal.float()).view(-1, 1)) * target_q 292 | 293 | # add r(s,a) to Q(s,a) 294 | cur_q, step_reward = self.evaluate(state, action) 295 | target_q += step_reward.detach() 296 | 297 | # critic loss and update 298 | value_loss = criterion(cur_q, target_q) 299 | self.critic.zero_grad() 300 | value_loss.backward(retain_graph=True) 301 | self.critic_optim.step() 302 | 303 | # actor loss and update 304 | action = self.act(state) 305 | pre_q, _ = self.evaluate(state.detach(), action) 306 | policy_loss = -pre_q.mean() 307 | self.actor.zero_grad() 308 | policy_loss.backward(retain_graph=True) 309 | self.actor_optim.step() 310 | 311 | # Soft-update target networks for both actor and critic 312 | util.soft_update(self.actor_target, self.actor, self.tau) 313 | util.soft_update(self.critic_target, self.critic, self.tau) 314 | 315 | return -policy_loss, value_loss 316 | 317 | def observe(self, reward, state, done, step): 318 | """ 319 | Store observed sample in replay buffer 320 | :param reward: 321 | :param state: 322 | :param done: 323 | :param step: step count within an episode 324 | :return: None 325 | """ 326 | s0 = self.state.clone().detach().cpu() # torch.tensor(self.state, device='cpu') 327 | a = util.to_tensor(self.action, "cpu") 328 | r = util.to_tensor(reward, "cpu") 329 | s1 = state.clone().detach().cpu() # torch.tensor(state, device='cpu') 330 | d = util.to_tensor(done.astype('float32'), "cpu") 331 | for i in range(self.nenv): 332 | self.memory.append([s0[i], a[i], r[i], s1[i], d[i]]) 333 | self.state = state 334 | 335 | def noise_action(self, noise_factor, state, action): 336 | """ 337 | Add gaussian noise to continuous actions (stroke params) with zero mean and self.noise_level[i] variance 338 | :param noise_factor: 339 | :param state: 340 | :param action: 341 | :return: action (stroke params) clipped between 0,1 342 | """ 343 | noise = np.zeros(action.shape) 344 | for i in range(self.nenv): 345 | action[i] = action[i] + np.random.normal(0, self.noise_level[i], action.shape[1:]).astype('float32') 346 | return np.clip(action.astype('float32'), 0, 1) 347 | 348 | def select_action(self, state, return_fix=False, noise_factor=0): 349 | """ 350 | compute action given a state and noise_factor 351 | :param state: 352 | :param return_fix: 353 | :param noise_factor: 354 | :return: 355 | """ 356 | self.eval() 357 | # compute action 358 | with torch.no_grad(): 359 | action = self.act(state) 360 | action = util.to_numpy(action) 361 | # add noise to action 362 | if noise_factor > 0: 363 | action = self.noise_action(noise_factor, state, action) 364 | self.train() 365 | 366 | self.action = action 367 | if return_fix: 368 | return action 369 | return self.action 370 | 371 | def reset(self, obs, factor): 372 | self.state = obs 373 | self.noise_level = np.random.uniform(0, factor, self.nenv) 374 | 375 | def decode(self, x, canvas, mask=False): # b * (10 + 3) 376 | """ 377 | Update canvas given stroke parameters x 378 | :param x: stroke parameters 379 | :param canvas: current canvas state 380 | :return: updated canvas with stroke drawn 381 | """ 382 | # 13 stroke parameters (10 position and 3 RGB color) 383 | x = x.view(-1, 10 + 3) 384 | 385 | # get stroke on an empty canvas given 10 positional parameters 386 | stroke = 1 - self.decoder(x[:, :10]) 387 | stroke = stroke.view(-1, 128, 128, 1) 388 | 389 | # add color to the stroke 390 | color_stroke = stroke * x[:, -3:].view(-1, 1, 1, 3) 391 | stroke = stroke.permute(0, 3, 1, 2) 392 | color_stroke = color_stroke.permute(0, 3, 1, 2) 393 | 394 | # draw bundle_size=5 strokes at a time (action bundle) 395 | stroke = stroke.view(-1, self.bundle_size, 1, 128, 128) 396 | color_stroke = color_stroke.view(-1, self.bundle_size, 3, 128, 128) 397 | for i in range(self.bundle_size): 398 | canvas = canvas * (1 - stroke[:, i]) + color_stroke[:, i] 399 | 400 | # also return stroke mask if required 401 | stroke_mask = None 402 | if mask: 403 | stroke_mask = (stroke != 0).float() # -1, bundle_size, 1, width, width 404 | 405 | return canvas, stroke_mask 406 | 407 | def load_weights(self, path, map_location=None, num_episodes=0): 408 | """ 409 | load actor,critic,gan from given paths 410 | :param path: 411 | :return: 412 | """ 413 | if map_location is None: 414 | map_location = {'cuda:0': 'cuda:{}'.format(self.gpu)} 415 | 416 | if path is None: return 417 | 418 | self.actor.load_state_dict( 419 | torch.load('{}/actor_{:05}.pkl'.format(path, num_episodes), map_location=map_location)) 420 | self.critic.load_state_dict( 421 | torch.load('{}/critic_{:05}.pkl'.format(path, num_episodes), map_location=map_location)) 422 | self.wgan.load_gan(path, map_location, num_episodes) 423 | self.wgan_bg.load_gan(path, map_location, num_episodes + 1) 424 | 425 | def save_model(self, path, num_episodes): 426 | """ 427 | save trained actor,critic,gan models 428 | :param path: save parent dir 429 | :return: None 430 | """ 431 | if self.gpu == 0: 432 | torch.save(self.actor.state_dict(), "{}/actor_{:05}.pkl".format(path, num_episodes)) 433 | torch.save(self.critic.state_dict(), '{}/critic_{:05}.pkl'.format(path, num_episodes)) 434 | self.wgan.save_gan(path, num_episodes) 435 | self.wgan_bg.save_gan(path, num_episodes + 1) 436 | 437 | # Use a barrier() to make sure that process 1 loads the model after process 438 | # 0 saves it. 439 | dist.barrier() 440 | # configure map_location properly 441 | device_pairs = [(0, self.gpu)] 442 | map_location = {'cuda:%d' % x: 'cuda:%d' % y for x, y in device_pairs} 443 | self.load_weights(path, map_location, num_episodes) 444 | dist.barrier() 445 | print("done saving on cuda:{}".format(self.gpu)) 446 | 447 | def eval(self): 448 | """ 449 | set actor, critic in eval mode 450 | :return: None 451 | """ 452 | self.actor.eval() 453 | self.actor_target.eval() 454 | self.critic.eval() 455 | self.critic_target.eval() 456 | 457 | def train(self): 458 | """ 459 | set actor, critic in train mode 460 | :return: None 461 | """ 462 | self.actor.train() 463 | self.actor_target.train() 464 | self.critic.train() 465 | self.critic_target.train() 466 | 467 | def choose_device(self): 468 | """ 469 | transfer renderer, actor, critic to device 470 | :return: None 471 | """ 472 | self.actor.cuda(self.gpu) 473 | self.actor_target.cuda(self.gpu) 474 | self.critic.cuda(self.gpu) 475 | self.critic_target.cuda(self.gpu) 476 | 477 | if self.distributed: 478 | self.actor = DDP(self.actor, device_ids=[self.gpu]) 479 | self.critic = DDP(self.critic, device_ids=[self.gpu]) 480 | -------------------------------------------------------------------------------- /semantic_guidance/DRL/evaluator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from utils.util import * 3 | 4 | class Evaluator(object): 5 | 6 | def __init__(self, args, writer): 7 | self.val_num_eps = args.val_num_eps 8 | self.max_eps_len = args.max_eps_len 9 | self.nenv = args.nenv 10 | self.writer = writer 11 | self.log = 0 12 | 13 | def __call__(self, env, policy, debug=False): 14 | observation = None 15 | for episode in range(self.val_num_eps): 16 | # reset at the start of episode 17 | observation = env.reset(test=True, episode=episode) 18 | episode_steps = 0 19 | episode_reward = 0. 20 | assert observation is not None 21 | # start episode 22 | episode_reward = np.zeros(self.nenv) 23 | while (episode_steps < self.max_eps_len or not self.max_eps_len): 24 | action = policy(observation) 25 | observation, reward, done, (step_num) = env.step(action) 26 | episode_reward += reward 27 | episode_steps += 1 28 | env.save_image(self.log, episode_steps) 29 | dist = env.get_dist() 30 | self.log += 1 31 | return episode_reward, dist 32 | -------------------------------------------------------------------------------- /semantic_guidance/DRL/multi.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import numpy as np 4 | from env_ins import Paint 5 | import utils.util as util 6 | 7 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 8 | 9 | 10 | class fastenv: 11 | def __init__(self, max_episode_length=10, nenv=64, 12 | writer=None, dataset='cub200', use_gbp=False, use_bilevel=False, gpu=0, high_res=False, 13 | bundle_size=5): 14 | self.max_episode_length = max_episode_length 15 | self.nenv = nenv 16 | self.env = Paint(self.nenv, self.max_episode_length, dataset=dataset, use_gbp=use_gbp, 17 | use_bilevel=use_bilevel, gpu=gpu, high_res=high_res, bundle_size=bundle_size) 18 | self.env.load_data() 19 | self.observation_space = self.env.observation_space 20 | self.action_space = self.env.action_space 21 | self.writer = writer 22 | self.test = False 23 | self.log = 0 24 | self.gpu = gpu 25 | self.use_gbp = use_gbp 26 | self.use_bilevel = use_bilevel 27 | 28 | def nalignment(self, gt, canvas0): 29 | gt_ = (gt - self.mean) / self.std 30 | predictions = self.loc_model(gt_) 31 | M = torch.matmul(predictions, self.P) 32 | M = M - self.c_ 33 | M = M.view(-1, 2, 3) 34 | grid = torch.nn.functional.affine_grid(M, gt.size()) 35 | z_gt = torch.nn.functional.grid_sample(gt, grid.float()) 36 | z_canvas0 = torch.nn.functional.grid_sample(canvas0, grid.float()) 37 | # z_canvas1 = torch.nn.functional.grid_sample(canvas1, grid.float()) 38 | return z_gt, z_canvas0 39 | 40 | def save_image(self, log, step): 41 | if self.gpu == 0: 42 | for i in range(self.nenv): 43 | if self.env.imgid[i] <= 10: 44 | # write background images 45 | canvas = util.to_numpy(self.env.canvas[i, :3].permute(1, 2, 0)) 46 | self.writer.add_image('{}/canvas_{}.png'.format(str(self.env.imgid[i]), str(step)), canvas, log) 47 | if step == self.max_episode_length: 48 | if self.use_bilevel: 49 | # z_gt, z_canvas = self.nalignment(self.env.gt[:,:3].float() / 255,self.env.canvas[:,:3].float() / 255) 50 | grid = self.env.grid[:, :2].float() / 255 51 | grid = 2 * grid - 1 52 | z_gt = torch.nn.functional.grid_sample(self.env.gt[:, :3].float() / 255, grid.permute(0, 2, 3, 1)) 53 | z_canvas = torch.nn.functional.grid_sample(self.env.canvas[:, :3].float() / 255, 54 | grid.permute(0, 2, 3, 1)) 55 | for i in range(self.nenv): 56 | if self.env.imgid[i] < 50: 57 | # write background images 58 | gt = util.to_numpy(self.env.gt[i, :3].permute(1, 2, 0)) 59 | canvas = util.to_numpy(self.env.canvas[i, :3].permute(1, 2, 0)) 60 | self.writer.add_image(str(self.env.imgid[i]) + '/_target.png', gt, log) 61 | self.writer.add_image(str(self.env.imgid[i]) + '/_canvas.png', canvas, log) 62 | if self.use_bilevel: 63 | # # also write foreground images 64 | gt = util.to_numpy(z_gt[i, :3].permute(1, 2, 0)) 65 | canvas = util.to_numpy(z_canvas[i, :3].permute(1, 2, 0)) 66 | self.writer.add_image(str(self.env.imgid[i]) + '_foreground/_target.png', gt, log) 67 | self.writer.add_image(str(self.env.imgid[i]) + '_foreground/_canvas.png', canvas, log) 68 | 69 | def step(self, action): 70 | with torch.no_grad(): 71 | ob, r, d, _ = self.env.step(torch.tensor(action).cuda(self.gpu)) 72 | return ob, r, d, _ 73 | 74 | def get_dist(self): 75 | return util.to_numpy( 76 | (((self.env.gt[:, :3].float() - self.env.canvas[:, :3].float()) / 255) ** 2).mean(1).mean(1).mean(1)) 77 | 78 | def reset(self, test=False, episode=0): 79 | self.test = test 80 | ob = self.env.reset(self.test, episode * self.nenv) 81 | return ob 82 | -------------------------------------------------------------------------------- /semantic_guidance/DRL/rpm.py: -------------------------------------------------------------------------------- 1 | # from collections import deque 2 | import numpy as np 3 | import random 4 | import torch 5 | import pickle as pickle 6 | 7 | class rpm(object): 8 | # replay memory 9 | def __init__(self, buffer_size, gpu=0): 10 | self.buffer_size = buffer_size 11 | self.buffer = [] 12 | self.index = 0 13 | self.gpu = gpu 14 | 15 | def append(self, obj): 16 | if self.size() > self.buffer_size: 17 | print('buffer size larger than set value, trimming...') 18 | self.buffer = self.buffer[(self.size() - self.buffer_size):] 19 | elif self.size() == self.buffer_size: 20 | self.buffer[self.index] = obj 21 | self.index += 1 22 | self.index %= self.buffer_size 23 | else: 24 | self.buffer.append(obj) 25 | 26 | def size(self): 27 | return len(self.buffer) 28 | 29 | def sample_batch(self, batch_size, device, only_state=False): 30 | if self.size() < batch_size: 31 | batch = random.sample(self.buffer, self.size()) 32 | else: 33 | batch = random.sample(self.buffer, batch_size) 34 | 35 | if only_state: 36 | res = torch.stack(tuple(item[3] for item in batch), dim=0) 37 | return res.cuda(self.gpu) 38 | else: 39 | item_count = 5 40 | res = [] 41 | for i in range(5): 42 | k = torch.stack(tuple(item[i] for item in batch), dim=0) 43 | res.append(k.cuda(self.gpu)) 44 | return res[0], res[1], res[2], res[3], res[4] 45 | -------------------------------------------------------------------------------- /semantic_guidance/DRL/wgan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from torch.optim import Adam, SGD 5 | from torch import autograd 6 | from torch.autograd import Variable 7 | import torch.nn.functional as F 8 | from torch.autograd import grad as torch_grad 9 | import torch.nn.utils.weight_norm as weightNorm 10 | import utils.util as util 11 | from torch.nn.parallel import DistributedDataParallel as DDP 12 | 13 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 14 | 15 | 16 | # dim = 128 17 | # LAMBDA = 10 # Gradient penalty lambda hyperparameter 18 | 19 | 20 | class TReLU(nn.Module): 21 | def __init__(self): 22 | super(TReLU, self).__init__() 23 | self.alpha = nn.Parameter(torch.FloatTensor(1), requires_grad=True) 24 | self.alpha.data.fill_(0) 25 | 26 | def forward(self, x): 27 | x = F.relu(x - self.alpha) + self.alpha 28 | return x 29 | 30 | 31 | class Discriminator(nn.Module): 32 | def __init__(self, high_res=False): 33 | super(Discriminator, self).__init__() 34 | 35 | self.conv0 = weightNorm(nn.Conv2d(6, 16, 5, 2, 2)) 36 | self.conv1 = weightNorm(nn.Conv2d(16, 32, 5, 2, 2)) 37 | self.conv2 = weightNorm(nn.Conv2d(32, 64, 5, 2, 2)) 38 | self.conv3 = weightNorm(nn.Conv2d(64, 128, 5, 2, 2)) 39 | self.conv4 = weightNorm(nn.Conv2d(128, 1, 5, 2, 2)) 40 | self.relu0 = TReLU() 41 | self.relu1 = TReLU() 42 | self.relu2 = TReLU() 43 | self.relu3 = TReLU() 44 | self.high_res = high_res 45 | 46 | def forward(self, x): 47 | x = self.conv0(x) 48 | x = self.relu0(x) 49 | x = self.conv1(x) 50 | x = self.relu1(x) 51 | x = self.conv2(x) 52 | x = self.relu2(x) 53 | x = self.conv3(x) 54 | x = self.relu3(x) 55 | x = self.conv4(x) 56 | if self.high_res is True: 57 | x = F.avg_pool2d(x, 8) 58 | else: 59 | x = F.avg_pool2d(x, 4) 60 | x = x.view(-1, 1) 61 | return x 62 | 63 | 64 | class WGAN: 65 | def __init__(self, gpu=0, distributed=False, dim=128, high_res=False): 66 | self.gpu = gpu 67 | self.distributed = distributed 68 | self.high_res = high_res 69 | 70 | if self.high_res is True: 71 | self.dim = 256 72 | else: 73 | self.dim = 128 74 | 75 | self.netD = Discriminator(high_res=self.high_res) 76 | self.target_netD = Discriminator(high_res=self.high_res) 77 | 78 | self.netD = self.netD.cuda(gpu) 79 | self.target_netD = self.target_netD.cuda(gpu) 80 | 81 | for param in self.target_netD.parameters(): 82 | param.requires_grad = False 83 | 84 | if distributed: 85 | self.netD = DDP(self.netD, device_ids=[self.gpu]) 86 | # self.target_netD = DDP(self.target_netD, device_ids=[self.gpu]) 87 | 88 | util.hard_update(self.target_netD, self.netD) 89 | 90 | self.optimizerD = Adam(self.netD.parameters(), lr=3e-4, betas=(0.5, 0.999)) 91 | # self.dim = dim 92 | self.LAMBDA = 10 # Gradient penalty lambda hyperparameter 93 | 94 | def cal_gradient_penalty(self, real_data, fake_data, batch_size): 95 | alpha = torch.rand(batch_size, 1) 96 | alpha = alpha.expand(batch_size, int(real_data.nelement() / batch_size)).contiguous() 97 | alpha = alpha.view(batch_size, 6, self.dim, self.dim) 98 | alpha = alpha.cuda(self.gpu) 99 | fake_data = fake_data.view(batch_size, 6, self.dim, self.dim) 100 | interpolates = Variable(alpha * real_data.data + ((1 - alpha) * fake_data.data), requires_grad=True) 101 | disc_interpolates = self.netD(interpolates) 102 | gradients = autograd.grad(disc_interpolates, interpolates, 103 | grad_outputs=torch.ones(disc_interpolates.size()).cuda(self.gpu), 104 | create_graph=True, retain_graph=True)[0] 105 | gradients = gradients.view(gradients.size(0), -1) 106 | gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * self.LAMBDA 107 | return gradient_penalty 108 | 109 | def cal_reward(self, fake_data, real_data): 110 | return self.target_netD(torch.cat([real_data, fake_data], 1)) 111 | 112 | def save_gan(self, path, num_episodes): 113 | # self.netD.cpu() 114 | torch.save(self.netD.state_dict(), '{}/wgan_{:05}.pkl'.format(path, num_episodes)) 115 | # self.netD.cuda(self.gpu) 116 | 117 | def load_gan(self, path, map_location, num_episodes): 118 | self.netD.load_state_dict(torch.load('{}/wgan_{:05}.pkl'.format(path, num_episodes), map_location=map_location)) 119 | 120 | def random_masks(self): 121 | """ 122 | Generate random masks for complement discriminator (need to make mask overlap with the stroke) 123 | :return: mask 124 | """ 125 | # initialize mask 126 | mask = np.ones((3, self.dim, self.dim)) 127 | 128 | # generate one of 4 random masks 129 | choose = 1 # np.random.randint(0, 1) 130 | if choose == 0: 131 | mask[:, :self.dim // 2] = 0 132 | elif choose == 1: 133 | mask[:, :, :self.dim // 2] = 0 134 | elif choose == 2: 135 | mask[:, :, self.dim // 2:] = 0 136 | elif choose == 3: 137 | mask[:, self.dim // 2:] = 0 138 | 139 | return mask 140 | 141 | def update(self, fake_data, real_data): 142 | fake_data = fake_data.detach() 143 | real_data = real_data.detach() 144 | 145 | # standard conditional training for discriminator 146 | fake = torch.cat([real_data, fake_data], 1) 147 | real = torch.cat([real_data, real_data], 1) 148 | 149 | # # complement discriminator conditional training for discriminator 150 | # mask = torch.tensor(random_masks()).float().to(device) 151 | # fake = torch.cat([(1 - mask) * real_data, mask * fake_data], 1) 152 | # real = torch.cat([(1 - mask) * real_data, mask * real_data], 1) 153 | 154 | # compute discriminator scores for real and fake data 155 | D_real = self.netD(real) 156 | D_fake = self.netD(fake) 157 | 158 | gradient_penalty = self.cal_gradient_penalty(real, fake, real.shape[0]) 159 | self.optimizerD.zero_grad() 160 | D_cost = D_fake.mean() - D_real.mean() + gradient_penalty 161 | D_cost.backward() 162 | self.optimizerD.step() 163 | util.soft_update(self.target_netD, self.netD, 0.001) 164 | return D_fake.mean(), D_real.mean(), gradient_penalty 165 | -------------------------------------------------------------------------------- /semantic_guidance/Renderer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1jsingh/semantic-guidance/092f923a0f5d42a949489bd49df8449db9d94d53/semantic_guidance/Renderer/__init__.py -------------------------------------------------------------------------------- /semantic_guidance/Renderer/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.utils.weight_norm as weightNorm 5 | 6 | 7 | class FCN(nn.Module): 8 | def __init__(self, high_res=False): 9 | super(FCN, self).__init__() 10 | self.fc1 = (nn.Linear(10, 512)) 11 | self.fc2 = (nn.Linear(512, 1024)) 12 | self.fc3 = (nn.Linear(1024, 2048)) 13 | self.fc4 = (nn.Linear(2048, 4096)) 14 | self.high_res = high_res 15 | 16 | if self.high_res is True: 17 | self.conv1 = (nn.Conv2d(16, 64, 3, 1, 1)) 18 | self.conv2 = (nn.Conv2d(64, 64, 3, 1, 1)) 19 | self.conv3 = (nn.Conv2d(16, 32, 3, 1, 1)) 20 | self.conv4 = (nn.Conv2d(32, 32, 3, 1, 1)) 21 | self.conv5 = (nn.Conv2d(8, 16, 3, 1, 1)) 22 | self.conv6 = (nn.Conv2d(16, 16, 3, 1, 1)) 23 | self.conv7 = (nn.Conv2d(4, 8, 3, 1, 1)) 24 | self.conv8 = (nn.Conv2d(8, 4, 3, 1, 1)) 25 | else: 26 | self.conv1 = (nn.Conv2d(16, 32, 3, 1, 1)) 27 | self.conv2 = (nn.Conv2d(32, 32, 3, 1, 1)) 28 | self.conv3 = (nn.Conv2d(8, 16, 3, 1, 1)) 29 | self.conv4 = (nn.Conv2d(16, 16, 3, 1, 1)) 30 | self.conv5 = (nn.Conv2d(4, 8, 3, 1, 1)) 31 | self.conv6 = (nn.Conv2d(8, 4, 3, 1, 1)) 32 | self.pixel_shuffle = nn.PixelShuffle(2) 33 | 34 | def forward(self, x): 35 | x = F.relu(self.fc1(x)) 36 | x = F.relu(self.fc2(x)) 37 | x = F.relu(self.fc3(x)) 38 | x = F.relu(self.fc4(x)) 39 | x = x.view(-1, 16, 16, 16) 40 | x = F.relu(self.conv1(x)) 41 | x = self.pixel_shuffle(self.conv2(x)) 42 | x = F.relu(self.conv3(x)) 43 | x = self.pixel_shuffle(self.conv4(x)) 44 | x = F.relu(self.conv5(x)) 45 | x = self.pixel_shuffle(self.conv6(x)) 46 | if self.high_res is True: 47 | x = F.relu(self.conv7(x)) 48 | x = self.pixel_shuffle(self.conv8(x)) 49 | x = torch.sigmoid(x) 50 | if self.high_res is True: 51 | return 1 - x.view(-1, 256, 256) 52 | else: 53 | return 1 - x.view(-1, 128, 128) 54 | -------------------------------------------------------------------------------- /semantic_guidance/Renderer/stroke_gen.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | def normal(x, width): 5 | """ 6 | scale stroke parameter x ([0,1]) based on width of the canvas 7 | :param x: stroke parameter x ([0,1]) 8 | :param width: width of canvas 9 | :return: scaled parameter 10 | """ 11 | return (int)(x * (width - 1) + 0.5) 12 | 13 | def draw(f, width=128, mask = False): 14 | """ 15 | Draw the Brezier curve on empty canvas 16 | :param f: stroke parameters (x0, y0, x1, y1, x2, y2, z0, z2, w0, w2) 17 | :param width: width of the canvas 18 | :param mask: boolean on whether mask is required for the canvas 19 | :return: painted canvas is zero at stroke locations and one otherwise, 20 | stroke_mask (ones at stroke parameters and zero otherwise) 21 | """ 22 | # read stroke parameters (10 positional parameters) 23 | x0, y0, x1, y1, x2, y2, z0, z2, w0, w2 = f 24 | x1 = x0 + (x2 - x0) * x1 25 | y1 = y0 + (y2 - y0) * y1 26 | x0 = normal(x0, width * 2) 27 | x1 = normal(x1, width * 2) 28 | x2 = normal(x2, width * 2) 29 | y0 = normal(y0, width * 2) 30 | y1 = normal(y1, width * 2) 31 | y2 = normal(y2, width * 2) 32 | z0 = (int)(1 + z0 * width // 2) 33 | z2 = (int)(1 + z2 * width // 2) 34 | 35 | # initialize empty canvas 36 | canvas = np.zeros([width * 2, width * 2]).astype('float32') 37 | tmp = 1. / 100 38 | 39 | # Brezier curve is made of 100 smaller circles 40 | for i in range(100): 41 | t = i * tmp 42 | x = (int)((1-t) * (1-t) * x0 + 2 * t * (1-t) * x1 + t * t * x2) 43 | y = (int)((1-t) * (1-t) * y0 + 2 * t * (1-t) * y1 + t * t * y2) 44 | z = (int)((1-t) * z0 + t * z2) 45 | w = (1-t) * w0 + t * w2 46 | cv2.circle(canvas, (y, x), z, w, -1) 47 | 48 | # return mask if required 49 | if mask: 50 | stroke_mask = (canvas!=0).astype(np.int32) 51 | return 1 - cv2.resize(canvas, dsize=(width, width)), stroke_mask 52 | return 1 - cv2.resize(canvas, dsize=(width, width)) 53 | -------------------------------------------------------------------------------- /semantic_guidance/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /semantic_guidance/env_ins.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | import torch 4 | import numpy as np 5 | import argparse 6 | import torchvision.transforms as transforms 7 | import cv2 8 | # from DRL.ddpg import decode 9 | import utils.util as util 10 | from utils.dataloader import ImageDataset 11 | import random 12 | from Renderer.model import FCN 13 | import pandas as pd 14 | 15 | # define device 16 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 17 | 18 | 19 | class Paint: 20 | def __init__(self, batch_size, max_eps_len, dataset='cub200', use_gbp=False, use_bilevel=False, gpu=0, 21 | width=128, high_res=False, bundle_size=5): 22 | self.batch_size = batch_size 23 | self.max_eps_len = max_eps_len 24 | self.action_space = (13) 25 | self.use_gbp = use_gbp 26 | self.use_bilevel = use_bilevel 27 | self.width = width 28 | self.high_res = high_res 29 | self.bundle_size = bundle_size 30 | 31 | if self.high_res: 32 | self.width = 256 33 | 34 | # define and load renderer 35 | self.decoder = FCN(self.high_res) 36 | self.decoder.load_state_dict(torch.load('../data/renderer{}.pkl'.format("_256" if self.high_res else ""))) 37 | self.decoder.cuda(gpu) 38 | 39 | for param in self.decoder.parameters(): 40 | param.requires_grad = False 41 | 42 | # gpu and distributed parameters 43 | self.gpu = gpu 44 | 45 | # define observation space 46 | self.observation_space = (self.batch_size, self.width, self.width, 7 + use_gbp + use_bilevel) 47 | self.test = False 48 | 49 | # dataset name 50 | self.dataset = dataset 51 | if dataset == 'cub200': 52 | df = pd.read_csv('../data/cub200/CUB_200_2011/images.txt', sep=' ', index_col=0, names=['idx', 'img_names']) 53 | img_names = list(df['img_names']) 54 | self.data = np.array( 55 | ["../data/cub200/CUB_200_2011/images/{}.jpg".format(img_name[:-4]) for img_name in img_names]) 56 | self.seg = np.array( 57 | ["../data/cub200/CUB_200_2011/segmentations_pred/{}.jpg".format(img_name[:-4]) for img_name in 58 | img_names]) 59 | self.gbp = np.array( 60 | ["../data/cub200/CUB_200_2011/gbp_global/{}.jpg".format(img_name[:-4]) for img_name in img_names]) 61 | 62 | df_ = pd.read_csv('../data/cub200/CUB_200_2011/bounding_boxes_pred.txt', sep=' ', index_col=0) 63 | x, y, w, h = np.array(df_['x']).astype(int), np.array(df_['y']).astype(int), np.array(df_['w']).astype(int), \ 64 | np.array(df_['h']).astype(int) 65 | self.bbox = np.array(list(zip(x, y, w, h))) 66 | 67 | # random shuffle data 68 | shuffled_indices = np.arange(len(self.data)).astype(np.int32) 69 | np.random.shuffle(shuffled_indices) 70 | self.data = self.data[shuffled_indices] 71 | self.seg = self.seg[shuffled_indices] 72 | self.bbox = self.bbox[shuffled_indices] 73 | self.gbp = self.gbp[shuffled_indices] 74 | 75 | def load_data(self, num_test=2000): 76 | # random shuffle data 77 | shuffled_indices = np.arange(len(self.data)).astype(np.int32) 78 | np.random.shuffle(shuffled_indices) 79 | 80 | # divide data into train and test 81 | train_data, test_data = self.data[shuffled_indices[num_test:]], self.data[shuffled_indices[:num_test]] 82 | 83 | # divide gbp data 84 | if self.use_gbp: 85 | gbp_train, gbp_test = self.gbp[shuffled_indices[num_test:]], self.gbp[shuffled_indices[:num_test]] 86 | else: 87 | gbp_train, gbp_test = None, None 88 | 89 | if self.use_bilevel: 90 | seg_train, seg_test = self.seg[shuffled_indices[num_test:]], self.seg[ 91 | shuffled_indices[:num_test]] 92 | bbox_train, bbox_test = self.bbox[shuffled_indices[num_test:]], self.bbox[ 93 | shuffled_indices[:num_test]] 94 | else: 95 | seg_train, seg_test = None, None 96 | bbox_train, bbox_test = None, None 97 | 98 | # create train and test data 99 | self.train_dataset = ImageDataset(train_data, gbp_list=gbp_train, seg_list=seg_train, 100 | bbox_list=bbox_train, high_res=self.high_res) 101 | self.test_dataset = ImageDataset(test_data, gbp_list=gbp_test, seg_list=seg_test, 102 | bbox_list=bbox_test, high_res=self.high_res) 103 | 104 | # record train test split 105 | self.num_train, self.num_test = len(train_data), num_test 106 | 107 | def reset(self, test=False, begin_num=False): 108 | self.test = test 109 | # self.imgid = [0] * self.batch_size 110 | self.gt = torch.zeros([self.batch_size, 3, self.width, self.width], dtype=torch.uint8).cuda(self.gpu) 111 | if self.use_gbp: 112 | self.gbp_gt = torch.zeros([self.batch_size, 1, self.width, self.width], dtype=torch.uint8).cuda(self.gpu) 113 | if self.use_bilevel: 114 | self.seg_gt = torch.zeros([self.batch_size, 1, self.width, self.width], dtype=torch.uint8).cuda(self.gpu) 115 | self.grid = torch.zeros([self.batch_size, 2, self.width, self.width], dtype=torch.uint8).cuda(self.gpu) 116 | 117 | # get ground truths and corresponding idxs 118 | if test: 119 | self.imgid = (begin_num + np.arange(self.batch_size)) % self.num_test 120 | for i in range(self.batch_size): 121 | img, gbp_gt, seg_gt, grid = self.test_dataset[self.imgid[i]] 122 | self.gt[i] = img 123 | if self.use_gbp: 124 | self.gbp_gt[i, :] = gbp_gt 125 | if self.use_bilevel: 126 | self.seg_gt[i, :] = seg_gt 127 | self.grid[i, :] = grid 128 | else: 129 | self.imgid = np.random.choice(np.arange(self.num_train), self.batch_size, replace=False) 130 | for i in range(self.batch_size): 131 | img, gbp_gt, seg_gt, grid = self.train_dataset[self.imgid[i]] 132 | self.gt[i] = img 133 | if self.use_gbp: 134 | self.gbp_gt[i, :] = gbp_gt 135 | if self.use_bilevel: 136 | self.seg_gt[i, :] = seg_gt 137 | self.grid[i, :] = grid 138 | 139 | self.tot_reward = ((self.gt.float() / 255) ** 2).mean(1).mean(1).mean(1) 140 | self.stepnum = 0 141 | self.canvas = torch.zeros([self.batch_size, 3, self.width, self.width], dtype=torch.uint8).cuda(self.gpu) 142 | self.lastdis = self.ini_dis = self.cal_dis() 143 | return self.observation() 144 | 145 | def observation(self): 146 | # canvas B * 3 * width * width 147 | # gt B * 3 * width * width 148 | # T B * 1 * width * width 149 | ob = [] 150 | T = torch.ones([self.batch_size, 1, self.width, self.width], dtype=torch.uint8) * self.stepnum 151 | 152 | # canvas, img, T 153 | obs_list = [self.canvas, self.gt, T.cuda(self.gpu)] 154 | 155 | if self.use_gbp: 156 | obs_list += [self.gbp_gt] 157 | if self.use_bilevel: 158 | obs_list += [self.seg_gt] 159 | obs_list += [self.grid] 160 | 161 | return torch.cat(obs_list, 1) 162 | 163 | def cal_trans(self, s, t): 164 | return (s.transpose(0, 3) * t).transpose(0, 3) 165 | 166 | def step(self, action): 167 | if self.use_bilevel: 168 | self.canvas = ( 169 | self.decode_parallel(action, self.canvas.float() / 255, seg_mask=self.seg_gt.float() / 255)[ 170 | 0] * 255).byte() 171 | else: 172 | self.canvas = (self.decode(action, self.canvas.float() / 255)[0] * 255).byte() 173 | 174 | self.stepnum += 1 175 | ob = self.observation() 176 | # ob = ob[:, :7, :, :] 177 | done = (self.stepnum == self.max_eps_len) 178 | reward = self.cal_reward() # np.array([0.] * self.batch_size) 179 | return ob.detach(), reward, np.array([done] * self.batch_size), None 180 | 181 | def cal_dis(self): 182 | return (((self.canvas.float() - self.gt.float()) / 255) ** 2).mean(1).mean(1).mean(1) 183 | 184 | def cal_reward(self): 185 | """L2 loss difference between canvas and ground truth""" 186 | dis = self.cal_dis() 187 | reward = (self.lastdis - dis) / (self.ini_dis + 1e-8) 188 | self.lastdis = dis 189 | return util.to_numpy(reward) 190 | 191 | def decode_parallel(self, x, canvas, seg_mask=None, mask=False): 192 | canvas, _ = self.decode(x[:, :13 * self.bundle_size], canvas, mask, 1 - seg_mask) 193 | canvas, _ = self.decode(x[:, 13 * self.bundle_size:], canvas, mask, seg_mask) 194 | return canvas, _ 195 | 196 | def decode(self, x, canvas, mask=False, seg_mask=None): # b * (10 + 3) 197 | """ 198 | Update canvas given stroke parameters x 199 | :param x: stroke parameters (N,13*5) 200 | :param canvas: current canvas state 201 | :return: updated canvas with stroke drawn 202 | """ 203 | # 13 stroke parameters (10 position and 3 RGB color) 204 | x = x.contiguous().view(-1, 10 + 3) 205 | 206 | # get stroke on an empty canvas given 10 positional parameters 207 | stroke = 1 - self.decoder(x[:, :10]) 208 | if self.high_res is True: 209 | stroke = stroke.view(-1, 256, 256, 1) 210 | else: 211 | stroke = stroke.view(-1, 128, 128, 1) 212 | 213 | # add color to the stroke 214 | color_stroke = stroke * x[:, -3:].view(-1, 1, 1, 3) 215 | stroke = stroke.permute(0, 3, 1, 2) 216 | color_stroke = color_stroke.permute(0, 3, 1, 2) 217 | 218 | # draw bundle_size=5 strokes at a time (action bundle) 219 | if self.high_res is True: 220 | stroke = stroke.view(-1, self.bundle_size, 1, 256, 256) 221 | color_stroke = color_stroke.view(-1, self.bundle_size, 3, 256, 256) 222 | else: 223 | stroke = stroke.view(-1, self.bundle_size, 1, 128, 128) 224 | color_stroke = color_stroke.view(-1, self.bundle_size, 3, 128, 128) 225 | 226 | for i in range(self.bundle_size): 227 | if seg_mask is not None: 228 | canvas = canvas * (1 - stroke[:, i]) + color_stroke[:, i] * seg_mask 229 | else: 230 | canvas = canvas * (1 - stroke[:, i]) + color_stroke[:, i] 231 | 232 | # also return stroke mask if required 233 | stroke_mask = None 234 | if mask: 235 | stroke_mask = (stroke != 0).float() # -1, bundle_size, 1, width, width 236 | 237 | return canvas, stroke_mask 238 | -------------------------------------------------------------------------------- /semantic_guidance/preprocess.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os, sys 3 | import pandas as pd 4 | import cv2 5 | 6 | from PIL import Image 7 | import torch 8 | import torchvision 9 | from torchvision import transforms, utils 10 | from torchvision.models.detection.faster_rcnn import FastRCNNPredictor 11 | from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor 12 | 13 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 14 | print("using device: {}".format(device)) 15 | 16 | 17 | # define function for getting combined object localization and semantic segmentation prediction model 18 | def get_model_instance_segmentation(num_classes): 19 | # load an instance segmentation model pre-trained pre-trained on COCO 20 | model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True) 21 | 22 | # get number of input features for the classifier 23 | in_features = model.roi_heads.box_predictor.cls_score.in_features 24 | # replace the pre-trained head with a new one 25 | model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) 26 | 27 | # now get the number of input features for the mask classifier 28 | in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels 29 | hidden_layer = 256 30 | # and replace the mask predictor with a new one 31 | model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, 32 | hidden_layer, 33 | num_classes) 34 | 35 | return model 36 | 37 | 38 | # our dataset has two classes only - background and foreground 39 | num_classes = 2 40 | # load pretrained object localization and semantic segmentation model 41 | model = get_model_instance_segmentation(num_classes) 42 | model.load_state_dict(torch.load('../data/birds_obj_seg.pkl', map_location={'cuda:0': 'cpu'})) 43 | model.to(device) 44 | model.eval() 45 | 46 | transform_test = transforms.Compose([ 47 | transforms.ToTensor(), 48 | ]) 49 | 50 | expert_model = torch.hub.load('nicolalandro/ntsnet-cub200', 'ntsnet', pretrained=True, 51 | **{'topN': 6, 'device': 'cpu', 'num_classes': 200}) 52 | expert_model.eval() 53 | 54 | from torch.nn import ReLU 55 | 56 | gbp_transform = transforms.Compose([ 57 | transforms.ToPILImage(), 58 | transforms.Resize((440, 440)), 59 | transforms.CenterCrop((440, 440)), 60 | # transforms.RandomHorizontalFlip(), # only if train 61 | transforms.ToTensor(), 62 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 63 | ]) 64 | 65 | 66 | def convert_to_grayscale(im_as_arr): 67 | """ 68 | Converts 3d image to grayscale 69 | Args: 70 | im_as_arr (numpy arr): RGB image with shape (D,W,H) 71 | returns: 72 | grayscale_im (numpy_arr): Grayscale image with shape (1,W,D) 73 | """ 74 | grayscale_im = np.sum(np.abs(im_as_arr), axis=0) 75 | im_max = np.percentile(grayscale_im, 99) 76 | im_min = np.min(grayscale_im) 77 | grayscale_im = (np.clip((grayscale_im - im_min) / (im_max - im_min), 0, 1)) 78 | grayscale_im = np.expand_dims(grayscale_im, axis=0) 79 | return grayscale_im 80 | 81 | 82 | class GuidedBackprop(): 83 | """ 84 | Produces gradients generated with guided back propagation from the given image 85 | """ 86 | 87 | def __init__(self, model): 88 | self.model = model 89 | self.gradients = None 90 | # Put model in evaluation mode 91 | self.model.eval() 92 | self.update_relus() 93 | self.hook_layers() 94 | 95 | def hook_layers(self): 96 | def hook_function(module, grad_in, grad_out): 97 | # print (module,grad_in,grad_out) 98 | self.gradients = grad_in[0] 99 | 100 | # Register hook to the first layer 101 | first_layer = list(self.model.children())[0] 102 | first_layer.register_backward_hook(hook_function) 103 | 104 | def update_relus(self): 105 | """ 106 | Updates relu activation functions so that it only returns positive gradients 107 | """ 108 | 109 | def relu_hook_function(module, grad_in, grad_out): 110 | """ 111 | If there is a negative gradient, changes it to zero 112 | """ 113 | if isinstance(module, ReLU): 114 | return (torch.clamp(grad_in[0], min=0.0),) 115 | 116 | # Loop through layers, hook up ReLUs with relu_hook_function 117 | for module in self.model.modules(): 118 | if isinstance(module, ReLU): 119 | module.register_backward_hook(relu_hook_function) 120 | 121 | def generate_gradients(self, input_image): 122 | input_image.requires_grad = True 123 | # Forward pass 124 | model_output = self.model(input_image)[0] 125 | target_class = torch.argmax(model_output).item() 126 | 127 | # Zero gradients 128 | self.model.zero_grad() 129 | # Target for backprop 130 | one_hot_output = torch.FloatTensor(1, model_output.size()[-1]).zero_() 131 | one_hot_output[0][target_class] = 1 132 | 133 | # print (one_hot_output) 134 | # Backward pass 135 | model_output.backward(gradient=one_hot_output) 136 | # Convert Pytorch variable to numpy array 137 | # [0] to get rid of the first channel (1,3,224,224) 138 | gradients_as_arr = self.gradients.data.numpy()[0] 139 | return gradients_as_arr 140 | 141 | 142 | # df_ = pd.read_csv('data/cub200/CUB_200_2011/bounding_boxes.txt', sep=' ', index_col=0, 143 | # names=['idx', 'x', 'y', 'w', 'h']) 144 | # df_ = df_.copy() 145 | # data = [] 146 | 147 | # get input image list 148 | df = pd.read_csv('../data/cub200/CUB_200_2011/images.txt', sep=' ', index_col=0, names=['idx', 'img_names']) 149 | img_names = list(df['img_names']) 150 | img_list = np.array(["../data/cub200/CUB_200_2011/images/{}.jpg".format(img_name[:-4]) for img_name in img_names]) 151 | 152 | # predicted bounded box data 153 | pred_bbox_data = [] 154 | GBP = GuidedBackprop(expert_model.pretrained_model) 155 | 156 | # create output directories for storing data 157 | gbp_dir = '../data/cub200/CUB_200_2011/gbp_global/' 158 | segmentations_dir = '../data/cub200/CUB_200_2011/segmentations_pred/' 159 | if not os.path.exists(gbp_dir): 160 | os.makedirs(gbp_dir) 161 | if not os.path.exists(segmentations_dir): 162 | os.makedirs(segmentations_dir) 163 | 164 | for choose in range(len(img_list)): 165 | # get image segmentation and bounding box predictions 166 | img_name = img_list[choose] 167 | img = Image.open(img_name) # cv2.cvtColor(cv2.imread(img_name, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) 168 | img_ = transform_test(img).unsqueeze(0).to(device) 169 | with torch.no_grad(): 170 | prediction = model(img_) 171 | bbox_pred = prediction[0]['boxes'][0] 172 | x, y, w, h = bbox_pred.detach().cpu().numpy().astype(int) 173 | seg_mask = prediction[0]['masks'][0, 0].mul(255).byte().cpu().numpy() 174 | 175 | pred_bbox_data.append([x, y, w - x, h - y]) 176 | # df_.iloc[choose]['x'] = x 177 | # df_.iloc[choose]['y'] = y 178 | # df_.iloc[choose]['w'] = w - x 179 | # df_.iloc[choose]['h'] = h - y 180 | 181 | path_list = img_name.split('/') 182 | path_list[-3] = 'segmentations_pred' 183 | seg_name = '/'.join(path_list) 184 | # seg_name = "{}/segmentations_pred2/{}".format(img_name[:24], img_name[32:]) 185 | if not os.path.exists(os.path.dirname(seg_name)): 186 | os.makedirs(os.path.dirname(seg_name)) 187 | cv2.imwrite(seg_name, seg_mask) 188 | 189 | # get guided backpropagation maps from the expert model 190 | img = cv2.cvtColor(cv2.imread(img_name, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) 191 | H, W, C = img.shape 192 | img = img[y:h, x:w].astype(np.uint8) 193 | seg_mask = seg_mask.astype(float) / 255. 194 | seg_mask = seg_mask[y:h, x: w] 195 | 196 | img = img.astype(float) * np.expand_dims(seg_mask, -1) 197 | img = img.astype(np.uint8) 198 | 199 | scaled_img = gbp_transform(img) 200 | torch_images = scaled_img.unsqueeze(0) 201 | 202 | guided_grads = GBP.generate_gradients(torch_images) # .transpose(1,2,0) 203 | gbp = convert_to_grayscale(guided_grads)[0] 204 | gbp = cv2.resize(gbp, (w-x, h-y)) 205 | gbp = (255 * gbp).astype(np.uint8) 206 | 207 | seg_mask = (seg_mask > 0.5).astype(np.uint8) 208 | gbp = gbp * seg_mask 209 | 210 | gbp_global = np.zeros((H, W)).astype(np.uint8) 211 | w, h = min(w, W), min(h, H) 212 | gbp_global[y:h, x:w] = gbp 213 | 214 | path_list = img_name.split('/') 215 | path_list[-3] = 'gbp_global' 216 | gbp_name = '/'.join(path_list) 217 | if not os.path.exists(os.path.dirname(gbp_name)): 218 | os.makedirs(os.path.dirname(gbp_name)) 219 | cv2.imwrite(gbp_name, gbp_global) 220 | 221 | if choose % 100 == 0: 222 | print("Processed :{}/{} images ...".format(choose, len(img_list))) 223 | 224 | # save bounding box predictions 225 | df_ = pd.DataFrame(data=pred_bbox_data, columns=['x', 'y', 'w', 'h']) 226 | df_.to_csv('../data/cub200/CUB_200_2011/bounding_boxes_pred.txt', sep=' ') 227 | -------------------------------------------------------------------------------- /semantic_guidance/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import numpy as np 5 | import argparse 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torchvision import transforms, utils 9 | from PIL import Image 10 | from DRL.actor import * 11 | from Renderer.stroke_gen import * 12 | # from Renderer.model import * 13 | import pandas as pd 14 | from DRL.actor import ResNet 15 | from Renderer.model import FCN 16 | import matplotlib.pyplot as plt 17 | from test_utils import * 18 | from collections import OrderedDict 19 | 20 | # define device 21 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 22 | 23 | # parse input arguments 24 | parser = argparse.ArgumentParser(description='Paint canvases using semantic guidance') 25 | parser.add_argument('--max_eps_len', default=50, type=int, help='max length for episode') 26 | parser.add_argument('--actor', default='pretrained_models/actor_semantic_guidance.pkl', type=str, help='actor model') 27 | parser.add_argument('--use_baseline', action='store_true', help='use baseline model instead of semantic guidance') 28 | parser.add_argument('--renderer', default='../data/renderer.pkl', type=str, help='renderer model') 29 | parser.add_argument('--img', default='../image/test.png', type=str, help='test image') 30 | args = parser.parse_args() 31 | 32 | # width of canvas used 33 | width = 128 34 | # time image 35 | n_batch = 1 36 | T = torch.ones([n_batch, 1, width, width], dtype=torch.float32).to(device) 37 | coord = torch.zeros([n_batch, 2, width, width]) 38 | for i in range(width): 39 | for j in range(width): 40 | coord[:, 0, i, j] = i / (width - 1.) 41 | coord[:, 1, i, j] = j / (width - 1.) 42 | coord = coord.to(device) # Coordconv 43 | 44 | # define, load actor and set to eval mode 45 | bundle_size = 5 46 | use_gbp, use_bilevel, use_neural_alignment = False, False, False 47 | if not args.use_baseline: 48 | use_gbp, use_bilevel, use_neural_alignment = True, True, True 49 | bundle_size = int(np.ceil(bundle_size / 2)) 50 | actor = ResNet(9 + use_gbp + use_bilevel + 2 * use_neural_alignment, 18, 13 * bundle_size * (1 + use_bilevel)) 51 | 52 | # load trained actor model 53 | state_dict = torch.load(args.actor, map_location={'cuda:0': 'cpu'}) 54 | new_state_dict = OrderedDict() 55 | for k, v in state_dict.items(): 56 | name = k.replace("module.", "") # k[7:] # remove `module.` 57 | new_state_dict[name] = v 58 | actor.load_state_dict(new_state_dict) 59 | actor = actor.to(device).eval() 60 | 61 | if not args.use_baseline: 62 | # load object localization and semantic segmentation model 63 | seg_model = get_model_instance_segmentation(num_classes=2) 64 | seg_model.load_state_dict(torch.load('../data/birds_obj_seg.pkl', map_location={'cuda:0': 'cpu'})) 65 | for param in seg_model.parameters(): 66 | param.requires_grad = False 67 | seg_model.to(device) 68 | seg_model.eval() 69 | 70 | # load expert model for getting guided backpropagation maps 71 | expert_model = torch.hub.load('nicolalandro/ntsnet-cub200', 'ntsnet', pretrained=True, 72 | **{'topN': 6, 'device': 'cpu', 'num_classes': 200}) 73 | expert_model.eval() 74 | GBP = GuidedBackprop(expert_model.pretrained_model) 75 | 76 | # get segmentation, bbox and guided backpropagation map predictions 77 | img = Image.open(args.img) # cv2.cvtColor(cv2.imread(img_name, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) 78 | img_ = transform_test(img).unsqueeze(0).to(device) 79 | with torch.no_grad(): 80 | prediction = seg_model(img_) 81 | bbox_pred = prediction[0]['boxes'][0] 82 | x, y, w, h = bbox_pred.detach().cpu().numpy().astype(int) 83 | seg_mask = prediction[0]['masks'][0, 0].mul(255).byte().cpu().numpy() 84 | classgt_mask = seg_mask.copy() 85 | 86 | # get guided backpropagation maps from the expert model 87 | img = cv2.cvtColor(cv2.imread(args.img, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) 88 | H, W, C = img.shape 89 | img = img[y:h, x:w].astype(np.uint8) 90 | seg_mask = seg_mask.astype(float) / 255. 91 | seg_mask = seg_mask[y:h, x: w] 92 | 93 | img = img.astype(float) * np.expand_dims(seg_mask, -1) 94 | img = img.astype(np.uint8) 95 | 96 | scaled_img = gbp_transform(img) 97 | torch_images = scaled_img.unsqueeze(0) 98 | 99 | guided_grads = GBP.generate_gradients(torch_images) # .transpose(1,2,0) 100 | gbp = convert_to_grayscale(guided_grads)[0] 101 | gbp = cv2.resize(gbp, (w - x, h - y)) 102 | gbp = (255 * gbp).astype(np.uint8) 103 | 104 | seg_mask = (seg_mask > 0.5).astype(np.uint8) 105 | gbp = gbp * seg_mask 106 | 107 | gbp_global = np.zeros((H, W)).astype(np.uint8) 108 | w, h = min(w, W), min(h, H) 109 | gbp_global[y:h, x:w] = gbp 110 | 111 | # get grid for spatial transformer network 112 | w, h = w - x, h - y 113 | x, y, w, h = x / W, y / H, w / W, h / H 114 | Affine_Mat_w = [w, 0, (2 * x + w - 1)] 115 | Affine_Mat_h = [0, h, (2 * y + h - 1)] 116 | M = np.c_[Affine_Mat_w, Affine_Mat_h].T 117 | M = torch.tensor(M).unsqueeze(0) 118 | grid = torch.nn.functional.affine_grid(M, (1, 3, 128, 128)) # (1,128,128,2) 119 | grid = (grid + 1) / 2 # scale between 0,1 120 | grid = torch.tensor(grid * 255, dtype=torch.uint8).permute(0, 3, 1, 2) 121 | grid = grid.to(device).float() / 255. 122 | 123 | # load segmentation image 124 | classgt_img = cv2.resize(classgt_mask, (128, 128)) 125 | classgt_img = torch.tensor(classgt_img, dtype=torch.uint8) 126 | classgt_img = classgt_img.unsqueeze(0).unsqueeze(0).to(device).float() / 255. 127 | 128 | # load guided backpropagation map 129 | gbp_global = cv2.resize(gbp_global, (128, 128)) 130 | gbp_img = torch.tensor(gbp_global, dtype=torch.uint8) 131 | gbp_img = gbp_img.unsqueeze(0).unsqueeze(0).to(device).float() / 255. 132 | 133 | # load image 134 | image = cv2.cvtColor(cv2.imread(args.img, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) 135 | image = cv2.resize(image, (128, 128)) 136 | image = torch.tensor(image, dtype=torch.uint8).permute(2, 0, 1) 137 | img = image.unsqueeze(0).to(device).float() / 255. 138 | 139 | # initialize empty canvas 140 | canvas = torch.zeros([n_batch, 3, width, width]).to(device) 141 | max_eps_len = args.max_eps_len 142 | # initialize canvas frames for generating video 143 | canvas_frames = [] 144 | canvas_ = canvas.permute(0, 2, 3, 1).cpu().numpy() 145 | canvas_frames.append(canvas_[0]) 146 | 147 | with torch.no_grad(): 148 | # generate canvas episode 149 | for i in range(max_eps_len): 150 | stroke_type = 'both' # 'fg' if i > 1 else 'both' 151 | stepnum = T * i / max_eps_len 152 | if args.use_baseline: 153 | actions = actor(torch.cat([canvas, img, stepnum, coord], 1)) 154 | canvas, _, frames = decode(actions, canvas, get_frames=True, bundle_size=bundle_size) 155 | else: 156 | actions = actor(torch.cat([canvas, img, stepnum, coord, gbp_img, classgt_img, grid], 1)) 157 | canvas, _, frames = decode_parallel(actions, canvas, seg_mask=classgt_img, stroke_type=stroke_type, 158 | get_frames=True, bundle_size=bundle_size) 159 | canvas_frames += frames 160 | 161 | # get final painted canvas 162 | canvas_ = canvas.permute(0, 2, 3, 1).cpu().numpy() 163 | canvas_ = np.array(255 * canvas_[0]).astype(np.uint8) 164 | # H, W, C = cv2.imread(args.img, cv2.IMREAD_COLOR).shape 165 | H, W = 250, 250 166 | painted_canvas = cv2.cvtColor(cv2.resize(canvas_, (W, H)), cv2.COLOR_RGB2BGR) 167 | 168 | # save generated canvas 169 | save_img_name = "../output/{}_painted_{}".format('baseline' if args.use_baseline else 'sg', os.path.basename(args.img)) 170 | print("\nSaving painted canvas to {}\n".format(save_img_name)) 171 | cv2.imwrite(save_img_name, painted_canvas) 172 | 173 | # generate and save video for the painting episode 174 | video_name = '../video/{}_{}.mp4'.format('baseline' if args.use_baseline else 'sg', os.path.basename(args.img)[:-4]) 175 | video = cv2.VideoWriter(video_name, cv2.VideoWriter_fourcc(*'mp4v'), 20, (W, H)) 176 | for image in canvas_frames: 177 | image = np.array(255 * image, dtype=np.uint8) 178 | image = cv2.resize(image, (W, H)) 179 | video.write(cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) 180 | video.release() 181 | print("\nSaving painting video to {}\n".format(video_name)) 182 | -------------------------------------------------------------------------------- /semantic_guidance/test_utils.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | from torchvision.models.detection.faster_rcnn import FastRCNNPredictor 3 | from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor 4 | from torchvision import transforms, utils 5 | import os 6 | import cv2 7 | import torch 8 | import numpy as np 9 | import argparse 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torchvision import transforms, utils 13 | from Renderer.model import FCN 14 | from torch.nn import ReLU 15 | 16 | # define device 17 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 18 | 19 | 20 | def get_model_instance_segmentation(num_classes): 21 | # load an instance segmentation model pre-trained pre-trained on COCO 22 | model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True) 23 | 24 | # get number of input features for the classifier 25 | in_features = model.roi_heads.box_predictor.cls_score.in_features 26 | # replace the pre-trained head with a new one 27 | model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) 28 | 29 | # now get the number of input features for the mask classifier 30 | in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels 31 | hidden_layer = 256 32 | # and replace the mask predictor with a new one 33 | model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, 34 | hidden_layer, 35 | num_classes) 36 | 37 | return model 38 | 39 | 40 | # renderer model 41 | Decoder = FCN() 42 | Decoder.load_state_dict(torch.load('../data/renderer.pkl')) 43 | Decoder = Decoder.to(device).eval() 44 | 45 | def decode_parallel(x, canvas, seg_mask=None, mask=False, stroke_type='both', get_frames=False, bundle_size=3): 46 | # print (seg_mask.shape) 47 | 48 | bg, fg = True, True 49 | 50 | if stroke_type == 'fg': 51 | bg = False 52 | if stroke_type == 'bg': 53 | fg = False 54 | 55 | canvas_frames = [] 56 | if get_frames: 57 | if bg: 58 | canvas, _, frames = decode(x[:, :13 * bundle_size], canvas, mask, 1 - seg_mask, get_frames=get_frames,bundle_size=bundle_size) 59 | canvas_frames += frames 60 | if fg: 61 | canvas, _, frames = decode(x[:, 13 * bundle_size:], canvas, mask, seg_mask, get_frames=get_frames,bundle_size=bundle_size) 62 | canvas_frames += frames 63 | return canvas, _, canvas_frames 64 | 65 | if bg: 66 | canvas, _ = decode(x[:, :13 * bundle_size], canvas, mask, 1 - seg_mask, get_frames=get_frames,bundle_size=bundle_size) 67 | if fg: 68 | canvas, _ = decode(x[:, 13 * bundle_size:], canvas, mask, seg_mask, get_frames=get_frames,bundle_size=bundle_size) 69 | return canvas, _ 70 | 71 | 72 | def decode(x, canvas, mask=False, seg_mask=None, get_frames=False, bundle_size=5): # b * (10 + 3) 73 | """ 74 | Update canvas given stroke parameters x 75 | """ 76 | # 13 stroke parameters (10 position and 3 RGB color) 77 | x = x.contiguous().view(-1, 10 + 3) 78 | 79 | # get stroke on an empty canvas given 10 positional parameters 80 | stroke = 1 - Decoder(x[:, :10]) 81 | 82 | stroke = stroke.view(-1, 128, 128, 1) 83 | 84 | # add color to the stroke 85 | color_stroke = stroke * x[:, -3:].view(-1, 1, 1, 3) 86 | stroke = stroke.permute(0, 3, 1, 2) 87 | color_stroke = color_stroke.permute(0, 3, 1, 2) 88 | 89 | stroke = stroke.view(-1, bundle_size, 1, 128, 128) 90 | color_stroke = color_stroke.view(-1, bundle_size, 3, 128, 128) 91 | 92 | canvas_frames = [] 93 | for i in range(bundle_size): 94 | if seg_mask is not None: 95 | canvas = canvas * (1 - stroke[:, i]) + color_stroke[:, i] * seg_mask 96 | else: 97 | canvas = canvas * (1 - stroke[:, i]) + color_stroke[:, i] 98 | canvas_ = canvas.permute(0, 2, 3, 1).cpu().numpy() 99 | canvas_frames.append(canvas_[0]) 100 | 101 | # also return stroke mask if required 102 | stroke_mask = None 103 | if mask: 104 | stroke_mask = (stroke != 0).float() # -1, bundle_size, 1, width, width 105 | 106 | if get_frames: 107 | return canvas, stroke_mask, canvas_frames 108 | return canvas, stroke_mask 109 | 110 | 111 | gbp_transform = transforms.Compose([ 112 | transforms.ToPILImage(), 113 | transforms.Resize((440, 440)), 114 | transforms.CenterCrop((440, 440)), 115 | # transforms.RandomHorizontalFlip(), # only if train 116 | transforms.ToTensor(), 117 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 118 | ]) 119 | 120 | transform_test = transforms.Compose([ 121 | transforms.ToTensor(), 122 | ]) 123 | 124 | 125 | def convert_to_grayscale(im_as_arr): 126 | """ 127 | Converts 3d image to grayscale 128 | Args: 129 | im_as_arr (numpy arr): RGB image with shape (D,W,H) 130 | returns: 131 | grayscale_im (numpy_arr): Grayscale image with shape (1,W,D) 132 | """ 133 | grayscale_im = np.sum(np.abs(im_as_arr), axis=0) 134 | im_max = np.percentile(grayscale_im, 99) 135 | im_min = np.min(grayscale_im) 136 | grayscale_im = (np.clip((grayscale_im - im_min) / (im_max - im_min), 0, 1)) 137 | grayscale_im = np.expand_dims(grayscale_im, axis=0) 138 | return grayscale_im 139 | 140 | 141 | class GuidedBackprop(): 142 | """ 143 | Produces gradients generated with guided back propagation from the given image 144 | """ 145 | 146 | def __init__(self, model): 147 | self.model = model 148 | self.gradients = None 149 | # Put model in evaluation mode 150 | self.model.eval() 151 | self.update_relus() 152 | self.hook_layers() 153 | 154 | def hook_layers(self): 155 | def hook_function(module, grad_in, grad_out): 156 | # print (module,grad_in,grad_out) 157 | self.gradients = grad_in[0] 158 | 159 | # Register hook to the first layer 160 | first_layer = list(self.model.children())[0] 161 | first_layer.register_backward_hook(hook_function) 162 | 163 | def update_relus(self): 164 | """ 165 | Updates relu activation functions so that it only returns positive gradients 166 | """ 167 | 168 | def relu_hook_function(module, grad_in, grad_out): 169 | """ 170 | If there is a negative gradient, changes it to zero 171 | """ 172 | if isinstance(module, ReLU): 173 | return (torch.clamp(grad_in[0], min=0.0),) 174 | 175 | # Loop through layers, hook up ReLUs with relu_hook_function 176 | for module in self.model.modules(): 177 | if isinstance(module, ReLU): 178 | module.register_backward_hook(relu_hook_function) 179 | 180 | def generate_gradients(self, input_image): 181 | input_image.requires_grad = True 182 | # Forward pass 183 | model_output = self.model(input_image)[0] 184 | target_class = torch.argmax(model_output).item() 185 | 186 | # Zero gradients 187 | self.model.zero_grad() 188 | # Target for backprop 189 | one_hot_output = torch.FloatTensor(1, model_output.size()[-1]).zero_() 190 | one_hot_output[0][target_class] = 1 191 | 192 | # print (one_hot_output) 193 | # Backward pass 194 | model_output.backward(gradient=one_hot_output) 195 | # Convert Pytorch variable to numpy array 196 | # [0] to get rid of the first channel (1,3,224,224) 197 | gradients_as_arr = self.gradients.data.numpy()[0] 198 | return gradients_as_arr 199 | -------------------------------------------------------------------------------- /semantic_guidance/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import cv2 3 | import random 4 | import numpy as np 5 | import argparse 6 | from DRL.evaluator import Evaluator 7 | from utils.util import * 8 | from utils.tensorboard import TensorBoard 9 | import time 10 | 11 | from DRL.ddpg import DDPG 12 | 13 | import argparse 14 | import torch.distributed as dist 15 | import sys, os 16 | import torch 17 | import torch.multiprocessing as mp 18 | import numpy as np 19 | 20 | # exp = os.path.abspath('.').split('/')[-1] 21 | # writer = TensorBoard('../train_log/{}'.format(exp)) 22 | # os.system('ln -sf ../train_log/{} ./log'.format(exp)) 23 | os.system('mkdir ./model') 24 | 25 | 26 | def train(agent, evaluate, writer, args, gpu, distributed): 27 | """ 28 | :param agent: DDPG agent 29 | :param evaluate: 30 | :param writer: tensorboard summary writer 31 | :return: None 32 | """ 33 | ## hyperparameters 34 | # total timesteps for training 35 | train_timesteps = args.train_timesteps 36 | # number of parallel environments for faster sample collection 37 | nenv = args.nenv 38 | # number of episodes between validation tests 39 | val_interval = args.val_interval 40 | # maximum length of a single painting episode (number of brush strokes) 41 | max_eps_len = args.max_eps_len 42 | # number of training steps per episode 43 | train_steps_per_eps = args.train_steps_per_eps 44 | # noise factor used in training 45 | noise_factor = args.noise_factor 46 | 47 | ## display progress 48 | # verbose: print training progress if true 49 | debug = args.debug 50 | 51 | ## load and save directories 52 | # path for stored model if resuming training 53 | load_path = args.load_path 54 | # parent directory for storing trained models 55 | output = args.output 56 | 57 | ## intializations 58 | # get current time stamp 59 | time_stamp = time.time() 60 | # initialize training steps 61 | step = episode = episode_steps = 0 62 | # initialize total reward 63 | tot_reward = 0. 64 | # initialize state 65 | observation = None 66 | 67 | # synchronize initial models 68 | agent.save_model(output, episode) 69 | 70 | # begin training 71 | while step <= train_timesteps: 72 | # update training steps 73 | step += 1 74 | # steps within an episode (cannot be greater than max_eps_len) 75 | episode_steps += 1 76 | 77 | ## take a step in concurrent environments and store samples to the replay buffer 78 | # reset if it is the start of episode 79 | if observation is None: 80 | observation = agent.env.reset() 81 | agent.reset(observation, noise_factor) 82 | action = agent.select_action(observation, noise_factor=noise_factor) 83 | observation, reward, done, _ = agent.env.step(action) 84 | # store r,s,a tuple to replay memory 85 | agent.observe(reward, observation, done, step) 86 | 87 | # store progress and train if episode is done 88 | if episode_steps >= max_eps_len and max_eps_len: 89 | if step > args.warmup: 90 | # compute validation results 91 | if episode > 0 and val_interval > 0 and episode % val_interval == 0: 92 | reward, dist = evaluate(agent.env, agent.select_action, debug=debug) 93 | if debug and gpu == 0: 94 | prRed('Step_{:07d}: mean_reward:{:.3f} mean_dist:{:.3f} var_dist:{:.3f}'.format(step - 1, 95 | np.mean(reward), 96 | np.mean(dist), 97 | np.var(dist))) 98 | writer.add_scalar('validate/mean_reward', np.mean(reward), step) 99 | writer.add_scalar('validate/mean_dist', np.mean(dist), step) 100 | writer.add_scalar('validate/var_dist', np.var(dist), step) 101 | 102 | # save latest model 103 | if episode % 200 == 0: 104 | agent.save_model(output, episode) 105 | 106 | # get training time and update timestamp 107 | train_time_interval = time.time() - time_stamp 108 | time_stamp = time.time() 109 | 110 | # initialize total expected value and overall value loss for the episode 111 | tot_Q = 0. 112 | tot_value_loss = 0. 113 | 114 | if step > args.warmup: 115 | # step learning rate schedule after (1e4,2e4) episodes 116 | # also note lr[0],lr[1] are for updating critic,actor respectively. 117 | if step < 10000 * max_eps_len: 118 | lr = (3e-4, 1e-3) 119 | elif step < 20000 * max_eps_len: 120 | lr = (1e-4, 3e-4) 121 | else: 122 | lr = (3e-5, 1e-4) 123 | 124 | # perform several training steps after each episode 125 | for i in range(train_steps_per_eps): 126 | # train the agent 127 | Q, value_loss = agent.update_policy(lr) 128 | tot_Q += Q.data.cpu().numpy() 129 | tot_value_loss += value_loss.data.cpu().numpy() 130 | 131 | # store training performance summaries 132 | if gpu == 0: 133 | writer.add_scalar('train/critic_lr', lr[0], step) 134 | writer.add_scalar('train/actor_lr', lr[1], step) 135 | writer.add_scalar('train/Q', tot_Q / train_steps_per_eps, step) 136 | writer.add_scalar('train/critic_loss', tot_value_loss / train_steps_per_eps, step) 137 | 138 | # display training progress 139 | if debug and gpu == 0: 140 | prBlack('#{}: steps:{} interval_time:{:.2f} train_time:{:.2f}' \ 141 | .format(episode, step, train_time_interval, time.time() - time_stamp)) 142 | 143 | # reset/update timestamp and episode stats 144 | time_stamp = time.time() 145 | observation = None 146 | episode_steps = 0 147 | episode += 1 148 | 149 | 150 | def setup(args): 151 | os.environ['MASTER_ADDR'] = 'localhost' 152 | os.environ['MASTER_PORT'] = '12354' 153 | 154 | # initialize the process group 155 | dist.init_process_group("nccl", rank=args.rank, world_size=args.world_size) 156 | # limit number of threads per process 157 | setup_pytorch_for_mpi(args) 158 | torch.backends.cudnn.benchmark = True 159 | torch.backends.cudnn.deterministic = False 160 | 161 | 162 | def cleanup(): 163 | dist.destroy_process_group() 164 | 165 | 166 | def setup_pytorch_for_mpi(args): 167 | """ 168 | Avoid slowdowns caused by each separate process's PyTorch using 169 | more than its fair share of CPU resources. 170 | """ 171 | print('Proc %d: Reporting original number of Torch threads as %d.' % (args.rank, torch.get_num_threads()), 172 | flush=True) 173 | if torch.get_num_threads() == 1: 174 | return 175 | fair_num_threads = max(int(torch.get_num_threads() / args.world_size), 1) 176 | fair_num_threads = 1 177 | torch.set_num_threads(fair_num_threads) 178 | print('Proc %d: Reporting new number of Torch threads as %d.' % (args.rank, torch.get_num_threads()), flush=True) 179 | 180 | 181 | def demo(gpu, args): 182 | # rank of the current process 183 | args.rank = args.nr * args.gpus + gpu 184 | # setup dist process 185 | setup(args) 186 | 187 | # Random seed 188 | seed = 10000 * gpu + gpu 189 | torch.manual_seed(seed) 190 | np.random.seed(seed) 191 | random.seed(seed) 192 | 193 | # # summary writer 194 | writer = TensorBoard(args.LOG_DIR) 195 | 196 | # setup concurrent environments 197 | # define agent 198 | agent = DDPG(args.batch_size, args.nenv, args.max_eps_len, 199 | args.tau, args.discount, args.rmsize, 200 | writer, args.load_path, args.output, args.dataset, 201 | use_gbp=args.use_gbp, use_bilevel=args.use_bilevel, 202 | gbp_coef=args.gbp_coef, seggt_coef=args.seggt_coef, 203 | gpu=gpu, distributed=False, 204 | high_res=args.high_res, bundle_size=args.bundle_size) 205 | evaluate = Evaluator(args, writer) 206 | 207 | # display state, action space info 208 | if gpu == 0: 209 | print('observation_space', agent.env.observation_space, 'action_space', agent.env.action_space) 210 | 211 | # begin training 212 | train(agent, evaluate, writer, args, gpu=gpu, distributed=True) 213 | 214 | if gpu == 0: 215 | print("Training finished") 216 | 217 | cleanup() 218 | 219 | 220 | if __name__ == "__main__": 221 | parser = argparse.ArgumentParser(description='Learning to Paint') 222 | 223 | # hyper-parameter 224 | parser.add_argument('--dataset', type=str, default='cub200', choices=['cub200'], 225 | help='dataset') 226 | parser.add_argument('--warmup', default=400, type=int, 227 | help='timestep without training but only filling the replay memory') 228 | parser.add_argument('--discount', default=0.95 ** 5, type=float, help='discount factor (gamma)') 229 | parser.add_argument('--batch_size', default=96, type=int, help='minibatch size') 230 | parser.add_argument('--bundle_size', default=5, type=int, help='action bundle size') 231 | parser.add_argument('--rmsize', default=800, type=int, help='replay memory size') 232 | parser.add_argument('--nenv', default=96, type=int, 233 | help='concurrent environment number/ number of environments') 234 | parser.add_argument('--tau', default=0.001, type=float, help='moving average for target network') 235 | parser.add_argument('--max_eps_len', default=40, type=int, help='max length for episode (*)') 236 | parser.add_argument('--noise_factor', default=0, type=float, help='noise level for parameter space noise') 237 | parser.add_argument('--val_interval', default=50, type=int, help='episode interval for performing validation') 238 | parser.add_argument('--val_num_eps', default=5, type=int, help='episodes used for performing validation') 239 | parser.add_argument('--train_timesteps', default=int(2e6), type=int, help='total training steps') 240 | parser.add_argument('--train_steps_per_eps', default=10, type=int, help='number of training steps per episode') 241 | parser.add_argument('--load_path', default=None, type=str, help='Load model and resume training') 242 | parser.add_argument('--exp_suffix', default='base', type=str, 243 | help='suffix for providing additional experiment info') 244 | parser.add_argument('--output', default='./model', type=str, help='Output path for storing model') 245 | parser.add_argument('--use_gbp', action='store_true', help='use gbp info along with rgb image') 246 | parser.add_argument('--use_bilevel', action='store_true', help='use semantic class maps info along with rgb image') 247 | parser.add_argument('--bundled_seggt', action='store_true', 248 | help='all strokes in a bundle should belong to same class') 249 | parser.add_argument('--gbp_coef', default=1.0, type=float, help='coefficient for gbp reward') 250 | parser.add_argument('--seggt_coef', default=1.0, type=float, help='coefficient for seggt reward') 251 | parser.add_argument('--high_res', action='store_true', help='use high resolution gt(256 x 256)') 252 | parser.add_argument('--debug', dest='debug', action='store_true', help='print some info') 253 | parser.add_argument('--seed', default=1234, type=int, help='random seed') 254 | parser.add_argument('-n', '--nodes', default=1, type=int, metavar='N', 255 | help='number of data loading workers (default: 1)') 256 | parser.add_argument('-g', '--gpus', default=1, type=int, help='number of gpus per node') 257 | parser.add_argument('-nr', '--nr', default=0, type=int, help='ranking within the nodes') 258 | 259 | # parse args 260 | args = parser.parse_args() 261 | 262 | # if using bilevel painting procedure than use only half of brush strokes for foreground 263 | # and rest for background 264 | if args.use_bilevel: 265 | args.bundle_size = int(np.ceil(args.bundle_size/2)) 266 | 267 | # log directory 268 | exp_type = os.path.abspath('.').split('/')[-1] 269 | LOG_DIR = "../train_log/{}/".format(exp_type) 270 | # choose log directory based on the experiment 271 | exp_name = "{}/nenv{}_batchsize{}_maxstep_{}_tau{}_memsize{}_{}{}{}_{}".format(args.dataset, args.nenv, 272 | args.batch_size, 273 | args.max_eps_len, args.tau, 274 | args.rmsize, 275 | "bundlesize{}".format(args.bundle_size), 276 | "_gbp{}".format(args.gbp_coef) if args.use_gbp else "", 277 | "_seggt{}".format(args.seggt_coef) if args.use_bilevel else "", 278 | args.exp_suffix) 279 | 280 | # create summary writer 281 | LOG_DIR += exp_name 282 | args.LOG_DIR = LOG_DIR 283 | 284 | # create output directory 285 | args.output = get_output_folder(args.output + "/" + exp_name, "Paint") 286 | 287 | # total number of processes 288 | args.world_size = args.gpus * args.nodes 289 | 290 | print("starting") 291 | # launch multiple processes 292 | mp.spawn(demo, nprocs=args.gpus, args=(args,), join=True) 293 | -------------------------------------------------------------------------------- /semantic_guidance/train_renderer.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import argparse 4 | import numpy as np 5 | 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from utils.tensorboard import TensorBoard 9 | from Renderer.model import FCN 10 | from Renderer.stroke_gen import * 11 | 12 | def save_model(): 13 | if use_cuda: 14 | net.cpu() 15 | torch.save(net.state_dict(), "data/renderer.pkl") 16 | if use_cuda: 17 | net.cuda() 18 | 19 | 20 | def load_weights(): 21 | pretrained_dict = torch.load("data/renderer.pkl") 22 | model_dict = net.state_dict() 23 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 24 | model_dict.update(pretrained_dict) 25 | net.load_state_dict(model_dict) 26 | 27 | 28 | if __name__ == "__main__": 29 | parser = argparse.ArgumentParser(description='Learning to Paint') 30 | 31 | # hyper-parameter 32 | parser.add_argument('--high_res', default=False, type=bool, help='resolution') 33 | args = parser.parse_args() 34 | 35 | writer = TensorBoard("./train_log/") 36 | import torch.optim as optim 37 | 38 | criterion = nn.MSELoss() 39 | net = FCN(args.high_res) 40 | optimizer = optim.Adam(net.parameters(), lr=3e-6) 41 | batch_size = 64 42 | 43 | use_cuda = torch.cuda.is_available() 44 | step = 0 45 | 46 | load_weights() 47 | while step < 500000: 48 | net.train() 49 | train_batch = [] 50 | ground_truth = [] 51 | for i in range(batch_size): 52 | f = np.random.uniform(0, 1, 10) 53 | train_batch.append(f) 54 | if args.high_res is True: 55 | ground_truth.append(draw(f, width=256)) 56 | else: 57 | ground_truth.append(draw(f)) 58 | 59 | train_batch = torch.tensor(train_batch).float() 60 | ground_truth = torch.tensor(ground_truth).float() 61 | if use_cuda: 62 | net = net.cuda() 63 | train_batch = train_batch.cuda() 64 | ground_truth = ground_truth.cuda() 65 | gen = net(train_batch) 66 | optimizer.zero_grad() 67 | loss = criterion(gen, ground_truth) 68 | loss.backward() 69 | optimizer.step() 70 | print(step, loss.item()) 71 | if step < 200000: 72 | lr = 1e-4 73 | elif step < 400000: 74 | lr = 1e-5 75 | else: 76 | lr = 1e-6 77 | for param_group in optimizer.param_groups: 78 | param_group["lr"] = lr 79 | writer.add_scalar("train/loss", loss.item(), step) 80 | if step % 100 == 0: 81 | net.eval() 82 | gen = net(train_batch) 83 | loss = criterion(gen, ground_truth) 84 | writer.add_scalar("val/loss", loss.item(), step) 85 | for i in range(32): 86 | G = gen[i].cpu().data.numpy() 87 | GT = ground_truth[i].cpu().data.numpy() 88 | writer.add_image("train/gen{}.png".format(i), G, step) 89 | writer.add_image("train/ground_truth{}.png".format(i), GT, step) 90 | if step % 1000 == 0: 91 | save_model() 92 | step += 1 93 | -------------------------------------------------------------------------------- /semantic_guidance/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1jsingh/semantic-guidance/092f923a0f5d42a949489bd49df8449db9d94d53/semantic_guidance/utils/__init__.py -------------------------------------------------------------------------------- /semantic_guidance/utils/dataloader.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | import torch 4 | import numpy as np 5 | from torchvision import transforms, utils 6 | import cv2 7 | 8 | 9 | def adjust_gamma(image, gamma=1.0): 10 | # build a lookup table mapping the pixel values [0, 255] to 11 | # their adjusted gamma values 12 | invGamma = 1.0 / gamma 13 | table = np.array([((i / 255.0) ** invGamma) * 255 14 | for i in np.arange(0, 256)]).astype("uint8") 15 | # apply gamma correction using the lookup table 16 | return cv2.LUT(image, table) 17 | 18 | 19 | class ImageDataset: 20 | """ 21 | Dataset for loading images at run-time 22 | """ 23 | 24 | def __init__(self, img_list, gbp_list=None, seg_list=None, transform=None, width=128, 25 | bbox_list=None, high_res=False): 26 | """ 27 | Args: 28 | img_list (string): list of images 29 | transform (callable, optional): Optional transform to be applied 30 | on a sample. 31 | """ 32 | self.img_list = img_list 33 | self.gbp_list = gbp_list 34 | self.seg_list = seg_list 35 | self.transform = transform 36 | self.width = width 37 | self.high_res = high_res 38 | self.bbox_list = bbox_list 39 | 40 | if self.high_res is True: 41 | self.width = 256 42 | 43 | # custom transforms 44 | self.horizontal_flip = transforms.RandomHorizontalFlip(p=1.0) 45 | self.resize = transforms.Resize((width, width)) 46 | self.toTensor = transforms.ToTensor() 47 | 48 | def __len__(self): 49 | return len(self.img_list) 50 | 51 | def __getitem__(self, idx, gbp_img=None): 52 | if torch.is_tensor(idx): 53 | idx = idx.tolist() 54 | 55 | # flip horizontal 56 | flip_horizontal = np.random.rand() > 0.5 57 | 58 | # read rgb image 59 | img_name = self.img_list[idx] 60 | image = cv2.cvtColor(cv2.imread(img_name), cv2.COLOR_BGR2RGB) 61 | H, W, C = image.shape 62 | image = cv2.resize(image, (self.width, self.width)) 63 | # image = adjust_gamma(image, gamma=1.5) 64 | if flip_horizontal: 65 | image = cv2.flip(image, 1) 66 | image = torch.tensor(image, dtype=torch.uint8).permute(2, 0, 1) 67 | 68 | # initialize gbp, seg_img image 69 | gbp_img, seg_img, grid = None, None, None 70 | 71 | # read gbp image 72 | if self.gbp_list is not None: 73 | gbp_fname = self.gbp_list[idx] 74 | 75 | gbp_img = cv2.cvtColor(cv2.imread(gbp_fname), cv2.COLOR_BGR2GRAY) 76 | gbp_img = cv2.resize(gbp_img, (self.width, self.width)) 77 | if flip_horizontal: 78 | gbp_img = cv2.flip(gbp_img, 1) 79 | gbp_img = torch.tensor(gbp_img, dtype=torch.uint8).unsqueeze(0) 80 | 81 | if self.seg_list is not None: 82 | seg_fname = self.seg_list[idx] 83 | seg_img = cv2.cvtColor(cv2.imread(seg_fname), cv2.COLOR_BGR2GRAY) 84 | seg_img = cv2.resize(seg_img, (self.width, self.width)) 85 | if flip_horizontal: 86 | seg_img = cv2.flip(seg_img, 1) 87 | # convert to tensor 88 | seg_img = torch.tensor(seg_img, dtype=torch.uint8).unsqueeze(0) 89 | 90 | # create grid 91 | x, y, w, h = self.bbox_list[idx] 92 | x, y, w, h = x / W, y / H, w / W, h / H 93 | if flip_horizontal: 94 | x = 1 - x - w 95 | Affine_Mat_w = [w, 0, (2 * x + w - 1)] 96 | Affine_Mat_h = [0, h, (2 * y + h - 1)] 97 | M = np.c_[Affine_Mat_w, Affine_Mat_h].T 98 | M = torch.tensor(M).unsqueeze(0) 99 | grid = torch.nn.functional.affine_grid(M, (1, 3, 128, 128)) # (1,128,128,2) 100 | grid = (grid + 1) / 2 # scale between 0,1 101 | grid = torch.tensor(grid * 255, dtype=torch.uint8).permute(0, 3, 1, 2) 102 | 103 | return image, gbp_img, seg_img, grid -------------------------------------------------------------------------------- /semantic_guidance/utils/tensorboard.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import scipy.misc 3 | from io import BytesIO 4 | import tensorboardX as tb 5 | from tensorboardX.summary import Summary 6 | 7 | class TensorBoard(object): 8 | def __init__(self, model_dir): 9 | self.summary_writer = tb.FileWriter(model_dir) 10 | 11 | def add_image(self, tag, img, step): 12 | summary = Summary() 13 | bio = BytesIO() 14 | 15 | if type(img) == str: 16 | img = Image.open(img) 17 | elif type(img) == Image.Image: 18 | pass 19 | else: 20 | img = scipy.misc.toimage(img) 21 | 22 | img.save(bio, format="png") 23 | image_summary = Summary.Image(encoded_image_string=bio.getvalue()) 24 | summary.value.add(tag=tag, image=image_summary) 25 | self.summary_writer.add_summary(summary, global_step=step) 26 | 27 | def add_scalar(self, tag, value, step): 28 | summary = Summary(value=[Summary.Value(tag=tag, simple_value=value)]) 29 | self.summary_writer.add_summary(summary, global_step=step) 30 | -------------------------------------------------------------------------------- /semantic_guidance/utils/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.autograd import Variable 4 | 5 | USE_CUDA = torch.cuda.is_available() 6 | 7 | def prRed(prt): print("\033[91m {}\033[00m" .format(prt)) 8 | def prGreen(prt): print("\033[92m {}\033[00m" .format(prt)) 9 | def prYellow(prt): print("\033[93m {}\033[00m" .format(prt)) 10 | def prLightPurple(prt): print("\033[94m {}\033[00m" .format(prt)) 11 | def prPurple(prt): print("\033[95m {}\033[00m" .format(prt)) 12 | def prCyan(prt): print("\033[96m {}\033[00m" .format(prt)) 13 | def prLightGray(prt): print("\033[97m {}\033[00m" .format(prt)) 14 | def prBlack(prt): print("\033[98m {}\033[00m" .format(prt)) 15 | 16 | def to_numpy(var): 17 | return var.cpu().data.numpy() if USE_CUDA else var.data.numpy() 18 | 19 | def to_tensor(ndarray, device): 20 | return torch.tensor(ndarray, dtype=torch.float, device=device) 21 | 22 | def soft_update(target, source, tau): 23 | for target_param, param in zip(target.parameters(), source.parameters()): 24 | target_param.data.copy_( 25 | target_param.data * (1.0 - tau) + param.data * tau 26 | ) 27 | 28 | def hard_update(target, source): 29 | for m1, m2 in zip(target.modules(), source.modules()): 30 | m1._buffers = m2._buffers.copy() 31 | for target_param, param in zip(target.parameters(), source.parameters()): 32 | target_param.data.copy_(param.data) 33 | 34 | def get_output_folder(parent_dir, env_name): 35 | """Return save folder. 36 | 37 | Assumes folders in the parent_dir have suffix -run{run 38 | number}. Finds the highest run number and sets the output folder 39 | to that number + 1. This is just convenient so that if you run the 40 | same script multiple times tensorboard can plot all of the results 41 | on the same plots with different names. 42 | 43 | Parameters 44 | ---------- 45 | parent_dir: str 46 | Path of the directory containing all experiment runs. 47 | 48 | Returns 49 | ------- 50 | parent_dir/run_dir 51 | Path to this run's save directory. 52 | """ 53 | os.makedirs(parent_dir, exist_ok=True) 54 | experiment_id = 0 55 | for folder_name in os.listdir(parent_dir): 56 | if not os.path.isdir(os.path.join(parent_dir, folder_name)): 57 | continue 58 | try: 59 | folder_name = int(folder_name.split('-run')[-1]) 60 | if folder_name > experiment_id: 61 | experiment_id = folder_name 62 | except: 63 | pass 64 | experiment_id += 1 65 | 66 | parent_dir = os.path.join(parent_dir, env_name) 67 | parent_dir = parent_dir + '-run{}'.format(experiment_id) 68 | os.makedirs(parent_dir, exist_ok=True) 69 | return parent_dir 70 | -------------------------------------------------------------------------------- /video/baseline_target_bird_4648.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1jsingh/semantic-guidance/092f923a0f5d42a949489bd49df8449db9d94d53/video/baseline_target_bird_4648.mp4 -------------------------------------------------------------------------------- /video/sg_bird_5602.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1jsingh/semantic-guidance/092f923a0f5d42a949489bd49df8449db9d94d53/video/sg_bird_5602.gif -------------------------------------------------------------------------------- /video/sg_target_bird_4648.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1jsingh/semantic-guidance/092f923a0f5d42a949489bd49df8449db9d94d53/video/sg_target_bird_4648.mp4 --------------------------------------------------------------------------------