├── .github
└── workflows
│ └── github-actions-demo.yml
├── .gitignore
├── LICENSE.md
├── README.md
├── __pycache__
├── main.cpython-37.pyc
├── utils.cpython-36.pyc
├── utils.cpython-37.pyc
└── utils.cpython-38.pyc
├── checkpoints
└── config.yml
├── cog.yaml
├── config.yml.example
├── examples
└── my_small_data
│ ├── 0090.jpg
│ ├── 0095.jpg
│ ├── 0376.jpg
│ ├── 0378.jpg
│ ├── 1023_256.jpg
│ ├── 1055_256.jpg
│ ├── 141_256.jpg
│ ├── 277_256.jpg
│ ├── 292_256.jpg
│ ├── 313_256.jpg
│ ├── 386_256.jpg
│ ├── 856_256.jpg
│ ├── 985_256.jpg
│ ├── IMG_4164.jpg
│ ├── IMG_4467.jpg
│ ├── IMG_4804.jpg
│ └── places2_02.png
├── main.py
├── predict.py
├── requirements.txt
├── scripts
├── download_model.sh
├── fid_score.py
├── flist.py
├── inception.py
└── metrics.py
├── segmentation_classes.txt
├── setup.cfg
├── src
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-37.pyc
│ ├── config.cpython-36.pyc
│ ├── config.cpython-37.pyc
│ ├── dataset.cpython-36.pyc
│ ├── dataset.cpython-37.pyc
│ ├── edge_connect.cpython-36.pyc
│ ├── edge_connect.cpython-37.pyc
│ ├── loss.cpython-36.pyc
│ ├── loss.cpython-37.pyc
│ ├── metrics.cpython-36.pyc
│ ├── metrics.cpython-37.pyc
│ ├── models.cpython-36.pyc
│ ├── models.cpython-37.pyc
│ ├── networks.cpython-36.pyc
│ ├── networks.cpython-37.pyc
│ ├── segmentor_fcn.cpython-37.pyc
│ ├── utils.cpython-36.pyc
│ └── utils.cpython-37.pyc
├── config.py
├── dataset.py
├── edge_connect.py
├── loss.py
├── models.py
├── networks.py
├── segmentor_fcn.py
└── utils.py
└── test.py
/.github/workflows/github-actions-demo.yml:
--------------------------------------------------------------------------------
1 | name: GitHub Actions start and python test
2 | on:
3 | # push:
4 | # branches: [ master ]
5 | pull_request:
6 | branches: [ master ]
7 |
8 | jobs:
9 | intro:
10 | name: first-interaction
11 | runs-on: ubuntu-latest
12 |
13 | steps:
14 | - name: First interaction
15 | uses: actions/first-interaction@v1.1.0
16 | with:
17 | # Token for the repository. Can be passed in using {{ secrets.GITHUB_TOKEN }}
18 | repo-token: ${{ secrets.GITHUB_TOKEN }}
19 | # Comment to post on an individual's first issue
20 | issue-message: "first issue"
21 | # Comment to post on an individual's first pull request
22 | pr-message: "welcome to my repo"
23 |
24 |
25 | python-testing:
26 | name: testing
27 | runs-on: ubuntu-latest
28 | strategy:
29 | fail-fast: false
30 | matrix:
31 | python-version: [3.7, 3.8]
32 |
33 | steps:
34 | - uses: actions/checkout@v2
35 | - name: Set up Python ${{ matrix.python-version }}
36 | uses: actions/setup-python@v2
37 | with:
38 | python-version: ${{ matrix.python-version }}
39 | - name: Install dependencies
40 | run: |
41 | python -m pip install --upgrade pip
42 | python -m pip install flake8 pytest
43 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
44 | - name: Lint with flake8
45 | run: |
46 | DIFF=$(git diff --name-only --diff-filter=b $(git merge-base HEAD $BRANCH))
47 | # stop the build if there are Python syntax errors or undefined names
48 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
49 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
50 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
51 |
52 | # - name: Test with pytest
53 | # run: |
54 | # pytest
55 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pth
2 | *.pyc
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | ## creative commons
2 |
3 | # Attribution-NonCommercial 4.0 International
4 |
5 | Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible.
6 |
7 | ### Using Creative Commons Public Licenses
8 |
9 | Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses.
10 |
11 | * __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors).
12 |
13 | * __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees).
14 |
15 | ## Creative Commons Attribution-NonCommercial 4.0 International Public License
16 |
17 | By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions.
18 |
19 | ### Section 1 – Definitions.
20 |
21 | a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image.
22 |
23 | b. __Adapter's License__ means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License.
24 |
25 | c. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights.
26 |
27 | d. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements.
28 |
29 | e. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material.
30 |
31 | f. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License.
32 |
33 | g. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license.
34 |
35 | h. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License.
36 |
37 | i. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange.
38 |
39 | j. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them.
40 |
41 | k. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world.
42 |
43 | l. __You__ means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning.
44 |
45 | ### Section 2 – Scope.
46 |
47 | a. ___License grant.___
48 |
49 | 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to:
50 |
51 | A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and
52 |
53 | B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only.
54 |
55 | 2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions.
56 |
57 | 3. __Term.__ The term of this Public License is specified in Section 6(a).
58 |
59 | 4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material.
60 |
61 | 5. __Downstream recipients.__
62 |
63 | A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License.
64 |
65 | B. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material.
66 |
67 | 6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i).
68 |
69 | b. ___Other rights.___
70 |
71 | 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise.
72 |
73 | 2. Patent and trademark rights are not licensed under this Public License.
74 |
75 | 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes.
76 |
77 | ### Section 3 – License Conditions.
78 |
79 | Your exercise of the Licensed Rights is expressly made subject to the following conditions.
80 |
81 | a. ___Attribution.___
82 |
83 | 1. If You Share the Licensed Material (including in modified form), You must:
84 |
85 | A. retain the following if it is supplied by the Licensor with the Licensed Material:
86 |
87 | i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated);
88 |
89 | ii. a copyright notice;
90 |
91 | iii. a notice that refers to this Public License;
92 |
93 | iv. a notice that refers to the disclaimer of warranties;
94 |
95 | v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable;
96 |
97 | B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and
98 |
99 | C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License.
100 |
101 | 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information.
102 |
103 | 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable.
104 |
105 | 4. If You Share Adapted Material You produce, the Adapter's License You apply must not prevent recipients of the Adapted Material from complying with this Public License.
106 |
107 | ### Section 4 – Sui Generis Database Rights.
108 |
109 | Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material:
110 |
111 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only;
112 |
113 | b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material; and
114 |
115 | c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database.
116 |
117 | For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights.
118 |
119 | ### Section 5 – Disclaimer of Warranties and Limitation of Liability.
120 |
121 | a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__
122 |
123 | b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__
124 |
125 | c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability.
126 |
127 | ### Section 6 – Term and Termination.
128 |
129 | a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically.
130 |
131 | b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates:
132 |
133 | 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or
134 |
135 | 2. upon express reinstatement by the Licensor.
136 |
137 | For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License.
138 |
139 | c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License.
140 |
141 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License.
142 |
143 | ### Section 7 – Other Terms and Conditions.
144 |
145 | a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed.
146 |
147 | b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License.
148 |
149 | ### Section 8 – Interpretation.
150 |
151 | a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License.
152 |
153 | b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions.
154 |
155 | c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor.
156 |
157 | d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority.
158 |
159 | > Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at [creativecommons.org/policies](http://creativecommons.org/policies), Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses.
160 | >
161 | > Creative Commons may be contacted at creativecommons.org
162 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Automated-Objects-Removal-Inpainter
2 |
3 | [Demo and Docker image on Replicate](https://replicate.com/sujaykhandekar/object-removal)
4 |
5 |
6 |
7 | Automated object remover Inpainter is a project that combines Semantic segmentation and EdgeConnect architectures with minor changes in order to remove specified objects from photos. For Semantic Segmentation, the code from pytorch has been adapted, whereas for EdgeConnect, the code has been adapted from [https://github.com/knazeri/edge-connect](https://github.com/knazeri/edge-connect).
8 |
9 | This project is capable of removing objects from list of 20 different ones.It can be used as photo editing tool as well as for Data augmentation.
10 |
11 | Python 3.8.5 and pytorch 1.5.1 have been used in this project.
12 |
13 | ## How does it work?
14 |
15 |
16 |
17 | Semantic segmenator model of deeplabv3/fcn resnet 101 has been combined with EdgeConnect. A pre-trained segmentation network has been used for object segmentation (generating a mask around detected object), and its output is fed to a EdgeConnect network along with input image with portion of mask removed. EdgeConnect uses two stage adversarial architecture where first stage is edge generator followed by image completion network. EdgeConnect paper can be found [here](https://arxiv.org/abs/1901.00212) and code in this [repo](https://github.com/knazeri/edge-connect)
18 |
19 |
20 |
21 |
22 | ## Prerequisite
23 | * python 3
24 | * pytorch 1.0.1 <
25 | * NVIDIA GPU + CUDA cuDNN (optional)
26 |
27 | ## Installation
28 | * clone this repo
29 | ```
30 | git clone https://github.com/sujaykhandekar/Automated-objects-removal-inpainter.git
31 | cd Automated-objects-removal-inpainter
32 | ```
33 | or alternately download zip file.
34 | * install pytorch with this command
35 | ```
36 | conda install pytorch==1.5.1 torchvision==0.6.1 -c pytorch
37 | ```
38 | * install other python requirements using this command
39 | ```
40 | pip install -r requirements.txt
41 | ```
42 | * Install one of the three pretrained Edgeconnect model and copy them in ./checkpoints directory
43 | [Plcaes2](https://drive.google.com/drive/folders/1qjieeThyse_iwJ0gZoJr-FERkgD5sm4y?usp=sharing) (option 1)
44 | [CelebA](https://drive.google.com/drive/folders/1nkLOhzWL-w2euo0U6amhz7HVzqNC5rqb) (option 2)
45 | [Paris-street-view](https://drive.google.com/drive/folders/1cGwDaZqDcqYU7kDuEbMXa9TP3uDJRBR1) (option 3)
46 |
47 | or alternately you can use this command:
48 | ```
49 | bash ./scripts/download_model.sh
50 | ```
51 |
52 | ## Prediction/Test
53 | For quick prediction you can run this command. If you don't have cuda/gpu please run the second command.
54 | ```
55 | python test.py --input ./examples/my_small_data --output ./checkpoints/resultsfinal --remove 3 15
56 | ```
57 | It will take sample images in the ./examples/my_small_data directory and will create and produce result in directory ./checkpoints/resultsfinal. You can replace these input /output directories with your desired ones.
58 | numbers after --remove specifies objects to be removed in the images. ABove command will remove 3(bird) and 15(people) from the images. Check segmentation-classes.txt for all removal options along with it's number.
59 |
60 | Output images will all be 256x256. It takes around 10 minutes for 1000 images on NVIDIA GeForce GTX 1650
61 |
62 | for better quality but slower runtime you can use this command
63 | ```
64 | python test.py --input ./examples/my_small_data --output ./checkpoints/resultsfinal --remove 3 15 --cpu yes
65 | ```
66 | It will run the segmentation model on cpu. It will be 5 times slower than on gpu (default)
67 | For other options including different segmentation model and EdgeConnect parameters to change please make corresponding modifications in .checkpoints/config.yml file
68 |
69 | ## training
70 | For training your own segmentation model you can refer to this [repo](https://github.com/CSAILVision/semantic-segmentation-pytorch) and replace .src/segmentor_fcn.py with your model.
71 |
72 | For training Edgeconnect model plaese refer to orignal [EdgeConnect repo](https://github.com/knazeri/edge-connect) after training you can copy your model weights in .checkpoints/
73 |
74 | ## some results
75 |
76 |
77 | ## Next Steps
78 | * pretrained EdgeConnect models used in this project are trained on 256 x256 images. To make output images of the same size as input two approaches can be used. You can train your own Edgeconnect model on bigger images.Or you can create subimages of 256x256 for every object detected in the image and then merge them back together after passing through edgeconnect to reconstruct orignal sized image.Similar approach has been used in this [repo](https://github.com/javirk/Person_remover)
79 | * To detect object not present in segmentation classes , you can train your own segmentation model or you can use pretrained segmentation models from this [repo](https://github.com/CSAILVision/semantic-segmentation-pytorch), which has 150 different categories available.
80 | * It is also possible to combine opnecv's feature matching and edge prediction from EdgeConnect to highlight and create mask for relvant objects based on single mask created by user. I may try this part myself.
81 |
82 | ## License
83 | Licensed under a [Creative Commons Attribution-NonCommercial 4.0 International.](https://creativecommons.org/licenses/by-nc/4.0/)
84 |
85 | Except where otherwise noted, this content is published under a [CC BY-NC](https://github.com/knazeri/edge-connect) license, which means that you can copy, remix, transform and build upon the content as long as you do not use the material for commercial purposes and give appropriate credit and provide a link to the license.
86 |
87 | ## Citation
88 | ```
89 | @inproceedings{nazeri2019edgeconnect,
90 | title={EdgeConnect: Generative Image Inpainting with Adversarial Edge Learning},
91 | author={Nazeri, Kamyar and Ng, Eric and Joseph, Tony and Qureshi, Faisal and Ebrahimi, Mehran},
92 | journal={arXiv preprint},
93 | year={2019},
94 | }
95 |
96 | @InProceedings{Nazeri_2019_ICCV,
97 | title = {EdgeConnect: Structure Guided Image Inpainting using Edge Prediction},
98 | author = {Nazeri, Kamyar and Ng, Eric and Joseph, Tony and Qureshi, Faisal and Ebrahimi, Mehran},
99 | booktitle = {The IEEE International Conference on Computer Vision (ICCV) Workshops},
100 | month = {Oct},
101 | year = {2019}
102 | }
103 | ```
104 |
--------------------------------------------------------------------------------
/__pycache__/main.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/__pycache__/main.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/checkpoints/config.yml:
--------------------------------------------------------------------------------
1 | MODE: 1 # 1: train, 2: test, 3: eval
2 | MODEL: 1 # 1: edge model, 2: inpaint model, 3: edge-inpaint model, 4: joint model
3 | MASK: 3 # 1: random block, 2: half, 3: external, 4: (external, random block), 5: (external, random block, half)
4 | EDGE: 1 # 1: canny, 2: external
5 | NMS: 1 # 0: no non-max-suppression, 1: applies non-max-suppression on the external edges by multiplying by Canny
6 | SEED: 10 # random seed
7 | GPU: [0] # list of gpu ids
8 | DEBUG: 0 # turns on debugging mode
9 | VERBOSE: 0 # turns on verbose mode in the output console
10 |
11 | TRAIN_FLIST: ./datasets/places2_train.flist
12 | VAL_FLIST: ./datasets/places2_val.flist
13 | TEST_FLIST: ./datasets/places2_test.flist
14 |
15 | TRAIN_EDGE_FLIST: ./datasets/places2_edges_train.flist
16 | VAL_EDGE_FLIST: ./datasets/places2_edges_val.flist
17 | TEST_EDGE_FLIST: ./datasets/places2_edges_test.flist
18 |
19 | TRAIN_MASK_FLIST: ./datasets/masks_train.flist
20 | VAL_MASK_FLIST: ./datasets/masks_val.flist
21 | TEST_MASK_FLIST: ./datasets/masks_test.flist
22 |
23 | LR: 0.0001 # learning rate
24 | D2G_LR: 0.1 # discriminator/generator learning rate ratio
25 | BETA1: 0.0 # adam optimizer beta1
26 | BETA2: 0.9 # adam optimizer beta2
27 | BATCH_SIZE: 8 # input batch size for training
28 | INPUT_SIZE: 256 # input image size for training 0 for original size
29 | SIGMA: 2 # standard deviation of the Gaussian filter used in Canny edge detector (0: random, -1: no edge)
30 | MAX_ITERS: 2e6 # maximum number of iterations to train the model
31 |
32 | EDGE_THRESHOLD: 0.5 # edge detection threshold
33 | L1_LOSS_WEIGHT: 1 # l1 loss weight
34 | FM_LOSS_WEIGHT: 10 # feature-matching loss weight
35 | STYLE_LOSS_WEIGHT: 250 # style loss weight
36 | CONTENT_LOSS_WEIGHT: 0.1 # perceptual loss weight
37 | INPAINT_ADV_LOSS_WEIGHT: 0.1 # adversarial loss weight
38 |
39 | GAN_LOSS: nsgan # nsgan | lsgan | hinge
40 | GAN_POOL_SIZE: 0 # fake images pool size
41 |
42 | SAVE_INTERVAL: 1000 # how many iterations to wait before saving model (0: never)
43 | SAMPLE_INTERVAL: 1000 # how many iterations to wait before sampling (0: never)
44 | SAMPLE_SIZE: 12 # number of images to sample
45 | EVAL_INTERVAL: 0 # how many iterations to wait before model evaluation (0: never)
46 | LOG_INTERVAL: 10 # how many iterations to wait before logging training status (0: never)
47 | SEG_NETWORK: 0 # 0:DeepLabV3 resnet 101 segmentation , 1: FCN resnet 101 segmentation
--------------------------------------------------------------------------------
/cog.yaml:
--------------------------------------------------------------------------------
1 | # Configuration for Cog ⚙️
2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
3 |
4 | build:
5 | # set to true if your model requires a GPU
6 | gpu: true
7 |
8 | # a list of ubuntu apt packages to install
9 | # system_packages:
10 | # - "libgl1-mesa-glx"
11 | # - "libglib2.0-0"
12 |
13 | # python version in the form '3.8' or '3.8.12'
14 | python_version: "3.8"
15 |
16 | # a list of packages in the format ==
17 | python_packages:
18 | - "ipython==7.33.0"
19 | # - "numpy==1.19.4"
20 | # - "torch==1.11.0"
21 | # - "torchvision==0.12.0"
22 | - "torch==1.5.1"
23 | - "torchvision==0.6.1"
24 | - "Pillow==8.0.1"
25 |
26 | - "numpy==1.19.1"
27 | - "scipy==1.5.2"
28 | - "future==0.18.2"
29 | - "matplotlib==3.3.0"
30 | # - "pillow==6.2.0"
31 | - "opencv-python==4.3.0.36"
32 | - "scikit-image==0.17.2"
33 | - "pyaml==20.4.0"
34 |
35 | # commands run after the environment is setup
36 | run:
37 | - "apt-get update && apt-get install -y cmake"
38 | # - "sudo apt-get install python python-pip build-essential cmake"
39 | # - "pip install tqdm gdown kornia scipy opencv-python dlib moviepy lpips aubio ninja"
40 |
41 | # predict.py defines how predictions are run on your model
42 | predict: "predict.py:Predictor"
--------------------------------------------------------------------------------
/config.yml.example:
--------------------------------------------------------------------------------
1 | MODE: 1 # 1: train, 2: test, 3: eval
2 | MODEL: 1 # 1: edge model, 2: inpaint model, 3: edge-inpaint model, 4: joint model
3 | MASK: 3 # 1: random block, 2: half, 3: external, 4: (external, random block), 5: (external, random block, half)
4 | EDGE: 1 # 1: canny, 2: external
5 | NMS: 1 # 0: no non-max-suppression, 1: applies non-max-suppression on the external edges by multiplying by Canny
6 | SEED: 10 # random seed
7 | GPU: [0] # list of gpu ids
8 | DEBUG: 0 # turns on debugging mode
9 | VERBOSE: 0 # turns on verbose mode in the output console
10 |
11 | TRAIN_FLIST: ./datasets/places2_train.flist
12 | VAL_FLIST: ./datasets/places2_val.flist
13 | TEST_FLIST: ./datasets/places2_test.flist
14 |
15 | TRAIN_EDGE_FLIST: ./datasets/places2_edges_train.flist
16 | VAL_EDGE_FLIST: ./datasets/places2_edges_val.flist
17 | TEST_EDGE_FLIST: ./datasets/places2_edges_test.flist
18 |
19 | TRAIN_MASK_FLIST: ./datasets/masks_train.flist
20 | VAL_MASK_FLIST: ./datasets/masks_val.flist
21 | TEST_MASK_FLIST: ./datasets/masks_test.flist
22 |
23 | LR: 0.0001 # learning rate
24 | D2G_LR: 0.1 # discriminator/generator learning rate ratio
25 | BETA1: 0.0 # adam optimizer beta1
26 | BETA2: 0.9 # adam optimizer beta2
27 | BATCH_SIZE: 8 # input batch size for training
28 | INPUT_SIZE: 256 # input image size for training 0 for original size
29 | SIGMA: 2 # standard deviation of the Gaussian filter used in Canny edge detector (0: random, -1: no edge)
30 | MAX_ITERS: 2e6 # maximum number of iterations to train the model
31 |
32 | EDGE_THRESHOLD: 0.5 # edge detection threshold
33 | L1_LOSS_WEIGHT: 1 # l1 loss weight
34 | FM_LOSS_WEIGHT: 10 # feature-matching loss weight
35 | STYLE_LOSS_WEIGHT: 250 # style loss weight
36 | CONTENT_LOSS_WEIGHT: 0.1 # perceptual loss weight
37 | INPAINT_ADV_LOSS_WEIGHT: 0.1 # adversarial loss weight
38 |
39 | GAN_LOSS: nsgan # nsgan | lsgan | hinge
40 | GAN_POOL_SIZE: 0 # fake images pool size
41 |
42 | SAVE_INTERVAL: 1000 # how many iterations to wait before saving model (0: never)
43 | SAMPLE_INTERVAL: 1000 # how many iterations to wait before sampling (0: never)
44 | SAMPLE_SIZE: 12 # number of images to sample
45 | EVAL_INTERVAL: 0 # how many iterations to wait before model evaluation (0: never)
46 | LOG_INTERVAL: 10 # how many iterations to wait before logging training status (0: never)
--------------------------------------------------------------------------------
/examples/my_small_data/0090.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/examples/my_small_data/0090.jpg
--------------------------------------------------------------------------------
/examples/my_small_data/0095.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/examples/my_small_data/0095.jpg
--------------------------------------------------------------------------------
/examples/my_small_data/0376.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/examples/my_small_data/0376.jpg
--------------------------------------------------------------------------------
/examples/my_small_data/0378.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/examples/my_small_data/0378.jpg
--------------------------------------------------------------------------------
/examples/my_small_data/1023_256.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/examples/my_small_data/1023_256.jpg
--------------------------------------------------------------------------------
/examples/my_small_data/1055_256.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/examples/my_small_data/1055_256.jpg
--------------------------------------------------------------------------------
/examples/my_small_data/141_256.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/examples/my_small_data/141_256.jpg
--------------------------------------------------------------------------------
/examples/my_small_data/277_256.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/examples/my_small_data/277_256.jpg
--------------------------------------------------------------------------------
/examples/my_small_data/292_256.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/examples/my_small_data/292_256.jpg
--------------------------------------------------------------------------------
/examples/my_small_data/313_256.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/examples/my_small_data/313_256.jpg
--------------------------------------------------------------------------------
/examples/my_small_data/386_256.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/examples/my_small_data/386_256.jpg
--------------------------------------------------------------------------------
/examples/my_small_data/856_256.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/examples/my_small_data/856_256.jpg
--------------------------------------------------------------------------------
/examples/my_small_data/985_256.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/examples/my_small_data/985_256.jpg
--------------------------------------------------------------------------------
/examples/my_small_data/IMG_4164.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/examples/my_small_data/IMG_4164.jpg
--------------------------------------------------------------------------------
/examples/my_small_data/IMG_4467.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/examples/my_small_data/IMG_4467.jpg
--------------------------------------------------------------------------------
/examples/my_small_data/IMG_4804.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/examples/my_small_data/IMG_4804.jpg
--------------------------------------------------------------------------------
/examples/my_small_data/places2_02.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/examples/my_small_data/places2_02.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import random
4 | import numpy as np
5 | import torch
6 | import argparse
7 | from shutil import copyfile
8 | from src.config import Config
9 | from src.edge_connect import EdgeConnect
10 |
11 |
12 | def main(mode=None):
13 | r"""starts the model
14 |
15 | """
16 |
17 | config = load_config(mode)
18 |
19 |
20 | # cuda visble devices
21 | os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(str(e) for e in config.GPU)
22 |
23 |
24 | # init device
25 | if torch.cuda.is_available():
26 | config.DEVICE = torch.device("cuda")
27 | torch.backends.cudnn.benchmark = True # cudnn auto-tuner
28 | else:
29 | config.DEVICE = torch.device("cpu")
30 |
31 |
32 |
33 | # set cv2 running threads to 1 (prevents deadlocks with pytorch dataloader)
34 | cv2.setNumThreads(0)
35 |
36 |
37 | # initialize random seed
38 | torch.manual_seed(config.SEED)
39 | torch.cuda.manual_seed_all(config.SEED)
40 | np.random.seed(config.SEED)
41 | random.seed(config.SEED)
42 |
43 |
44 |
45 | # build the model and initialize
46 | model = EdgeConnect(config)
47 | model.load()
48 |
49 |
50 |
51 | # model test
52 | print('\nstart testing...\n')
53 | model.test()
54 |
55 |
56 |
57 | def load_config(mode=None):
58 | r"""loads model config
59 |
60 | """
61 |
62 | parser = argparse.ArgumentParser()
63 | parser.add_argument('--path', '--checkpoints', type=str, default='./checkpoints', help='model checkpoints path (default: ./checkpoints)')
64 | parser.add_argument('--model', type=int, choices=[1, 2, 3, 4], help='1: edge model, 2: inpaint model, 3: edge-inpaint model, 4: joint model')
65 |
66 | # test mode
67 | parser.add_argument('--input', type=str, help='path to the input images directory or an input image')
68 | parser.add_argument('--edge', type=str, help='path to the edges directory or an edge file')
69 | parser.add_argument('--output', type=str, help='path to the output directory')
70 | parser.add_argument('--remove', nargs= '*' ,type=int, help='objects to remove')
71 | parser.add_argument('--cpu', type=str, help='machine to run segmentation model on')
72 | args = parser.parse_args()
73 |
74 | #if path for checkpoint not given
75 | if args.path is None:
76 | args.path='./checkpoints'
77 | config_path = os.path.join(args.path, 'config.yml')
78 |
79 | # create checkpoints path if does't exist
80 | if not os.path.exists(args.path):
81 | os.makedirs(args.path)
82 |
83 | # copy config template if does't exist
84 | if not os.path.exists(config_path):
85 | copyfile('./config.yml.example', config_path)
86 |
87 | # load config file
88 | config = Config(config_path)
89 |
90 |
91 | # test mode
92 | config.MODE = 2
93 | config.MODEL = args.model if args.model is not None else 3
94 | config.OBJECTS = args.remove if args.remove is not None else [3,15]
95 | config.SEG_DEVICE = 'cpu' if args.cpu is not None else 'cuda'
96 | config.INPUT_SIZE = 256
97 | if args.input is not None:
98 | config.TEST_FLIST = args.input
99 |
100 | if args.edge is not None:
101 | config.TEST_EDGE_FLIST = args.edge
102 | if args.output is not None:
103 | config.RESULTS = args.output
104 | else:
105 | if not os.path.exists('./results_images'):
106 | os.makedirs('./results_images')
107 | config.RESULTS = './results_images'
108 |
109 |
110 |
111 |
112 |
113 | return config
114 |
115 |
116 | if __name__ == "__main__":
117 | main()
118 |
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 | from shutil import copyfile
5 |
6 | import cv2
7 | import numpy as np
8 | import torch
9 | import torchvision
10 | import torchvision.transforms
11 | from PIL import Image
12 |
13 |
14 | # cog
15 | from cog import BasePredictor, Input, Path
16 |
17 | from src.config import Config
18 | from src.edge_connect import EdgeConnect
19 |
20 | import tempfile
21 |
22 |
23 | # Maps object to index
24 | obj2idx = {
25 | "Background":0, "Aeroplane":1, "bicycle":2, "bird":3, "boat":4, "bottle":5, "bus":6, "car":7, "cat":8,
26 | "chair":9, "cow":10, "dining table":11, "dog":12, "horse":13, "motorbike":14, "person":15,
27 | "potted plant":16, "sheep":17, "sofa":18, "train":19, "tv/monitor":20
28 | }
29 |
30 |
31 |
32 | def load_config(mode=None, objects_to_remove=None):
33 | print('Object(s) to remove:', objects_to_remove)
34 |
35 |
36 | # load config file
37 | path = "./checkpoints"
38 | config_path = os.path.join(path, "config.yml")
39 |
40 | # create checkpoints path if does't exist
41 | if not os.path.exists(path):
42 | os.makedirs(path)
43 |
44 | # copy config template if does't exist
45 | if not os.path.exists(config_path):
46 | copyfile("./config.yml.example", config_path)
47 |
48 | # load config file
49 | config = Config(config_path)
50 |
51 | # test mode
52 | config.MODE = mode
53 | config.MODEL = 3
54 | config.OBJECTS = objects_to_remove
55 | config.SEG_DEVICE = "cuda"
56 | config.INPUT_SIZE = 256
57 |
58 | # outputs
59 | if not os.path.exists("./results_images"):
60 | os.makedirs("./results_images")
61 | config.RESULTS = "./results_images"
62 | return config
63 |
64 |
65 | # Instantiate Cog Predictor
66 | class Predictor(BasePredictor):
67 | def setup(self):
68 |
69 | # Select torch device
70 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
71 |
72 | def predict(self,
73 | image_path: Path = Input(description="Input image (ideally a square image)"),
74 | objects_to_remove: str = Input(description="Object(s) to remove (separate with comma, e.g. car,cat,bird). See full list of names at https://github.com/sujaykhandekar/Automated-objects-removal-inpainter/blob/master/segmentation_classes.txt", default='person,car'),
75 |
76 | ) -> Path:
77 |
78 | # format input image
79 | image_path = str(image_path)
80 | image = Image.open(image_path).convert('RGB')
81 | image.save(image_path) # resave formatted image
82 |
83 | # parse objects to remove
84 | objects_to_remove = objects_to_remove.split(',')
85 | objects_to_remove = [obj2idx[x] for x in objects_to_remove]
86 |
87 | mode = 2 # 1: train, 2: test, 3: eal
88 | self.config = load_config(mode, objects_to_remove=objects_to_remove)
89 | # set cv2 running threads to 1 (prevents deadlocks with pytorch dataloader)
90 | cv2.setNumThreads(0)
91 |
92 | # initialize random seed
93 | torch.manual_seed(self.config.SEED)
94 | torch.cuda.manual_seed_all(self.config.SEED)
95 | np.random.seed(self.config.SEED)
96 | random.seed(self.config.SEED)
97 |
98 | # save to path
99 | self.config.TEST_FLIST = image_path
100 |
101 | # build the model and initialize
102 | model = EdgeConnect(self.config)
103 | model.load()
104 |
105 | # model test
106 | output_image = model.test()
107 | output_image = output_image.cpu().numpy()
108 | output_image = Image.fromarray(np.uint8(output_image)).convert('RGB')
109 |
110 | # save output image as Cog Path object
111 | output_path = Path(tempfile.mkdtemp()) / "output.png"
112 | output_image.save(output_path)
113 | print(output_path)
114 | return output_path
115 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy ~= 1.19.1
2 | scipy ~= 1.5.2
3 | future ~= 0.18.2
4 | matplotlib ~= 3.3.0
5 | pillow >= 6.2.0
6 | opencv-python ~= 4.3.0.36
7 | scikit-image ~= 0.17.2
8 | pyaml ~= 20.4.0
9 |
--------------------------------------------------------------------------------
/scripts/download_model.sh:
--------------------------------------------------------------------------------
1 | DIR='./checkpoints'
2 | URL='https://drive.google.com/uc?export=download&id=1IrlFQGTpdQYdPeZIEgGUaSFpbYtNpekA'
3 |
4 | echo "Downloading pre-trained models..."
5 | mkdir -p $DIR
6 | FILE="$(curl -sc /tmp/gcokie "${URL}" | grep -o '="uc-name.*' | sed 's/.*">//;s/<.a> .*//')"
7 | curl -Lb /tmp/gcokie "${URL}&confirm=$(awk '/_warning_/ {print $NF}' /tmp/gcokie)" -o "$DIR/${FILE}"
8 |
9 | echo "Extracting pre-trained models..."
10 | cd $DIR
11 | unzip $FILE
12 | rm $FILE
13 |
14 | echo "Download success."
--------------------------------------------------------------------------------
/scripts/fid_score.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """Calculates the Frechet Inception Distance (FID) to evalulate GANs
3 | The FID metric calculates the distance between two distributions of images.
4 | Typically, we have summary statistics (mean & covariance matrix) of one
5 | of these distributions, while the 2nd distribution is given by a GAN.
6 | When run as a stand-alone program, it compares the distribution of
7 | images that are stored as PNG/JPEG at a specified location with a
8 | distribution given by summary statistics (in pickle format).
9 | The FID is calculated by assuming that X_1 and X_2 are the activations of
10 | the pool_3 layer of the inception net for generated samples and real world
11 | samples respectivly.
12 | See --help to see further details.
13 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
14 | of Tensorflow
15 | Copyright 2018 Institute of Bioinformatics, JKU Linz
16 | Licensed under the Apache License, Version 2.0 (the "License");
17 | you may not use this file except in compliance with the License.
18 | You may obtain a copy of the License at
19 | http://www.apache.org/licenses/LICENSE-2.0
20 | Unless required by applicable law or agreed to in writing, software
21 | distributed under the License is distributed on an "AS IS" BASIS,
22 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23 | See the License for the specific language governing permissions and
24 | limitations under the License.
25 | """
26 | import os
27 | import pathlib
28 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
29 |
30 | import torch
31 | import numpy as np
32 | from scipy.misc import imread
33 | from scipy import linalg
34 | from torch.autograd import Variable
35 | from torch.nn.functional import adaptive_avg_pool2d
36 |
37 | from inception import InceptionV3
38 |
39 |
40 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
41 | parser.add_argument('--path', type=str, nargs=2, help=('Path to the generated images or to .npz statistic files'))
42 | parser.add_argument('--batch-size', type=int, default=64, help='Batch size to use')
43 | parser.add_argument('--dims', type=int, default=2048, choices=list(InceptionV3.BLOCK_INDEX_BY_DIM), help=('Dimensionality of Inception features to use. By default, uses pool3 features'))
44 | parser.add_argument('-c', '--gpu', default='', type=str, help='GPU to use (leave blank for CPU only)')
45 |
46 |
47 | def get_activations(images, model, batch_size=64, dims=2048,
48 | cuda=False, verbose=False):
49 | """Calculates the activations of the pool_3 layer for all images.
50 | Params:
51 | -- images : Numpy array of dimension (n_images, 3, hi, wi). The values
52 | must lie between 0 and 1.
53 | -- model : Instance of inception model
54 | -- batch_size : the images numpy array is split into batches with
55 | batch size batch_size. A reasonable batch size depends
56 | on the hardware.
57 | -- dims : Dimensionality of features returned by Inception
58 | -- cuda : If set to True, use GPU
59 | -- verbose : If set to True and parameter out_step is given, the number
60 | of calculated batches is reported.
61 | Returns:
62 | -- A numpy array of dimension (num images, dims) that contains the
63 | activations of the given tensor when feeding inception with the
64 | query tensor.
65 | """
66 | model.eval()
67 |
68 | d0 = images.shape[0]
69 | if batch_size > d0:
70 | print(('Warning: batch size is bigger than the data size. '
71 | 'Setting batch size to data size'))
72 | batch_size = d0
73 |
74 | n_batches = d0 // batch_size
75 | n_used_imgs = n_batches * batch_size
76 |
77 | pred_arr = np.empty((n_used_imgs, dims))
78 | for i in range(n_batches):
79 | if verbose:
80 | print('\rPropagating batch %d/%d' % (i + 1, n_batches),
81 | end='', flush=True)
82 | start = i * batch_size
83 | end = start + batch_size
84 |
85 | batch = torch.from_numpy(images[start:end]).type(torch.FloatTensor)
86 | batch = Variable(batch, volatile=True)
87 | if cuda:
88 | batch = batch.cuda()
89 |
90 | pred = model(batch)[0]
91 |
92 | # If model output is not scalar, apply global spatial average pooling.
93 | # This happens if you choose a dimensionality not equal 2048.
94 | if pred.shape[2] != 1 or pred.shape[3] != 1:
95 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
96 |
97 | pred_arr[start:end] = pred.cpu().data.numpy().reshape(batch_size, -1)
98 |
99 | if verbose:
100 | print(' done')
101 |
102 | return pred_arr
103 |
104 |
105 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
106 | """Numpy implementation of the Frechet Distance.
107 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
108 | and X_2 ~ N(mu_2, C_2) is
109 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
110 | Stable version by Dougal J. Sutherland.
111 | Params:
112 | -- mu1 : Numpy array containing the activations of a layer of the
113 | inception net (like returned by the function 'get_predictions')
114 | for generated samples.
115 | -- mu2 : The sample mean over activations, precalculated on an
116 | representive data set.
117 | -- sigma1: The covariance matrix over activations for generated samples.
118 | -- sigma2: The covariance matrix over activations, precalculated on an
119 | representive data set.
120 | Returns:
121 | -- : The Frechet Distance.
122 | """
123 |
124 | mu1 = np.atleast_1d(mu1)
125 | mu2 = np.atleast_1d(mu2)
126 |
127 | sigma1 = np.atleast_2d(sigma1)
128 | sigma2 = np.atleast_2d(sigma2)
129 |
130 | assert mu1.shape == mu2.shape, \
131 | 'Training and test mean vectors have different lengths'
132 | assert sigma1.shape == sigma2.shape, \
133 | 'Training and test covariances have different dimensions'
134 |
135 | diff = mu1 - mu2
136 |
137 | # Product might be almost singular
138 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
139 | if not np.isfinite(covmean).all():
140 | msg = ('fid calculation produces singular product; '
141 | 'adding %s to diagonal of cov estimates') % eps
142 | print(msg)
143 | offset = np.eye(sigma1.shape[0]) * eps
144 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
145 |
146 | # Numerical error might give slight imaginary component
147 | if np.iscomplexobj(covmean):
148 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
149 | m = np.max(np.abs(covmean.imag))
150 | raise ValueError('Imaginary component {}'.format(m))
151 | covmean = covmean.real
152 |
153 | tr_covmean = np.trace(covmean)
154 |
155 | return (diff.dot(diff) + np.trace(sigma1) +
156 | np.trace(sigma2) - 2 * tr_covmean)
157 |
158 |
159 | def calculate_activation_statistics(images, model, batch_size=64,
160 | dims=2048, cuda=False, verbose=False):
161 | """Calculation of the statistics used by the FID.
162 | Params:
163 | -- images : Numpy array of dimension (n_images, 3, hi, wi). The values
164 | must lie between 0 and 1.
165 | -- model : Instance of inception model
166 | -- batch_size : The images numpy array is split into batches with
167 | batch size batch_size. A reasonable batch size
168 | depends on the hardware.
169 | -- dims : Dimensionality of features returned by Inception
170 | -- cuda : If set to True, use GPU
171 | -- verbose : If set to True and parameter out_step is given, the
172 | number of calculated batches is reported.
173 | Returns:
174 | -- mu : The mean over samples of the activations of the pool_3 layer of
175 | the inception model.
176 | -- sigma : The covariance matrix of the activations of the pool_3 layer of
177 | the inception model.
178 | """
179 | act = get_activations(images, model, batch_size, dims, cuda, verbose)
180 | mu = np.mean(act, axis=0)
181 | sigma = np.cov(act, rowvar=False)
182 | return mu, sigma
183 |
184 |
185 | def _compute_statistics_of_path(path, model, batch_size, dims, cuda):
186 | npz_file = os.path.join(path, 'statistics.npz')
187 | if os.path.exists(npz_file):
188 | f = np.load(npz_file)
189 | m, s = f['mu'][:], f['sigma'][:]
190 | f.close()
191 | else:
192 | path = pathlib.Path(path)
193 | files = list(path.glob('*.jpg')) + list(path.glob('*.png'))
194 |
195 | imgs = np.array([imread(str(fn)).astype(np.float32) for fn in files])
196 |
197 | # Bring images to shape (B, 3, H, W)
198 | imgs = imgs.transpose((0, 3, 1, 2))
199 |
200 | # Rescale images to be between 0 and 1
201 | imgs /= 255
202 |
203 | m, s = calculate_activation_statistics(imgs, model, batch_size, dims, cuda)
204 | np.savez(npz_file, mu=m, sigma=s)
205 |
206 | return m, s
207 |
208 |
209 | def calculate_fid_given_paths(paths, batch_size, cuda, dims):
210 | """Calculates the FID of two paths"""
211 | for p in paths:
212 | if not os.path.exists(p):
213 | raise RuntimeError('Invalid path: %s' % p)
214 |
215 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
216 |
217 | model = InceptionV3([block_idx])
218 | if cuda:
219 | model.cuda()
220 |
221 | print('calculate path1 statistics...')
222 | m1, s1 = _compute_statistics_of_path(paths[0], model, batch_size, dims, cuda)
223 | print('calculate path2 statistics...')
224 | m2, s2 = _compute_statistics_of_path(paths[1], model, batch_size, dims, cuda)
225 | print('calculate frechet distance...')
226 | fid_value = calculate_frechet_distance(m1, s1, m2, s2)
227 |
228 | return fid_value
229 |
230 |
231 | if __name__ == '__main__':
232 | args = parser.parse_args()
233 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
234 |
235 | fid_value = calculate_fid_given_paths(args.path,
236 | args.batch_size,
237 | args.gpu != '',
238 | args.dims)
239 | print('FID: ', round(fid_value, 4))
240 |
--------------------------------------------------------------------------------
/scripts/flist.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import numpy as np
4 |
5 | parser = argparse.ArgumentParser()
6 | parser.add_argument('--path', type=str, help='path to the dataset')
7 | parser.add_argument('--output', type=str, help='path to the file list')
8 | args = parser.parse_args()
9 |
10 | ext = {'.JPG', '.JPEG', '.PNG', '.TIF', 'TIFF'}
11 |
12 | images = []
13 | for root, dirs, files in os.walk(args.path):
14 | print('loading ' + root)
15 | for file in files:
16 | if os.path.splitext(file)[1].upper() in ext:
17 | images.append(os.path.join(root, file))
18 |
19 | images = sorted(images)
20 | np.savetxt(args.output, images, fmt='%s')
--------------------------------------------------------------------------------
/scripts/inception.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | from torchvision import models
4 |
5 |
6 | class InceptionV3(nn.Module):
7 | """Pretrained InceptionV3 network returning feature maps"""
8 |
9 | # Index of default block of inception to return,
10 | # corresponds to output of final average pooling
11 | DEFAULT_BLOCK_INDEX = 3
12 |
13 | # Maps feature dimensionality to their output blocks indices
14 | BLOCK_INDEX_BY_DIM = {
15 | 64: 0, # First max pooling features
16 | 192: 1, # Second max pooling featurs
17 | 768: 2, # Pre-aux classifier features
18 | 2048: 3 # Final average pooling features
19 | }
20 |
21 | def __init__(self,
22 | output_blocks=[DEFAULT_BLOCK_INDEX],
23 | resize_input=True,
24 | normalize_input=True,
25 | requires_grad=False):
26 | """Build pretrained InceptionV3
27 | Parameters
28 | ----------
29 | output_blocks : list of int
30 | Indices of blocks to return features of. Possible values are:
31 | - 0: corresponds to output of first max pooling
32 | - 1: corresponds to output of second max pooling
33 | - 2: corresponds to output which is fed to aux classifier
34 | - 3: corresponds to output of final average pooling
35 | resize_input : bool
36 | If true, bilinearly resizes input to width and height 299 before
37 | feeding input to model. As the network without fully connected
38 | layers is fully convolutional, it should be able to handle inputs
39 | of arbitrary size, so resizing might not be strictly needed
40 | normalize_input : bool
41 | If true, normalizes the input to the statistics the pretrained
42 | Inception network expects
43 | requires_grad : bool
44 | If true, parameters of the model require gradient. Possibly useful
45 | for finetuning the network
46 | """
47 | super(InceptionV3, self).__init__()
48 |
49 | self.resize_input = resize_input
50 | self.normalize_input = normalize_input
51 | self.output_blocks = sorted(output_blocks)
52 | self.last_needed_block = max(output_blocks)
53 |
54 | assert self.last_needed_block <= 3, \
55 | 'Last possible output block index is 3'
56 |
57 | self.blocks = nn.ModuleList()
58 |
59 | inception = models.inception_v3(pretrained=True)
60 |
61 | # Block 0: input to maxpool1
62 | block0 = [
63 | inception.Conv2d_1a_3x3,
64 | inception.Conv2d_2a_3x3,
65 | inception.Conv2d_2b_3x3,
66 | nn.MaxPool2d(kernel_size=3, stride=2)
67 | ]
68 | self.blocks.append(nn.Sequential(*block0))
69 |
70 | # Block 1: maxpool1 to maxpool2
71 | if self.last_needed_block >= 1:
72 | block1 = [
73 | inception.Conv2d_3b_1x1,
74 | inception.Conv2d_4a_3x3,
75 | nn.MaxPool2d(kernel_size=3, stride=2)
76 | ]
77 | self.blocks.append(nn.Sequential(*block1))
78 |
79 | # Block 2: maxpool2 to aux classifier
80 | if self.last_needed_block >= 2:
81 | block2 = [
82 | inception.Mixed_5b,
83 | inception.Mixed_5c,
84 | inception.Mixed_5d,
85 | inception.Mixed_6a,
86 | inception.Mixed_6b,
87 | inception.Mixed_6c,
88 | inception.Mixed_6d,
89 | inception.Mixed_6e,
90 | ]
91 | self.blocks.append(nn.Sequential(*block2))
92 |
93 | # Block 3: aux classifier to final avgpool
94 | if self.last_needed_block >= 3:
95 | block3 = [
96 | inception.Mixed_7a,
97 | inception.Mixed_7b,
98 | inception.Mixed_7c,
99 | nn.AdaptiveAvgPool2d(output_size=(1, 1))
100 | ]
101 | self.blocks.append(nn.Sequential(*block3))
102 |
103 | for param in self.parameters():
104 | param.requires_grad = requires_grad
105 |
106 | def forward(self, inp):
107 | """Get Inception feature maps
108 | Parameters
109 | ----------
110 | inp : torch.autograd.Variable
111 | Input tensor of shape Bx3xHxW. Values are expected to be in
112 | range (0, 1)
113 | Returns
114 | -------
115 | List of torch.autograd.Variable, corresponding to the selected output
116 | block, sorted ascending by index
117 | """
118 | outp = []
119 | x = inp
120 |
121 | if self.resize_input:
122 | x = F.upsample(x, size=(299, 299), mode='bilinear')
123 |
124 | if self.normalize_input:
125 | x = x.clone()
126 | x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
127 | x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
128 | x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
129 |
130 | for idx, block in enumerate(self.blocks):
131 | x = block(x)
132 | if idx in self.output_blocks:
133 | outp.append(x)
134 |
135 | if idx == self.last_needed_block:
136 | break
137 |
138 | return outp
139 |
--------------------------------------------------------------------------------
/scripts/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import argparse
3 | import matplotlib.pyplot as plt
4 |
5 | from glob import glob
6 | from ntpath import basename
7 | from scipy.misc import imread
8 | from skimage.measure import compare_ssim
9 | from skimage.measure import compare_psnr
10 | from skimage.color import rgb2gray
11 |
12 |
13 | def parse_args():
14 | parser = argparse.ArgumentParser(description='script to compute all statistics')
15 | parser.add_argument('--data-path', help='Path to ground truth data', type=str)
16 | parser.add_argument('--output-path', help='Path to output data', type=str)
17 | parser.add_argument('--debug', default=0, help='Debug', type=int)
18 | args = parser.parse_args()
19 | return args
20 |
21 |
22 | def compare_mae(img_true, img_test):
23 | img_true = img_true.astype(np.float32)
24 | img_test = img_test.astype(np.float32)
25 | return np.sum(np.abs(img_true - img_test)) / np.sum(img_true + img_test)
26 |
27 |
28 | args = parse_args()
29 | for arg in vars(args):
30 | print('[%s] =' % arg, getattr(args, arg))
31 |
32 | path_true = args.data_path
33 | path_pred = args.output_path
34 |
35 | psnr = []
36 | ssim = []
37 | mae = []
38 | names = []
39 | index = 1
40 |
41 | files = list(glob(path_true + '/*.jpg')) + list(glob(path_true + '/*.png'))
42 | for fn in sorted(files):
43 | name = basename(str(fn))
44 | names.append(name)
45 |
46 | img_gt = (imread(str(fn)) / 255.0).astype(np.float32)
47 | img_pred = (imread(path_pred + '/' + basename(str(fn))) / 255.0).astype(np.float32)
48 |
49 | img_gt = rgb2gray(img_gt)
50 | img_pred = rgb2gray(img_pred)
51 |
52 | if args.debug != 0:
53 | plt.subplot('121')
54 | plt.imshow(img_gt)
55 | plt.title('Groud truth')
56 | plt.subplot('122')
57 | plt.imshow(img_pred)
58 | plt.title('Output')
59 | plt.show()
60 |
61 | psnr.append(compare_psnr(img_gt, img_pred, data_range=1))
62 | ssim.append(compare_ssim(img_gt, img_pred, data_range=1, win_size=51))
63 | mae.append(compare_mae(img_gt, img_pred))
64 | if np.mod(index, 100) == 0:
65 | print(
66 | str(index) + ' images processed',
67 | "PSNR: %.4f" % round(np.mean(psnr), 4),
68 | "SSIM: %.4f" % round(np.mean(ssim), 4),
69 | "MAE: %.4f" % round(np.mean(mae), 4),
70 | )
71 | index += 1
72 |
73 | np.savez(args.output_path + '/metrics.npz', psnr=psnr, ssim=ssim, mae=mae, names=names)
74 | print(
75 | "PSNR: %.4f" % round(np.mean(psnr), 4),
76 | "PSNR Variance: %.4f" % round(np.var(psnr), 4),
77 | "SSIM: %.4f" % round(np.mean(ssim), 4),
78 | "SSIM Variance: %.4f" % round(np.var(ssim), 4),
79 | "MAE: %.4f" % round(np.mean(mae), 4),
80 | "MAE Variance: %.4f" % round(np.var(mae), 4)
81 | )
82 |
--------------------------------------------------------------------------------
/segmentation_classes.txt:
--------------------------------------------------------------------------------
1 | 0: Background
2 | 1: Aeroplane
3 | 2: bicycle
4 | 3: bird
5 | 4: boat
6 | 5: bottle
7 | 6: bus
8 | 7: car
9 | 8: cat
10 | 9: chair
11 | 10: cow
12 | 11: dining table
13 | 12: dog
14 | 13: horse
15 | 14: motorbike
16 | 15: person
17 | 16: potted plant
18 | 17: sheep
19 | 18: sofa
20 | 19: train
21 | 20: tv/monitor
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [pycodestyle]
2 | ignore = E303
3 | max-line-length = 200
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | # empty
--------------------------------------------------------------------------------
/src/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/src/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/src/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/src/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/src/__pycache__/config.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/src/__pycache__/config.cpython-36.pyc
--------------------------------------------------------------------------------
/src/__pycache__/config.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/src/__pycache__/config.cpython-37.pyc
--------------------------------------------------------------------------------
/src/__pycache__/dataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/src/__pycache__/dataset.cpython-36.pyc
--------------------------------------------------------------------------------
/src/__pycache__/dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/src/__pycache__/dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/src/__pycache__/edge_connect.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/src/__pycache__/edge_connect.cpython-36.pyc
--------------------------------------------------------------------------------
/src/__pycache__/edge_connect.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/src/__pycache__/edge_connect.cpython-37.pyc
--------------------------------------------------------------------------------
/src/__pycache__/loss.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/src/__pycache__/loss.cpython-36.pyc
--------------------------------------------------------------------------------
/src/__pycache__/loss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/src/__pycache__/loss.cpython-37.pyc
--------------------------------------------------------------------------------
/src/__pycache__/metrics.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/src/__pycache__/metrics.cpython-36.pyc
--------------------------------------------------------------------------------
/src/__pycache__/metrics.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/src/__pycache__/metrics.cpython-37.pyc
--------------------------------------------------------------------------------
/src/__pycache__/models.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/src/__pycache__/models.cpython-36.pyc
--------------------------------------------------------------------------------
/src/__pycache__/models.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/src/__pycache__/models.cpython-37.pyc
--------------------------------------------------------------------------------
/src/__pycache__/networks.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/src/__pycache__/networks.cpython-36.pyc
--------------------------------------------------------------------------------
/src/__pycache__/networks.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/src/__pycache__/networks.cpython-37.pyc
--------------------------------------------------------------------------------
/src/__pycache__/segmentor_fcn.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/src/__pycache__/segmentor_fcn.cpython-37.pyc
--------------------------------------------------------------------------------
/src/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/src/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/src/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sujaykhandekar/Automated-objects-removal-inpainter/73fa9c42967016d544fc02ffc538e0d25d4e5071/src/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/src/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import yaml
3 |
4 | class Config(dict):
5 | def __init__(self, config_path):
6 | with open(config_path, 'r') as f:
7 | self._yaml = f.read()
8 | self._dict = yaml.safe_load(self._yaml)
9 | self._dict['PATH'] = os.path.dirname(config_path)
10 |
11 | def __getattr__(self, name):
12 | if self._dict.get(name) is not None:
13 | return self._dict[name]
14 |
15 | if DEFAULT_CONFIG.get(name) is not None:
16 | return DEFAULT_CONFIG[name]
17 |
18 | return None
19 |
20 | def print(self):
21 | print('Model configurations:')
22 | print('---------------------------------')
23 | print(self._yaml)
24 | print('')
25 | print('---------------------------------')
26 | print('')
27 |
28 |
29 | DEFAULT_CONFIG = {
30 | 'MODE': 1, # 1: train, 2: test, 3: eval
31 | 'MODEL': 1, # 1: edge model, 2: inpaint model, 3: edge-inpaint model, 4: joint model
32 | 'MASK': 3, # 1: random block, 2: half, 3: external, 4: (external, random block), 5: (external, random block, half)
33 | 'EDGE': 1, # 1: canny, 2: external
34 | 'NMS': 1, # 0: no non-max-suppression, 1: applies non-max-suppression on the external edges by multiplying by Canny
35 | 'SEED': 10, # random seed
36 | 'GPU': [0], # list of gpu ids
37 | 'DEBUG': 0, # turns on debugging mode
38 | 'VERBOSE': 0, # turns on verbose mode in the output console
39 |
40 | 'LR': 0.0001, # learning rate
41 | 'D2G_LR': 0.1, # discriminator/generator learning rate ratio
42 | 'BETA1': 0.0, # adam optimizer beta1
43 | 'BETA2': 0.9, # adam optimizer beta2
44 | 'BATCH_SIZE': 8, # input batch size for training
45 | 'INPUT_SIZE': 256, # input image size for training 0 for original size
46 | 'SIGMA': 2, # standard deviation of the Gaussian filter used in Canny edge detector (0: random, -1: no edge)
47 | 'MAX_ITERS': 2e6, # maximum number of iterations to train the model
48 |
49 | 'EDGE_THRESHOLD': 0.5, # edge detection threshold
50 | 'L1_LOSS_WEIGHT': 1, # l1 loss weight
51 | 'FM_LOSS_WEIGHT': 10, # feature-matching loss weight
52 | 'STYLE_LOSS_WEIGHT': 1, # style loss weight
53 | 'CONTENT_LOSS_WEIGHT': 1, # perceptual loss weight
54 | 'INPAINT_ADV_LOSS_WEIGHT': 0.01,# adversarial loss weight
55 |
56 | 'GAN_LOSS': 'nsgan', # nsgan | lsgan | hinge
57 | 'GAN_POOL_SIZE': 0, # fake images pool size
58 |
59 | 'SAVE_INTERVAL': 1000, # how many iterations to wait before saving model (0: never)
60 | 'SAMPLE_INTERVAL': 1000, # how many iterations to wait before sampling (0: never)
61 | 'SAMPLE_SIZE': 12, # number of images to sample
62 | 'EVAL_INTERVAL': 0, # how many iterations to wait before model evaluation (0: never)
63 | 'LOG_INTERVAL': 10, # how many iterations to wait before logging training status (0: never)
64 | 'SEG_NETWORK': 0, # 0:DeepLabV3 resnet 101 segmentation , 1: FCN resnet 101 segmentation
65 | }
66 |
--------------------------------------------------------------------------------
/src/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import torch
4 | import random
5 | import numpy as np
6 | import torchvision.transforms.functional as F
7 | from torch.utils.data import DataLoader
8 | from PIL import Image
9 | from imageio import imread
10 | from skimage.feature import canny
11 | from skimage.color import rgb2gray, gray2rgb
12 | from .utils import create_mask
13 | import cv2
14 | from .segmentor_fcn import segmentor,fill_gaps
15 |
16 |
17 | class Dataset(torch.utils.data.Dataset):
18 | def __init__(self, config, flist, edge_flist, augment=True, training=True):
19 | super(Dataset, self).__init__()
20 | self.augment = augment
21 | self.training = training
22 | self.data = self.load_flist(flist)
23 | self.edge_data = self.load_flist(edge_flist)
24 |
25 | self.input_size = config.INPUT_SIZE
26 | self.sigma = config.SIGMA
27 | self.edge = config.EDGE
28 | self.mask = config.MASK
29 | self.nms = config.NMS
30 | self.device = config.SEG_DEVICE
31 | self.objects = config.OBJECTS
32 | self.segment_net = config.SEG_NETWORK
33 | # in test mode, there's a one-to-one relationship between mask and image
34 | # masks are loaded non random
35 |
36 |
37 | def __len__(self):
38 | return len(self.data)
39 |
40 | def __getitem__(self, index):
41 | try:
42 | item = self.load_item(index)
43 | except:
44 | print('loading error: ' + self.data[index])
45 | item = self.load_item(0)
46 |
47 | return item
48 |
49 | def load_name(self, index):
50 | name = self.data[index]
51 | return os.path.basename(name)
52 |
53 | def load_size(self, index):
54 | img = Image.open(self.data[index])
55 | width,height=img.size
56 | return width,height
57 |
58 |
59 | def load_item(self, index):
60 |
61 | size = self.input_size
62 |
63 | # load image
64 | img = Image.open(self.data[index])
65 |
66 |
67 |
68 |
69 |
70 | # gray to rgb
71 | if img.mode !='RGB':
72 | img = gray2rgb(np.array(img))
73 | img=Image.fromarray(img)
74 |
75 | # resize/crop if needed
76 | img,mask=segmentor(self.segment_net,img,self.device,self.objects)
77 | img = Image.fromarray(img)
78 | img = np.array(img.resize((size, size), Image.ANTIALIAS))
79 |
80 |
81 | # create grayscale image
82 | img_gray = rgb2gray(np.array(img))
83 |
84 |
85 |
86 |
87 | # load mask
88 | mask = Image.fromarray(mask)
89 | mask = np.array(mask.resize((size, size), Image.ANTIALIAS))
90 | idx=(mask>0)
91 | mask[idx]=255
92 | #kernel = np.ones((5, 5), np.uint8)
93 | #opening = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
94 | #closing = cv2.morphologyEx(opening, cv2.MORPH_CLOSE, kernel)
95 | mask=np.apply_along_axis(fill_gaps, 1, mask) #horizontal padding
96 | mask=np.apply_along_axis(fill_gaps, 0, mask) #vertical padding
97 |
98 |
99 |
100 |
101 | # load edge
102 | edge = self.load_edge(img_gray, index, mask)
103 |
104 | # augment data
105 | if self.augment and np.random.binomial(1, 0.5) > 0:
106 | img = img[:, ::-1, ...]
107 | img_gray = img_gray[:, ::-1, ...]
108 | edge = edge[:, ::-1, ...]
109 | mask = mask[:, ::-1, ...]
110 |
111 | return self.to_tensor(img), self.to_tensor(img_gray), self.to_tensor(edge), self.to_tensor(mask)
112 |
113 | def load_edge(self, img, index, mask):
114 | sigma = self.sigma
115 |
116 | # in test mode images are masked (with masked regions),
117 | # using 'mask' parameter prevents canny to detect edges for the masked regions
118 | mask = None if self.training else (1 - mask / 255).astype(np.bool)
119 |
120 | # canny
121 | if self.edge == 1:
122 | # no edge
123 | if sigma == -1:
124 | return np.zeros(img.shape).astype(np.float)
125 |
126 | # random sigma
127 | if sigma == 0:
128 | sigma = random.randint(1, 4)
129 |
130 | return canny(img, sigma=sigma, mask=mask).astype(np.float)
131 |
132 | # external
133 | else:
134 | imgh, imgw = img.shape[0:2]
135 | edge = imread(self.edge_data[index])
136 | edge = self.resized(edge, imgh, imgw)
137 |
138 | # non-max suppression
139 | if self.nms == 1:
140 | edge = edge * canny(img, sigma=sigma, mask=mask)
141 |
142 | return edge
143 |
144 |
145 | def to_tensor(self, img):
146 | img = Image.fromarray(img)
147 | img_t = F.to_tensor(img).float()
148 | return img_t
149 |
150 |
151 | def load_flist(self, flist):
152 | if isinstance(flist, list):
153 | return flist
154 |
155 | # flist: image file path, image directory path, text file flist path
156 | if isinstance(flist, str):
157 | if os.path.isdir(flist):
158 | flist = list(glob.glob(flist + '/*.jpg')) + list(glob.glob(flist + '/*.png'))
159 | flist.sort()
160 | return flist
161 |
162 | if os.path.isfile(flist):
163 | try:
164 | return np.genfromtxt(flist, dtype=np.str, encoding='utf-8')
165 | except:
166 | return [flist]
167 |
168 | return []
169 |
170 | def create_iterator(self, batch_size):
171 | while True:
172 | sample_loader = DataLoader(
173 | dataset=self,
174 | batch_size=batch_size,
175 | drop_last=True
176 | )
177 |
178 | for item in sample_loader:
179 | yield item
180 |
--------------------------------------------------------------------------------
/src/edge_connect.py:
--------------------------------------------------------------------------------
1 | '''
2 | Code of EdgeConnect is from this repo
3 | https://github.com/knazeri/edge-connect
4 | '''
5 |
6 |
7 |
8 | import os
9 | import numpy as np
10 | import torch
11 | from torch.utils.data import DataLoader
12 | from .dataset import Dataset
13 | from .models import EdgeModel, InpaintingModel
14 | from .utils import Progbar, create_dir, stitch_images, imsave
15 | from PIL import Image
16 | from torchvision import transforms
17 |
18 | class EdgeConnect():
19 | def __init__(self, config):
20 | self.config = config
21 |
22 | if config.MODEL == 1:
23 | model_name = 'edge'
24 | elif config.MODEL == 2:
25 | model_name = 'inpaint'
26 | elif config.MODEL == 3:
27 | model_name = 'edge_inpaint'
28 | elif config.MODEL == 4:
29 | model_name = 'joint'
30 |
31 | self.debug = False
32 | self.model_name = model_name
33 | self.edge_model = EdgeModel(config).to(config.DEVICE)
34 | self.inpaint_model = InpaintingModel(config).to(config.DEVICE)
35 |
36 |
37 | # test mode
38 | self.test_dataset = Dataset(config, config.TEST_FLIST, config.TEST_EDGE_FLIST, augment=False, training=False)
39 |
40 | self.samples_path = os.path.join(config.PATH, 'samples')
41 | self.results_path = os.path.join(config.PATH, 'results')
42 |
43 | if config.RESULTS is not None:
44 | self.results_path = os.path.join(config.RESULTS)
45 |
46 | if config.DEBUG is not None and config.DEBUG != 0:
47 | self.debug = True
48 |
49 | self.log_file = os.path.join(config.PATH, 'log_' + model_name + '.dat')
50 |
51 | def load(self):
52 | if self.config.MODEL == 1:
53 | self.edge_model.load()
54 |
55 | elif self.config.MODEL == 2:
56 | self.inpaint_model.load()
57 |
58 | else:
59 | self.edge_model.load()
60 | self.inpaint_model.load()
61 |
62 | def save(self):
63 | if self.config.MODEL == 1:
64 | self.edge_model.save()
65 |
66 | elif self.config.MODEL == 2 or self.config.MODEL == 3:
67 | self.inpaint_model.save()
68 |
69 | else:
70 | self.edge_model.save()
71 | self.inpaint_model.save()
72 |
73 |
74 | def test(self):
75 | self.edge_model.eval()
76 | self.inpaint_model.eval()
77 |
78 | model = self.config.MODEL
79 | create_dir(self.results_path)
80 |
81 | test_loader = DataLoader(
82 | dataset=self.test_dataset,
83 | batch_size=1,
84 | )
85 |
86 | index = 0
87 | for items in test_loader:
88 | name = self.test_dataset.load_name(index)
89 |
90 | images, images_gray, edges, masks = self.cuda(*items)
91 | index += 1
92 |
93 | # edge model
94 | if model == 1:
95 | outputs = self.edge_model(images_gray, edges, masks)
96 | outputs_merged = (outputs * masks) + (edges * (1 - masks))
97 |
98 | # inpaint model
99 | elif model == 2:
100 | outputs = self.inpaint_model(images, edges, masks)
101 | outputs_merged = (outputs * masks) + (images * (1 - masks))
102 |
103 | # inpaint with edge model / joint model
104 | else:
105 | edges = self.edge_model(images_gray, edges, masks).detach()
106 | outputs = self.inpaint_model(images, edges, masks)
107 | outputs_merged = (outputs * masks) + (images * (1 - masks))
108 |
109 | output = self.postprocess(outputs_merged)[0]
110 | path = os.path.join(self.results_path, name)
111 | print(index, name)
112 |
113 | imsave(output, path)
114 |
115 | if self.debug:
116 | edges = self.postprocess(1 - edges)[0]
117 | masked = self.postprocess(images * (1 - masks) + masks)[0]
118 | fname, fext = name.split('.')
119 |
120 | imsave(edges, os.path.join(self.results_path, fname + '_edge.' + fext))
121 | imsave(masked, os.path.join(self.results_path, fname + '_masked.' + fext))
122 |
123 | print('\nEnd test....')
124 | return output
125 |
126 |
127 | def log(self, logs):
128 | with open(self.log_file, 'a') as f:
129 | f.write('%s\n' % ' '.join([str(item[1]) for item in logs]))
130 |
131 | def cuda(self, *args):
132 | return (item.to(self.config.DEVICE) for item in args)
133 |
134 | def postprocess(self, img):
135 | # [0, 1] => [0, 255]
136 | img = img * 255.0
137 | img = img.permute(0, 2, 3, 1)
138 | return img.int()
139 |
--------------------------------------------------------------------------------
/src/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision.models as models
4 |
5 |
6 | class AdversarialLoss(nn.Module):
7 | r"""
8 | Adversarial loss
9 | https://arxiv.org/abs/1711.10337
10 | """
11 |
12 | def __init__(self, type='nsgan', target_real_label=1.0, target_fake_label=0.0):
13 | r"""
14 | type = nsgan | lsgan | hinge
15 | """
16 | super(AdversarialLoss, self).__init__()
17 |
18 | self.type = type
19 | self.register_buffer('real_label', torch.tensor(target_real_label))
20 | self.register_buffer('fake_label', torch.tensor(target_fake_label))
21 |
22 | if type == 'nsgan':
23 | self.criterion = nn.BCELoss()
24 |
25 | elif type == 'lsgan':
26 | self.criterion = nn.MSELoss()
27 |
28 | elif type == 'hinge':
29 | self.criterion = nn.ReLU()
30 |
31 | def __call__(self, outputs, is_real, is_disc=None):
32 | if self.type == 'hinge':
33 | if is_disc:
34 | if is_real:
35 | outputs = -outputs
36 | return self.criterion(1 + outputs).mean()
37 | else:
38 | return (-outputs).mean()
39 |
40 | else:
41 | labels = (self.real_label if is_real else self.fake_label).expand_as(outputs)
42 | loss = self.criterion(outputs, labels)
43 | return loss
44 |
45 |
46 | class StyleLoss(nn.Module):
47 | r"""
48 | Perceptual loss, VGG-based
49 | https://arxiv.org/abs/1603.08155
50 | https://github.com/dxyang/StyleTransfer/blob/master/utils.py
51 | """
52 |
53 | def __init__(self):
54 | super(StyleLoss, self).__init__()
55 | self.add_module('vgg', VGG19())
56 | self.criterion = torch.nn.L1Loss()
57 |
58 | def compute_gram(self, x):
59 | b, ch, h, w = x.size()
60 | f = x.view(b, ch, w * h)
61 | f_T = f.transpose(1, 2)
62 | G = f.bmm(f_T) / (h * w * ch)
63 |
64 | return G
65 |
66 | def __call__(self, x, y):
67 | # Compute features
68 | x_vgg, y_vgg = self.vgg(x), self.vgg(y)
69 |
70 | # Compute loss
71 | style_loss = 0.0
72 | style_loss += self.criterion(self.compute_gram(x_vgg['relu2_2']), self.compute_gram(y_vgg['relu2_2']))
73 | style_loss += self.criterion(self.compute_gram(x_vgg['relu3_4']), self.compute_gram(y_vgg['relu3_4']))
74 | style_loss += self.criterion(self.compute_gram(x_vgg['relu4_4']), self.compute_gram(y_vgg['relu4_4']))
75 | style_loss += self.criterion(self.compute_gram(x_vgg['relu5_2']), self.compute_gram(y_vgg['relu5_2']))
76 |
77 | return style_loss
78 |
79 |
80 |
81 | class PerceptualLoss(nn.Module):
82 | r"""
83 | Perceptual loss, VGG-based
84 | https://arxiv.org/abs/1603.08155
85 | https://github.com/dxyang/StyleTransfer/blob/master/utils.py
86 | """
87 |
88 | def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]):
89 | super(PerceptualLoss, self).__init__()
90 | self.add_module('vgg', VGG19())
91 | self.criterion = torch.nn.L1Loss()
92 | self.weights = weights
93 |
94 | def __call__(self, x, y):
95 | # Compute features
96 | x_vgg, y_vgg = self.vgg(x), self.vgg(y)
97 |
98 | content_loss = 0.0
99 | content_loss += self.weights[0] * self.criterion(x_vgg['relu1_1'], y_vgg['relu1_1'])
100 | content_loss += self.weights[1] * self.criterion(x_vgg['relu2_1'], y_vgg['relu2_1'])
101 | content_loss += self.weights[2] * self.criterion(x_vgg['relu3_1'], y_vgg['relu3_1'])
102 | content_loss += self.weights[3] * self.criterion(x_vgg['relu4_1'], y_vgg['relu4_1'])
103 | content_loss += self.weights[4] * self.criterion(x_vgg['relu5_1'], y_vgg['relu5_1'])
104 |
105 |
106 | return content_loss
107 |
108 |
109 |
110 | class VGG19(torch.nn.Module):
111 | def __init__(self):
112 | super(VGG19, self).__init__()
113 | features = models.vgg19(pretrained=True).features
114 | self.relu1_1 = torch.nn.Sequential()
115 | self.relu1_2 = torch.nn.Sequential()
116 |
117 | self.relu2_1 = torch.nn.Sequential()
118 | self.relu2_2 = torch.nn.Sequential()
119 |
120 | self.relu3_1 = torch.nn.Sequential()
121 | self.relu3_2 = torch.nn.Sequential()
122 | self.relu3_3 = torch.nn.Sequential()
123 | self.relu3_4 = torch.nn.Sequential()
124 |
125 | self.relu4_1 = torch.nn.Sequential()
126 | self.relu4_2 = torch.nn.Sequential()
127 | self.relu4_3 = torch.nn.Sequential()
128 | self.relu4_4 = torch.nn.Sequential()
129 |
130 | self.relu5_1 = torch.nn.Sequential()
131 | self.relu5_2 = torch.nn.Sequential()
132 | self.relu5_3 = torch.nn.Sequential()
133 | self.relu5_4 = torch.nn.Sequential()
134 |
135 | for x in range(2):
136 | self.relu1_1.add_module(str(x), features[x])
137 |
138 | for x in range(2, 4):
139 | self.relu1_2.add_module(str(x), features[x])
140 |
141 | for x in range(4, 7):
142 | self.relu2_1.add_module(str(x), features[x])
143 |
144 | for x in range(7, 9):
145 | self.relu2_2.add_module(str(x), features[x])
146 |
147 | for x in range(9, 12):
148 | self.relu3_1.add_module(str(x), features[x])
149 |
150 | for x in range(12, 14):
151 | self.relu3_2.add_module(str(x), features[x])
152 |
153 | for x in range(14, 16):
154 | self.relu3_3.add_module(str(x), features[x])
155 |
156 | for x in range(16, 18):
157 | self.relu3_4.add_module(str(x), features[x])
158 |
159 | for x in range(18, 21):
160 | self.relu4_1.add_module(str(x), features[x])
161 |
162 | for x in range(21, 23):
163 | self.relu4_2.add_module(str(x), features[x])
164 |
165 | for x in range(23, 25):
166 | self.relu4_3.add_module(str(x), features[x])
167 |
168 | for x in range(25, 27):
169 | self.relu4_4.add_module(str(x), features[x])
170 |
171 | for x in range(27, 30):
172 | self.relu5_1.add_module(str(x), features[x])
173 |
174 | for x in range(30, 32):
175 | self.relu5_2.add_module(str(x), features[x])
176 |
177 | for x in range(32, 34):
178 | self.relu5_3.add_module(str(x), features[x])
179 |
180 | for x in range(34, 36):
181 | self.relu5_4.add_module(str(x), features[x])
182 |
183 | # don't need the gradients, just want the features
184 | for param in self.parameters():
185 | param.requires_grad = False
186 |
187 | def forward(self, x):
188 | relu1_1 = self.relu1_1(x)
189 | relu1_2 = self.relu1_2(relu1_1)
190 |
191 | relu2_1 = self.relu2_1(relu1_2)
192 | relu2_2 = self.relu2_2(relu2_1)
193 |
194 | relu3_1 = self.relu3_1(relu2_2)
195 | relu3_2 = self.relu3_2(relu3_1)
196 | relu3_3 = self.relu3_3(relu3_2)
197 | relu3_4 = self.relu3_4(relu3_3)
198 |
199 | relu4_1 = self.relu4_1(relu3_4)
200 | relu4_2 = self.relu4_2(relu4_1)
201 | relu4_3 = self.relu4_3(relu4_2)
202 | relu4_4 = self.relu4_4(relu4_3)
203 |
204 | relu5_1 = self.relu5_1(relu4_4)
205 | relu5_2 = self.relu5_2(relu5_1)
206 | relu5_3 = self.relu5_3(relu5_2)
207 | relu5_4 = self.relu5_4(relu5_3)
208 |
209 | out = {
210 | 'relu1_1': relu1_1,
211 | 'relu1_2': relu1_2,
212 |
213 | 'relu2_1': relu2_1,
214 | 'relu2_2': relu2_2,
215 |
216 | 'relu3_1': relu3_1,
217 | 'relu3_2': relu3_2,
218 | 'relu3_3': relu3_3,
219 | 'relu3_4': relu3_4,
220 |
221 | 'relu4_1': relu4_1,
222 | 'relu4_2': relu4_2,
223 | 'relu4_3': relu4_3,
224 | 'relu4_4': relu4_4,
225 |
226 | 'relu5_1': relu5_1,
227 | 'relu5_2': relu5_2,
228 | 'relu5_3': relu5_3,
229 | 'relu5_4': relu5_4,
230 | }
231 | return out
232 |
--------------------------------------------------------------------------------
/src/models.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | import torch.optim as optim
5 | from .networks import InpaintGenerator, EdgeGenerator, Discriminator
6 | from .loss import AdversarialLoss, PerceptualLoss, StyleLoss
7 |
8 |
9 | class BaseModel(nn.Module):
10 | def __init__(self, name, config):
11 | super(BaseModel, self).__init__()
12 |
13 | self.name = name
14 | self.config = config
15 | self.iteration = 0
16 |
17 | self.gen_weights_path = os.path.join(config.PATH, name + '_gen.pth')
18 | self.dis_weights_path = os.path.join(config.PATH, name + '_dis.pth')
19 |
20 | def load(self):
21 | if os.path.exists(self.gen_weights_path):
22 | print('Loading %s generator...' % self.name)
23 |
24 | if torch.cuda.is_available():
25 | data = torch.load(self.gen_weights_path)
26 | else:
27 | data = torch.load(self.gen_weights_path, map_location=lambda storage, loc: storage)
28 |
29 | self.generator.load_state_dict(data['generator'])
30 | self.iteration = data['iteration']
31 |
32 | # load discriminator only when training
33 | if self.config.MODE == 1 and os.path.exists(self.dis_weights_path):
34 | print('Loading %s discriminator...' % self.name)
35 |
36 | if torch.cuda.is_available():
37 | data = torch.load(self.dis_weights_path)
38 | else:
39 | data = torch.load(self.dis_weights_path, map_location=lambda storage, loc: storage)
40 |
41 | self.discriminator.load_state_dict(data['discriminator'])
42 |
43 | def save(self):
44 | print('\nsaving %s...\n' % self.name)
45 | torch.save({
46 | 'iteration': self.iteration,
47 | 'generator': self.generator.state_dict()
48 | }, self.gen_weights_path)
49 |
50 | torch.save({
51 | 'discriminator': self.discriminator.state_dict()
52 | }, self.dis_weights_path)
53 |
54 |
55 | class EdgeModel(BaseModel):
56 | def __init__(self, config):
57 | super(EdgeModel, self).__init__('EdgeModel', config)
58 |
59 | # generator input: [grayscale(1) + edge(1) + mask(1)]
60 | # discriminator input: (grayscale(1) + edge(1))
61 | generator = EdgeGenerator(use_spectral_norm=True)
62 | discriminator = Discriminator(in_channels=2, use_sigmoid=config.GAN_LOSS != 'hinge')
63 | if len(config.GPU) > 1:
64 | generator = nn.DataParallel(generator, config.GPU)
65 | discriminator = nn.DataParallel(discriminator, config.GPU)
66 | l1_loss = nn.L1Loss()
67 | adversarial_loss = AdversarialLoss(type=config.GAN_LOSS)
68 |
69 | self.add_module('generator', generator)
70 | self.add_module('discriminator', discriminator)
71 |
72 | self.add_module('l1_loss', l1_loss)
73 | self.add_module('adversarial_loss', adversarial_loss)
74 |
75 | self.gen_optimizer = optim.Adam(
76 | params=generator.parameters(),
77 | lr=float(config.LR),
78 | betas=(config.BETA1, config.BETA2)
79 | )
80 |
81 | self.dis_optimizer = optim.Adam(
82 | params=discriminator.parameters(),
83 | lr=float(config.LR) * float(config.D2G_LR),
84 | betas=(config.BETA1, config.BETA2)
85 | )
86 |
87 | def process(self, images, edges, masks):
88 | self.iteration += 1
89 |
90 |
91 | # zero optimizers
92 | self.gen_optimizer.zero_grad()
93 | self.dis_optimizer.zero_grad()
94 |
95 |
96 | # process outputs
97 | outputs = self(images, edges, masks)
98 | gen_loss = 0
99 | dis_loss = 0
100 |
101 |
102 | # discriminator loss
103 | dis_input_real = torch.cat((images, edges), dim=1)
104 | dis_input_fake = torch.cat((images, outputs.detach()), dim=1)
105 | dis_real, dis_real_feat = self.discriminator(dis_input_real) # in: (grayscale(1) + edge(1))
106 | dis_fake, dis_fake_feat = self.discriminator(dis_input_fake) # in: (grayscale(1) + edge(1))
107 | dis_real_loss = self.adversarial_loss(dis_real, True, True)
108 | dis_fake_loss = self.adversarial_loss(dis_fake, False, True)
109 | dis_loss += (dis_real_loss + dis_fake_loss) / 2
110 |
111 |
112 | # generator adversarial loss
113 | gen_input_fake = torch.cat((images, outputs), dim=1)
114 | gen_fake, gen_fake_feat = self.discriminator(gen_input_fake) # in: (grayscale(1) + edge(1))
115 | gen_gan_loss = self.adversarial_loss(gen_fake, True, False)
116 | gen_loss += gen_gan_loss
117 |
118 |
119 | # generator feature matching loss
120 | gen_fm_loss = 0
121 | for i in range(len(dis_real_feat)):
122 | gen_fm_loss += self.l1_loss(gen_fake_feat[i], dis_real_feat[i].detach())
123 | gen_fm_loss = gen_fm_loss * self.config.FM_LOSS_WEIGHT
124 | gen_loss += gen_fm_loss
125 |
126 |
127 | # create logs
128 | logs = [
129 | ("l_d1", dis_loss.item()),
130 | ("l_g1", gen_gan_loss.item()),
131 | ("l_fm", gen_fm_loss.item()),
132 | ]
133 |
134 | return outputs, gen_loss, dis_loss, logs
135 |
136 | def forward(self, images, edges, masks):
137 | edges_masked = (edges * (1 - masks))
138 | images_masked = (images * (1 - masks)) + masks
139 | inputs = torch.cat((images_masked, edges_masked, masks), dim=1)
140 | outputs = self.generator(inputs) # in: [grayscale(1) + edge(1) + mask(1)]
141 | return outputs
142 |
143 | def backward(self, gen_loss=None, dis_loss=None):
144 | if dis_loss is not None:
145 | dis_loss.backward()
146 | self.dis_optimizer.step()
147 |
148 | if gen_loss is not None:
149 | gen_loss.backward()
150 | self.gen_optimizer.step()
151 |
152 |
153 | class InpaintingModel(BaseModel):
154 | def __init__(self, config):
155 | super(InpaintingModel, self).__init__('InpaintingModel', config)
156 |
157 | # generator input: [rgb(3) + edge(1)]
158 | # discriminator input: [rgb(3)]
159 | generator = InpaintGenerator()
160 | discriminator = Discriminator(in_channels=3, use_sigmoid=config.GAN_LOSS != 'hinge')
161 | if len(config.GPU) > 1:
162 | generator = nn.DataParallel(generator, config.GPU)
163 | discriminator = nn.DataParallel(discriminator , config.GPU)
164 |
165 | l1_loss = nn.L1Loss()
166 | perceptual_loss = PerceptualLoss()
167 | style_loss = StyleLoss()
168 | adversarial_loss = AdversarialLoss(type=config.GAN_LOSS)
169 |
170 | self.add_module('generator', generator)
171 | self.add_module('discriminator', discriminator)
172 |
173 | self.add_module('l1_loss', l1_loss)
174 | self.add_module('perceptual_loss', perceptual_loss)
175 | self.add_module('style_loss', style_loss)
176 | self.add_module('adversarial_loss', adversarial_loss)
177 |
178 | self.gen_optimizer = optim.Adam(
179 | params=generator.parameters(),
180 | lr=float(config.LR),
181 | betas=(config.BETA1, config.BETA2)
182 | )
183 |
184 | self.dis_optimizer = optim.Adam(
185 | params=discriminator.parameters(),
186 | lr=float(config.LR) * float(config.D2G_LR),
187 | betas=(config.BETA1, config.BETA2)
188 | )
189 |
190 | def process(self, images, edges, masks):
191 | self.iteration += 1
192 |
193 | # zero optimizers
194 | self.gen_optimizer.zero_grad()
195 | self.dis_optimizer.zero_grad()
196 |
197 |
198 | # process outputs
199 | outputs = self(images, edges, masks)
200 | gen_loss = 0
201 | dis_loss = 0
202 |
203 |
204 | # discriminator loss
205 | dis_input_real = images
206 | dis_input_fake = outputs.detach()
207 | dis_real, _ = self.discriminator(dis_input_real) # in: [rgb(3)]
208 | dis_fake, _ = self.discriminator(dis_input_fake) # in: [rgb(3)]
209 | dis_real_loss = self.adversarial_loss(dis_real, True, True)
210 | dis_fake_loss = self.adversarial_loss(dis_fake, False, True)
211 | dis_loss += (dis_real_loss + dis_fake_loss) / 2
212 |
213 |
214 | # generator adversarial loss
215 | gen_input_fake = outputs
216 | gen_fake, _ = self.discriminator(gen_input_fake) # in: [rgb(3)]
217 | gen_gan_loss = self.adversarial_loss(gen_fake, True, False) * self.config.INPAINT_ADV_LOSS_WEIGHT
218 | gen_loss += gen_gan_loss
219 |
220 |
221 | # generator l1 loss
222 | gen_l1_loss = self.l1_loss(outputs, images) * self.config.L1_LOSS_WEIGHT / torch.mean(masks)
223 | gen_loss += gen_l1_loss
224 |
225 |
226 | # generator perceptual loss
227 | gen_content_loss = self.perceptual_loss(outputs, images)
228 | gen_content_loss = gen_content_loss * self.config.CONTENT_LOSS_WEIGHT
229 | gen_loss += gen_content_loss
230 |
231 |
232 | # generator style loss
233 | gen_style_loss = self.style_loss(outputs * masks, images * masks)
234 | gen_style_loss = gen_style_loss * self.config.STYLE_LOSS_WEIGHT
235 | gen_loss += gen_style_loss
236 |
237 |
238 | # create logs
239 | logs = [
240 | ("l_d2", dis_loss.item()),
241 | ("l_g2", gen_gan_loss.item()),
242 | ("l_l1", gen_l1_loss.item()),
243 | ("l_per", gen_content_loss.item()),
244 | ("l_sty", gen_style_loss.item()),
245 | ]
246 |
247 | return outputs, gen_loss, dis_loss, logs
248 |
249 | def forward(self, images, edges, masks):
250 | images_masked = (images * (1 - masks).float()) + masks
251 | inputs = torch.cat((images_masked, edges), dim=1)
252 | outputs = self.generator(inputs) # in: [rgb(3) + edge(1)]
253 | return outputs
254 |
255 | def backward(self, gen_loss=None, dis_loss=None):
256 | dis_loss.backward()
257 | self.dis_optimizer.step()
258 |
259 | gen_loss.backward()
260 | self.gen_optimizer.step()
261 |
--------------------------------------------------------------------------------
/src/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class BaseNetwork(nn.Module):
6 | def __init__(self):
7 | super(BaseNetwork, self).__init__()
8 |
9 | def init_weights(self, init_type='normal', gain=0.02):
10 | '''
11 | initialize network's weights
12 | init_type: normal | xavier | kaiming | orthogonal
13 | https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
14 | '''
15 |
16 | def init_func(m):
17 | classname = m.__class__.__name__
18 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
19 | if init_type == 'normal':
20 | nn.init.normal_(m.weight.data, 0.0, gain)
21 | elif init_type == 'xavier':
22 | nn.init.xavier_normal_(m.weight.data, gain=gain)
23 | elif init_type == 'kaiming':
24 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
25 | elif init_type == 'orthogonal':
26 | nn.init.orthogonal_(m.weight.data, gain=gain)
27 |
28 | if hasattr(m, 'bias') and m.bias is not None:
29 | nn.init.constant_(m.bias.data, 0.0)
30 |
31 | elif classname.find('BatchNorm2d') != -1:
32 | nn.init.normal_(m.weight.data, 1.0, gain)
33 | nn.init.constant_(m.bias.data, 0.0)
34 |
35 | self.apply(init_func)
36 |
37 |
38 | class InpaintGenerator(BaseNetwork):
39 | def __init__(self, residual_blocks=8, init_weights=True):
40 | super(InpaintGenerator, self).__init__()
41 |
42 | self.encoder = nn.Sequential(
43 | nn.ReflectionPad2d(3),
44 | nn.Conv2d(in_channels=4, out_channels=64, kernel_size=7, padding=0),
45 | nn.InstanceNorm2d(64, track_running_stats=False),
46 | nn.ReLU(True),
47 |
48 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1),
49 | nn.InstanceNorm2d(128, track_running_stats=False),
50 | nn.ReLU(True),
51 |
52 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1),
53 | nn.InstanceNorm2d(256, track_running_stats=False),
54 | nn.ReLU(True)
55 | )
56 |
57 | blocks = []
58 | for _ in range(residual_blocks):
59 | block = ResnetBlock(256, 2)
60 | blocks.append(block)
61 |
62 | self.middle = nn.Sequential(*blocks)
63 |
64 | self.decoder = nn.Sequential(
65 | nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1),
66 | nn.InstanceNorm2d(128, track_running_stats=False),
67 | nn.ReLU(True),
68 |
69 | nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1),
70 | nn.InstanceNorm2d(64, track_running_stats=False),
71 | nn.ReLU(True),
72 |
73 | nn.ReflectionPad2d(3),
74 | nn.Conv2d(in_channels=64, out_channels=3, kernel_size=7, padding=0),
75 | )
76 |
77 | if init_weights:
78 | self.init_weights()
79 |
80 | def forward(self, x):
81 | x = self.encoder(x)
82 | x = self.middle(x)
83 | x = self.decoder(x)
84 | x = (torch.tanh(x) + 1) / 2
85 |
86 | return x
87 |
88 |
89 | class EdgeGenerator(BaseNetwork):
90 | def __init__(self, residual_blocks=8, use_spectral_norm=True, init_weights=True):
91 | super(EdgeGenerator, self).__init__()
92 |
93 | self.encoder = nn.Sequential(
94 | nn.ReflectionPad2d(3),
95 | spectral_norm(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, padding=0), use_spectral_norm),
96 | nn.InstanceNorm2d(64, track_running_stats=False),
97 | nn.ReLU(True),
98 |
99 | spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1), use_spectral_norm),
100 | nn.InstanceNorm2d(128, track_running_stats=False),
101 | nn.ReLU(True),
102 |
103 | spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1), use_spectral_norm),
104 | nn.InstanceNorm2d(256, track_running_stats=False),
105 | nn.ReLU(True)
106 | )
107 |
108 | blocks = []
109 | for _ in range(residual_blocks):
110 | block = ResnetBlock(256, 2, use_spectral_norm=use_spectral_norm)
111 | blocks.append(block)
112 |
113 | self.middle = nn.Sequential(*blocks)
114 |
115 | self.decoder = nn.Sequential(
116 | spectral_norm(nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1), use_spectral_norm),
117 | nn.InstanceNorm2d(128, track_running_stats=False),
118 | nn.ReLU(True),
119 |
120 | spectral_norm(nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1), use_spectral_norm),
121 | nn.InstanceNorm2d(64, track_running_stats=False),
122 | nn.ReLU(True),
123 |
124 | nn.ReflectionPad2d(3),
125 | nn.Conv2d(in_channels=64, out_channels=1, kernel_size=7, padding=0),
126 | )
127 |
128 | if init_weights:
129 | self.init_weights()
130 |
131 | def forward(self, x):
132 | x = self.encoder(x)
133 | x = self.middle(x)
134 | x = self.decoder(x)
135 | x = torch.sigmoid(x)
136 | return x
137 |
138 |
139 | class Discriminator(BaseNetwork):
140 | def __init__(self, in_channels, use_sigmoid=True, use_spectral_norm=True, init_weights=True):
141 | super(Discriminator, self).__init__()
142 | self.use_sigmoid = use_sigmoid
143 |
144 | self.conv1 = self.features = nn.Sequential(
145 | spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm),
146 | nn.LeakyReLU(0.2, inplace=True),
147 | )
148 |
149 | self.conv2 = nn.Sequential(
150 | spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm),
151 | nn.LeakyReLU(0.2, inplace=True),
152 | )
153 |
154 | self.conv3 = nn.Sequential(
155 | spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm),
156 | nn.LeakyReLU(0.2, inplace=True),
157 | )
158 |
159 | self.conv4 = nn.Sequential(
160 | spectral_norm(nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=1, padding=1, bias=not use_spectral_norm), use_spectral_norm),
161 | nn.LeakyReLU(0.2, inplace=True),
162 | )
163 |
164 | self.conv5 = nn.Sequential(
165 | spectral_norm(nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, padding=1, bias=not use_spectral_norm), use_spectral_norm),
166 | )
167 |
168 | if init_weights:
169 | self.init_weights()
170 |
171 | def forward(self, x):
172 | conv1 = self.conv1(x)
173 | conv2 = self.conv2(conv1)
174 | conv3 = self.conv3(conv2)
175 | conv4 = self.conv4(conv3)
176 | conv5 = self.conv5(conv4)
177 |
178 | outputs = conv5
179 | if self.use_sigmoid:
180 | outputs = torch.sigmoid(conv5)
181 |
182 | return outputs, [conv1, conv2, conv3, conv4, conv5]
183 |
184 |
185 | class ResnetBlock(nn.Module):
186 | def __init__(self, dim, dilation=1, use_spectral_norm=False):
187 | super(ResnetBlock, self).__init__()
188 | self.conv_block = nn.Sequential(
189 | nn.ReflectionPad2d(dilation),
190 | spectral_norm(nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=dilation, bias=not use_spectral_norm), use_spectral_norm),
191 | nn.InstanceNorm2d(dim, track_running_stats=False),
192 | nn.ReLU(True),
193 |
194 | nn.ReflectionPad2d(1),
195 | spectral_norm(nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=1, bias=not use_spectral_norm), use_spectral_norm),
196 | nn.InstanceNorm2d(dim, track_running_stats=False),
197 | )
198 |
199 | def forward(self, x):
200 | out = x + self.conv_block(x)
201 |
202 | # Remove ReLU at the end of the residual block
203 | # http://torch.ch/blog/2016/02/04/resnets.html
204 |
205 | return out
206 |
207 |
208 | def spectral_norm(module, mode=True):
209 | if mode:
210 | return nn.utils.spectral_norm(module)
211 |
212 | return module
213 |
--------------------------------------------------------------------------------
/src/segmentor_fcn.py:
--------------------------------------------------------------------------------
1 | from torchvision import models
2 | from PIL import Image
3 | import torchvision.transforms as T
4 | import matplotlib.pyplot as plt
5 | import torch
6 | import numpy as np
7 | from imageio import imread
8 | from skimage.color import rgb2gray, gray2rgb
9 | import cv2
10 |
11 | fcn = models.segmentation.fcn_resnet101(pretrained=True).eval()
12 | dlab = models.segmentation.deeplabv3_resnet101(pretrained=1).eval()
13 |
14 | def decode_segmap(image,objects,nc=21):
15 |
16 | r = np.zeros_like(image).astype(np.uint8)
17 | for l in objects:
18 | idx = image == l
19 | r[idx] = 255#fill r with 255 wherever class is 1 and so on
20 | return np.array(r)
21 |
22 |
23 | def fill_gaps(values):
24 | searchval=[255,0,255]
25 | searchval2=[255,0,0,255]
26 | idx=(np.array(np.where((values[:-2]==searchval[0]) & (values[1:-1]==searchval[1]) & (values[2:]==searchval[2])))+1)
27 | idx2=(np.array(np.where((values[:-3]==searchval2[0]) & (values[1:-2]==searchval2[1]) & (values[2:-1]==searchval2[2]) & (values[3:]==searchval2[3])))+1)
28 | idx3=(idx2+1)
29 | new=idx.tolist()+idx2.tolist()+idx3.tolist()
30 | newlist = [item for items in new for item in items]
31 | values[newlist]=255
32 | return values
33 |
34 | def fill_gaps2(values):
35 | searchval=[0,255]
36 | searchval2=[255,0]
37 | idx=(np.array(np.where((values[:-1]==searchval[0]) & (values[1:]==searchval[1]))))
38 | idx2=(np.array(np.where((values[:-1]==searchval[0]) & (values[1:]==searchval[1])))+1)
39 |
40 | new=idx.tolist()+idx2.tolist()
41 | newlist = [item for items in new for item in items]
42 | values[newlist]=255
43 | return values
44 |
45 |
46 | def remove_patch_og(real_img,mask):
47 | og_data = real_img.copy()
48 | idx = mask == 255 ### cutting out mask part from real image here
49 | og_data[idx] =255
50 | return og_data
51 |
52 |
53 |
54 |
55 | def segmentor(seg_net,img,dev,objects):
56 | #plt.imshow(img); plt.show()
57 | if seg_net==1:
58 | net=fcn
59 | else:
60 | net=dlab
61 | if dev == 'cuda':
62 | trf = T.Compose([T.Resize(400),
63 | #T.CenterCrop(224),
64 | T.ToTensor(),
65 | T.Normalize(mean = [0.485, 0.456, 0.406],
66 | std = [0.229, 0.224, 0.225])])
67 | else:
68 | trf = T.Compose([T.Resize(680),
69 | #T.CenterCrop(224),
70 | T.ToTensor(),
71 | T.Normalize(mean = [0.485, 0.456, 0.406],
72 | std = [0.229, 0.224, 0.225])])
73 | inp = trf(img).unsqueeze(0).to(dev)
74 | out = net.to(dev)(inp)['out']
75 | om = torch.argmax(out.squeeze(), dim=0).detach().cpu().numpy()
76 | mask=decode_segmap(om,objects)
77 | height,width =mask.shape
78 | img=np.array(img.resize((width, height), Image.ANTIALIAS))
79 |
80 |
81 | og_img=remove_patch_og(img,mask)
82 | #plt.imshow(mask); plt.show()
83 | return og_img,mask
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import random
5 | import numpy as np
6 | import matplotlib.pyplot as plt
7 | from PIL import Image
8 |
9 |
10 | def create_dir(dir):
11 | if not os.path.exists(dir):
12 | os.makedirs(dir)
13 |
14 |
15 | def create_mask(width, height, mask_width, mask_height, x=None, y=None):
16 | mask = np.zeros((height, width))
17 | mask_x = x if x is not None else random.randint(0, width - mask_width)
18 | mask_y = y if y is not None else random.randint(0, height - mask_height)
19 | mask[mask_y:mask_y + mask_height, mask_x:mask_x + mask_width] = 1
20 | return mask
21 |
22 |
23 | def stitch_images(inputs, *outputs, img_per_row=2):
24 | gap = 5
25 | columns = len(outputs) + 1
26 |
27 | width, height = inputs[0][:, :, 0].shape
28 | img = Image.new('RGB', (width * img_per_row * columns + gap * (img_per_row - 1), height * int(len(inputs) / img_per_row)))
29 | images = [inputs, *outputs]
30 |
31 | for ix in range(len(inputs)):
32 | xoffset = int(ix % img_per_row) * width * columns + int(ix % img_per_row) * gap
33 | yoffset = int(ix / img_per_row) * height
34 |
35 | for cat in range(len(images)):
36 | im = np.array((images[cat][ix]).cpu()).astype(np.uint8).squeeze()
37 | im = Image.fromarray(im)
38 | img.paste(im, (xoffset + cat * width, yoffset))
39 |
40 | return img
41 |
42 |
43 | def imshow(img, title=''):
44 | fig = plt.gcf()
45 | fig.canvas.set_window_title(title)
46 | plt.axis('off')
47 | plt.imshow(img, interpolation='none')
48 | plt.show()
49 |
50 |
51 | def imsave(img, path):
52 | im = Image.fromarray(img.cpu().numpy().astype(np.uint8).squeeze())
53 |
54 | im.save(path)
55 |
56 |
57 | class Progbar(object):
58 | """Displays a progress bar.
59 |
60 | Arguments:
61 | target: Total number of steps expected, None if unknown.
62 | width: Progress bar width on screen.
63 | verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose)
64 | stateful_metrics: Iterable of string names of metrics that
65 | should *not* be averaged over time. Metrics in this list
66 | will be displayed as-is. All others will be averaged
67 | by the progbar before display.
68 | interval: Minimum visual progress update interval (in seconds).
69 | """
70 |
71 | def __init__(self, target, width=25, verbose=1, interval=0.05,
72 | stateful_metrics=None):
73 | self.target = target
74 | self.width = width
75 | self.verbose = verbose
76 | self.interval = interval
77 | if stateful_metrics:
78 | self.stateful_metrics = set(stateful_metrics)
79 | else:
80 | self.stateful_metrics = set()
81 |
82 | self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and
83 | sys.stdout.isatty()) or
84 | 'ipykernel' in sys.modules or
85 | 'posix' in sys.modules)
86 | self._total_width = 0
87 | self._seen_so_far = 0
88 | # We use a dict + list to avoid garbage collection
89 | # issues found in OrderedDict
90 | self._values = {}
91 | self._values_order = []
92 | self._start = time.time()
93 | self._last_update = 0
94 |
95 | def update(self, current, values=None):
96 | """Updates the progress bar.
97 |
98 | Arguments:
99 | current: Index of current step.
100 | values: List of tuples:
101 | `(name, value_for_last_step)`.
102 | If `name` is in `stateful_metrics`,
103 | `value_for_last_step` will be displayed as-is.
104 | Else, an average of the metric over time will be displayed.
105 | """
106 | values = values or []
107 | for k, v in values:
108 | if k not in self._values_order:
109 | self._values_order.append(k)
110 | if k not in self.stateful_metrics:
111 | if k not in self._values:
112 | self._values[k] = [v * (current - self._seen_so_far),
113 | current - self._seen_so_far]
114 | else:
115 | self._values[k][0] += v * (current - self._seen_so_far)
116 | self._values[k][1] += (current - self._seen_so_far)
117 | else:
118 | self._values[k] = v
119 | self._seen_so_far = current
120 |
121 | now = time.time()
122 | info = ' - %.0fs' % (now - self._start)
123 | if self.verbose == 1:
124 | if (now - self._last_update < self.interval and
125 | self.target is not None and current < self.target):
126 | return
127 |
128 | prev_total_width = self._total_width
129 | if self._dynamic_display:
130 | sys.stdout.write('\b' * prev_total_width)
131 | sys.stdout.write('\r')
132 | else:
133 | sys.stdout.write('\n')
134 |
135 | if self.target is not None:
136 | numdigits = int(np.floor(np.log10(self.target))) + 1
137 | barstr = '%%%dd/%d [' % (numdigits, self.target)
138 | bar = barstr % current
139 | prog = float(current) / self.target
140 | prog_width = int(self.width * prog)
141 | if prog_width > 0:
142 | bar += ('=' * (prog_width - 1))
143 | if current < self.target:
144 | bar += '>'
145 | else:
146 | bar += '='
147 | bar += ('.' * (self.width - prog_width))
148 | bar += ']'
149 | else:
150 | bar = '%7d/Unknown' % current
151 |
152 | self._total_width = len(bar)
153 | sys.stdout.write(bar)
154 |
155 | if current:
156 | time_per_unit = (now - self._start) / current
157 | else:
158 | time_per_unit = 0
159 | if self.target is not None and current < self.target:
160 | eta = time_per_unit * (self.target - current)
161 | if eta > 3600:
162 | eta_format = '%d:%02d:%02d' % (eta // 3600,
163 | (eta % 3600) // 60,
164 | eta % 60)
165 | elif eta > 60:
166 | eta_format = '%d:%02d' % (eta // 60, eta % 60)
167 | else:
168 | eta_format = '%ds' % eta
169 |
170 | info = ' - ETA: %s' % eta_format
171 | else:
172 | if time_per_unit >= 1:
173 | info += ' %.0fs/step' % time_per_unit
174 | elif time_per_unit >= 1e-3:
175 | info += ' %.0fms/step' % (time_per_unit * 1e3)
176 | else:
177 | info += ' %.0fus/step' % (time_per_unit * 1e6)
178 |
179 | for k in self._values_order:
180 | info += ' - %s:' % k
181 | if isinstance(self._values[k], list):
182 | avg = np.mean(self._values[k][0] / max(1, self._values[k][1]))
183 | if abs(avg) > 1e-3:
184 | info += ' %.4f' % avg
185 | else:
186 | info += ' %.4e' % avg
187 | else:
188 | info += ' %s' % self._values[k]
189 |
190 | self._total_width += len(info)
191 | if prev_total_width > self._total_width:
192 | info += (' ' * (prev_total_width - self._total_width))
193 |
194 | if self.target is not None and current >= self.target:
195 | info += '\n'
196 |
197 | sys.stdout.write(info)
198 | sys.stdout.flush()
199 |
200 | elif self.verbose == 2:
201 | if self.target is None or current >= self.target:
202 | for k in self._values_order:
203 | info += ' - %s:' % k
204 | avg = np.mean(self._values[k][0] / max(1, self._values[k][1]))
205 | if avg > 1e-3:
206 | info += ' %.4f' % avg
207 | else:
208 | info += ' %.4e' % avg
209 | info += '\n'
210 |
211 | sys.stdout.write(info)
212 | sys.stdout.flush()
213 |
214 | self._last_update = now
215 |
216 | def add(self, n, values=None):
217 | self.update(self._seen_so_far + n, values)
218 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | from main import main
2 | main(mode=2)
--------------------------------------------------------------------------------