├── .gitignore
├── LICENSE
├── README.md
├── figures
├── IntrinsicCompositingVideo.jpg
├── astronauts.png
├── teaser2.jpg
├── teaser_comparison.jpg
├── teaser_pipeline.jpg
└── user_study_comp.jpg
├── inference
├── examples
│ ├── cone_chair
│ │ ├── bg.jpeg
│ │ ├── composite.png
│ │ └── mask.png
│ ├── lamp_candles
│ │ ├── bg.jpeg
│ │ ├── composite.png
│ │ └── mask.png
│ └── lamp_soap
│ │ ├── bg.jpeg
│ │ ├── composite.png
│ │ └── mask.png
└── inference.py
├── interface
├── examples
│ ├── bgs
│ │ ├── blue_chairs.jpeg
│ │ ├── boxes.jpeg
│ │ ├── classroom.jpeg
│ │ ├── cone_org.jpeg
│ │ ├── dim.jpeg
│ │ ├── dock.jpeg
│ │ ├── empty_room.jpeg
│ │ ├── lamp.jpeg
│ │ ├── museum2.jpeg
│ │ ├── museum3.jpeg
│ │ └── pillar.jpeg
│ ├── fgs
│ │ ├── astro.png
│ │ ├── figurine.png
│ │ ├── soap.png
│ │ ├── white_bag.png
│ │ ├── white_chair.png
│ │ └── white_pot.png
│ └── masks
│ │ ├── astro.png
│ │ ├── figurine.png
│ │ ├── soap.png
│ │ ├── white_bag.png
│ │ ├── white_chair.png
│ │ └── white_pot.png
└── interface.py
├── intrinsic_compositing
├── __init__.py
├── albedo
│ ├── __init__.py
│ ├── model
│ │ ├── MiDaS
│ │ │ ├── __init__.py
│ │ │ └── midas
│ │ │ │ ├── __init__.py
│ │ │ │ ├── base_model.py
│ │ │ │ ├── blocks.py
│ │ │ │ ├── dpt_depth.py
│ │ │ │ ├── midas_net.py
│ │ │ │ ├── midas_net_custom.py
│ │ │ │ ├── transforms.py
│ │ │ │ └── vit.py
│ │ ├── __init__.py
│ │ ├── discriminator.py
│ │ ├── editingnetwork_trainer.py
│ │ ├── parametermodel.py
│ │ └── pix2pix
│ │ │ ├── __init__.py
│ │ │ └── models
│ │ │ ├── __init__.py
│ │ │ ├── base_model.py
│ │ │ └── networks.py
│ ├── pipeline.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── datautils.py
│ │ ├── depthutils.py
│ │ ├── edits.py
│ │ ├── networkutils.py
│ │ └── utils.py
└── shading
│ ├── __init__.py
│ └── pipeline.py
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright 2023, Chris Careaga, S. Mahdi H. Miangoleh, Yağız Aksoy, Computational Photography Laboratory. All rights reserved.
2 |
3 | This software is for academic use only. A redistribution of this
4 | software, with or without modifications, has to be for academic
5 | use only, while giving the appropriate credit to the original
6 | authors of the software. The methods implemented as a part of
7 | this software may be covered under patents or patent applications.
8 |
9 | THIS SOFTWARE IS PROVIDED BY THE AUTHOR ''AS IS'' AND ANY EXPRESS OR IMPLIED
10 | WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
11 | FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR
12 | CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
13 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
14 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
15 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
16 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
17 | ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
18 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Intrinsic Harmonization for Illumination-Aware Compositing
2 | Code for the paper: Intrinsic Harmonization for Illumination-Aware Compositing, [Chris Careaga](https://ccareaga.github.io), [S. Mahdi H. Miangoleh](https://miangoleh.github.io) [Yağız Aksoy](https://yaksoy.github.io), Proc. SIGGRAPH Asia, 2023
3 | ### [Project Page](https://yaksoy.github.io/intrinsicCompositing) | [Paper](https://yaksoy.github.io/papers/SigAsia23-IntrinsicCompositing.pdf) | [Video](https://www.youtube.com/watch?v=M9hCUTp8bo4) | [Supplementary](https://yaksoy.github.io/papers/SigAsia23-IntrinsicCompositing-Supp.pdf)
4 |
5 | We propose an illumination-aware image harmonization approach for in-the-wild imagery. Our method is formulated in the intrinsic image domain. We use off-the-shelf networks to generate albedo, shading and surface normals for the input composite and background image. We first harmonize the albedo of the background and foreground by predicting image editing parameters. Using normals and shading we estimate a simple lighting model for the background illumination. With this lighting model, we render Lambertian shading for the foreground and refine it using a network trained on segmentation datasets via self-supervision. When compared to prior works we are the only method that is capable of modeling realistic lighting effects.
6 |
7 | [](https://www.youtube.com/watch?v=M9hCUTp8bo4)
8 |
9 |
10 | ## Method
11 | Compositing is a crucial image editing task requiring realistic integration of objects into new backgrounds.
12 | Achieving a natural composition requires adjusting the appearance of the inserted object through a process called image harmonization.
13 | While existing literature addresses color harmonization, relighting, an equally vital aspect, is often overlooked due to the challenges in realistically adjusting object illumination in diverse environments.
14 |
15 | 
16 |
17 | In this project, we tackle image harmonization in the intrinsic domain, decomposing images into reflectance (albedo) and illumination (shading).
18 | We employ a two-step approach: first, harmonizing color in the albedo space, and then addressing the challenging relighting problem in the shading domain.
19 | Our goal is to generate realistic shading for the composited object, reflecting the new illumination environment.
20 |
21 | 
22 |
23 | More specifically, we initially render an initial shading using the Lambertian model and surface normals for the background and inserted object.
24 | A re-shading network then refines this shading for the composited object in a self-supervised manner.
25 | Our method is able to generate novel reshadings of the foreground region that reflect the illumination conditions of the background scene.
26 |
27 | 
28 |
29 | Our method outperforms prior works, producing realistic composite images that not only match color but also exhibit realistic illumination in diverse scenarios.
30 |
31 | 
32 |
33 | Our re-shading network learns to predict spatially-varying lighting effects in-context due to our self-supervised training approach
34 |
35 | 
36 |
37 | ## Setup
38 | Depending on how you would like to use the code in this repository there are two options to setup the code.
39 | In either case, you should first create a fresh virtual environment (`python3 -m venv intrinsic_env`) and start it (`source intrinsic_env/bin/activate`)
40 |
41 | You can install this repository as a package using `pip`:
42 | ```
43 | git clone https://github.com/compphoto/IntrinsicCompositing
44 | cd IntrinsicCompositing
45 | pip install .
46 | ```
47 | If you want to make changes to the code and have it reflected when you import the package use `pip install --editable`
48 | Or perform the same action without cloning the code using:
49 | ```
50 | pip install https://github.com/compphoto/IntrinsicCompositing/archive/main.zip
51 | ```
52 | This will allow you to import the repository as a Python package, and use our pipeline as part of your codebase. The pipeline has been tested with the following versions, but earlier versions should work as well:
53 | ```
54 | python==3.10
55 | torch==2.5.1
56 | opencv-python==4.10
57 | numpy==1.26.4
58 | ```
59 |
60 | ## Interface
61 |
62 | The best way to run our pipeline is by using our interactive interface. We provide some example backgrounds and foregrounds in `interface/examples`:
63 |
64 | ```
65 | $ cd interface
66 | $ python interface.py --bg examples/bgs/lamp.jpeg --fg examples/fgs/soap.png --mask examples/masks/soap.png
67 | ```
68 | The first time you run the interface multiple pretrained checkpoints will be downloaded (the method makes use of multiple off-the-shelf models) which may take some time. Subsequent runs will use the cached weights, but there is still a bit of preprocessing that is required when the interface is started. Once the preprocessing is done the interface window will appear and the input composite can be edited. After editing the composite, harmonizing only requires running our albedo and shading networks which should only take a second or two. These are the keybinds for the interface:
69 |
70 | | Key | Action |
71 | |--|--|
72 | | r | run the harmonization of the current composite |
73 | | s | save inputs, outputs and intermediate images |
74 | |1-5 | view various intermediate representations (shading, normals, etc) |
75 | |scroll up/down | scale foreground region up or down |
76 |
77 | The interface has been tested on an RTX2060 with 8 gb of VRAM which should be able to handle inference at a 1024 pixel resolution.
78 |
79 | ## Inference
80 |
81 | If you want to run our pipeline on pre-made composite images, you can use the script in the `inference` folder.
82 | This script will iterate through a set of composites and output our harmonized result:
83 | ```
84 | $ cd inference
85 | $ python inference.py --help
86 |
87 | usage: inference.py [-h] --input_dir INPUT_DIR --output_dir OUTPUT_DIR [--inference_size INFERENCE_SIZE] [--intermediate]
88 |
89 | optional arguments:
90 | -h, --help show this help message and exit
91 | --input_dir INPUT_DIR
92 | input directory to read input composites, bgs and masks
93 | --output_dir OUTPUT_DIR
94 | output directory to store harmonized composites
95 | --inference_size INFERENCE_SIZE
96 | size to perform inference (default 1024)
97 | --intermediate whether or not to save visualization of intermediate representations
98 | --reproduce_paper whether or not to use code and weights from the original paper implementation
99 |
100 | ```
101 | Here is how you can run the script on a set of example composites stored in `inference/examples`:
102 | ```
103 | $ python inference.py --input_dir examples/ --output_dir output/
104 | ```
105 | If you want to test your own examples, the script uses the following input directory structure:
106 | ```
107 | examples/
108 | ├── cone_chair
109 | │ ├── bg.jpeg
110 | │ ├── composite.png
111 | │ └── mask.png
112 | ├── lamp_candles
113 | │ ├── bg.jpeg
114 | │ ├── composite.png
115 | │ └── mask.png
116 | └── lamp_soap
117 | ├── bg.jpeg
118 | ├── composite.png
119 | └── mask.png
120 | ```
121 | Each directory contains a composite image, a corresponding mask for the composited region, and the background image without the composited object.
122 | Note the background image is only used to compute the lighting direction, so it doesn't need to be exactly aligned with the composite image.
123 | In fact, it can be any image and the script will use it to estimate the illumination parameters used as part of our pipeline.
124 |
125 | The script expects the images to have the extensions shown above, and for the bg and composite to be three channels while the mask is one channel.
126 | The script can be easily adjusted in order to fit whatever data format you're using.
127 |
128 | ## Note on Reproducibility
129 |
130 | The original albedo harmonization training and testing code assumed that the shading images were stored as 16-bit values, and normalized them to [0-1] accordingly. But when generating results I was using 8-bit shading images. This meant that the albedo being fed to the network was incorrect (due to the low-contrast shading values). When I prepared the code for release, I fixed this bug without thinking about it meaning the GitHub code does not have this issue. I believe the GitHub code is a more accurate implementation since the albedo harmonization network is receiving the correct albedo as input. In order to maintain reproducibility, I've added a flag to the inference and interface scripts called `--reproduce_paper` that will use the logic and weights from the original implementation. Without this flag, the code will run correctly and use better weights for the reshading network. Here are the results you should see for each setting of this flag:
131 |
132 | | with `--reproduce_paper` | without `--reproduce_paper` |
133 | | ------------- | ------------- |
134 | |  |  |
135 |
136 | ## Citation
137 |
138 | ```
139 | @INPROCEEDINGS{careagaCompositing,
140 | author={Chris Careaga and S. Mahdi H. Miangoleh and Ya\u{g}{\i}z Aksoy},
141 | title={Intrinsic Harmonization for Illumination-Aware Compositing},
142 | booktitle={Proc. SIGGRAPH Asia},
143 | year={2023},
144 | }
145 | ```
146 |
147 | ## License
148 |
149 | This implementation is provided for academic use only. Please cite our paper if you use this code or any of the models.
150 |
151 | The methodology presented in this work is safeguarded under intellectual property protection. For inquiries regarding licensing opportunities, kindly reach out to SFU Technology Licensing Office <tlo_dir ατ sfu δøτ ca> and Dr. Yağız Aksoy <yagiz ατ sfu δøτ ca>.
152 |
--------------------------------------------------------------------------------
/figures/IntrinsicCompositingVideo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/figures/IntrinsicCompositingVideo.jpg
--------------------------------------------------------------------------------
/figures/astronauts.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/figures/astronauts.png
--------------------------------------------------------------------------------
/figures/teaser2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/figures/teaser2.jpg
--------------------------------------------------------------------------------
/figures/teaser_comparison.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/figures/teaser_comparison.jpg
--------------------------------------------------------------------------------
/figures/teaser_pipeline.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/figures/teaser_pipeline.jpg
--------------------------------------------------------------------------------
/figures/user_study_comp.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/figures/user_study_comp.jpg
--------------------------------------------------------------------------------
/inference/examples/cone_chair/bg.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/inference/examples/cone_chair/bg.jpeg
--------------------------------------------------------------------------------
/inference/examples/cone_chair/composite.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/inference/examples/cone_chair/composite.png
--------------------------------------------------------------------------------
/inference/examples/cone_chair/mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/inference/examples/cone_chair/mask.png
--------------------------------------------------------------------------------
/inference/examples/lamp_candles/bg.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/inference/examples/lamp_candles/bg.jpeg
--------------------------------------------------------------------------------
/inference/examples/lamp_candles/composite.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/inference/examples/lamp_candles/composite.png
--------------------------------------------------------------------------------
/inference/examples/lamp_candles/mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/inference/examples/lamp_candles/mask.png
--------------------------------------------------------------------------------
/inference/examples/lamp_soap/bg.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/inference/examples/lamp_soap/bg.jpeg
--------------------------------------------------------------------------------
/inference/examples/lamp_soap/composite.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/inference/examples/lamp_soap/composite.png
--------------------------------------------------------------------------------
/inference/examples/lamp_soap/mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/inference/examples/lamp_soap/mask.png
--------------------------------------------------------------------------------
/inference/inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from pathlib import Path
4 | import numpy as np
5 |
6 | from glob import glob
7 |
8 | from skimage.transform import resize
9 |
10 | from chrislib.general import (
11 | invert,
12 | uninvert,
13 | view,
14 | np_to_pil,
15 | to2np,
16 | add_chan,
17 | show,
18 | round_32,
19 | tile_imgs
20 | )
21 | from chrislib.data_util import load_image
22 | from chrislib.normal_util import get_omni_normals
23 |
24 | from boosted_depth.depth_util import create_depth_models, get_depth
25 |
26 | from intrinsic.model_util import load_models
27 | from intrinsic.pipeline import run_pipeline
28 |
29 | from intrinsic_compositing.shading.pipeline import (
30 | load_reshading_model,
31 | compute_reshading,
32 | generate_shd,
33 | get_light_coeffs
34 | )
35 |
36 | from intrinsic_compositing.albedo.pipeline import (
37 | load_albedo_harmonizer,
38 | harmonize_albedo
39 | )
40 |
41 | from omnidata_tools.model_util import load_omni_model
42 |
43 | def get_bbox(mask):
44 | rows = np.any(mask, axis=1)
45 | cols = np.any(mask, axis=0)
46 | rmin, rmax = np.where(rows)[0][[0, -1]]
47 | cmin, cmax = np.where(cols)[0][[0, -1]]
48 |
49 | return rmin, rmax, cmin, cmax
50 |
51 | def rescale(img, scale, r32=False):
52 | if scale == 1.0: return img
53 |
54 | h = img.shape[0]
55 | w = img.shape[1]
56 |
57 | if r32:
58 | img = resize(img, (round_32(h * scale), round_32(w * scale)))
59 | else:
60 | img = resize(img, (int(h * scale), int(w * scale)))
61 |
62 | return img
63 |
64 | def compute_composite_normals(img, msk, model, size):
65 |
66 | bin_msk = (msk > 0)
67 |
68 | bb = get_bbox(bin_msk)
69 | bb_h, bb_w = bb[1] - bb[0], bb[3] - bb[2]
70 |
71 | # create the crop around the object in the image to send through normal net
72 | img_crop = img[bb[0] : bb[1], bb[2] : bb[3], :]
73 |
74 | crop_scale = 1024 / max(bb_h, bb_w)
75 | img_crop = rescale(img_crop, crop_scale)
76 |
77 | # get normals of cropped and scaled object and resize back to original bbox size
78 | nrm_crop = get_omni_normals(model, img_crop)
79 | nrm_crop = resize(nrm_crop, (bb_h, bb_w))
80 |
81 | h, w, c = img.shape
82 | max_dim = max(h, w)
83 | if max_dim > size:
84 | scale = size / max_dim
85 | else:
86 | scale = 1.0
87 |
88 | # resize to the final output size as specified by input args
89 | out_img = rescale(img, scale, r32=True)
90 | out_msk = rescale(msk, scale, r32=True)
91 | out_bin_msk = (out_msk > 0)
92 |
93 | # compute normals for the entire composite image at it's output size
94 | out_nrm_bg = get_omni_normals(model, out_img)
95 |
96 | # now the image is at a new size so the parameters of the object crop change.
97 | # in order to overlay the normals, we need to resize the crop to this new size
98 | out_bb = get_bbox(out_bin_msk)
99 | bb_h, bb_w = out_bb[1] - out_bb[0], out_bb[3] - out_bb[2]
100 |
101 | # now resize the normals of the crop to this size, and put them in empty image
102 | out_nrm_crop = resize(nrm_crop, (bb_h, bb_w))
103 | out_nrm_fg = np.zeros_like(out_img)
104 | out_nrm_fg[out_bb[0] : out_bb[1], out_bb[2] : out_bb[3], :] = out_nrm_crop
105 |
106 | # combine bg and fg normals with mask alphas
107 | out_nrm = (out_nrm_fg * out_msk[:, :, None]) + (out_nrm_bg * (1.0 - out_msk[:, :, None]))
108 | return out_nrm
109 |
110 | parser = argparse.ArgumentParser()
111 |
112 | parser.add_argument('--input_dir', type=str, required=True, help='input directory to read input composites, bgs and masks')
113 | parser.add_argument('--output_dir', type=str, required=True, help='output directory to store harmonized composites')
114 | parser.add_argument('--inference_size', type=int, default=1024, help='size to perform inference (default 1024)')
115 | parser.add_argument('--intermediate', action='store_true', help='whether or not to save visualization of intermediate representations')
116 | parser.add_argument('--reproduce_paper', action='store_true', help='whether or not to use code and weights from the original paper implementation')
117 |
118 | args = parser.parse_args()
119 |
120 | print('loading depth model')
121 | dpt_model = create_depth_models()
122 |
123 | print('loading normals model')
124 | nrm_model = load_omni_model()
125 |
126 | print('loading intrinsic decomposition model')
127 | int_model = load_models('paper_weights')
128 |
129 | print('loading albedo model')
130 | alb_model = load_albedo_harmonizer()
131 |
132 | print('loading reshading model')
133 | if args.reproduce_paper:
134 | shd_model = load_reshading_model('paper_weights')
135 | else:
136 | shd_model = load_reshading_model('further_trained')
137 |
138 |
139 | examples = glob(f'{args.input_dir}/*')
140 | print()
141 | print(f'found {len(examples)} scenes')
142 | print()
143 |
144 | if not os.path.exists(args.output_dir):
145 | os.makedirs(args.output_dir, exist_ok=True)
146 |
147 | for i, example_dir in enumerate(examples):
148 |
149 | bg_img = load_image(f'{example_dir}/bg.jpeg')
150 | comp_img = load_image(f'{example_dir}/composite.png')
151 | mask_img = load_image(f'{example_dir}/mask.png')
152 |
153 | scene_name = Path(example_dir).stem
154 |
155 | # to ensure that normals are globally accurate we compute them at
156 | # a resolution of 512 pixels, so resize our shading and image to compute
157 | # rescaled normals, then run the lighting model optimization
158 | bg_h, bg_w = bg_img.shape[:2]
159 | max_dim = max(bg_h, bg_w)
160 | scale = 512 / max_dim
161 |
162 | small_bg_img = rescale(bg_img, scale)
163 | small_bg_nrm = get_omni_normals(nrm_model, small_bg_img)
164 |
165 | result = run_pipeline(
166 | int_model,
167 | small_bg_img ** 2.2,
168 | resize_conf=0.0,
169 | maintain_size=True,
170 | linear=True
171 | )
172 |
173 | small_bg_shd = result['inv_shading'][:, :, None]
174 |
175 |
176 | coeffs, lgt_vis = get_light_coeffs(
177 | small_bg_shd[:, :, 0],
178 | small_bg_nrm,
179 | small_bg_img
180 | )
181 |
182 | # now we compute the normals of the entire composite image, we have some logic
183 | # to generate a detailed estimation of the foreground object by cropping and
184 | # resizing, we then overlay that onto the normals of the whole scene
185 | comp_nrm = compute_composite_normals(comp_img, mask_img, nrm_model, args.inference_size)
186 |
187 | # now compute depth and intrinsics at a specific resolution for the composite image
188 | # if the image is already smaller than the specified resolution, leave it
189 | h, w, c = comp_img.shape
190 |
191 | max_dim = max(h, w)
192 | if max_dim > args.inference_size:
193 | scale = args.inference_size / max_dim
194 | else:
195 | scale = 1.0
196 |
197 | # resize to specified size and round to 32 for network inference
198 | img = rescale(comp_img, scale, r32=True)
199 | msk = rescale(mask_img, scale, r32=True)
200 |
201 | depth = get_depth(img, dpt_model)
202 |
203 | result = run_pipeline(
204 | int_model,
205 | img ** 2.2,
206 | resize_conf=0.0,
207 | maintain_size=True,
208 | linear=True
209 | )
210 |
211 | inv_shd = result['inv_shading']
212 | # inv_shd = rescale(inv_shd, scale, r32=True)
213 |
214 | # compute the harmonized albedo, and the subsequent color harmonized image
215 | alb_harm = harmonize_albedo(img, msk, inv_shd, alb_model, reproduce_paper=args.reproduce_paper) ** 2.2
216 | harm_img = alb_harm * uninvert(inv_shd)[:, :, None]
217 |
218 | # run the reshading model using the various composited components,
219 | # and our lighting coefficients computed from the background
220 | comp_result = compute_reshading(
221 | harm_img,
222 | msk,
223 | inv_shd,
224 | depth,
225 | comp_nrm,
226 | alb_harm,
227 | coeffs,
228 | shd_model
229 | )
230 |
231 | if args.intermediate:
232 | tile_imgs([
233 | img,
234 | msk,
235 | view(alb_harm),
236 | 1-inv_shd,
237 | depth,
238 | comp_nrm,
239 | view(generate_shd(comp_nrm, coeffs, msk, viz=True)[1]),
240 | 1-invert(comp_result['reshading'])
241 | ], save=f'{args.output_dir}/{scene_name}_intermediate.jpeg', rescale=0.75)
242 |
243 | np_to_pil(comp_result['composite']).save(f'{args.output_dir}/{scene_name}.png')
244 |
245 | print(f'finished ({i+1}/{len(examples)}) - {scene_name}')
246 |
--------------------------------------------------------------------------------
/interface/examples/bgs/blue_chairs.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/bgs/blue_chairs.jpeg
--------------------------------------------------------------------------------
/interface/examples/bgs/boxes.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/bgs/boxes.jpeg
--------------------------------------------------------------------------------
/interface/examples/bgs/classroom.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/bgs/classroom.jpeg
--------------------------------------------------------------------------------
/interface/examples/bgs/cone_org.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/bgs/cone_org.jpeg
--------------------------------------------------------------------------------
/interface/examples/bgs/dim.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/bgs/dim.jpeg
--------------------------------------------------------------------------------
/interface/examples/bgs/dock.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/bgs/dock.jpeg
--------------------------------------------------------------------------------
/interface/examples/bgs/empty_room.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/bgs/empty_room.jpeg
--------------------------------------------------------------------------------
/interface/examples/bgs/lamp.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/bgs/lamp.jpeg
--------------------------------------------------------------------------------
/interface/examples/bgs/museum2.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/bgs/museum2.jpeg
--------------------------------------------------------------------------------
/interface/examples/bgs/museum3.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/bgs/museum3.jpeg
--------------------------------------------------------------------------------
/interface/examples/bgs/pillar.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/bgs/pillar.jpeg
--------------------------------------------------------------------------------
/interface/examples/fgs/astro.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/fgs/astro.png
--------------------------------------------------------------------------------
/interface/examples/fgs/figurine.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/fgs/figurine.png
--------------------------------------------------------------------------------
/interface/examples/fgs/soap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/fgs/soap.png
--------------------------------------------------------------------------------
/interface/examples/fgs/white_bag.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/fgs/white_bag.png
--------------------------------------------------------------------------------
/interface/examples/fgs/white_chair.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/fgs/white_chair.png
--------------------------------------------------------------------------------
/interface/examples/fgs/white_pot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/fgs/white_pot.png
--------------------------------------------------------------------------------
/interface/examples/masks/astro.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/masks/astro.png
--------------------------------------------------------------------------------
/interface/examples/masks/figurine.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/masks/figurine.png
--------------------------------------------------------------------------------
/interface/examples/masks/soap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/masks/soap.png
--------------------------------------------------------------------------------
/interface/examples/masks/white_bag.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/masks/white_bag.png
--------------------------------------------------------------------------------
/interface/examples/masks/white_chair.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/masks/white_chair.png
--------------------------------------------------------------------------------
/interface/examples/masks/white_pot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/masks/white_pot.png
--------------------------------------------------------------------------------
/interface/interface.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import argparse
4 | import imageio
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 | import tkinter as tk
10 | from tkinter import ttk
11 | from pathlib import Path
12 | from datetime import datetime
13 |
14 | from PIL import ImageTk, Image
15 | import numpy as np
16 |
17 | from skimage.transform import resize
18 |
19 | from chrislib.general import invert, uninvert, view, np_to_pil, to2np, add_chan
20 | from chrislib.data_util import load_image
21 | from chrislib.normal_util import get_omni_normals
22 |
23 | from boosted_depth.depth_util import create_depth_models, get_depth
24 |
25 | from intrinsic.model_util import load_models
26 | from intrinsic.pipeline import run_pipeline
27 |
28 | from intrinsic_compositing.shading.pipeline import (
29 | load_reshading_model,
30 | compute_reshading,
31 | generate_shd,
32 | get_light_coeffs
33 | )
34 |
35 | from intrinsic_compositing.albedo.pipeline import (
36 | load_albedo_harmonizer,
37 | harmonize_albedo
38 | )
39 |
40 | from omnidata_tools.model_util import load_omni_model
41 |
42 | def viz_coeffs(coeffs, size):
43 | half_sz = size // 2
44 | nrm_circ = draw_normal_circle(
45 | np.zeros((size, size, 3)),
46 | (half_sz, half_sz),
47 | half_sz
48 | )
49 |
50 | out_shd = (nrm_circ.reshape(-1, 3) @ coeffs[:3]) + coeffs[-1]
51 | out_shd = out_shd.reshape(size, size)
52 |
53 | lin = np.linspace(-1, 1, num=size)
54 | ys, xs = np.meshgrid(lin, lin)
55 |
56 | zs = np.sqrt((1.0 - (xs**2 + ys**2)).clip(0))
57 |
58 | out_shd[zs == 0] = 0
59 |
60 | return (out_shd.clip(1e-4) ** (1/2.2)).clip(0, 1)
61 |
62 | def draw_normal_circle(nrm, loc, rad):
63 | size = rad * 2
64 |
65 | lin = np.linspace(-1, 1, num=size)
66 | ys, xs = np.meshgrid(lin, lin)
67 |
68 | zs = np.sqrt((1.0 - (xs**2 + ys**2)).clip(0))
69 | valid = (zs != 0)
70 | normals = np.stack((ys[valid], -xs[valid], zs[valid]), 1)
71 |
72 | valid_mask = np.zeros((size, size))
73 | valid_mask[valid] = 1
74 |
75 | full_mask = np.zeros((nrm.shape[0], nrm.shape[1]))
76 | x = loc[0] - rad
77 | y = loc[1] - rad
78 | full_mask[y : y + size, x : x + size] = valid_mask
79 | # nrm[full_mask > 0] = (normals + 1.0) / 2.0
80 | nrm[full_mask > 0] = normals
81 |
82 | return nrm
83 |
84 | def get_bbox(mask):
85 | rows = np.any(mask, axis=1)
86 | cols = np.any(mask, axis=0)
87 | rmin, rmax = np.where(rows)[0][[0, -1]]
88 | cmin, cmax = np.where(cols)[0][[0, -1]]
89 |
90 | return rmin, rmax, cmin, cmax
91 |
92 | def rescale(img, scale):
93 | if scale == 1.0: return img
94 |
95 | h = img.shape[0]
96 | w = img.shape[1]
97 |
98 | img = resize(img, (int(h * scale), int(w * scale)))
99 | return img
100 |
101 | def composite_crop(img, loc, fg, mask):
102 | c_h, c_w, _ = fg.shape
103 |
104 | img = img.copy()
105 |
106 | img_crop = img[loc[0] : loc[0] + c_h, loc[1] : loc[1] + c_w, :]
107 | comp = (img_crop * (1.0 - mask)) + (fg * mask)
108 | img[loc[0] : loc[0] + c_h, loc[1] : loc[1] + c_w, :] = comp
109 |
110 | return img
111 |
112 | # composite the depth of a fragment but try to match
113 | # the wherever the fragment is placed (fake the depth)
114 | def composite_depth(img, loc, fg, mask):
115 | c_h, c_w = fg.shape[:2]
116 |
117 | # get the bottom-center depth of the bg
118 | bg_bc = loc[0] + c_h, loc[1] + (c_w // 2)
119 | bg_bc_val = img[bg_bc[0], bg_bc[1]].item()
120 |
121 | # get the bottom center depth of the fragment
122 | fg_bc = c_h - 1, (c_w // 2)
123 | fg_bc_val = fg[fg_bc[0], fg_bc[1]].item()
124 |
125 | # compute scale to match the fg values to bg
126 | scale = bg_bc_val / fg_bc_val
127 |
128 | img = img.copy()
129 |
130 | img_crop = img[loc[0] : loc[0] + c_h, loc[1] : loc[1] + c_w]
131 | comp = (img_crop * (1.0 - mask)) + (scale * fg * mask)
132 | img[loc[0] : loc[0] + c_h, loc[1] : loc[1] + c_w] = comp
133 |
134 | return img
135 |
136 |
137 | # DISP_SCALE = 0.5
138 | DISP_SCALE = 1.0
139 | MAX_BG_SZ = 1024
140 | LGT_VIZ_SZ = 144
141 |
142 | class App(tk.Tk):
143 | def __init__(self, args):
144 | super().__init__()
145 |
146 | self.configure(background='black')
147 |
148 | self.args = args
149 |
150 | loaded_fg = load_image(args.fg)
151 |
152 | self.bg_img = load_image(args.bg)[:, :, :3]
153 | self.fg_img = loaded_fg[:, :, :3]
154 |
155 | if args.mask is not None:
156 | self.mask_img = load_image(args.mask)
157 | else:
158 | if loaded_fg.shape[-1] != 4:
159 | print("expected foreground image to have an alpha channel since no mask was specified")
160 | exit()
161 |
162 | self.mask_img = self.fg_img[:, :, -1]
163 |
164 | if len(self.mask_img.shape) == 3:
165 | self.mask_img = self.mask_img[:, :, :1]
166 | else:
167 | self.mask_img = self.mask_img[:, :, np.newaxis]
168 |
169 |
170 | print('loading depth model')
171 | self.dpt_model = create_depth_models()
172 |
173 | print('loading normals model')
174 | self.nrm_model = load_omni_model()
175 |
176 | print('loading intrinsic decomposition model')
177 | self.int_model = load_models('paper_weights')
178 |
179 | print('loading albedo model')
180 | self.alb_model = load_albedo_harmonizer()
181 |
182 | print('loading reshading model')
183 | if self.args.reproduce_paper:
184 | self.shd_model = load_reshading_model('paper_weights')
185 | else:
186 | self.shd_model = load_reshading_model('further_trained')
187 |
188 | self.init_scene()
189 |
190 | self.bg_disp_w = int(self.bg_w * DISP_SCALE)
191 | self.bg_disp_h = int(self.bg_h * DISP_SCALE)
192 |
193 | win_w = (self.bg_disp_w * 2) + LGT_VIZ_SZ + 40
194 | win_h = self.bg_disp_h + 20
195 |
196 | # configure the root window
197 | self.title('compositing demo')
198 | self.geometry(f"{win_w}x{win_h}")
199 | # self.geometry(f"")
200 |
201 | self.l_frame = ttk.Frame(self, width=self.bg_disp_w, height=self.bg_disp_h)
202 | self.l_frame.pack()
203 | self.l_frame.place(x=10, y=10)
204 |
205 | self.r_frame = ttk.Frame(self, width=self.bg_disp_w, height=self.bg_disp_h)
206 | self.r_frame.pack()
207 | self.r_frame.place(x=self.bg_disp_w + 20, y=10)
208 |
209 | style = ttk.Style(self)
210 | style.configure("TFrame", background="black")
211 |
212 | self.lgt_frame = ttk.Frame(self, width=LGT_VIZ_SZ, height=self.bg_disp_h)
213 | self.lgt_frame.pack()
214 | self.lgt_frame.place(x=win_w - LGT_VIZ_SZ - 10, y=10)
215 |
216 | l_disp_img = rescale(self.l_img, DISP_SCALE)
217 | r_disp_img = rescale(self.r_img, DISP_SCALE)
218 | lgt_disp_img = viz_coeffs(self.coeffs, LGT_VIZ_SZ)
219 |
220 | self.l_photo = ImageTk.PhotoImage(np_to_pil(l_disp_img))
221 | self.r_photo = ImageTk.PhotoImage(np_to_pil(r_disp_img))
222 | self.lgt_photo = ImageTk.PhotoImage(np_to_pil(lgt_disp_img))
223 |
224 | self.l_label = ttk.Label(self.l_frame, image=self.l_photo)
225 | self.l_label.pack()
226 |
227 | self.r_label = ttk.Label(self.r_frame, image=self.r_photo)
228 | self.r_label.pack()
229 |
230 | self.lgt_label = ttk.Label(self.lgt_frame, image=self.lgt_photo)
231 | self.lgt_label.pack()
232 |
233 | self.bias_scale = ttk.Scale(self.lgt_frame, from_=0.1, to=1.0, orient=tk.HORIZONTAL, command=self.update_bias)
234 | self.bias_scale.pack(pady=5)
235 |
236 | style = ttk.Style(self)
237 | style.configure("White.TLabel", foreground="white", background='black')
238 |
239 | al = ttk.Label(self.lgt_frame, text="Ambient Strength", style="White.TLabel")
240 | al.pack(pady=5)
241 |
242 | self.dir_scale = ttk.Scale(self.lgt_frame, from_=0.1, to=2.0, orient=tk.HORIZONTAL, command=self.update_dir)
243 | self.dir_scale.pack(pady=5)
244 |
245 | dl = ttk.Label(self.lgt_frame, text="Directional Strength", style="White.TLabel")
246 | dl.pack(pady=5)
247 |
248 | dir_val = np.linalg.norm(self.coeffs[:3])
249 | bias_val = self.coeffs[-1]
250 |
251 | self.bias_scale.set(bias_val)
252 | self.dir_scale.set(dir_val)
253 |
254 | self.bind('', self.key_pressed)
255 | self.bind('', self.click_motion)
256 | self.bind('', self.scrolled)
257 | self.bind('', self.scrolled)
258 | self.bind('', self.clicked)
259 |
260 |
261 | def update_left(self):
262 | # disp_img = rescale(self.l_img, DISP_SCALE)
263 | disp_img = self.l_img
264 |
265 | self.l_photo = ImageTk.PhotoImage(np_to_pil(disp_img))
266 | self.l_label.configure(image=self.l_photo)
267 | self.l_label.image = self.l_photo
268 |
269 | def update_right(self):
270 | disp_img = rescale(self.r_img, DISP_SCALE)
271 | self.r_photo = ImageTk.PhotoImage(np_to_pil(disp_img))
272 | self.r_label.configure(image=self.r_photo)
273 | self.r_label.image = self.r_photo
274 |
275 | def update_light(self):
276 |
277 | self.lgt_disp_img = viz_coeffs(self.coeffs, LGT_VIZ_SZ)
278 | self.lgt_photo = ImageTk.PhotoImage(np_to_pil(self.lgt_disp_img))
279 | self.lgt_label.configure(image=self.lgt_photo)
280 | self.lgt_label.image = self.lgt_photo
281 |
282 |
283 | def update_bias(self, val):
284 | # update the ambient light strength from slider value
285 | self.coeffs[-1] = float(val)
286 | self.update_light()
287 |
288 | def update_dir(self, val):
289 | # update the directional light source strength
290 |
291 | # normalize the currect light source direction
292 | vec = self.coeffs[:3]
293 | vec /= np.linalg.norm(vec).clip(1e-3)
294 |
295 | # scale the unit length light direction by the slider value
296 | # then update the stored coeffs and UI accordingly
297 | vec *= float(val)
298 | self.coeffs[:3] = vec
299 | self.update_light()
300 |
301 | def init_scene(self):
302 | bg_h, bg_w, _ = self.bg_img.shape
303 |
304 | # resize the background image to be large side < some max value
305 | max_dim = max(bg_h, bg_w)
306 | scale = MAX_BG_SZ / max_dim
307 |
308 | # here is our bg image and size after resizing to max size
309 | self.bg_img = resize(self.bg_img, (int(bg_h * scale), int(bg_w * scale)))
310 | self.bg_h, self.bg_w, _ = self.bg_img.shape
311 |
312 | # compute normals and shading for background, and use them
313 | # to optimize for the lighting coefficients
314 | self.bg_nrm = get_omni_normals(self.nrm_model, self.bg_img)
315 | result = run_pipeline(
316 | self.int_model,
317 | self.bg_img ** 2.2,
318 | resize_conf=0.0,
319 | maintain_size=True,
320 | linear=True
321 | )
322 |
323 | self.bg_shd = result['inv_shading'][:, :, None]
324 | self.bg_alb = result['albedo']
325 |
326 | # to ensure that normals are globally accurate we compute them at
327 | # a resolution of 512 pixels, so resize our shading and image to compute
328 | # rescaled normals, then run the lighting model optimization
329 | max_dim = max(self.bg_h, self.bg_w)
330 | scale = 512 / max_dim
331 | small_bg_img = rescale(self.bg_img, scale)
332 | small_bg_nrm = get_omni_normals(self.nrm_model, small_bg_img)
333 | small_bg_shd = rescale(self.bg_shd, scale)
334 |
335 | self.orig_coeffs, self.lgt_vis = get_light_coeffs(
336 | small_bg_shd[:, :, 0],
337 | small_bg_nrm,
338 | small_bg_img
339 | )
340 |
341 | self.coeffs = self.orig_coeffs
342 |
343 | # first we want to reason about the fg image as an image fragment that we can
344 | # move, scale and composite, so we only really need to deal with the masked area
345 | bin_msk = (self.mask_img > 0)
346 |
347 | bb = get_bbox(bin_msk)
348 | bb_h, bb_w = bb[1] - bb[0], bb[3] - bb[2]
349 |
350 | # create the crop around the object in the image, this can be very large depending
351 | # on the image that has been chosen by the user, but we want to store it at 1024
352 | self.orig_fg_crop = self.fg_img[bb[0] : bb[1], bb[2] : bb[3], :].copy()
353 | self.orig_msk_crop = self.mask_img[bb[0] : bb[1], bb[2] : bb[3], :].copy()
354 |
355 | # this is the real_scale, maps the original image crop size to 1024
356 | max_dim = max(bb_h, bb_w)
357 | real_scale = MAX_BG_SZ / max_dim
358 |
359 | # this is the copy of the crop we keep to compute normals
360 | self.orig_fg_crop = rescale(self.orig_fg_crop, real_scale)
361 | self.orig_msk_crop = rescale(self.orig_msk_crop, real_scale)
362 |
363 | # now compute the display scale to show the fragment on the ui
364 | max_dim = max(self.orig_fg_crop.shape)
365 | disp_scale = (min(self.bg_h, self.bg_w) // 2) / max_dim
366 | self.frag_scale = disp_scale
367 | print('init frag_scale:', self.frag_scale)
368 |
369 | # these are the versions that the UI shows and what will be
370 | # used to create the composite images sent to the networks
371 | self.fg_crop = rescale(self.orig_fg_crop, self.frag_scale)
372 | self.msk_crop = rescale(self.orig_msk_crop, self.frag_scale)
373 |
374 | self.bb_h, self.bb_w, _ = self.fg_crop.shape
375 |
376 | # set the composite region to center of bg img
377 | self.loc_y = self.bg_h // 2
378 | self.loc_x = self.bg_w // 2
379 |
380 | # get top and left using location, and bounding box size
381 | top = self.loc_y - (self.bb_h // 2)
382 | left = self.loc_x - (self.bb_w // 2)
383 |
384 | # create the composite image using helper function and location
385 | init_cmp = composite_crop(
386 | self.bg_img,
387 | (top, left),
388 | self.fg_crop,
389 | self.msk_crop
390 | )
391 |
392 | # get normals just for the image fragment, it's best to send it through
393 | # cropped and scaled to 1024 in order to get the most details,
394 | # then we resize it to match the fragment size
395 | self.orig_fg_nrm = get_omni_normals(self.nrm_model, self.orig_fg_crop)
396 |
397 | # run the intrinsic pipeline to get foreground albedo and shading
398 | result = run_pipeline(
399 | self.int_model,
400 | self.orig_fg_crop ** 2.2,
401 | resize_conf=0.0,
402 | maintain_size=True,
403 | linear=True
404 | )
405 |
406 | self.orig_fg_shd = result['inv_shading'][:, :, None]
407 | self.orig_fg_alb = result['albedo']
408 |
409 | # to try to save memory I remove refs and empty cuda cache (might not even do anything)
410 | del self.nrm_model
411 | del self.int_model
412 | torch.cuda.empty_cache()
413 |
414 | # run depth model on bg and fg seperately
415 | self.bg_dpt = get_depth(self.bg_img, self.dpt_model)[:, :, None]
416 | self.orig_fg_dpt = get_depth(self.orig_fg_crop, self.dpt_model)[:, :, None]
417 | del self.dpt_model
418 |
419 | # these are the versions of the bg, fg and msk that we use in the UI
420 | self.disp_bg_img = rescale(self.bg_img, DISP_SCALE)
421 | self.disp_fg_crop = rescale(self.fg_crop, DISP_SCALE)
422 | self.disp_msk_crop = rescale(self.msk_crop, DISP_SCALE)
423 |
424 | # initialize right side as initial composite, and left is just zeros
425 | self.l_img = init_cmp
426 | self.r_img = np.zeros_like(init_cmp)
427 |
428 | def scrolled(self, e):
429 |
430 | # if we scroll we want to scale the foreground fragment up or down
431 | if e.num == 5 and self.frag_scale > 0.05: # scroll down
432 | self.frag_scale -= 0.01
433 | if e.num == 4 and self.frag_scale < 1.0: # scroll up
434 | self.frag_scale += 0.01
435 |
436 | self.fg_crop = rescale(self.orig_fg_crop, self.frag_scale)
437 | self.msk_crop = rescale(self.orig_msk_crop, self.frag_scale)
438 |
439 | self.disp_fg_crop = rescale(self.fg_crop, DISP_SCALE)
440 | self.disp_msk_crop = rescale(self.msk_crop, DISP_SCALE)
441 |
442 | x = int(self.loc_x * DISP_SCALE)
443 | y = int(self.loc_y * DISP_SCALE)
444 |
445 | top = y - (self.disp_fg_crop.shape[0] // 2)
446 | left = x - (self.disp_fg_crop.shape[1] // 2)
447 |
448 | self.l_img = composite_crop(
449 | self.disp_bg_img,
450 | (top, left),
451 | self.disp_fg_crop,
452 | self.disp_msk_crop
453 | )
454 | self.update_left()
455 |
456 | def clicked(self, e):
457 |
458 | x, y = e.x, e.y
459 | radius = (LGT_VIZ_SZ // 2)
460 |
461 | if e.widget == self.lgt_label:
462 | # if the user clicked the light ball, compute the direction from mouse pos
463 | rel_x = (x - radius) / radius
464 | rel_y = (y - radius) / radius
465 |
466 | z = np.sqrt(1 - rel_x ** 2 - rel_y ** 2)
467 |
468 | # print('clicked the lighting viz:', rel_x, rel_y, z)
469 |
470 | # after converting the mouse pos to a normal direction on a unit sphere
471 | # we can create our 4D lighting coefficients using the slider values
472 | self.coeffs = np.array([0, 0, 0, float(self.bias_scale.get())])
473 | dir_vec = np.array([rel_x, -rel_y, z]) * float(self.dir_scale.get())
474 | self.coeffs[:3] = dir_vec
475 |
476 | self.update_light()
477 |
478 | def click_motion(self, e):
479 | x, y = e.x, e.y
480 |
481 | if e.widget == self.l_label:
482 | if (x <= self.bg_disp_w) and (y <= self.bg_disp_h):
483 |
484 | # we want to show the scaled version of the composite so that the UI
485 | # can be responsive, but save the coordinates properly so that the
486 | # we can send the original size image through the network
487 | self.loc_y = int(y / DISP_SCALE)
488 | self.loc_x = int(x / DISP_SCALE)
489 |
490 | top = y - (self.disp_fg_crop.shape[0] // 2)
491 | left = x - (self.disp_fg_crop.shape[1] // 2)
492 |
493 | self.l_img = composite_crop(
494 | self.disp_bg_img,
495 | (top, left),
496 | self.disp_fg_crop,
497 | self.disp_msk_crop
498 | )
499 |
500 | self.update_left()
501 |
502 | def key_pressed(self, e):
503 | # run the harmonization
504 | if e.char == 'r':
505 |
506 | # create all the necessary inputs from the state of the interface
507 | fg_shd_res = rescale(self.orig_fg_shd, self.frag_scale)
508 | fg_nrm_res = rescale(self.orig_fg_nrm, self.frag_scale)
509 | fg_dpt_res = rescale(self.orig_fg_dpt, self.frag_scale)
510 |
511 | top = self.loc_y - (self.fg_crop.shape[0] // 2)
512 | left = self.loc_x - (self.fg_crop.shape[1] // 2)
513 |
514 | # create all the composite images to send to the pipeline
515 | self.comp_img = composite_crop(
516 | self.bg_img,
517 | (top, left),
518 | self.fg_crop,
519 | self.msk_crop
520 | )
521 |
522 | self.comp_shd = composite_crop(
523 | self.bg_shd,
524 | (top, left),
525 | fg_shd_res,
526 | self.msk_crop
527 | )
528 |
529 | self.comp_msk = composite_crop(
530 | np.zeros_like(self.bg_shd),
531 | (top, left),
532 | self.msk_crop,
533 | self.msk_crop
534 | )
535 |
536 | comp_nrm = composite_crop(
537 | self.bg_nrm,
538 | (top, left),
539 | fg_nrm_res,
540 | self.msk_crop
541 | )
542 |
543 | self.comp_dpt = composite_depth(
544 | self.bg_dpt,
545 | (top, left),
546 | fg_dpt_res,
547 | self.msk_crop
548 | )
549 |
550 | # the albedo comes out gamma corrected so make it linear
551 | self.alb_harm = harmonize_albedo(
552 | self.comp_img,
553 | self.comp_msk,
554 | self.comp_shd,
555 | self.alb_model,
556 | reproduce_paper=self.args.reproduce_paper
557 | ) ** 2.2
558 |
559 | self.orig_alb = (self.comp_img ** 2.2) / uninvert(self.comp_shd)
560 | harm_img = self.alb_harm * uninvert(self.comp_shd)
561 |
562 | # run the reshading model using the various composited components,
563 | # and our lighting coefficients from the user interface
564 | self.result = compute_reshading(
565 | harm_img,
566 | self.comp_msk,
567 | self.comp_shd,
568 | self.comp_dpt,
569 | comp_nrm,
570 | self.alb_harm,
571 | self.coeffs,
572 | self.shd_model
573 | )
574 |
575 | self.r_img = self.result['composite']
576 | self.update_right()
577 |
578 | if e.char == '1':
579 | self.r_img = self.result['reshading']
580 | self.update_right()
581 |
582 | if e.char == '2':
583 | self.r_img = self.result['init_shading']
584 | self.update_right()
585 |
586 | if e.char == '3':
587 | self.r_img = self.result['normals']
588 | self.update_right()
589 |
590 | if e.char == '4':
591 | self.r_img = self.comp_shd[:, :, 0]
592 | self.update_right()
593 |
594 | if e.char == '5':
595 | self.r_img = self.alb_harm
596 | self.update_right()
597 |
598 | if e.char == '6':
599 | self.r_img = self.comp_dpt[:, :, 0]
600 | self.update_right()
601 |
602 | if e.char == 's':
603 | # save all components
604 |
605 | # orig_shd from intrinsic pipeline is linear and inverse
606 | orig_shd = add_chan(uninvert(self.comp_shd))
607 |
608 | # reshading coming from compositing pipeline is linear but not inverse
609 | reshading = add_chan(self.result['reshading'])
610 |
611 | imageio.imwrite('output/orig_shd.exr', orig_shd)
612 | imageio.imwrite('output/orig_shd.png', orig_shd)
613 | imageio.imwrite('output/orig_alb.exr', self.orig_alb)
614 | imageio.imwrite('output/orig_alb.png', self.orig_alb)
615 | imageio.imwrite('output/orig_img.exr', self.comp_img ** 2.2)
616 | imageio.imwrite('output/orig_img.png', self.comp_img ** 2.2)
617 |
618 | imageio.imwrite('output/harm_alb.exr', self.alb_harm)
619 | imageio.imwrite('output/harm_alb.png', self.alb_harm)
620 | imageio.imwrite('output/reshading.exr', reshading)
621 | imageio.imwrite('output/reshading.png', reshading)
622 | imageio.imwrite('output/final.exr', self.result['composite'] ** 2.2)
623 | imageio.imwrite('output/final.png', self.result['composite'] ** 2.2)
624 |
625 | imageio.imwrite('output/normals.exr', self.result['normals'])
626 | imageio.imwrite('output/light.exr', self.lgt_disp_img)
627 |
628 | if e.char == 'w':
629 | # write all the different components as pngs
630 |
631 | # orig_shd from intrinsic pipeline is linear and inverse
632 | orig_shd = add_chan(uninvert(self.comp_shd))
633 |
634 | # reshading coming from compositing pipeline is linear but not inverse
635 | reshading = add_chan(self.result['reshading'])
636 | lambertian = add_chan(self.result['init_shading'])
637 | mask = add_chan(self.comp_msk)
638 |
639 | fg_name = Path(self.args.fg).stem
640 | bg_name = Path(self.args.bg).stem
641 | ts = int(datetime.utcnow().timestamp())
642 |
643 | save_dir = f'{fg_name}_{bg_name}_{ts}'
644 | os.makedirs(f'output/{save_dir}')
645 |
646 | np_to_pil(view(orig_shd)).save(f'output/{save_dir}/orig_shd.png')
647 | np_to_pil(view(lambertian)).save(f'output/{save_dir}/lamb_shd.png')
648 | np_to_pil(view(self.orig_alb)).save(f'output/{save_dir}/orig_alb.png')
649 | np_to_pil(self.comp_img).save(f'output/{save_dir}/orig_img.png')
650 |
651 | np_to_pil(view(self.alb_harm)).save(f'output/{save_dir}/harm_alb.png')
652 | np_to_pil(view(reshading)).save(f'output/{save_dir}/reshading.png')
653 | np_to_pil(self.result['composite']).save(f'output/{save_dir}/final.png')
654 |
655 | np_to_pil(self.result['normals']).save(f'output/{save_dir}/normals.png')
656 | np_to_pil(self.lgt_disp_img).save(f'output/{save_dir}/light.png')
657 |
658 | np_to_pil(mask).save(f'output/{save_dir}/mask.png')
659 |
660 | _, bg_lamb_shd = generate_shd(self.bg_nrm, self.coeffs, np.ones(self.bg_nrm.shape[:2]), viz=True)
661 | np_to_pil(add_chan(view(bg_lamb_shd))).save(f'output/{save_dir}/bg_lamb_shd.png')
662 |
663 |
664 | if __name__ == "__main__":
665 | parser = argparse.ArgumentParser()
666 |
667 | parser.add_argument('--bg', type=str, required=True)
668 | parser.add_argument('--fg', type=str, required=True)
669 | parser.add_argument('--mask', type=str, default=None)
670 | parser.add_argument('--reproduce_paper', action='store_true', help='whether or not use the code and weights of the original paper implementation')
671 |
672 | args = parser.parse_args()
673 |
674 | app = App(args)
675 | app.mainloop()
676 |
--------------------------------------------------------------------------------
/intrinsic_compositing/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/intrinsic_compositing/__init__.py
--------------------------------------------------------------------------------
/intrinsic_compositing/albedo/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/intrinsic_compositing/albedo/__init__.py
--------------------------------------------------------------------------------
/intrinsic_compositing/albedo/model/MiDaS/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/intrinsic_compositing/albedo/model/MiDaS/__init__.py
--------------------------------------------------------------------------------
/intrinsic_compositing/albedo/model/MiDaS/midas/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/intrinsic_compositing/albedo/model/MiDaS/midas/__init__.py
--------------------------------------------------------------------------------
/intrinsic_compositing/albedo/model/MiDaS/midas/base_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class BaseModel(torch.nn.Module):
5 | def load(self, path):
6 | """Load model from file.
7 |
8 | Args:
9 | path (str): file path
10 | """
11 | parameters = torch.load(path, map_location=torch.device('cpu'))
12 |
13 | if "optimizer" in parameters:
14 | parameters = parameters["model"]
15 |
16 | self.load_state_dict(parameters)
17 |
--------------------------------------------------------------------------------
/intrinsic_compositing/albedo/model/MiDaS/midas/blocks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .vit import (
5 | _make_pretrained_vitb_rn50_384,
6 | _make_pretrained_vitl16_384,
7 | _make_pretrained_vitb16_384,
8 | forward_vit,
9 | )
10 |
11 | def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
12 | if backbone == "vitl16_384":
13 | pretrained = _make_pretrained_vitl16_384(
14 | use_pretrained, hooks=hooks, use_readout=use_readout
15 | )
16 | scratch = _make_scratch(
17 | [256, 512, 1024, 1024], features, groups=groups, expand=expand
18 | ) # ViT-L/16 - 85.0% Top1 (backbone)
19 | elif backbone == "vitb_rn50_384":
20 | pretrained = _make_pretrained_vitb_rn50_384(
21 | use_pretrained,
22 | hooks=hooks,
23 | use_vit_only=use_vit_only,
24 | use_readout=use_readout,
25 | )
26 | scratch = _make_scratch(
27 | [256, 512, 768, 768], features, groups=groups, expand=expand
28 | ) # ViT-H/16 - 85.0% Top1 (backbone)
29 | elif backbone == "vitb16_384":
30 | pretrained = _make_pretrained_vitb16_384(
31 | use_pretrained, hooks=hooks, use_readout=use_readout
32 | )
33 | scratch = _make_scratch(
34 | [96, 192, 384, 768], features, groups=groups, expand=expand
35 | ) # ViT-B/16 - 84.6% Top1 (backbone)
36 | elif backbone == "resnext101_wsl":
37 | pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
38 | scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
39 | elif backbone == "efficientnet_lite3":
40 | pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
41 | scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
42 | else:
43 | print(f"Backbone '{backbone}' not implemented")
44 | assert False
45 |
46 | return pretrained, scratch
47 |
48 |
49 | def _make_scratch(in_shape, out_shape, groups=1, expand=False):
50 | scratch = nn.Module()
51 |
52 | out_shape1 = out_shape
53 | out_shape2 = out_shape
54 | out_shape3 = out_shape
55 | out_shape4 = out_shape
56 | if expand==True:
57 | out_shape1 = out_shape
58 | out_shape2 = out_shape*2
59 | out_shape3 = out_shape*4
60 | out_shape4 = out_shape*8
61 |
62 | scratch.layer1_rn = nn.Conv2d(
63 | in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
64 | )
65 | scratch.layer2_rn = nn.Conv2d(
66 | in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
67 | )
68 | scratch.layer3_rn = nn.Conv2d(
69 | in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
70 | )
71 | scratch.layer4_rn = nn.Conv2d(
72 | in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
73 | )
74 |
75 | return scratch
76 |
77 |
78 | def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
79 | efficientnet = torch.hub.load(
80 | "rwightman/gen-efficientnet-pytorch",
81 | "tf_efficientnet_lite3",
82 | pretrained=use_pretrained,
83 | exportable=exportable
84 | )
85 | return _make_efficientnet_backbone(efficientnet)
86 |
87 |
88 | def _make_efficientnet_backbone(effnet):
89 | pretrained = nn.Module()
90 |
91 | pretrained.layer1 = nn.Sequential(
92 | effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
93 | )
94 | pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
95 | pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
96 | pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
97 |
98 | return pretrained
99 |
100 |
101 | def _make_resnet_backbone(resnet):
102 | pretrained = nn.Module()
103 | pretrained.layer1 = nn.Sequential(
104 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
105 | )
106 |
107 | pretrained.layer2 = resnet.layer2
108 | pretrained.layer3 = resnet.layer3
109 | pretrained.layer4 = resnet.layer4
110 |
111 | return pretrained
112 |
113 |
114 | def _make_pretrained_resnext101_wsl(use_pretrained):
115 | resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
116 | return _make_resnet_backbone(resnet)
117 |
118 |
119 |
120 | class Interpolate(nn.Module):
121 | """Interpolation module.
122 | """
123 |
124 | def __init__(self, scale_factor, mode, align_corners=False):
125 | """Init.
126 |
127 | Args:
128 | scale_factor (float): scaling
129 | mode (str): interpolation mode
130 | """
131 | super(Interpolate, self).__init__()
132 |
133 | self.interp = nn.functional.interpolate
134 | self.scale_factor = scale_factor
135 | self.mode = mode
136 | self.align_corners = align_corners
137 |
138 | def forward(self, x):
139 | """Forward pass.
140 |
141 | Args:
142 | x (tensor): input
143 |
144 | Returns:
145 | tensor: interpolated data
146 | """
147 |
148 | x = self.interp(
149 | x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
150 | )
151 |
152 | return x
153 |
154 |
155 | class ResidualConvUnit(nn.Module):
156 | """Residual convolution module.
157 | """
158 |
159 | def __init__(self, features):
160 | """Init.
161 |
162 | Args:
163 | features (int): number of features
164 | """
165 | super().__init__()
166 |
167 | self.conv1 = nn.Conv2d(
168 | features, features, kernel_size=3, stride=1, padding=1, bias=True
169 | )
170 |
171 | self.conv2 = nn.Conv2d(
172 | features, features, kernel_size=3, stride=1, padding=1, bias=True
173 | )
174 |
175 | self.relu = nn.ReLU(inplace=True)
176 |
177 | def forward(self, x):
178 | """Forward pass.
179 |
180 | Args:
181 | x (tensor): input
182 |
183 | Returns:
184 | tensor: output
185 | """
186 | out = self.relu(x)
187 | out = self.conv1(out)
188 | out = self.relu(out)
189 | out = self.conv2(out)
190 |
191 | return out + x
192 |
193 |
194 | class FeatureFusionBlock(nn.Module):
195 | """Feature fusion block.
196 | """
197 |
198 | def __init__(self, features):
199 | """Init.
200 |
201 | Args:
202 | features (int): number of features
203 | """
204 | super(FeatureFusionBlock, self).__init__()
205 |
206 | self.resConfUnit1 = ResidualConvUnit(features)
207 | self.resConfUnit2 = ResidualConvUnit(features)
208 |
209 | def forward(self, *xs):
210 | """Forward pass.
211 |
212 | Returns:
213 | tensor: output
214 | """
215 | output = xs[0]
216 |
217 | if len(xs) == 2:
218 | output += self.resConfUnit1(xs[1])
219 |
220 | output = self.resConfUnit2(output)
221 |
222 | output = nn.functional.interpolate(
223 | output, scale_factor=2, mode="bilinear", align_corners=True
224 | )
225 |
226 | return output
227 |
228 |
229 |
230 |
231 | class ResidualConvUnit_custom(nn.Module):
232 | """Residual convolution module.
233 | """
234 |
235 | def __init__(self, features, activation, bn):
236 | """Init.
237 |
238 | Args:
239 | features (int): number of features
240 | """
241 | super().__init__()
242 |
243 | self.bn = bn
244 |
245 | self.groups=1
246 |
247 | self.conv1 = nn.Conv2d(
248 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
249 | )
250 |
251 | self.conv2 = nn.Conv2d(
252 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
253 | )
254 |
255 | if self.bn==True:
256 | self.bn1 = nn.BatchNorm2d(features)
257 | self.bn2 = nn.BatchNorm2d(features)
258 |
259 | self.activation = activation
260 |
261 | self.skip_add = nn.quantized.FloatFunctional()
262 |
263 | def forward(self, x):
264 | """Forward pass.
265 |
266 | Args:
267 | x (tensor): input
268 |
269 | Returns:
270 | tensor: output
271 | """
272 |
273 | out = self.activation(x)
274 | out = self.conv1(out)
275 | if self.bn==True:
276 | out = self.bn1(out)
277 |
278 | out = self.activation(out)
279 | out = self.conv2(out)
280 | if self.bn==True:
281 | out = self.bn2(out)
282 |
283 | if self.groups > 1:
284 | out = self.conv_merge(out)
285 |
286 | return self.skip_add.add(out, x)
287 |
288 | # return out + x
289 |
290 |
291 | class FeatureFusionBlock_custom(nn.Module):
292 | """Feature fusion block.
293 | """
294 |
295 | def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
296 | """Init.
297 |
298 | Args:
299 | features (int): number of features
300 | """
301 | super(FeatureFusionBlock_custom, self).__init__()
302 |
303 | self.deconv = deconv
304 | self.align_corners = align_corners
305 |
306 | self.groups=1
307 |
308 | self.expand = expand
309 | out_features = features
310 | if self.expand==True:
311 | out_features = features//2
312 |
313 | self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
314 |
315 | self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
316 | self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
317 |
318 | self.skip_add = nn.quantized.FloatFunctional()
319 |
320 | def forward(self, *xs):
321 | """Forward pass.
322 |
323 | Returns:
324 | tensor: output
325 | """
326 | output = xs[0]
327 |
328 | if len(xs) == 2:
329 | res = self.resConfUnit1(xs[1])
330 | output = self.skip_add.add(output, res)
331 | # output += res
332 |
333 | output = self.resConfUnit2(output)
334 |
335 | output = nn.functional.interpolate(
336 | output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
337 | )
338 |
339 | output = self.out_conv(output)
340 |
341 | return output
342 |
343 |
--------------------------------------------------------------------------------
/intrinsic_compositing/albedo/model/MiDaS/midas/dpt_depth.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .base_model import BaseModel
6 | from .blocks import (
7 | FeatureFusionBlock,
8 | FeatureFusionBlock_custom,
9 | Interpolate,
10 | _make_encoder,
11 | forward_vit,
12 | )
13 |
14 |
15 | def _make_fusion_block(features, use_bn):
16 | return FeatureFusionBlock_custom(
17 | features,
18 | nn.ReLU(False),
19 | deconv=False,
20 | bn=use_bn,
21 | expand=False,
22 | align_corners=True,
23 | )
24 |
25 |
26 | class DPT(BaseModel):
27 | def __init__(
28 | self,
29 | head,
30 | features=256,
31 | backbone="vitb_rn50_384",
32 | readout="project",
33 | channels_last=False,
34 | use_bn=False,
35 | ):
36 |
37 | super(DPT, self).__init__()
38 |
39 | self.channels_last = channels_last
40 |
41 | hooks = {
42 | "vitb_rn50_384": [0, 1, 8, 11],
43 | "vitb16_384": [2, 5, 8, 11],
44 | "vitl16_384": [5, 11, 17, 23],
45 | }
46 |
47 | # Instantiate backbone and reassemble blocks
48 | self.pretrained, self.scratch = _make_encoder(
49 | backbone,
50 | features,
51 | False, # Set to true of you want to train from scratch, uses ImageNet weights
52 | groups=1,
53 | expand=False,
54 | exportable=False,
55 | hooks=hooks[backbone],
56 | use_readout=readout,
57 | )
58 |
59 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
60 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
61 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
62 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
63 |
64 | self.scratch.output_conv = head
65 |
66 |
67 | def forward(self, x):
68 | if self.channels_last == True:
69 | x.contiguous(memory_format=torch.channels_last)
70 |
71 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
72 |
73 | layer_1_rn = self.scratch.layer1_rn(layer_1)
74 | layer_2_rn = self.scratch.layer2_rn(layer_2)
75 | layer_3_rn = self.scratch.layer3_rn(layer_3)
76 | layer_4_rn = self.scratch.layer4_rn(layer_4)
77 |
78 | path_4 = self.scratch.refinenet4(layer_4_rn)
79 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
80 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
81 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
82 |
83 | out = self.scratch.output_conv(path_1)
84 |
85 | return out
86 |
87 |
88 | class DPTDepthModel(DPT):
89 | def __init__(self, path=None, non_negative=True, **kwargs):
90 | features = kwargs["features"] if "features" in kwargs else 256
91 |
92 | head = nn.Sequential(
93 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
94 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
95 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
96 | nn.ReLU(True),
97 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
98 | nn.ReLU(True) if non_negative else nn.Identity(),
99 | nn.Identity(),
100 | )
101 |
102 | super().__init__(head, **kwargs)
103 |
104 | if path is not None:
105 | self.load(path)
106 |
107 | def forward(self, x):
108 | return super().forward(x).squeeze(dim=1)
109 |
110 |
--------------------------------------------------------------------------------
/intrinsic_compositing/albedo/model/MiDaS/midas/midas_net.py:
--------------------------------------------------------------------------------
1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2 | This file contains code that is adapted from
3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4 | """
5 | import torch
6 | import torch.nn as nn
7 |
8 | from .base_model import BaseModel
9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
10 |
11 |
12 | class MidasNet(BaseModel):
13 | """Network for monocular depth estimation.
14 | """
15 |
16 | def __init__(self, path=None, features=256, non_negative=True):
17 | """Init.
18 |
19 | Args:
20 | path (str, optional): Path to saved model. Defaults to None.
21 | features (int, optional): Number of features. Defaults to 256.
22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50
23 | """
24 | print("Loading weights: ", path)
25 |
26 | super(MidasNet, self).__init__()
27 |
28 | use_pretrained = False if path is None else True
29 |
30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
31 |
32 | self.scratch.refinenet4 = FeatureFusionBlock(features)
33 | self.scratch.refinenet3 = FeatureFusionBlock(features)
34 | self.scratch.refinenet2 = FeatureFusionBlock(features)
35 | self.scratch.refinenet1 = FeatureFusionBlock(features)
36 |
37 | self.scratch.output_conv = nn.Sequential(
38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
39 | Interpolate(scale_factor=2, mode="bilinear"),
40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
41 | nn.ReLU(True),
42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
43 | nn.ReLU(True) if non_negative else nn.Identity(),
44 | )
45 |
46 | if path:
47 | self.load(path)
48 |
49 | def forward(self, x):
50 | """Forward pass.
51 |
52 | Args:
53 | x (tensor): input data (image)
54 |
55 | Returns:
56 | tensor: depth
57 | """
58 |
59 | layer_1 = self.pretrained.layer1(x)
60 | layer_2 = self.pretrained.layer2(layer_1)
61 | layer_3 = self.pretrained.layer3(layer_2)
62 | layer_4 = self.pretrained.layer4(layer_3)
63 |
64 | layer_1_rn = self.scratch.layer1_rn(layer_1)
65 | layer_2_rn = self.scratch.layer2_rn(layer_2)
66 | layer_3_rn = self.scratch.layer3_rn(layer_3)
67 | layer_4_rn = self.scratch.layer4_rn(layer_4)
68 |
69 | path_4 = self.scratch.refinenet4(layer_4_rn)
70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
73 |
74 | out = self.scratch.output_conv(path_1)
75 |
76 | return torch.squeeze(out, dim=1)
77 |
--------------------------------------------------------------------------------
/intrinsic_compositing/albedo/model/MiDaS/midas/midas_net_custom.py:
--------------------------------------------------------------------------------
1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2 | This file contains code that is adapted from
3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4 | """
5 | import torch
6 | import torch.nn as nn
7 |
8 | from .base_model import BaseModel
9 | from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
10 |
11 |
12 | class MidasNet_small(BaseModel):
13 | """Network for monocular depth estimation.
14 | """
15 |
16 | def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
17 | blocks={'expand': True}):
18 | """Init.
19 |
20 | Args:
21 | path (str, optional): Path to saved model. Defaults to None.
22 | features (int, optional): Number of features. Defaults to 256.
23 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50
24 | """
25 | print("Loading weights: ", path)
26 |
27 | super(MidasNet_small, self).__init__()
28 |
29 | use_pretrained = False if path else True
30 |
31 | self.channels_last = channels_last
32 | self.blocks = blocks
33 | self.backbone = backbone
34 |
35 | self.groups = 1
36 |
37 | features1=features
38 | features2=features
39 | features3=features
40 | features4=features
41 | self.expand = False
42 | if "expand" in self.blocks and self.blocks['expand'] == True:
43 | self.expand = True
44 | features1=features
45 | features2=features*2
46 | features3=features*4
47 | features4=features*8
48 |
49 | self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
50 |
51 | self.scratch.activation = nn.ReLU(False)
52 |
53 | self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
54 | self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
55 | self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
56 | self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
57 |
58 |
59 | self.scratch.output_conv = nn.Sequential(
60 | nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
61 | Interpolate(scale_factor=2, mode="bilinear"),
62 | nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
63 | self.scratch.activation,
64 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
65 | nn.ReLU(True) if non_negative else nn.Identity(),
66 | nn.Identity(),
67 | )
68 |
69 | if path:
70 | self.load(path)
71 |
72 |
73 | def forward(self, x):
74 | """Forward pass.
75 |
76 | Args:
77 | x (tensor): input data (image)
78 |
79 | Returns:
80 | tensor: depth
81 | """
82 | if self.channels_last==True:
83 | print("self.channels_last = ", self.channels_last)
84 | x.contiguous(memory_format=torch.channels_last)
85 |
86 |
87 | layer_1 = self.pretrained.layer1(x)
88 | layer_2 = self.pretrained.layer2(layer_1)
89 | layer_3 = self.pretrained.layer3(layer_2)
90 | layer_4 = self.pretrained.layer4(layer_3)
91 |
92 | layer_1_rn = self.scratch.layer1_rn(layer_1)
93 | layer_2_rn = self.scratch.layer2_rn(layer_2)
94 | layer_3_rn = self.scratch.layer3_rn(layer_3)
95 | layer_4_rn = self.scratch.layer4_rn(layer_4)
96 |
97 |
98 | path_4 = self.scratch.refinenet4(layer_4_rn)
99 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
100 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
101 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
102 |
103 | out = self.scratch.output_conv(path_1)
104 |
105 | return torch.squeeze(out, dim=1)
106 |
107 |
108 |
109 | def fuse_model(m):
110 | prev_previous_type = nn.Identity()
111 | prev_previous_name = ''
112 | previous_type = nn.Identity()
113 | previous_name = ''
114 | for name, module in m.named_modules():
115 | if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
116 | # print("FUSED ", prev_previous_name, previous_name, name)
117 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
118 | elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
119 | # print("FUSED ", prev_previous_name, previous_name)
120 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
121 | # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
122 | # print("FUSED ", previous_name, name)
123 | # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
124 |
125 | prev_previous_type = previous_type
126 | prev_previous_name = previous_name
127 | previous_type = type(module)
128 | previous_name = name
--------------------------------------------------------------------------------
/intrinsic_compositing/albedo/model/MiDaS/midas/transforms.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import math
4 |
5 |
6 | def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
7 | """Rezise the sample to ensure the given size. Keeps aspect ratio.
8 |
9 | Args:
10 | sample (dict): sample
11 | size (tuple): image size
12 |
13 | Returns:
14 | tuple: new size
15 | """
16 | shape = list(sample["disparity"].shape)
17 |
18 | if shape[0] >= size[0] and shape[1] >= size[1]:
19 | return sample
20 |
21 | scale = [0, 0]
22 | scale[0] = size[0] / shape[0]
23 | scale[1] = size[1] / shape[1]
24 |
25 | scale = max(scale)
26 |
27 | shape[0] = math.ceil(scale * shape[0])
28 | shape[1] = math.ceil(scale * shape[1])
29 |
30 | # resize
31 | sample["image"] = cv2.resize(
32 | sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
33 | )
34 |
35 | sample["disparity"] = cv2.resize(
36 | sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
37 | )
38 | sample["mask"] = cv2.resize(
39 | sample["mask"].astype(np.single),
40 | tuple(shape[::-1]),
41 | interpolation=cv2.INTER_NEAREST,
42 | )
43 | sample["mask"] = sample["mask"].astype(bool)
44 |
45 | return tuple(shape)
46 |
47 |
48 | class Resize(object):
49 | """Resize sample to given size (width, height).
50 | """
51 |
52 | def __init__(
53 | self,
54 | width,
55 | height,
56 | resize_target=True,
57 | keep_aspect_ratio=False,
58 | ensure_multiple_of=1,
59 | resize_method="lower_bound",
60 | image_interpolation_method=cv2.INTER_AREA,
61 | ):
62 | """Init.
63 |
64 | Args:
65 | width (int): desired output width
66 | height (int): desired output height
67 | resize_target (bool, optional):
68 | True: Resize the full sample (image, mask, target).
69 | False: Resize image only.
70 | Defaults to True.
71 | keep_aspect_ratio (bool, optional):
72 | True: Keep the aspect ratio of the input sample.
73 | Output sample might not have the given width and height, and
74 | resize behaviour depends on the parameter 'resize_method'.
75 | Defaults to False.
76 | ensure_multiple_of (int, optional):
77 | Output width and height is constrained to be multiple of this parameter.
78 | Defaults to 1.
79 | resize_method (str, optional):
80 | "lower_bound": Output will be at least as large as the given size.
81 | "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
82 | "minimal": Scale as least as possible. (Output size might be smaller than given size.)
83 | Defaults to "lower_bound".
84 | """
85 | self.__width = width
86 | self.__height = height
87 |
88 | self.__resize_target = resize_target
89 | self.__keep_aspect_ratio = keep_aspect_ratio
90 | self.__multiple_of = ensure_multiple_of
91 | self.__resize_method = resize_method
92 | self.__image_interpolation_method = image_interpolation_method
93 |
94 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
95 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
96 |
97 | if max_val is not None and y > max_val:
98 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
99 |
100 | if y < min_val:
101 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
102 |
103 | return y
104 |
105 | def get_size(self, width, height):
106 | # determine new height and width
107 | scale_height = self.__height / height
108 | scale_width = self.__width / width
109 |
110 | if self.__keep_aspect_ratio:
111 | if self.__resize_method == "lower_bound":
112 | # scale such that output size is lower bound
113 | if scale_width > scale_height:
114 | # fit width
115 | scale_height = scale_width
116 | else:
117 | # fit height
118 | scale_width = scale_height
119 | elif self.__resize_method == "upper_bound":
120 | # scale such that output size is upper bound
121 | if scale_width < scale_height:
122 | # fit width
123 | scale_height = scale_width
124 | else:
125 | # fit height
126 | scale_width = scale_height
127 | elif self.__resize_method == "minimal":
128 | # scale as least as possbile
129 | if abs(1 - scale_width) < abs(1 - scale_height):
130 | # fit width
131 | scale_height = scale_width
132 | else:
133 | # fit height
134 | scale_width = scale_height
135 | else:
136 | raise ValueError(
137 | f"resize_method {self.__resize_method} not implemented"
138 | )
139 |
140 | if self.__resize_method == "lower_bound":
141 | new_height = self.constrain_to_multiple_of(
142 | scale_height * height, min_val=self.__height
143 | )
144 | new_width = self.constrain_to_multiple_of(
145 | scale_width * width, min_val=self.__width
146 | )
147 | elif self.__resize_method == "upper_bound":
148 | new_height = self.constrain_to_multiple_of(
149 | scale_height * height, max_val=self.__height
150 | )
151 | new_width = self.constrain_to_multiple_of(
152 | scale_width * width, max_val=self.__width
153 | )
154 | elif self.__resize_method == "minimal":
155 | new_height = self.constrain_to_multiple_of(scale_height * height)
156 | new_width = self.constrain_to_multiple_of(scale_width * width)
157 | else:
158 | raise ValueError(f"resize_method {self.__resize_method} not implemented")
159 |
160 | return (new_width, new_height)
161 |
162 | def __call__(self, sample):
163 | width, height = self.get_size(
164 | sample["image"].shape[1], sample["image"].shape[0]
165 | )
166 |
167 | # resize sample
168 | sample["image"] = cv2.resize(
169 | sample["image"],
170 | (width, height),
171 | interpolation=self.__image_interpolation_method,
172 | )
173 |
174 | if self.__resize_target:
175 | if "disparity" in sample:
176 | sample["disparity"] = cv2.resize(
177 | sample["disparity"],
178 | (width, height),
179 | interpolation=cv2.INTER_NEAREST,
180 | )
181 |
182 | if "depth" in sample:
183 | sample["depth"] = cv2.resize(
184 | sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
185 | )
186 |
187 | sample["mask"] = cv2.resize(
188 | sample["mask"].astype(np.single),
189 | (width, height),
190 | interpolation=cv2.INTER_NEAREST,
191 | )
192 | sample["mask"] = sample["mask"].astype(bool)
193 |
194 | return sample
195 |
196 |
197 | class NormalizeImage(object):
198 | """Normlize image by given mean and std.
199 | """
200 |
201 | def __init__(self, mean, std):
202 | self.__mean = mean
203 | self.__std = std
204 |
205 | def __call__(self, sample):
206 | sample["image"] = (sample["image"] - self.__mean) / self.__std
207 |
208 | return sample
209 |
210 |
211 | class PrepareForNet(object):
212 | """Prepare sample for usage as network input.
213 | """
214 |
215 | def __init__(self):
216 | pass
217 |
218 | def __call__(self, sample):
219 | image = np.transpose(sample["image"], (2, 0, 1))
220 | sample["image"] = np.ascontiguousarray(image).astype(np.single)
221 |
222 | if "mask" in sample:
223 | sample["mask"] = sample["mask"].astype(np.single)
224 | sample["mask"] = np.ascontiguousarray(sample["mask"])
225 |
226 | if "disparity" in sample:
227 | disparity = sample["disparity"].astype(np.single)
228 | sample["disparity"] = np.ascontiguousarray(disparity)
229 |
230 | if "depth" in sample:
231 | depth = sample["depth"].astype(np.single)
232 | sample["depth"] = np.ascontiguousarray(depth)
233 |
234 | return sample
235 |
--------------------------------------------------------------------------------
/intrinsic_compositing/albedo/model/MiDaS/midas/vit.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import timm
4 | import types
5 | import math
6 | import torch.nn.functional as F
7 |
8 |
9 | class Slice(nn.Module):
10 | def __init__(self, start_index=1):
11 | super(Slice, self).__init__()
12 | self.start_index = start_index
13 |
14 | def forward(self, x):
15 | return x[:, self.start_index :]
16 |
17 |
18 | class AddReadout(nn.Module):
19 | def __init__(self, start_index=1):
20 | super(AddReadout, self).__init__()
21 | self.start_index = start_index
22 |
23 | def forward(self, x):
24 | if self.start_index == 2:
25 | readout = (x[:, 0] + x[:, 1]) / 2
26 | else:
27 | readout = x[:, 0]
28 | return x[:, self.start_index :] + readout.unsqueeze(1)
29 |
30 |
31 | class ProjectReadout(nn.Module):
32 | def __init__(self, in_features, start_index=1):
33 | super(ProjectReadout, self).__init__()
34 | self.start_index = start_index
35 |
36 | self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
37 |
38 | def forward(self, x):
39 | readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
40 | features = torch.cat((x[:, self.start_index :], readout), -1)
41 |
42 | return self.project(features)
43 |
44 |
45 | class Transpose(nn.Module):
46 | def __init__(self, dim0, dim1):
47 | super(Transpose, self).__init__()
48 | self.dim0 = dim0
49 | self.dim1 = dim1
50 |
51 | def forward(self, x):
52 | x = x.transpose(self.dim0, self.dim1)
53 | return x
54 |
55 |
56 | def forward_vit(pretrained, x):
57 | b, c, h, w = x.shape
58 |
59 | glob = pretrained.model.forward_flex(x)
60 |
61 | layer_1 = pretrained.activations["1"]
62 | layer_2 = pretrained.activations["2"]
63 | layer_3 = pretrained.activations["3"]
64 | layer_4 = pretrained.activations["4"]
65 |
66 | layer_1 = pretrained.act_postprocess1[0:2](layer_1)
67 | layer_2 = pretrained.act_postprocess2[0:2](layer_2)
68 | layer_3 = pretrained.act_postprocess3[0:2](layer_3)
69 | layer_4 = pretrained.act_postprocess4[0:2](layer_4)
70 |
71 | unflatten = nn.Sequential(
72 | nn.Unflatten(
73 | 2,
74 | torch.Size(
75 | [
76 | h // pretrained.model.patch_size[1],
77 | w // pretrained.model.patch_size[0],
78 | ]
79 | ),
80 | )
81 | )
82 |
83 | if layer_1.ndim == 3:
84 | layer_1 = unflatten(layer_1)
85 | if layer_2.ndim == 3:
86 | layer_2 = unflatten(layer_2)
87 | if layer_3.ndim == 3:
88 | layer_3 = unflatten(layer_3)
89 | if layer_4.ndim == 3:
90 | layer_4 = unflatten(layer_4)
91 |
92 | layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
93 | layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
94 | layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
95 | layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
96 |
97 | return layer_1, layer_2, layer_3, layer_4
98 |
99 |
100 | def _resize_pos_embed(self, posemb, gs_h, gs_w):
101 | posemb_tok, posemb_grid = (
102 | posemb[:, : self.start_index],
103 | posemb[0, self.start_index :],
104 | )
105 |
106 | gs_old = int(math.sqrt(len(posemb_grid)))
107 |
108 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
109 | posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
110 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
111 |
112 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
113 |
114 | return posemb
115 |
116 |
117 | def forward_flex(self, x):
118 | b, c, h, w = x.shape
119 |
120 | pos_embed = self._resize_pos_embed(
121 | self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
122 | )
123 |
124 | B = x.shape[0]
125 |
126 | if hasattr(self.patch_embed, "backbone"):
127 | x = self.patch_embed.backbone(x)
128 | if isinstance(x, (list, tuple)):
129 | x = x[-1] # last feature if backbone outputs list/tuple of features
130 |
131 | x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
132 |
133 | if getattr(self, "dist_token", None) is not None:
134 | cls_tokens = self.cls_token.expand(
135 | B, -1, -1
136 | ) # stole cls_tokens impl from Phil Wang, thanks
137 | dist_token = self.dist_token.expand(B, -1, -1)
138 | x = torch.cat((cls_tokens, dist_token, x), dim=1)
139 | else:
140 | cls_tokens = self.cls_token.expand(
141 | B, -1, -1
142 | ) # stole cls_tokens impl from Phil Wang, thanks
143 | x = torch.cat((cls_tokens, x), dim=1)
144 |
145 | x = x + pos_embed
146 | x = self.pos_drop(x)
147 |
148 | for blk in self.blocks:
149 | x = blk(x)
150 |
151 | x = self.norm(x)
152 |
153 | return x
154 |
155 |
156 | activations = {}
157 |
158 |
159 | def get_activation(name):
160 | def hook(model, input, output):
161 | activations[name] = output
162 |
163 | return hook
164 |
165 |
166 | def get_readout_oper(vit_features, features, use_readout, start_index=1):
167 | if use_readout == "ignore":
168 | readout_oper = [Slice(start_index)] * len(features)
169 | elif use_readout == "add":
170 | readout_oper = [AddReadout(start_index)] * len(features)
171 | elif use_readout == "project":
172 | readout_oper = [
173 | ProjectReadout(vit_features, start_index) for out_feat in features
174 | ]
175 | else:
176 | assert (
177 | False
178 | ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
179 |
180 | return readout_oper
181 |
182 |
183 | def _make_vit_b16_backbone(
184 | model,
185 | features=[96, 192, 384, 768],
186 | size=[384, 384],
187 | hooks=[2, 5, 8, 11],
188 | vit_features=768,
189 | use_readout="ignore",
190 | start_index=1,
191 | ):
192 | pretrained = nn.Module()
193 |
194 | pretrained.model = model
195 | pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
196 | pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
197 | pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
198 | pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
199 |
200 | pretrained.activations = activations
201 |
202 | readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
203 |
204 | # 32, 48, 136, 384
205 | pretrained.act_postprocess1 = nn.Sequential(
206 | readout_oper[0],
207 | Transpose(1, 2),
208 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
209 | nn.Conv2d(
210 | in_channels=vit_features,
211 | out_channels=features[0],
212 | kernel_size=1,
213 | stride=1,
214 | padding=0,
215 | ),
216 | nn.ConvTranspose2d(
217 | in_channels=features[0],
218 | out_channels=features[0],
219 | kernel_size=4,
220 | stride=4,
221 | padding=0,
222 | bias=True,
223 | dilation=1,
224 | groups=1,
225 | ),
226 | )
227 |
228 | pretrained.act_postprocess2 = nn.Sequential(
229 | readout_oper[1],
230 | Transpose(1, 2),
231 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
232 | nn.Conv2d(
233 | in_channels=vit_features,
234 | out_channels=features[1],
235 | kernel_size=1,
236 | stride=1,
237 | padding=0,
238 | ),
239 | nn.ConvTranspose2d(
240 | in_channels=features[1],
241 | out_channels=features[1],
242 | kernel_size=2,
243 | stride=2,
244 | padding=0,
245 | bias=True,
246 | dilation=1,
247 | groups=1,
248 | ),
249 | )
250 |
251 | pretrained.act_postprocess3 = nn.Sequential(
252 | readout_oper[2],
253 | Transpose(1, 2),
254 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
255 | nn.Conv2d(
256 | in_channels=vit_features,
257 | out_channels=features[2],
258 | kernel_size=1,
259 | stride=1,
260 | padding=0,
261 | ),
262 | )
263 |
264 | pretrained.act_postprocess4 = nn.Sequential(
265 | readout_oper[3],
266 | Transpose(1, 2),
267 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
268 | nn.Conv2d(
269 | in_channels=vit_features,
270 | out_channels=features[3],
271 | kernel_size=1,
272 | stride=1,
273 | padding=0,
274 | ),
275 | nn.Conv2d(
276 | in_channels=features[3],
277 | out_channels=features[3],
278 | kernel_size=3,
279 | stride=2,
280 | padding=1,
281 | ),
282 | )
283 |
284 | pretrained.model.start_index = start_index
285 | pretrained.model.patch_size = [16, 16]
286 |
287 | # We inject this function into the VisionTransformer instances so that
288 | # we can use it with interpolated position embeddings without modifying the library source.
289 | pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
290 | pretrained.model._resize_pos_embed = types.MethodType(
291 | _resize_pos_embed, pretrained.model
292 | )
293 |
294 | return pretrained
295 |
296 |
297 | def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
298 | model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
299 |
300 | hooks = [5, 11, 17, 23] if hooks == None else hooks
301 | return _make_vit_b16_backbone(
302 | model,
303 | features=[256, 512, 1024, 1024],
304 | hooks=hooks,
305 | vit_features=1024,
306 | use_readout=use_readout,
307 | )
308 |
309 |
310 | def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
311 | model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
312 |
313 | hooks = [2, 5, 8, 11] if hooks == None else hooks
314 | return _make_vit_b16_backbone(
315 | model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
316 | )
317 |
318 |
319 | def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
320 | model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
321 |
322 | hooks = [2, 5, 8, 11] if hooks == None else hooks
323 | return _make_vit_b16_backbone(
324 | model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
325 | )
326 |
327 |
328 | def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
329 | model = timm.create_model(
330 | "vit_deit_base_distilled_patch16_384", pretrained=pretrained
331 | )
332 |
333 | hooks = [2, 5, 8, 11] if hooks == None else hooks
334 | return _make_vit_b16_backbone(
335 | model,
336 | features=[96, 192, 384, 768],
337 | hooks=hooks,
338 | use_readout=use_readout,
339 | start_index=2,
340 | )
341 |
342 |
343 | def _make_vit_b_rn50_backbone(
344 | model,
345 | features=[256, 512, 768, 768],
346 | size=[384, 384],
347 | hooks=[0, 1, 8, 11],
348 | vit_features=768,
349 | use_vit_only=False,
350 | use_readout="ignore",
351 | start_index=1,
352 | ):
353 | pretrained = nn.Module()
354 |
355 | pretrained.model = model
356 |
357 | if use_vit_only == True:
358 | pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
359 | pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
360 | else:
361 | pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
362 | get_activation("1")
363 | )
364 | pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
365 | get_activation("2")
366 | )
367 |
368 | pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
369 | pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
370 |
371 | pretrained.activations = activations
372 |
373 | readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
374 |
375 | if use_vit_only == True:
376 | pretrained.act_postprocess1 = nn.Sequential(
377 | readout_oper[0],
378 | Transpose(1, 2),
379 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
380 | nn.Conv2d(
381 | in_channels=vit_features,
382 | out_channels=features[0],
383 | kernel_size=1,
384 | stride=1,
385 | padding=0,
386 | ),
387 | nn.ConvTranspose2d(
388 | in_channels=features[0],
389 | out_channels=features[0],
390 | kernel_size=4,
391 | stride=4,
392 | padding=0,
393 | bias=True,
394 | dilation=1,
395 | groups=1,
396 | ),
397 | )
398 |
399 | pretrained.act_postprocess2 = nn.Sequential(
400 | readout_oper[1],
401 | Transpose(1, 2),
402 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
403 | nn.Conv2d(
404 | in_channels=vit_features,
405 | out_channels=features[1],
406 | kernel_size=1,
407 | stride=1,
408 | padding=0,
409 | ),
410 | nn.ConvTranspose2d(
411 | in_channels=features[1],
412 | out_channels=features[1],
413 | kernel_size=2,
414 | stride=2,
415 | padding=0,
416 | bias=True,
417 | dilation=1,
418 | groups=1,
419 | ),
420 | )
421 | else:
422 | pretrained.act_postprocess1 = nn.Sequential(
423 | nn.Identity(), nn.Identity(), nn.Identity()
424 | )
425 | pretrained.act_postprocess2 = nn.Sequential(
426 | nn.Identity(), nn.Identity(), nn.Identity()
427 | )
428 |
429 | pretrained.act_postprocess3 = nn.Sequential(
430 | readout_oper[2],
431 | Transpose(1, 2),
432 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
433 | nn.Conv2d(
434 | in_channels=vit_features,
435 | out_channels=features[2],
436 | kernel_size=1,
437 | stride=1,
438 | padding=0,
439 | ),
440 | )
441 |
442 | pretrained.act_postprocess4 = nn.Sequential(
443 | readout_oper[3],
444 | Transpose(1, 2),
445 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
446 | nn.Conv2d(
447 | in_channels=vit_features,
448 | out_channels=features[3],
449 | kernel_size=1,
450 | stride=1,
451 | padding=0,
452 | ),
453 | nn.Conv2d(
454 | in_channels=features[3],
455 | out_channels=features[3],
456 | kernel_size=3,
457 | stride=2,
458 | padding=1,
459 | ),
460 | )
461 |
462 | pretrained.model.start_index = start_index
463 | pretrained.model.patch_size = [16, 16]
464 |
465 | # We inject this function into the VisionTransformer instances so that
466 | # we can use it with interpolated position embeddings without modifying the library source.
467 | pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
468 |
469 | # We inject this function into the VisionTransformer instances so that
470 | # we can use it with interpolated position embeddings without modifying the library source.
471 | pretrained.model._resize_pos_embed = types.MethodType(
472 | _resize_pos_embed, pretrained.model
473 | )
474 |
475 | return pretrained
476 |
477 |
478 | def _make_pretrained_vitb_rn50_384(
479 | pretrained, use_readout="ignore", hooks=None, use_vit_only=False
480 | ):
481 | model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
482 |
483 | hooks = [0, 1, 8, 11] if hooks == None else hooks
484 | return _make_vit_b_rn50_backbone(
485 | model,
486 | features=[256, 512, 768, 768],
487 | size=[384, 384],
488 | hooks=hooks,
489 | use_vit_only=use_vit_only,
490 | use_readout=use_readout,
491 | )
492 |
--------------------------------------------------------------------------------
/intrinsic_compositing/albedo/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/intrinsic_compositing/albedo/model/__init__.py
--------------------------------------------------------------------------------
/intrinsic_compositing/albedo/model/discriminator.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | from .pix2pix.models.networks import NLayerDiscriminator,get_norm_layer
5 | from ..utils.utils import normalize
6 |
7 | class VOTEGAN(nn.Module):
8 | def __init__(self, args):
9 | super(VOTEGAN, self).__init__()
10 | n_ic = 4
11 | if args.crop_size == 384:
12 | n_f = 22*22
13 | else:
14 | raise NotImplementedError
15 |
16 | ndf = 64 # of discrim filters in the first conv layer
17 | norm_layer = get_norm_layer(norm_type='instance')
18 | self.model = nn.Sequential(NLayerDiscriminator(n_ic, ndf, n_layers=4, norm_layer=norm_layer), # 70*70 patchgan for 256*256
19 | nn.Flatten(start_dim=1),
20 | nn.Linear(n_f,32),
21 | nn.Linear(32,1),
22 | nn.Sigmoid()
23 | )
24 |
25 |
26 | def forward(self,x):
27 | feat = self.model(normalize(x))
28 | return feat
29 |
--------------------------------------------------------------------------------
/intrinsic_compositing/albedo/model/editingnetwork_trainer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | from ..utils.networkutils import init_net, loadmodelweights
7 | from ..utils.edits import apply_colorcurve, apply_exposure, apply_saturation, apply_whitebalancing, get_edits
8 |
9 | from .parametermodel import ParametersRegressor
10 | from .discriminator import VOTEGAN
11 |
12 |
13 | class EditingNetworkTrainer:
14 | def __init__(self, args):
15 |
16 | self.args = args
17 | self.edits = get_edits(args.nops)
18 | self.device = torch.device('cuda:{}'.format(args.gpu_ids[0])) if args.gpu_ids else torch.device('cpu')
19 | self.model_names = ['Parameters']
20 |
21 | self.net_Parameters = init_net(ParametersRegressor(args),args.gpu_ids)
22 | self.net_Parameters.train()
23 |
24 | # Initial Editing Network weights
25 | if args.checkpoint_load_path is not None:
26 | loadmodelweights(self.net_Parameters,args.checkpoint_load_path, self.device)
27 |
28 | if 'realism' in args.edit_loss:
29 | self.net_D = init_net(VOTEGAN(args), args.gpu_ids)
30 | self.set_requires_grad(self.net_D, False)
31 | # Load the realism network weights
32 | loadmodelweights(self.net_D,args.realism_model_weight_path, self.device)
33 | self.net_D.eval()
34 |
35 | # Set the optimizers
36 | self.optimizer_Parameters = torch.optim.Adam(self.net_Parameters.parameters(), lr=args.lr_editnet)
37 |
38 | # Set the mode for each network
39 |
40 | # Set the loss functions
41 | self.criterion_L2 = torch.nn.MSELoss()
42 |
43 | # Set the needed constants and parameters
44 | self.logs = []
45 |
46 | # self.all_permutations = torch.tensor([
47 | # [0,1,2,3],[0,2,1,3],[0,3,1,2],[0,1,3,2],[0,2,3,1],[0,3,2,1],
48 | # [1,0,2,3],[1,2,0,3],[1,3,0,2],[1,0,3,2],[1,2,3,0],[1,3,2,0],
49 | # [2,0,1,3],[2,1,0,3],[2,3,0,1],[2,0,3,1],[2,1,3,0],[2,3,1,0],
50 | # [3,0,1,2],[3,1,0,2],[3,2,0,1],[3,0,2,1],[3,1,2,0],[3,2,1,0]
51 | # ]).float().to(self.device)
52 |
53 |
54 | def setEval(self):
55 | self.net_Parameters.eval()
56 |
57 | def setTrain(self):
58 | self.net_Parameters.train()
59 |
60 | def setinput(self, input, mergebatch=1):
61 | self.srgb = input['srgb'].to(self.device)
62 | self.albedo = input['albedo'].to(self.device)
63 | self.shading = input['shading'].to(self.device)
64 | self.mask = input['mask'].to(self.device)
65 |
66 | if mergebatch > 1:
67 | self.srgb = torch.reshape(self.srgb, (self.args.batch_size, 3, self.args.crop_size, self.args.crop_size))
68 | self.albedo = torch.reshape(self.albedo, (self.args.batch_size, 3, self.args.crop_size, self.args.crop_size))
69 | self.shading = torch.reshape(self.shading, (self.args.batch_size, 1, self.args.crop_size, self.args.crop_size))
70 | self.mask = torch.reshape(self.mask, (self.args.batch_size, 1, self.args.crop_size, self.args.crop_size))
71 |
72 | albedo_edited = self.create_fake_edited(self.albedo)
73 | self.albedo_fake = (1 - self.mask) * self.albedo + self.mask * albedo_edited
74 | self.input = torch.cat((self.albedo_fake,self.mask),dim=1).to(self.device)
75 |
76 | # self.numelmask = torch.sum(self.mask,dim=[1,2,3])
77 | def setinput_HR(self, input):
78 | self.srgb = input['srgb'].to(self.device)
79 | self.albedo_fake = input['albedo'].to(self.device)
80 | self.albedo_full = input['albedo_full'].to(self.device)
81 | self.shading_full = input['shading_full'].to(self.device)
82 | self.mask_full = input['mask_full'].to(self.device)
83 | self.shading = input['shading'].to(self.device)
84 | self.mask = input['mask'].to(self.device)
85 |
86 | self.input = torch.cat((self.albedo_fake, self.mask),dim=1).to(self.device)
87 |
88 |
89 | def create_fake_edited(self,rgb):
90 | # Randomly choose an edit.
91 | edited = rgb.clone()
92 | ne = np.random.randint(0, 4)
93 | perm = torch.randperm(len(self.edits))
94 |
95 | args = self.args
96 | device = self.device
97 |
98 | for i in range(ne):
99 | edit_id = perm[i]
100 | if self.args.fake_gen_lowdev == 0:
101 | wb_param = torch.rand(args.batch_size, 3).to(device)*0.9 + 0.1
102 | colorcurve = torch.rand(args.batch_size, 24).to(device)*1.5 + 0.5
103 |
104 | sat_param = torch.rand(args.batch_size, 1)*2
105 | sat_param = sat_param.to(device)
106 |
107 | expos_param = torch.rand(args.batch_size, 1)*1.5 + 0.5
108 | expos_param = expos_param.to(device)
109 |
110 | blur_param = torch.rand(args.batch_size, 1)*5 + 0.0001 # to make sure the blur param is never exactly zero.
111 | blur_param = blur_param.to(device)
112 |
113 | sharp_param = torch.rand(args.batch_size, 1)*10 + 1
114 | sharp_param = sharp_param.to(device)
115 |
116 | else:
117 | wb_param = torch.rand(args.batch_size, 3).to(device)*0.2 + 0.5
118 | colorcurve = torch.rand(args.batch_size, 24).to(device)*1.5 + 0.5
119 |
120 | sat_param = torch.rand(args.batch_size, 1)*1 + 0.5
121 | sat_param = sat_param.to(device)
122 |
123 | expos_param = torch.rand(args.batch_size, 1)*1 + 0.5
124 | expos_param = expos_param.to(device)
125 |
126 | blur_param = torch.rand(args.batch_size, 1)*2.5 + 0.0001 # to make sure the blur param is never exactly zero.
127 | blur_param = blur_param.to(device)
128 |
129 | sharp_param = torch.rand(args.batch_size, 1)*5 + 1
130 | sharp_param = sharp_param.to(device)
131 |
132 | parameters = {
133 | 'whitebalancing':wb_param,
134 | 'colorcurve':colorcurve,
135 | 'saturation':sat_param,
136 | 'exposure':expos_param,
137 | 'blur':blur_param,
138 | 'sharpness':sharp_param
139 | }
140 |
141 |
142 | edited = torch.clamp(self.edits[edit_id.item()](edited,parameters),0,1)
143 |
144 | return edited.detach()
145 |
146 | def forward(self):
147 | permutation = torch.randperm(len(self.edits)).float().to(self.device)
148 | params_dic = self.net_Parameters(self.input, permutation.repeat(self.args.batch_size,1))
149 | # print(params_dic)
150 | self.logs.append(params_dic)
151 |
152 | # current_rgb = self.albedo_fake
153 | current_rgb = self.albedo_full
154 |
155 | for ed_in in range(self.args.nops):
156 | current_edited = torch.clamp(self.edits[permutation[ed_in].item()](current_rgb,params_dic),0,1)
157 | current_result = (1 - self.mask_full) * current_rgb + self.mask_full * current_edited
158 |
159 | current_rgb = current_result
160 |
161 | self.result = current_result
162 |
163 | self.result_albedo_srgb = self.result ** (2.2)
164 | self.result_rgb = self.result_albedo_srgb * self.shading_full
165 | self.result_srgb = torch.clamp(self.result_rgb ** (1/2.2),0,1)
166 |
167 |
168 |
169 | def computeloss_realism(self):
170 | if self.args.edit_loss == 'realism':
171 | after = torch.cat((self.result, self.mask), 1)
172 | after_D_value = self.net_D(after).squeeze(1)
173 | self.realism_change = 1 - after_D_value
174 | self.loss_realism = F.relu(self.realism_change - self.args.loss_relu_bias)
175 | self.loss_g = torch.mean(self.loss_realism)
176 |
177 | elif self.args.edit_loss == 'realism_relative':
178 | before = torch.cat((self.albedo, self.mask), 1)
179 | before_D_value = self.net_D(before).squeeze(1)
180 | after = torch.cat((self.result, self.mask), 1)
181 | after_D_value = self.net_D(after).squeeze(1)
182 |
183 | self.realism_change = before_D_value - after_D_value
184 | self.loss_realism = F.relu(self.realism_change - self.args.loss_relu_bias)
185 | self.loss_g = torch.mean(self.loss_realism)
186 |
187 | elif self.args.edit_loss == 'MSE':
188 | self.loss_L2 = self.criterion_L2(self.result, self.albedo)
189 | self.loss_g = torch.mean(self.loss_L2)
190 |
191 |
192 |
193 |
194 | def optimize_parameters(self):
195 | self.optimizer_Parameters.zero_grad()
196 | self.computeloss_realism()
197 | self.loss_g.backward()
198 | self.optimizer_Parameters.step()
199 |
200 | for name, p in self.net_Parameters.named_parameters():
201 | if p.grad is None:
202 | print(name)
203 |
204 |
205 | def set_requires_grad(self, nets, requires_grad=False):
206 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
207 | Parameters:
208 | nets (network list) -- a list of networks
209 | requires_grad (bool) -- whether the networks require gradients or not
210 | """
211 | if not isinstance(nets, list):
212 | nets = [nets]
213 | for net in nets:
214 | if net is not None:
215 | for param in net.parameters():
216 | param.requires_grad = requires_grad
217 |
218 | def savemodel(self,iteration, checkpointdir):
219 | for name in self.model_names:
220 | if isinstance(name, str):
221 | save_filename = '%s_net_%s.pth' % (iteration, name)
222 | save_path = os.path.join(checkpointdir, save_filename)
223 | net = getattr(self, 'net_' + name)
224 | if len(self.args.gpu_ids) > 0 and torch.cuda.is_available():
225 | torch.save(net.module.cpu().state_dict(), save_path)
226 | net.cuda(self.args.gpu_ids[0])
227 | else:
228 | torch.save(net.cpu().state_dict(), save_path)
229 |
--------------------------------------------------------------------------------
/intrinsic_compositing/albedo/model/parametermodel.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .MiDaS.midas.blocks import _make_pretrained_efficientnet_lite3
6 | from ..utils.networkutils import Conv2dSameExport
7 | from ..utils.utils import normalize
8 |
9 |
10 | class ParametersRegressor(nn.Module):
11 | def __init__(self, args):
12 | super(ParametersRegressor,self).__init__()
13 |
14 | self.blurandsharpen = args.blursharpen
15 | self.nf = 384
16 | self.shared_dec_nfeat = 128
17 | self.perm_nfeat = 32
18 | self.nops = args.nops
19 |
20 | self.encoder = _make_pretrained_efficientnet_lite3(use_pretrained=True, exportable=True)
21 | self.encoder.layer1[0] = Conv2dSameExport(4, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
22 | self.encoder.layer4[1][0].bn3 = nn.Identity()
23 |
24 | if args.bn_momentum is not None:
25 | for module in self.encoder.modules():
26 | if isinstance(module, nn.BatchNorm2d):
27 | module.momentum = args.bn_momentum
28 |
29 | self.shared_decoder = nn.Sequential(
30 | nn.Linear(self.nf + self.perm_nfeat, self.shared_dec_nfeat),
31 | nn.LeakyReLU(0.1,False),
32 | nn.Linear(self.shared_dec_nfeat, self.shared_dec_nfeat),
33 | )
34 |
35 | self.WB_head = nn.Sequential(
36 | nn.Linear(self.shared_dec_nfeat, 3),
37 | nn.Sigmoid(),
38 | )
39 | self.ColorCurve_head = nn.Sequential(
40 | nn.Linear(self.shared_dec_nfeat, 24),
41 | nn.Sigmoid(),
42 | )
43 | self.Satur_head = nn.Sequential(
44 | nn.Linear(self.shared_dec_nfeat, 1),
45 | nn.Sigmoid(),
46 | )
47 | self.Expos_head = nn.Sequential(
48 | nn.Linear(self.shared_dec_nfeat, 1),
49 | nn.Sigmoid(),
50 | )
51 |
52 | self.perm_modulation = nn.Sequential(
53 | nn.Linear(self.nops,self.perm_nfeat)
54 | )
55 |
56 | if self.blurandsharpen:
57 | self.Sharp_head = nn.Sequential(
58 | nn.Linear(self.shared_dec_nfeat, 1),
59 | nn.Sigmoid(),
60 | )
61 |
62 | self.Blur_head = nn.Sequential(
63 | nn.Linear(self.shared_dec_nfeat, 1),
64 | nn.Sigmoid(),
65 | )
66 |
67 |
68 | self.globalavgpool = nn.AdaptiveAvgPool2d(output_size=1)
69 |
70 | def forward(self,x, permutation):
71 |
72 | img_feat = self.encoder.layer1(normalize(x))
73 | img_feat = self.encoder.layer2(img_feat)
74 | img_feat = self.encoder.layer3(img_feat)
75 | img_feat = self.encoder.layer4(img_feat)
76 |
77 | img_feat = self.globalavgpool(img_feat).squeeze(-1).squeeze(-1)
78 | perm_feat = self.perm_modulation(permutation)
79 |
80 | feat = torch.cat((img_feat,perm_feat),dim=1)
81 |
82 | feat = self.shared_decoder(feat)
83 |
84 | wb_param = self.WB_head(feat)*0.9 + 0.1
85 | colorcurve = self.ColorCurve_head(feat)*1.5 + 0.5
86 | sat_param = self.Satur_head(feat)*2
87 | expos_param = self.Expos_head(feat)*1.5 + 0.5
88 |
89 | if self.blurandsharpen:
90 | sharp_param = self.Sharp_head(feat)*10 + 1 # smaller than one with generate blured image!
91 | blur_param = self.Blur_head(feat)*5 + 0.001 # cannot be zero -- blur operation fails for zero kernel
92 |
93 | result_dic = {
94 | 'whitebalancing':wb_param,
95 | 'colorcurve':colorcurve,
96 | 'saturation':sat_param,
97 | 'exposure':expos_param ,
98 | 'blur':blur_param,
99 | 'sharpness':sharp_param
100 | }
101 | else:
102 | result_dic = {
103 | 'whitebalancing':wb_param,
104 | 'colorcurve':colorcurve,
105 | 'saturation':sat_param,
106 | 'exposure':expos_param ,
107 | }
108 |
109 | return result_dic
110 |
--------------------------------------------------------------------------------
/intrinsic_compositing/albedo/model/pix2pix/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/intrinsic_compositing/albedo/model/pix2pix/__init__.py
--------------------------------------------------------------------------------
/intrinsic_compositing/albedo/model/pix2pix/models/__init__.py:
--------------------------------------------------------------------------------
1 | """This package contains modules related to objective functions, optimizations, and network architectures.
2 |
3 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
4 | You need to implement the following five functions:
5 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
6 | -- : unpack data from dataset and apply preprocessing.
7 | -- : produce intermediate results.
8 | -- : calculate loss, gradients, and update network weights.
9 | -- : (optionally) add model-specific options and set default options.
10 |
11 | In the function <__init__>, you need to define four lists:
12 | -- self.loss_names (str list): specify the training losses that you want to plot and save.
13 | -- self.model_names (str list): define networks used in our training.
14 | -- self.visual_names (str list): specify the images that you want to display and save.
15 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
16 |
17 | Now you can use the model class by specifying flag '--model dummy'.
18 | See our template model class 'template_model.py' for more details.
19 | """
20 |
21 | import importlib
22 | from .base_model import BaseModel
23 |
24 |
25 | def find_model_using_name(model_name):
26 | """Import the module "models/[model_name]_model.py".
27 |
28 | In the file, the class called DatasetNameModel() will
29 | be instantiated. It has to be a subclass of BaseModel,
30 | and it is case-insensitive.
31 | """
32 | model_filename = "models." + model_name + "_model"
33 | modellib = importlib.import_module(model_filename)
34 | model = None
35 | target_model_name = model_name.replace('_', '') + 'model'
36 | for name, cls in modellib.__dict__.items():
37 | if name.lower() == target_model_name.lower() \
38 | and issubclass(cls, BaseModel):
39 | model = cls
40 |
41 | if model is None:
42 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
43 | exit(0)
44 |
45 | return model
46 |
47 |
48 | def get_option_setter(model_name):
49 | """Return the static method of the model class."""
50 | model_class = find_model_using_name(model_name)
51 | return model_class.modify_commandline_options
52 |
53 |
54 | def create_model(opt):
55 | """Create a model given the option.
56 |
57 | This function warps the class CustomDatasetDataLoader.
58 | This is the main interface between this package and 'train.py'/'test.py'
59 |
60 | Example:
61 | >>> from models import create_model
62 | >>> model = create_model(opt)
63 | """
64 | model = find_model_using_name(opt.model)
65 | instance = model(opt)
66 | print("model [%s] was created" % type(instance).__name__)
67 | return instance
68 |
--------------------------------------------------------------------------------
/intrinsic_compositing/albedo/model/pix2pix/models/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from collections import OrderedDict
4 | from abc import ABC, abstractmethod
5 | from . import networks
6 |
7 |
8 | class BaseModel(ABC):
9 | """This class is an abstract base class (ABC) for models.
10 | To create a subclass, you need to implement the following five functions:
11 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
12 | -- : unpack data from dataset and apply preprocessing.
13 | -- : produce intermediate results.
14 | -- : calculate losses, gradients, and update network weights.
15 | -- : (optionally) add model-specific options and set default options.
16 | """
17 |
18 | def __init__(self, opt):
19 | """Initialize the BaseModel class.
20 |
21 | Parameters:
22 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
23 |
24 | When creating your custom class, you need to implement your own initialization.
25 | In this function, you should first call
26 | Then, you need to define four lists:
27 | -- self.loss_names (str list): specify the training losses that you want to plot and save.
28 | -- self.model_names (str list): define networks used in our training.
29 | -- self.visual_names (str list): specify the images that you want to display and save.
30 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
31 | """
32 | self.opt = opt
33 | self.gpu_ids = opt.gpu_ids
34 | self.isTrain = opt.isTrain
35 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
36 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
37 | if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
38 | torch.backends.cudnn.benchmark = True
39 | self.loss_names = []
40 | self.model_names = []
41 | self.visual_names = []
42 | self.optimizers = []
43 | self.image_paths = []
44 | self.metric = 0 # used for learning rate policy 'plateau'
45 |
46 | @staticmethod
47 | def modify_commandline_options(parser, is_train):
48 | """Add new model-specific options, and rewrite default values for existing options.
49 |
50 | Parameters:
51 | parser -- original option parser
52 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
53 |
54 | Returns:
55 | the modified parser.
56 | """
57 | return parser
58 |
59 | @abstractmethod
60 | def set_input(self, input):
61 | """Unpack input data from the dataloader and perform necessary pre-processing steps.
62 |
63 | Parameters:
64 | input (dict): includes the data itself and its metadata information.
65 | """
66 | pass
67 |
68 | @abstractmethod
69 | def forward(self):
70 | """Run forward pass; called by both functions and ."""
71 | pass
72 |
73 | @abstractmethod
74 | def optimize_parameters(self):
75 | """Calculate losses, gradients, and update network weights; called in every training iteration"""
76 | pass
77 |
78 | def setup(self, opt):
79 | """Load and print networks; create schedulers
80 |
81 | Parameters:
82 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
83 | """
84 | if self.isTrain:
85 | self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
86 | if not self.isTrain or opt.continue_train:
87 | load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch
88 | self.load_networks(load_suffix)
89 | self.print_networks(opt.verbose)
90 |
91 | def eval(self):
92 | """Make models eval mode during test time"""
93 | for name in self.model_names:
94 | if isinstance(name, str):
95 | net = getattr(self, 'net' + name)
96 | net.eval()
97 |
98 | def test(self):
99 | """Forward function used in test time.
100 |
101 | This function wraps function in no_grad() so we don't save intermediate steps for backprop
102 | It also calls to produce additional visualization results
103 | """
104 | with torch.no_grad():
105 | self.forward()
106 | self.compute_visuals()
107 |
108 | def compute_visuals(self):
109 | """Calculate additional output images for visdom and HTML visualization"""
110 | pass
111 |
112 | def get_image_paths(self):
113 | """ Return image paths that are used to load current data"""
114 | return self.image_paths
115 |
116 | def update_learning_rate(self):
117 | """Update learning rates for all the networks; called at the end of every epoch"""
118 | old_lr = self.optimizers[0].param_groups[0]['lr']
119 | for scheduler in self.schedulers:
120 | if self.opt.lr_policy == 'plateau':
121 | scheduler.step(self.metric)
122 | else:
123 | scheduler.step()
124 |
125 | lr = self.optimizers[0].param_groups[0]['lr']
126 | print('learning rate %.7f -> %.7f' % (old_lr, lr))
127 |
128 | def get_current_visuals(self):
129 | """Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
130 | visual_ret = OrderedDict()
131 | for name in self.visual_names:
132 | if isinstance(name, str):
133 | visual_ret[name] = getattr(self, name)
134 | return visual_ret
135 |
136 | def get_current_losses(self):
137 | """Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
138 | errors_ret = OrderedDict()
139 | for name in self.loss_names:
140 | if isinstance(name, str):
141 | errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
142 | return errors_ret
143 |
144 | def save_networks(self, epoch):
145 | """Save all the networks to the disk.
146 |
147 | Parameters:
148 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
149 | """
150 | for name in self.model_names:
151 | if isinstance(name, str):
152 | save_filename = '%s_net_%s.pth' % (epoch, name)
153 | save_path = os.path.join(self.save_dir, save_filename)
154 | net = getattr(self, 'net' + name)
155 |
156 | if len(self.gpu_ids) > 0 and torch.cuda.is_available():
157 | torch.save(net.module.cpu().state_dict(), save_path)
158 | net.cuda(self.gpu_ids[0])
159 | else:
160 | torch.save(net.cpu().state_dict(), save_path)
161 |
162 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
163 | """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
164 | key = keys[i]
165 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
166 | if module.__class__.__name__.startswith('InstanceNorm') and \
167 | (key == 'running_mean' or key == 'running_var'):
168 | if getattr(module, key) is None:
169 | state_dict.pop('.'.join(keys))
170 | if module.__class__.__name__.startswith('InstanceNorm') and \
171 | (key == 'num_batches_tracked'):
172 | state_dict.pop('.'.join(keys))
173 | else:
174 | self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
175 |
176 | def load_networks(self, epoch):
177 | """Load all the networks from the disk.
178 |
179 | Parameters:
180 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
181 | """
182 | for name in self.model_names:
183 | if isinstance(name, str):
184 | load_filename = '%s_net_%s.pth' % (epoch, name)
185 | load_path = os.path.join(self.save_dir, load_filename)
186 | net = getattr(self, 'net' + name)
187 | if isinstance(net, torch.nn.DataParallel):
188 | net = net.module
189 | print('loading the model from %s' % load_path)
190 | # if you are using PyTorch newer than 0.4 (e.g., built from
191 | # GitHub source), you can remove str() on self.device
192 | state_dict = torch.load(load_path, map_location=str(self.device))
193 | if hasattr(state_dict, '_metadata'):
194 | del state_dict._metadata
195 |
196 | # patch InstanceNorm checkpoints prior to 0.4
197 | for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
198 | self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
199 | net.load_state_dict(state_dict)
200 |
201 | def print_networks(self, verbose):
202 | """Print the total number of parameters in the network and (if verbose) network architecture
203 |
204 | Parameters:
205 | verbose (bool) -- if verbose: print the network architecture
206 | """
207 | print('---------- Networks initialized -------------')
208 | for name in self.model_names:
209 | if isinstance(name, str):
210 | net = getattr(self, 'net' + name)
211 | num_params = 0
212 | for param in net.parameters():
213 | num_params += param.numel()
214 | if verbose:
215 | print(net)
216 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
217 | print('-----------------------------------------------')
218 |
219 | def set_requires_grad(self, nets, requires_grad=False):
220 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
221 | Parameters:
222 | nets (network list) -- a list of networks
223 | requires_grad (bool) -- whether the networks require gradients or not
224 | """
225 | if not isinstance(nets, list):
226 | nets = [nets]
227 | for net in nets:
228 | if net is not None:
229 | for param in net.parameters():
230 | param.requires_grad = requires_grad
231 |
--------------------------------------------------------------------------------
/intrinsic_compositing/albedo/model/pix2pix/models/networks.py:
--------------------------------------------------------------------------------
1 | from multiprocessing import reduction
2 | import torch
3 | import torch.nn as nn
4 | from torch.nn import init
5 | import functools
6 | from torch.optim import lr_scheduler
7 |
8 |
9 | ###############################################################################
10 | # Helper Functions
11 | ###############################################################################
12 |
13 |
14 | class Identity(nn.Module):
15 | def forward(self, x):
16 | return x
17 |
18 |
19 | def get_norm_layer(norm_type='instance'):
20 | """Return a normalization layer
21 |
22 | Parameters:
23 | norm_type (str) -- the name of the normalization layer: batch | instance | none
24 |
25 | For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
26 | For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
27 | """
28 | if norm_type == 'batch':
29 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
30 | elif norm_type == 'instance':
31 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
32 | elif norm_type == 'none':
33 | def norm_layer(x): return Identity()
34 | else:
35 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
36 | return norm_layer
37 |
38 |
39 | def get_scheduler(optimizer, opt):
40 | """Return a learning rate scheduler
41 |
42 | Parameters:
43 | optimizer -- the optimizer of the network
44 | opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
45 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
46 |
47 | For 'linear', we keep the same learning rate for the first epochs
48 | and linearly decay the rate to zero over the next epochs.
49 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
50 | See https://pytorch.org/docs/stable/optim.html for more details.
51 | """
52 | if opt.lr_policy == 'linear':
53 | def lambda_rule(epoch):
54 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1)
55 | return lr_l
56 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
57 | elif opt.lr_policy == 'step':
58 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
59 | elif opt.lr_policy == 'plateau':
60 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
61 | elif opt.lr_policy == 'cosine':
62 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
63 | else:
64 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
65 | return scheduler
66 |
67 |
68 | def init_weights(net, init_type='normal', init_gain=0.02):
69 | """Initialize network weights.
70 |
71 | Parameters:
72 | net (network) -- network to be initialized
73 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
74 | init_gain (float) -- scaling factor for normal, xavier and orthogonal.
75 |
76 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
77 | work better for some applications. Feel free to try yourself.
78 | """
79 | def init_func(m): # define the initialization function
80 | classname = m.__class__.__name__
81 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
82 | if init_type == 'normal':
83 | init.normal_(m.weight.data, 0.0, init_gain)
84 | elif init_type == 'xavier':
85 | init.xavier_normal_(m.weight.data, gain=init_gain)
86 | elif init_type == 'kaiming':
87 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
88 | elif init_type == 'orthogonal':
89 | init.orthogonal_(m.weight.data, gain=init_gain)
90 | else:
91 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
92 | if hasattr(m, 'bias') and m.bias is not None:
93 | init.constant_(m.bias.data, 0.0)
94 | elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
95 | init.normal_(m.weight.data, 1.0, init_gain)
96 | init.constant_(m.bias.data, 0.0)
97 |
98 | print('initialize network with %s' % init_type)
99 | net.apply(init_func) # apply the initialization function
100 |
101 |
102 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
103 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
104 | Parameters:
105 | net (network) -- the network to be initialized
106 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
107 | gain (float) -- scaling factor for normal, xavier and orthogonal.
108 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
109 |
110 | Return an initialized network.
111 | """
112 | if len(gpu_ids) > 0:
113 | assert(torch.cuda.is_available())
114 | net.to(gpu_ids[0])
115 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
116 | init_weights(net, init_type, init_gain=init_gain)
117 | return net
118 |
119 |
120 | def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]):
121 | """Create a generator
122 |
123 | Parameters:
124 | input_nc (int) -- the number of channels in input images
125 | output_nc (int) -- the number of channels in output images
126 | ngf (int) -- the number of filters in the last conv layer
127 | netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128
128 | norm (str) -- the name of normalization layers used in the network: batch | instance | none
129 | use_dropout (bool) -- if use dropout layers.
130 | init_type (str) -- the name of our initialization method.
131 | init_gain (float) -- scaling factor for normal, xavier and orthogonal.
132 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
133 |
134 | Returns a generator
135 |
136 | Our current implementation provides two types of generators:
137 | U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images)
138 | The original U-Net paper: https://arxiv.org/abs/1505.04597
139 |
140 | Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks)
141 | Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations.
142 | We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style).
143 |
144 |
145 | The generator has been initialized by . It uses RELU for non-linearity.
146 | """
147 | net = None
148 | norm_layer = get_norm_layer(norm_type=norm)
149 |
150 | if netG == 'resnet_9blocks':
151 | net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
152 | elif netG == 'resnet_6blocks':
153 | net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)
154 | elif netG == 'unet_128':
155 | net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
156 | elif netG == 'unet_256':
157 | net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
158 | else:
159 | raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
160 | return init_net(net, init_type, init_gain, gpu_ids)
161 |
162 |
163 | def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]):
164 | """Create a discriminator
165 |
166 | Parameters:
167 | input_nc (int) -- the number of channels in input images
168 | ndf (int) -- the number of filters in the first conv layer
169 | netD (str) -- the architecture's name: basic | n_layers | pixel
170 | n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers'
171 | norm (str) -- the type of normalization layers used in the network.
172 | init_type (str) -- the name of the initialization method.
173 | init_gain (float) -- scaling factor for normal, xavier and orthogonal.
174 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
175 |
176 | Returns a discriminator
177 |
178 | Our current implementation provides three types of discriminators:
179 | [basic]: 'PatchGAN' classifier described in the original pix2pix paper.
180 | It can classify whether 70×70 overlapping patches are real or fake.
181 | Such a patch-level discriminator architecture has fewer parameters
182 | than a full-image discriminator and can work on arbitrarily-sized images
183 | in a fully convolutional fashion.
184 |
185 | [n_layers]: With this mode, you can specify the number of conv layers in the discriminator
186 | with the parameter (default=3 as used in [basic] (PatchGAN).)
187 |
188 | [pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not.
189 | It encourages greater color diversity but has no effect on spatial statistics.
190 |
191 | The discriminator has been initialized by . It uses Leakly RELU for non-linearity.
192 | """
193 | net = None
194 | norm_layer = get_norm_layer(norm_type=norm)
195 |
196 | if netD == 'basic': # default PatchGAN classifier
197 | net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer)
198 | elif netD == 'n_layers': # more options
199 | net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer)
200 | elif netD == 'pixel': # classify if each pixel is real or fake
201 | net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
202 | else:
203 | raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD)
204 | return init_net(net, init_type, init_gain, gpu_ids)
205 |
206 |
207 | ##############################################################################
208 | # Classes
209 | ##############################################################################
210 | class GANLoss(nn.Module):
211 | """Define different GAN objectives.
212 |
213 | The GANLoss class abstracts away the need to create the target label tensor
214 | that has the same size as the input.
215 | """
216 |
217 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
218 | """ Initialize the GANLoss class.
219 |
220 | Parameters:
221 | gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
222 | target_real_label (bool) - - label for a real image
223 | target_fake_label (bool) - - label of a fake image
224 |
225 | Note: Do not use sigmoid as the last layer of Discriminator.
226 | LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
227 | """
228 | super(GANLoss, self).__init__()
229 | self.register_buffer('real_label', torch.tensor(target_real_label))
230 | self.register_buffer('fake_label', torch.tensor(target_fake_label))
231 | self.gan_mode = gan_mode
232 | if gan_mode == 'lsgan':
233 | self.loss = nn.MSELoss(reduction='none')
234 | elif gan_mode == 'vanilla':
235 | self.loss = nn.BCEWithLogitsLoss(reduction='none')
236 | elif gan_mode in ['wgangp']:
237 | self.loss = None
238 | else:
239 | raise NotImplementedError('gan mode %s not implemented' % gan_mode)
240 |
241 | def get_target_tensor(self, prediction, target_is_real):
242 | """Create label tensors with the same size as the input.
243 |
244 | Parameters:
245 | prediction (tensor) - - tpyically the prediction from a discriminator
246 | target_is_real (bool) - - if the ground truth label is for real images or fake images
247 |
248 | Returns:
249 | A label tensor filled with ground truth label, and with the size of the input
250 | """
251 |
252 | if target_is_real:
253 | target_tensor = self.real_label
254 | else:
255 | target_tensor = self.fake_label
256 | return target_tensor.expand_as(prediction)
257 |
258 | def __call__(self, prediction, target_is_real):
259 | """Calculate loss given Discriminator's output and grount truth labels.
260 |
261 | Parameters:
262 | prediction (tensor) - - tpyically the prediction output from a discriminator
263 | target_is_real (bool) - - if the ground truth label is for real images or fake images
264 |
265 | Returns:
266 | the calculated loss.
267 | """
268 | if self.gan_mode in ['lsgan', 'vanilla']:
269 | target_tensor = self.get_target_tensor(prediction, target_is_real)
270 | loss = self.loss(prediction, target_tensor)
271 | elif self.gan_mode == 'wgangp':
272 | if target_is_real:
273 | loss = -prediction.mean()
274 | else:
275 | loss = prediction.mean()
276 | return loss
277 |
278 |
279 | def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0):
280 | """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
281 |
282 | Arguments:
283 | netD (network) -- discriminator network
284 | real_data (tensor array) -- real images
285 | fake_data (tensor array) -- generated images from the generator
286 | device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
287 | type (str) -- if we mix real and fake data or not [real | fake | mixed].
288 | constant (float) -- the constant used in formula ( ||gradient||_2 - constant)^2
289 | lambda_gp (float) -- weight for this loss
290 |
291 | Returns the gradient penalty loss
292 | """
293 | if lambda_gp > 0.0:
294 | if type == 'real': # either use real images, fake images, or a linear interpolation of two.
295 | interpolatesv = real_data
296 | elif type == 'fake':
297 | interpolatesv = fake_data
298 | elif type == 'mixed':
299 | alpha = torch.rand(real_data.shape[0], 1, device=device)
300 | alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape)
301 | interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
302 | else:
303 | raise NotImplementedError('{} not implemented'.format(type))
304 | interpolatesv.requires_grad_(True)
305 | disc_interpolates = netD(interpolatesv)
306 | gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
307 | grad_outputs=torch.ones(disc_interpolates.size()).to(device),
308 | create_graph=True, retain_graph=True, only_inputs=True)
309 | gradients = gradients[0].view(real_data.size(0), -1) # flat the data
310 | gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps
311 | return gradient_penalty, gradients
312 | else:
313 | return 0.0, None
314 |
315 |
316 | class ResnetGenerator(nn.Module):
317 | """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
318 |
319 | We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
320 | """
321 |
322 | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
323 | """Construct a Resnet-based generator
324 |
325 | Parameters:
326 | input_nc (int) -- the number of channels in input images
327 | output_nc (int) -- the number of channels in output images
328 | ngf (int) -- the number of filters in the last conv layer
329 | norm_layer -- normalization layer
330 | use_dropout (bool) -- if use dropout layers
331 | n_blocks (int) -- the number of ResNet blocks
332 | padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
333 | """
334 | assert(n_blocks >= 0)
335 | super(ResnetGenerator, self).__init__()
336 | if type(norm_layer) == functools.partial:
337 | use_bias = norm_layer.func == nn.InstanceNorm2d
338 | else:
339 | use_bias = norm_layer == nn.InstanceNorm2d
340 |
341 | model = [nn.ReflectionPad2d(3),
342 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
343 | norm_layer(ngf),
344 | nn.ReLU(True)]
345 |
346 | n_downsampling = 2
347 | for i in range(n_downsampling): # add downsampling layers
348 | mult = 2 ** i
349 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
350 | norm_layer(ngf * mult * 2),
351 | nn.ReLU(True)]
352 |
353 | mult = 2 ** n_downsampling
354 | for i in range(n_blocks): # add ResNet blocks
355 |
356 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
357 |
358 | for i in range(n_downsampling): # add upsampling layers
359 | mult = 2 ** (n_downsampling - i)
360 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
361 | kernel_size=3, stride=2,
362 | padding=1, output_padding=1,
363 | bias=use_bias),
364 | norm_layer(int(ngf * mult / 2)),
365 | nn.ReLU(True)]
366 | model += [nn.ReflectionPad2d(3)]
367 | model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
368 | model += [nn.Tanh()]
369 |
370 | self.model = nn.Sequential(*model)
371 |
372 | def forward(self, input):
373 | """Standard forward"""
374 | return self.model(input)
375 |
376 |
377 | class ResnetBlock(nn.Module):
378 | """Define a Resnet block"""
379 |
380 | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
381 | """Initialize the Resnet block
382 |
383 | A resnet block is a conv block with skip connections
384 | We construct a conv block with build_conv_block function,
385 | and implement skip connections in function.
386 | Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
387 | """
388 | super(ResnetBlock, self).__init__()
389 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
390 |
391 | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
392 | """Construct a convolutional block.
393 |
394 | Parameters:
395 | dim (int) -- the number of channels in the conv layer.
396 | padding_type (str) -- the name of padding layer: reflect | replicate | zero
397 | norm_layer -- normalization layer
398 | use_dropout (bool) -- if use dropout layers.
399 | use_bias (bool) -- if the conv layer uses bias or not
400 |
401 | Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
402 | """
403 | conv_block = []
404 | p = 0
405 | if padding_type == 'reflect':
406 | conv_block += [nn.ReflectionPad2d(1)]
407 | elif padding_type == 'replicate':
408 | conv_block += [nn.ReplicationPad2d(1)]
409 | elif padding_type == 'zero':
410 | p = 1
411 | else:
412 | raise NotImplementedError('padding [%s] is not implemented' % padding_type)
413 |
414 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
415 | if use_dropout:
416 | conv_block += [nn.Dropout(0.5)]
417 |
418 | p = 0
419 | if padding_type == 'reflect':
420 | conv_block += [nn.ReflectionPad2d(1)]
421 | elif padding_type == 'replicate':
422 | conv_block += [nn.ReplicationPad2d(1)]
423 | elif padding_type == 'zero':
424 | p = 1
425 | else:
426 | raise NotImplementedError('padding [%s] is not implemented' % padding_type)
427 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
428 |
429 | return nn.Sequential(*conv_block)
430 |
431 | def forward(self, x):
432 | """Forward function (with skip connections)"""
433 | out = x + self.conv_block(x) # add skip connections
434 | return out
435 |
436 |
437 | class UnetGenerator(nn.Module):
438 | """Create a Unet-based generator"""
439 |
440 | def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
441 | """Construct a Unet generator
442 | Parameters:
443 | input_nc (int) -- the number of channels in input images
444 | output_nc (int) -- the number of channels in output images
445 | num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
446 | image of size 128x128 will become of size 1x1 # at the bottleneck
447 | ngf (int) -- the number of filters in the last conv layer
448 | norm_layer -- normalization layer
449 |
450 | We construct the U-Net from the innermost layer to the outermost layer.
451 | It is a recursive process.
452 | """
453 | super(UnetGenerator, self).__init__()
454 | # construct unet structure
455 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
456 | for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
457 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
458 | # gradually reduce the number of filters from ngf * 8 to ngf
459 | unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
460 | unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
461 | unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
462 | self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
463 |
464 | def forward(self, input):
465 | """Standard forward"""
466 | return self.model(input)
467 |
468 |
469 | class UnetSkipConnectionBlock(nn.Module):
470 | """Defines the Unet submodule with skip connection.
471 | X -------------------identity----------------------
472 | |-- downsampling -- |submodule| -- upsampling --|
473 | """
474 |
475 | def __init__(self, outer_nc, inner_nc, input_nc=None,
476 | submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
477 | """Construct a Unet submodule with skip connections.
478 |
479 | Parameters:
480 | outer_nc (int) -- the number of filters in the outer conv layer
481 | inner_nc (int) -- the number of filters in the inner conv layer
482 | input_nc (int) -- the number of channels in input images/features
483 | submodule (UnetSkipConnectionBlock) -- previously defined submodules
484 | outermost (bool) -- if this module is the outermost module
485 | innermost (bool) -- if this module is the innermost module
486 | norm_layer -- normalization layer
487 | use_dropout (bool) -- if use dropout layers.
488 | """
489 | super(UnetSkipConnectionBlock, self).__init__()
490 | self.outermost = outermost
491 | if type(norm_layer) == functools.partial:
492 | use_bias = norm_layer.func == nn.InstanceNorm2d
493 | else:
494 | use_bias = norm_layer == nn.InstanceNorm2d
495 | if input_nc is None:
496 | input_nc = outer_nc
497 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
498 | stride=2, padding=1, bias=use_bias)
499 | downrelu = nn.LeakyReLU(0.2, True)
500 | downnorm = norm_layer(inner_nc)
501 | uprelu = nn.ReLU(True)
502 | upnorm = norm_layer(outer_nc)
503 |
504 | if outermost:
505 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
506 | kernel_size=4, stride=2,
507 | padding=1)
508 | down = [downconv]
509 | up = [uprelu, upconv, nn.Tanh()]
510 | model = down + [submodule] + up
511 | elif innermost:
512 | upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
513 | kernel_size=4, stride=2,
514 | padding=1, bias=use_bias)
515 | down = [downrelu, downconv]
516 | up = [uprelu, upconv, upnorm]
517 | model = down + up
518 | else:
519 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
520 | kernel_size=4, stride=2,
521 | padding=1, bias=use_bias)
522 | down = [downrelu, downconv, downnorm]
523 | up = [uprelu, upconv, upnorm]
524 |
525 | if use_dropout:
526 | model = down + [submodule] + up + [nn.Dropout(0.5)]
527 | else:
528 | model = down + [submodule] + up
529 |
530 | self.model = nn.Sequential(*model)
531 |
532 | def forward(self, x):
533 | if self.outermost:
534 | return self.model(x)
535 | else: # add skip connections
536 | return torch.cat([x, self.model(x)], 1)
537 |
538 |
539 | class NLayerDiscriminator(nn.Module):
540 | """Defines a PatchGAN discriminator"""
541 |
542 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
543 | """Construct a PatchGAN discriminator
544 |
545 | Parameters:
546 | input_nc (int) -- the number of channels in input images
547 | ndf (int) -- the number of filters in the last conv layer
548 | n_layers (int) -- the number of conv layers in the discriminator
549 | norm_layer -- normalization layer
550 | """
551 | super(NLayerDiscriminator, self).__init__()
552 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
553 | use_bias = norm_layer.func == nn.InstanceNorm2d
554 | else:
555 | use_bias = norm_layer == nn.InstanceNorm2d
556 |
557 | kw = 4
558 | padw = 1
559 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
560 | nf_mult = 1
561 | nf_mult_prev = 1
562 | for n in range(1, n_layers): # gradually increase the number of filters
563 | nf_mult_prev = nf_mult
564 | nf_mult = min(2 ** n, 8)
565 | sequence += [
566 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
567 | norm_layer(ndf * nf_mult),
568 | nn.LeakyReLU(0.2, True)
569 | ]
570 |
571 | nf_mult_prev = nf_mult
572 | nf_mult = min(2 ** n_layers, 8)
573 | sequence += [
574 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
575 | norm_layer(ndf * nf_mult),
576 | nn.LeakyReLU(0.2, True)
577 | ]
578 |
579 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
580 | self.model = nn.Sequential(*sequence)
581 |
582 | def forward(self, input):
583 | """Standard forward."""
584 | return self.model(input)
585 |
586 |
587 | class PixelDiscriminator(nn.Module):
588 | """Defines a 1x1 PatchGAN discriminator (pixelGAN)"""
589 |
590 | def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):
591 | """Construct a 1x1 PatchGAN discriminator
592 |
593 | Parameters:
594 | input_nc (int) -- the number of channels in input images
595 | ndf (int) -- the number of filters in the last conv layer
596 | norm_layer -- normalization layer
597 | """
598 | super(PixelDiscriminator, self).__init__()
599 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
600 | use_bias = norm_layer.func == nn.InstanceNorm2d
601 | else:
602 | use_bias = norm_layer == nn.InstanceNorm2d
603 |
604 | self.net = [
605 | nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
606 | nn.LeakyReLU(0.2, True),
607 | nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
608 | norm_layer(ndf * 2),
609 | nn.LeakyReLU(0.2, True),
610 | nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
611 |
612 | self.net = nn.Sequential(*self.net)
613 |
614 | def forward(self, input):
615 | """Standard forward."""
616 | return self.net(input)
617 |
--------------------------------------------------------------------------------
/intrinsic_compositing/albedo/pipeline.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | import os
3 | from skimage.transform import resize
4 | import torch
5 | import numpy as np
6 | import torchvision.transforms as transforms
7 | from PIL import Image
8 | import torch.utils.data as data
9 | from argparse import Namespace
10 |
11 | from chrislib.general import np_to_pil
12 |
13 | from intrinsic_compositing.albedo.model.editingnetwork_trainer import EditingNetworkTrainer
14 |
15 | PAPER_WEIGHTS_URL = 'https://github.com/compphoto/IntrinsicCompositing/releases/download/1.0.0/albedo_paper_weights.pth'
16 | CACHE_PATH = torch.hub.get_dir()
17 |
18 |
19 | def get_transform(opt, grayscale=False, method=Image.BICUBIC, convert=True):
20 | transform_list = []
21 | if grayscale:
22 | transform_list.append(transforms.Grayscale(1))
23 | method=Image.BILINEAR
24 | if 'resize' in opt['preprocess']:
25 | osize = [opt['load_size'], opt['load_size']]
26 | transform_list.append(transforms.Resize(osize, method))
27 |
28 | if 'crop' in opt['preprocess']:
29 | transform_list.append(transforms.RandomCrop(opt['crop_size']))
30 |
31 | if not opt['no_flip']:
32 | transform_list.append(transforms.RandomHorizontalFlip())
33 |
34 | if convert:
35 | transform_list += [transforms.ToTensor()]
36 | # if grayscale:
37 | # transform_list += [transforms.Normalize((0.5,), (0.5,))]
38 | # else:
39 | # transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
40 | return transforms.Compose(transform_list)
41 |
42 | def match_scalar(source, target, mask=None, min_percentile=0, max_percentile=100):
43 |
44 | if mask is None:
45 | # mask = np.ones((self.args.crop_size, self.args.crop_size), dtype=bool)
46 | mask = np.ones((384, 384), dtype=bool)
47 |
48 | target_masked = target[:,mask]
49 |
50 |
51 | # consider all values up to a percentile
52 | p_threshold_min = np.percentile(target_masked.reshape(-1), min_percentile)
53 | p_mask_min = np.greater_equal(target, p_threshold_min)
54 |
55 | p_threshold_max = np.percentile(target_masked.reshape(-1), max_percentile)
56 | p_mask_max = np.less_equal(target, p_threshold_max)
57 |
58 | p_mask = np.logical_and(p_mask_max, p_mask_min)
59 | mask = np.logical_and(p_mask, mask)
60 |
61 | flat_source = source[mask]
62 | flat_target = target[mask]
63 |
64 | scalar, _, _, _ = np.linalg.lstsq(
65 | flat_source.reshape(-1, 1), flat_target.reshape(-1, 1), rcond=None)
66 | source_scaled = source * scalar
67 | return source_scaled, scalar
68 |
69 | def prep_input(rgb_img, mask_img, shading_img, reproduce_paper=False):
70 | # this function takes the srgb image (rgb_img) the mask
71 | # and the shading as numpy arrays between [0-1]
72 |
73 | opt = {}
74 | opt['load_size'] = 384
75 | opt['crop_size'] = 384
76 | opt['preprocess'] = 'resize'
77 | opt['no_flip'] = True
78 |
79 | rgb_transform = get_transform(opt, grayscale=False)
80 | mask_transform = get_transform(opt, grayscale=True)
81 | shading_transform = get_transform(opt, grayscale=False, method=Image.BILINEAR)
82 |
83 | mask_img_np = mask_img
84 | if len(mask_img_np.shape) == 3:
85 | mask_img_np = mask_img_np[:, :, 0]
86 |
87 | if len(shading_img.shape) == 3:
88 | shading_img = shading_img[:, :, 0]
89 |
90 | full_shd = ((1.0 / (shading_img)) - 1.0)
91 | full_msk = resize(mask_img_np, full_shd.shape)
92 | # full_msk = resize(np.array(mask_img) / 255., full_shd.shape)
93 | full_img = resize(rgb_img, full_shd.shape)
94 | full_alb = (full_img ** 2.2) / full_shd[:, :, None].clip(1e-4)
95 | full_alb = full_alb.clip(1e-4) ** (1/2.2)
96 |
97 | full_alb = torch.from_numpy(full_alb).permute(2, 0, 1)
98 | full_shd = torch.from_numpy(full_shd).unsqueeze(0)
99 | full_msk = torch.from_numpy(full_msk).unsqueeze(0)
100 |
101 | srgb = rgb_transform(np_to_pil(rgb_img))
102 | rgb_mask = np.stack([mask_img_np] * 3, -1)
103 | mask = mask_transform(np_to_pil(rgb_mask))
104 |
105 | if reproduce_paper:
106 | invshading = shading_transform(np_to_pil(shading_img)) / (2**16-1)
107 | else:
108 | invshading = shading_transform(np_to_pil(shading_img))
109 |
110 | shading = ((1.0 / invshading) - 1.0)
111 |
112 | ## compute albedo
113 | rgb = srgb ** 2.2
114 | albedo = rgb / shading
115 |
116 | ## min max normalize the albedo:
117 | # albedo = (albedo - albedo.min()) / (albedo.max() - albedo.min())
118 |
119 | ## match the albedo to the rgb
120 | albedo, scalar = match_scalar(albedo.numpy(),rgb.numpy())
121 | albedo = torch.from_numpy(albedo)
122 | albedo = albedo ** (1/2.2)
123 |
124 | shading = shading / scalar
125 | ## clip albedo to [0,1]
126 | albedo = torch.clamp(albedo,0,1)
127 |
128 |
129 | # all of these need to have a batch dimension as if they are coming from the dataloader
130 | return {
131 | 'srgb': srgb.unsqueeze(0),
132 | 'mask': mask.unsqueeze(0),
133 | 'albedo':albedo.unsqueeze(0),
134 | 'shading': shading.unsqueeze(0),
135 | 'albedo_full' : full_alb.unsqueeze(0),
136 | 'shading_full' : full_shd.unsqueeze(0),
137 | 'mask_full' : full_msk.unsqueeze(0)
138 | }
139 |
140 |
141 | def load_albedo_harmonizer():
142 |
143 | args = Namespace()
144 | args.nops = 4
145 | args.gpu_ids = [0]
146 | args.blursharpen = 0
147 | args.fake_gen_lowdev = 0
148 | args.bn_momentum = 0.01
149 | args.edit_loss = ''
150 | args.loss_relu_bias = 0
151 | args.crop_size = 384
152 | args.load_size = 384
153 | args.lr_d = 0.00001
154 | args.lr_editnet = 0.00001
155 | args.batch_size = 1
156 |
157 | args.checkpoint_load_path = f'{CACHE_PATH}/albedo_harmonization/albedo_paper_weights.pth'
158 |
159 | if not os.path.exists(args.checkpoint_load_path):
160 | os.makedirs(f'{CACHE_PATH}/albedo_harmonization', exist_ok=True)
161 | os.system(f'wget {PAPER_WEIGHTS_URL} -P {CACHE_PATH}/albedo_harmonization')
162 |
163 | trainer = EditingNetworkTrainer(args)
164 |
165 | return trainer
166 |
167 | def harmonize_albedo(img, shd, msk, trainer, reproduce_paper=False):
168 |
169 | trainer.setEval()
170 |
171 | data = prep_input(img, shd, msk, reproduce_paper=reproduce_paper)
172 |
173 | trainer.setinput_HR(data)
174 |
175 | with torch.no_grad():
176 | trainer.forward()
177 |
178 | albedo_out = trainer.result[0,...].cpu().detach().numpy().squeeze().transpose([1,2,0])
179 | result = albedo_out.clip(0, 1)
180 | return result
181 |
182 |
--------------------------------------------------------------------------------
/intrinsic_compositing/albedo/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/intrinsic_compositing/albedo/utils/__init__.py
--------------------------------------------------------------------------------
/intrinsic_compositing/albedo/utils/datautils.py:
--------------------------------------------------------------------------------
1 | import torchvision.transforms as transforms
2 | from PIL import Image
3 | import torch.utils.data as data
4 |
5 |
6 | def get_transform(opt, grayscale=False, method=Image.BICUBIC, convert=True):
7 | transform_list = []
8 | if grayscale:
9 | transform_list.append(transforms.Grayscale(1))
10 | method=Image.BILINEAR
11 | if 'resize' in opt['preprocess']:
12 | osize = [opt['load_size'], opt['load_size']]
13 | transform_list.append(transforms.Resize(osize, method))
14 |
15 | if 'crop' in opt['preprocess']:
16 | transform_list.append(transforms.RandomCrop(opt['crop_size']))
17 |
18 | if not opt['no_flip']:
19 | transform_list.append(transforms.RandomHorizontalFlip())
20 |
21 | if convert:
22 | transform_list += [transforms.ToTensor()]
23 | # if grayscale:
24 | # transform_list += [transforms.Normalize((0.5,), (0.5,))]
25 | # else:
26 | # transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
27 | return transforms.Compose(transform_list)
--------------------------------------------------------------------------------
/intrinsic_compositing/albedo/utils/depthutils.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | curr_path = "/localhome/smh31/Repositories/intrinsic_composite/realismnet/model/"
5 | # OUR
6 | from BoostingMonocularDepth.utils import ImageandPatchs, ImageDataset, generatemask, getGF_fromintegral, calculateprocessingres, rgb2gray,\
7 | applyGridpatch
8 |
9 | # MIDAS
10 | import BoostingMonocularDepth.midas.utils
11 | from BoostingMonocularDepth.midas.models.midas_net import MidasNet
12 | from BoostingMonocularDepth.midas.models.transforms import Resize, NormalizeImage, PrepareForNet
13 |
14 | # PIX2PIX : MERGE NET
15 | from BoostingMonocularDepth.pix2pix.options.test_options import TestOptions
16 | from BoostingMonocularDepth.pix2pix.models.pix2pix4depth_model import Pix2Pix4DepthModel
17 |
18 | import torch
19 | from torchvision.transforms import Compose
20 | from torchvision.transforms import transforms
21 |
22 | import time
23 | import os
24 | import cv2
25 | import numpy as np
26 | import argparse
27 | from argparse import Namespace
28 | import warnings
29 |
30 | whole_size_threshold = 3000 # R_max from the paper
31 | GPU_threshold = 1600 - 32 # Limit for the GPU (NVIDIA RTX 2080), can be adjusted
32 |
33 | def create_depth_models(device='cuda', midas_path=None, pix2pix_path=None):
34 |
35 | # opt = TestOptions().parse()
36 | opt = Namespace(Final=False, R0=False, R20=False, aspect_ratio=1.0, batch_size=1, checkpoints_dir=f'{curr_path}/BoostingMonocularDepth/pix2pix/checkpoints', colorize_results=False, crop_size=672, data_dir=None, dataroot=None, dataset_mode='depthmerge', depthNet=None, direction='AtoB', display_winsize=256, epoch='latest', eval=False, generatevideo=None, gpu_ids=[0], init_gain=0.02, init_type='normal', input_nc=2, isTrain=False, load_iter=0, load_size=672, max_dataset_size=10000, max_res=float('inf'), model='pix2pix4depth', n_layers_D=3, name='mergemodel', ndf=64, netD='basic', netG='unet_1024', net_receptive_field_size=None, ngf=64, no_dropout=False, no_flip=False, norm='none', num_test=50, num_threads=4, output_dir=None, output_nc=1, output_resolution=None, phase='test', pix2pixsize=None, preprocess='resize_and_crop', savecrops=None, savewholeest=None, serial_batches=False, suffix='', verbose=False)
37 | # opt = Namespace()
38 | # opt.gpu_ids = [0]
39 | # opt.isTrain = False
40 | # global pix2pixmodel
41 |
42 | pix2pixmodel = Pix2Pix4DepthModel(opt)
43 |
44 | if pix2pix_path == None:
45 | pix2pixmodel.save_dir = f'{curr_path}/BoostingMonocularDepth/pix2pix/checkpoints/mergemodel'
46 | else:
47 | pix2pixmode.save_dir = pix2pix_path
48 |
49 | pix2pixmodel.load_networks('latest')
50 | pix2pixmodel.eval()
51 |
52 | if midas_path == None:
53 | midas_model_path = f"{curr_path}/BoostingMonocularDepth/midas/model.pt"
54 | else:
55 | midas_model_path = midas_path
56 |
57 | # global midasmodel
58 | midasmodel = MidasNet(midas_model_path, non_negative=True)
59 | midasmodel.to(device)
60 | midasmodel.eval()
61 |
62 | return [pix2pixmodel, midasmodel]
63 |
64 |
65 | def get_depth(img, models, threshold=0.2):
66 |
67 | pix2pixmodel, midasmodel = models
68 |
69 | # Generate mask used to smoothly blend the local pathc estimations to the base estimate.
70 | # It is arbitrarily large to avoid artifacts during rescaling for each crop.
71 | mask_org = generatemask((3000, 3000))
72 | mask = mask_org.copy()
73 |
74 | # Value x of R_x defined in the section 5 of the main paper.
75 | r_threshold_value = threshold
76 |
77 | # print("start processing")
78 |
79 | input_resolution = img.shape
80 |
81 | scale_threshold = 3 # Allows up-scaling with a scale up to 3
82 |
83 | # Find the best input resolution R-x. The resolution search described in section 5-double estimation of the main paper and section B of the
84 | # supplementary material.
85 | whole_image_optimal_size, patch_scale = calculateprocessingres(img, 384,
86 | r_threshold_value, scale_threshold,
87 | whole_size_threshold)
88 |
89 | # print('\t wholeImage being processed in :', whole_image_optimal_size)
90 |
91 | # Generate the base estimate using the double estimation.
92 | whole_estimate = doubleestimate(img, 384, whole_image_optimal_size, 1024, pix2pixmodel, midasmodel)
93 | whole_estimate = cv2.resize(whole_estimate, (input_resolution[1], input_resolution[0]), interpolation=cv2.INTER_CUBIC)
94 |
95 | return whole_estimate
96 |
97 |
98 | # Generate a double-input depth estimation
99 | def doubleestimate(img, size1, size2, pix2pixsize, pix2pixmodel, midasmodel):
100 | # Generate the low resolution estimation
101 | estimate1 = singleestimate(img, size1, midasmodel)
102 | # Resize to the inference size of merge network.
103 | estimate1 = cv2.resize(estimate1, (pix2pixsize, pix2pixsize), interpolation=cv2.INTER_CUBIC)
104 |
105 | # Generate the high resolution estimation
106 | estimate2 = singleestimate(img, size2, midasmodel)
107 | # Resize to the inference size of merge network.
108 | estimate2 = cv2.resize(estimate2, (pix2pixsize, pix2pixsize), interpolation=cv2.INTER_CUBIC)
109 |
110 | # Inference on the merge model
111 | pix2pixmodel.set_input(estimate1, estimate2)
112 | pix2pixmodel.test()
113 | visuals = pix2pixmodel.get_current_visuals()
114 | prediction_mapped = visuals['fake_B']
115 | prediction_mapped = (prediction_mapped+1)/2
116 | prediction_mapped = (prediction_mapped - torch.min(prediction_mapped)) / (
117 | torch.max(prediction_mapped) - torch.min(prediction_mapped))
118 | prediction_mapped = prediction_mapped.squeeze().cpu().numpy()
119 |
120 | return prediction_mapped
121 |
122 |
123 | # Generate a single-input depth estimation
124 | def singleestimate(img, msize, midasmodel):
125 | if msize > GPU_threshold:
126 | # print(" \t \t DEBUG| GPU THRESHOLD REACHED", msize, '--->', GPU_threshold)
127 | msize = GPU_threshold
128 |
129 | return estimatemidas(img, midasmodel, msize)
130 | # elif net_type == 1:
131 | # return estimatesrl(img, msize)
132 | # elif net_type == 2:
133 | # return estimateleres(img, msize)
134 |
135 |
136 | def estimatemidas(img, midasmodel, msize, device='cuda'):
137 | # MiDas -v2 forward pass script adapted from https://github.com/intel-isl/MiDaS/tree/v2
138 |
139 | transform = Compose(
140 | [
141 | Resize(
142 | msize,
143 | msize,
144 | resize_target=None,
145 | keep_aspect_ratio=True,
146 | ensure_multiple_of=32,
147 | resize_method="upper_bound",
148 | image_interpolation_method=cv2.INTER_CUBIC,
149 | ),
150 | NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
151 | PrepareForNet(),
152 | ]
153 | )
154 |
155 | img_input = transform({"image": img})["image"]
156 |
157 | # Forward pass
158 | with torch.no_grad():
159 | sample = torch.from_numpy(img_input).to(device).unsqueeze(0)
160 | prediction = midasmodel.forward(sample)
161 |
162 | prediction = prediction.squeeze().cpu().numpy()
163 | prediction = cv2.resize(prediction, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_CUBIC)
164 |
165 | # Normalization
166 | depth_min = prediction.min()
167 | depth_max = prediction.max()
168 |
169 | if depth_max - depth_min > np.finfo("float").eps:
170 | prediction = (prediction - depth_min) / (depth_max - depth_min)
171 | else:
172 | prediction = 0
173 |
174 | return prediction
175 |
--------------------------------------------------------------------------------
/intrinsic_compositing/albedo/utils/edits.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from kornia.color import rgb_to_hsv, hsv_to_rgb
4 |
5 | import torchvision.transforms.functional as F
6 | # from ..argumentsparser import args
7 |
8 | COLORCURVE_L = 8
9 |
10 | def apply_whitebalancing(input, parameters):
11 | param = parameters['whitebalancing']
12 | param = param / (param[:,1:2] + 1e-9)
13 | result = input / (param[:,:,None,None] + 1e-9)
14 | return result
15 |
16 | def apply_colorcurve(input, parameters):
17 | color_curve_param = torch.reshape(parameters['colorcurve'],(-1,3,COLORCURVE_L))
18 | color_curve_sum = torch.sum(color_curve_param,dim=[2])
19 | total_image = torch.zeros_like(input)
20 | for i in range(COLORCURVE_L):
21 | total_image += torch.clip(input * COLORCURVE_L - i, 0, 1) * color_curve_param[:,:,i][:,:,None,None]
22 | result = total_image / (color_curve_sum[:,:,None,None] + 1e-9)
23 | return result
24 |
25 | def apply_saturation(input, parameters):
26 | hsv = rgb_to_hsv(input)
27 | param = parameters['saturation'][:,:,None,None]
28 | s_new = hsv[:,1:2,:,:] * param
29 | hsv_new = hsv.clone()
30 | hsv_new[:,1:2,:,:] = s_new
31 | result = hsv_to_rgb(hsv_new)
32 | return result
33 |
34 | def apply_exposure(input, parameters):
35 | result = input * parameters['exposure'][:,:,None,None]
36 | return result
37 |
38 |
39 | def apply_blur(input, parameters):
40 | sigma = parameters['blur'][:,:,None,None]
41 | kernelsize = 2*torch.ceil(2*sigma)+1.
42 |
43 | result = torch.zeros_like(input)
44 | for B in range(input.shape[0]):
45 | kernelsize_ = (int(kernelsize[B].item()), int(kernelsize[B].item()))
46 | sigma_ = (sigma[B].item(), sigma[B].item())
47 | result[B,:,:,:] = F.gaussian_blur(input[B:B+1,:,:,:], kernelsize_, sigma_)
48 | return result
49 |
50 | def apply_sharpness(input, parameters):
51 | param = parameters['sharpness'][:,0]
52 | result = torch.zeros_like(input)
53 |
54 | for B in range(input.shape[0]):
55 | result[B,:,:,:] = F.adjust_sharpness(input[B:B+1,:,:,:], param[B])
56 | return result
57 |
58 | def get_edits(nops):
59 | if nops == 4:
60 | return {
61 | 0: apply_whitebalancing,
62 | 1: apply_colorcurve,
63 | 2: apply_saturation,
64 | 3: apply_exposure,
65 | }
66 | else:
67 | return {
68 | 0: apply_whitebalancing,
69 | 1: apply_colorcurve,
70 | 2: apply_saturation,
71 | 3: apply_exposure,
72 | 4: apply_blur,
73 | 5: apply_sharpness,
74 | }
75 |
--------------------------------------------------------------------------------
/intrinsic_compositing/albedo/utils/networkutils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 |
6 | def loadmodelweights(net, load_path, device):
7 | if isinstance(net, torch.nn.DataParallel):
8 | net = net.module
9 | print('loading the model from %s' % load_path)
10 | # if you are using PyTorch newer than 0.4 (e.g., built from
11 | # GitHub source), you can remove str() on self.device
12 | state_dict = torch.load(load_path, map_location=str(device))
13 | if hasattr(state_dict, '_metadata'):
14 | del state_dict._metadata
15 | net.load_state_dict(state_dict)
16 |
17 |
18 | def init_net(net, gpu_ids=[]):
19 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support);
20 | """
21 | if len(gpu_ids) > 0:
22 | assert(torch.cuda.is_available())
23 | net.to(gpu_ids[0])
24 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
25 | return net
26 |
27 | def _calc_same_pad(i, k, s, d):
28 | return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)
29 |
30 |
31 | def _same_pad_arg(input_size, kernel_size, stride, dilation):
32 | ih, iw = input_size
33 | kh, kw = kernel_size
34 | pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0])
35 | pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1])
36 | return [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
37 |
38 | class Conv2dSameExport(nn.Conv2d):
39 |
40 | """ ONNX export friendly Tensorflow like 'SAME' convolution wrapper for 2D convolutions
41 | """
42 |
43 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
44 | padding=0, dilation=1, groups=1, bias=True):
45 | super(Conv2dSameExport, self).__init__(
46 | in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
47 | self.pad = None
48 | self.pad_input_size = (0, 0)
49 |
50 | def forward(self, x):
51 | input_size = x.size()[-2:]
52 | if self.pad is None:
53 | pad_arg = _same_pad_arg(
54 | input_size, self.weight.size()[-2:], self.stride, self.dilation)
55 | self.pad = nn.ZeroPad2d(pad_arg)
56 | self.pad_input_size = input_size
57 | else:
58 | assert self.pad_input_size == input_size
59 |
60 | x = self.pad(x)
61 | return F.conv2d(
62 | x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
--------------------------------------------------------------------------------
/intrinsic_compositing/albedo/utils/utils.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 |
3 | def normalize(x):
4 | return (x - 0.5) / 0.5
5 |
6 | def create_exp_name(args, prefix='RealismNet'):
7 |
8 | components = []
9 | components.append(f"{prefix}")
10 | components.append(f"{datetime.now().strftime('%m_%d_%Y_%H_%M_%S')}")
11 | components.append(f"{args.expdscp}")
12 | components.append(f"lrd_{args.lr_d}")
13 |
14 | name = "_".join(components)
15 | return name
16 |
17 | def create_exp_name_editnet(args, prefix='EditingNet'):
18 |
19 | components = []
20 | components.append(f"{prefix}")
21 | components.append(f"{datetime.now().strftime('%m_%d_%Y_%H_%M_%S')}")
22 | components.append(f"{args.expdscp}")
23 | components.append(f"lredit_{args.lr_editnet}")
24 | components.append(f"rlubias_{args.loss_relu_bias}")
25 | components.append(f"edtloss_{args.edit_loss}")
26 | components.append(f"fkgnlwdv_{args.fake_gen_lowdev}")
27 | components.append(f"blrshrpn_{args.blursharpen}")
28 |
29 | name = "_".join(components)
30 | return name
--------------------------------------------------------------------------------
/intrinsic_compositing/shading/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/intrinsic_compositing/shading/__init__.py
--------------------------------------------------------------------------------
/intrinsic_compositing/shading/pipeline.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.optim import Adam
3 |
4 | import numpy as np
5 |
6 | from skimage.transform import resize
7 |
8 | from chrislib.general import uninvert, invert, round_32, view
9 |
10 | from altered_midas.midas_net import MidasNet
11 |
12 | def load_reshading_model(path, device='cuda'):
13 |
14 | if path == 'paper_weights':
15 | state_dict = torch.hub.load_state_dict_from_url('https://github.com/compphoto/IntrinsicCompositing/releases/download/1.0.0/shading_paper_weights.pt', map_location=device, progress=True)
16 | elif path == 'further_trained':
17 | state_dict = torch.hub.load_state_dict_from_url('https://github.com/compphoto/IntrinsicCompositing/releases/download/1.0.0/further_trained.pt', map_location=device, progress=True)
18 | else:
19 | state_dict = torch.load(path)
20 |
21 | shd_model = MidasNet(input_channels=9)
22 | shd_model.load_state_dict(state_dict)
23 | shd_model = shd_model.eval()
24 | shd_model = shd_model.to(device)
25 |
26 | return shd_model
27 |
28 | def spherical2cart(r, theta, phi):
29 | return [
30 | r * torch.sin(theta) * torch.cos(phi),
31 | r * torch.sin(theta) * torch.sin(phi),
32 | r * torch.cos(theta)
33 | ]
34 |
35 | def run_optimization(params, A, b):
36 |
37 | optim = Adam([params], lr=0.01)
38 | prev_loss = 1000
39 |
40 | init_params = params.clone()
41 |
42 | for i in range(500):
43 | optim.zero_grad()
44 |
45 | x, y, z = spherical2cart(params[2], params[0], params[1])
46 |
47 | dir_shd = (A[:, 0] * x) + (A[:, 1] * y) + (A[:, 2] * z)
48 | pred_shd = dir_shd + params[3]
49 |
50 | loss = torch.nn.functional.mse_loss(pred_shd.reshape(-1), b)
51 |
52 | loss.backward()
53 |
54 | optim.step()
55 |
56 | # theta can range from 0 -> pi/2 (0 to 90 degrees)
57 | # phi can range from 0 -> 2pi (0 to 360 degrees)
58 | with torch.no_grad():
59 | if params[0] < 0:
60 | params[0] = 0
61 |
62 | if params[0] > np.pi / 2:
63 | params[0] = np.pi / 2
64 |
65 | if params[1] < 0:
66 | params[1] = 0
67 |
68 | if params[1] > 2 * np.pi:
69 | params[1] = 2 * np.pi
70 |
71 | if params[2] < 0:
72 | params[2] = 0
73 |
74 | if params[3] < 0.1:
75 | params[3] = 0.1
76 |
77 | delta = prev_loss - loss
78 |
79 | if delta < 0.0001:
80 | break
81 |
82 | prev_loss = loss
83 |
84 | return loss, params
85 |
86 | def test_init(params, A, b):
87 | x, y, z = spherical2cart(params[2], params[0], params[1])
88 |
89 | dir_shd = (A[:, 0] * x) + (A[:, 1] * y) + (A[:, 2] * z)
90 | pred_shd = dir_shd + params[3]
91 |
92 | loss = torch.nn.functional.mse_loss(pred_shd.reshape(-1), b)
93 | return loss
94 |
95 | def get_light_coeffs(shd, nrm, img, mask=None, bias=True):
96 | img = resize(img, shd.shape)
97 |
98 | reg_shd = uninvert(shd)
99 | valid = (img.mean(-1) > 0.05) * (img.mean(-1) < 0.95)
100 |
101 | if mask is not None:
102 | valid *= (mask == 0)
103 |
104 | nrm = (nrm * 2.0) - 1.0
105 |
106 | A = nrm[valid == 1]
107 | # A = nrm.reshape(-1, 3)
108 | A /= np.linalg.norm(A, axis=1, keepdims=True)
109 |
110 | b = reg_shd[valid == 1]
111 | # b = reg_shd.reshape(-1)
112 |
113 | # parameters are theta, phi, and bias (c)
114 | A = torch.from_numpy(A)
115 | b = torch.from_numpy(b)
116 |
117 | min_init = 1000
118 | for t in np.arange(0, np.pi/2, 0.1):
119 | for p in np.arange(0, 2*np.pi, 0.25):
120 | params = torch.nn.Parameter(torch.tensor([t, p, 1, 0.5]))
121 | init_loss = test_init(params, A, b)
122 |
123 | if init_loss < min_init:
124 | best_init = params
125 | min_init = init_loss
126 | # print('new min:', min_init)
127 |
128 | loss, params = run_optimization(best_init, A, b)
129 |
130 | nrm_vis = nrm.copy()
131 | nrm_vis = draw_normal_circle(nrm_vis, (50, 50), 40)
132 |
133 | x, y, z = spherical2cart(params[2], params[0], params[1])
134 |
135 | coeffs = torch.tensor([x, y, z]).reshape(3, 1).detach().numpy()
136 | out_shd = (nrm_vis.reshape(-1, 3) @ coeffs) + params[3].item()
137 |
138 | coeffs = np.array([x.item(), y.item(), z.item(), params[3].item()])
139 |
140 | return coeffs, out_shd.reshape(shd.shape)
141 |
142 | def draw_normal_circle(nrm, loc, rad):
143 | size = rad * 2
144 |
145 | lin = np.linspace(-1, 1, num=size)
146 | ys, xs = np.meshgrid(lin, lin)
147 |
148 | zs = np.sqrt((1.0 - (xs**2 + ys**2)).clip(0))
149 | valid = (zs != 0)
150 | normals = np.stack((ys[valid], -xs[valid], zs[valid]), 1)
151 |
152 | valid_mask = np.zeros((size, size))
153 | valid_mask[valid] = 1
154 |
155 | full_mask = np.zeros((nrm.shape[0], nrm.shape[1]))
156 | x = loc[0] - rad
157 | y = loc[1] - rad
158 | full_mask[y : y + size, x : x + size] = valid_mask
159 | # nrm[full_mask > 0] = (normals + 1.0) / 2.0
160 | nrm[full_mask > 0] = normals
161 |
162 | return nrm
163 |
164 | def generate_shd(nrm, coeffs, msk, bias=True, viz=False):
165 |
166 | # if viz:
167 | # nrm = draw_normal_circle(nrm.copy(), (50, 50), 40)
168 |
169 | nrm = (nrm * 2.0) - 1.0
170 |
171 | A = nrm.reshape(-1, 3)
172 | A /= np.linalg.norm(A, axis=1, keepdims=True)
173 |
174 | A_fg = nrm[msk == 1]
175 | A_fg /= np.linalg.norm(A_fg, axis=1, keepdims=True)
176 |
177 | if bias:
178 | A = np.concatenate((A, np.ones((A.shape[0], 1))), 1)
179 | A_fg = np.concatenate((A_fg, np.ones((A_fg.shape[0], 1))), 1)
180 |
181 | inf_shd = (A_fg @ coeffs)
182 | inf_shd = inf_shd.clip(0) + 0.2
183 |
184 | if viz:
185 | shd_viz = (A @ coeffs).reshape(nrm.shape[:2])
186 | shd_viz = shd_viz.clip(0) + 0.2
187 | return inf_shd, shd_viz
188 |
189 |
190 | return inf_shd
191 |
192 | def compute_reshading(orig, msk, inv_shd, depth, normals, alb, coeffs, model):
193 |
194 | # expects no channel dim on msk, shd and depth
195 | if len(inv_shd.shape) == 3:
196 | inv_shd = inv_shd[:, :, 0]
197 |
198 | if len(msk.shape) == 3:
199 | msk = msk[:, :, 0]
200 |
201 | if len(depth.shape) == 3:
202 | depth = depth[:, :, 0]
203 |
204 | h, w, _ = orig.shape
205 |
206 | # max_dim = max(h, w)
207 | # if max_dim > 1024:
208 | # scale = 1024 / max_dim
209 | # else:
210 | # scale = 1.0
211 |
212 | orig = resize(orig, (round_32(h), round_32(w)))
213 | alb = resize(alb, (round_32(h), round_32(w)))
214 | msk = resize(msk, (round_32(h), round_32(w)))
215 | inv_shd = resize(inv_shd, (round_32(h), round_32(w)))
216 | dpt = resize(depth, (round_32(h), round_32(w)))
217 | nrm = resize(normals, (round_32(h), round_32(w)))
218 | msk = msk.astype(np.single)
219 |
220 | hard_msk = (msk > 0.5)
221 |
222 | reg_shd = uninvert(inv_shd)
223 | img = (alb * reg_shd[:, :, None]).clip(0, 1)
224 |
225 | orig_alb = orig / reg_shd[:, :, None].clip(1e-4)
226 |
227 | bad_shd_np = reg_shd.copy()
228 | inf_shd = generate_shd(nrm, coeffs, hard_msk)
229 | bad_shd_np[hard_msk == 1] = inf_shd
230 |
231 | bad_img_np = alb * bad_shd_np[:, :, None]
232 |
233 | sem_msk = torch.from_numpy(msk).unsqueeze(0)
234 | bad_img = torch.from_numpy(bad_img_np).permute(2, 0, 1)
235 | bad_shd = torch.from_numpy(invert(bad_shd_np)).unsqueeze(0)
236 | in_nrm = torch.from_numpy(nrm).permute(2, 0, 1)
237 | in_dpt = torch.from_numpy(dpt).unsqueeze(0)
238 | # inp = torch.cat((sem_msk, bad_img, bad_shd), dim=0).unsqueeze(0)
239 | inp = torch.cat((sem_msk, bad_img, bad_shd, in_nrm, in_dpt), dim=0).unsqueeze(0)
240 | inp = inp.cuda()
241 |
242 | with torch.no_grad():
243 | out = model(inp).squeeze()
244 |
245 | fin_shd = out.detach().cpu().numpy()
246 | fin_shd = uninvert(fin_shd)
247 | fin_img = alb * fin_shd[:, :, None]
248 |
249 | normals = resize(nrm, (h, w))
250 | fin_shd = resize(fin_shd, (h, w))
251 | fin_img = resize(fin_img, (h, w))
252 | bad_shd_np = resize(bad_shd_np, (h, w))
253 |
254 | result = {}
255 | result['reshading'] = fin_shd
256 | result['init_shading'] = bad_shd_np
257 | result['composite'] = (fin_img ** (1/2.2)).clip(0, 1)
258 | result['normals'] = normals
259 |
260 | return result
261 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 | setuptools.setup(
3 | name="intrinsic_compositing",
4 | version="0.0.1",
5 | author="Chris Careaga",
6 | author_email="chris_careaga@sfu.ca",
7 | description='a package containing to the code for the paper "Intrinsic Harmonization for Illumination-Aware Compositing"',
8 | url="",
9 | packages=setuptools.find_packages(),
10 | license="",
11 | python_requires=">3.6",
12 | install_requires=[
13 | 'altered_midas @ git+https://github.com/CCareaga/MiDaS@master',
14 | 'chrislib @ git+https://github.com/CCareaga/chrislib@main',
15 | 'omnidata_tools @ git+https://github.com/CCareaga/omnidata@main',
16 | 'boosted_depth @ git+https://github.com/CCareaga/BoostingMonocularDepth@main',
17 | 'intrinsic @ git+https://github.com/compphoto/intrinsic@d9741e99b2997e679c4055e7e1f773498b791288'
18 | ]
19 | )
20 |
--------------------------------------------------------------------------------