├── .github └── workflows │ └── github-actions-demo.yml ├── .gitignore ├── LICENSE.md ├── README.md ├── __pycache__ ├── main.cpython-37.pyc ├── utils.cpython-36.pyc ├── utils.cpython-37.pyc └── utils.cpython-38.pyc ├── checkpoints └── config.yml ├── cog.yaml ├── config.yml.example ├── examples └── my_small_data │ ├── 0090.jpg │ ├── 0095.jpg │ ├── 0376.jpg │ ├── 0378.jpg │ ├── 1023_256.jpg │ ├── 1055_256.jpg │ ├── 141_256.jpg │ ├── 277_256.jpg │ ├── 292_256.jpg │ ├── 313_256.jpg │ ├── 386_256.jpg │ ├── 856_256.jpg │ ├── 985_256.jpg │ ├── IMG_4164.jpg │ ├── IMG_4467.jpg │ ├── IMG_4804.jpg │ └── places2_02.png ├── main.py ├── predict.py ├── requirements.txt ├── scripts ├── download_model.sh ├── fid_score.py ├── flist.py ├── inception.py └── metrics.py ├── segmentation_classes.txt ├── setup.cfg ├── src ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── config.cpython-36.pyc │ ├── config.cpython-37.pyc │ ├── dataset.cpython-36.pyc │ ├── dataset.cpython-37.pyc │ ├── edge_connect.cpython-36.pyc │ ├── edge_connect.cpython-37.pyc │ ├── loss.cpython-36.pyc │ ├── loss.cpython-37.pyc │ ├── metrics.cpython-36.pyc │ ├── metrics.cpython-37.pyc │ ├── models.cpython-36.pyc │ ├── models.cpython-37.pyc │ ├── networks.cpython-36.pyc │ ├── networks.cpython-37.pyc │ ├── segmentor_fcn.cpython-37.pyc │ ├── utils.cpython-36.pyc │ └── utils.cpython-37.pyc ├── config.py ├── dataset.py ├── edge_connect.py ├── loss.py ├── models.py ├── networks.py ├── segmentor_fcn.py └── utils.py └── test.py /.github/workflows/github-actions-demo.yml: -------------------------------------------------------------------------------- 1 | name: GitHub Actions start and python test 2 | on: 3 | # push: 4 | # branches: [ master ] 5 | pull_request: 6 | branches: [ master ] 7 | 8 | jobs: 9 | intro: 10 | name: first-interaction 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - name: First interaction 15 | uses: actions/first-interaction@v1.1.0 16 | with: 17 | # Token for the repository. Can be passed in using {{ secrets.GITHUB_TOKEN }} 18 | repo-token: ${{ secrets.GITHUB_TOKEN }} 19 | # Comment to post on an individual's first issue 20 | issue-message: "first issue" 21 | # Comment to post on an individual's first pull request 22 | pr-message: "welcome to my repo" 23 | 24 | 25 | python-testing: 26 | name: testing 27 | runs-on: ubuntu-latest 28 | strategy: 29 | fail-fast: false 30 | matrix: 31 | python-version: [3.7, 3.8] 32 | 33 | steps: 34 | - uses: actions/checkout@v2 35 | - name: Set up Python ${{ matrix.python-version }} 36 | uses: actions/setup-python@v2 37 | with: 38 | python-version: ${{ matrix.python-version }} 39 | - name: Install dependencies 40 | run: | 41 | python -m pip install --upgrade pip 42 | python -m pip install flake8 pytest 43 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 44 | - name: Lint with flake8 45 | run: | 46 | DIFF=$(git diff --name-only --diff-filter=b $(git merge-base HEAD $BRANCH)) 47 | # stop the build if there are Python syntax errors or undefined names 48 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 49 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 50 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 51 | 52 | # - name: Test with pytest 53 | # run: | 54 | # pytest 55 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pth 2 | *.pyc -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | ## creative commons 2 | 3 | # Attribution-NonCommercial 4.0 International 4 | 5 | Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible. 6 | 7 | ### Using Creative Commons Public Licenses 8 | 9 | Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses. 10 | 11 | * __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors). 12 | 13 | * __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees). 14 | 15 | ## Creative Commons Attribution-NonCommercial 4.0 International Public License 16 | 17 | By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions. 18 | 19 | ### Section 1 – Definitions. 20 | 21 | a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image. 22 | 23 | b. __Adapter's License__ means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License. 24 | 25 | c. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights. 26 | 27 | d. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements. 28 | 29 | e. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material. 30 | 31 | f. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License. 32 | 33 | g. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license. 34 | 35 | h. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License. 36 | 37 | i. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange. 38 | 39 | j. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them. 40 | 41 | k. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world. 42 | 43 | l. __You__ means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning. 44 | 45 | ### Section 2 – Scope. 46 | 47 | a. ___License grant.___ 48 | 49 | 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to: 50 | 51 | A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and 52 | 53 | B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only. 54 | 55 | 2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions. 56 | 57 | 3. __Term.__ The term of this Public License is specified in Section 6(a). 58 | 59 | 4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material. 60 | 61 | 5. __Downstream recipients.__ 62 | 63 | A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License. 64 | 65 | B. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material. 66 | 67 | 6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i). 68 | 69 | b. ___Other rights.___ 70 | 71 | 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise. 72 | 73 | 2. Patent and trademark rights are not licensed under this Public License. 74 | 75 | 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes. 76 | 77 | ### Section 3 – License Conditions. 78 | 79 | Your exercise of the Licensed Rights is expressly made subject to the following conditions. 80 | 81 | a. ___Attribution.___ 82 | 83 | 1. If You Share the Licensed Material (including in modified form), You must: 84 | 85 | A. retain the following if it is supplied by the Licensor with the Licensed Material: 86 | 87 | i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated); 88 | 89 | ii. a copyright notice; 90 | 91 | iii. a notice that refers to this Public License; 92 | 93 | iv. a notice that refers to the disclaimer of warranties; 94 | 95 | v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable; 96 | 97 | B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and 98 | 99 | C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License. 100 | 101 | 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information. 102 | 103 | 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable. 104 | 105 | 4. If You Share Adapted Material You produce, the Adapter's License You apply must not prevent recipients of the Adapted Material from complying with this Public License. 106 | 107 | ### Section 4 – Sui Generis Database Rights. 108 | 109 | Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material: 110 | 111 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only; 112 | 113 | b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material; and 114 | 115 | c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database. 116 | 117 | For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights. 118 | 119 | ### Section 5 – Disclaimer of Warranties and Limitation of Liability. 120 | 121 | a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__ 122 | 123 | b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__ 124 | 125 | c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability. 126 | 127 | ### Section 6 – Term and Termination. 128 | 129 | a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically. 130 | 131 | b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates: 132 | 133 | 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or 134 | 135 | 2. upon express reinstatement by the Licensor. 136 | 137 | For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License. 138 | 139 | c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License. 140 | 141 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License. 142 | 143 | ### Section 7 – Other Terms and Conditions. 144 | 145 | a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed. 146 | 147 | b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License. 148 | 149 | ### Section 8 – Interpretation. 150 | 151 | a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License. 152 | 153 | b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions. 154 | 155 | c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor. 156 | 157 | d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority. 158 | 159 | > Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at [creativecommons.org/policies](http://creativecommons.org/policies), Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses. 160 | > 161 | > Creative Commons may be contacted at creativecommons.org 162 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Automated-Objects-Removal-Inpainter 2 | 3 | [Demo and Docker image on Replicate](https://replicate.com/sujaykhandekar/object-removal) 4 | 5 | 6 | 7 | Automated object remover Inpainter is a project that combines Semantic segmentation and EdgeConnect architectures with minor changes in order to remove specified objects from photos. For Semantic Segmentation, the code from pytorch has been adapted, whereas for EdgeConnect, the code has been adapted from [https://github.com/knazeri/edge-connect](https://github.com/knazeri/edge-connect). 8 | 9 | This project is capable of removing objects from list of 20 different ones.It can be used as photo editing tool as well as for Data augmentation. 10 | 11 | Python 3.8.5 and pytorch 1.5.1 have been used in this project. 12 | 13 | ## How does it work? 14 | 15 | 16 | 17 | Semantic segmenator model of deeplabv3/fcn resnet 101 has been combined with EdgeConnect. A pre-trained segmentation network has been used for object segmentation (generating a mask around detected object), and its output is fed to a EdgeConnect network along with input image with portion of mask removed. EdgeConnect uses two stage adversarial architecture where first stage is edge generator followed by image completion network. EdgeConnect paper can be found [here](https://arxiv.org/abs/1901.00212) and code in this [repo](https://github.com/knazeri/edge-connect) 18 | 19 | 20 | 21 | 22 | ## Prerequisite 23 | * python 3 24 | * pytorch 1.0.1 < 25 | * NVIDIA GPU + CUDA cuDNN (optional) 26 | 27 | ## Installation 28 | * clone this repo 29 | ``` 30 | git clone https://github.com/sujaykhandekar/Automated-objects-removal-inpainter.git 31 | cd Automated-objects-removal-inpainter 32 | ``` 33 | or alternately download zip file. 34 | * install pytorch with this command 35 | ``` 36 | conda install pytorch==1.5.1 torchvision==0.6.1 -c pytorch 37 | ``` 38 | * install other python requirements using this command 39 | ``` 40 | pip install -r requirements.txt 41 | ``` 42 | * Install one of the three pretrained Edgeconnect model and copy them in ./checkpoints directory 43 | [Plcaes2](https://drive.google.com/drive/folders/1qjieeThyse_iwJ0gZoJr-FERkgD5sm4y?usp=sharing) (option 1) 44 | [CelebA](https://drive.google.com/drive/folders/1nkLOhzWL-w2euo0U6amhz7HVzqNC5rqb) (option 2) 45 | [Paris-street-view](https://drive.google.com/drive/folders/1cGwDaZqDcqYU7kDuEbMXa9TP3uDJRBR1) (option 3) 46 | 47 | or alternately you can use this command: 48 | ``` 49 | bash ./scripts/download_model.sh 50 | ``` 51 | 52 | ## Prediction/Test 53 | For quick prediction you can run this command. If you don't have cuda/gpu please run the second command. 54 | ``` 55 | python test.py --input ./examples/my_small_data --output ./checkpoints/resultsfinal --remove 3 15 56 | ``` 57 | It will take sample images in the ./examples/my_small_data directory and will create and produce result in directory ./checkpoints/resultsfinal. You can replace these input /output directories with your desired ones. 58 | numbers after --remove specifies objects to be removed in the images. ABove command will remove 3(bird) and 15(people) from the images. Check segmentation-classes.txt for all removal options along with it's number. 59 | 60 | Output images will all be 256x256. It takes around 10 minutes for 1000 images on NVIDIA GeForce GTX 1650 61 | 62 | for better quality but slower runtime you can use this command 63 | ``` 64 | python test.py --input ./examples/my_small_data --output ./checkpoints/resultsfinal --remove 3 15 --cpu yes 65 | ``` 66 | It will run the segmentation model on cpu. It will be 5 times slower than on gpu (default) 67 | For other options including different segmentation model and EdgeConnect parameters to change please make corresponding modifications in .checkpoints/config.yml file 68 | 69 | ## training 70 | For training your own segmentation model you can refer to this [repo](https://github.com/CSAILVision/semantic-segmentation-pytorch) and replace .src/segmentor_fcn.py with your model. 71 | 72 | For training Edgeconnect model plaese refer to orignal [EdgeConnect repo](https://github.com/knazeri/edge-connect) after training you can copy your model weights in .checkpoints/ 73 | 74 | ## some results 75 | 76 | 77 | ## Next Steps 78 | * pretrained EdgeConnect models used in this project are trained on 256 x256 images. To make output images of the same size as input two approaches can be used. You can train your own Edgeconnect model on bigger images.Or you can create subimages of 256x256 for every object detected in the image and then merge them back together after passing through edgeconnect to reconstruct orignal sized image.Similar approach has been used in this [repo](https://github.com/javirk/Person_remover) 79 | * To detect object not present in segmentation classes , you can train your own segmentation model or you can use pretrained segmentation models from this [repo](https://github.com/CSAILVision/semantic-segmentation-pytorch), which has 150 different categories available. 80 | * It is also possible to combine opnecv's feature matching and edge prediction from EdgeConnect to highlight and create mask for relvant objects based on single mask created by user. I may try this part myself. 81 | 82 | ## License 83 | Licensed under a [Creative Commons Attribution-NonCommercial 4.0 International.](https://creativecommons.org/licenses/by-nc/4.0/) 84 | 85 | Except where otherwise noted, this content is published under a [CC BY-NC](https://github.com/knazeri/edge-connect) license, which means that you can copy, remix, transform and build upon the content as long as you do not use the material for commercial purposes and give appropriate credit and provide a link to the license. 86 | 87 | ## Citation 88 | ``` 89 | @inproceedings{nazeri2019edgeconnect, 90 | title={EdgeConnect: Generative Image Inpainting with Adversarial Edge Learning}, 91 | author={Nazeri, Kamyar and Ng, Eric and Joseph, Tony and Qureshi, Faisal and Ebrahimi, Mehran}, 92 | journal={arXiv preprint}, 93 | year={2019}, 94 | } 95 | 96 | @InProceedings{Nazeri_2019_ICCV, 97 | title = {EdgeConnect: Structure Guided Image Inpainting using Edge Prediction}, 98 | author = {Nazeri, Kamyar and Ng, Eric and Joseph, Tony and Qureshi, Faisal and Ebrahimi, Mehran}, 99 | booktitle = {The IEEE International Conference on Computer Vision (ICCV) Workshops}, 100 | month = {Oct}, 101 | year = {2019} 102 | } 103 | ``` 104 | -------------------------------------------------------------------------------- /__pycache__/main.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/__pycache__/main.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /checkpoints/config.yml: -------------------------------------------------------------------------------- 1 | MODE: 1 # 1: train, 2: test, 3: eval 2 | MODEL: 1 # 1: edge model, 2: inpaint model, 3: edge-inpaint model, 4: joint model 3 | MASK: 3 # 1: random block, 2: half, 3: external, 4: (external, random block), 5: (external, random block, half) 4 | EDGE: 1 # 1: canny, 2: external 5 | NMS: 1 # 0: no non-max-suppression, 1: applies non-max-suppression on the external edges by multiplying by Canny 6 | SEED: 10 # random seed 7 | GPU: [0] # list of gpu ids 8 | DEBUG: 0 # turns on debugging mode 9 | VERBOSE: 0 # turns on verbose mode in the output console 10 | 11 | TRAIN_FLIST: ./datasets/places2_train.flist 12 | VAL_FLIST: ./datasets/places2_val.flist 13 | TEST_FLIST: ./datasets/places2_test.flist 14 | 15 | TRAIN_EDGE_FLIST: ./datasets/places2_edges_train.flist 16 | VAL_EDGE_FLIST: ./datasets/places2_edges_val.flist 17 | TEST_EDGE_FLIST: ./datasets/places2_edges_test.flist 18 | 19 | TRAIN_MASK_FLIST: ./datasets/masks_train.flist 20 | VAL_MASK_FLIST: ./datasets/masks_val.flist 21 | TEST_MASK_FLIST: ./datasets/masks_test.flist 22 | 23 | LR: 0.0001 # learning rate 24 | D2G_LR: 0.1 # discriminator/generator learning rate ratio 25 | BETA1: 0.0 # adam optimizer beta1 26 | BETA2: 0.9 # adam optimizer beta2 27 | BATCH_SIZE: 8 # input batch size for training 28 | INPUT_SIZE: 256 # input image size for training 0 for original size 29 | SIGMA: 2 # standard deviation of the Gaussian filter used in Canny edge detector (0: random, -1: no edge) 30 | MAX_ITERS: 2e6 # maximum number of iterations to train the model 31 | 32 | EDGE_THRESHOLD: 0.5 # edge detection threshold 33 | L1_LOSS_WEIGHT: 1 # l1 loss weight 34 | FM_LOSS_WEIGHT: 10 # feature-matching loss weight 35 | STYLE_LOSS_WEIGHT: 250 # style loss weight 36 | CONTENT_LOSS_WEIGHT: 0.1 # perceptual loss weight 37 | INPAINT_ADV_LOSS_WEIGHT: 0.1 # adversarial loss weight 38 | 39 | GAN_LOSS: nsgan # nsgan | lsgan | hinge 40 | GAN_POOL_SIZE: 0 # fake images pool size 41 | 42 | SAVE_INTERVAL: 1000 # how many iterations to wait before saving model (0: never) 43 | SAMPLE_INTERVAL: 1000 # how many iterations to wait before sampling (0: never) 44 | SAMPLE_SIZE: 12 # number of images to sample 45 | EVAL_INTERVAL: 0 # how many iterations to wait before model evaluation (0: never) 46 | LOG_INTERVAL: 10 # how many iterations to wait before logging training status (0: never) 47 | SEG_NETWORK: 0 # 0:DeepLabV3 resnet 101 segmentation , 1: FCN resnet 101 segmentation -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md 3 | 4 | build: 5 | # set to true if your model requires a GPU 6 | gpu: true 7 | 8 | # a list of ubuntu apt packages to install 9 | # system_packages: 10 | # - "libgl1-mesa-glx" 11 | # - "libglib2.0-0" 12 | 13 | # python version in the form '3.8' or '3.8.12' 14 | python_version: "3.8" 15 | 16 | # a list of packages in the format == 17 | python_packages: 18 | - "ipython==7.33.0" 19 | # - "numpy==1.19.4" 20 | # - "torch==1.11.0" 21 | # - "torchvision==0.12.0" 22 | - "torch==1.5.1" 23 | - "torchvision==0.6.1" 24 | - "Pillow==8.0.1" 25 | 26 | - "numpy==1.19.1" 27 | - "scipy==1.5.2" 28 | - "future==0.18.2" 29 | - "matplotlib==3.3.0" 30 | # - "pillow==6.2.0" 31 | - "opencv-python==4.3.0.36" 32 | - "scikit-image==0.17.2" 33 | - "pyaml==20.4.0" 34 | 35 | # commands run after the environment is setup 36 | run: 37 | - "apt-get update && apt-get install -y cmake" 38 | # - "sudo apt-get install python python-pip build-essential cmake" 39 | # - "pip install tqdm gdown kornia scipy opencv-python dlib moviepy lpips aubio ninja" 40 | 41 | # predict.py defines how predictions are run on your model 42 | predict: "predict.py:Predictor" -------------------------------------------------------------------------------- /config.yml.example: -------------------------------------------------------------------------------- 1 | MODE: 1 # 1: train, 2: test, 3: eval 2 | MODEL: 1 # 1: edge model, 2: inpaint model, 3: edge-inpaint model, 4: joint model 3 | MASK: 3 # 1: random block, 2: half, 3: external, 4: (external, random block), 5: (external, random block, half) 4 | EDGE: 1 # 1: canny, 2: external 5 | NMS: 1 # 0: no non-max-suppression, 1: applies non-max-suppression on the external edges by multiplying by Canny 6 | SEED: 10 # random seed 7 | GPU: [0] # list of gpu ids 8 | DEBUG: 0 # turns on debugging mode 9 | VERBOSE: 0 # turns on verbose mode in the output console 10 | 11 | TRAIN_FLIST: ./datasets/places2_train.flist 12 | VAL_FLIST: ./datasets/places2_val.flist 13 | TEST_FLIST: ./datasets/places2_test.flist 14 | 15 | TRAIN_EDGE_FLIST: ./datasets/places2_edges_train.flist 16 | VAL_EDGE_FLIST: ./datasets/places2_edges_val.flist 17 | TEST_EDGE_FLIST: ./datasets/places2_edges_test.flist 18 | 19 | TRAIN_MASK_FLIST: ./datasets/masks_train.flist 20 | VAL_MASK_FLIST: ./datasets/masks_val.flist 21 | TEST_MASK_FLIST: ./datasets/masks_test.flist 22 | 23 | LR: 0.0001 # learning rate 24 | D2G_LR: 0.1 # discriminator/generator learning rate ratio 25 | BETA1: 0.0 # adam optimizer beta1 26 | BETA2: 0.9 # adam optimizer beta2 27 | BATCH_SIZE: 8 # input batch size for training 28 | INPUT_SIZE: 256 # input image size for training 0 for original size 29 | SIGMA: 2 # standard deviation of the Gaussian filter used in Canny edge detector (0: random, -1: no edge) 30 | MAX_ITERS: 2e6 # maximum number of iterations to train the model 31 | 32 | EDGE_THRESHOLD: 0.5 # edge detection threshold 33 | L1_LOSS_WEIGHT: 1 # l1 loss weight 34 | FM_LOSS_WEIGHT: 10 # feature-matching loss weight 35 | STYLE_LOSS_WEIGHT: 250 # style loss weight 36 | CONTENT_LOSS_WEIGHT: 0.1 # perceptual loss weight 37 | INPAINT_ADV_LOSS_WEIGHT: 0.1 # adversarial loss weight 38 | 39 | GAN_LOSS: nsgan # nsgan | lsgan | hinge 40 | GAN_POOL_SIZE: 0 # fake images pool size 41 | 42 | SAVE_INTERVAL: 1000 # how many iterations to wait before saving model (0: never) 43 | SAMPLE_INTERVAL: 1000 # how many iterations to wait before sampling (0: never) 44 | SAMPLE_SIZE: 12 # number of images to sample 45 | EVAL_INTERVAL: 0 # how many iterations to wait before model evaluation (0: never) 46 | LOG_INTERVAL: 10 # how many iterations to wait before logging training status (0: never) -------------------------------------------------------------------------------- /examples/my_small_data/0090.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/examples/my_small_data/0090.jpg -------------------------------------------------------------------------------- /examples/my_small_data/0095.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/examples/my_small_data/0095.jpg -------------------------------------------------------------------------------- /examples/my_small_data/0376.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/examples/my_small_data/0376.jpg -------------------------------------------------------------------------------- /examples/my_small_data/0378.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/examples/my_small_data/0378.jpg -------------------------------------------------------------------------------- /examples/my_small_data/1023_256.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/examples/my_small_data/1023_256.jpg -------------------------------------------------------------------------------- /examples/my_small_data/1055_256.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/examples/my_small_data/1055_256.jpg -------------------------------------------------------------------------------- /examples/my_small_data/141_256.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/examples/my_small_data/141_256.jpg -------------------------------------------------------------------------------- /examples/my_small_data/277_256.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/examples/my_small_data/277_256.jpg -------------------------------------------------------------------------------- /examples/my_small_data/292_256.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/examples/my_small_data/292_256.jpg -------------------------------------------------------------------------------- /examples/my_small_data/313_256.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/examples/my_small_data/313_256.jpg -------------------------------------------------------------------------------- /examples/my_small_data/386_256.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/examples/my_small_data/386_256.jpg -------------------------------------------------------------------------------- /examples/my_small_data/856_256.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/examples/my_small_data/856_256.jpg -------------------------------------------------------------------------------- /examples/my_small_data/985_256.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/examples/my_small_data/985_256.jpg -------------------------------------------------------------------------------- /examples/my_small_data/IMG_4164.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/examples/my_small_data/IMG_4164.jpg -------------------------------------------------------------------------------- /examples/my_small_data/IMG_4467.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/examples/my_small_data/IMG_4467.jpg -------------------------------------------------------------------------------- /examples/my_small_data/IMG_4804.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/examples/my_small_data/IMG_4804.jpg -------------------------------------------------------------------------------- /examples/my_small_data/places2_02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/examples/my_small_data/places2_02.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import random 4 | import numpy as np 5 | import torch 6 | import argparse 7 | from shutil import copyfile 8 | from src.config import Config 9 | from src.edge_connect import EdgeConnect 10 | 11 | 12 | def main(mode=None): 13 | r"""starts the model 14 | 15 | """ 16 | 17 | config = load_config(mode) 18 | 19 | 20 | # cuda visble devices 21 | os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(str(e) for e in config.GPU) 22 | 23 | 24 | # init device 25 | if torch.cuda.is_available(): 26 | config.DEVICE = torch.device("cuda") 27 | torch.backends.cudnn.benchmark = True # cudnn auto-tuner 28 | else: 29 | config.DEVICE = torch.device("cpu") 30 | 31 | 32 | 33 | # set cv2 running threads to 1 (prevents deadlocks with pytorch dataloader) 34 | cv2.setNumThreads(0) 35 | 36 | 37 | # initialize random seed 38 | torch.manual_seed(config.SEED) 39 | torch.cuda.manual_seed_all(config.SEED) 40 | np.random.seed(config.SEED) 41 | random.seed(config.SEED) 42 | 43 | 44 | 45 | # build the model and initialize 46 | model = EdgeConnect(config) 47 | model.load() 48 | 49 | 50 | 51 | # model test 52 | print('\nstart testing...\n') 53 | model.test() 54 | 55 | 56 | 57 | def load_config(mode=None): 58 | r"""loads model config 59 | 60 | """ 61 | 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument('--path', '--checkpoints', type=str, default='./checkpoints', help='model checkpoints path (default: ./checkpoints)') 64 | parser.add_argument('--model', type=int, choices=[1, 2, 3, 4], help='1: edge model, 2: inpaint model, 3: edge-inpaint model, 4: joint model') 65 | 66 | # test mode 67 | parser.add_argument('--input', type=str, help='path to the input images directory or an input image') 68 | parser.add_argument('--edge', type=str, help='path to the edges directory or an edge file') 69 | parser.add_argument('--output', type=str, help='path to the output directory') 70 | parser.add_argument('--remove', nargs= '*' ,type=int, help='objects to remove') 71 | parser.add_argument('--cpu', type=str, help='machine to run segmentation model on') 72 | args = parser.parse_args() 73 | 74 | #if path for checkpoint not given 75 | if args.path is None: 76 | args.path='./checkpoints' 77 | config_path = os.path.join(args.path, 'config.yml') 78 | 79 | # create checkpoints path if does't exist 80 | if not os.path.exists(args.path): 81 | os.makedirs(args.path) 82 | 83 | # copy config template if does't exist 84 | if not os.path.exists(config_path): 85 | copyfile('./config.yml.example', config_path) 86 | 87 | # load config file 88 | config = Config(config_path) 89 | 90 | 91 | # test mode 92 | config.MODE = 2 93 | config.MODEL = args.model if args.model is not None else 3 94 | config.OBJECTS = args.remove if args.remove is not None else [3,15] 95 | config.SEG_DEVICE = 'cpu' if args.cpu is not None else 'cuda' 96 | config.INPUT_SIZE = 256 97 | if args.input is not None: 98 | config.TEST_FLIST = args.input 99 | 100 | if args.edge is not None: 101 | config.TEST_EDGE_FLIST = args.edge 102 | if args.output is not None: 103 | config.RESULTS = args.output 104 | else: 105 | if not os.path.exists('./results_images'): 106 | os.makedirs('./results_images') 107 | config.RESULTS = './results_images' 108 | 109 | 110 | 111 | 112 | 113 | return config 114 | 115 | 116 | if __name__ == "__main__": 117 | main() 118 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | from shutil import copyfile 5 | 6 | import cv2 7 | import numpy as np 8 | import torch 9 | import torchvision 10 | import torchvision.transforms 11 | from PIL import Image 12 | 13 | 14 | # cog 15 | from cog import BasePredictor, Input, Path 16 | 17 | from src.config import Config 18 | from src.edge_connect import EdgeConnect 19 | 20 | import tempfile 21 | 22 | 23 | # Maps object to index 24 | obj2idx = { 25 | "Background":0, "Aeroplane":1, "bicycle":2, "bird":3, "boat":4, "bottle":5, "bus":6, "car":7, "cat":8, 26 | "chair":9, "cow":10, "dining table":11, "dog":12, "horse":13, "motorbike":14, "person":15, 27 | "potted plant":16, "sheep":17, "sofa":18, "train":19, "tv/monitor":20 28 | } 29 | 30 | 31 | 32 | def load_config(mode=None, objects_to_remove=None): 33 | print('Object(s) to remove:', objects_to_remove) 34 | 35 | 36 | # load config file 37 | path = "./checkpoints" 38 | config_path = os.path.join(path, "config.yml") 39 | 40 | # create checkpoints path if does't exist 41 | if not os.path.exists(path): 42 | os.makedirs(path) 43 | 44 | # copy config template if does't exist 45 | if not os.path.exists(config_path): 46 | copyfile("./config.yml.example", config_path) 47 | 48 | # load config file 49 | config = Config(config_path) 50 | 51 | # test mode 52 | config.MODE = mode 53 | config.MODEL = 3 54 | config.OBJECTS = objects_to_remove 55 | config.SEG_DEVICE = "cuda" 56 | config.INPUT_SIZE = 256 57 | 58 | # outputs 59 | if not os.path.exists("./results_images"): 60 | os.makedirs("./results_images") 61 | config.RESULTS = "./results_images" 62 | return config 63 | 64 | 65 | # Instantiate Cog Predictor 66 | class Predictor(BasePredictor): 67 | def setup(self): 68 | 69 | # Select torch device 70 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 71 | 72 | def predict(self, 73 | image_path: Path = Input(description="Input image (ideally a square image)"), 74 | objects_to_remove: str = Input(description="Object(s) to remove (separate with comma, e.g. car,cat,bird). See full list of names at https://github.com/sujaykhandekar/Automated-objects-removal-inpainter/blob/master/segmentation_classes.txt", default='person,car'), 75 | 76 | ) -> Path: 77 | 78 | # format input image 79 | image_path = str(image_path) 80 | image = Image.open(image_path).convert('RGB') 81 | image.save(image_path) # resave formatted image 82 | 83 | # parse objects to remove 84 | objects_to_remove = objects_to_remove.split(',') 85 | objects_to_remove = [obj2idx[x] for x in objects_to_remove] 86 | 87 | mode = 2 # 1: train, 2: test, 3: eal 88 | self.config = load_config(mode, objects_to_remove=objects_to_remove) 89 | # set cv2 running threads to 1 (prevents deadlocks with pytorch dataloader) 90 | cv2.setNumThreads(0) 91 | 92 | # initialize random seed 93 | torch.manual_seed(self.config.SEED) 94 | torch.cuda.manual_seed_all(self.config.SEED) 95 | np.random.seed(self.config.SEED) 96 | random.seed(self.config.SEED) 97 | 98 | # save to path 99 | self.config.TEST_FLIST = image_path 100 | 101 | # build the model and initialize 102 | model = EdgeConnect(self.config) 103 | model.load() 104 | 105 | # model test 106 | output_image = model.test() 107 | output_image = output_image.cpu().numpy() 108 | output_image = Image.fromarray(np.uint8(output_image)).convert('RGB') 109 | 110 | # save output image as Cog Path object 111 | output_path = Path(tempfile.mkdtemp()) / "output.png" 112 | output_image.save(output_path) 113 | print(output_path) 114 | return output_path 115 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy ~= 1.19.1 2 | scipy ~= 1.5.2 3 | future ~= 0.18.2 4 | matplotlib ~= 3.3.0 5 | pillow >= 6.2.0 6 | opencv-python ~= 4.3.0.36 7 | scikit-image ~= 0.17.2 8 | pyaml ~= 20.4.0 9 | -------------------------------------------------------------------------------- /scripts/download_model.sh: -------------------------------------------------------------------------------- 1 | DIR='./checkpoints' 2 | URL='https://drive.google.com/uc?export=download&id=1IrlFQGTpdQYdPeZIEgGUaSFpbYtNpekA' 3 | 4 | echo "Downloading pre-trained models..." 5 | mkdir -p $DIR 6 | FILE="$(curl -sc /tmp/gcokie "${URL}" | grep -o '="uc-name.*' | sed 's/.*">//;s/<.a> .*//')" 7 | curl -Lb /tmp/gcokie "${URL}&confirm=$(awk '/_warning_/ {print $NF}' /tmp/gcokie)" -o "$DIR/${FILE}" 8 | 9 | echo "Extracting pre-trained models..." 10 | cd $DIR 11 | unzip $FILE 12 | rm $FILE 13 | 14 | echo "Download success." -------------------------------------------------------------------------------- /scripts/fid_score.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Calculates the Frechet Inception Distance (FID) to evalulate GANs 3 | The FID metric calculates the distance between two distributions of images. 4 | Typically, we have summary statistics (mean & covariance matrix) of one 5 | of these distributions, while the 2nd distribution is given by a GAN. 6 | When run as a stand-alone program, it compares the distribution of 7 | images that are stored as PNG/JPEG at a specified location with a 8 | distribution given by summary statistics (in pickle format). 9 | The FID is calculated by assuming that X_1 and X_2 are the activations of 10 | the pool_3 layer of the inception net for generated samples and real world 11 | samples respectivly. 12 | See --help to see further details. 13 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead 14 | of Tensorflow 15 | Copyright 2018 Institute of Bioinformatics, JKU Linz 16 | Licensed under the Apache License, Version 2.0 (the "License"); 17 | you may not use this file except in compliance with the License. 18 | You may obtain a copy of the License at 19 | http://www.apache.org/licenses/LICENSE-2.0 20 | Unless required by applicable law or agreed to in writing, software 21 | distributed under the License is distributed on an "AS IS" BASIS, 22 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 23 | See the License for the specific language governing permissions and 24 | limitations under the License. 25 | """ 26 | import os 27 | import pathlib 28 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 29 | 30 | import torch 31 | import numpy as np 32 | from scipy.misc import imread 33 | from scipy import linalg 34 | from torch.autograd import Variable 35 | from torch.nn.functional import adaptive_avg_pool2d 36 | 37 | from inception import InceptionV3 38 | 39 | 40 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 41 | parser.add_argument('--path', type=str, nargs=2, help=('Path to the generated images or to .npz statistic files')) 42 | parser.add_argument('--batch-size', type=int, default=64, help='Batch size to use') 43 | parser.add_argument('--dims', type=int, default=2048, choices=list(InceptionV3.BLOCK_INDEX_BY_DIM), help=('Dimensionality of Inception features to use. By default, uses pool3 features')) 44 | parser.add_argument('-c', '--gpu', default='', type=str, help='GPU to use (leave blank for CPU only)') 45 | 46 | 47 | def get_activations(images, model, batch_size=64, dims=2048, 48 | cuda=False, verbose=False): 49 | """Calculates the activations of the pool_3 layer for all images. 50 | Params: 51 | -- images : Numpy array of dimension (n_images, 3, hi, wi). The values 52 | must lie between 0 and 1. 53 | -- model : Instance of inception model 54 | -- batch_size : the images numpy array is split into batches with 55 | batch size batch_size. A reasonable batch size depends 56 | on the hardware. 57 | -- dims : Dimensionality of features returned by Inception 58 | -- cuda : If set to True, use GPU 59 | -- verbose : If set to True and parameter out_step is given, the number 60 | of calculated batches is reported. 61 | Returns: 62 | -- A numpy array of dimension (num images, dims) that contains the 63 | activations of the given tensor when feeding inception with the 64 | query tensor. 65 | """ 66 | model.eval() 67 | 68 | d0 = images.shape[0] 69 | if batch_size > d0: 70 | print(('Warning: batch size is bigger than the data size. ' 71 | 'Setting batch size to data size')) 72 | batch_size = d0 73 | 74 | n_batches = d0 // batch_size 75 | n_used_imgs = n_batches * batch_size 76 | 77 | pred_arr = np.empty((n_used_imgs, dims)) 78 | for i in range(n_batches): 79 | if verbose: 80 | print('\rPropagating batch %d/%d' % (i + 1, n_batches), 81 | end='', flush=True) 82 | start = i * batch_size 83 | end = start + batch_size 84 | 85 | batch = torch.from_numpy(images[start:end]).type(torch.FloatTensor) 86 | batch = Variable(batch, volatile=True) 87 | if cuda: 88 | batch = batch.cuda() 89 | 90 | pred = model(batch)[0] 91 | 92 | # If model output is not scalar, apply global spatial average pooling. 93 | # This happens if you choose a dimensionality not equal 2048. 94 | if pred.shape[2] != 1 or pred.shape[3] != 1: 95 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 96 | 97 | pred_arr[start:end] = pred.cpu().data.numpy().reshape(batch_size, -1) 98 | 99 | if verbose: 100 | print(' done') 101 | 102 | return pred_arr 103 | 104 | 105 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 106 | """Numpy implementation of the Frechet Distance. 107 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 108 | and X_2 ~ N(mu_2, C_2) is 109 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 110 | Stable version by Dougal J. Sutherland. 111 | Params: 112 | -- mu1 : Numpy array containing the activations of a layer of the 113 | inception net (like returned by the function 'get_predictions') 114 | for generated samples. 115 | -- mu2 : The sample mean over activations, precalculated on an 116 | representive data set. 117 | -- sigma1: The covariance matrix over activations for generated samples. 118 | -- sigma2: The covariance matrix over activations, precalculated on an 119 | representive data set. 120 | Returns: 121 | -- : The Frechet Distance. 122 | """ 123 | 124 | mu1 = np.atleast_1d(mu1) 125 | mu2 = np.atleast_1d(mu2) 126 | 127 | sigma1 = np.atleast_2d(sigma1) 128 | sigma2 = np.atleast_2d(sigma2) 129 | 130 | assert mu1.shape == mu2.shape, \ 131 | 'Training and test mean vectors have different lengths' 132 | assert sigma1.shape == sigma2.shape, \ 133 | 'Training and test covariances have different dimensions' 134 | 135 | diff = mu1 - mu2 136 | 137 | # Product might be almost singular 138 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 139 | if not np.isfinite(covmean).all(): 140 | msg = ('fid calculation produces singular product; ' 141 | 'adding %s to diagonal of cov estimates') % eps 142 | print(msg) 143 | offset = np.eye(sigma1.shape[0]) * eps 144 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 145 | 146 | # Numerical error might give slight imaginary component 147 | if np.iscomplexobj(covmean): 148 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 149 | m = np.max(np.abs(covmean.imag)) 150 | raise ValueError('Imaginary component {}'.format(m)) 151 | covmean = covmean.real 152 | 153 | tr_covmean = np.trace(covmean) 154 | 155 | return (diff.dot(diff) + np.trace(sigma1) + 156 | np.trace(sigma2) - 2 * tr_covmean) 157 | 158 | 159 | def calculate_activation_statistics(images, model, batch_size=64, 160 | dims=2048, cuda=False, verbose=False): 161 | """Calculation of the statistics used by the FID. 162 | Params: 163 | -- images : Numpy array of dimension (n_images, 3, hi, wi). The values 164 | must lie between 0 and 1. 165 | -- model : Instance of inception model 166 | -- batch_size : The images numpy array is split into batches with 167 | batch size batch_size. A reasonable batch size 168 | depends on the hardware. 169 | -- dims : Dimensionality of features returned by Inception 170 | -- cuda : If set to True, use GPU 171 | -- verbose : If set to True and parameter out_step is given, the 172 | number of calculated batches is reported. 173 | Returns: 174 | -- mu : The mean over samples of the activations of the pool_3 layer of 175 | the inception model. 176 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 177 | the inception model. 178 | """ 179 | act = get_activations(images, model, batch_size, dims, cuda, verbose) 180 | mu = np.mean(act, axis=0) 181 | sigma = np.cov(act, rowvar=False) 182 | return mu, sigma 183 | 184 | 185 | def _compute_statistics_of_path(path, model, batch_size, dims, cuda): 186 | npz_file = os.path.join(path, 'statistics.npz') 187 | if os.path.exists(npz_file): 188 | f = np.load(npz_file) 189 | m, s = f['mu'][:], f['sigma'][:] 190 | f.close() 191 | else: 192 | path = pathlib.Path(path) 193 | files = list(path.glob('*.jpg')) + list(path.glob('*.png')) 194 | 195 | imgs = np.array([imread(str(fn)).astype(np.float32) for fn in files]) 196 | 197 | # Bring images to shape (B, 3, H, W) 198 | imgs = imgs.transpose((0, 3, 1, 2)) 199 | 200 | # Rescale images to be between 0 and 1 201 | imgs /= 255 202 | 203 | m, s = calculate_activation_statistics(imgs, model, batch_size, dims, cuda) 204 | np.savez(npz_file, mu=m, sigma=s) 205 | 206 | return m, s 207 | 208 | 209 | def calculate_fid_given_paths(paths, batch_size, cuda, dims): 210 | """Calculates the FID of two paths""" 211 | for p in paths: 212 | if not os.path.exists(p): 213 | raise RuntimeError('Invalid path: %s' % p) 214 | 215 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 216 | 217 | model = InceptionV3([block_idx]) 218 | if cuda: 219 | model.cuda() 220 | 221 | print('calculate path1 statistics...') 222 | m1, s1 = _compute_statistics_of_path(paths[0], model, batch_size, dims, cuda) 223 | print('calculate path2 statistics...') 224 | m2, s2 = _compute_statistics_of_path(paths[1], model, batch_size, dims, cuda) 225 | print('calculate frechet distance...') 226 | fid_value = calculate_frechet_distance(m1, s1, m2, s2) 227 | 228 | return fid_value 229 | 230 | 231 | if __name__ == '__main__': 232 | args = parser.parse_args() 233 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 234 | 235 | fid_value = calculate_fid_given_paths(args.path, 236 | args.batch_size, 237 | args.gpu != '', 238 | args.dims) 239 | print('FID: ', round(fid_value, 4)) 240 | -------------------------------------------------------------------------------- /scripts/flist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('--path', type=str, help='path to the dataset') 7 | parser.add_argument('--output', type=str, help='path to the file list') 8 | args = parser.parse_args() 9 | 10 | ext = {'.JPG', '.JPEG', '.PNG', '.TIF', 'TIFF'} 11 | 12 | images = [] 13 | for root, dirs, files in os.walk(args.path): 14 | print('loading ' + root) 15 | for file in files: 16 | if os.path.splitext(file)[1].upper() in ext: 17 | images.append(os.path.join(root, file)) 18 | 19 | images = sorted(images) 20 | np.savetxt(args.output, images, fmt='%s') -------------------------------------------------------------------------------- /scripts/inception.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from torchvision import models 4 | 5 | 6 | class InceptionV3(nn.Module): 7 | """Pretrained InceptionV3 network returning feature maps""" 8 | 9 | # Index of default block of inception to return, 10 | # corresponds to output of final average pooling 11 | DEFAULT_BLOCK_INDEX = 3 12 | 13 | # Maps feature dimensionality to their output blocks indices 14 | BLOCK_INDEX_BY_DIM = { 15 | 64: 0, # First max pooling features 16 | 192: 1, # Second max pooling featurs 17 | 768: 2, # Pre-aux classifier features 18 | 2048: 3 # Final average pooling features 19 | } 20 | 21 | def __init__(self, 22 | output_blocks=[DEFAULT_BLOCK_INDEX], 23 | resize_input=True, 24 | normalize_input=True, 25 | requires_grad=False): 26 | """Build pretrained InceptionV3 27 | Parameters 28 | ---------- 29 | output_blocks : list of int 30 | Indices of blocks to return features of. Possible values are: 31 | - 0: corresponds to output of first max pooling 32 | - 1: corresponds to output of second max pooling 33 | - 2: corresponds to output which is fed to aux classifier 34 | - 3: corresponds to output of final average pooling 35 | resize_input : bool 36 | If true, bilinearly resizes input to width and height 299 before 37 | feeding input to model. As the network without fully connected 38 | layers is fully convolutional, it should be able to handle inputs 39 | of arbitrary size, so resizing might not be strictly needed 40 | normalize_input : bool 41 | If true, normalizes the input to the statistics the pretrained 42 | Inception network expects 43 | requires_grad : bool 44 | If true, parameters of the model require gradient. Possibly useful 45 | for finetuning the network 46 | """ 47 | super(InceptionV3, self).__init__() 48 | 49 | self.resize_input = resize_input 50 | self.normalize_input = normalize_input 51 | self.output_blocks = sorted(output_blocks) 52 | self.last_needed_block = max(output_blocks) 53 | 54 | assert self.last_needed_block <= 3, \ 55 | 'Last possible output block index is 3' 56 | 57 | self.blocks = nn.ModuleList() 58 | 59 | inception = models.inception_v3(pretrained=True) 60 | 61 | # Block 0: input to maxpool1 62 | block0 = [ 63 | inception.Conv2d_1a_3x3, 64 | inception.Conv2d_2a_3x3, 65 | inception.Conv2d_2b_3x3, 66 | nn.MaxPool2d(kernel_size=3, stride=2) 67 | ] 68 | self.blocks.append(nn.Sequential(*block0)) 69 | 70 | # Block 1: maxpool1 to maxpool2 71 | if self.last_needed_block >= 1: 72 | block1 = [ 73 | inception.Conv2d_3b_1x1, 74 | inception.Conv2d_4a_3x3, 75 | nn.MaxPool2d(kernel_size=3, stride=2) 76 | ] 77 | self.blocks.append(nn.Sequential(*block1)) 78 | 79 | # Block 2: maxpool2 to aux classifier 80 | if self.last_needed_block >= 2: 81 | block2 = [ 82 | inception.Mixed_5b, 83 | inception.Mixed_5c, 84 | inception.Mixed_5d, 85 | inception.Mixed_6a, 86 | inception.Mixed_6b, 87 | inception.Mixed_6c, 88 | inception.Mixed_6d, 89 | inception.Mixed_6e, 90 | ] 91 | self.blocks.append(nn.Sequential(*block2)) 92 | 93 | # Block 3: aux classifier to final avgpool 94 | if self.last_needed_block >= 3: 95 | block3 = [ 96 | inception.Mixed_7a, 97 | inception.Mixed_7b, 98 | inception.Mixed_7c, 99 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 100 | ] 101 | self.blocks.append(nn.Sequential(*block3)) 102 | 103 | for param in self.parameters(): 104 | param.requires_grad = requires_grad 105 | 106 | def forward(self, inp): 107 | """Get Inception feature maps 108 | Parameters 109 | ---------- 110 | inp : torch.autograd.Variable 111 | Input tensor of shape Bx3xHxW. Values are expected to be in 112 | range (0, 1) 113 | Returns 114 | ------- 115 | List of torch.autograd.Variable, corresponding to the selected output 116 | block, sorted ascending by index 117 | """ 118 | outp = [] 119 | x = inp 120 | 121 | if self.resize_input: 122 | x = F.upsample(x, size=(299, 299), mode='bilinear') 123 | 124 | if self.normalize_input: 125 | x = x.clone() 126 | x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 127 | x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 128 | x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 129 | 130 | for idx, block in enumerate(self.blocks): 131 | x = block(x) 132 | if idx in self.output_blocks: 133 | outp.append(x) 134 | 135 | if idx == self.last_needed_block: 136 | break 137 | 138 | return outp 139 | -------------------------------------------------------------------------------- /scripts/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import matplotlib.pyplot as plt 4 | 5 | from glob import glob 6 | from ntpath import basename 7 | from scipy.misc import imread 8 | from skimage.measure import compare_ssim 9 | from skimage.measure import compare_psnr 10 | from skimage.color import rgb2gray 11 | 12 | 13 | def parse_args(): 14 | parser = argparse.ArgumentParser(description='script to compute all statistics') 15 | parser.add_argument('--data-path', help='Path to ground truth data', type=str) 16 | parser.add_argument('--output-path', help='Path to output data', type=str) 17 | parser.add_argument('--debug', default=0, help='Debug', type=int) 18 | args = parser.parse_args() 19 | return args 20 | 21 | 22 | def compare_mae(img_true, img_test): 23 | img_true = img_true.astype(np.float32) 24 | img_test = img_test.astype(np.float32) 25 | return np.sum(np.abs(img_true - img_test)) / np.sum(img_true + img_test) 26 | 27 | 28 | args = parse_args() 29 | for arg in vars(args): 30 | print('[%s] =' % arg, getattr(args, arg)) 31 | 32 | path_true = args.data_path 33 | path_pred = args.output_path 34 | 35 | psnr = [] 36 | ssim = [] 37 | mae = [] 38 | names = [] 39 | index = 1 40 | 41 | files = list(glob(path_true + '/*.jpg')) + list(glob(path_true + '/*.png')) 42 | for fn in sorted(files): 43 | name = basename(str(fn)) 44 | names.append(name) 45 | 46 | img_gt = (imread(str(fn)) / 255.0).astype(np.float32) 47 | img_pred = (imread(path_pred + '/' + basename(str(fn))) / 255.0).astype(np.float32) 48 | 49 | img_gt = rgb2gray(img_gt) 50 | img_pred = rgb2gray(img_pred) 51 | 52 | if args.debug != 0: 53 | plt.subplot('121') 54 | plt.imshow(img_gt) 55 | plt.title('Groud truth') 56 | plt.subplot('122') 57 | plt.imshow(img_pred) 58 | plt.title('Output') 59 | plt.show() 60 | 61 | psnr.append(compare_psnr(img_gt, img_pred, data_range=1)) 62 | ssim.append(compare_ssim(img_gt, img_pred, data_range=1, win_size=51)) 63 | mae.append(compare_mae(img_gt, img_pred)) 64 | if np.mod(index, 100) == 0: 65 | print( 66 | str(index) + ' images processed', 67 | "PSNR: %.4f" % round(np.mean(psnr), 4), 68 | "SSIM: %.4f" % round(np.mean(ssim), 4), 69 | "MAE: %.4f" % round(np.mean(mae), 4), 70 | ) 71 | index += 1 72 | 73 | np.savez(args.output_path + '/metrics.npz', psnr=psnr, ssim=ssim, mae=mae, names=names) 74 | print( 75 | "PSNR: %.4f" % round(np.mean(psnr), 4), 76 | "PSNR Variance: %.4f" % round(np.var(psnr), 4), 77 | "SSIM: %.4f" % round(np.mean(ssim), 4), 78 | "SSIM Variance: %.4f" % round(np.var(ssim), 4), 79 | "MAE: %.4f" % round(np.mean(mae), 4), 80 | "MAE Variance: %.4f" % round(np.var(mae), 4) 81 | ) 82 | -------------------------------------------------------------------------------- /segmentation_classes.txt: -------------------------------------------------------------------------------- 1 | 0: Background 2 | 1: Aeroplane 3 | 2: bicycle 4 | 3: bird 5 | 4: boat 6 | 5: bottle 7 | 6: bus 8 | 7: car 9 | 8: cat 10 | 9: chair 11 | 10: cow 12 | 11: dining table 13 | 12: dog 14 | 13: horse 15 | 14: motorbike 16 | 15: person 17 | 16: potted plant 18 | 17: sheep 19 | 18: sofa 20 | 19: train 21 | 20: tv/monitor -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [pycodestyle] 2 | ignore = E303 3 | max-line-length = 200 -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | # empty -------------------------------------------------------------------------------- /src/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/src/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /src/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/src/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/src/__pycache__/config.cpython-36.pyc -------------------------------------------------------------------------------- /src/__pycache__/config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/src/__pycache__/config.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/src/__pycache__/dataset.cpython-36.pyc -------------------------------------------------------------------------------- /src/__pycache__/dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/src/__pycache__/dataset.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/edge_connect.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/src/__pycache__/edge_connect.cpython-36.pyc -------------------------------------------------------------------------------- /src/__pycache__/edge_connect.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/src/__pycache__/edge_connect.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/src/__pycache__/loss.cpython-36.pyc -------------------------------------------------------------------------------- /src/__pycache__/loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/src/__pycache__/loss.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/metrics.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/src/__pycache__/metrics.cpython-36.pyc -------------------------------------------------------------------------------- /src/__pycache__/metrics.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/src/__pycache__/metrics.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/models.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/src/__pycache__/models.cpython-36.pyc -------------------------------------------------------------------------------- /src/__pycache__/models.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/src/__pycache__/models.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/networks.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/src/__pycache__/networks.cpython-36.pyc -------------------------------------------------------------------------------- /src/__pycache__/networks.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/src/__pycache__/networks.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/segmentor_fcn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/src/__pycache__/segmentor_fcn.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/src/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /src/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/src/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | 4 | class Config(dict): 5 | def __init__(self, config_path): 6 | with open(config_path, 'r') as f: 7 | self._yaml = f.read() 8 | self._dict = yaml.safe_load(self._yaml) 9 | self._dict['PATH'] = os.path.dirname(config_path) 10 | 11 | def __getattr__(self, name): 12 | if self._dict.get(name) is not None: 13 | return self._dict[name] 14 | 15 | if DEFAULT_CONFIG.get(name) is not None: 16 | return DEFAULT_CONFIG[name] 17 | 18 | return None 19 | 20 | def print(self): 21 | print('Model configurations:') 22 | print('---------------------------------') 23 | print(self._yaml) 24 | print('') 25 | print('---------------------------------') 26 | print('') 27 | 28 | 29 | DEFAULT_CONFIG = { 30 | 'MODE': 1, # 1: train, 2: test, 3: eval 31 | 'MODEL': 1, # 1: edge model, 2: inpaint model, 3: edge-inpaint model, 4: joint model 32 | 'MASK': 3, # 1: random block, 2: half, 3: external, 4: (external, random block), 5: (external, random block, half) 33 | 'EDGE': 1, # 1: canny, 2: external 34 | 'NMS': 1, # 0: no non-max-suppression, 1: applies non-max-suppression on the external edges by multiplying by Canny 35 | 'SEED': 10, # random seed 36 | 'GPU': [0], # list of gpu ids 37 | 'DEBUG': 0, # turns on debugging mode 38 | 'VERBOSE': 0, # turns on verbose mode in the output console 39 | 40 | 'LR': 0.0001, # learning rate 41 | 'D2G_LR': 0.1, # discriminator/generator learning rate ratio 42 | 'BETA1': 0.0, # adam optimizer beta1 43 | 'BETA2': 0.9, # adam optimizer beta2 44 | 'BATCH_SIZE': 8, # input batch size for training 45 | 'INPUT_SIZE': 256, # input image size for training 0 for original size 46 | 'SIGMA': 2, # standard deviation of the Gaussian filter used in Canny edge detector (0: random, -1: no edge) 47 | 'MAX_ITERS': 2e6, # maximum number of iterations to train the model 48 | 49 | 'EDGE_THRESHOLD': 0.5, # edge detection threshold 50 | 'L1_LOSS_WEIGHT': 1, # l1 loss weight 51 | 'FM_LOSS_WEIGHT': 10, # feature-matching loss weight 52 | 'STYLE_LOSS_WEIGHT': 1, # style loss weight 53 | 'CONTENT_LOSS_WEIGHT': 1, # perceptual loss weight 54 | 'INPAINT_ADV_LOSS_WEIGHT': 0.01,# adversarial loss weight 55 | 56 | 'GAN_LOSS': 'nsgan', # nsgan | lsgan | hinge 57 | 'GAN_POOL_SIZE': 0, # fake images pool size 58 | 59 | 'SAVE_INTERVAL': 1000, # how many iterations to wait before saving model (0: never) 60 | 'SAMPLE_INTERVAL': 1000, # how many iterations to wait before sampling (0: never) 61 | 'SAMPLE_SIZE': 12, # number of images to sample 62 | 'EVAL_INTERVAL': 0, # how many iterations to wait before model evaluation (0: never) 63 | 'LOG_INTERVAL': 10, # how many iterations to wait before logging training status (0: never) 64 | 'SEG_NETWORK': 0, # 0:DeepLabV3 resnet 101 segmentation , 1: FCN resnet 101 segmentation 65 | } 66 | -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import torch 4 | import random 5 | import numpy as np 6 | import torchvision.transforms.functional as F 7 | from torch.utils.data import DataLoader 8 | from PIL import Image 9 | from imageio import imread 10 | from skimage.feature import canny 11 | from skimage.color import rgb2gray, gray2rgb 12 | from .utils import create_mask 13 | import cv2 14 | from .segmentor_fcn import segmentor,fill_gaps 15 | 16 | 17 | class Dataset(torch.utils.data.Dataset): 18 | def __init__(self, config, flist, edge_flist, augment=True, training=True): 19 | super(Dataset, self).__init__() 20 | self.augment = augment 21 | self.training = training 22 | self.data = self.load_flist(flist) 23 | self.edge_data = self.load_flist(edge_flist) 24 | 25 | self.input_size = config.INPUT_SIZE 26 | self.sigma = config.SIGMA 27 | self.edge = config.EDGE 28 | self.mask = config.MASK 29 | self.nms = config.NMS 30 | self.device = config.SEG_DEVICE 31 | self.objects = config.OBJECTS 32 | self.segment_net = config.SEG_NETWORK 33 | # in test mode, there's a one-to-one relationship between mask and image 34 | # masks are loaded non random 35 | 36 | 37 | def __len__(self): 38 | return len(self.data) 39 | 40 | def __getitem__(self, index): 41 | try: 42 | item = self.load_item(index) 43 | except: 44 | print('loading error: ' + self.data[index]) 45 | item = self.load_item(0) 46 | 47 | return item 48 | 49 | def load_name(self, index): 50 | name = self.data[index] 51 | return os.path.basename(name) 52 | 53 | def load_size(self, index): 54 | img = Image.open(self.data[index]) 55 | width,height=img.size 56 | return width,height 57 | 58 | 59 | def load_item(self, index): 60 | 61 | size = self.input_size 62 | 63 | # load image 64 | img = Image.open(self.data[index]) 65 | 66 | 67 | 68 | 69 | 70 | # gray to rgb 71 | if img.mode !='RGB': 72 | img = gray2rgb(np.array(img)) 73 | img=Image.fromarray(img) 74 | 75 | # resize/crop if needed 76 | img,mask=segmentor(self.segment_net,img,self.device,self.objects) 77 | img = Image.fromarray(img) 78 | img = np.array(img.resize((size, size), Image.ANTIALIAS)) 79 | 80 | 81 | # create grayscale image 82 | img_gray = rgb2gray(np.array(img)) 83 | 84 | 85 | 86 | 87 | # load mask 88 | mask = Image.fromarray(mask) 89 | mask = np.array(mask.resize((size, size), Image.ANTIALIAS)) 90 | idx=(mask>0) 91 | mask[idx]=255 92 | #kernel = np.ones((5, 5), np.uint8) 93 | #opening = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) 94 | #closing = cv2.morphologyEx(opening, cv2.MORPH_CLOSE, kernel) 95 | mask=np.apply_along_axis(fill_gaps, 1, mask) #horizontal padding 96 | mask=np.apply_along_axis(fill_gaps, 0, mask) #vertical padding 97 | 98 | 99 | 100 | 101 | # load edge 102 | edge = self.load_edge(img_gray, index, mask) 103 | 104 | # augment data 105 | if self.augment and np.random.binomial(1, 0.5) > 0: 106 | img = img[:, ::-1, ...] 107 | img_gray = img_gray[:, ::-1, ...] 108 | edge = edge[:, ::-1, ...] 109 | mask = mask[:, ::-1, ...] 110 | 111 | return self.to_tensor(img), self.to_tensor(img_gray), self.to_tensor(edge), self.to_tensor(mask) 112 | 113 | def load_edge(self, img, index, mask): 114 | sigma = self.sigma 115 | 116 | # in test mode images are masked (with masked regions), 117 | # using 'mask' parameter prevents canny to detect edges for the masked regions 118 | mask = None if self.training else (1 - mask / 255).astype(np.bool) 119 | 120 | # canny 121 | if self.edge == 1: 122 | # no edge 123 | if sigma == -1: 124 | return np.zeros(img.shape).astype(np.float) 125 | 126 | # random sigma 127 | if sigma == 0: 128 | sigma = random.randint(1, 4) 129 | 130 | return canny(img, sigma=sigma, mask=mask).astype(np.float) 131 | 132 | # external 133 | else: 134 | imgh, imgw = img.shape[0:2] 135 | edge = imread(self.edge_data[index]) 136 | edge = self.resized(edge, imgh, imgw) 137 | 138 | # non-max suppression 139 | if self.nms == 1: 140 | edge = edge * canny(img, sigma=sigma, mask=mask) 141 | 142 | return edge 143 | 144 | 145 | def to_tensor(self, img): 146 | img = Image.fromarray(img) 147 | img_t = F.to_tensor(img).float() 148 | return img_t 149 | 150 | 151 | def load_flist(self, flist): 152 | if isinstance(flist, list): 153 | return flist 154 | 155 | # flist: image file path, image directory path, text file flist path 156 | if isinstance(flist, str): 157 | if os.path.isdir(flist): 158 | flist = list(glob.glob(flist + '/*.jpg')) + list(glob.glob(flist + '/*.png')) 159 | flist.sort() 160 | return flist 161 | 162 | if os.path.isfile(flist): 163 | try: 164 | return np.genfromtxt(flist, dtype=np.str, encoding='utf-8') 165 | except: 166 | return [flist] 167 | 168 | return [] 169 | 170 | def create_iterator(self, batch_size): 171 | while True: 172 | sample_loader = DataLoader( 173 | dataset=self, 174 | batch_size=batch_size, 175 | drop_last=True 176 | ) 177 | 178 | for item in sample_loader: 179 | yield item 180 | -------------------------------------------------------------------------------- /src/edge_connect.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Code of EdgeConnect is from this repo 3 | https://github.com/knazeri/edge-connect 4 | ''' 5 | 6 | 7 | 8 | import os 9 | import numpy as np 10 | import torch 11 | from torch.utils.data import DataLoader 12 | from .dataset import Dataset 13 | from .models import EdgeModel, InpaintingModel 14 | from .utils import Progbar, create_dir, stitch_images, imsave 15 | from PIL import Image 16 | from torchvision import transforms 17 | 18 | class EdgeConnect(): 19 | def __init__(self, config): 20 | self.config = config 21 | 22 | if config.MODEL == 1: 23 | model_name = 'edge' 24 | elif config.MODEL == 2: 25 | model_name = 'inpaint' 26 | elif config.MODEL == 3: 27 | model_name = 'edge_inpaint' 28 | elif config.MODEL == 4: 29 | model_name = 'joint' 30 | 31 | self.debug = False 32 | self.model_name = model_name 33 | self.edge_model = EdgeModel(config).to(config.DEVICE) 34 | self.inpaint_model = InpaintingModel(config).to(config.DEVICE) 35 | 36 | 37 | # test mode 38 | self.test_dataset = Dataset(config, config.TEST_FLIST, config.TEST_EDGE_FLIST, augment=False, training=False) 39 | 40 | self.samples_path = os.path.join(config.PATH, 'samples') 41 | self.results_path = os.path.join(config.PATH, 'results') 42 | 43 | if config.RESULTS is not None: 44 | self.results_path = os.path.join(config.RESULTS) 45 | 46 | if config.DEBUG is not None and config.DEBUG != 0: 47 | self.debug = True 48 | 49 | self.log_file = os.path.join(config.PATH, 'log_' + model_name + '.dat') 50 | 51 | def load(self): 52 | if self.config.MODEL == 1: 53 | self.edge_model.load() 54 | 55 | elif self.config.MODEL == 2: 56 | self.inpaint_model.load() 57 | 58 | else: 59 | self.edge_model.load() 60 | self.inpaint_model.load() 61 | 62 | def save(self): 63 | if self.config.MODEL == 1: 64 | self.edge_model.save() 65 | 66 | elif self.config.MODEL == 2 or self.config.MODEL == 3: 67 | self.inpaint_model.save() 68 | 69 | else: 70 | self.edge_model.save() 71 | self.inpaint_model.save() 72 | 73 | 74 | def test(self): 75 | self.edge_model.eval() 76 | self.inpaint_model.eval() 77 | 78 | model = self.config.MODEL 79 | create_dir(self.results_path) 80 | 81 | test_loader = DataLoader( 82 | dataset=self.test_dataset, 83 | batch_size=1, 84 | ) 85 | 86 | index = 0 87 | for items in test_loader: 88 | name = self.test_dataset.load_name(index) 89 | 90 | images, images_gray, edges, masks = self.cuda(*items) 91 | index += 1 92 | 93 | # edge model 94 | if model == 1: 95 | outputs = self.edge_model(images_gray, edges, masks) 96 | outputs_merged = (outputs * masks) + (edges * (1 - masks)) 97 | 98 | # inpaint model 99 | elif model == 2: 100 | outputs = self.inpaint_model(images, edges, masks) 101 | outputs_merged = (outputs * masks) + (images * (1 - masks)) 102 | 103 | # inpaint with edge model / joint model 104 | else: 105 | edges = self.edge_model(images_gray, edges, masks).detach() 106 | outputs = self.inpaint_model(images, edges, masks) 107 | outputs_merged = (outputs * masks) + (images * (1 - masks)) 108 | 109 | output = self.postprocess(outputs_merged)[0] 110 | path = os.path.join(self.results_path, name) 111 | print(index, name) 112 | 113 | imsave(output, path) 114 | 115 | if self.debug: 116 | edges = self.postprocess(1 - edges)[0] 117 | masked = self.postprocess(images * (1 - masks) + masks)[0] 118 | fname, fext = name.split('.') 119 | 120 | imsave(edges, os.path.join(self.results_path, fname + '_edge.' + fext)) 121 | imsave(masked, os.path.join(self.results_path, fname + '_masked.' + fext)) 122 | 123 | print('\nEnd test....') 124 | return output 125 | 126 | 127 | def log(self, logs): 128 | with open(self.log_file, 'a') as f: 129 | f.write('%s\n' % ' '.join([str(item[1]) for item in logs])) 130 | 131 | def cuda(self, *args): 132 | return (item.to(self.config.DEVICE) for item in args) 133 | 134 | def postprocess(self, img): 135 | # [0, 1] => [0, 255] 136 | img = img * 255.0 137 | img = img.permute(0, 2, 3, 1) 138 | return img.int() 139 | -------------------------------------------------------------------------------- /src/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | 5 | 6 | class AdversarialLoss(nn.Module): 7 | r""" 8 | Adversarial loss 9 | https://arxiv.org/abs/1711.10337 10 | """ 11 | 12 | def __init__(self, type='nsgan', target_real_label=1.0, target_fake_label=0.0): 13 | r""" 14 | type = nsgan | lsgan | hinge 15 | """ 16 | super(AdversarialLoss, self).__init__() 17 | 18 | self.type = type 19 | self.register_buffer('real_label', torch.tensor(target_real_label)) 20 | self.register_buffer('fake_label', torch.tensor(target_fake_label)) 21 | 22 | if type == 'nsgan': 23 | self.criterion = nn.BCELoss() 24 | 25 | elif type == 'lsgan': 26 | self.criterion = nn.MSELoss() 27 | 28 | elif type == 'hinge': 29 | self.criterion = nn.ReLU() 30 | 31 | def __call__(self, outputs, is_real, is_disc=None): 32 | if self.type == 'hinge': 33 | if is_disc: 34 | if is_real: 35 | outputs = -outputs 36 | return self.criterion(1 + outputs).mean() 37 | else: 38 | return (-outputs).mean() 39 | 40 | else: 41 | labels = (self.real_label if is_real else self.fake_label).expand_as(outputs) 42 | loss = self.criterion(outputs, labels) 43 | return loss 44 | 45 | 46 | class StyleLoss(nn.Module): 47 | r""" 48 | Perceptual loss, VGG-based 49 | https://arxiv.org/abs/1603.08155 50 | https://github.com/dxyang/StyleTransfer/blob/master/utils.py 51 | """ 52 | 53 | def __init__(self): 54 | super(StyleLoss, self).__init__() 55 | self.add_module('vgg', VGG19()) 56 | self.criterion = torch.nn.L1Loss() 57 | 58 | def compute_gram(self, x): 59 | b, ch, h, w = x.size() 60 | f = x.view(b, ch, w * h) 61 | f_T = f.transpose(1, 2) 62 | G = f.bmm(f_T) / (h * w * ch) 63 | 64 | return G 65 | 66 | def __call__(self, x, y): 67 | # Compute features 68 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 69 | 70 | # Compute loss 71 | style_loss = 0.0 72 | style_loss += self.criterion(self.compute_gram(x_vgg['relu2_2']), self.compute_gram(y_vgg['relu2_2'])) 73 | style_loss += self.criterion(self.compute_gram(x_vgg['relu3_4']), self.compute_gram(y_vgg['relu3_4'])) 74 | style_loss += self.criterion(self.compute_gram(x_vgg['relu4_4']), self.compute_gram(y_vgg['relu4_4'])) 75 | style_loss += self.criterion(self.compute_gram(x_vgg['relu5_2']), self.compute_gram(y_vgg['relu5_2'])) 76 | 77 | return style_loss 78 | 79 | 80 | 81 | class PerceptualLoss(nn.Module): 82 | r""" 83 | Perceptual loss, VGG-based 84 | https://arxiv.org/abs/1603.08155 85 | https://github.com/dxyang/StyleTransfer/blob/master/utils.py 86 | """ 87 | 88 | def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]): 89 | super(PerceptualLoss, self).__init__() 90 | self.add_module('vgg', VGG19()) 91 | self.criterion = torch.nn.L1Loss() 92 | self.weights = weights 93 | 94 | def __call__(self, x, y): 95 | # Compute features 96 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 97 | 98 | content_loss = 0.0 99 | content_loss += self.weights[0] * self.criterion(x_vgg['relu1_1'], y_vgg['relu1_1']) 100 | content_loss += self.weights[1] * self.criterion(x_vgg['relu2_1'], y_vgg['relu2_1']) 101 | content_loss += self.weights[2] * self.criterion(x_vgg['relu3_1'], y_vgg['relu3_1']) 102 | content_loss += self.weights[3] * self.criterion(x_vgg['relu4_1'], y_vgg['relu4_1']) 103 | content_loss += self.weights[4] * self.criterion(x_vgg['relu5_1'], y_vgg['relu5_1']) 104 | 105 | 106 | return content_loss 107 | 108 | 109 | 110 | class VGG19(torch.nn.Module): 111 | def __init__(self): 112 | super(VGG19, self).__init__() 113 | features = models.vgg19(pretrained=True).features 114 | self.relu1_1 = torch.nn.Sequential() 115 | self.relu1_2 = torch.nn.Sequential() 116 | 117 | self.relu2_1 = torch.nn.Sequential() 118 | self.relu2_2 = torch.nn.Sequential() 119 | 120 | self.relu3_1 = torch.nn.Sequential() 121 | self.relu3_2 = torch.nn.Sequential() 122 | self.relu3_3 = torch.nn.Sequential() 123 | self.relu3_4 = torch.nn.Sequential() 124 | 125 | self.relu4_1 = torch.nn.Sequential() 126 | self.relu4_2 = torch.nn.Sequential() 127 | self.relu4_3 = torch.nn.Sequential() 128 | self.relu4_4 = torch.nn.Sequential() 129 | 130 | self.relu5_1 = torch.nn.Sequential() 131 | self.relu5_2 = torch.nn.Sequential() 132 | self.relu5_3 = torch.nn.Sequential() 133 | self.relu5_4 = torch.nn.Sequential() 134 | 135 | for x in range(2): 136 | self.relu1_1.add_module(str(x), features[x]) 137 | 138 | for x in range(2, 4): 139 | self.relu1_2.add_module(str(x), features[x]) 140 | 141 | for x in range(4, 7): 142 | self.relu2_1.add_module(str(x), features[x]) 143 | 144 | for x in range(7, 9): 145 | self.relu2_2.add_module(str(x), features[x]) 146 | 147 | for x in range(9, 12): 148 | self.relu3_1.add_module(str(x), features[x]) 149 | 150 | for x in range(12, 14): 151 | self.relu3_2.add_module(str(x), features[x]) 152 | 153 | for x in range(14, 16): 154 | self.relu3_3.add_module(str(x), features[x]) 155 | 156 | for x in range(16, 18): 157 | self.relu3_4.add_module(str(x), features[x]) 158 | 159 | for x in range(18, 21): 160 | self.relu4_1.add_module(str(x), features[x]) 161 | 162 | for x in range(21, 23): 163 | self.relu4_2.add_module(str(x), features[x]) 164 | 165 | for x in range(23, 25): 166 | self.relu4_3.add_module(str(x), features[x]) 167 | 168 | for x in range(25, 27): 169 | self.relu4_4.add_module(str(x), features[x]) 170 | 171 | for x in range(27, 30): 172 | self.relu5_1.add_module(str(x), features[x]) 173 | 174 | for x in range(30, 32): 175 | self.relu5_2.add_module(str(x), features[x]) 176 | 177 | for x in range(32, 34): 178 | self.relu5_3.add_module(str(x), features[x]) 179 | 180 | for x in range(34, 36): 181 | self.relu5_4.add_module(str(x), features[x]) 182 | 183 | # don't need the gradients, just want the features 184 | for param in self.parameters(): 185 | param.requires_grad = False 186 | 187 | def forward(self, x): 188 | relu1_1 = self.relu1_1(x) 189 | relu1_2 = self.relu1_2(relu1_1) 190 | 191 | relu2_1 = self.relu2_1(relu1_2) 192 | relu2_2 = self.relu2_2(relu2_1) 193 | 194 | relu3_1 = self.relu3_1(relu2_2) 195 | relu3_2 = self.relu3_2(relu3_1) 196 | relu3_3 = self.relu3_3(relu3_2) 197 | relu3_4 = self.relu3_4(relu3_3) 198 | 199 | relu4_1 = self.relu4_1(relu3_4) 200 | relu4_2 = self.relu4_2(relu4_1) 201 | relu4_3 = self.relu4_3(relu4_2) 202 | relu4_4 = self.relu4_4(relu4_3) 203 | 204 | relu5_1 = self.relu5_1(relu4_4) 205 | relu5_2 = self.relu5_2(relu5_1) 206 | relu5_3 = self.relu5_3(relu5_2) 207 | relu5_4 = self.relu5_4(relu5_3) 208 | 209 | out = { 210 | 'relu1_1': relu1_1, 211 | 'relu1_2': relu1_2, 212 | 213 | 'relu2_1': relu2_1, 214 | 'relu2_2': relu2_2, 215 | 216 | 'relu3_1': relu3_1, 217 | 'relu3_2': relu3_2, 218 | 'relu3_3': relu3_3, 219 | 'relu3_4': relu3_4, 220 | 221 | 'relu4_1': relu4_1, 222 | 'relu4_2': relu4_2, 223 | 'relu4_3': relu4_3, 224 | 'relu4_4': relu4_4, 225 | 226 | 'relu5_1': relu5_1, 227 | 'relu5_2': relu5_2, 228 | 'relu5_3': relu5_3, 229 | 'relu5_4': relu5_4, 230 | } 231 | return out 232 | -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | from .networks import InpaintGenerator, EdgeGenerator, Discriminator 6 | from .loss import AdversarialLoss, PerceptualLoss, StyleLoss 7 | 8 | 9 | class BaseModel(nn.Module): 10 | def __init__(self, name, config): 11 | super(BaseModel, self).__init__() 12 | 13 | self.name = name 14 | self.config = config 15 | self.iteration = 0 16 | 17 | self.gen_weights_path = os.path.join(config.PATH, name + '_gen.pth') 18 | self.dis_weights_path = os.path.join(config.PATH, name + '_dis.pth') 19 | 20 | def load(self): 21 | if os.path.exists(self.gen_weights_path): 22 | print('Loading %s generator...' % self.name) 23 | 24 | if torch.cuda.is_available(): 25 | data = torch.load(self.gen_weights_path) 26 | else: 27 | data = torch.load(self.gen_weights_path, map_location=lambda storage, loc: storage) 28 | 29 | self.generator.load_state_dict(data['generator']) 30 | self.iteration = data['iteration'] 31 | 32 | # load discriminator only when training 33 | if self.config.MODE == 1 and os.path.exists(self.dis_weights_path): 34 | print('Loading %s discriminator...' % self.name) 35 | 36 | if torch.cuda.is_available(): 37 | data = torch.load(self.dis_weights_path) 38 | else: 39 | data = torch.load(self.dis_weights_path, map_location=lambda storage, loc: storage) 40 | 41 | self.discriminator.load_state_dict(data['discriminator']) 42 | 43 | def save(self): 44 | print('\nsaving %s...\n' % self.name) 45 | torch.save({ 46 | 'iteration': self.iteration, 47 | 'generator': self.generator.state_dict() 48 | }, self.gen_weights_path) 49 | 50 | torch.save({ 51 | 'discriminator': self.discriminator.state_dict() 52 | }, self.dis_weights_path) 53 | 54 | 55 | class EdgeModel(BaseModel): 56 | def __init__(self, config): 57 | super(EdgeModel, self).__init__('EdgeModel', config) 58 | 59 | # generator input: [grayscale(1) + edge(1) + mask(1)] 60 | # discriminator input: (grayscale(1) + edge(1)) 61 | generator = EdgeGenerator(use_spectral_norm=True) 62 | discriminator = Discriminator(in_channels=2, use_sigmoid=config.GAN_LOSS != 'hinge') 63 | if len(config.GPU) > 1: 64 | generator = nn.DataParallel(generator, config.GPU) 65 | discriminator = nn.DataParallel(discriminator, config.GPU) 66 | l1_loss = nn.L1Loss() 67 | adversarial_loss = AdversarialLoss(type=config.GAN_LOSS) 68 | 69 | self.add_module('generator', generator) 70 | self.add_module('discriminator', discriminator) 71 | 72 | self.add_module('l1_loss', l1_loss) 73 | self.add_module('adversarial_loss', adversarial_loss) 74 | 75 | self.gen_optimizer = optim.Adam( 76 | params=generator.parameters(), 77 | lr=float(config.LR), 78 | betas=(config.BETA1, config.BETA2) 79 | ) 80 | 81 | self.dis_optimizer = optim.Adam( 82 | params=discriminator.parameters(), 83 | lr=float(config.LR) * float(config.D2G_LR), 84 | betas=(config.BETA1, config.BETA2) 85 | ) 86 | 87 | def process(self, images, edges, masks): 88 | self.iteration += 1 89 | 90 | 91 | # zero optimizers 92 | self.gen_optimizer.zero_grad() 93 | self.dis_optimizer.zero_grad() 94 | 95 | 96 | # process outputs 97 | outputs = self(images, edges, masks) 98 | gen_loss = 0 99 | dis_loss = 0 100 | 101 | 102 | # discriminator loss 103 | dis_input_real = torch.cat((images, edges), dim=1) 104 | dis_input_fake = torch.cat((images, outputs.detach()), dim=1) 105 | dis_real, dis_real_feat = self.discriminator(dis_input_real) # in: (grayscale(1) + edge(1)) 106 | dis_fake, dis_fake_feat = self.discriminator(dis_input_fake) # in: (grayscale(1) + edge(1)) 107 | dis_real_loss = self.adversarial_loss(dis_real, True, True) 108 | dis_fake_loss = self.adversarial_loss(dis_fake, False, True) 109 | dis_loss += (dis_real_loss + dis_fake_loss) / 2 110 | 111 | 112 | # generator adversarial loss 113 | gen_input_fake = torch.cat((images, outputs), dim=1) 114 | gen_fake, gen_fake_feat = self.discriminator(gen_input_fake) # in: (grayscale(1) + edge(1)) 115 | gen_gan_loss = self.adversarial_loss(gen_fake, True, False) 116 | gen_loss += gen_gan_loss 117 | 118 | 119 | # generator feature matching loss 120 | gen_fm_loss = 0 121 | for i in range(len(dis_real_feat)): 122 | gen_fm_loss += self.l1_loss(gen_fake_feat[i], dis_real_feat[i].detach()) 123 | gen_fm_loss = gen_fm_loss * self.config.FM_LOSS_WEIGHT 124 | gen_loss += gen_fm_loss 125 | 126 | 127 | # create logs 128 | logs = [ 129 | ("l_d1", dis_loss.item()), 130 | ("l_g1", gen_gan_loss.item()), 131 | ("l_fm", gen_fm_loss.item()), 132 | ] 133 | 134 | return outputs, gen_loss, dis_loss, logs 135 | 136 | def forward(self, images, edges, masks): 137 | edges_masked = (edges * (1 - masks)) 138 | images_masked = (images * (1 - masks)) + masks 139 | inputs = torch.cat((images_masked, edges_masked, masks), dim=1) 140 | outputs = self.generator(inputs) # in: [grayscale(1) + edge(1) + mask(1)] 141 | return outputs 142 | 143 | def backward(self, gen_loss=None, dis_loss=None): 144 | if dis_loss is not None: 145 | dis_loss.backward() 146 | self.dis_optimizer.step() 147 | 148 | if gen_loss is not None: 149 | gen_loss.backward() 150 | self.gen_optimizer.step() 151 | 152 | 153 | class InpaintingModel(BaseModel): 154 | def __init__(self, config): 155 | super(InpaintingModel, self).__init__('InpaintingModel', config) 156 | 157 | # generator input: [rgb(3) + edge(1)] 158 | # discriminator input: [rgb(3)] 159 | generator = InpaintGenerator() 160 | discriminator = Discriminator(in_channels=3, use_sigmoid=config.GAN_LOSS != 'hinge') 161 | if len(config.GPU) > 1: 162 | generator = nn.DataParallel(generator, config.GPU) 163 | discriminator = nn.DataParallel(discriminator , config.GPU) 164 | 165 | l1_loss = nn.L1Loss() 166 | perceptual_loss = PerceptualLoss() 167 | style_loss = StyleLoss() 168 | adversarial_loss = AdversarialLoss(type=config.GAN_LOSS) 169 | 170 | self.add_module('generator', generator) 171 | self.add_module('discriminator', discriminator) 172 | 173 | self.add_module('l1_loss', l1_loss) 174 | self.add_module('perceptual_loss', perceptual_loss) 175 | self.add_module('style_loss', style_loss) 176 | self.add_module('adversarial_loss', adversarial_loss) 177 | 178 | self.gen_optimizer = optim.Adam( 179 | params=generator.parameters(), 180 | lr=float(config.LR), 181 | betas=(config.BETA1, config.BETA2) 182 | ) 183 | 184 | self.dis_optimizer = optim.Adam( 185 | params=discriminator.parameters(), 186 | lr=float(config.LR) * float(config.D2G_LR), 187 | betas=(config.BETA1, config.BETA2) 188 | ) 189 | 190 | def process(self, images, edges, masks): 191 | self.iteration += 1 192 | 193 | # zero optimizers 194 | self.gen_optimizer.zero_grad() 195 | self.dis_optimizer.zero_grad() 196 | 197 | 198 | # process outputs 199 | outputs = self(images, edges, masks) 200 | gen_loss = 0 201 | dis_loss = 0 202 | 203 | 204 | # discriminator loss 205 | dis_input_real = images 206 | dis_input_fake = outputs.detach() 207 | dis_real, _ = self.discriminator(dis_input_real) # in: [rgb(3)] 208 | dis_fake, _ = self.discriminator(dis_input_fake) # in: [rgb(3)] 209 | dis_real_loss = self.adversarial_loss(dis_real, True, True) 210 | dis_fake_loss = self.adversarial_loss(dis_fake, False, True) 211 | dis_loss += (dis_real_loss + dis_fake_loss) / 2 212 | 213 | 214 | # generator adversarial loss 215 | gen_input_fake = outputs 216 | gen_fake, _ = self.discriminator(gen_input_fake) # in: [rgb(3)] 217 | gen_gan_loss = self.adversarial_loss(gen_fake, True, False) * self.config.INPAINT_ADV_LOSS_WEIGHT 218 | gen_loss += gen_gan_loss 219 | 220 | 221 | # generator l1 loss 222 | gen_l1_loss = self.l1_loss(outputs, images) * self.config.L1_LOSS_WEIGHT / torch.mean(masks) 223 | gen_loss += gen_l1_loss 224 | 225 | 226 | # generator perceptual loss 227 | gen_content_loss = self.perceptual_loss(outputs, images) 228 | gen_content_loss = gen_content_loss * self.config.CONTENT_LOSS_WEIGHT 229 | gen_loss += gen_content_loss 230 | 231 | 232 | # generator style loss 233 | gen_style_loss = self.style_loss(outputs * masks, images * masks) 234 | gen_style_loss = gen_style_loss * self.config.STYLE_LOSS_WEIGHT 235 | gen_loss += gen_style_loss 236 | 237 | 238 | # create logs 239 | logs = [ 240 | ("l_d2", dis_loss.item()), 241 | ("l_g2", gen_gan_loss.item()), 242 | ("l_l1", gen_l1_loss.item()), 243 | ("l_per", gen_content_loss.item()), 244 | ("l_sty", gen_style_loss.item()), 245 | ] 246 | 247 | return outputs, gen_loss, dis_loss, logs 248 | 249 | def forward(self, images, edges, masks): 250 | images_masked = (images * (1 - masks).float()) + masks 251 | inputs = torch.cat((images_masked, edges), dim=1) 252 | outputs = self.generator(inputs) # in: [rgb(3) + edge(1)] 253 | return outputs 254 | 255 | def backward(self, gen_loss=None, dis_loss=None): 256 | dis_loss.backward() 257 | self.dis_optimizer.step() 258 | 259 | gen_loss.backward() 260 | self.gen_optimizer.step() 261 | -------------------------------------------------------------------------------- /src/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class BaseNetwork(nn.Module): 6 | def __init__(self): 7 | super(BaseNetwork, self).__init__() 8 | 9 | def init_weights(self, init_type='normal', gain=0.02): 10 | ''' 11 | initialize network's weights 12 | init_type: normal | xavier | kaiming | orthogonal 13 | https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39 14 | ''' 15 | 16 | def init_func(m): 17 | classname = m.__class__.__name__ 18 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 19 | if init_type == 'normal': 20 | nn.init.normal_(m.weight.data, 0.0, gain) 21 | elif init_type == 'xavier': 22 | nn.init.xavier_normal_(m.weight.data, gain=gain) 23 | elif init_type == 'kaiming': 24 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 25 | elif init_type == 'orthogonal': 26 | nn.init.orthogonal_(m.weight.data, gain=gain) 27 | 28 | if hasattr(m, 'bias') and m.bias is not None: 29 | nn.init.constant_(m.bias.data, 0.0) 30 | 31 | elif classname.find('BatchNorm2d') != -1: 32 | nn.init.normal_(m.weight.data, 1.0, gain) 33 | nn.init.constant_(m.bias.data, 0.0) 34 | 35 | self.apply(init_func) 36 | 37 | 38 | class InpaintGenerator(BaseNetwork): 39 | def __init__(self, residual_blocks=8, init_weights=True): 40 | super(InpaintGenerator, self).__init__() 41 | 42 | self.encoder = nn.Sequential( 43 | nn.ReflectionPad2d(3), 44 | nn.Conv2d(in_channels=4, out_channels=64, kernel_size=7, padding=0), 45 | nn.InstanceNorm2d(64, track_running_stats=False), 46 | nn.ReLU(True), 47 | 48 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1), 49 | nn.InstanceNorm2d(128, track_running_stats=False), 50 | nn.ReLU(True), 51 | 52 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1), 53 | nn.InstanceNorm2d(256, track_running_stats=False), 54 | nn.ReLU(True) 55 | ) 56 | 57 | blocks = [] 58 | for _ in range(residual_blocks): 59 | block = ResnetBlock(256, 2) 60 | blocks.append(block) 61 | 62 | self.middle = nn.Sequential(*blocks) 63 | 64 | self.decoder = nn.Sequential( 65 | nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1), 66 | nn.InstanceNorm2d(128, track_running_stats=False), 67 | nn.ReLU(True), 68 | 69 | nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1), 70 | nn.InstanceNorm2d(64, track_running_stats=False), 71 | nn.ReLU(True), 72 | 73 | nn.ReflectionPad2d(3), 74 | nn.Conv2d(in_channels=64, out_channels=3, kernel_size=7, padding=0), 75 | ) 76 | 77 | if init_weights: 78 | self.init_weights() 79 | 80 | def forward(self, x): 81 | x = self.encoder(x) 82 | x = self.middle(x) 83 | x = self.decoder(x) 84 | x = (torch.tanh(x) + 1) / 2 85 | 86 | return x 87 | 88 | 89 | class EdgeGenerator(BaseNetwork): 90 | def __init__(self, residual_blocks=8, use_spectral_norm=True, init_weights=True): 91 | super(EdgeGenerator, self).__init__() 92 | 93 | self.encoder = nn.Sequential( 94 | nn.ReflectionPad2d(3), 95 | spectral_norm(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, padding=0), use_spectral_norm), 96 | nn.InstanceNorm2d(64, track_running_stats=False), 97 | nn.ReLU(True), 98 | 99 | spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1), use_spectral_norm), 100 | nn.InstanceNorm2d(128, track_running_stats=False), 101 | nn.ReLU(True), 102 | 103 | spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1), use_spectral_norm), 104 | nn.InstanceNorm2d(256, track_running_stats=False), 105 | nn.ReLU(True) 106 | ) 107 | 108 | blocks = [] 109 | for _ in range(residual_blocks): 110 | block = ResnetBlock(256, 2, use_spectral_norm=use_spectral_norm) 111 | blocks.append(block) 112 | 113 | self.middle = nn.Sequential(*blocks) 114 | 115 | self.decoder = nn.Sequential( 116 | spectral_norm(nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1), use_spectral_norm), 117 | nn.InstanceNorm2d(128, track_running_stats=False), 118 | nn.ReLU(True), 119 | 120 | spectral_norm(nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1), use_spectral_norm), 121 | nn.InstanceNorm2d(64, track_running_stats=False), 122 | nn.ReLU(True), 123 | 124 | nn.ReflectionPad2d(3), 125 | nn.Conv2d(in_channels=64, out_channels=1, kernel_size=7, padding=0), 126 | ) 127 | 128 | if init_weights: 129 | self.init_weights() 130 | 131 | def forward(self, x): 132 | x = self.encoder(x) 133 | x = self.middle(x) 134 | x = self.decoder(x) 135 | x = torch.sigmoid(x) 136 | return x 137 | 138 | 139 | class Discriminator(BaseNetwork): 140 | def __init__(self, in_channels, use_sigmoid=True, use_spectral_norm=True, init_weights=True): 141 | super(Discriminator, self).__init__() 142 | self.use_sigmoid = use_sigmoid 143 | 144 | self.conv1 = self.features = nn.Sequential( 145 | spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm), 146 | nn.LeakyReLU(0.2, inplace=True), 147 | ) 148 | 149 | self.conv2 = nn.Sequential( 150 | spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm), 151 | nn.LeakyReLU(0.2, inplace=True), 152 | ) 153 | 154 | self.conv3 = nn.Sequential( 155 | spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm), 156 | nn.LeakyReLU(0.2, inplace=True), 157 | ) 158 | 159 | self.conv4 = nn.Sequential( 160 | spectral_norm(nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=1, padding=1, bias=not use_spectral_norm), use_spectral_norm), 161 | nn.LeakyReLU(0.2, inplace=True), 162 | ) 163 | 164 | self.conv5 = nn.Sequential( 165 | spectral_norm(nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, padding=1, bias=not use_spectral_norm), use_spectral_norm), 166 | ) 167 | 168 | if init_weights: 169 | self.init_weights() 170 | 171 | def forward(self, x): 172 | conv1 = self.conv1(x) 173 | conv2 = self.conv2(conv1) 174 | conv3 = self.conv3(conv2) 175 | conv4 = self.conv4(conv3) 176 | conv5 = self.conv5(conv4) 177 | 178 | outputs = conv5 179 | if self.use_sigmoid: 180 | outputs = torch.sigmoid(conv5) 181 | 182 | return outputs, [conv1, conv2, conv3, conv4, conv5] 183 | 184 | 185 | class ResnetBlock(nn.Module): 186 | def __init__(self, dim, dilation=1, use_spectral_norm=False): 187 | super(ResnetBlock, self).__init__() 188 | self.conv_block = nn.Sequential( 189 | nn.ReflectionPad2d(dilation), 190 | spectral_norm(nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=dilation, bias=not use_spectral_norm), use_spectral_norm), 191 | nn.InstanceNorm2d(dim, track_running_stats=False), 192 | nn.ReLU(True), 193 | 194 | nn.ReflectionPad2d(1), 195 | spectral_norm(nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=1, bias=not use_spectral_norm), use_spectral_norm), 196 | nn.InstanceNorm2d(dim, track_running_stats=False), 197 | ) 198 | 199 | def forward(self, x): 200 | out = x + self.conv_block(x) 201 | 202 | # Remove ReLU at the end of the residual block 203 | # http://torch.ch/blog/2016/02/04/resnets.html 204 | 205 | return out 206 | 207 | 208 | def spectral_norm(module, mode=True): 209 | if mode: 210 | return nn.utils.spectral_norm(module) 211 | 212 | return module 213 | -------------------------------------------------------------------------------- /src/segmentor_fcn.py: -------------------------------------------------------------------------------- 1 | from torchvision import models 2 | from PIL import Image 3 | import torchvision.transforms as T 4 | import matplotlib.pyplot as plt 5 | import torch 6 | import numpy as np 7 | from imageio import imread 8 | from skimage.color import rgb2gray, gray2rgb 9 | import cv2 10 | 11 | fcn = models.segmentation.fcn_resnet101(pretrained=True).eval() 12 | dlab = models.segmentation.deeplabv3_resnet101(pretrained=1).eval() 13 | 14 | def decode_segmap(image,objects,nc=21): 15 | 16 | r = np.zeros_like(image).astype(np.uint8) 17 | for l in objects: 18 | idx = image == l 19 | r[idx] = 255#fill r with 255 wherever class is 1 and so on 20 | return np.array(r) 21 | 22 | 23 | def fill_gaps(values): 24 | searchval=[255,0,255] 25 | searchval2=[255,0,0,255] 26 | idx=(np.array(np.where((values[:-2]==searchval[0]) & (values[1:-1]==searchval[1]) & (values[2:]==searchval[2])))+1) 27 | idx2=(np.array(np.where((values[:-3]==searchval2[0]) & (values[1:-2]==searchval2[1]) & (values[2:-1]==searchval2[2]) & (values[3:]==searchval2[3])))+1) 28 | idx3=(idx2+1) 29 | new=idx.tolist()+idx2.tolist()+idx3.tolist() 30 | newlist = [item for items in new for item in items] 31 | values[newlist]=255 32 | return values 33 | 34 | def fill_gaps2(values): 35 | searchval=[0,255] 36 | searchval2=[255,0] 37 | idx=(np.array(np.where((values[:-1]==searchval[0]) & (values[1:]==searchval[1])))) 38 | idx2=(np.array(np.where((values[:-1]==searchval[0]) & (values[1:]==searchval[1])))+1) 39 | 40 | new=idx.tolist()+idx2.tolist() 41 | newlist = [item for items in new for item in items] 42 | values[newlist]=255 43 | return values 44 | 45 | 46 | def remove_patch_og(real_img,mask): 47 | og_data = real_img.copy() 48 | idx = mask == 255 ### cutting out mask part from real image here 49 | og_data[idx] =255 50 | return og_data 51 | 52 | 53 | 54 | 55 | def segmentor(seg_net,img,dev,objects): 56 | #plt.imshow(img); plt.show() 57 | if seg_net==1: 58 | net=fcn 59 | else: 60 | net=dlab 61 | if dev == 'cuda': 62 | trf = T.Compose([T.Resize(400), 63 | #T.CenterCrop(224), 64 | T.ToTensor(), 65 | T.Normalize(mean = [0.485, 0.456, 0.406], 66 | std = [0.229, 0.224, 0.225])]) 67 | else: 68 | trf = T.Compose([T.Resize(680), 69 | #T.CenterCrop(224), 70 | T.ToTensor(), 71 | T.Normalize(mean = [0.485, 0.456, 0.406], 72 | std = [0.229, 0.224, 0.225])]) 73 | inp = trf(img).unsqueeze(0).to(dev) 74 | out = net.to(dev)(inp)['out'] 75 | om = torch.argmax(out.squeeze(), dim=0).detach().cpu().numpy() 76 | mask=decode_segmap(om,objects) 77 | height,width =mask.shape 78 | img=np.array(img.resize((width, height), Image.ANTIALIAS)) 79 | 80 | 81 | og_img=remove_patch_og(img,mask) 82 | #plt.imshow(mask); plt.show() 83 | return og_img,mask 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import random 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | from PIL import Image 8 | 9 | 10 | def create_dir(dir): 11 | if not os.path.exists(dir): 12 | os.makedirs(dir) 13 | 14 | 15 | def create_mask(width, height, mask_width, mask_height, x=None, y=None): 16 | mask = np.zeros((height, width)) 17 | mask_x = x if x is not None else random.randint(0, width - mask_width) 18 | mask_y = y if y is not None else random.randint(0, height - mask_height) 19 | mask[mask_y:mask_y + mask_height, mask_x:mask_x + mask_width] = 1 20 | return mask 21 | 22 | 23 | def stitch_images(inputs, *outputs, img_per_row=2): 24 | gap = 5 25 | columns = len(outputs) + 1 26 | 27 | width, height = inputs[0][:, :, 0].shape 28 | img = Image.new('RGB', (width * img_per_row * columns + gap * (img_per_row - 1), height * int(len(inputs) / img_per_row))) 29 | images = [inputs, *outputs] 30 | 31 | for ix in range(len(inputs)): 32 | xoffset = int(ix % img_per_row) * width * columns + int(ix % img_per_row) * gap 33 | yoffset = int(ix / img_per_row) * height 34 | 35 | for cat in range(len(images)): 36 | im = np.array((images[cat][ix]).cpu()).astype(np.uint8).squeeze() 37 | im = Image.fromarray(im) 38 | img.paste(im, (xoffset + cat * width, yoffset)) 39 | 40 | return img 41 | 42 | 43 | def imshow(img, title=''): 44 | fig = plt.gcf() 45 | fig.canvas.set_window_title(title) 46 | plt.axis('off') 47 | plt.imshow(img, interpolation='none') 48 | plt.show() 49 | 50 | 51 | def imsave(img, path): 52 | im = Image.fromarray(img.cpu().numpy().astype(np.uint8).squeeze()) 53 | 54 | im.save(path) 55 | 56 | 57 | class Progbar(object): 58 | """Displays a progress bar. 59 | 60 | Arguments: 61 | target: Total number of steps expected, None if unknown. 62 | width: Progress bar width on screen. 63 | verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose) 64 | stateful_metrics: Iterable of string names of metrics that 65 | should *not* be averaged over time. Metrics in this list 66 | will be displayed as-is. All others will be averaged 67 | by the progbar before display. 68 | interval: Minimum visual progress update interval (in seconds). 69 | """ 70 | 71 | def __init__(self, target, width=25, verbose=1, interval=0.05, 72 | stateful_metrics=None): 73 | self.target = target 74 | self.width = width 75 | self.verbose = verbose 76 | self.interval = interval 77 | if stateful_metrics: 78 | self.stateful_metrics = set(stateful_metrics) 79 | else: 80 | self.stateful_metrics = set() 81 | 82 | self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and 83 | sys.stdout.isatty()) or 84 | 'ipykernel' in sys.modules or 85 | 'posix' in sys.modules) 86 | self._total_width = 0 87 | self._seen_so_far = 0 88 | # We use a dict + list to avoid garbage collection 89 | # issues found in OrderedDict 90 | self._values = {} 91 | self._values_order = [] 92 | self._start = time.time() 93 | self._last_update = 0 94 | 95 | def update(self, current, values=None): 96 | """Updates the progress bar. 97 | 98 | Arguments: 99 | current: Index of current step. 100 | values: List of tuples: 101 | `(name, value_for_last_step)`. 102 | If `name` is in `stateful_metrics`, 103 | `value_for_last_step` will be displayed as-is. 104 | Else, an average of the metric over time will be displayed. 105 | """ 106 | values = values or [] 107 | for k, v in values: 108 | if k not in self._values_order: 109 | self._values_order.append(k) 110 | if k not in self.stateful_metrics: 111 | if k not in self._values: 112 | self._values[k] = [v * (current - self._seen_so_far), 113 | current - self._seen_so_far] 114 | else: 115 | self._values[k][0] += v * (current - self._seen_so_far) 116 | self._values[k][1] += (current - self._seen_so_far) 117 | else: 118 | self._values[k] = v 119 | self._seen_so_far = current 120 | 121 | now = time.time() 122 | info = ' - %.0fs' % (now - self._start) 123 | if self.verbose == 1: 124 | if (now - self._last_update < self.interval and 125 | self.target is not None and current < self.target): 126 | return 127 | 128 | prev_total_width = self._total_width 129 | if self._dynamic_display: 130 | sys.stdout.write('\b' * prev_total_width) 131 | sys.stdout.write('\r') 132 | else: 133 | sys.stdout.write('\n') 134 | 135 | if self.target is not None: 136 | numdigits = int(np.floor(np.log10(self.target))) + 1 137 | barstr = '%%%dd/%d [' % (numdigits, self.target) 138 | bar = barstr % current 139 | prog = float(current) / self.target 140 | prog_width = int(self.width * prog) 141 | if prog_width > 0: 142 | bar += ('=' * (prog_width - 1)) 143 | if current < self.target: 144 | bar += '>' 145 | else: 146 | bar += '=' 147 | bar += ('.' * (self.width - prog_width)) 148 | bar += ']' 149 | else: 150 | bar = '%7d/Unknown' % current 151 | 152 | self._total_width = len(bar) 153 | sys.stdout.write(bar) 154 | 155 | if current: 156 | time_per_unit = (now - self._start) / current 157 | else: 158 | time_per_unit = 0 159 | if self.target is not None and current < self.target: 160 | eta = time_per_unit * (self.target - current) 161 | if eta > 3600: 162 | eta_format = '%d:%02d:%02d' % (eta // 3600, 163 | (eta % 3600) // 60, 164 | eta % 60) 165 | elif eta > 60: 166 | eta_format = '%d:%02d' % (eta // 60, eta % 60) 167 | else: 168 | eta_format = '%ds' % eta 169 | 170 | info = ' - ETA: %s' % eta_format 171 | else: 172 | if time_per_unit >= 1: 173 | info += ' %.0fs/step' % time_per_unit 174 | elif time_per_unit >= 1e-3: 175 | info += ' %.0fms/step' % (time_per_unit * 1e3) 176 | else: 177 | info += ' %.0fus/step' % (time_per_unit * 1e6) 178 | 179 | for k in self._values_order: 180 | info += ' - %s:' % k 181 | if isinstance(self._values[k], list): 182 | avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) 183 | if abs(avg) > 1e-3: 184 | info += ' %.4f' % avg 185 | else: 186 | info += ' %.4e' % avg 187 | else: 188 | info += ' %s' % self._values[k] 189 | 190 | self._total_width += len(info) 191 | if prev_total_width > self._total_width: 192 | info += (' ' * (prev_total_width - self._total_width)) 193 | 194 | if self.target is not None and current >= self.target: 195 | info += '\n' 196 | 197 | sys.stdout.write(info) 198 | sys.stdout.flush() 199 | 200 | elif self.verbose == 2: 201 | if self.target is None or current >= self.target: 202 | for k in self._values_order: 203 | info += ' - %s:' % k 204 | avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) 205 | if avg > 1e-3: 206 | info += ' %.4f' % avg 207 | else: 208 | info += ' %.4e' % avg 209 | info += '\n' 210 | 211 | sys.stdout.write(info) 212 | sys.stdout.flush() 213 | 214 | self._last_update = now 215 | 216 | def add(self, n, values=None): 217 | self.update(self._seen_so_far + n, values) 218 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from main import main 2 | main(mode=2) --------------------------------------------------------------------------------