├── .gitignore ├── LICENSE ├── README.md ├── figures ├── IntrinsicCompositingVideo.jpg ├── astronauts.png ├── teaser2.jpg ├── teaser_comparison.jpg ├── teaser_pipeline.jpg └── user_study_comp.jpg ├── inference ├── examples │ ├── cone_chair │ │ ├── bg.jpeg │ │ ├── composite.png │ │ └── mask.png │ ├── lamp_candles │ │ ├── bg.jpeg │ │ ├── composite.png │ │ └── mask.png │ └── lamp_soap │ │ ├── bg.jpeg │ │ ├── composite.png │ │ └── mask.png └── inference.py ├── interface ├── examples │ ├── bgs │ │ ├── blue_chairs.jpeg │ │ ├── boxes.jpeg │ │ ├── classroom.jpeg │ │ ├── cone_org.jpeg │ │ ├── dim.jpeg │ │ ├── dock.jpeg │ │ ├── empty_room.jpeg │ │ ├── lamp.jpeg │ │ ├── museum2.jpeg │ │ ├── museum3.jpeg │ │ └── pillar.jpeg │ ├── fgs │ │ ├── astro.png │ │ ├── figurine.png │ │ ├── soap.png │ │ ├── white_bag.png │ │ ├── white_chair.png │ │ └── white_pot.png │ └── masks │ │ ├── astro.png │ │ ├── figurine.png │ │ ├── soap.png │ │ ├── white_bag.png │ │ ├── white_chair.png │ │ └── white_pot.png └── interface.py ├── intrinsic_compositing ├── __init__.py ├── albedo │ ├── __init__.py │ ├── model │ │ ├── MiDaS │ │ │ ├── __init__.py │ │ │ └── midas │ │ │ │ ├── __init__.py │ │ │ │ ├── base_model.py │ │ │ │ ├── blocks.py │ │ │ │ ├── dpt_depth.py │ │ │ │ ├── midas_net.py │ │ │ │ ├── midas_net_custom.py │ │ │ │ ├── transforms.py │ │ │ │ └── vit.py │ │ ├── __init__.py │ │ ├── discriminator.py │ │ ├── editingnetwork_trainer.py │ │ ├── parametermodel.py │ │ └── pix2pix │ │ │ ├── __init__.py │ │ │ └── models │ │ │ ├── __init__.py │ │ │ ├── base_model.py │ │ │ └── networks.py │ ├── pipeline.py │ └── utils │ │ ├── __init__.py │ │ ├── datautils.py │ │ ├── depthutils.py │ │ ├── edits.py │ │ ├── networkutils.py │ │ └── utils.py └── shading │ ├── __init__.py │ └── pipeline.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2023, Chris Careaga, S. Mahdi H. Miangoleh, Yağız Aksoy, Computational Photography Laboratory. All rights reserved. 2 | 3 | This software is for academic use only. A redistribution of this 4 | software, with or without modifications, has to be for academic 5 | use only, while giving the appropriate credit to the original 6 | authors of the software. The methods implemented as a part of 7 | this software may be covered under patents or patent applications. 8 | 9 | THIS SOFTWARE IS PROVIDED BY THE AUTHOR ''AS IS'' AND ANY EXPRESS OR IMPLIED 10 | WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND 11 | FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR 12 | CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 13 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 14 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 15 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 16 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 17 | ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 18 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Intrinsic Harmonization for Illumination-Aware Compositing 2 | Code for the paper: Intrinsic Harmonization for Illumination-Aware Compositing, [Chris Careaga](https://ccareaga.github.io), [S. Mahdi H. Miangoleh](https://miangoleh.github.io) [Yağız Aksoy](https://yaksoy.github.io), Proc. SIGGRAPH Asia, 2023 3 | ### [Project Page](https://yaksoy.github.io/intrinsicCompositing) | [Paper](https://yaksoy.github.io/papers/SigAsia23-IntrinsicCompositing.pdf) | [Video](https://www.youtube.com/watch?v=M9hCUTp8bo4) | [Supplementary](https://yaksoy.github.io/papers/SigAsia23-IntrinsicCompositing-Supp.pdf) 4 | 5 | We propose an illumination-aware image harmonization approach for in-the-wild imagery. Our method is formulated in the intrinsic image domain. We use off-the-shelf networks to generate albedo, shading and surface normals for the input composite and background image. We first harmonize the albedo of the background and foreground by predicting image editing parameters. Using normals and shading we estimate a simple lighting model for the background illumination. With this lighting model, we render Lambertian shading for the foreground and refine it using a network trained on segmentation datasets via self-supervision. When compared to prior works we are the only method that is capable of modeling realistic lighting effects. 6 | 7 | [![YouTube Video](./figures/IntrinsicCompositingVideo.jpg)](https://www.youtube.com/watch?v=M9hCUTp8bo4) 8 | 9 | 10 | ## Method 11 | Compositing is a crucial image editing task requiring realistic integration of objects into new backgrounds. 12 | Achieving a natural composition requires adjusting the appearance of the inserted object through a process called image harmonization. 13 | While existing literature addresses color harmonization, relighting, an equally vital aspect, is often overlooked due to the challenges in realistically adjusting object illumination in diverse environments. 14 | 15 | ![](./figures/teaser_comparison.jpg) 16 | 17 | In this project, we tackle image harmonization in the intrinsic domain, decomposing images into reflectance (albedo) and illumination (shading). 18 | We employ a two-step approach: first, harmonizing color in the albedo space, and then addressing the challenging relighting problem in the shading domain. 19 | Our goal is to generate realistic shading for the composited object, reflecting the new illumination environment. 20 | 21 | ![](./figures/teaser_pipeline.jpg) 22 | 23 | More specifically, we initially render an initial shading using the Lambertian model and surface normals for the background and inserted object. 24 | A re-shading network then refines this shading for the composited object in a self-supervised manner. 25 | Our method is able to generate novel reshadings of the foreground region that reflect the illumination conditions of the background scene. 26 | 27 | ![](./figures/teaser2.jpg) 28 | 29 | Our method outperforms prior works, producing realistic composite images that not only match color but also exhibit realistic illumination in diverse scenarios. 30 | 31 | ![](./figures/user_study_comp.jpg) 32 | 33 | Our re-shading network learns to predict spatially-varying lighting effects in-context due to our self-supervised training approach 34 | 35 | ![](./figures/astronauts.png) 36 | 37 | ## Setup 38 | Depending on how you would like to use the code in this repository there are two options to setup the code. 39 | In either case, you should first create a fresh virtual environment (`python3 -m venv intrinsic_env`) and start it (`source intrinsic_env/bin/activate`) 40 | 41 | You can install this repository as a package using `pip`: 42 | ``` 43 | git clone https://github.com/compphoto/IntrinsicCompositing 44 | cd IntrinsicCompositing 45 | pip install . 46 | ``` 47 | If you want to make changes to the code and have it reflected when you import the package use `pip install --editable` 48 | Or perform the same action without cloning the code using: 49 | ``` 50 | pip install https://github.com/compphoto/IntrinsicCompositing/archive/main.zip 51 | ``` 52 | This will allow you to import the repository as a Python package, and use our pipeline as part of your codebase. The pipeline has been tested with the following versions, but earlier versions should work as well: 53 | ``` 54 | python==3.10 55 | torch==2.5.1 56 | opencv-python==4.10 57 | numpy==1.26.4 58 | ``` 59 | 60 | ## Interface 61 | 62 | The best way to run our pipeline is by using our interactive interface. We provide some example backgrounds and foregrounds in `interface/examples`: 63 | 64 | ``` 65 | $ cd interface 66 | $ python interface.py --bg examples/bgs/lamp.jpeg --fg examples/fgs/soap.png --mask examples/masks/soap.png 67 | ``` 68 | The first time you run the interface multiple pretrained checkpoints will be downloaded (the method makes use of multiple off-the-shelf models) which may take some time. Subsequent runs will use the cached weights, but there is still a bit of preprocessing that is required when the interface is started. Once the preprocessing is done the interface window will appear and the input composite can be edited. After editing the composite, harmonizing only requires running our albedo and shading networks which should only take a second or two. These are the keybinds for the interface: 69 | 70 | | Key | Action | 71 | |--|--| 72 | | r | run the harmonization of the current composite | 73 | | s | save inputs, outputs and intermediate images | 74 | |1-5 | view various intermediate representations (shading, normals, etc) | 75 | |scroll up/down | scale foreground region up or down | 76 | 77 | The interface has been tested on an RTX2060 with 8 gb of VRAM which should be able to handle inference at a 1024 pixel resolution. 78 | 79 | ## Inference 80 | 81 | If you want to run our pipeline on pre-made composite images, you can use the script in the `inference` folder. 82 | This script will iterate through a set of composites and output our harmonized result: 83 | ``` 84 | $ cd inference 85 | $ python inference.py --help 86 | 87 | usage: inference.py [-h] --input_dir INPUT_DIR --output_dir OUTPUT_DIR [--inference_size INFERENCE_SIZE] [--intermediate] 88 | 89 | optional arguments: 90 | -h, --help show this help message and exit 91 | --input_dir INPUT_DIR 92 | input directory to read input composites, bgs and masks 93 | --output_dir OUTPUT_DIR 94 | output directory to store harmonized composites 95 | --inference_size INFERENCE_SIZE 96 | size to perform inference (default 1024) 97 | --intermediate whether or not to save visualization of intermediate representations 98 | --reproduce_paper whether or not to use code and weights from the original paper implementation 99 | 100 | ``` 101 | Here is how you can run the script on a set of example composites stored in `inference/examples`: 102 | ``` 103 | $ python inference.py --input_dir examples/ --output_dir output/ 104 | ``` 105 | If you want to test your own examples, the script uses the following input directory structure: 106 | ``` 107 | examples/ 108 | ├── cone_chair 109 | │   ├── bg.jpeg 110 | │   ├── composite.png 111 | │   └── mask.png 112 | ├── lamp_candles 113 | │   ├── bg.jpeg 114 | │   ├── composite.png 115 | │   └── mask.png 116 | └── lamp_soap 117 | ├── bg.jpeg 118 | ├── composite.png 119 | └── mask.png 120 | ``` 121 | Each directory contains a composite image, a corresponding mask for the composited region, and the background image without the composited object. 122 | Note the background image is only used to compute the lighting direction, so it doesn't need to be exactly aligned with the composite image. 123 | In fact, it can be any image and the script will use it to estimate the illumination parameters used as part of our pipeline. 124 | 125 | The script expects the images to have the extensions shown above, and for the bg and composite to be three channels while the mask is one channel. 126 | The script can be easily adjusted in order to fit whatever data format you're using. 127 | 128 | ## Note on Reproducibility 129 | 130 | The original albedo harmonization training and testing code assumed that the shading images were stored as 16-bit values, and normalized them to [0-1] accordingly. But when generating results I was using 8-bit shading images. This meant that the albedo being fed to the network was incorrect (due to the low-contrast shading values). When I prepared the code for release, I fixed this bug without thinking about it meaning the GitHub code does not have this issue. I believe the GitHub code is a more accurate implementation since the albedo harmonization network is receiving the correct albedo as input. In order to maintain reproducibility, I've added a flag to the inference and interface scripts called `--reproduce_paper` that will use the logic and weights from the original implementation. Without this flag, the code will run correctly and use better weights for the reshading network. Here are the results you should see for each setting of this flag: 131 | 132 | | with `--reproduce_paper` | without `--reproduce_paper` | 133 | | ------------- | ------------- | 134 | | ![cone_chair](https://github.com/compphoto/IntrinsicCompositing/assets/3434597/b23c22dc-75c2-4e46-ba1f-54d7a137cacc) | ![cone_chair](https://github.com/compphoto/IntrinsicCompositing/assets/3434597/15ab7c12-527e-4d38-83d7-5cdaa4e67da3) | 135 | 136 | ## Citation 137 | 138 | ``` 139 | @INPROCEEDINGS{careagaCompositing, 140 | author={Chris Careaga and S. Mahdi H. Miangoleh and Ya\u{g}{\i}z Aksoy}, 141 | title={Intrinsic Harmonization for Illumination-Aware Compositing}, 142 | booktitle={Proc. SIGGRAPH Asia}, 143 | year={2023}, 144 | } 145 | ``` 146 | 147 | ## License 148 | 149 | This implementation is provided for academic use only. Please cite our paper if you use this code or any of the models. 150 | 151 | The methodology presented in this work is safeguarded under intellectual property protection. For inquiries regarding licensing opportunities, kindly reach out to SFU Technology Licensing Office <tlo_dir ατ sfu δøτ ca> and Dr. Yağız Aksoy <yagiz ατ sfu δøτ ca>. 152 | -------------------------------------------------------------------------------- /figures/IntrinsicCompositingVideo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/figures/IntrinsicCompositingVideo.jpg -------------------------------------------------------------------------------- /figures/astronauts.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/figures/astronauts.png -------------------------------------------------------------------------------- /figures/teaser2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/figures/teaser2.jpg -------------------------------------------------------------------------------- /figures/teaser_comparison.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/figures/teaser_comparison.jpg -------------------------------------------------------------------------------- /figures/teaser_pipeline.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/figures/teaser_pipeline.jpg -------------------------------------------------------------------------------- /figures/user_study_comp.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/figures/user_study_comp.jpg -------------------------------------------------------------------------------- /inference/examples/cone_chair/bg.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/inference/examples/cone_chair/bg.jpeg -------------------------------------------------------------------------------- /inference/examples/cone_chair/composite.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/inference/examples/cone_chair/composite.png -------------------------------------------------------------------------------- /inference/examples/cone_chair/mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/inference/examples/cone_chair/mask.png -------------------------------------------------------------------------------- /inference/examples/lamp_candles/bg.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/inference/examples/lamp_candles/bg.jpeg -------------------------------------------------------------------------------- /inference/examples/lamp_candles/composite.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/inference/examples/lamp_candles/composite.png -------------------------------------------------------------------------------- /inference/examples/lamp_candles/mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/inference/examples/lamp_candles/mask.png -------------------------------------------------------------------------------- /inference/examples/lamp_soap/bg.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/inference/examples/lamp_soap/bg.jpeg -------------------------------------------------------------------------------- /inference/examples/lamp_soap/composite.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/inference/examples/lamp_soap/composite.png -------------------------------------------------------------------------------- /inference/examples/lamp_soap/mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/inference/examples/lamp_soap/mask.png -------------------------------------------------------------------------------- /inference/inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from pathlib import Path 4 | import numpy as np 5 | 6 | from glob import glob 7 | 8 | from skimage.transform import resize 9 | 10 | from chrislib.general import ( 11 | invert, 12 | uninvert, 13 | view, 14 | np_to_pil, 15 | to2np, 16 | add_chan, 17 | show, 18 | round_32, 19 | tile_imgs 20 | ) 21 | from chrislib.data_util import load_image 22 | from chrislib.normal_util import get_omni_normals 23 | 24 | from boosted_depth.depth_util import create_depth_models, get_depth 25 | 26 | from intrinsic.model_util import load_models 27 | from intrinsic.pipeline import run_pipeline 28 | 29 | from intrinsic_compositing.shading.pipeline import ( 30 | load_reshading_model, 31 | compute_reshading, 32 | generate_shd, 33 | get_light_coeffs 34 | ) 35 | 36 | from intrinsic_compositing.albedo.pipeline import ( 37 | load_albedo_harmonizer, 38 | harmonize_albedo 39 | ) 40 | 41 | from omnidata_tools.model_util import load_omni_model 42 | 43 | def get_bbox(mask): 44 | rows = np.any(mask, axis=1) 45 | cols = np.any(mask, axis=0) 46 | rmin, rmax = np.where(rows)[0][[0, -1]] 47 | cmin, cmax = np.where(cols)[0][[0, -1]] 48 | 49 | return rmin, rmax, cmin, cmax 50 | 51 | def rescale(img, scale, r32=False): 52 | if scale == 1.0: return img 53 | 54 | h = img.shape[0] 55 | w = img.shape[1] 56 | 57 | if r32: 58 | img = resize(img, (round_32(h * scale), round_32(w * scale))) 59 | else: 60 | img = resize(img, (int(h * scale), int(w * scale))) 61 | 62 | return img 63 | 64 | def compute_composite_normals(img, msk, model, size): 65 | 66 | bin_msk = (msk > 0) 67 | 68 | bb = get_bbox(bin_msk) 69 | bb_h, bb_w = bb[1] - bb[0], bb[3] - bb[2] 70 | 71 | # create the crop around the object in the image to send through normal net 72 | img_crop = img[bb[0] : bb[1], bb[2] : bb[3], :] 73 | 74 | crop_scale = 1024 / max(bb_h, bb_w) 75 | img_crop = rescale(img_crop, crop_scale) 76 | 77 | # get normals of cropped and scaled object and resize back to original bbox size 78 | nrm_crop = get_omni_normals(model, img_crop) 79 | nrm_crop = resize(nrm_crop, (bb_h, bb_w)) 80 | 81 | h, w, c = img.shape 82 | max_dim = max(h, w) 83 | if max_dim > size: 84 | scale = size / max_dim 85 | else: 86 | scale = 1.0 87 | 88 | # resize to the final output size as specified by input args 89 | out_img = rescale(img, scale, r32=True) 90 | out_msk = rescale(msk, scale, r32=True) 91 | out_bin_msk = (out_msk > 0) 92 | 93 | # compute normals for the entire composite image at it's output size 94 | out_nrm_bg = get_omni_normals(model, out_img) 95 | 96 | # now the image is at a new size so the parameters of the object crop change. 97 | # in order to overlay the normals, we need to resize the crop to this new size 98 | out_bb = get_bbox(out_bin_msk) 99 | bb_h, bb_w = out_bb[1] - out_bb[0], out_bb[3] - out_bb[2] 100 | 101 | # now resize the normals of the crop to this size, and put them in empty image 102 | out_nrm_crop = resize(nrm_crop, (bb_h, bb_w)) 103 | out_nrm_fg = np.zeros_like(out_img) 104 | out_nrm_fg[out_bb[0] : out_bb[1], out_bb[2] : out_bb[3], :] = out_nrm_crop 105 | 106 | # combine bg and fg normals with mask alphas 107 | out_nrm = (out_nrm_fg * out_msk[:, :, None]) + (out_nrm_bg * (1.0 - out_msk[:, :, None])) 108 | return out_nrm 109 | 110 | parser = argparse.ArgumentParser() 111 | 112 | parser.add_argument('--input_dir', type=str, required=True, help='input directory to read input composites, bgs and masks') 113 | parser.add_argument('--output_dir', type=str, required=True, help='output directory to store harmonized composites') 114 | parser.add_argument('--inference_size', type=int, default=1024, help='size to perform inference (default 1024)') 115 | parser.add_argument('--intermediate', action='store_true', help='whether or not to save visualization of intermediate representations') 116 | parser.add_argument('--reproduce_paper', action='store_true', help='whether or not to use code and weights from the original paper implementation') 117 | 118 | args = parser.parse_args() 119 | 120 | print('loading depth model') 121 | dpt_model = create_depth_models() 122 | 123 | print('loading normals model') 124 | nrm_model = load_omni_model() 125 | 126 | print('loading intrinsic decomposition model') 127 | int_model = load_models('paper_weights') 128 | 129 | print('loading albedo model') 130 | alb_model = load_albedo_harmonizer() 131 | 132 | print('loading reshading model') 133 | if args.reproduce_paper: 134 | shd_model = load_reshading_model('paper_weights') 135 | else: 136 | shd_model = load_reshading_model('further_trained') 137 | 138 | 139 | examples = glob(f'{args.input_dir}/*') 140 | print() 141 | print(f'found {len(examples)} scenes') 142 | print() 143 | 144 | if not os.path.exists(args.output_dir): 145 | os.makedirs(args.output_dir, exist_ok=True) 146 | 147 | for i, example_dir in enumerate(examples): 148 | 149 | bg_img = load_image(f'{example_dir}/bg.jpeg') 150 | comp_img = load_image(f'{example_dir}/composite.png') 151 | mask_img = load_image(f'{example_dir}/mask.png') 152 | 153 | scene_name = Path(example_dir).stem 154 | 155 | # to ensure that normals are globally accurate we compute them at 156 | # a resolution of 512 pixels, so resize our shading and image to compute 157 | # rescaled normals, then run the lighting model optimization 158 | bg_h, bg_w = bg_img.shape[:2] 159 | max_dim = max(bg_h, bg_w) 160 | scale = 512 / max_dim 161 | 162 | small_bg_img = rescale(bg_img, scale) 163 | small_bg_nrm = get_omni_normals(nrm_model, small_bg_img) 164 | 165 | result = run_pipeline( 166 | int_model, 167 | small_bg_img ** 2.2, 168 | resize_conf=0.0, 169 | maintain_size=True, 170 | linear=True 171 | ) 172 | 173 | small_bg_shd = result['inv_shading'][:, :, None] 174 | 175 | 176 | coeffs, lgt_vis = get_light_coeffs( 177 | small_bg_shd[:, :, 0], 178 | small_bg_nrm, 179 | small_bg_img 180 | ) 181 | 182 | # now we compute the normals of the entire composite image, we have some logic 183 | # to generate a detailed estimation of the foreground object by cropping and 184 | # resizing, we then overlay that onto the normals of the whole scene 185 | comp_nrm = compute_composite_normals(comp_img, mask_img, nrm_model, args.inference_size) 186 | 187 | # now compute depth and intrinsics at a specific resolution for the composite image 188 | # if the image is already smaller than the specified resolution, leave it 189 | h, w, c = comp_img.shape 190 | 191 | max_dim = max(h, w) 192 | if max_dim > args.inference_size: 193 | scale = args.inference_size / max_dim 194 | else: 195 | scale = 1.0 196 | 197 | # resize to specified size and round to 32 for network inference 198 | img = rescale(comp_img, scale, r32=True) 199 | msk = rescale(mask_img, scale, r32=True) 200 | 201 | depth = get_depth(img, dpt_model) 202 | 203 | result = run_pipeline( 204 | int_model, 205 | img ** 2.2, 206 | resize_conf=0.0, 207 | maintain_size=True, 208 | linear=True 209 | ) 210 | 211 | inv_shd = result['inv_shading'] 212 | # inv_shd = rescale(inv_shd, scale, r32=True) 213 | 214 | # compute the harmonized albedo, and the subsequent color harmonized image 215 | alb_harm = harmonize_albedo(img, msk, inv_shd, alb_model, reproduce_paper=args.reproduce_paper) ** 2.2 216 | harm_img = alb_harm * uninvert(inv_shd)[:, :, None] 217 | 218 | # run the reshading model using the various composited components, 219 | # and our lighting coefficients computed from the background 220 | comp_result = compute_reshading( 221 | harm_img, 222 | msk, 223 | inv_shd, 224 | depth, 225 | comp_nrm, 226 | alb_harm, 227 | coeffs, 228 | shd_model 229 | ) 230 | 231 | if args.intermediate: 232 | tile_imgs([ 233 | img, 234 | msk, 235 | view(alb_harm), 236 | 1-inv_shd, 237 | depth, 238 | comp_nrm, 239 | view(generate_shd(comp_nrm, coeffs, msk, viz=True)[1]), 240 | 1-invert(comp_result['reshading']) 241 | ], save=f'{args.output_dir}/{scene_name}_intermediate.jpeg', rescale=0.75) 242 | 243 | np_to_pil(comp_result['composite']).save(f'{args.output_dir}/{scene_name}.png') 244 | 245 | print(f'finished ({i+1}/{len(examples)}) - {scene_name}') 246 | -------------------------------------------------------------------------------- /interface/examples/bgs/blue_chairs.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/bgs/blue_chairs.jpeg -------------------------------------------------------------------------------- /interface/examples/bgs/boxes.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/bgs/boxes.jpeg -------------------------------------------------------------------------------- /interface/examples/bgs/classroom.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/bgs/classroom.jpeg -------------------------------------------------------------------------------- /interface/examples/bgs/cone_org.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/bgs/cone_org.jpeg -------------------------------------------------------------------------------- /interface/examples/bgs/dim.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/bgs/dim.jpeg -------------------------------------------------------------------------------- /interface/examples/bgs/dock.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/bgs/dock.jpeg -------------------------------------------------------------------------------- /interface/examples/bgs/empty_room.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/bgs/empty_room.jpeg -------------------------------------------------------------------------------- /interface/examples/bgs/lamp.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/bgs/lamp.jpeg -------------------------------------------------------------------------------- /interface/examples/bgs/museum2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/bgs/museum2.jpeg -------------------------------------------------------------------------------- /interface/examples/bgs/museum3.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/bgs/museum3.jpeg -------------------------------------------------------------------------------- /interface/examples/bgs/pillar.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/bgs/pillar.jpeg -------------------------------------------------------------------------------- /interface/examples/fgs/astro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/fgs/astro.png -------------------------------------------------------------------------------- /interface/examples/fgs/figurine.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/fgs/figurine.png -------------------------------------------------------------------------------- /interface/examples/fgs/soap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/fgs/soap.png -------------------------------------------------------------------------------- /interface/examples/fgs/white_bag.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/fgs/white_bag.png -------------------------------------------------------------------------------- /interface/examples/fgs/white_chair.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/fgs/white_chair.png -------------------------------------------------------------------------------- /interface/examples/fgs/white_pot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/fgs/white_pot.png -------------------------------------------------------------------------------- /interface/examples/masks/astro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/masks/astro.png -------------------------------------------------------------------------------- /interface/examples/masks/figurine.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/masks/figurine.png -------------------------------------------------------------------------------- /interface/examples/masks/soap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/masks/soap.png -------------------------------------------------------------------------------- /interface/examples/masks/white_bag.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/masks/white_bag.png -------------------------------------------------------------------------------- /interface/examples/masks/white_chair.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/masks/white_chair.png -------------------------------------------------------------------------------- /interface/examples/masks/white_pot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/interface/examples/masks/white_pot.png -------------------------------------------------------------------------------- /interface/interface.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import argparse 4 | import imageio 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | import tkinter as tk 10 | from tkinter import ttk 11 | from pathlib import Path 12 | from datetime import datetime 13 | 14 | from PIL import ImageTk, Image 15 | import numpy as np 16 | 17 | from skimage.transform import resize 18 | 19 | from chrislib.general import invert, uninvert, view, np_to_pil, to2np, add_chan 20 | from chrislib.data_util import load_image 21 | from chrislib.normal_util import get_omni_normals 22 | 23 | from boosted_depth.depth_util import create_depth_models, get_depth 24 | 25 | from intrinsic.model_util import load_models 26 | from intrinsic.pipeline import run_pipeline 27 | 28 | from intrinsic_compositing.shading.pipeline import ( 29 | load_reshading_model, 30 | compute_reshading, 31 | generate_shd, 32 | get_light_coeffs 33 | ) 34 | 35 | from intrinsic_compositing.albedo.pipeline import ( 36 | load_albedo_harmonizer, 37 | harmonize_albedo 38 | ) 39 | 40 | from omnidata_tools.model_util import load_omni_model 41 | 42 | def viz_coeffs(coeffs, size): 43 | half_sz = size // 2 44 | nrm_circ = draw_normal_circle( 45 | np.zeros((size, size, 3)), 46 | (half_sz, half_sz), 47 | half_sz 48 | ) 49 | 50 | out_shd = (nrm_circ.reshape(-1, 3) @ coeffs[:3]) + coeffs[-1] 51 | out_shd = out_shd.reshape(size, size) 52 | 53 | lin = np.linspace(-1, 1, num=size) 54 | ys, xs = np.meshgrid(lin, lin) 55 | 56 | zs = np.sqrt((1.0 - (xs**2 + ys**2)).clip(0)) 57 | 58 | out_shd[zs == 0] = 0 59 | 60 | return (out_shd.clip(1e-4) ** (1/2.2)).clip(0, 1) 61 | 62 | def draw_normal_circle(nrm, loc, rad): 63 | size = rad * 2 64 | 65 | lin = np.linspace(-1, 1, num=size) 66 | ys, xs = np.meshgrid(lin, lin) 67 | 68 | zs = np.sqrt((1.0 - (xs**2 + ys**2)).clip(0)) 69 | valid = (zs != 0) 70 | normals = np.stack((ys[valid], -xs[valid], zs[valid]), 1) 71 | 72 | valid_mask = np.zeros((size, size)) 73 | valid_mask[valid] = 1 74 | 75 | full_mask = np.zeros((nrm.shape[0], nrm.shape[1])) 76 | x = loc[0] - rad 77 | y = loc[1] - rad 78 | full_mask[y : y + size, x : x + size] = valid_mask 79 | # nrm[full_mask > 0] = (normals + 1.0) / 2.0 80 | nrm[full_mask > 0] = normals 81 | 82 | return nrm 83 | 84 | def get_bbox(mask): 85 | rows = np.any(mask, axis=1) 86 | cols = np.any(mask, axis=0) 87 | rmin, rmax = np.where(rows)[0][[0, -1]] 88 | cmin, cmax = np.where(cols)[0][[0, -1]] 89 | 90 | return rmin, rmax, cmin, cmax 91 | 92 | def rescale(img, scale): 93 | if scale == 1.0: return img 94 | 95 | h = img.shape[0] 96 | w = img.shape[1] 97 | 98 | img = resize(img, (int(h * scale), int(w * scale))) 99 | return img 100 | 101 | def composite_crop(img, loc, fg, mask): 102 | c_h, c_w, _ = fg.shape 103 | 104 | img = img.copy() 105 | 106 | img_crop = img[loc[0] : loc[0] + c_h, loc[1] : loc[1] + c_w, :] 107 | comp = (img_crop * (1.0 - mask)) + (fg * mask) 108 | img[loc[0] : loc[0] + c_h, loc[1] : loc[1] + c_w, :] = comp 109 | 110 | return img 111 | 112 | # composite the depth of a fragment but try to match 113 | # the wherever the fragment is placed (fake the depth) 114 | def composite_depth(img, loc, fg, mask): 115 | c_h, c_w = fg.shape[:2] 116 | 117 | # get the bottom-center depth of the bg 118 | bg_bc = loc[0] + c_h, loc[1] + (c_w // 2) 119 | bg_bc_val = img[bg_bc[0], bg_bc[1]].item() 120 | 121 | # get the bottom center depth of the fragment 122 | fg_bc = c_h - 1, (c_w // 2) 123 | fg_bc_val = fg[fg_bc[0], fg_bc[1]].item() 124 | 125 | # compute scale to match the fg values to bg 126 | scale = bg_bc_val / fg_bc_val 127 | 128 | img = img.copy() 129 | 130 | img_crop = img[loc[0] : loc[0] + c_h, loc[1] : loc[1] + c_w] 131 | comp = (img_crop * (1.0 - mask)) + (scale * fg * mask) 132 | img[loc[0] : loc[0] + c_h, loc[1] : loc[1] + c_w] = comp 133 | 134 | return img 135 | 136 | 137 | # DISP_SCALE = 0.5 138 | DISP_SCALE = 1.0 139 | MAX_BG_SZ = 1024 140 | LGT_VIZ_SZ = 144 141 | 142 | class App(tk.Tk): 143 | def __init__(self, args): 144 | super().__init__() 145 | 146 | self.configure(background='black') 147 | 148 | self.args = args 149 | 150 | loaded_fg = load_image(args.fg) 151 | 152 | self.bg_img = load_image(args.bg)[:, :, :3] 153 | self.fg_img = loaded_fg[:, :, :3] 154 | 155 | if args.mask is not None: 156 | self.mask_img = load_image(args.mask) 157 | else: 158 | if loaded_fg.shape[-1] != 4: 159 | print("expected foreground image to have an alpha channel since no mask was specified") 160 | exit() 161 | 162 | self.mask_img = self.fg_img[:, :, -1] 163 | 164 | if len(self.mask_img.shape) == 3: 165 | self.mask_img = self.mask_img[:, :, :1] 166 | else: 167 | self.mask_img = self.mask_img[:, :, np.newaxis] 168 | 169 | 170 | print('loading depth model') 171 | self.dpt_model = create_depth_models() 172 | 173 | print('loading normals model') 174 | self.nrm_model = load_omni_model() 175 | 176 | print('loading intrinsic decomposition model') 177 | self.int_model = load_models('paper_weights') 178 | 179 | print('loading albedo model') 180 | self.alb_model = load_albedo_harmonizer() 181 | 182 | print('loading reshading model') 183 | if self.args.reproduce_paper: 184 | self.shd_model = load_reshading_model('paper_weights') 185 | else: 186 | self.shd_model = load_reshading_model('further_trained') 187 | 188 | self.init_scene() 189 | 190 | self.bg_disp_w = int(self.bg_w * DISP_SCALE) 191 | self.bg_disp_h = int(self.bg_h * DISP_SCALE) 192 | 193 | win_w = (self.bg_disp_w * 2) + LGT_VIZ_SZ + 40 194 | win_h = self.bg_disp_h + 20 195 | 196 | # configure the root window 197 | self.title('compositing demo') 198 | self.geometry(f"{win_w}x{win_h}") 199 | # self.geometry(f"") 200 | 201 | self.l_frame = ttk.Frame(self, width=self.bg_disp_w, height=self.bg_disp_h) 202 | self.l_frame.pack() 203 | self.l_frame.place(x=10, y=10) 204 | 205 | self.r_frame = ttk.Frame(self, width=self.bg_disp_w, height=self.bg_disp_h) 206 | self.r_frame.pack() 207 | self.r_frame.place(x=self.bg_disp_w + 20, y=10) 208 | 209 | style = ttk.Style(self) 210 | style.configure("TFrame", background="black") 211 | 212 | self.lgt_frame = ttk.Frame(self, width=LGT_VIZ_SZ, height=self.bg_disp_h) 213 | self.lgt_frame.pack() 214 | self.lgt_frame.place(x=win_w - LGT_VIZ_SZ - 10, y=10) 215 | 216 | l_disp_img = rescale(self.l_img, DISP_SCALE) 217 | r_disp_img = rescale(self.r_img, DISP_SCALE) 218 | lgt_disp_img = viz_coeffs(self.coeffs, LGT_VIZ_SZ) 219 | 220 | self.l_photo = ImageTk.PhotoImage(np_to_pil(l_disp_img)) 221 | self.r_photo = ImageTk.PhotoImage(np_to_pil(r_disp_img)) 222 | self.lgt_photo = ImageTk.PhotoImage(np_to_pil(lgt_disp_img)) 223 | 224 | self.l_label = ttk.Label(self.l_frame, image=self.l_photo) 225 | self.l_label.pack() 226 | 227 | self.r_label = ttk.Label(self.r_frame, image=self.r_photo) 228 | self.r_label.pack() 229 | 230 | self.lgt_label = ttk.Label(self.lgt_frame, image=self.lgt_photo) 231 | self.lgt_label.pack() 232 | 233 | self.bias_scale = ttk.Scale(self.lgt_frame, from_=0.1, to=1.0, orient=tk.HORIZONTAL, command=self.update_bias) 234 | self.bias_scale.pack(pady=5) 235 | 236 | style = ttk.Style(self) 237 | style.configure("White.TLabel", foreground="white", background='black') 238 | 239 | al = ttk.Label(self.lgt_frame, text="Ambient Strength", style="White.TLabel") 240 | al.pack(pady=5) 241 | 242 | self.dir_scale = ttk.Scale(self.lgt_frame, from_=0.1, to=2.0, orient=tk.HORIZONTAL, command=self.update_dir) 243 | self.dir_scale.pack(pady=5) 244 | 245 | dl = ttk.Label(self.lgt_frame, text="Directional Strength", style="White.TLabel") 246 | dl.pack(pady=5) 247 | 248 | dir_val = np.linalg.norm(self.coeffs[:3]) 249 | bias_val = self.coeffs[-1] 250 | 251 | self.bias_scale.set(bias_val) 252 | self.dir_scale.set(dir_val) 253 | 254 | self.bind('', self.key_pressed) 255 | self.bind('', self.click_motion) 256 | self.bind('', self.scrolled) 257 | self.bind('', self.scrolled) 258 | self.bind('', self.clicked) 259 | 260 | 261 | def update_left(self): 262 | # disp_img = rescale(self.l_img, DISP_SCALE) 263 | disp_img = self.l_img 264 | 265 | self.l_photo = ImageTk.PhotoImage(np_to_pil(disp_img)) 266 | self.l_label.configure(image=self.l_photo) 267 | self.l_label.image = self.l_photo 268 | 269 | def update_right(self): 270 | disp_img = rescale(self.r_img, DISP_SCALE) 271 | self.r_photo = ImageTk.PhotoImage(np_to_pil(disp_img)) 272 | self.r_label.configure(image=self.r_photo) 273 | self.r_label.image = self.r_photo 274 | 275 | def update_light(self): 276 | 277 | self.lgt_disp_img = viz_coeffs(self.coeffs, LGT_VIZ_SZ) 278 | self.lgt_photo = ImageTk.PhotoImage(np_to_pil(self.lgt_disp_img)) 279 | self.lgt_label.configure(image=self.lgt_photo) 280 | self.lgt_label.image = self.lgt_photo 281 | 282 | 283 | def update_bias(self, val): 284 | # update the ambient light strength from slider value 285 | self.coeffs[-1] = float(val) 286 | self.update_light() 287 | 288 | def update_dir(self, val): 289 | # update the directional light source strength 290 | 291 | # normalize the currect light source direction 292 | vec = self.coeffs[:3] 293 | vec /= np.linalg.norm(vec).clip(1e-3) 294 | 295 | # scale the unit length light direction by the slider value 296 | # then update the stored coeffs and UI accordingly 297 | vec *= float(val) 298 | self.coeffs[:3] = vec 299 | self.update_light() 300 | 301 | def init_scene(self): 302 | bg_h, bg_w, _ = self.bg_img.shape 303 | 304 | # resize the background image to be large side < some max value 305 | max_dim = max(bg_h, bg_w) 306 | scale = MAX_BG_SZ / max_dim 307 | 308 | # here is our bg image and size after resizing to max size 309 | self.bg_img = resize(self.bg_img, (int(bg_h * scale), int(bg_w * scale))) 310 | self.bg_h, self.bg_w, _ = self.bg_img.shape 311 | 312 | # compute normals and shading for background, and use them 313 | # to optimize for the lighting coefficients 314 | self.bg_nrm = get_omni_normals(self.nrm_model, self.bg_img) 315 | result = run_pipeline( 316 | self.int_model, 317 | self.bg_img ** 2.2, 318 | resize_conf=0.0, 319 | maintain_size=True, 320 | linear=True 321 | ) 322 | 323 | self.bg_shd = result['inv_shading'][:, :, None] 324 | self.bg_alb = result['albedo'] 325 | 326 | # to ensure that normals are globally accurate we compute them at 327 | # a resolution of 512 pixels, so resize our shading and image to compute 328 | # rescaled normals, then run the lighting model optimization 329 | max_dim = max(self.bg_h, self.bg_w) 330 | scale = 512 / max_dim 331 | small_bg_img = rescale(self.bg_img, scale) 332 | small_bg_nrm = get_omni_normals(self.nrm_model, small_bg_img) 333 | small_bg_shd = rescale(self.bg_shd, scale) 334 | 335 | self.orig_coeffs, self.lgt_vis = get_light_coeffs( 336 | small_bg_shd[:, :, 0], 337 | small_bg_nrm, 338 | small_bg_img 339 | ) 340 | 341 | self.coeffs = self.orig_coeffs 342 | 343 | # first we want to reason about the fg image as an image fragment that we can 344 | # move, scale and composite, so we only really need to deal with the masked area 345 | bin_msk = (self.mask_img > 0) 346 | 347 | bb = get_bbox(bin_msk) 348 | bb_h, bb_w = bb[1] - bb[0], bb[3] - bb[2] 349 | 350 | # create the crop around the object in the image, this can be very large depending 351 | # on the image that has been chosen by the user, but we want to store it at 1024 352 | self.orig_fg_crop = self.fg_img[bb[0] : bb[1], bb[2] : bb[3], :].copy() 353 | self.orig_msk_crop = self.mask_img[bb[0] : bb[1], bb[2] : bb[3], :].copy() 354 | 355 | # this is the real_scale, maps the original image crop size to 1024 356 | max_dim = max(bb_h, bb_w) 357 | real_scale = MAX_BG_SZ / max_dim 358 | 359 | # this is the copy of the crop we keep to compute normals 360 | self.orig_fg_crop = rescale(self.orig_fg_crop, real_scale) 361 | self.orig_msk_crop = rescale(self.orig_msk_crop, real_scale) 362 | 363 | # now compute the display scale to show the fragment on the ui 364 | max_dim = max(self.orig_fg_crop.shape) 365 | disp_scale = (min(self.bg_h, self.bg_w) // 2) / max_dim 366 | self.frag_scale = disp_scale 367 | print('init frag_scale:', self.frag_scale) 368 | 369 | # these are the versions that the UI shows and what will be 370 | # used to create the composite images sent to the networks 371 | self.fg_crop = rescale(self.orig_fg_crop, self.frag_scale) 372 | self.msk_crop = rescale(self.orig_msk_crop, self.frag_scale) 373 | 374 | self.bb_h, self.bb_w, _ = self.fg_crop.shape 375 | 376 | # set the composite region to center of bg img 377 | self.loc_y = self.bg_h // 2 378 | self.loc_x = self.bg_w // 2 379 | 380 | # get top and left using location, and bounding box size 381 | top = self.loc_y - (self.bb_h // 2) 382 | left = self.loc_x - (self.bb_w // 2) 383 | 384 | # create the composite image using helper function and location 385 | init_cmp = composite_crop( 386 | self.bg_img, 387 | (top, left), 388 | self.fg_crop, 389 | self.msk_crop 390 | ) 391 | 392 | # get normals just for the image fragment, it's best to send it through 393 | # cropped and scaled to 1024 in order to get the most details, 394 | # then we resize it to match the fragment size 395 | self.orig_fg_nrm = get_omni_normals(self.nrm_model, self.orig_fg_crop) 396 | 397 | # run the intrinsic pipeline to get foreground albedo and shading 398 | result = run_pipeline( 399 | self.int_model, 400 | self.orig_fg_crop ** 2.2, 401 | resize_conf=0.0, 402 | maintain_size=True, 403 | linear=True 404 | ) 405 | 406 | self.orig_fg_shd = result['inv_shading'][:, :, None] 407 | self.orig_fg_alb = result['albedo'] 408 | 409 | # to try to save memory I remove refs and empty cuda cache (might not even do anything) 410 | del self.nrm_model 411 | del self.int_model 412 | torch.cuda.empty_cache() 413 | 414 | # run depth model on bg and fg seperately 415 | self.bg_dpt = get_depth(self.bg_img, self.dpt_model)[:, :, None] 416 | self.orig_fg_dpt = get_depth(self.orig_fg_crop, self.dpt_model)[:, :, None] 417 | del self.dpt_model 418 | 419 | # these are the versions of the bg, fg and msk that we use in the UI 420 | self.disp_bg_img = rescale(self.bg_img, DISP_SCALE) 421 | self.disp_fg_crop = rescale(self.fg_crop, DISP_SCALE) 422 | self.disp_msk_crop = rescale(self.msk_crop, DISP_SCALE) 423 | 424 | # initialize right side as initial composite, and left is just zeros 425 | self.l_img = init_cmp 426 | self.r_img = np.zeros_like(init_cmp) 427 | 428 | def scrolled(self, e): 429 | 430 | # if we scroll we want to scale the foreground fragment up or down 431 | if e.num == 5 and self.frag_scale > 0.05: # scroll down 432 | self.frag_scale -= 0.01 433 | if e.num == 4 and self.frag_scale < 1.0: # scroll up 434 | self.frag_scale += 0.01 435 | 436 | self.fg_crop = rescale(self.orig_fg_crop, self.frag_scale) 437 | self.msk_crop = rescale(self.orig_msk_crop, self.frag_scale) 438 | 439 | self.disp_fg_crop = rescale(self.fg_crop, DISP_SCALE) 440 | self.disp_msk_crop = rescale(self.msk_crop, DISP_SCALE) 441 | 442 | x = int(self.loc_x * DISP_SCALE) 443 | y = int(self.loc_y * DISP_SCALE) 444 | 445 | top = y - (self.disp_fg_crop.shape[0] // 2) 446 | left = x - (self.disp_fg_crop.shape[1] // 2) 447 | 448 | self.l_img = composite_crop( 449 | self.disp_bg_img, 450 | (top, left), 451 | self.disp_fg_crop, 452 | self.disp_msk_crop 453 | ) 454 | self.update_left() 455 | 456 | def clicked(self, e): 457 | 458 | x, y = e.x, e.y 459 | radius = (LGT_VIZ_SZ // 2) 460 | 461 | if e.widget == self.lgt_label: 462 | # if the user clicked the light ball, compute the direction from mouse pos 463 | rel_x = (x - radius) / radius 464 | rel_y = (y - radius) / radius 465 | 466 | z = np.sqrt(1 - rel_x ** 2 - rel_y ** 2) 467 | 468 | # print('clicked the lighting viz:', rel_x, rel_y, z) 469 | 470 | # after converting the mouse pos to a normal direction on a unit sphere 471 | # we can create our 4D lighting coefficients using the slider values 472 | self.coeffs = np.array([0, 0, 0, float(self.bias_scale.get())]) 473 | dir_vec = np.array([rel_x, -rel_y, z]) * float(self.dir_scale.get()) 474 | self.coeffs[:3] = dir_vec 475 | 476 | self.update_light() 477 | 478 | def click_motion(self, e): 479 | x, y = e.x, e.y 480 | 481 | if e.widget == self.l_label: 482 | if (x <= self.bg_disp_w) and (y <= self.bg_disp_h): 483 | 484 | # we want to show the scaled version of the composite so that the UI 485 | # can be responsive, but save the coordinates properly so that the 486 | # we can send the original size image through the network 487 | self.loc_y = int(y / DISP_SCALE) 488 | self.loc_x = int(x / DISP_SCALE) 489 | 490 | top = y - (self.disp_fg_crop.shape[0] // 2) 491 | left = x - (self.disp_fg_crop.shape[1] // 2) 492 | 493 | self.l_img = composite_crop( 494 | self.disp_bg_img, 495 | (top, left), 496 | self.disp_fg_crop, 497 | self.disp_msk_crop 498 | ) 499 | 500 | self.update_left() 501 | 502 | def key_pressed(self, e): 503 | # run the harmonization 504 | if e.char == 'r': 505 | 506 | # create all the necessary inputs from the state of the interface 507 | fg_shd_res = rescale(self.orig_fg_shd, self.frag_scale) 508 | fg_nrm_res = rescale(self.orig_fg_nrm, self.frag_scale) 509 | fg_dpt_res = rescale(self.orig_fg_dpt, self.frag_scale) 510 | 511 | top = self.loc_y - (self.fg_crop.shape[0] // 2) 512 | left = self.loc_x - (self.fg_crop.shape[1] // 2) 513 | 514 | # create all the composite images to send to the pipeline 515 | self.comp_img = composite_crop( 516 | self.bg_img, 517 | (top, left), 518 | self.fg_crop, 519 | self.msk_crop 520 | ) 521 | 522 | self.comp_shd = composite_crop( 523 | self.bg_shd, 524 | (top, left), 525 | fg_shd_res, 526 | self.msk_crop 527 | ) 528 | 529 | self.comp_msk = composite_crop( 530 | np.zeros_like(self.bg_shd), 531 | (top, left), 532 | self.msk_crop, 533 | self.msk_crop 534 | ) 535 | 536 | comp_nrm = composite_crop( 537 | self.bg_nrm, 538 | (top, left), 539 | fg_nrm_res, 540 | self.msk_crop 541 | ) 542 | 543 | self.comp_dpt = composite_depth( 544 | self.bg_dpt, 545 | (top, left), 546 | fg_dpt_res, 547 | self.msk_crop 548 | ) 549 | 550 | # the albedo comes out gamma corrected so make it linear 551 | self.alb_harm = harmonize_albedo( 552 | self.comp_img, 553 | self.comp_msk, 554 | self.comp_shd, 555 | self.alb_model, 556 | reproduce_paper=self.args.reproduce_paper 557 | ) ** 2.2 558 | 559 | self.orig_alb = (self.comp_img ** 2.2) / uninvert(self.comp_shd) 560 | harm_img = self.alb_harm * uninvert(self.comp_shd) 561 | 562 | # run the reshading model using the various composited components, 563 | # and our lighting coefficients from the user interface 564 | self.result = compute_reshading( 565 | harm_img, 566 | self.comp_msk, 567 | self.comp_shd, 568 | self.comp_dpt, 569 | comp_nrm, 570 | self.alb_harm, 571 | self.coeffs, 572 | self.shd_model 573 | ) 574 | 575 | self.r_img = self.result['composite'] 576 | self.update_right() 577 | 578 | if e.char == '1': 579 | self.r_img = self.result['reshading'] 580 | self.update_right() 581 | 582 | if e.char == '2': 583 | self.r_img = self.result['init_shading'] 584 | self.update_right() 585 | 586 | if e.char == '3': 587 | self.r_img = self.result['normals'] 588 | self.update_right() 589 | 590 | if e.char == '4': 591 | self.r_img = self.comp_shd[:, :, 0] 592 | self.update_right() 593 | 594 | if e.char == '5': 595 | self.r_img = self.alb_harm 596 | self.update_right() 597 | 598 | if e.char == '6': 599 | self.r_img = self.comp_dpt[:, :, 0] 600 | self.update_right() 601 | 602 | if e.char == 's': 603 | # save all components 604 | 605 | # orig_shd from intrinsic pipeline is linear and inverse 606 | orig_shd = add_chan(uninvert(self.comp_shd)) 607 | 608 | # reshading coming from compositing pipeline is linear but not inverse 609 | reshading = add_chan(self.result['reshading']) 610 | 611 | imageio.imwrite('output/orig_shd.exr', orig_shd) 612 | imageio.imwrite('output/orig_shd.png', orig_shd) 613 | imageio.imwrite('output/orig_alb.exr', self.orig_alb) 614 | imageio.imwrite('output/orig_alb.png', self.orig_alb) 615 | imageio.imwrite('output/orig_img.exr', self.comp_img ** 2.2) 616 | imageio.imwrite('output/orig_img.png', self.comp_img ** 2.2) 617 | 618 | imageio.imwrite('output/harm_alb.exr', self.alb_harm) 619 | imageio.imwrite('output/harm_alb.png', self.alb_harm) 620 | imageio.imwrite('output/reshading.exr', reshading) 621 | imageio.imwrite('output/reshading.png', reshading) 622 | imageio.imwrite('output/final.exr', self.result['composite'] ** 2.2) 623 | imageio.imwrite('output/final.png', self.result['composite'] ** 2.2) 624 | 625 | imageio.imwrite('output/normals.exr', self.result['normals']) 626 | imageio.imwrite('output/light.exr', self.lgt_disp_img) 627 | 628 | if e.char == 'w': 629 | # write all the different components as pngs 630 | 631 | # orig_shd from intrinsic pipeline is linear and inverse 632 | orig_shd = add_chan(uninvert(self.comp_shd)) 633 | 634 | # reshading coming from compositing pipeline is linear but not inverse 635 | reshading = add_chan(self.result['reshading']) 636 | lambertian = add_chan(self.result['init_shading']) 637 | mask = add_chan(self.comp_msk) 638 | 639 | fg_name = Path(self.args.fg).stem 640 | bg_name = Path(self.args.bg).stem 641 | ts = int(datetime.utcnow().timestamp()) 642 | 643 | save_dir = f'{fg_name}_{bg_name}_{ts}' 644 | os.makedirs(f'output/{save_dir}') 645 | 646 | np_to_pil(view(orig_shd)).save(f'output/{save_dir}/orig_shd.png') 647 | np_to_pil(view(lambertian)).save(f'output/{save_dir}/lamb_shd.png') 648 | np_to_pil(view(self.orig_alb)).save(f'output/{save_dir}/orig_alb.png') 649 | np_to_pil(self.comp_img).save(f'output/{save_dir}/orig_img.png') 650 | 651 | np_to_pil(view(self.alb_harm)).save(f'output/{save_dir}/harm_alb.png') 652 | np_to_pil(view(reshading)).save(f'output/{save_dir}/reshading.png') 653 | np_to_pil(self.result['composite']).save(f'output/{save_dir}/final.png') 654 | 655 | np_to_pil(self.result['normals']).save(f'output/{save_dir}/normals.png') 656 | np_to_pil(self.lgt_disp_img).save(f'output/{save_dir}/light.png') 657 | 658 | np_to_pil(mask).save(f'output/{save_dir}/mask.png') 659 | 660 | _, bg_lamb_shd = generate_shd(self.bg_nrm, self.coeffs, np.ones(self.bg_nrm.shape[:2]), viz=True) 661 | np_to_pil(add_chan(view(bg_lamb_shd))).save(f'output/{save_dir}/bg_lamb_shd.png') 662 | 663 | 664 | if __name__ == "__main__": 665 | parser = argparse.ArgumentParser() 666 | 667 | parser.add_argument('--bg', type=str, required=True) 668 | parser.add_argument('--fg', type=str, required=True) 669 | parser.add_argument('--mask', type=str, default=None) 670 | parser.add_argument('--reproduce_paper', action='store_true', help='whether or not use the code and weights of the original paper implementation') 671 | 672 | args = parser.parse_args() 673 | 674 | app = App(args) 675 | app.mainloop() 676 | -------------------------------------------------------------------------------- /intrinsic_compositing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/intrinsic_compositing/__init__.py -------------------------------------------------------------------------------- /intrinsic_compositing/albedo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/intrinsic_compositing/albedo/__init__.py -------------------------------------------------------------------------------- /intrinsic_compositing/albedo/model/MiDaS/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/intrinsic_compositing/albedo/model/MiDaS/__init__.py -------------------------------------------------------------------------------- /intrinsic_compositing/albedo/model/MiDaS/midas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/intrinsic_compositing/albedo/model/MiDaS/midas/__init__.py -------------------------------------------------------------------------------- /intrinsic_compositing/albedo/model/MiDaS/midas/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseModel(torch.nn.Module): 5 | def load(self, path): 6 | """Load model from file. 7 | 8 | Args: 9 | path (str): file path 10 | """ 11 | parameters = torch.load(path, map_location=torch.device('cpu')) 12 | 13 | if "optimizer" in parameters: 14 | parameters = parameters["model"] 15 | 16 | self.load_state_dict(parameters) 17 | -------------------------------------------------------------------------------- /intrinsic_compositing/albedo/model/MiDaS/midas/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .vit import ( 5 | _make_pretrained_vitb_rn50_384, 6 | _make_pretrained_vitl16_384, 7 | _make_pretrained_vitb16_384, 8 | forward_vit, 9 | ) 10 | 11 | def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): 12 | if backbone == "vitl16_384": 13 | pretrained = _make_pretrained_vitl16_384( 14 | use_pretrained, hooks=hooks, use_readout=use_readout 15 | ) 16 | scratch = _make_scratch( 17 | [256, 512, 1024, 1024], features, groups=groups, expand=expand 18 | ) # ViT-L/16 - 85.0% Top1 (backbone) 19 | elif backbone == "vitb_rn50_384": 20 | pretrained = _make_pretrained_vitb_rn50_384( 21 | use_pretrained, 22 | hooks=hooks, 23 | use_vit_only=use_vit_only, 24 | use_readout=use_readout, 25 | ) 26 | scratch = _make_scratch( 27 | [256, 512, 768, 768], features, groups=groups, expand=expand 28 | ) # ViT-H/16 - 85.0% Top1 (backbone) 29 | elif backbone == "vitb16_384": 30 | pretrained = _make_pretrained_vitb16_384( 31 | use_pretrained, hooks=hooks, use_readout=use_readout 32 | ) 33 | scratch = _make_scratch( 34 | [96, 192, 384, 768], features, groups=groups, expand=expand 35 | ) # ViT-B/16 - 84.6% Top1 (backbone) 36 | elif backbone == "resnext101_wsl": 37 | pretrained = _make_pretrained_resnext101_wsl(use_pretrained) 38 | scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 39 | elif backbone == "efficientnet_lite3": 40 | pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) 41 | scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 42 | else: 43 | print(f"Backbone '{backbone}' not implemented") 44 | assert False 45 | 46 | return pretrained, scratch 47 | 48 | 49 | def _make_scratch(in_shape, out_shape, groups=1, expand=False): 50 | scratch = nn.Module() 51 | 52 | out_shape1 = out_shape 53 | out_shape2 = out_shape 54 | out_shape3 = out_shape 55 | out_shape4 = out_shape 56 | if expand==True: 57 | out_shape1 = out_shape 58 | out_shape2 = out_shape*2 59 | out_shape3 = out_shape*4 60 | out_shape4 = out_shape*8 61 | 62 | scratch.layer1_rn = nn.Conv2d( 63 | in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 64 | ) 65 | scratch.layer2_rn = nn.Conv2d( 66 | in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 67 | ) 68 | scratch.layer3_rn = nn.Conv2d( 69 | in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 70 | ) 71 | scratch.layer4_rn = nn.Conv2d( 72 | in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 73 | ) 74 | 75 | return scratch 76 | 77 | 78 | def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): 79 | efficientnet = torch.hub.load( 80 | "rwightman/gen-efficientnet-pytorch", 81 | "tf_efficientnet_lite3", 82 | pretrained=use_pretrained, 83 | exportable=exportable 84 | ) 85 | return _make_efficientnet_backbone(efficientnet) 86 | 87 | 88 | def _make_efficientnet_backbone(effnet): 89 | pretrained = nn.Module() 90 | 91 | pretrained.layer1 = nn.Sequential( 92 | effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] 93 | ) 94 | pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) 95 | pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) 96 | pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) 97 | 98 | return pretrained 99 | 100 | 101 | def _make_resnet_backbone(resnet): 102 | pretrained = nn.Module() 103 | pretrained.layer1 = nn.Sequential( 104 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 105 | ) 106 | 107 | pretrained.layer2 = resnet.layer2 108 | pretrained.layer3 = resnet.layer3 109 | pretrained.layer4 = resnet.layer4 110 | 111 | return pretrained 112 | 113 | 114 | def _make_pretrained_resnext101_wsl(use_pretrained): 115 | resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") 116 | return _make_resnet_backbone(resnet) 117 | 118 | 119 | 120 | class Interpolate(nn.Module): 121 | """Interpolation module. 122 | """ 123 | 124 | def __init__(self, scale_factor, mode, align_corners=False): 125 | """Init. 126 | 127 | Args: 128 | scale_factor (float): scaling 129 | mode (str): interpolation mode 130 | """ 131 | super(Interpolate, self).__init__() 132 | 133 | self.interp = nn.functional.interpolate 134 | self.scale_factor = scale_factor 135 | self.mode = mode 136 | self.align_corners = align_corners 137 | 138 | def forward(self, x): 139 | """Forward pass. 140 | 141 | Args: 142 | x (tensor): input 143 | 144 | Returns: 145 | tensor: interpolated data 146 | """ 147 | 148 | x = self.interp( 149 | x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners 150 | ) 151 | 152 | return x 153 | 154 | 155 | class ResidualConvUnit(nn.Module): 156 | """Residual convolution module. 157 | """ 158 | 159 | def __init__(self, features): 160 | """Init. 161 | 162 | Args: 163 | features (int): number of features 164 | """ 165 | super().__init__() 166 | 167 | self.conv1 = nn.Conv2d( 168 | features, features, kernel_size=3, stride=1, padding=1, bias=True 169 | ) 170 | 171 | self.conv2 = nn.Conv2d( 172 | features, features, kernel_size=3, stride=1, padding=1, bias=True 173 | ) 174 | 175 | self.relu = nn.ReLU(inplace=True) 176 | 177 | def forward(self, x): 178 | """Forward pass. 179 | 180 | Args: 181 | x (tensor): input 182 | 183 | Returns: 184 | tensor: output 185 | """ 186 | out = self.relu(x) 187 | out = self.conv1(out) 188 | out = self.relu(out) 189 | out = self.conv2(out) 190 | 191 | return out + x 192 | 193 | 194 | class FeatureFusionBlock(nn.Module): 195 | """Feature fusion block. 196 | """ 197 | 198 | def __init__(self, features): 199 | """Init. 200 | 201 | Args: 202 | features (int): number of features 203 | """ 204 | super(FeatureFusionBlock, self).__init__() 205 | 206 | self.resConfUnit1 = ResidualConvUnit(features) 207 | self.resConfUnit2 = ResidualConvUnit(features) 208 | 209 | def forward(self, *xs): 210 | """Forward pass. 211 | 212 | Returns: 213 | tensor: output 214 | """ 215 | output = xs[0] 216 | 217 | if len(xs) == 2: 218 | output += self.resConfUnit1(xs[1]) 219 | 220 | output = self.resConfUnit2(output) 221 | 222 | output = nn.functional.interpolate( 223 | output, scale_factor=2, mode="bilinear", align_corners=True 224 | ) 225 | 226 | return output 227 | 228 | 229 | 230 | 231 | class ResidualConvUnit_custom(nn.Module): 232 | """Residual convolution module. 233 | """ 234 | 235 | def __init__(self, features, activation, bn): 236 | """Init. 237 | 238 | Args: 239 | features (int): number of features 240 | """ 241 | super().__init__() 242 | 243 | self.bn = bn 244 | 245 | self.groups=1 246 | 247 | self.conv1 = nn.Conv2d( 248 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups 249 | ) 250 | 251 | self.conv2 = nn.Conv2d( 252 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups 253 | ) 254 | 255 | if self.bn==True: 256 | self.bn1 = nn.BatchNorm2d(features) 257 | self.bn2 = nn.BatchNorm2d(features) 258 | 259 | self.activation = activation 260 | 261 | self.skip_add = nn.quantized.FloatFunctional() 262 | 263 | def forward(self, x): 264 | """Forward pass. 265 | 266 | Args: 267 | x (tensor): input 268 | 269 | Returns: 270 | tensor: output 271 | """ 272 | 273 | out = self.activation(x) 274 | out = self.conv1(out) 275 | if self.bn==True: 276 | out = self.bn1(out) 277 | 278 | out = self.activation(out) 279 | out = self.conv2(out) 280 | if self.bn==True: 281 | out = self.bn2(out) 282 | 283 | if self.groups > 1: 284 | out = self.conv_merge(out) 285 | 286 | return self.skip_add.add(out, x) 287 | 288 | # return out + x 289 | 290 | 291 | class FeatureFusionBlock_custom(nn.Module): 292 | """Feature fusion block. 293 | """ 294 | 295 | def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): 296 | """Init. 297 | 298 | Args: 299 | features (int): number of features 300 | """ 301 | super(FeatureFusionBlock_custom, self).__init__() 302 | 303 | self.deconv = deconv 304 | self.align_corners = align_corners 305 | 306 | self.groups=1 307 | 308 | self.expand = expand 309 | out_features = features 310 | if self.expand==True: 311 | out_features = features//2 312 | 313 | self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) 314 | 315 | self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) 316 | self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) 317 | 318 | self.skip_add = nn.quantized.FloatFunctional() 319 | 320 | def forward(self, *xs): 321 | """Forward pass. 322 | 323 | Returns: 324 | tensor: output 325 | """ 326 | output = xs[0] 327 | 328 | if len(xs) == 2: 329 | res = self.resConfUnit1(xs[1]) 330 | output = self.skip_add.add(output, res) 331 | # output += res 332 | 333 | output = self.resConfUnit2(output) 334 | 335 | output = nn.functional.interpolate( 336 | output, scale_factor=2, mode="bilinear", align_corners=self.align_corners 337 | ) 338 | 339 | output = self.out_conv(output) 340 | 341 | return output 342 | 343 | -------------------------------------------------------------------------------- /intrinsic_compositing/albedo/model/MiDaS/midas/dpt_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base_model import BaseModel 6 | from .blocks import ( 7 | FeatureFusionBlock, 8 | FeatureFusionBlock_custom, 9 | Interpolate, 10 | _make_encoder, 11 | forward_vit, 12 | ) 13 | 14 | 15 | def _make_fusion_block(features, use_bn): 16 | return FeatureFusionBlock_custom( 17 | features, 18 | nn.ReLU(False), 19 | deconv=False, 20 | bn=use_bn, 21 | expand=False, 22 | align_corners=True, 23 | ) 24 | 25 | 26 | class DPT(BaseModel): 27 | def __init__( 28 | self, 29 | head, 30 | features=256, 31 | backbone="vitb_rn50_384", 32 | readout="project", 33 | channels_last=False, 34 | use_bn=False, 35 | ): 36 | 37 | super(DPT, self).__init__() 38 | 39 | self.channels_last = channels_last 40 | 41 | hooks = { 42 | "vitb_rn50_384": [0, 1, 8, 11], 43 | "vitb16_384": [2, 5, 8, 11], 44 | "vitl16_384": [5, 11, 17, 23], 45 | } 46 | 47 | # Instantiate backbone and reassemble blocks 48 | self.pretrained, self.scratch = _make_encoder( 49 | backbone, 50 | features, 51 | False, # Set to true of you want to train from scratch, uses ImageNet weights 52 | groups=1, 53 | expand=False, 54 | exportable=False, 55 | hooks=hooks[backbone], 56 | use_readout=readout, 57 | ) 58 | 59 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn) 60 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn) 61 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn) 62 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn) 63 | 64 | self.scratch.output_conv = head 65 | 66 | 67 | def forward(self, x): 68 | if self.channels_last == True: 69 | x.contiguous(memory_format=torch.channels_last) 70 | 71 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) 72 | 73 | layer_1_rn = self.scratch.layer1_rn(layer_1) 74 | layer_2_rn = self.scratch.layer2_rn(layer_2) 75 | layer_3_rn = self.scratch.layer3_rn(layer_3) 76 | layer_4_rn = self.scratch.layer4_rn(layer_4) 77 | 78 | path_4 = self.scratch.refinenet4(layer_4_rn) 79 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 80 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 81 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 82 | 83 | out = self.scratch.output_conv(path_1) 84 | 85 | return out 86 | 87 | 88 | class DPTDepthModel(DPT): 89 | def __init__(self, path=None, non_negative=True, **kwargs): 90 | features = kwargs["features"] if "features" in kwargs else 256 91 | 92 | head = nn.Sequential( 93 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), 94 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True), 95 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), 96 | nn.ReLU(True), 97 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 98 | nn.ReLU(True) if non_negative else nn.Identity(), 99 | nn.Identity(), 100 | ) 101 | 102 | super().__init__(head, **kwargs) 103 | 104 | if path is not None: 105 | self.load(path) 106 | 107 | def forward(self, x): 108 | return super().forward(x).squeeze(dim=1) 109 | 110 | -------------------------------------------------------------------------------- /intrinsic_compositing/albedo/model/MiDaS/midas/midas_net.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=256, non_negative=True): 17 | """Init. 18 | 19 | Args: 20 | path (str, optional): Path to saved model. Defaults to None. 21 | features (int, optional): Number of features. Defaults to 256. 22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 23 | """ 24 | print("Loading weights: ", path) 25 | 26 | super(MidasNet, self).__init__() 27 | 28 | use_pretrained = False if path is None else True 29 | 30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) 31 | 32 | self.scratch.refinenet4 = FeatureFusionBlock(features) 33 | self.scratch.refinenet3 = FeatureFusionBlock(features) 34 | self.scratch.refinenet2 = FeatureFusionBlock(features) 35 | self.scratch.refinenet1 = FeatureFusionBlock(features) 36 | 37 | self.scratch.output_conv = nn.Sequential( 38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 39 | Interpolate(scale_factor=2, mode="bilinear"), 40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 41 | nn.ReLU(True), 42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 43 | nn.ReLU(True) if non_negative else nn.Identity(), 44 | ) 45 | 46 | if path: 47 | self.load(path) 48 | 49 | def forward(self, x): 50 | """Forward pass. 51 | 52 | Args: 53 | x (tensor): input data (image) 54 | 55 | Returns: 56 | tensor: depth 57 | """ 58 | 59 | layer_1 = self.pretrained.layer1(x) 60 | layer_2 = self.pretrained.layer2(layer_1) 61 | layer_3 = self.pretrained.layer3(layer_2) 62 | layer_4 = self.pretrained.layer4(layer_3) 63 | 64 | layer_1_rn = self.scratch.layer1_rn(layer_1) 65 | layer_2_rn = self.scratch.layer2_rn(layer_2) 66 | layer_3_rn = self.scratch.layer3_rn(layer_3) 67 | layer_4_rn = self.scratch.layer4_rn(layer_4) 68 | 69 | path_4 = self.scratch.refinenet4(layer_4_rn) 70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 73 | 74 | out = self.scratch.output_conv(path_1) 75 | 76 | return torch.squeeze(out, dim=1) 77 | -------------------------------------------------------------------------------- /intrinsic_compositing/albedo/model/MiDaS/midas/midas_net_custom.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet_small(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, 17 | blocks={'expand': True}): 18 | """Init. 19 | 20 | Args: 21 | path (str, optional): Path to saved model. Defaults to None. 22 | features (int, optional): Number of features. Defaults to 256. 23 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 24 | """ 25 | print("Loading weights: ", path) 26 | 27 | super(MidasNet_small, self).__init__() 28 | 29 | use_pretrained = False if path else True 30 | 31 | self.channels_last = channels_last 32 | self.blocks = blocks 33 | self.backbone = backbone 34 | 35 | self.groups = 1 36 | 37 | features1=features 38 | features2=features 39 | features3=features 40 | features4=features 41 | self.expand = False 42 | if "expand" in self.blocks and self.blocks['expand'] == True: 43 | self.expand = True 44 | features1=features 45 | features2=features*2 46 | features3=features*4 47 | features4=features*8 48 | 49 | self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) 50 | 51 | self.scratch.activation = nn.ReLU(False) 52 | 53 | self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 54 | self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 55 | self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 56 | self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) 57 | 58 | 59 | self.scratch.output_conv = nn.Sequential( 60 | nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), 61 | Interpolate(scale_factor=2, mode="bilinear"), 62 | nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), 63 | self.scratch.activation, 64 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 65 | nn.ReLU(True) if non_negative else nn.Identity(), 66 | nn.Identity(), 67 | ) 68 | 69 | if path: 70 | self.load(path) 71 | 72 | 73 | def forward(self, x): 74 | """Forward pass. 75 | 76 | Args: 77 | x (tensor): input data (image) 78 | 79 | Returns: 80 | tensor: depth 81 | """ 82 | if self.channels_last==True: 83 | print("self.channels_last = ", self.channels_last) 84 | x.contiguous(memory_format=torch.channels_last) 85 | 86 | 87 | layer_1 = self.pretrained.layer1(x) 88 | layer_2 = self.pretrained.layer2(layer_1) 89 | layer_3 = self.pretrained.layer3(layer_2) 90 | layer_4 = self.pretrained.layer4(layer_3) 91 | 92 | layer_1_rn = self.scratch.layer1_rn(layer_1) 93 | layer_2_rn = self.scratch.layer2_rn(layer_2) 94 | layer_3_rn = self.scratch.layer3_rn(layer_3) 95 | layer_4_rn = self.scratch.layer4_rn(layer_4) 96 | 97 | 98 | path_4 = self.scratch.refinenet4(layer_4_rn) 99 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 100 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 101 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 102 | 103 | out = self.scratch.output_conv(path_1) 104 | 105 | return torch.squeeze(out, dim=1) 106 | 107 | 108 | 109 | def fuse_model(m): 110 | prev_previous_type = nn.Identity() 111 | prev_previous_name = '' 112 | previous_type = nn.Identity() 113 | previous_name = '' 114 | for name, module in m.named_modules(): 115 | if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: 116 | # print("FUSED ", prev_previous_name, previous_name, name) 117 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) 118 | elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: 119 | # print("FUSED ", prev_previous_name, previous_name) 120 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) 121 | # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: 122 | # print("FUSED ", previous_name, name) 123 | # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) 124 | 125 | prev_previous_type = previous_type 126 | prev_previous_name = previous_name 127 | previous_type = type(module) 128 | previous_name = name -------------------------------------------------------------------------------- /intrinsic_compositing/albedo/model/MiDaS/midas/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import math 4 | 5 | 6 | def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): 7 | """Rezise the sample to ensure the given size. Keeps aspect ratio. 8 | 9 | Args: 10 | sample (dict): sample 11 | size (tuple): image size 12 | 13 | Returns: 14 | tuple: new size 15 | """ 16 | shape = list(sample["disparity"].shape) 17 | 18 | if shape[0] >= size[0] and shape[1] >= size[1]: 19 | return sample 20 | 21 | scale = [0, 0] 22 | scale[0] = size[0] / shape[0] 23 | scale[1] = size[1] / shape[1] 24 | 25 | scale = max(scale) 26 | 27 | shape[0] = math.ceil(scale * shape[0]) 28 | shape[1] = math.ceil(scale * shape[1]) 29 | 30 | # resize 31 | sample["image"] = cv2.resize( 32 | sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method 33 | ) 34 | 35 | sample["disparity"] = cv2.resize( 36 | sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST 37 | ) 38 | sample["mask"] = cv2.resize( 39 | sample["mask"].astype(np.single), 40 | tuple(shape[::-1]), 41 | interpolation=cv2.INTER_NEAREST, 42 | ) 43 | sample["mask"] = sample["mask"].astype(bool) 44 | 45 | return tuple(shape) 46 | 47 | 48 | class Resize(object): 49 | """Resize sample to given size (width, height). 50 | """ 51 | 52 | def __init__( 53 | self, 54 | width, 55 | height, 56 | resize_target=True, 57 | keep_aspect_ratio=False, 58 | ensure_multiple_of=1, 59 | resize_method="lower_bound", 60 | image_interpolation_method=cv2.INTER_AREA, 61 | ): 62 | """Init. 63 | 64 | Args: 65 | width (int): desired output width 66 | height (int): desired output height 67 | resize_target (bool, optional): 68 | True: Resize the full sample (image, mask, target). 69 | False: Resize image only. 70 | Defaults to True. 71 | keep_aspect_ratio (bool, optional): 72 | True: Keep the aspect ratio of the input sample. 73 | Output sample might not have the given width and height, and 74 | resize behaviour depends on the parameter 'resize_method'. 75 | Defaults to False. 76 | ensure_multiple_of (int, optional): 77 | Output width and height is constrained to be multiple of this parameter. 78 | Defaults to 1. 79 | resize_method (str, optional): 80 | "lower_bound": Output will be at least as large as the given size. 81 | "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) 82 | "minimal": Scale as least as possible. (Output size might be smaller than given size.) 83 | Defaults to "lower_bound". 84 | """ 85 | self.__width = width 86 | self.__height = height 87 | 88 | self.__resize_target = resize_target 89 | self.__keep_aspect_ratio = keep_aspect_ratio 90 | self.__multiple_of = ensure_multiple_of 91 | self.__resize_method = resize_method 92 | self.__image_interpolation_method = image_interpolation_method 93 | 94 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None): 95 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) 96 | 97 | if max_val is not None and y > max_val: 98 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) 99 | 100 | if y < min_val: 101 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) 102 | 103 | return y 104 | 105 | def get_size(self, width, height): 106 | # determine new height and width 107 | scale_height = self.__height / height 108 | scale_width = self.__width / width 109 | 110 | if self.__keep_aspect_ratio: 111 | if self.__resize_method == "lower_bound": 112 | # scale such that output size is lower bound 113 | if scale_width > scale_height: 114 | # fit width 115 | scale_height = scale_width 116 | else: 117 | # fit height 118 | scale_width = scale_height 119 | elif self.__resize_method == "upper_bound": 120 | # scale such that output size is upper bound 121 | if scale_width < scale_height: 122 | # fit width 123 | scale_height = scale_width 124 | else: 125 | # fit height 126 | scale_width = scale_height 127 | elif self.__resize_method == "minimal": 128 | # scale as least as possbile 129 | if abs(1 - scale_width) < abs(1 - scale_height): 130 | # fit width 131 | scale_height = scale_width 132 | else: 133 | # fit height 134 | scale_width = scale_height 135 | else: 136 | raise ValueError( 137 | f"resize_method {self.__resize_method} not implemented" 138 | ) 139 | 140 | if self.__resize_method == "lower_bound": 141 | new_height = self.constrain_to_multiple_of( 142 | scale_height * height, min_val=self.__height 143 | ) 144 | new_width = self.constrain_to_multiple_of( 145 | scale_width * width, min_val=self.__width 146 | ) 147 | elif self.__resize_method == "upper_bound": 148 | new_height = self.constrain_to_multiple_of( 149 | scale_height * height, max_val=self.__height 150 | ) 151 | new_width = self.constrain_to_multiple_of( 152 | scale_width * width, max_val=self.__width 153 | ) 154 | elif self.__resize_method == "minimal": 155 | new_height = self.constrain_to_multiple_of(scale_height * height) 156 | new_width = self.constrain_to_multiple_of(scale_width * width) 157 | else: 158 | raise ValueError(f"resize_method {self.__resize_method} not implemented") 159 | 160 | return (new_width, new_height) 161 | 162 | def __call__(self, sample): 163 | width, height = self.get_size( 164 | sample["image"].shape[1], sample["image"].shape[0] 165 | ) 166 | 167 | # resize sample 168 | sample["image"] = cv2.resize( 169 | sample["image"], 170 | (width, height), 171 | interpolation=self.__image_interpolation_method, 172 | ) 173 | 174 | if self.__resize_target: 175 | if "disparity" in sample: 176 | sample["disparity"] = cv2.resize( 177 | sample["disparity"], 178 | (width, height), 179 | interpolation=cv2.INTER_NEAREST, 180 | ) 181 | 182 | if "depth" in sample: 183 | sample["depth"] = cv2.resize( 184 | sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST 185 | ) 186 | 187 | sample["mask"] = cv2.resize( 188 | sample["mask"].astype(np.single), 189 | (width, height), 190 | interpolation=cv2.INTER_NEAREST, 191 | ) 192 | sample["mask"] = sample["mask"].astype(bool) 193 | 194 | return sample 195 | 196 | 197 | class NormalizeImage(object): 198 | """Normlize image by given mean and std. 199 | """ 200 | 201 | def __init__(self, mean, std): 202 | self.__mean = mean 203 | self.__std = std 204 | 205 | def __call__(self, sample): 206 | sample["image"] = (sample["image"] - self.__mean) / self.__std 207 | 208 | return sample 209 | 210 | 211 | class PrepareForNet(object): 212 | """Prepare sample for usage as network input. 213 | """ 214 | 215 | def __init__(self): 216 | pass 217 | 218 | def __call__(self, sample): 219 | image = np.transpose(sample["image"], (2, 0, 1)) 220 | sample["image"] = np.ascontiguousarray(image).astype(np.single) 221 | 222 | if "mask" in sample: 223 | sample["mask"] = sample["mask"].astype(np.single) 224 | sample["mask"] = np.ascontiguousarray(sample["mask"]) 225 | 226 | if "disparity" in sample: 227 | disparity = sample["disparity"].astype(np.single) 228 | sample["disparity"] = np.ascontiguousarray(disparity) 229 | 230 | if "depth" in sample: 231 | depth = sample["depth"].astype(np.single) 232 | sample["depth"] = np.ascontiguousarray(depth) 233 | 234 | return sample 235 | -------------------------------------------------------------------------------- /intrinsic_compositing/albedo/model/MiDaS/midas/vit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import timm 4 | import types 5 | import math 6 | import torch.nn.functional as F 7 | 8 | 9 | class Slice(nn.Module): 10 | def __init__(self, start_index=1): 11 | super(Slice, self).__init__() 12 | self.start_index = start_index 13 | 14 | def forward(self, x): 15 | return x[:, self.start_index :] 16 | 17 | 18 | class AddReadout(nn.Module): 19 | def __init__(self, start_index=1): 20 | super(AddReadout, self).__init__() 21 | self.start_index = start_index 22 | 23 | def forward(self, x): 24 | if self.start_index == 2: 25 | readout = (x[:, 0] + x[:, 1]) / 2 26 | else: 27 | readout = x[:, 0] 28 | return x[:, self.start_index :] + readout.unsqueeze(1) 29 | 30 | 31 | class ProjectReadout(nn.Module): 32 | def __init__(self, in_features, start_index=1): 33 | super(ProjectReadout, self).__init__() 34 | self.start_index = start_index 35 | 36 | self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) 37 | 38 | def forward(self, x): 39 | readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :]) 40 | features = torch.cat((x[:, self.start_index :], readout), -1) 41 | 42 | return self.project(features) 43 | 44 | 45 | class Transpose(nn.Module): 46 | def __init__(self, dim0, dim1): 47 | super(Transpose, self).__init__() 48 | self.dim0 = dim0 49 | self.dim1 = dim1 50 | 51 | def forward(self, x): 52 | x = x.transpose(self.dim0, self.dim1) 53 | return x 54 | 55 | 56 | def forward_vit(pretrained, x): 57 | b, c, h, w = x.shape 58 | 59 | glob = pretrained.model.forward_flex(x) 60 | 61 | layer_1 = pretrained.activations["1"] 62 | layer_2 = pretrained.activations["2"] 63 | layer_3 = pretrained.activations["3"] 64 | layer_4 = pretrained.activations["4"] 65 | 66 | layer_1 = pretrained.act_postprocess1[0:2](layer_1) 67 | layer_2 = pretrained.act_postprocess2[0:2](layer_2) 68 | layer_3 = pretrained.act_postprocess3[0:2](layer_3) 69 | layer_4 = pretrained.act_postprocess4[0:2](layer_4) 70 | 71 | unflatten = nn.Sequential( 72 | nn.Unflatten( 73 | 2, 74 | torch.Size( 75 | [ 76 | h // pretrained.model.patch_size[1], 77 | w // pretrained.model.patch_size[0], 78 | ] 79 | ), 80 | ) 81 | ) 82 | 83 | if layer_1.ndim == 3: 84 | layer_1 = unflatten(layer_1) 85 | if layer_2.ndim == 3: 86 | layer_2 = unflatten(layer_2) 87 | if layer_3.ndim == 3: 88 | layer_3 = unflatten(layer_3) 89 | if layer_4.ndim == 3: 90 | layer_4 = unflatten(layer_4) 91 | 92 | layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) 93 | layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) 94 | layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) 95 | layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) 96 | 97 | return layer_1, layer_2, layer_3, layer_4 98 | 99 | 100 | def _resize_pos_embed(self, posemb, gs_h, gs_w): 101 | posemb_tok, posemb_grid = ( 102 | posemb[:, : self.start_index], 103 | posemb[0, self.start_index :], 104 | ) 105 | 106 | gs_old = int(math.sqrt(len(posemb_grid))) 107 | 108 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) 109 | posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") 110 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) 111 | 112 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1) 113 | 114 | return posemb 115 | 116 | 117 | def forward_flex(self, x): 118 | b, c, h, w = x.shape 119 | 120 | pos_embed = self._resize_pos_embed( 121 | self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] 122 | ) 123 | 124 | B = x.shape[0] 125 | 126 | if hasattr(self.patch_embed, "backbone"): 127 | x = self.patch_embed.backbone(x) 128 | if isinstance(x, (list, tuple)): 129 | x = x[-1] # last feature if backbone outputs list/tuple of features 130 | 131 | x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) 132 | 133 | if getattr(self, "dist_token", None) is not None: 134 | cls_tokens = self.cls_token.expand( 135 | B, -1, -1 136 | ) # stole cls_tokens impl from Phil Wang, thanks 137 | dist_token = self.dist_token.expand(B, -1, -1) 138 | x = torch.cat((cls_tokens, dist_token, x), dim=1) 139 | else: 140 | cls_tokens = self.cls_token.expand( 141 | B, -1, -1 142 | ) # stole cls_tokens impl from Phil Wang, thanks 143 | x = torch.cat((cls_tokens, x), dim=1) 144 | 145 | x = x + pos_embed 146 | x = self.pos_drop(x) 147 | 148 | for blk in self.blocks: 149 | x = blk(x) 150 | 151 | x = self.norm(x) 152 | 153 | return x 154 | 155 | 156 | activations = {} 157 | 158 | 159 | def get_activation(name): 160 | def hook(model, input, output): 161 | activations[name] = output 162 | 163 | return hook 164 | 165 | 166 | def get_readout_oper(vit_features, features, use_readout, start_index=1): 167 | if use_readout == "ignore": 168 | readout_oper = [Slice(start_index)] * len(features) 169 | elif use_readout == "add": 170 | readout_oper = [AddReadout(start_index)] * len(features) 171 | elif use_readout == "project": 172 | readout_oper = [ 173 | ProjectReadout(vit_features, start_index) for out_feat in features 174 | ] 175 | else: 176 | assert ( 177 | False 178 | ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" 179 | 180 | return readout_oper 181 | 182 | 183 | def _make_vit_b16_backbone( 184 | model, 185 | features=[96, 192, 384, 768], 186 | size=[384, 384], 187 | hooks=[2, 5, 8, 11], 188 | vit_features=768, 189 | use_readout="ignore", 190 | start_index=1, 191 | ): 192 | pretrained = nn.Module() 193 | 194 | pretrained.model = model 195 | pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) 196 | pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) 197 | pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) 198 | pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) 199 | 200 | pretrained.activations = activations 201 | 202 | readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) 203 | 204 | # 32, 48, 136, 384 205 | pretrained.act_postprocess1 = nn.Sequential( 206 | readout_oper[0], 207 | Transpose(1, 2), 208 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 209 | nn.Conv2d( 210 | in_channels=vit_features, 211 | out_channels=features[0], 212 | kernel_size=1, 213 | stride=1, 214 | padding=0, 215 | ), 216 | nn.ConvTranspose2d( 217 | in_channels=features[0], 218 | out_channels=features[0], 219 | kernel_size=4, 220 | stride=4, 221 | padding=0, 222 | bias=True, 223 | dilation=1, 224 | groups=1, 225 | ), 226 | ) 227 | 228 | pretrained.act_postprocess2 = nn.Sequential( 229 | readout_oper[1], 230 | Transpose(1, 2), 231 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 232 | nn.Conv2d( 233 | in_channels=vit_features, 234 | out_channels=features[1], 235 | kernel_size=1, 236 | stride=1, 237 | padding=0, 238 | ), 239 | nn.ConvTranspose2d( 240 | in_channels=features[1], 241 | out_channels=features[1], 242 | kernel_size=2, 243 | stride=2, 244 | padding=0, 245 | bias=True, 246 | dilation=1, 247 | groups=1, 248 | ), 249 | ) 250 | 251 | pretrained.act_postprocess3 = nn.Sequential( 252 | readout_oper[2], 253 | Transpose(1, 2), 254 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 255 | nn.Conv2d( 256 | in_channels=vit_features, 257 | out_channels=features[2], 258 | kernel_size=1, 259 | stride=1, 260 | padding=0, 261 | ), 262 | ) 263 | 264 | pretrained.act_postprocess4 = nn.Sequential( 265 | readout_oper[3], 266 | Transpose(1, 2), 267 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 268 | nn.Conv2d( 269 | in_channels=vit_features, 270 | out_channels=features[3], 271 | kernel_size=1, 272 | stride=1, 273 | padding=0, 274 | ), 275 | nn.Conv2d( 276 | in_channels=features[3], 277 | out_channels=features[3], 278 | kernel_size=3, 279 | stride=2, 280 | padding=1, 281 | ), 282 | ) 283 | 284 | pretrained.model.start_index = start_index 285 | pretrained.model.patch_size = [16, 16] 286 | 287 | # We inject this function into the VisionTransformer instances so that 288 | # we can use it with interpolated position embeddings without modifying the library source. 289 | pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) 290 | pretrained.model._resize_pos_embed = types.MethodType( 291 | _resize_pos_embed, pretrained.model 292 | ) 293 | 294 | return pretrained 295 | 296 | 297 | def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None): 298 | model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) 299 | 300 | hooks = [5, 11, 17, 23] if hooks == None else hooks 301 | return _make_vit_b16_backbone( 302 | model, 303 | features=[256, 512, 1024, 1024], 304 | hooks=hooks, 305 | vit_features=1024, 306 | use_readout=use_readout, 307 | ) 308 | 309 | 310 | def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): 311 | model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) 312 | 313 | hooks = [2, 5, 8, 11] if hooks == None else hooks 314 | return _make_vit_b16_backbone( 315 | model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout 316 | ) 317 | 318 | 319 | def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None): 320 | model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) 321 | 322 | hooks = [2, 5, 8, 11] if hooks == None else hooks 323 | return _make_vit_b16_backbone( 324 | model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout 325 | ) 326 | 327 | 328 | def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None): 329 | model = timm.create_model( 330 | "vit_deit_base_distilled_patch16_384", pretrained=pretrained 331 | ) 332 | 333 | hooks = [2, 5, 8, 11] if hooks == None else hooks 334 | return _make_vit_b16_backbone( 335 | model, 336 | features=[96, 192, 384, 768], 337 | hooks=hooks, 338 | use_readout=use_readout, 339 | start_index=2, 340 | ) 341 | 342 | 343 | def _make_vit_b_rn50_backbone( 344 | model, 345 | features=[256, 512, 768, 768], 346 | size=[384, 384], 347 | hooks=[0, 1, 8, 11], 348 | vit_features=768, 349 | use_vit_only=False, 350 | use_readout="ignore", 351 | start_index=1, 352 | ): 353 | pretrained = nn.Module() 354 | 355 | pretrained.model = model 356 | 357 | if use_vit_only == True: 358 | pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) 359 | pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) 360 | else: 361 | pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( 362 | get_activation("1") 363 | ) 364 | pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( 365 | get_activation("2") 366 | ) 367 | 368 | pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) 369 | pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) 370 | 371 | pretrained.activations = activations 372 | 373 | readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) 374 | 375 | if use_vit_only == True: 376 | pretrained.act_postprocess1 = nn.Sequential( 377 | readout_oper[0], 378 | Transpose(1, 2), 379 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 380 | nn.Conv2d( 381 | in_channels=vit_features, 382 | out_channels=features[0], 383 | kernel_size=1, 384 | stride=1, 385 | padding=0, 386 | ), 387 | nn.ConvTranspose2d( 388 | in_channels=features[0], 389 | out_channels=features[0], 390 | kernel_size=4, 391 | stride=4, 392 | padding=0, 393 | bias=True, 394 | dilation=1, 395 | groups=1, 396 | ), 397 | ) 398 | 399 | pretrained.act_postprocess2 = nn.Sequential( 400 | readout_oper[1], 401 | Transpose(1, 2), 402 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 403 | nn.Conv2d( 404 | in_channels=vit_features, 405 | out_channels=features[1], 406 | kernel_size=1, 407 | stride=1, 408 | padding=0, 409 | ), 410 | nn.ConvTranspose2d( 411 | in_channels=features[1], 412 | out_channels=features[1], 413 | kernel_size=2, 414 | stride=2, 415 | padding=0, 416 | bias=True, 417 | dilation=1, 418 | groups=1, 419 | ), 420 | ) 421 | else: 422 | pretrained.act_postprocess1 = nn.Sequential( 423 | nn.Identity(), nn.Identity(), nn.Identity() 424 | ) 425 | pretrained.act_postprocess2 = nn.Sequential( 426 | nn.Identity(), nn.Identity(), nn.Identity() 427 | ) 428 | 429 | pretrained.act_postprocess3 = nn.Sequential( 430 | readout_oper[2], 431 | Transpose(1, 2), 432 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 433 | nn.Conv2d( 434 | in_channels=vit_features, 435 | out_channels=features[2], 436 | kernel_size=1, 437 | stride=1, 438 | padding=0, 439 | ), 440 | ) 441 | 442 | pretrained.act_postprocess4 = nn.Sequential( 443 | readout_oper[3], 444 | Transpose(1, 2), 445 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 446 | nn.Conv2d( 447 | in_channels=vit_features, 448 | out_channels=features[3], 449 | kernel_size=1, 450 | stride=1, 451 | padding=0, 452 | ), 453 | nn.Conv2d( 454 | in_channels=features[3], 455 | out_channels=features[3], 456 | kernel_size=3, 457 | stride=2, 458 | padding=1, 459 | ), 460 | ) 461 | 462 | pretrained.model.start_index = start_index 463 | pretrained.model.patch_size = [16, 16] 464 | 465 | # We inject this function into the VisionTransformer instances so that 466 | # we can use it with interpolated position embeddings without modifying the library source. 467 | pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) 468 | 469 | # We inject this function into the VisionTransformer instances so that 470 | # we can use it with interpolated position embeddings without modifying the library source. 471 | pretrained.model._resize_pos_embed = types.MethodType( 472 | _resize_pos_embed, pretrained.model 473 | ) 474 | 475 | return pretrained 476 | 477 | 478 | def _make_pretrained_vitb_rn50_384( 479 | pretrained, use_readout="ignore", hooks=None, use_vit_only=False 480 | ): 481 | model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) 482 | 483 | hooks = [0, 1, 8, 11] if hooks == None else hooks 484 | return _make_vit_b_rn50_backbone( 485 | model, 486 | features=[256, 512, 768, 768], 487 | size=[384, 384], 488 | hooks=hooks, 489 | use_vit_only=use_vit_only, 490 | use_readout=use_readout, 491 | ) 492 | -------------------------------------------------------------------------------- /intrinsic_compositing/albedo/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/intrinsic_compositing/albedo/model/__init__.py -------------------------------------------------------------------------------- /intrinsic_compositing/albedo/model/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | from .pix2pix.models.networks import NLayerDiscriminator,get_norm_layer 5 | from ..utils.utils import normalize 6 | 7 | class VOTEGAN(nn.Module): 8 | def __init__(self, args): 9 | super(VOTEGAN, self).__init__() 10 | n_ic = 4 11 | if args.crop_size == 384: 12 | n_f = 22*22 13 | else: 14 | raise NotImplementedError 15 | 16 | ndf = 64 # of discrim filters in the first conv layer 17 | norm_layer = get_norm_layer(norm_type='instance') 18 | self.model = nn.Sequential(NLayerDiscriminator(n_ic, ndf, n_layers=4, norm_layer=norm_layer), # 70*70 patchgan for 256*256 19 | nn.Flatten(start_dim=1), 20 | nn.Linear(n_f,32), 21 | nn.Linear(32,1), 22 | nn.Sigmoid() 23 | ) 24 | 25 | 26 | def forward(self,x): 27 | feat = self.model(normalize(x)) 28 | return feat 29 | -------------------------------------------------------------------------------- /intrinsic_compositing/albedo/model/editingnetwork_trainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from ..utils.networkutils import init_net, loadmodelweights 7 | from ..utils.edits import apply_colorcurve, apply_exposure, apply_saturation, apply_whitebalancing, get_edits 8 | 9 | from .parametermodel import ParametersRegressor 10 | from .discriminator import VOTEGAN 11 | 12 | 13 | class EditingNetworkTrainer: 14 | def __init__(self, args): 15 | 16 | self.args = args 17 | self.edits = get_edits(args.nops) 18 | self.device = torch.device('cuda:{}'.format(args.gpu_ids[0])) if args.gpu_ids else torch.device('cpu') 19 | self.model_names = ['Parameters'] 20 | 21 | self.net_Parameters = init_net(ParametersRegressor(args),args.gpu_ids) 22 | self.net_Parameters.train() 23 | 24 | # Initial Editing Network weights 25 | if args.checkpoint_load_path is not None: 26 | loadmodelweights(self.net_Parameters,args.checkpoint_load_path, self.device) 27 | 28 | if 'realism' in args.edit_loss: 29 | self.net_D = init_net(VOTEGAN(args), args.gpu_ids) 30 | self.set_requires_grad(self.net_D, False) 31 | # Load the realism network weights 32 | loadmodelweights(self.net_D,args.realism_model_weight_path, self.device) 33 | self.net_D.eval() 34 | 35 | # Set the optimizers 36 | self.optimizer_Parameters = torch.optim.Adam(self.net_Parameters.parameters(), lr=args.lr_editnet) 37 | 38 | # Set the mode for each network 39 | 40 | # Set the loss functions 41 | self.criterion_L2 = torch.nn.MSELoss() 42 | 43 | # Set the needed constants and parameters 44 | self.logs = [] 45 | 46 | # self.all_permutations = torch.tensor([ 47 | # [0,1,2,3],[0,2,1,3],[0,3,1,2],[0,1,3,2],[0,2,3,1],[0,3,2,1], 48 | # [1,0,2,3],[1,2,0,3],[1,3,0,2],[1,0,3,2],[1,2,3,0],[1,3,2,0], 49 | # [2,0,1,3],[2,1,0,3],[2,3,0,1],[2,0,3,1],[2,1,3,0],[2,3,1,0], 50 | # [3,0,1,2],[3,1,0,2],[3,2,0,1],[3,0,2,1],[3,1,2,0],[3,2,1,0] 51 | # ]).float().to(self.device) 52 | 53 | 54 | def setEval(self): 55 | self.net_Parameters.eval() 56 | 57 | def setTrain(self): 58 | self.net_Parameters.train() 59 | 60 | def setinput(self, input, mergebatch=1): 61 | self.srgb = input['srgb'].to(self.device) 62 | self.albedo = input['albedo'].to(self.device) 63 | self.shading = input['shading'].to(self.device) 64 | self.mask = input['mask'].to(self.device) 65 | 66 | if mergebatch > 1: 67 | self.srgb = torch.reshape(self.srgb, (self.args.batch_size, 3, self.args.crop_size, self.args.crop_size)) 68 | self.albedo = torch.reshape(self.albedo, (self.args.batch_size, 3, self.args.crop_size, self.args.crop_size)) 69 | self.shading = torch.reshape(self.shading, (self.args.batch_size, 1, self.args.crop_size, self.args.crop_size)) 70 | self.mask = torch.reshape(self.mask, (self.args.batch_size, 1, self.args.crop_size, self.args.crop_size)) 71 | 72 | albedo_edited = self.create_fake_edited(self.albedo) 73 | self.albedo_fake = (1 - self.mask) * self.albedo + self.mask * albedo_edited 74 | self.input = torch.cat((self.albedo_fake,self.mask),dim=1).to(self.device) 75 | 76 | # self.numelmask = torch.sum(self.mask,dim=[1,2,3]) 77 | def setinput_HR(self, input): 78 | self.srgb = input['srgb'].to(self.device) 79 | self.albedo_fake = input['albedo'].to(self.device) 80 | self.albedo_full = input['albedo_full'].to(self.device) 81 | self.shading_full = input['shading_full'].to(self.device) 82 | self.mask_full = input['mask_full'].to(self.device) 83 | self.shading = input['shading'].to(self.device) 84 | self.mask = input['mask'].to(self.device) 85 | 86 | self.input = torch.cat((self.albedo_fake, self.mask),dim=1).to(self.device) 87 | 88 | 89 | def create_fake_edited(self,rgb): 90 | # Randomly choose an edit. 91 | edited = rgb.clone() 92 | ne = np.random.randint(0, 4) 93 | perm = torch.randperm(len(self.edits)) 94 | 95 | args = self.args 96 | device = self.device 97 | 98 | for i in range(ne): 99 | edit_id = perm[i] 100 | if self.args.fake_gen_lowdev == 0: 101 | wb_param = torch.rand(args.batch_size, 3).to(device)*0.9 + 0.1 102 | colorcurve = torch.rand(args.batch_size, 24).to(device)*1.5 + 0.5 103 | 104 | sat_param = torch.rand(args.batch_size, 1)*2 105 | sat_param = sat_param.to(device) 106 | 107 | expos_param = torch.rand(args.batch_size, 1)*1.5 + 0.5 108 | expos_param = expos_param.to(device) 109 | 110 | blur_param = torch.rand(args.batch_size, 1)*5 + 0.0001 # to make sure the blur param is never exactly zero. 111 | blur_param = blur_param.to(device) 112 | 113 | sharp_param = torch.rand(args.batch_size, 1)*10 + 1 114 | sharp_param = sharp_param.to(device) 115 | 116 | else: 117 | wb_param = torch.rand(args.batch_size, 3).to(device)*0.2 + 0.5 118 | colorcurve = torch.rand(args.batch_size, 24).to(device)*1.5 + 0.5 119 | 120 | sat_param = torch.rand(args.batch_size, 1)*1 + 0.5 121 | sat_param = sat_param.to(device) 122 | 123 | expos_param = torch.rand(args.batch_size, 1)*1 + 0.5 124 | expos_param = expos_param.to(device) 125 | 126 | blur_param = torch.rand(args.batch_size, 1)*2.5 + 0.0001 # to make sure the blur param is never exactly zero. 127 | blur_param = blur_param.to(device) 128 | 129 | sharp_param = torch.rand(args.batch_size, 1)*5 + 1 130 | sharp_param = sharp_param.to(device) 131 | 132 | parameters = { 133 | 'whitebalancing':wb_param, 134 | 'colorcurve':colorcurve, 135 | 'saturation':sat_param, 136 | 'exposure':expos_param, 137 | 'blur':blur_param, 138 | 'sharpness':sharp_param 139 | } 140 | 141 | 142 | edited = torch.clamp(self.edits[edit_id.item()](edited,parameters),0,1) 143 | 144 | return edited.detach() 145 | 146 | def forward(self): 147 | permutation = torch.randperm(len(self.edits)).float().to(self.device) 148 | params_dic = self.net_Parameters(self.input, permutation.repeat(self.args.batch_size,1)) 149 | # print(params_dic) 150 | self.logs.append(params_dic) 151 | 152 | # current_rgb = self.albedo_fake 153 | current_rgb = self.albedo_full 154 | 155 | for ed_in in range(self.args.nops): 156 | current_edited = torch.clamp(self.edits[permutation[ed_in].item()](current_rgb,params_dic),0,1) 157 | current_result = (1 - self.mask_full) * current_rgb + self.mask_full * current_edited 158 | 159 | current_rgb = current_result 160 | 161 | self.result = current_result 162 | 163 | self.result_albedo_srgb = self.result ** (2.2) 164 | self.result_rgb = self.result_albedo_srgb * self.shading_full 165 | self.result_srgb = torch.clamp(self.result_rgb ** (1/2.2),0,1) 166 | 167 | 168 | 169 | def computeloss_realism(self): 170 | if self.args.edit_loss == 'realism': 171 | after = torch.cat((self.result, self.mask), 1) 172 | after_D_value = self.net_D(after).squeeze(1) 173 | self.realism_change = 1 - after_D_value 174 | self.loss_realism = F.relu(self.realism_change - self.args.loss_relu_bias) 175 | self.loss_g = torch.mean(self.loss_realism) 176 | 177 | elif self.args.edit_loss == 'realism_relative': 178 | before = torch.cat((self.albedo, self.mask), 1) 179 | before_D_value = self.net_D(before).squeeze(1) 180 | after = torch.cat((self.result, self.mask), 1) 181 | after_D_value = self.net_D(after).squeeze(1) 182 | 183 | self.realism_change = before_D_value - after_D_value 184 | self.loss_realism = F.relu(self.realism_change - self.args.loss_relu_bias) 185 | self.loss_g = torch.mean(self.loss_realism) 186 | 187 | elif self.args.edit_loss == 'MSE': 188 | self.loss_L2 = self.criterion_L2(self.result, self.albedo) 189 | self.loss_g = torch.mean(self.loss_L2) 190 | 191 | 192 | 193 | 194 | def optimize_parameters(self): 195 | self.optimizer_Parameters.zero_grad() 196 | self.computeloss_realism() 197 | self.loss_g.backward() 198 | self.optimizer_Parameters.step() 199 | 200 | for name, p in self.net_Parameters.named_parameters(): 201 | if p.grad is None: 202 | print(name) 203 | 204 | 205 | def set_requires_grad(self, nets, requires_grad=False): 206 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 207 | Parameters: 208 | nets (network list) -- a list of networks 209 | requires_grad (bool) -- whether the networks require gradients or not 210 | """ 211 | if not isinstance(nets, list): 212 | nets = [nets] 213 | for net in nets: 214 | if net is not None: 215 | for param in net.parameters(): 216 | param.requires_grad = requires_grad 217 | 218 | def savemodel(self,iteration, checkpointdir): 219 | for name in self.model_names: 220 | if isinstance(name, str): 221 | save_filename = '%s_net_%s.pth' % (iteration, name) 222 | save_path = os.path.join(checkpointdir, save_filename) 223 | net = getattr(self, 'net_' + name) 224 | if len(self.args.gpu_ids) > 0 and torch.cuda.is_available(): 225 | torch.save(net.module.cpu().state_dict(), save_path) 226 | net.cuda(self.args.gpu_ids[0]) 227 | else: 228 | torch.save(net.cpu().state_dict(), save_path) 229 | -------------------------------------------------------------------------------- /intrinsic_compositing/albedo/model/parametermodel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .MiDaS.midas.blocks import _make_pretrained_efficientnet_lite3 6 | from ..utils.networkutils import Conv2dSameExport 7 | from ..utils.utils import normalize 8 | 9 | 10 | class ParametersRegressor(nn.Module): 11 | def __init__(self, args): 12 | super(ParametersRegressor,self).__init__() 13 | 14 | self.blurandsharpen = args.blursharpen 15 | self.nf = 384 16 | self.shared_dec_nfeat = 128 17 | self.perm_nfeat = 32 18 | self.nops = args.nops 19 | 20 | self.encoder = _make_pretrained_efficientnet_lite3(use_pretrained=True, exportable=True) 21 | self.encoder.layer1[0] = Conv2dSameExport(4, 32, kernel_size=(3, 3), stride=(2, 2), bias=False) 22 | self.encoder.layer4[1][0].bn3 = nn.Identity() 23 | 24 | if args.bn_momentum is not None: 25 | for module in self.encoder.modules(): 26 | if isinstance(module, nn.BatchNorm2d): 27 | module.momentum = args.bn_momentum 28 | 29 | self.shared_decoder = nn.Sequential( 30 | nn.Linear(self.nf + self.perm_nfeat, self.shared_dec_nfeat), 31 | nn.LeakyReLU(0.1,False), 32 | nn.Linear(self.shared_dec_nfeat, self.shared_dec_nfeat), 33 | ) 34 | 35 | self.WB_head = nn.Sequential( 36 | nn.Linear(self.shared_dec_nfeat, 3), 37 | nn.Sigmoid(), 38 | ) 39 | self.ColorCurve_head = nn.Sequential( 40 | nn.Linear(self.shared_dec_nfeat, 24), 41 | nn.Sigmoid(), 42 | ) 43 | self.Satur_head = nn.Sequential( 44 | nn.Linear(self.shared_dec_nfeat, 1), 45 | nn.Sigmoid(), 46 | ) 47 | self.Expos_head = nn.Sequential( 48 | nn.Linear(self.shared_dec_nfeat, 1), 49 | nn.Sigmoid(), 50 | ) 51 | 52 | self.perm_modulation = nn.Sequential( 53 | nn.Linear(self.nops,self.perm_nfeat) 54 | ) 55 | 56 | if self.blurandsharpen: 57 | self.Sharp_head = nn.Sequential( 58 | nn.Linear(self.shared_dec_nfeat, 1), 59 | nn.Sigmoid(), 60 | ) 61 | 62 | self.Blur_head = nn.Sequential( 63 | nn.Linear(self.shared_dec_nfeat, 1), 64 | nn.Sigmoid(), 65 | ) 66 | 67 | 68 | self.globalavgpool = nn.AdaptiveAvgPool2d(output_size=1) 69 | 70 | def forward(self,x, permutation): 71 | 72 | img_feat = self.encoder.layer1(normalize(x)) 73 | img_feat = self.encoder.layer2(img_feat) 74 | img_feat = self.encoder.layer3(img_feat) 75 | img_feat = self.encoder.layer4(img_feat) 76 | 77 | img_feat = self.globalavgpool(img_feat).squeeze(-1).squeeze(-1) 78 | perm_feat = self.perm_modulation(permutation) 79 | 80 | feat = torch.cat((img_feat,perm_feat),dim=1) 81 | 82 | feat = self.shared_decoder(feat) 83 | 84 | wb_param = self.WB_head(feat)*0.9 + 0.1 85 | colorcurve = self.ColorCurve_head(feat)*1.5 + 0.5 86 | sat_param = self.Satur_head(feat)*2 87 | expos_param = self.Expos_head(feat)*1.5 + 0.5 88 | 89 | if self.blurandsharpen: 90 | sharp_param = self.Sharp_head(feat)*10 + 1 # smaller than one with generate blured image! 91 | blur_param = self.Blur_head(feat)*5 + 0.001 # cannot be zero -- blur operation fails for zero kernel 92 | 93 | result_dic = { 94 | 'whitebalancing':wb_param, 95 | 'colorcurve':colorcurve, 96 | 'saturation':sat_param, 97 | 'exposure':expos_param , 98 | 'blur':blur_param, 99 | 'sharpness':sharp_param 100 | } 101 | else: 102 | result_dic = { 103 | 'whitebalancing':wb_param, 104 | 'colorcurve':colorcurve, 105 | 'saturation':sat_param, 106 | 'exposure':expos_param , 107 | } 108 | 109 | return result_dic 110 | -------------------------------------------------------------------------------- /intrinsic_compositing/albedo/model/pix2pix/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/intrinsic_compositing/albedo/model/pix2pix/__init__.py -------------------------------------------------------------------------------- /intrinsic_compositing/albedo/model/pix2pix/models/__init__.py: -------------------------------------------------------------------------------- 1 | """This package contains modules related to objective functions, optimizations, and network architectures. 2 | 3 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. 4 | You need to implement the following five functions: 5 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 6 | -- : unpack data from dataset and apply preprocessing. 7 | -- : produce intermediate results. 8 | -- : calculate loss, gradients, and update network weights. 9 | -- : (optionally) add model-specific options and set default options. 10 | 11 | In the function <__init__>, you need to define four lists: 12 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 13 | -- self.model_names (str list): define networks used in our training. 14 | -- self.visual_names (str list): specify the images that you want to display and save. 15 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. 16 | 17 | Now you can use the model class by specifying flag '--model dummy'. 18 | See our template model class 'template_model.py' for more details. 19 | """ 20 | 21 | import importlib 22 | from .base_model import BaseModel 23 | 24 | 25 | def find_model_using_name(model_name): 26 | """Import the module "models/[model_name]_model.py". 27 | 28 | In the file, the class called DatasetNameModel() will 29 | be instantiated. It has to be a subclass of BaseModel, 30 | and it is case-insensitive. 31 | """ 32 | model_filename = "models." + model_name + "_model" 33 | modellib = importlib.import_module(model_filename) 34 | model = None 35 | target_model_name = model_name.replace('_', '') + 'model' 36 | for name, cls in modellib.__dict__.items(): 37 | if name.lower() == target_model_name.lower() \ 38 | and issubclass(cls, BaseModel): 39 | model = cls 40 | 41 | if model is None: 42 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 43 | exit(0) 44 | 45 | return model 46 | 47 | 48 | def get_option_setter(model_name): 49 | """Return the static method of the model class.""" 50 | model_class = find_model_using_name(model_name) 51 | return model_class.modify_commandline_options 52 | 53 | 54 | def create_model(opt): 55 | """Create a model given the option. 56 | 57 | This function warps the class CustomDatasetDataLoader. 58 | This is the main interface between this package and 'train.py'/'test.py' 59 | 60 | Example: 61 | >>> from models import create_model 62 | >>> model = create_model(opt) 63 | """ 64 | model = find_model_using_name(opt.model) 65 | instance = model(opt) 66 | print("model [%s] was created" % type(instance).__name__) 67 | return instance 68 | -------------------------------------------------------------------------------- /intrinsic_compositing/albedo/model/pix2pix/models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | from abc import ABC, abstractmethod 5 | from . import networks 6 | 7 | 8 | class BaseModel(ABC): 9 | """This class is an abstract base class (ABC) for models. 10 | To create a subclass, you need to implement the following five functions: 11 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 12 | -- : unpack data from dataset and apply preprocessing. 13 | -- : produce intermediate results. 14 | -- : calculate losses, gradients, and update network weights. 15 | -- : (optionally) add model-specific options and set default options. 16 | """ 17 | 18 | def __init__(self, opt): 19 | """Initialize the BaseModel class. 20 | 21 | Parameters: 22 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 23 | 24 | When creating your custom class, you need to implement your own initialization. 25 | In this function, you should first call 26 | Then, you need to define four lists: 27 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 28 | -- self.model_names (str list): define networks used in our training. 29 | -- self.visual_names (str list): specify the images that you want to display and save. 30 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. 31 | """ 32 | self.opt = opt 33 | self.gpu_ids = opt.gpu_ids 34 | self.isTrain = opt.isTrain 35 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU 36 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir 37 | if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark. 38 | torch.backends.cudnn.benchmark = True 39 | self.loss_names = [] 40 | self.model_names = [] 41 | self.visual_names = [] 42 | self.optimizers = [] 43 | self.image_paths = [] 44 | self.metric = 0 # used for learning rate policy 'plateau' 45 | 46 | @staticmethod 47 | def modify_commandline_options(parser, is_train): 48 | """Add new model-specific options, and rewrite default values for existing options. 49 | 50 | Parameters: 51 | parser -- original option parser 52 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 53 | 54 | Returns: 55 | the modified parser. 56 | """ 57 | return parser 58 | 59 | @abstractmethod 60 | def set_input(self, input): 61 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 62 | 63 | Parameters: 64 | input (dict): includes the data itself and its metadata information. 65 | """ 66 | pass 67 | 68 | @abstractmethod 69 | def forward(self): 70 | """Run forward pass; called by both functions and .""" 71 | pass 72 | 73 | @abstractmethod 74 | def optimize_parameters(self): 75 | """Calculate losses, gradients, and update network weights; called in every training iteration""" 76 | pass 77 | 78 | def setup(self, opt): 79 | """Load and print networks; create schedulers 80 | 81 | Parameters: 82 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 83 | """ 84 | if self.isTrain: 85 | self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] 86 | if not self.isTrain or opt.continue_train: 87 | load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch 88 | self.load_networks(load_suffix) 89 | self.print_networks(opt.verbose) 90 | 91 | def eval(self): 92 | """Make models eval mode during test time""" 93 | for name in self.model_names: 94 | if isinstance(name, str): 95 | net = getattr(self, 'net' + name) 96 | net.eval() 97 | 98 | def test(self): 99 | """Forward function used in test time. 100 | 101 | This function wraps function in no_grad() so we don't save intermediate steps for backprop 102 | It also calls to produce additional visualization results 103 | """ 104 | with torch.no_grad(): 105 | self.forward() 106 | self.compute_visuals() 107 | 108 | def compute_visuals(self): 109 | """Calculate additional output images for visdom and HTML visualization""" 110 | pass 111 | 112 | def get_image_paths(self): 113 | """ Return image paths that are used to load current data""" 114 | return self.image_paths 115 | 116 | def update_learning_rate(self): 117 | """Update learning rates for all the networks; called at the end of every epoch""" 118 | old_lr = self.optimizers[0].param_groups[0]['lr'] 119 | for scheduler in self.schedulers: 120 | if self.opt.lr_policy == 'plateau': 121 | scheduler.step(self.metric) 122 | else: 123 | scheduler.step() 124 | 125 | lr = self.optimizers[0].param_groups[0]['lr'] 126 | print('learning rate %.7f -> %.7f' % (old_lr, lr)) 127 | 128 | def get_current_visuals(self): 129 | """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" 130 | visual_ret = OrderedDict() 131 | for name in self.visual_names: 132 | if isinstance(name, str): 133 | visual_ret[name] = getattr(self, name) 134 | return visual_ret 135 | 136 | def get_current_losses(self): 137 | """Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" 138 | errors_ret = OrderedDict() 139 | for name in self.loss_names: 140 | if isinstance(name, str): 141 | errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number 142 | return errors_ret 143 | 144 | def save_networks(self, epoch): 145 | """Save all the networks to the disk. 146 | 147 | Parameters: 148 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 149 | """ 150 | for name in self.model_names: 151 | if isinstance(name, str): 152 | save_filename = '%s_net_%s.pth' % (epoch, name) 153 | save_path = os.path.join(self.save_dir, save_filename) 154 | net = getattr(self, 'net' + name) 155 | 156 | if len(self.gpu_ids) > 0 and torch.cuda.is_available(): 157 | torch.save(net.module.cpu().state_dict(), save_path) 158 | net.cuda(self.gpu_ids[0]) 159 | else: 160 | torch.save(net.cpu().state_dict(), save_path) 161 | 162 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): 163 | """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" 164 | key = keys[i] 165 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer 166 | if module.__class__.__name__.startswith('InstanceNorm') and \ 167 | (key == 'running_mean' or key == 'running_var'): 168 | if getattr(module, key) is None: 169 | state_dict.pop('.'.join(keys)) 170 | if module.__class__.__name__.startswith('InstanceNorm') and \ 171 | (key == 'num_batches_tracked'): 172 | state_dict.pop('.'.join(keys)) 173 | else: 174 | self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) 175 | 176 | def load_networks(self, epoch): 177 | """Load all the networks from the disk. 178 | 179 | Parameters: 180 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 181 | """ 182 | for name in self.model_names: 183 | if isinstance(name, str): 184 | load_filename = '%s_net_%s.pth' % (epoch, name) 185 | load_path = os.path.join(self.save_dir, load_filename) 186 | net = getattr(self, 'net' + name) 187 | if isinstance(net, torch.nn.DataParallel): 188 | net = net.module 189 | print('loading the model from %s' % load_path) 190 | # if you are using PyTorch newer than 0.4 (e.g., built from 191 | # GitHub source), you can remove str() on self.device 192 | state_dict = torch.load(load_path, map_location=str(self.device)) 193 | if hasattr(state_dict, '_metadata'): 194 | del state_dict._metadata 195 | 196 | # patch InstanceNorm checkpoints prior to 0.4 197 | for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop 198 | self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) 199 | net.load_state_dict(state_dict) 200 | 201 | def print_networks(self, verbose): 202 | """Print the total number of parameters in the network and (if verbose) network architecture 203 | 204 | Parameters: 205 | verbose (bool) -- if verbose: print the network architecture 206 | """ 207 | print('---------- Networks initialized -------------') 208 | for name in self.model_names: 209 | if isinstance(name, str): 210 | net = getattr(self, 'net' + name) 211 | num_params = 0 212 | for param in net.parameters(): 213 | num_params += param.numel() 214 | if verbose: 215 | print(net) 216 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) 217 | print('-----------------------------------------------') 218 | 219 | def set_requires_grad(self, nets, requires_grad=False): 220 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 221 | Parameters: 222 | nets (network list) -- a list of networks 223 | requires_grad (bool) -- whether the networks require gradients or not 224 | """ 225 | if not isinstance(nets, list): 226 | nets = [nets] 227 | for net in nets: 228 | if net is not None: 229 | for param in net.parameters(): 230 | param.requires_grad = requires_grad 231 | -------------------------------------------------------------------------------- /intrinsic_compositing/albedo/model/pix2pix/models/networks.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import reduction 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import init 5 | import functools 6 | from torch.optim import lr_scheduler 7 | 8 | 9 | ############################################################################### 10 | # Helper Functions 11 | ############################################################################### 12 | 13 | 14 | class Identity(nn.Module): 15 | def forward(self, x): 16 | return x 17 | 18 | 19 | def get_norm_layer(norm_type='instance'): 20 | """Return a normalization layer 21 | 22 | Parameters: 23 | norm_type (str) -- the name of the normalization layer: batch | instance | none 24 | 25 | For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). 26 | For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. 27 | """ 28 | if norm_type == 'batch': 29 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) 30 | elif norm_type == 'instance': 31 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) 32 | elif norm_type == 'none': 33 | def norm_layer(x): return Identity() 34 | else: 35 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 36 | return norm_layer 37 | 38 | 39 | def get_scheduler(optimizer, opt): 40 | """Return a learning rate scheduler 41 | 42 | Parameters: 43 | optimizer -- the optimizer of the network 44 | opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  45 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine 46 | 47 | For 'linear', we keep the same learning rate for the first epochs 48 | and linearly decay the rate to zero over the next epochs. 49 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. 50 | See https://pytorch.org/docs/stable/optim.html for more details. 51 | """ 52 | if opt.lr_policy == 'linear': 53 | def lambda_rule(epoch): 54 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1) 55 | return lr_l 56 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 57 | elif opt.lr_policy == 'step': 58 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) 59 | elif opt.lr_policy == 'plateau': 60 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 61 | elif opt.lr_policy == 'cosine': 62 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0) 63 | else: 64 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 65 | return scheduler 66 | 67 | 68 | def init_weights(net, init_type='normal', init_gain=0.02): 69 | """Initialize network weights. 70 | 71 | Parameters: 72 | net (network) -- network to be initialized 73 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 74 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 75 | 76 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might 77 | work better for some applications. Feel free to try yourself. 78 | """ 79 | def init_func(m): # define the initialization function 80 | classname = m.__class__.__name__ 81 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 82 | if init_type == 'normal': 83 | init.normal_(m.weight.data, 0.0, init_gain) 84 | elif init_type == 'xavier': 85 | init.xavier_normal_(m.weight.data, gain=init_gain) 86 | elif init_type == 'kaiming': 87 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 88 | elif init_type == 'orthogonal': 89 | init.orthogonal_(m.weight.data, gain=init_gain) 90 | else: 91 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 92 | if hasattr(m, 'bias') and m.bias is not None: 93 | init.constant_(m.bias.data, 0.0) 94 | elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. 95 | init.normal_(m.weight.data, 1.0, init_gain) 96 | init.constant_(m.bias.data, 0.0) 97 | 98 | print('initialize network with %s' % init_type) 99 | net.apply(init_func) # apply the initialization function 100 | 101 | 102 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): 103 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights 104 | Parameters: 105 | net (network) -- the network to be initialized 106 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 107 | gain (float) -- scaling factor for normal, xavier and orthogonal. 108 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 109 | 110 | Return an initialized network. 111 | """ 112 | if len(gpu_ids) > 0: 113 | assert(torch.cuda.is_available()) 114 | net.to(gpu_ids[0]) 115 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs 116 | init_weights(net, init_type, init_gain=init_gain) 117 | return net 118 | 119 | 120 | def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]): 121 | """Create a generator 122 | 123 | Parameters: 124 | input_nc (int) -- the number of channels in input images 125 | output_nc (int) -- the number of channels in output images 126 | ngf (int) -- the number of filters in the last conv layer 127 | netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128 128 | norm (str) -- the name of normalization layers used in the network: batch | instance | none 129 | use_dropout (bool) -- if use dropout layers. 130 | init_type (str) -- the name of our initialization method. 131 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 132 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 133 | 134 | Returns a generator 135 | 136 | Our current implementation provides two types of generators: 137 | U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images) 138 | The original U-Net paper: https://arxiv.org/abs/1505.04597 139 | 140 | Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks) 141 | Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations. 142 | We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style). 143 | 144 | 145 | The generator has been initialized by . It uses RELU for non-linearity. 146 | """ 147 | net = None 148 | norm_layer = get_norm_layer(norm_type=norm) 149 | 150 | if netG == 'resnet_9blocks': 151 | net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9) 152 | elif netG == 'resnet_6blocks': 153 | net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6) 154 | elif netG == 'unet_128': 155 | net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout) 156 | elif netG == 'unet_256': 157 | net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout) 158 | else: 159 | raise NotImplementedError('Generator model name [%s] is not recognized' % netG) 160 | return init_net(net, init_type, init_gain, gpu_ids) 161 | 162 | 163 | def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]): 164 | """Create a discriminator 165 | 166 | Parameters: 167 | input_nc (int) -- the number of channels in input images 168 | ndf (int) -- the number of filters in the first conv layer 169 | netD (str) -- the architecture's name: basic | n_layers | pixel 170 | n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers' 171 | norm (str) -- the type of normalization layers used in the network. 172 | init_type (str) -- the name of the initialization method. 173 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 174 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 175 | 176 | Returns a discriminator 177 | 178 | Our current implementation provides three types of discriminators: 179 | [basic]: 'PatchGAN' classifier described in the original pix2pix paper. 180 | It can classify whether 70×70 overlapping patches are real or fake. 181 | Such a patch-level discriminator architecture has fewer parameters 182 | than a full-image discriminator and can work on arbitrarily-sized images 183 | in a fully convolutional fashion. 184 | 185 | [n_layers]: With this mode, you can specify the number of conv layers in the discriminator 186 | with the parameter (default=3 as used in [basic] (PatchGAN).) 187 | 188 | [pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not. 189 | It encourages greater color diversity but has no effect on spatial statistics. 190 | 191 | The discriminator has been initialized by . It uses Leakly RELU for non-linearity. 192 | """ 193 | net = None 194 | norm_layer = get_norm_layer(norm_type=norm) 195 | 196 | if netD == 'basic': # default PatchGAN classifier 197 | net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer) 198 | elif netD == 'n_layers': # more options 199 | net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer) 200 | elif netD == 'pixel': # classify if each pixel is real or fake 201 | net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer) 202 | else: 203 | raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD) 204 | return init_net(net, init_type, init_gain, gpu_ids) 205 | 206 | 207 | ############################################################################## 208 | # Classes 209 | ############################################################################## 210 | class GANLoss(nn.Module): 211 | """Define different GAN objectives. 212 | 213 | The GANLoss class abstracts away the need to create the target label tensor 214 | that has the same size as the input. 215 | """ 216 | 217 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): 218 | """ Initialize the GANLoss class. 219 | 220 | Parameters: 221 | gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. 222 | target_real_label (bool) - - label for a real image 223 | target_fake_label (bool) - - label of a fake image 224 | 225 | Note: Do not use sigmoid as the last layer of Discriminator. 226 | LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. 227 | """ 228 | super(GANLoss, self).__init__() 229 | self.register_buffer('real_label', torch.tensor(target_real_label)) 230 | self.register_buffer('fake_label', torch.tensor(target_fake_label)) 231 | self.gan_mode = gan_mode 232 | if gan_mode == 'lsgan': 233 | self.loss = nn.MSELoss(reduction='none') 234 | elif gan_mode == 'vanilla': 235 | self.loss = nn.BCEWithLogitsLoss(reduction='none') 236 | elif gan_mode in ['wgangp']: 237 | self.loss = None 238 | else: 239 | raise NotImplementedError('gan mode %s not implemented' % gan_mode) 240 | 241 | def get_target_tensor(self, prediction, target_is_real): 242 | """Create label tensors with the same size as the input. 243 | 244 | Parameters: 245 | prediction (tensor) - - tpyically the prediction from a discriminator 246 | target_is_real (bool) - - if the ground truth label is for real images or fake images 247 | 248 | Returns: 249 | A label tensor filled with ground truth label, and with the size of the input 250 | """ 251 | 252 | if target_is_real: 253 | target_tensor = self.real_label 254 | else: 255 | target_tensor = self.fake_label 256 | return target_tensor.expand_as(prediction) 257 | 258 | def __call__(self, prediction, target_is_real): 259 | """Calculate loss given Discriminator's output and grount truth labels. 260 | 261 | Parameters: 262 | prediction (tensor) - - tpyically the prediction output from a discriminator 263 | target_is_real (bool) - - if the ground truth label is for real images or fake images 264 | 265 | Returns: 266 | the calculated loss. 267 | """ 268 | if self.gan_mode in ['lsgan', 'vanilla']: 269 | target_tensor = self.get_target_tensor(prediction, target_is_real) 270 | loss = self.loss(prediction, target_tensor) 271 | elif self.gan_mode == 'wgangp': 272 | if target_is_real: 273 | loss = -prediction.mean() 274 | else: 275 | loss = prediction.mean() 276 | return loss 277 | 278 | 279 | def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0): 280 | """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028 281 | 282 | Arguments: 283 | netD (network) -- discriminator network 284 | real_data (tensor array) -- real images 285 | fake_data (tensor array) -- generated images from the generator 286 | device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') 287 | type (str) -- if we mix real and fake data or not [real | fake | mixed]. 288 | constant (float) -- the constant used in formula ( ||gradient||_2 - constant)^2 289 | lambda_gp (float) -- weight for this loss 290 | 291 | Returns the gradient penalty loss 292 | """ 293 | if lambda_gp > 0.0: 294 | if type == 'real': # either use real images, fake images, or a linear interpolation of two. 295 | interpolatesv = real_data 296 | elif type == 'fake': 297 | interpolatesv = fake_data 298 | elif type == 'mixed': 299 | alpha = torch.rand(real_data.shape[0], 1, device=device) 300 | alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape) 301 | interpolatesv = alpha * real_data + ((1 - alpha) * fake_data) 302 | else: 303 | raise NotImplementedError('{} not implemented'.format(type)) 304 | interpolatesv.requires_grad_(True) 305 | disc_interpolates = netD(interpolatesv) 306 | gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv, 307 | grad_outputs=torch.ones(disc_interpolates.size()).to(device), 308 | create_graph=True, retain_graph=True, only_inputs=True) 309 | gradients = gradients[0].view(real_data.size(0), -1) # flat the data 310 | gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps 311 | return gradient_penalty, gradients 312 | else: 313 | return 0.0, None 314 | 315 | 316 | class ResnetGenerator(nn.Module): 317 | """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations. 318 | 319 | We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style) 320 | """ 321 | 322 | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'): 323 | """Construct a Resnet-based generator 324 | 325 | Parameters: 326 | input_nc (int) -- the number of channels in input images 327 | output_nc (int) -- the number of channels in output images 328 | ngf (int) -- the number of filters in the last conv layer 329 | norm_layer -- normalization layer 330 | use_dropout (bool) -- if use dropout layers 331 | n_blocks (int) -- the number of ResNet blocks 332 | padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero 333 | """ 334 | assert(n_blocks >= 0) 335 | super(ResnetGenerator, self).__init__() 336 | if type(norm_layer) == functools.partial: 337 | use_bias = norm_layer.func == nn.InstanceNorm2d 338 | else: 339 | use_bias = norm_layer == nn.InstanceNorm2d 340 | 341 | model = [nn.ReflectionPad2d(3), 342 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), 343 | norm_layer(ngf), 344 | nn.ReLU(True)] 345 | 346 | n_downsampling = 2 347 | for i in range(n_downsampling): # add downsampling layers 348 | mult = 2 ** i 349 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias), 350 | norm_layer(ngf * mult * 2), 351 | nn.ReLU(True)] 352 | 353 | mult = 2 ** n_downsampling 354 | for i in range(n_blocks): # add ResNet blocks 355 | 356 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] 357 | 358 | for i in range(n_downsampling): # add upsampling layers 359 | mult = 2 ** (n_downsampling - i) 360 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), 361 | kernel_size=3, stride=2, 362 | padding=1, output_padding=1, 363 | bias=use_bias), 364 | norm_layer(int(ngf * mult / 2)), 365 | nn.ReLU(True)] 366 | model += [nn.ReflectionPad2d(3)] 367 | model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] 368 | model += [nn.Tanh()] 369 | 370 | self.model = nn.Sequential(*model) 371 | 372 | def forward(self, input): 373 | """Standard forward""" 374 | return self.model(input) 375 | 376 | 377 | class ResnetBlock(nn.Module): 378 | """Define a Resnet block""" 379 | 380 | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): 381 | """Initialize the Resnet block 382 | 383 | A resnet block is a conv block with skip connections 384 | We construct a conv block with build_conv_block function, 385 | and implement skip connections in function. 386 | Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf 387 | """ 388 | super(ResnetBlock, self).__init__() 389 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) 390 | 391 | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): 392 | """Construct a convolutional block. 393 | 394 | Parameters: 395 | dim (int) -- the number of channels in the conv layer. 396 | padding_type (str) -- the name of padding layer: reflect | replicate | zero 397 | norm_layer -- normalization layer 398 | use_dropout (bool) -- if use dropout layers. 399 | use_bias (bool) -- if the conv layer uses bias or not 400 | 401 | Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU)) 402 | """ 403 | conv_block = [] 404 | p = 0 405 | if padding_type == 'reflect': 406 | conv_block += [nn.ReflectionPad2d(1)] 407 | elif padding_type == 'replicate': 408 | conv_block += [nn.ReplicationPad2d(1)] 409 | elif padding_type == 'zero': 410 | p = 1 411 | else: 412 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 413 | 414 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)] 415 | if use_dropout: 416 | conv_block += [nn.Dropout(0.5)] 417 | 418 | p = 0 419 | if padding_type == 'reflect': 420 | conv_block += [nn.ReflectionPad2d(1)] 421 | elif padding_type == 'replicate': 422 | conv_block += [nn.ReplicationPad2d(1)] 423 | elif padding_type == 'zero': 424 | p = 1 425 | else: 426 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 427 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)] 428 | 429 | return nn.Sequential(*conv_block) 430 | 431 | def forward(self, x): 432 | """Forward function (with skip connections)""" 433 | out = x + self.conv_block(x) # add skip connections 434 | return out 435 | 436 | 437 | class UnetGenerator(nn.Module): 438 | """Create a Unet-based generator""" 439 | 440 | def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False): 441 | """Construct a Unet generator 442 | Parameters: 443 | input_nc (int) -- the number of channels in input images 444 | output_nc (int) -- the number of channels in output images 445 | num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, 446 | image of size 128x128 will become of size 1x1 # at the bottleneck 447 | ngf (int) -- the number of filters in the last conv layer 448 | norm_layer -- normalization layer 449 | 450 | We construct the U-Net from the innermost layer to the outermost layer. 451 | It is a recursive process. 452 | """ 453 | super(UnetGenerator, self).__init__() 454 | # construct unet structure 455 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer 456 | for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters 457 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) 458 | # gradually reduce the number of filters from ngf * 8 to ngf 459 | unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 460 | unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 461 | unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 462 | self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer 463 | 464 | def forward(self, input): 465 | """Standard forward""" 466 | return self.model(input) 467 | 468 | 469 | class UnetSkipConnectionBlock(nn.Module): 470 | """Defines the Unet submodule with skip connection. 471 | X -------------------identity---------------------- 472 | |-- downsampling -- |submodule| -- upsampling --| 473 | """ 474 | 475 | def __init__(self, outer_nc, inner_nc, input_nc=None, 476 | submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): 477 | """Construct a Unet submodule with skip connections. 478 | 479 | Parameters: 480 | outer_nc (int) -- the number of filters in the outer conv layer 481 | inner_nc (int) -- the number of filters in the inner conv layer 482 | input_nc (int) -- the number of channels in input images/features 483 | submodule (UnetSkipConnectionBlock) -- previously defined submodules 484 | outermost (bool) -- if this module is the outermost module 485 | innermost (bool) -- if this module is the innermost module 486 | norm_layer -- normalization layer 487 | use_dropout (bool) -- if use dropout layers. 488 | """ 489 | super(UnetSkipConnectionBlock, self).__init__() 490 | self.outermost = outermost 491 | if type(norm_layer) == functools.partial: 492 | use_bias = norm_layer.func == nn.InstanceNorm2d 493 | else: 494 | use_bias = norm_layer == nn.InstanceNorm2d 495 | if input_nc is None: 496 | input_nc = outer_nc 497 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, 498 | stride=2, padding=1, bias=use_bias) 499 | downrelu = nn.LeakyReLU(0.2, True) 500 | downnorm = norm_layer(inner_nc) 501 | uprelu = nn.ReLU(True) 502 | upnorm = norm_layer(outer_nc) 503 | 504 | if outermost: 505 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 506 | kernel_size=4, stride=2, 507 | padding=1) 508 | down = [downconv] 509 | up = [uprelu, upconv, nn.Tanh()] 510 | model = down + [submodule] + up 511 | elif innermost: 512 | upconv = nn.ConvTranspose2d(inner_nc, outer_nc, 513 | kernel_size=4, stride=2, 514 | padding=1, bias=use_bias) 515 | down = [downrelu, downconv] 516 | up = [uprelu, upconv, upnorm] 517 | model = down + up 518 | else: 519 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 520 | kernel_size=4, stride=2, 521 | padding=1, bias=use_bias) 522 | down = [downrelu, downconv, downnorm] 523 | up = [uprelu, upconv, upnorm] 524 | 525 | if use_dropout: 526 | model = down + [submodule] + up + [nn.Dropout(0.5)] 527 | else: 528 | model = down + [submodule] + up 529 | 530 | self.model = nn.Sequential(*model) 531 | 532 | def forward(self, x): 533 | if self.outermost: 534 | return self.model(x) 535 | else: # add skip connections 536 | return torch.cat([x, self.model(x)], 1) 537 | 538 | 539 | class NLayerDiscriminator(nn.Module): 540 | """Defines a PatchGAN discriminator""" 541 | 542 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): 543 | """Construct a PatchGAN discriminator 544 | 545 | Parameters: 546 | input_nc (int) -- the number of channels in input images 547 | ndf (int) -- the number of filters in the last conv layer 548 | n_layers (int) -- the number of conv layers in the discriminator 549 | norm_layer -- normalization layer 550 | """ 551 | super(NLayerDiscriminator, self).__init__() 552 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 553 | use_bias = norm_layer.func == nn.InstanceNorm2d 554 | else: 555 | use_bias = norm_layer == nn.InstanceNorm2d 556 | 557 | kw = 4 558 | padw = 1 559 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 560 | nf_mult = 1 561 | nf_mult_prev = 1 562 | for n in range(1, n_layers): # gradually increase the number of filters 563 | nf_mult_prev = nf_mult 564 | nf_mult = min(2 ** n, 8) 565 | sequence += [ 566 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 567 | norm_layer(ndf * nf_mult), 568 | nn.LeakyReLU(0.2, True) 569 | ] 570 | 571 | nf_mult_prev = nf_mult 572 | nf_mult = min(2 ** n_layers, 8) 573 | sequence += [ 574 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 575 | norm_layer(ndf * nf_mult), 576 | nn.LeakyReLU(0.2, True) 577 | ] 578 | 579 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 580 | self.model = nn.Sequential(*sequence) 581 | 582 | def forward(self, input): 583 | """Standard forward.""" 584 | return self.model(input) 585 | 586 | 587 | class PixelDiscriminator(nn.Module): 588 | """Defines a 1x1 PatchGAN discriminator (pixelGAN)""" 589 | 590 | def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d): 591 | """Construct a 1x1 PatchGAN discriminator 592 | 593 | Parameters: 594 | input_nc (int) -- the number of channels in input images 595 | ndf (int) -- the number of filters in the last conv layer 596 | norm_layer -- normalization layer 597 | """ 598 | super(PixelDiscriminator, self).__init__() 599 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 600 | use_bias = norm_layer.func == nn.InstanceNorm2d 601 | else: 602 | use_bias = norm_layer == nn.InstanceNorm2d 603 | 604 | self.net = [ 605 | nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0), 606 | nn.LeakyReLU(0.2, True), 607 | nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias), 608 | norm_layer(ndf * 2), 609 | nn.LeakyReLU(0.2, True), 610 | nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)] 611 | 612 | self.net = nn.Sequential(*self.net) 613 | 614 | def forward(self, input): 615 | """Standard forward.""" 616 | return self.net(input) 617 | -------------------------------------------------------------------------------- /intrinsic_compositing/albedo/pipeline.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import os 3 | from skimage.transform import resize 4 | import torch 5 | import numpy as np 6 | import torchvision.transforms as transforms 7 | from PIL import Image 8 | import torch.utils.data as data 9 | from argparse import Namespace 10 | 11 | from chrislib.general import np_to_pil 12 | 13 | from intrinsic_compositing.albedo.model.editingnetwork_trainer import EditingNetworkTrainer 14 | 15 | PAPER_WEIGHTS_URL = 'https://github.com/compphoto/IntrinsicCompositing/releases/download/1.0.0/albedo_paper_weights.pth' 16 | CACHE_PATH = torch.hub.get_dir() 17 | 18 | 19 | def get_transform(opt, grayscale=False, method=Image.BICUBIC, convert=True): 20 | transform_list = [] 21 | if grayscale: 22 | transform_list.append(transforms.Grayscale(1)) 23 | method=Image.BILINEAR 24 | if 'resize' in opt['preprocess']: 25 | osize = [opt['load_size'], opt['load_size']] 26 | transform_list.append(transforms.Resize(osize, method)) 27 | 28 | if 'crop' in opt['preprocess']: 29 | transform_list.append(transforms.RandomCrop(opt['crop_size'])) 30 | 31 | if not opt['no_flip']: 32 | transform_list.append(transforms.RandomHorizontalFlip()) 33 | 34 | if convert: 35 | transform_list += [transforms.ToTensor()] 36 | # if grayscale: 37 | # transform_list += [transforms.Normalize((0.5,), (0.5,))] 38 | # else: 39 | # transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 40 | return transforms.Compose(transform_list) 41 | 42 | def match_scalar(source, target, mask=None, min_percentile=0, max_percentile=100): 43 | 44 | if mask is None: 45 | # mask = np.ones((self.args.crop_size, self.args.crop_size), dtype=bool) 46 | mask = np.ones((384, 384), dtype=bool) 47 | 48 | target_masked = target[:,mask] 49 | 50 | 51 | # consider all values up to a percentile 52 | p_threshold_min = np.percentile(target_masked.reshape(-1), min_percentile) 53 | p_mask_min = np.greater_equal(target, p_threshold_min) 54 | 55 | p_threshold_max = np.percentile(target_masked.reshape(-1), max_percentile) 56 | p_mask_max = np.less_equal(target, p_threshold_max) 57 | 58 | p_mask = np.logical_and(p_mask_max, p_mask_min) 59 | mask = np.logical_and(p_mask, mask) 60 | 61 | flat_source = source[mask] 62 | flat_target = target[mask] 63 | 64 | scalar, _, _, _ = np.linalg.lstsq( 65 | flat_source.reshape(-1, 1), flat_target.reshape(-1, 1), rcond=None) 66 | source_scaled = source * scalar 67 | return source_scaled, scalar 68 | 69 | def prep_input(rgb_img, mask_img, shading_img, reproduce_paper=False): 70 | # this function takes the srgb image (rgb_img) the mask 71 | # and the shading as numpy arrays between [0-1] 72 | 73 | opt = {} 74 | opt['load_size'] = 384 75 | opt['crop_size'] = 384 76 | opt['preprocess'] = 'resize' 77 | opt['no_flip'] = True 78 | 79 | rgb_transform = get_transform(opt, grayscale=False) 80 | mask_transform = get_transform(opt, grayscale=True) 81 | shading_transform = get_transform(opt, grayscale=False, method=Image.BILINEAR) 82 | 83 | mask_img_np = mask_img 84 | if len(mask_img_np.shape) == 3: 85 | mask_img_np = mask_img_np[:, :, 0] 86 | 87 | if len(shading_img.shape) == 3: 88 | shading_img = shading_img[:, :, 0] 89 | 90 | full_shd = ((1.0 / (shading_img)) - 1.0) 91 | full_msk = resize(mask_img_np, full_shd.shape) 92 | # full_msk = resize(np.array(mask_img) / 255., full_shd.shape) 93 | full_img = resize(rgb_img, full_shd.shape) 94 | full_alb = (full_img ** 2.2) / full_shd[:, :, None].clip(1e-4) 95 | full_alb = full_alb.clip(1e-4) ** (1/2.2) 96 | 97 | full_alb = torch.from_numpy(full_alb).permute(2, 0, 1) 98 | full_shd = torch.from_numpy(full_shd).unsqueeze(0) 99 | full_msk = torch.from_numpy(full_msk).unsqueeze(0) 100 | 101 | srgb = rgb_transform(np_to_pil(rgb_img)) 102 | rgb_mask = np.stack([mask_img_np] * 3, -1) 103 | mask = mask_transform(np_to_pil(rgb_mask)) 104 | 105 | if reproduce_paper: 106 | invshading = shading_transform(np_to_pil(shading_img)) / (2**16-1) 107 | else: 108 | invshading = shading_transform(np_to_pil(shading_img)) 109 | 110 | shading = ((1.0 / invshading) - 1.0) 111 | 112 | ## compute albedo 113 | rgb = srgb ** 2.2 114 | albedo = rgb / shading 115 | 116 | ## min max normalize the albedo: 117 | # albedo = (albedo - albedo.min()) / (albedo.max() - albedo.min()) 118 | 119 | ## match the albedo to the rgb 120 | albedo, scalar = match_scalar(albedo.numpy(),rgb.numpy()) 121 | albedo = torch.from_numpy(albedo) 122 | albedo = albedo ** (1/2.2) 123 | 124 | shading = shading / scalar 125 | ## clip albedo to [0,1] 126 | albedo = torch.clamp(albedo,0,1) 127 | 128 | 129 | # all of these need to have a batch dimension as if they are coming from the dataloader 130 | return { 131 | 'srgb': srgb.unsqueeze(0), 132 | 'mask': mask.unsqueeze(0), 133 | 'albedo':albedo.unsqueeze(0), 134 | 'shading': shading.unsqueeze(0), 135 | 'albedo_full' : full_alb.unsqueeze(0), 136 | 'shading_full' : full_shd.unsqueeze(0), 137 | 'mask_full' : full_msk.unsqueeze(0) 138 | } 139 | 140 | 141 | def load_albedo_harmonizer(): 142 | 143 | args = Namespace() 144 | args.nops = 4 145 | args.gpu_ids = [0] 146 | args.blursharpen = 0 147 | args.fake_gen_lowdev = 0 148 | args.bn_momentum = 0.01 149 | args.edit_loss = '' 150 | args.loss_relu_bias = 0 151 | args.crop_size = 384 152 | args.load_size = 384 153 | args.lr_d = 0.00001 154 | args.lr_editnet = 0.00001 155 | args.batch_size = 1 156 | 157 | args.checkpoint_load_path = f'{CACHE_PATH}/albedo_harmonization/albedo_paper_weights.pth' 158 | 159 | if not os.path.exists(args.checkpoint_load_path): 160 | os.makedirs(f'{CACHE_PATH}/albedo_harmonization', exist_ok=True) 161 | os.system(f'wget {PAPER_WEIGHTS_URL} -P {CACHE_PATH}/albedo_harmonization') 162 | 163 | trainer = EditingNetworkTrainer(args) 164 | 165 | return trainer 166 | 167 | def harmonize_albedo(img, shd, msk, trainer, reproduce_paper=False): 168 | 169 | trainer.setEval() 170 | 171 | data = prep_input(img, shd, msk, reproduce_paper=reproduce_paper) 172 | 173 | trainer.setinput_HR(data) 174 | 175 | with torch.no_grad(): 176 | trainer.forward() 177 | 178 | albedo_out = trainer.result[0,...].cpu().detach().numpy().squeeze().transpose([1,2,0]) 179 | result = albedo_out.clip(0, 1) 180 | return result 181 | 182 | -------------------------------------------------------------------------------- /intrinsic_compositing/albedo/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/intrinsic_compositing/albedo/utils/__init__.py -------------------------------------------------------------------------------- /intrinsic_compositing/albedo/utils/datautils.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as transforms 2 | from PIL import Image 3 | import torch.utils.data as data 4 | 5 | 6 | def get_transform(opt, grayscale=False, method=Image.BICUBIC, convert=True): 7 | transform_list = [] 8 | if grayscale: 9 | transform_list.append(transforms.Grayscale(1)) 10 | method=Image.BILINEAR 11 | if 'resize' in opt['preprocess']: 12 | osize = [opt['load_size'], opt['load_size']] 13 | transform_list.append(transforms.Resize(osize, method)) 14 | 15 | if 'crop' in opt['preprocess']: 16 | transform_list.append(transforms.RandomCrop(opt['crop_size'])) 17 | 18 | if not opt['no_flip']: 19 | transform_list.append(transforms.RandomHorizontalFlip()) 20 | 21 | if convert: 22 | transform_list += [transforms.ToTensor()] 23 | # if grayscale: 24 | # transform_list += [transforms.Normalize((0.5,), (0.5,))] 25 | # else: 26 | # transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 27 | return transforms.Compose(transform_list) -------------------------------------------------------------------------------- /intrinsic_compositing/albedo/utils/depthutils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | curr_path = "/localhome/smh31/Repositories/intrinsic_composite/realismnet/model/" 5 | # OUR 6 | from BoostingMonocularDepth.utils import ImageandPatchs, ImageDataset, generatemask, getGF_fromintegral, calculateprocessingres, rgb2gray,\ 7 | applyGridpatch 8 | 9 | # MIDAS 10 | import BoostingMonocularDepth.midas.utils 11 | from BoostingMonocularDepth.midas.models.midas_net import MidasNet 12 | from BoostingMonocularDepth.midas.models.transforms import Resize, NormalizeImage, PrepareForNet 13 | 14 | # PIX2PIX : MERGE NET 15 | from BoostingMonocularDepth.pix2pix.options.test_options import TestOptions 16 | from BoostingMonocularDepth.pix2pix.models.pix2pix4depth_model import Pix2Pix4DepthModel 17 | 18 | import torch 19 | from torchvision.transforms import Compose 20 | from torchvision.transforms import transforms 21 | 22 | import time 23 | import os 24 | import cv2 25 | import numpy as np 26 | import argparse 27 | from argparse import Namespace 28 | import warnings 29 | 30 | whole_size_threshold = 3000 # R_max from the paper 31 | GPU_threshold = 1600 - 32 # Limit for the GPU (NVIDIA RTX 2080), can be adjusted 32 | 33 | def create_depth_models(device='cuda', midas_path=None, pix2pix_path=None): 34 | 35 | # opt = TestOptions().parse() 36 | opt = Namespace(Final=False, R0=False, R20=False, aspect_ratio=1.0, batch_size=1, checkpoints_dir=f'{curr_path}/BoostingMonocularDepth/pix2pix/checkpoints', colorize_results=False, crop_size=672, data_dir=None, dataroot=None, dataset_mode='depthmerge', depthNet=None, direction='AtoB', display_winsize=256, epoch='latest', eval=False, generatevideo=None, gpu_ids=[0], init_gain=0.02, init_type='normal', input_nc=2, isTrain=False, load_iter=0, load_size=672, max_dataset_size=10000, max_res=float('inf'), model='pix2pix4depth', n_layers_D=3, name='mergemodel', ndf=64, netD='basic', netG='unet_1024', net_receptive_field_size=None, ngf=64, no_dropout=False, no_flip=False, norm='none', num_test=50, num_threads=4, output_dir=None, output_nc=1, output_resolution=None, phase='test', pix2pixsize=None, preprocess='resize_and_crop', savecrops=None, savewholeest=None, serial_batches=False, suffix='', verbose=False) 37 | # opt = Namespace() 38 | # opt.gpu_ids = [0] 39 | # opt.isTrain = False 40 | # global pix2pixmodel 41 | 42 | pix2pixmodel = Pix2Pix4DepthModel(opt) 43 | 44 | if pix2pix_path == None: 45 | pix2pixmodel.save_dir = f'{curr_path}/BoostingMonocularDepth/pix2pix/checkpoints/mergemodel' 46 | else: 47 | pix2pixmode.save_dir = pix2pix_path 48 | 49 | pix2pixmodel.load_networks('latest') 50 | pix2pixmodel.eval() 51 | 52 | if midas_path == None: 53 | midas_model_path = f"{curr_path}/BoostingMonocularDepth/midas/model.pt" 54 | else: 55 | midas_model_path = midas_path 56 | 57 | # global midasmodel 58 | midasmodel = MidasNet(midas_model_path, non_negative=True) 59 | midasmodel.to(device) 60 | midasmodel.eval() 61 | 62 | return [pix2pixmodel, midasmodel] 63 | 64 | 65 | def get_depth(img, models, threshold=0.2): 66 | 67 | pix2pixmodel, midasmodel = models 68 | 69 | # Generate mask used to smoothly blend the local pathc estimations to the base estimate. 70 | # It is arbitrarily large to avoid artifacts during rescaling for each crop. 71 | mask_org = generatemask((3000, 3000)) 72 | mask = mask_org.copy() 73 | 74 | # Value x of R_x defined in the section 5 of the main paper. 75 | r_threshold_value = threshold 76 | 77 | # print("start processing") 78 | 79 | input_resolution = img.shape 80 | 81 | scale_threshold = 3 # Allows up-scaling with a scale up to 3 82 | 83 | # Find the best input resolution R-x. The resolution search described in section 5-double estimation of the main paper and section B of the 84 | # supplementary material. 85 | whole_image_optimal_size, patch_scale = calculateprocessingres(img, 384, 86 | r_threshold_value, scale_threshold, 87 | whole_size_threshold) 88 | 89 | # print('\t wholeImage being processed in :', whole_image_optimal_size) 90 | 91 | # Generate the base estimate using the double estimation. 92 | whole_estimate = doubleestimate(img, 384, whole_image_optimal_size, 1024, pix2pixmodel, midasmodel) 93 | whole_estimate = cv2.resize(whole_estimate, (input_resolution[1], input_resolution[0]), interpolation=cv2.INTER_CUBIC) 94 | 95 | return whole_estimate 96 | 97 | 98 | # Generate a double-input depth estimation 99 | def doubleestimate(img, size1, size2, pix2pixsize, pix2pixmodel, midasmodel): 100 | # Generate the low resolution estimation 101 | estimate1 = singleestimate(img, size1, midasmodel) 102 | # Resize to the inference size of merge network. 103 | estimate1 = cv2.resize(estimate1, (pix2pixsize, pix2pixsize), interpolation=cv2.INTER_CUBIC) 104 | 105 | # Generate the high resolution estimation 106 | estimate2 = singleestimate(img, size2, midasmodel) 107 | # Resize to the inference size of merge network. 108 | estimate2 = cv2.resize(estimate2, (pix2pixsize, pix2pixsize), interpolation=cv2.INTER_CUBIC) 109 | 110 | # Inference on the merge model 111 | pix2pixmodel.set_input(estimate1, estimate2) 112 | pix2pixmodel.test() 113 | visuals = pix2pixmodel.get_current_visuals() 114 | prediction_mapped = visuals['fake_B'] 115 | prediction_mapped = (prediction_mapped+1)/2 116 | prediction_mapped = (prediction_mapped - torch.min(prediction_mapped)) / ( 117 | torch.max(prediction_mapped) - torch.min(prediction_mapped)) 118 | prediction_mapped = prediction_mapped.squeeze().cpu().numpy() 119 | 120 | return prediction_mapped 121 | 122 | 123 | # Generate a single-input depth estimation 124 | def singleestimate(img, msize, midasmodel): 125 | if msize > GPU_threshold: 126 | # print(" \t \t DEBUG| GPU THRESHOLD REACHED", msize, '--->', GPU_threshold) 127 | msize = GPU_threshold 128 | 129 | return estimatemidas(img, midasmodel, msize) 130 | # elif net_type == 1: 131 | # return estimatesrl(img, msize) 132 | # elif net_type == 2: 133 | # return estimateleres(img, msize) 134 | 135 | 136 | def estimatemidas(img, midasmodel, msize, device='cuda'): 137 | # MiDas -v2 forward pass script adapted from https://github.com/intel-isl/MiDaS/tree/v2 138 | 139 | transform = Compose( 140 | [ 141 | Resize( 142 | msize, 143 | msize, 144 | resize_target=None, 145 | keep_aspect_ratio=True, 146 | ensure_multiple_of=32, 147 | resize_method="upper_bound", 148 | image_interpolation_method=cv2.INTER_CUBIC, 149 | ), 150 | NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 151 | PrepareForNet(), 152 | ] 153 | ) 154 | 155 | img_input = transform({"image": img})["image"] 156 | 157 | # Forward pass 158 | with torch.no_grad(): 159 | sample = torch.from_numpy(img_input).to(device).unsqueeze(0) 160 | prediction = midasmodel.forward(sample) 161 | 162 | prediction = prediction.squeeze().cpu().numpy() 163 | prediction = cv2.resize(prediction, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_CUBIC) 164 | 165 | # Normalization 166 | depth_min = prediction.min() 167 | depth_max = prediction.max() 168 | 169 | if depth_max - depth_min > np.finfo("float").eps: 170 | prediction = (prediction - depth_min) / (depth_max - depth_min) 171 | else: 172 | prediction = 0 173 | 174 | return prediction 175 | -------------------------------------------------------------------------------- /intrinsic_compositing/albedo/utils/edits.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from kornia.color import rgb_to_hsv, hsv_to_rgb 4 | 5 | import torchvision.transforms.functional as F 6 | # from ..argumentsparser import args 7 | 8 | COLORCURVE_L = 8 9 | 10 | def apply_whitebalancing(input, parameters): 11 | param = parameters['whitebalancing'] 12 | param = param / (param[:,1:2] + 1e-9) 13 | result = input / (param[:,:,None,None] + 1e-9) 14 | return result 15 | 16 | def apply_colorcurve(input, parameters): 17 | color_curve_param = torch.reshape(parameters['colorcurve'],(-1,3,COLORCURVE_L)) 18 | color_curve_sum = torch.sum(color_curve_param,dim=[2]) 19 | total_image = torch.zeros_like(input) 20 | for i in range(COLORCURVE_L): 21 | total_image += torch.clip(input * COLORCURVE_L - i, 0, 1) * color_curve_param[:,:,i][:,:,None,None] 22 | result = total_image / (color_curve_sum[:,:,None,None] + 1e-9) 23 | return result 24 | 25 | def apply_saturation(input, parameters): 26 | hsv = rgb_to_hsv(input) 27 | param = parameters['saturation'][:,:,None,None] 28 | s_new = hsv[:,1:2,:,:] * param 29 | hsv_new = hsv.clone() 30 | hsv_new[:,1:2,:,:] = s_new 31 | result = hsv_to_rgb(hsv_new) 32 | return result 33 | 34 | def apply_exposure(input, parameters): 35 | result = input * parameters['exposure'][:,:,None,None] 36 | return result 37 | 38 | 39 | def apply_blur(input, parameters): 40 | sigma = parameters['blur'][:,:,None,None] 41 | kernelsize = 2*torch.ceil(2*sigma)+1. 42 | 43 | result = torch.zeros_like(input) 44 | for B in range(input.shape[0]): 45 | kernelsize_ = (int(kernelsize[B].item()), int(kernelsize[B].item())) 46 | sigma_ = (sigma[B].item(), sigma[B].item()) 47 | result[B,:,:,:] = F.gaussian_blur(input[B:B+1,:,:,:], kernelsize_, sigma_) 48 | return result 49 | 50 | def apply_sharpness(input, parameters): 51 | param = parameters['sharpness'][:,0] 52 | result = torch.zeros_like(input) 53 | 54 | for B in range(input.shape[0]): 55 | result[B,:,:,:] = F.adjust_sharpness(input[B:B+1,:,:,:], param[B]) 56 | return result 57 | 58 | def get_edits(nops): 59 | if nops == 4: 60 | return { 61 | 0: apply_whitebalancing, 62 | 1: apply_colorcurve, 63 | 2: apply_saturation, 64 | 3: apply_exposure, 65 | } 66 | else: 67 | return { 68 | 0: apply_whitebalancing, 69 | 1: apply_colorcurve, 70 | 2: apply_saturation, 71 | 3: apply_exposure, 72 | 4: apply_blur, 73 | 5: apply_sharpness, 74 | } 75 | -------------------------------------------------------------------------------- /intrinsic_compositing/albedo/utils/networkutils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | def loadmodelweights(net, load_path, device): 7 | if isinstance(net, torch.nn.DataParallel): 8 | net = net.module 9 | print('loading the model from %s' % load_path) 10 | # if you are using PyTorch newer than 0.4 (e.g., built from 11 | # GitHub source), you can remove str() on self.device 12 | state_dict = torch.load(load_path, map_location=str(device)) 13 | if hasattr(state_dict, '_metadata'): 14 | del state_dict._metadata 15 | net.load_state_dict(state_dict) 16 | 17 | 18 | def init_net(net, gpu_ids=[]): 19 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 20 | """ 21 | if len(gpu_ids) > 0: 22 | assert(torch.cuda.is_available()) 23 | net.to(gpu_ids[0]) 24 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs 25 | return net 26 | 27 | def _calc_same_pad(i, k, s, d): 28 | return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) 29 | 30 | 31 | def _same_pad_arg(input_size, kernel_size, stride, dilation): 32 | ih, iw = input_size 33 | kh, kw = kernel_size 34 | pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0]) 35 | pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1]) 36 | return [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] 37 | 38 | class Conv2dSameExport(nn.Conv2d): 39 | 40 | """ ONNX export friendly Tensorflow like 'SAME' convolution wrapper for 2D convolutions 41 | """ 42 | 43 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 44 | padding=0, dilation=1, groups=1, bias=True): 45 | super(Conv2dSameExport, self).__init__( 46 | in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) 47 | self.pad = None 48 | self.pad_input_size = (0, 0) 49 | 50 | def forward(self, x): 51 | input_size = x.size()[-2:] 52 | if self.pad is None: 53 | pad_arg = _same_pad_arg( 54 | input_size, self.weight.size()[-2:], self.stride, self.dilation) 55 | self.pad = nn.ZeroPad2d(pad_arg) 56 | self.pad_input_size = input_size 57 | else: 58 | assert self.pad_input_size == input_size 59 | 60 | x = self.pad(x) 61 | return F.conv2d( 62 | x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) -------------------------------------------------------------------------------- /intrinsic_compositing/albedo/utils/utils.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | def normalize(x): 4 | return (x - 0.5) / 0.5 5 | 6 | def create_exp_name(args, prefix='RealismNet'): 7 | 8 | components = [] 9 | components.append(f"{prefix}") 10 | components.append(f"{datetime.now().strftime('%m_%d_%Y_%H_%M_%S')}") 11 | components.append(f"{args.expdscp}") 12 | components.append(f"lrd_{args.lr_d}") 13 | 14 | name = "_".join(components) 15 | return name 16 | 17 | def create_exp_name_editnet(args, prefix='EditingNet'): 18 | 19 | components = [] 20 | components.append(f"{prefix}") 21 | components.append(f"{datetime.now().strftime('%m_%d_%Y_%H_%M_%S')}") 22 | components.append(f"{args.expdscp}") 23 | components.append(f"lredit_{args.lr_editnet}") 24 | components.append(f"rlubias_{args.loss_relu_bias}") 25 | components.append(f"edtloss_{args.edit_loss}") 26 | components.append(f"fkgnlwdv_{args.fake_gen_lowdev}") 27 | components.append(f"blrshrpn_{args.blursharpen}") 28 | 29 | name = "_".join(components) 30 | return name -------------------------------------------------------------------------------- /intrinsic_compositing/shading/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compphoto/IntrinsicCompositing/ba50caeb9eaf2acb66739be25b4a22b12a3be7ca/intrinsic_compositing/shading/__init__.py -------------------------------------------------------------------------------- /intrinsic_compositing/shading/pipeline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import Adam 3 | 4 | import numpy as np 5 | 6 | from skimage.transform import resize 7 | 8 | from chrislib.general import uninvert, invert, round_32, view 9 | 10 | from altered_midas.midas_net import MidasNet 11 | 12 | def load_reshading_model(path, device='cuda'): 13 | 14 | if path == 'paper_weights': 15 | state_dict = torch.hub.load_state_dict_from_url('https://github.com/compphoto/IntrinsicCompositing/releases/download/1.0.0/shading_paper_weights.pt', map_location=device, progress=True) 16 | elif path == 'further_trained': 17 | state_dict = torch.hub.load_state_dict_from_url('https://github.com/compphoto/IntrinsicCompositing/releases/download/1.0.0/further_trained.pt', map_location=device, progress=True) 18 | else: 19 | state_dict = torch.load(path) 20 | 21 | shd_model = MidasNet(input_channels=9) 22 | shd_model.load_state_dict(state_dict) 23 | shd_model = shd_model.eval() 24 | shd_model = shd_model.to(device) 25 | 26 | return shd_model 27 | 28 | def spherical2cart(r, theta, phi): 29 | return [ 30 | r * torch.sin(theta) * torch.cos(phi), 31 | r * torch.sin(theta) * torch.sin(phi), 32 | r * torch.cos(theta) 33 | ] 34 | 35 | def run_optimization(params, A, b): 36 | 37 | optim = Adam([params], lr=0.01) 38 | prev_loss = 1000 39 | 40 | init_params = params.clone() 41 | 42 | for i in range(500): 43 | optim.zero_grad() 44 | 45 | x, y, z = spherical2cart(params[2], params[0], params[1]) 46 | 47 | dir_shd = (A[:, 0] * x) + (A[:, 1] * y) + (A[:, 2] * z) 48 | pred_shd = dir_shd + params[3] 49 | 50 | loss = torch.nn.functional.mse_loss(pred_shd.reshape(-1), b) 51 | 52 | loss.backward() 53 | 54 | optim.step() 55 | 56 | # theta can range from 0 -> pi/2 (0 to 90 degrees) 57 | # phi can range from 0 -> 2pi (0 to 360 degrees) 58 | with torch.no_grad(): 59 | if params[0] < 0: 60 | params[0] = 0 61 | 62 | if params[0] > np.pi / 2: 63 | params[0] = np.pi / 2 64 | 65 | if params[1] < 0: 66 | params[1] = 0 67 | 68 | if params[1] > 2 * np.pi: 69 | params[1] = 2 * np.pi 70 | 71 | if params[2] < 0: 72 | params[2] = 0 73 | 74 | if params[3] < 0.1: 75 | params[3] = 0.1 76 | 77 | delta = prev_loss - loss 78 | 79 | if delta < 0.0001: 80 | break 81 | 82 | prev_loss = loss 83 | 84 | return loss, params 85 | 86 | def test_init(params, A, b): 87 | x, y, z = spherical2cart(params[2], params[0], params[1]) 88 | 89 | dir_shd = (A[:, 0] * x) + (A[:, 1] * y) + (A[:, 2] * z) 90 | pred_shd = dir_shd + params[3] 91 | 92 | loss = torch.nn.functional.mse_loss(pred_shd.reshape(-1), b) 93 | return loss 94 | 95 | def get_light_coeffs(shd, nrm, img, mask=None, bias=True): 96 | img = resize(img, shd.shape) 97 | 98 | reg_shd = uninvert(shd) 99 | valid = (img.mean(-1) > 0.05) * (img.mean(-1) < 0.95) 100 | 101 | if mask is not None: 102 | valid *= (mask == 0) 103 | 104 | nrm = (nrm * 2.0) - 1.0 105 | 106 | A = nrm[valid == 1] 107 | # A = nrm.reshape(-1, 3) 108 | A /= np.linalg.norm(A, axis=1, keepdims=True) 109 | 110 | b = reg_shd[valid == 1] 111 | # b = reg_shd.reshape(-1) 112 | 113 | # parameters are theta, phi, and bias (c) 114 | A = torch.from_numpy(A) 115 | b = torch.from_numpy(b) 116 | 117 | min_init = 1000 118 | for t in np.arange(0, np.pi/2, 0.1): 119 | for p in np.arange(0, 2*np.pi, 0.25): 120 | params = torch.nn.Parameter(torch.tensor([t, p, 1, 0.5])) 121 | init_loss = test_init(params, A, b) 122 | 123 | if init_loss < min_init: 124 | best_init = params 125 | min_init = init_loss 126 | # print('new min:', min_init) 127 | 128 | loss, params = run_optimization(best_init, A, b) 129 | 130 | nrm_vis = nrm.copy() 131 | nrm_vis = draw_normal_circle(nrm_vis, (50, 50), 40) 132 | 133 | x, y, z = spherical2cart(params[2], params[0], params[1]) 134 | 135 | coeffs = torch.tensor([x, y, z]).reshape(3, 1).detach().numpy() 136 | out_shd = (nrm_vis.reshape(-1, 3) @ coeffs) + params[3].item() 137 | 138 | coeffs = np.array([x.item(), y.item(), z.item(), params[3].item()]) 139 | 140 | return coeffs, out_shd.reshape(shd.shape) 141 | 142 | def draw_normal_circle(nrm, loc, rad): 143 | size = rad * 2 144 | 145 | lin = np.linspace(-1, 1, num=size) 146 | ys, xs = np.meshgrid(lin, lin) 147 | 148 | zs = np.sqrt((1.0 - (xs**2 + ys**2)).clip(0)) 149 | valid = (zs != 0) 150 | normals = np.stack((ys[valid], -xs[valid], zs[valid]), 1) 151 | 152 | valid_mask = np.zeros((size, size)) 153 | valid_mask[valid] = 1 154 | 155 | full_mask = np.zeros((nrm.shape[0], nrm.shape[1])) 156 | x = loc[0] - rad 157 | y = loc[1] - rad 158 | full_mask[y : y + size, x : x + size] = valid_mask 159 | # nrm[full_mask > 0] = (normals + 1.0) / 2.0 160 | nrm[full_mask > 0] = normals 161 | 162 | return nrm 163 | 164 | def generate_shd(nrm, coeffs, msk, bias=True, viz=False): 165 | 166 | # if viz: 167 | # nrm = draw_normal_circle(nrm.copy(), (50, 50), 40) 168 | 169 | nrm = (nrm * 2.0) - 1.0 170 | 171 | A = nrm.reshape(-1, 3) 172 | A /= np.linalg.norm(A, axis=1, keepdims=True) 173 | 174 | A_fg = nrm[msk == 1] 175 | A_fg /= np.linalg.norm(A_fg, axis=1, keepdims=True) 176 | 177 | if bias: 178 | A = np.concatenate((A, np.ones((A.shape[0], 1))), 1) 179 | A_fg = np.concatenate((A_fg, np.ones((A_fg.shape[0], 1))), 1) 180 | 181 | inf_shd = (A_fg @ coeffs) 182 | inf_shd = inf_shd.clip(0) + 0.2 183 | 184 | if viz: 185 | shd_viz = (A @ coeffs).reshape(nrm.shape[:2]) 186 | shd_viz = shd_viz.clip(0) + 0.2 187 | return inf_shd, shd_viz 188 | 189 | 190 | return inf_shd 191 | 192 | def compute_reshading(orig, msk, inv_shd, depth, normals, alb, coeffs, model): 193 | 194 | # expects no channel dim on msk, shd and depth 195 | if len(inv_shd.shape) == 3: 196 | inv_shd = inv_shd[:, :, 0] 197 | 198 | if len(msk.shape) == 3: 199 | msk = msk[:, :, 0] 200 | 201 | if len(depth.shape) == 3: 202 | depth = depth[:, :, 0] 203 | 204 | h, w, _ = orig.shape 205 | 206 | # max_dim = max(h, w) 207 | # if max_dim > 1024: 208 | # scale = 1024 / max_dim 209 | # else: 210 | # scale = 1.0 211 | 212 | orig = resize(orig, (round_32(h), round_32(w))) 213 | alb = resize(alb, (round_32(h), round_32(w))) 214 | msk = resize(msk, (round_32(h), round_32(w))) 215 | inv_shd = resize(inv_shd, (round_32(h), round_32(w))) 216 | dpt = resize(depth, (round_32(h), round_32(w))) 217 | nrm = resize(normals, (round_32(h), round_32(w))) 218 | msk = msk.astype(np.single) 219 | 220 | hard_msk = (msk > 0.5) 221 | 222 | reg_shd = uninvert(inv_shd) 223 | img = (alb * reg_shd[:, :, None]).clip(0, 1) 224 | 225 | orig_alb = orig / reg_shd[:, :, None].clip(1e-4) 226 | 227 | bad_shd_np = reg_shd.copy() 228 | inf_shd = generate_shd(nrm, coeffs, hard_msk) 229 | bad_shd_np[hard_msk == 1] = inf_shd 230 | 231 | bad_img_np = alb * bad_shd_np[:, :, None] 232 | 233 | sem_msk = torch.from_numpy(msk).unsqueeze(0) 234 | bad_img = torch.from_numpy(bad_img_np).permute(2, 0, 1) 235 | bad_shd = torch.from_numpy(invert(bad_shd_np)).unsqueeze(0) 236 | in_nrm = torch.from_numpy(nrm).permute(2, 0, 1) 237 | in_dpt = torch.from_numpy(dpt).unsqueeze(0) 238 | # inp = torch.cat((sem_msk, bad_img, bad_shd), dim=0).unsqueeze(0) 239 | inp = torch.cat((sem_msk, bad_img, bad_shd, in_nrm, in_dpt), dim=0).unsqueeze(0) 240 | inp = inp.cuda() 241 | 242 | with torch.no_grad(): 243 | out = model(inp).squeeze() 244 | 245 | fin_shd = out.detach().cpu().numpy() 246 | fin_shd = uninvert(fin_shd) 247 | fin_img = alb * fin_shd[:, :, None] 248 | 249 | normals = resize(nrm, (h, w)) 250 | fin_shd = resize(fin_shd, (h, w)) 251 | fin_img = resize(fin_img, (h, w)) 252 | bad_shd_np = resize(bad_shd_np, (h, w)) 253 | 254 | result = {} 255 | result['reshading'] = fin_shd 256 | result['init_shading'] = bad_shd_np 257 | result['composite'] = (fin_img ** (1/2.2)).clip(0, 1) 258 | result['normals'] = normals 259 | 260 | return result 261 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | setuptools.setup( 3 | name="intrinsic_compositing", 4 | version="0.0.1", 5 | author="Chris Careaga", 6 | author_email="chris_careaga@sfu.ca", 7 | description='a package containing to the code for the paper "Intrinsic Harmonization for Illumination-Aware Compositing"', 8 | url="", 9 | packages=setuptools.find_packages(), 10 | license="", 11 | python_requires=">3.6", 12 | install_requires=[ 13 | 'altered_midas @ git+https://github.com/CCareaga/MiDaS@master', 14 | 'chrislib @ git+https://github.com/CCareaga/chrislib@main', 15 | 'omnidata_tools @ git+https://github.com/CCareaga/omnidata@main', 16 | 'boosted_depth @ git+https://github.com/CCareaga/BoostingMonocularDepth@main', 17 | 'intrinsic @ git+https://github.com/compphoto/intrinsic@d9741e99b2997e679c4055e7e1f773498b791288' 18 | ] 19 | ) 20 | --------------------------------------------------------------------------------