├── .circleci └── config.yml ├── .github └── workflows │ ├── ci.yml │ └── main.yml ├── .travis.yml ├── Aptfile ├── Dockerfile ├── LICENSE ├── Procfile ├── README.md ├── app.py ├── azure-pipelines.yml ├── circleci └── config.yml ├── image_save └── 下载.png ├── inference_onnx.py ├── input └── 下载.png ├── output └── 下载.png ├── packages.txt ├── pretrained └── modnet.onnx ├── requirements.txt ├── run.sh ├── setup.sh ├── src ├── __init__.py ├── image.png ├── models │ ├── __init__.py │ ├── backbones │ │ ├── __init__.py │ │ ├── mobilenetv2.py │ │ └── wrapper.py │ ├── modnet.py │ └── onnx_modnet.py └── trainer.py └── streamlit_app.py /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | jobs: 3 | build: 4 | machine: true 5 | steps: 6 | - checkout 7 | 8 | # build image 9 | - run: | 10 | docker info 11 | docker build -t mod-matting . 12 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: ci 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | autogreen: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - name: Clone repository 13 | uses: actions/checkout@v3 14 | 15 | - name: Auto green 16 | run: | 17 | git config --local user.email "literature1204@gmail.com" 18 | git config --local user.name "LiteraturePro" 19 | git remote set-url origin https://${{ github.actor }}:${{ secrets.GITHUB_TOKEN }}@github.com/${{ github.repository }} 20 | git pull --rebase 21 | git commit --allow-empty -m "a commit a day keeps your girlfriend away" 22 | git push 23 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | # This is a basic workflow to help you get started with Actions 2 | 3 | name: Build Docker images 4 | 5 | # Controls when the action will run. 6 | on: 7 | # Triggers the workflow on push or pull request events but only for the main branch 8 | push: 9 | branches: [ main ] 10 | pull_request: 11 | branches: [ main ] 12 | schedule: # 定时任务 13 | - cron: "0 0 * * *" # 每天 0 点跑 => 东八区 8点 14 | # Allows you to run this workflow manually from the Actions tab 15 | workflow_dispatch: 16 | 17 | # A workflow run is made up of one or more jobs that can run sequentially or in parallel 18 | jobs: 19 | # This workflow contains a single job called "build" 20 | build: 21 | # The type of runner that the job will run on 22 | runs-on: ubuntu-latest 23 | 24 | # Steps represent a sequence of tasks that will be executed as part of the job 25 | steps: 26 | # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it 27 | - uses: actions/checkout@v2 28 | 29 | # Runs a single command using the runners shell 30 | - name: Run Build Docker images 31 | run: docker build -t mod-matting . 32 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | services: 3 | - docker 4 | script: 5 | - docker build -t modnet-matting:latest . 6 | -------------------------------------------------------------------------------- /Aptfile: -------------------------------------------------------------------------------- 1 | libgl1-mesa-glx 2 | libglib2.0-dev 3 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Use the official lightweight Python image. 2 | # https://hub.docker.com/_/python 3 | FROM python:3.10-slim 4 | 5 | # Copy local code to the container image. 6 | ENV APP_HOME /app 7 | WORKDIR $APP_HOME 8 | COPY . ./ 9 | 10 | RUN apt update && apt install -y \ 11 | libgl1-mesa-glx \ 12 | libglib2.0-dev 13 | 14 | # Install production dependencies. 15 | RUN pip install --upgrade pip 16 | 17 | RUN pip install -r requirements.txt 18 | 19 | # CPU 20 | #RUN pip install torch==1.7.1+cpu torchvision==0.8.2+cpu torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html 21 | RUN pip install torch==1.13.0+cpu torchvision==0.14.0+cpu torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cpu 22 | 23 | # Run the web service on container startup. Here we use the gunicorn 24 | # webserver, with one worker process and 8 threads. 25 | # For environments with multiple CPU cores, increase the number of workers 26 | # to be equal to the cores available. 27 | CMD exec gunicorn --bind 0.0.0.0:8080 --workers 1 --threads 8 --timeout 0 app:app 28 | 29 | # CMD exec gunicorn --bind 0.0.0.0:$PORT --workers 1 --threads 8 --timeout 0 app:app 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Procfile: -------------------------------------------------------------------------------- 1 | web: sh setup.sh && streamlit run streamlit_app.py -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MODNet 2 | 3 | ## Version Web for MODNet-model Human Matting 4 | 5 | > Convert images on web ! 6 | 7 | The webapp is deployed Heroku Web here - https://modnet.herokuapp.com/ (abandoned) 8 | 9 | The webapp is deployed IBM Cloud Foundry Web here - https://modnet.mybluemix.net/ (abandoned) 10 | 11 | The webapp is deployed Streamlit Web here - https://modnet.streamlit.app/ 12 | ![](https://pcdn.wxiou.cn/20210309203225.png) 13 | 14 | --- 15 | 16 | [![DigitalOcean Referral Badge](https://web-platforms.sfo2.digitaloceanspaces.com/WWW/Badge%202.svg)](https://www.digitalocean.com/?refcode=7765f4d750f0&utm_campaign=Referral_Invite&utm_medium=Referral_Program&utm_source=badge) 17 | ## Docker version API for MODNet-model Human Matting 18 | 19 | > Convert images on api ! 20 | 21 | The webapp is deployed Divio-Online here - https://modnet.us.aldryn.io/ 22 | 23 | The webapp is deployed Divio-Test here - https://modnet-stage.us.aldryn.io/ 24 | 25 | The webapp is deployed Heroku here - https://modnet-demo.herokuapp.com/ (abandoned) 26 | 27 | The webapp is deployed Aliyun Severless here - http://modnet.ovzv.cn/ (abandoned) 28 | 29 | The webapp is deployed AWS Lambda here - 30 | 31 | ![](https://pcdn.wxiou.cn/20210309204034.png) 32 | 33 | --- 34 | 35 | ## What is this? 36 | 37 | - Original address of the project[MODNet](https://github.com/ZHKKKe/MODNet) 38 | - The project used by this version[MODNet Onnx](https://github.com/manthan3C273/MODNet/) 39 | 40 | ## Update 41 | [Use the associated applications of the model](https://github.com/LiteraturePro/Wx-Photo/) 42 | 43 | ### Explain 44 | MODNet-model Human Matting(Look at the picture) 45 | 46 | ![](https://pcdn.wxiou.cn/20210221141938.png) 47 | ![](https://pcdn.wxiou.cn/20210301145423.png) 48 | 49 | 50 | > This project is to package the matting program implemented by modnet algorithm as docker image to provide API calling service. If you don't know modnet, please read the original author's warehouse first. What I'm going to talk about is to use docker to build modnet as an API for calling. Of course, you can also directly run the `app.py` in the form of flash. Docker is used to avoid configuration environment errors. 51 | > Modnet can run on GPU or CPU. This project can use GPU or CPU. 52 | 53 | 54 | ## Compiled project on [hub.docker.com](https://hub.docker.com/) 55 | 56 | - [Normal version](https://hub.docker.com/layers/literature/modnet-matting/latest/images/sha256-65e14b60a5c155eec1d3607806456d5a269a169f7c4fdd5c760846fc0b0c3eb4?context=repo) 57 | - [Heroku version](https://hub.docker.com/layers/literature/modnet-matting/heroku/images/sha256-c3465a45ed6655969851f5e7fb5438c7837063b6143164672fded4cbf1a0e4f2?context=repo) 58 | - [Aliyun version](https://hub.docker.com/layers/literature/modnet-matting/sf/images/sha256-ec3423318458b00d342950a5c40061c16636f5875319bf33b6afe86b65389a51?context=repo) 59 | ## Build 60 | > Make sure you have `docker` installed 61 | 62 | 1. Clone the MODNet repository: 63 | ``` 64 | git clone https://github.com/LiteraturePro/MODNet.git 65 | cd MODNet 66 | ``` 67 | 2. Input command to build image: 68 | ``` 69 | docker build -t mod-matting . 70 | ``` 71 | - I also provided the compilation command for `Heroku`, just replace the last command of dockerfile file with each other, 72 | - For general 73 | ``` 74 | CMD exec gunicorn --bind 0.0.0.0:8080 --workers 1 --threads 8 --timeout 0 app:app 75 | ``` 76 | - For heroku 77 | ``` 78 | CMD exec gunicorn --bind 0.0.0.0:$PORT --workers 1 --threads 8 --timeout 0 app:app 79 | ``` 80 | - For Aliyun Severless 81 | - For Alibaba cloud functional computing service, it specifies that the service must run on port 9000, so two locations need to be changed, one of which is as follows 82 | ``` 83 | CMD exec gunicorn --bind 0.0.0.0:9000 --workers 1 --threads 8 --timeout 0 app:app 84 | ``` 85 | - The other one needs to be changed `app.py` Change `8080` in to `9000` port 86 | 87 | 3. Running image (You can specify the running port yourself): 88 | ``` 89 | docker run -p 8080:8080 mod-matting 90 | ``` 91 | ## Install 92 | > Make sure you have `docker` installed 93 | 94 | I have built the image and can install it directly. The installation command is as follows(You can specify the running port yourself): 95 | - For general 96 | ``` 97 | docker pull literature/modnet-matting:latest 98 | docker run -p 8080:8080 literature/modnet-matting:latest 99 | ``` 100 | - For heroku 101 | ``` 102 | docker pull literature/modnet-matting:heroku 103 | ``` 104 | [Please see the specific tutorial for installing container application in heroku](https://github.com/LiteraturePro/Cartoonize#using-heroku) 105 | 106 | - For Aliyun Serverless 107 | ``` 108 | docker pull literature/modnet-matting:sf 109 | ``` 110 | [Please see the specific tutorial for installing container application in Aliyun Serverless](https://github.com/LiteraturePro/Cartoonize#using-aliyun-severless) 111 | 112 | Now your service has started to run, but it runs on the local port. If you need to realize the external network call, you need to act as an agent to proxy the service to your domain name, 113 | 114 | 115 | ## Use 116 | > The call I have shown is based on the agent I have done. If you need to call it, you need to do it yourself 117 | 118 | - provided that you have installed `docker`. After you deploy correctly, both `GET` and `POST` requests can be accessed. The actual display is as follows 119 | - `Interface`: `http://your domain/api` or `http://127.0.0.1:8080/api` can be accessed. 120 | - `Parameter`: image `value`: a picture 121 | - `Return value`: the base64 data stream after processing the image 122 | ![](https://pcdn.wxiou.cn/20210221141131.png) 123 | ![](https://pcdn.wxiou.cn/20210221141230.png) 124 | 125 | ## Other 126 | Thanks for the work of the original author and the revised author. If you like, please give a `star`. 127 | 128 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | import uuid 4 | import sys 5 | import cv2 6 | import base64 7 | import logging 8 | import numpy as np 9 | from PIL import Image 10 | from io import BytesIO 11 | from flask import Flask, render_template, make_response, flash 12 | import flask 13 | 14 | import inference_onnx 15 | 16 | import sentry_sdk 17 | from flask import Flask 18 | from sentry_sdk.integrations.flask import FlaskIntegration 19 | 20 | sentry_sdk.init( 21 | dsn="https://1bd65a3f51934290bed6ed507cd2f22c@o513531.ingest.sentry.io/6715509", 22 | integrations=[ 23 | FlaskIntegration(), 24 | ], 25 | 26 | # Set traces_sample_rate to 1.0 to capture 100% 27 | # of transactions for performance monitoring. 28 | # We recommend adjusting this value in production. 29 | traces_sample_rate=1.0 30 | ) 31 | 32 | 33 | app = Flask(__name__) 34 | 35 | #run_with_ngrok(app) #starts ngrok when the app is run 36 | 37 | def convert_bytes_to_image(img_name,img_bytes): 38 | #将bytes结果转化为字节流 39 | bytes_stream = BytesIO(img_bytes) 40 | #读取到图片 41 | roiimg = Image.open(bytes_stream) 42 | img_path = os.path.join('./input', img_name + ".jpg") 43 | imgByteArr = BytesIO() #初始化一个空字节流 44 | roiimg.save(imgByteArr,format('PNG')) #把我们得图片以‘PNG’保存到空字节流 45 | imgByteArr = imgByteArr.getvalue() #无视指针,获取全部内容,类型由io流变成bytes。 46 | with open(img_path,'wb') as f: 47 | f.write(imgByteArr) 48 | 49 | return img_path 50 | 51 | @app.route('/') 52 | @app.route('/api', methods=["POST", "GET"]) 53 | def api(): 54 | try: 55 | img = flask.request.files["image"].read() 56 | img_name = str(uuid.uuid4()) 57 | input_img_path = convert_bytes_to_image(img_name,img) 58 | output_img_path = os.path.join('./output', img_name + ".png") 59 | image_save = os.path.join('./image_save', img_name + ".png") 60 | inference_onnx.main(input_img_path,output_img_path) 61 | image = Image.open(input_img_path) 62 | matte = Image.open(output_img_path) 63 | image.putalpha(matte) 64 | image.save(image_save) 65 | with open(image_save, 'rb') as f: 66 | res = base64.b64encode(f.read()) 67 | if os.path.exists(input_img_path): # 如果文件存在 68 | os.remove(input_img_path) 69 | else: 70 | logging.error('no such file') # 则返回文件不存在 71 | if os.path.exists(output_img_path): # 如果文件存在 72 | os.remove(output_img_path) 73 | else: 74 | logging.error('no such file') # 则返回文件不存在 75 | if os.path.exists(image_save): # 如果文件存在 76 | os.remove(image_save) 77 | else: 78 | logging.error('no such file') # 则返回文件不存在 79 | return res 80 | except Exception as e: 81 | logging.error(e) 82 | return "errorError occurred, please check the log output!" 83 | 84 | 85 | if __name__ == "__main__": 86 | app.run(debug=False, host='0.0.0.0', port=int(os.environ.get('PORT', 8080))) 87 | -------------------------------------------------------------------------------- /azure-pipelines.yml: -------------------------------------------------------------------------------- 1 | # Docker 2 | # Build a Docker image 3 | # https://docs.microsoft.com/azure/devops/pipelines/languages/docker 4 | 5 | trigger: 6 | - main 7 | 8 | resources: 9 | - repo: self 10 | 11 | variables: 12 | tag: '$(Build.BuildId)' 13 | 14 | stages: 15 | - stage: Build 16 | displayName: Build image 17 | jobs: 18 | - job: Build 19 | displayName: Build 20 | pool: 21 | vmImage: ubuntu-latest 22 | steps: 23 | - task: Docker@2 24 | displayName: Build an image 25 | inputs: 26 | command: build 27 | dockerfile: '$(Build.SourcesDirectory)/Dockerfile' 28 | tags: | 29 | $(tag) 30 | -------------------------------------------------------------------------------- /circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | jobs: 3 | build: 4 | machine: true 5 | steps: 6 | - checkout 7 | 8 | # build image 9 | - run: | 10 | docker info 11 | docker build -t mod-matting . 12 | -------------------------------------------------------------------------------- /image_save/下载.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiteraturePro/MODNet/701a6b85e507306e308b8f8a9879142d33f65121/image_save/下载.png -------------------------------------------------------------------------------- /inference_onnx.py: -------------------------------------------------------------------------------- 1 | """ 2 | Inference with onnxruntime 3 | 4 | Arguments: 5 | --image-path --> path to single input image 6 | --output-path --> paht to save generated matte 7 | --model-path --> path to onnx model file 8 | 9 | example: 10 | python inference_onnx.py \ 11 | --image-path=demo.jpg \ 12 | --output-path=matte.png \ 13 | --model-path=modnet.onnx 14 | 15 | Optional: 16 | Generate transparent image without background 17 | """ 18 | import os 19 | import argparse 20 | import cv2 21 | import numpy as np 22 | import onnx 23 | import onnxruntime 24 | from onnx import helper 25 | from PIL import Image 26 | 27 | def main(image_path, output_path): 28 | # define cmd arguments 29 | model_path = './pretrained/modnet.onnx' 30 | 31 | # check input arguments 32 | if not os.path.exists(image_path): 33 | print('Cannot find input path: {0}'.format(image_path)) 34 | exit() 35 | if not os.path.exists(model_path): 36 | print('Cannot find model path: {0}'.format(model_path)) 37 | exit() 38 | 39 | ref_size = 512 40 | 41 | # Get x_scale_factor & y_scale_factor to resize image 42 | def get_scale_factor(im_h, im_w, ref_size): 43 | 44 | if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size: 45 | if im_w >= im_h: 46 | im_rh = ref_size 47 | im_rw = int(im_w / im_h * ref_size) 48 | elif im_w < im_h: 49 | im_rw = ref_size 50 | im_rh = int(im_h / im_w * ref_size) 51 | else: 52 | im_rh = im_h 53 | im_rw = im_w 54 | 55 | im_rw = im_rw - im_rw % 32 56 | im_rh = im_rh - im_rh % 32 57 | 58 | x_scale_factor = im_rw / im_w 59 | y_scale_factor = im_rh / im_h 60 | 61 | return x_scale_factor, y_scale_factor 62 | 63 | ############################################## 64 | # Main Inference part 65 | ############################################## 66 | 67 | # read image 68 | im = cv2.imread(image_path) 69 | im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) 70 | 71 | # unify image channels to 3 72 | if len(im.shape) == 2: 73 | im = im[:, :, None] 74 | if im.shape[2] == 1: 75 | im = np.repeat(im, 3, axis=2) 76 | elif im.shape[2] == 4: 77 | im = im[:, :, 0:3] 78 | 79 | # normalize values to scale it between -1 to 1 80 | im = (im - 127.5) / 127.5 81 | 82 | im_h, im_w, im_c = im.shape 83 | x, y = get_scale_factor(im_h, im_w, ref_size) 84 | 85 | # resize image 86 | im = cv2.resize(im, None, fx = x, fy = y, interpolation = cv2.INTER_AREA) 87 | 88 | # prepare input shape 89 | im = np.transpose(im) 90 | im = np.swapaxes(im, 1, 2) 91 | im = np.expand_dims(im, axis = 0).astype('float32') 92 | 93 | # Initialize session and get prediction 94 | session = onnxruntime.InferenceSession(model_path, None) 95 | input_name = session.get_inputs()[0].name 96 | output_name = session.get_outputs()[0].name 97 | result = session.run([output_name], {input_name: im}) 98 | 99 | # refine matte 100 | matte = (np.squeeze(result[0]) * 255).astype('uint8') 101 | matte = cv2.resize(matte, dsize=(im_w, im_h), interpolation = cv2.INTER_AREA) 102 | 103 | cv2.imwrite(output_path, matte) 104 | 105 | ############################################## 106 | # Optional - save png image without background 107 | ############################################## 108 | 109 | # im_PIL = Image.open(args.image_path) 110 | # matte = Image.fromarray(matte) 111 | # im_PIL.putalpha(matte) # add alpha channel to keep transparency 112 | # im_PIL.save('without_background.png') 113 | if __name__ == '__main__': 114 | main(image_path, output_path) -------------------------------------------------------------------------------- /input/下载.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiteraturePro/MODNet/701a6b85e507306e308b8f8a9879142d33f65121/input/下载.png -------------------------------------------------------------------------------- /output/下载.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiteraturePro/MODNet/701a6b85e507306e308b8f8a9879142d33f65121/output/下载.png -------------------------------------------------------------------------------- /packages.txt: -------------------------------------------------------------------------------- 1 | libgl1-mesa-glx 2 | libglib2.0-dev 3 | -------------------------------------------------------------------------------- /pretrained/modnet.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiteraturePro/MODNet/701a6b85e507306e308b8f8a9879142d33f65121/pretrained/modnet.onnx -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Flask 2 | gunicorn 3 | onnx 4 | onnxruntime 5 | opencv-python 6 | Pillow 7 | tqdm 8 | streamlit 9 | sentry-sdk[flask] 10 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | streamlit run streamlit_app.py --server.port 7860 2 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | mkdir -p ~/.streamlit/ 2 | 3 | echo "\ 4 | [server]\n\ 5 | headless = true\n\ 6 | port = $PORT\n\ 7 | enableCORS = false\n\ 8 | \n\ 9 | " > ~/.streamlit/config.toml 10 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiteraturePro/MODNet/701a6b85e507306e308b8f8a9879142d33f65121/src/__init__.py -------------------------------------------------------------------------------- /src/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiteraturePro/MODNet/701a6b85e507306e308b8f8a9879142d33f65121/src/image.png -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiteraturePro/MODNet/701a6b85e507306e308b8f8a9879142d33f65121/src/models/__init__.py -------------------------------------------------------------------------------- /src/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .wrapper import * 2 | 3 | 4 | #------------------------------------------------------------------------------ 5 | # Replaceable Backbones 6 | #------------------------------------------------------------------------------ 7 | 8 | SUPPORTED_BACKBONES = { 9 | 'mobilenetv2': MobileNetV2Backbone, 10 | } 11 | -------------------------------------------------------------------------------- /src/models/backbones/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | """ This file is adapted from https://github.com/thuyngch/Human-Segmentation-PyTorch""" 2 | 3 | import math 4 | import json 5 | from functools import reduce 6 | 7 | import torch 8 | from torch import nn 9 | 10 | 11 | #------------------------------------------------------------------------------ 12 | # Useful functions 13 | #------------------------------------------------------------------------------ 14 | 15 | def _make_divisible(v, divisor, min_value=None): 16 | if min_value is None: 17 | min_value = divisor 18 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 19 | # Make sure that round down does not go down by more than 10%. 20 | if new_v < 0.9 * v: 21 | new_v += divisor 22 | return new_v 23 | 24 | 25 | def conv_bn(inp, oup, stride): 26 | return nn.Sequential( 27 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 28 | nn.BatchNorm2d(oup), 29 | nn.ReLU6(inplace=True) 30 | ) 31 | 32 | 33 | def conv_1x1_bn(inp, oup): 34 | return nn.Sequential( 35 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 36 | nn.BatchNorm2d(oup), 37 | nn.ReLU6(inplace=True) 38 | ) 39 | 40 | 41 | #------------------------------------------------------------------------------ 42 | # Class of Inverted Residual block 43 | #------------------------------------------------------------------------------ 44 | 45 | class InvertedResidual(nn.Module): 46 | def __init__(self, inp, oup, stride, expansion, dilation=1): 47 | super(InvertedResidual, self).__init__() 48 | self.stride = stride 49 | assert stride in [1, 2] 50 | 51 | hidden_dim = round(inp * expansion) 52 | self.use_res_connect = self.stride == 1 and inp == oup 53 | 54 | if expansion == 1: 55 | self.conv = nn.Sequential( 56 | # dw 57 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False), 58 | nn.BatchNorm2d(hidden_dim), 59 | nn.ReLU6(inplace=True), 60 | # pw-linear 61 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 62 | nn.BatchNorm2d(oup), 63 | ) 64 | else: 65 | self.conv = nn.Sequential( 66 | # pw 67 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 68 | nn.BatchNorm2d(hidden_dim), 69 | nn.ReLU6(inplace=True), 70 | # dw 71 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False), 72 | nn.BatchNorm2d(hidden_dim), 73 | nn.ReLU6(inplace=True), 74 | # pw-linear 75 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 76 | nn.BatchNorm2d(oup), 77 | ) 78 | 79 | def forward(self, x): 80 | if self.use_res_connect: 81 | return x + self.conv(x) 82 | else: 83 | return self.conv(x) 84 | 85 | 86 | #------------------------------------------------------------------------------ 87 | # Class of MobileNetV2 88 | #------------------------------------------------------------------------------ 89 | 90 | class MobileNetV2(nn.Module): 91 | def __init__(self, in_channels, alpha=1.0, expansion=6, num_classes=1000): 92 | super(MobileNetV2, self).__init__() 93 | self.in_channels = in_channels 94 | self.num_classes = num_classes 95 | input_channel = 32 96 | last_channel = 1280 97 | interverted_residual_setting = [ 98 | # t, c, n, s 99 | [1 , 16, 1, 1], 100 | [expansion, 24, 2, 2], 101 | [expansion, 32, 3, 2], 102 | [expansion, 64, 4, 2], 103 | [expansion, 96, 3, 1], 104 | [expansion, 160, 3, 2], 105 | [expansion, 320, 1, 1], 106 | ] 107 | 108 | # building first layer 109 | input_channel = _make_divisible(input_channel*alpha, 8) 110 | self.last_channel = _make_divisible(last_channel*alpha, 8) if alpha > 1.0 else last_channel 111 | self.features = [conv_bn(self.in_channels, input_channel, 2)] 112 | 113 | # building inverted residual blocks 114 | for t, c, n, s in interverted_residual_setting: 115 | output_channel = _make_divisible(int(c*alpha), 8) 116 | for i in range(n): 117 | if i == 0: 118 | self.features.append(InvertedResidual(input_channel, output_channel, s, expansion=t)) 119 | else: 120 | self.features.append(InvertedResidual(input_channel, output_channel, 1, expansion=t)) 121 | input_channel = output_channel 122 | 123 | # building last several layers 124 | self.features.append(conv_1x1_bn(input_channel, self.last_channel)) 125 | 126 | # make it nn.Sequential 127 | self.features = nn.Sequential(*self.features) 128 | 129 | # building classifier 130 | if self.num_classes is not None: 131 | self.classifier = nn.Sequential( 132 | nn.Dropout(0.2), 133 | nn.Linear(self.last_channel, num_classes), 134 | ) 135 | 136 | # Initialize weights 137 | self._init_weights() 138 | 139 | def forward(self, x, feature_names=None): 140 | # Stage1 141 | x = reduce(lambda x, n: self.features[n](x), list(range(0,2)), x) 142 | # Stage2 143 | x = reduce(lambda x, n: self.features[n](x), list(range(2,4)), x) 144 | # Stage3 145 | x = reduce(lambda x, n: self.features[n](x), list(range(4,7)), x) 146 | # Stage4 147 | x = reduce(lambda x, n: self.features[n](x), list(range(7,14)), x) 148 | # Stage5 149 | x = reduce(lambda x, n: self.features[n](x), list(range(14,19)), x) 150 | 151 | # Classification 152 | if self.num_classes is not None: 153 | x = x.mean(dim=(2,3)) 154 | x = self.classifier(x) 155 | 156 | # Output 157 | return x 158 | 159 | def _load_pretrained_model(self, pretrained_file): 160 | pretrain_dict = torch.load(pretrained_file, map_location='cpu') 161 | model_dict = {} 162 | state_dict = self.state_dict() 163 | print("[MobileNetV2] Loading pretrained model...") 164 | for k, v in pretrain_dict.items(): 165 | if k in state_dict: 166 | model_dict[k] = v 167 | else: 168 | print(k, "is ignored") 169 | state_dict.update(model_dict) 170 | self.load_state_dict(state_dict) 171 | 172 | def _init_weights(self): 173 | for m in self.modules(): 174 | if isinstance(m, nn.Conv2d): 175 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 176 | m.weight.data.normal_(0, math.sqrt(2. / n)) 177 | if m.bias is not None: 178 | m.bias.data.zero_() 179 | elif isinstance(m, nn.BatchNorm2d): 180 | m.weight.data.fill_(1) 181 | m.bias.data.zero_() 182 | elif isinstance(m, nn.Linear): 183 | n = m.weight.size(1) 184 | m.weight.data.normal_(0, 0.01) 185 | m.bias.data.zero_() 186 | -------------------------------------------------------------------------------- /src/models/backbones/wrapper.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import reduce 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from .mobilenetv2 import MobileNetV2 8 | 9 | 10 | class BaseBackbone(nn.Module): 11 | """ Superclass of Replaceable Backbone Model for Semantic Estimation 12 | """ 13 | 14 | def __init__(self, in_channels): 15 | super(BaseBackbone, self).__init__() 16 | self.in_channels = in_channels 17 | 18 | self.model = None 19 | self.enc_channels = [] 20 | 21 | def forward(self, x): 22 | raise NotImplementedError 23 | 24 | def load_pretrained_ckpt(self): 25 | raise NotImplementedError 26 | 27 | 28 | class MobileNetV2Backbone(BaseBackbone): 29 | """ MobileNetV2 Backbone 30 | """ 31 | 32 | def __init__(self, in_channels): 33 | super(MobileNetV2Backbone, self).__init__(in_channels) 34 | 35 | self.model = MobileNetV2(self.in_channels, alpha=1.0, expansion=6, num_classes=None) 36 | self.enc_channels = [16, 24, 32, 96, 1280] 37 | 38 | def forward(self, x): 39 | x = reduce(lambda x, n: self.model.features[n](x), list(range(0, 2)), x) 40 | enc2x = x 41 | x = reduce(lambda x, n: self.model.features[n](x), list(range(2, 4)), x) 42 | enc4x = x 43 | x = reduce(lambda x, n: self.model.features[n](x), list(range(4, 7)), x) 44 | enc8x = x 45 | x = reduce(lambda x, n: self.model.features[n](x), list(range(7, 14)), x) 46 | enc16x = x 47 | x = reduce(lambda x, n: self.model.features[n](x), list(range(14, 19)), x) 48 | enc32x = x 49 | return [enc2x, enc4x, enc8x, enc16x, enc32x] 50 | 51 | def load_pretrained_ckpt(self): 52 | # the pre-trained model is provided by https://github.com/thuyngch/Human-Segmentation-PyTorch 53 | ckpt_path = './pretrained/mobilenetv2_human_seg.ckpt' 54 | if not os.path.exists(ckpt_path): 55 | print('cannot find the pretrained mobilenetv2 backbone') 56 | exit() 57 | 58 | ckpt = torch.load(ckpt_path) 59 | self.model.load_state_dict(ckpt) 60 | -------------------------------------------------------------------------------- /src/models/modnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .backbones import SUPPORTED_BACKBONES 6 | 7 | 8 | #------------------------------------------------------------------------------ 9 | # MODNet Basic Modules 10 | #------------------------------------------------------------------------------ 11 | 12 | class IBNorm(nn.Module): 13 | """ Combine Instance Norm and Batch Norm into One Layer 14 | """ 15 | 16 | def __init__(self, in_channels): 17 | super(IBNorm, self).__init__() 18 | in_channels = in_channels 19 | self.bnorm_channels = int(in_channels / 2) 20 | self.inorm_channels = in_channels - self.bnorm_channels 21 | 22 | self.bnorm = nn.BatchNorm2d(self.bnorm_channels, affine=True) 23 | self.inorm = nn.InstanceNorm2d(self.inorm_channels, affine=False) 24 | 25 | def forward(self, x): 26 | bn_x = self.bnorm(x[:, :self.bnorm_channels, ...].contiguous()) 27 | in_x = self.inorm(x[:, self.bnorm_channels:, ...].contiguous()) 28 | 29 | return torch.cat((bn_x, in_x), 1) 30 | 31 | 32 | class Conv2dIBNormRelu(nn.Module): 33 | """ Convolution + IBNorm + ReLu 34 | """ 35 | 36 | def __init__(self, in_channels, out_channels, kernel_size, 37 | stride=1, padding=0, dilation=1, groups=1, bias=True, 38 | with_ibn=True, with_relu=True): 39 | super(Conv2dIBNormRelu, self).__init__() 40 | 41 | layers = [ 42 | nn.Conv2d(in_channels, out_channels, kernel_size, 43 | stride=stride, padding=padding, dilation=dilation, 44 | groups=groups, bias=bias) 45 | ] 46 | 47 | if with_ibn: 48 | layers.append(IBNorm(out_channels)) 49 | if with_relu: 50 | layers.append(nn.ReLU(inplace=True)) 51 | 52 | self.layers = nn.Sequential(*layers) 53 | 54 | def forward(self, x): 55 | return self.layers(x) 56 | 57 | 58 | class SEBlock(nn.Module): 59 | """ SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf 60 | """ 61 | 62 | def __init__(self, in_channels, out_channels, reduction=1): 63 | super(SEBlock, self).__init__() 64 | self.pool = nn.AdaptiveAvgPool2d(1) 65 | self.fc = nn.Sequential( 66 | nn.Linear(in_channels, int(in_channels // reduction), bias=False), 67 | nn.ReLU(inplace=True), 68 | nn.Linear(int(in_channels // reduction), out_channels, bias=False), 69 | nn.Sigmoid() 70 | ) 71 | 72 | def forward(self, x): 73 | b, c, _, _ = x.size() 74 | w = self.pool(x).view(b, c) 75 | w = self.fc(w).view(b, c, 1, 1) 76 | 77 | return x * w.expand_as(x) 78 | 79 | 80 | #------------------------------------------------------------------------------ 81 | # MODNet Branches 82 | #------------------------------------------------------------------------------ 83 | 84 | class LRBranch(nn.Module): 85 | """ Low Resolution Branch of MODNet 86 | """ 87 | 88 | def __init__(self, backbone): 89 | super(LRBranch, self).__init__() 90 | 91 | enc_channels = backbone.enc_channels 92 | 93 | self.backbone = backbone 94 | self.se_block = SEBlock(enc_channels[4], enc_channels[4], reduction=4) 95 | self.conv_lr16x = Conv2dIBNormRelu(enc_channels[4], enc_channels[3], 5, stride=1, padding=2) 96 | self.conv_lr8x = Conv2dIBNormRelu(enc_channels[3], enc_channels[2], 5, stride=1, padding=2) 97 | self.conv_lr = Conv2dIBNormRelu(enc_channels[2], 1, kernel_size=3, stride=2, padding=1, with_ibn=False, with_relu=False) 98 | 99 | def forward(self, img, inference): 100 | enc_features = self.backbone.forward(img) 101 | enc2x, enc4x, enc32x = enc_features[0], enc_features[1], enc_features[4] 102 | 103 | enc32x = self.se_block(enc32x) 104 | lr16x = F.interpolate(enc32x, scale_factor=2, mode='bilinear', align_corners=False) 105 | lr16x = self.conv_lr16x(lr16x) 106 | lr8x = F.interpolate(lr16x, scale_factor=2, mode='bilinear', align_corners=False) 107 | lr8x = self.conv_lr8x(lr8x) 108 | 109 | pred_semantic = None 110 | if not inference: 111 | lr = self.conv_lr(lr8x) 112 | pred_semantic = torch.sigmoid(lr) 113 | 114 | return pred_semantic, lr8x, [enc2x, enc4x] 115 | 116 | 117 | class HRBranch(nn.Module): 118 | """ High Resolution Branch of MODNet 119 | """ 120 | 121 | def __init__(self, hr_channels, enc_channels): 122 | super(HRBranch, self).__init__() 123 | 124 | self.tohr_enc2x = Conv2dIBNormRelu(enc_channels[0], hr_channels, 1, stride=1, padding=0) 125 | self.conv_enc2x = Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=2, padding=1) 126 | 127 | self.tohr_enc4x = Conv2dIBNormRelu(enc_channels[1], hr_channels, 1, stride=1, padding=0) 128 | self.conv_enc4x = Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1) 129 | 130 | self.conv_hr4x = nn.Sequential( 131 | Conv2dIBNormRelu(3 * hr_channels + 3, 2 * hr_channels, 3, stride=1, padding=1), 132 | Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1), 133 | Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1), 134 | ) 135 | 136 | self.conv_hr2x = nn.Sequential( 137 | Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1), 138 | Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1), 139 | Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1), 140 | Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1), 141 | ) 142 | 143 | self.conv_hr = nn.Sequential( 144 | Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=1, padding=1), 145 | Conv2dIBNormRelu(hr_channels, 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False), 146 | ) 147 | 148 | def forward(self, img, enc2x, enc4x, lr8x, inference): 149 | img2x = F.interpolate(img, scale_factor=1/2, mode='bilinear', align_corners=False) 150 | img4x = F.interpolate(img, scale_factor=1/4, mode='bilinear', align_corners=False) 151 | 152 | enc2x = self.tohr_enc2x(enc2x) 153 | hr4x = self.conv_enc2x(torch.cat((img2x, enc2x), dim=1)) 154 | 155 | enc4x = self.tohr_enc4x(enc4x) 156 | hr4x = self.conv_enc4x(torch.cat((hr4x, enc4x), dim=1)) 157 | 158 | lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False) 159 | hr4x = self.conv_hr4x(torch.cat((hr4x, lr4x, img4x), dim=1)) 160 | 161 | hr2x = F.interpolate(hr4x, scale_factor=2, mode='bilinear', align_corners=False) 162 | hr2x = self.conv_hr2x(torch.cat((hr2x, enc2x), dim=1)) 163 | 164 | pred_detail = None 165 | if not inference: 166 | hr = F.interpolate(hr2x, scale_factor=2, mode='bilinear', align_corners=False) 167 | hr = self.conv_hr(torch.cat((hr, img), dim=1)) 168 | pred_detail = torch.sigmoid(hr) 169 | 170 | return pred_detail, hr2x 171 | 172 | 173 | class FusionBranch(nn.Module): 174 | """ Fusion Branch of MODNet 175 | """ 176 | 177 | def __init__(self, hr_channels, enc_channels): 178 | super(FusionBranch, self).__init__() 179 | self.conv_lr4x = Conv2dIBNormRelu(enc_channels[2], hr_channels, 5, stride=1, padding=2) 180 | 181 | self.conv_f2x = Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1) 182 | self.conv_f = nn.Sequential( 183 | Conv2dIBNormRelu(hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1), 184 | Conv2dIBNormRelu(int(hr_channels / 2), 1, 1, stride=1, padding=0, with_ibn=False, with_relu=False), 185 | ) 186 | 187 | def forward(self, img, lr8x, hr2x): 188 | lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False) 189 | lr4x = self.conv_lr4x(lr4x) 190 | lr2x = F.interpolate(lr4x, scale_factor=2, mode='bilinear', align_corners=False) 191 | 192 | f2x = self.conv_f2x(torch.cat((lr2x, hr2x), dim=1)) 193 | f = F.interpolate(f2x, scale_factor=2, mode='bilinear', align_corners=False) 194 | f = self.conv_f(torch.cat((f, img), dim=1)) 195 | pred_matte = torch.sigmoid(f) 196 | 197 | return pred_matte 198 | 199 | 200 | #------------------------------------------------------------------------------ 201 | # MODNet 202 | #------------------------------------------------------------------------------ 203 | 204 | class MODNet(nn.Module): 205 | """ Architecture of MODNet 206 | """ 207 | 208 | def __init__(self, in_channels=3, hr_channels=32, backbone_arch='mobilenetv2', backbone_pretrained=True): 209 | super(MODNet, self).__init__() 210 | 211 | self.in_channels = in_channels 212 | self.hr_channels = hr_channels 213 | self.backbone_arch = backbone_arch 214 | self.backbone_pretrained = backbone_pretrained 215 | 216 | self.backbone = SUPPORTED_BACKBONES[self.backbone_arch](self.in_channels) 217 | 218 | self.lr_branch = LRBranch(self.backbone) 219 | self.hr_branch = HRBranch(self.hr_channels, self.backbone.enc_channels) 220 | self.f_branch = FusionBranch(self.hr_channels, self.backbone.enc_channels) 221 | 222 | for m in self.modules(): 223 | if isinstance(m, nn.Conv2d): 224 | self._init_conv(m) 225 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d): 226 | self._init_norm(m) 227 | 228 | if self.backbone_pretrained: 229 | self.backbone.load_pretrained_ckpt() 230 | 231 | def forward(self, img, inference): 232 | pred_semantic, lr8x, [enc2x, enc4x] = self.lr_branch(img, inference) 233 | pred_detail, hr2x = self.hr_branch(img, enc2x, enc4x, lr8x, inference) 234 | pred_matte = self.f_branch(img, lr8x, hr2x) 235 | 236 | return pred_semantic, pred_detail, pred_matte 237 | 238 | def freeze_norm(self): 239 | norm_types = [nn.BatchNorm2d, nn.InstanceNorm2d] 240 | for m in self.modules(): 241 | for n in norm_types: 242 | if isinstance(m, n): 243 | m.eval() 244 | continue 245 | 246 | def _init_conv(self, conv): 247 | nn.init.kaiming_uniform_( 248 | conv.weight, a=0, mode='fan_in', nonlinearity='relu') 249 | if conv.bias is not None: 250 | nn.init.constant_(conv.bias, 0) 251 | 252 | def _init_norm(self, norm): 253 | if norm.weight is not None: 254 | nn.init.constant_(norm.weight, 1) 255 | nn.init.constant_(norm.bias, 0) 256 | -------------------------------------------------------------------------------- /src/models/onnx_modnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is a modified version of the original file modnet.py without 3 | "pred_semantic" and "pred_details" as these both returns None when "inference = True" 4 | 5 | And it does not contain "inference" argument which will make it easier to 6 | convert checkpoint into onnx model. 7 | 8 | Refer: 'demo/image_matting/inference_with_ONNX/export_modnet_onnx.py' to export model. 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | from .backbones import SUPPORTED_BACKBONES 16 | 17 | 18 | #------------------------------------------------------------------------------ 19 | # MODNet Basic Modules 20 | #------------------------------------------------------------------------------ 21 | 22 | class IBNorm(nn.Module): 23 | """ Combine Instance Norm and Batch Norm into One Layer 24 | """ 25 | 26 | def __init__(self, in_channels): 27 | super(IBNorm, self).__init__() 28 | in_channels = in_channels 29 | self.bnorm_channels = int(in_channels / 2) 30 | self.inorm_channels = in_channels - self.bnorm_channels 31 | 32 | self.bnorm = nn.BatchNorm2d(self.bnorm_channels, affine=True) 33 | self.inorm = nn.InstanceNorm2d(self.inorm_channels, affine=False) 34 | 35 | def forward(self, x): 36 | bn_x = self.bnorm(x[:, :self.bnorm_channels, ...].contiguous()) 37 | in_x = self.inorm(x[:, self.bnorm_channels:, ...].contiguous()) 38 | 39 | return torch.cat((bn_x, in_x), 1) 40 | 41 | 42 | class Conv2dIBNormRelu(nn.Module): 43 | """ Convolution + IBNorm + ReLu 44 | """ 45 | 46 | def __init__(self, in_channels, out_channels, kernel_size, 47 | stride=1, padding=0, dilation=1, groups=1, bias=True, 48 | with_ibn=True, with_relu=True): 49 | super(Conv2dIBNormRelu, self).__init__() 50 | 51 | layers = [ 52 | nn.Conv2d(in_channels, out_channels, kernel_size, 53 | stride=stride, padding=padding, dilation=dilation, 54 | groups=groups, bias=bias) 55 | ] 56 | 57 | if with_ibn: 58 | layers.append(IBNorm(out_channels)) 59 | if with_relu: 60 | layers.append(nn.ReLU(inplace=True)) 61 | 62 | self.layers = nn.Sequential(*layers) 63 | 64 | def forward(self, x): 65 | return self.layers(x) 66 | 67 | 68 | class SEBlock(nn.Module): 69 | """ SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf 70 | """ 71 | 72 | def __init__(self, in_channels, out_channels, reduction=1): 73 | super(SEBlock, self).__init__() 74 | self.pool = nn.AdaptiveAvgPool2d(1) 75 | self.fc = nn.Sequential( 76 | nn.Linear(in_channels, int(in_channels // reduction), bias=False), 77 | nn.ReLU(inplace=True), 78 | nn.Linear(int(in_channels // reduction), out_channels, bias=False), 79 | nn.Sigmoid() 80 | ) 81 | 82 | def forward(self, x): 83 | b, c, _, _ = x.size() 84 | w = self.pool(x).view(b, c) 85 | w = self.fc(w).view(b, c, 1, 1) 86 | 87 | return x * w.expand_as(x) 88 | 89 | 90 | #------------------------------------------------------------------------------ 91 | # MODNet Branches 92 | #------------------------------------------------------------------------------ 93 | 94 | class LRBranch(nn.Module): 95 | """ Low Resolution Branch of MODNet 96 | """ 97 | 98 | def __init__(self, backbone): 99 | super(LRBranch, self).__init__() 100 | 101 | enc_channels = backbone.enc_channels 102 | 103 | self.backbone = backbone 104 | self.se_block = SEBlock(enc_channels[4], enc_channels[4], reduction=4) 105 | self.conv_lr16x = Conv2dIBNormRelu(enc_channels[4], enc_channels[3], 5, stride=1, padding=2) 106 | self.conv_lr8x = Conv2dIBNormRelu(enc_channels[3], enc_channels[2], 5, stride=1, padding=2) 107 | self.conv_lr = Conv2dIBNormRelu(enc_channels[2], 1, kernel_size=3, stride=2, padding=1, with_ibn=False, with_relu=False) 108 | 109 | def forward(self, img): 110 | enc_features = self.backbone.forward(img) 111 | enc2x, enc4x, enc32x = enc_features[0], enc_features[1], enc_features[4] 112 | 113 | enc32x = self.se_block(enc32x) 114 | lr16x = F.interpolate(enc32x, scale_factor=2, mode='bilinear', align_corners=False) 115 | lr16x = self.conv_lr16x(lr16x) 116 | lr8x = F.interpolate(lr16x, scale_factor=2, mode='bilinear', align_corners=False) 117 | lr8x = self.conv_lr8x(lr8x) 118 | 119 | return lr8x, [enc2x, enc4x] 120 | 121 | 122 | class HRBranch(nn.Module): 123 | """ High Resolution Branch of MODNet 124 | """ 125 | 126 | def __init__(self, hr_channels, enc_channels): 127 | super(HRBranch, self).__init__() 128 | 129 | self.tohr_enc2x = Conv2dIBNormRelu(enc_channels[0], hr_channels, 1, stride=1, padding=0) 130 | self.conv_enc2x = Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=2, padding=1) 131 | 132 | self.tohr_enc4x = Conv2dIBNormRelu(enc_channels[1], hr_channels, 1, stride=1, padding=0) 133 | self.conv_enc4x = Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1) 134 | 135 | self.conv_hr4x = nn.Sequential( 136 | Conv2dIBNormRelu(3 * hr_channels + 3, 2 * hr_channels, 3, stride=1, padding=1), 137 | Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1), 138 | Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1), 139 | ) 140 | 141 | self.conv_hr2x = nn.Sequential( 142 | Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1), 143 | Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1), 144 | Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1), 145 | Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1), 146 | ) 147 | 148 | self.conv_hr = nn.Sequential( 149 | Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=1, padding=1), 150 | Conv2dIBNormRelu(hr_channels, 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False), 151 | ) 152 | 153 | def forward(self, img, enc2x, enc4x, lr8x): 154 | img2x = F.interpolate(img, scale_factor=1/2, mode='bilinear', align_corners=False) 155 | img4x = F.interpolate(img, scale_factor=1/4, mode='bilinear', align_corners=False) 156 | 157 | enc2x = self.tohr_enc2x(enc2x) 158 | hr4x = self.conv_enc2x(torch.cat((img2x, enc2x), dim=1)) 159 | 160 | enc4x = self.tohr_enc4x(enc4x) 161 | hr4x = self.conv_enc4x(torch.cat((hr4x, enc4x), dim=1)) 162 | 163 | lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False) 164 | hr4x = self.conv_hr4x(torch.cat((hr4x, lr4x, img4x), dim=1)) 165 | 166 | hr2x = F.interpolate(hr4x, scale_factor=2, mode='bilinear', align_corners=False) 167 | hr2x = self.conv_hr2x(torch.cat((hr2x, enc2x), dim=1)) 168 | 169 | return hr2x 170 | 171 | 172 | class FusionBranch(nn.Module): 173 | """ Fusion Branch of MODNet 174 | """ 175 | 176 | def __init__(self, hr_channels, enc_channels): 177 | super(FusionBranch, self).__init__() 178 | self.conv_lr4x = Conv2dIBNormRelu(enc_channels[2], hr_channels, 5, stride=1, padding=2) 179 | 180 | self.conv_f2x = Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1) 181 | self.conv_f = nn.Sequential( 182 | Conv2dIBNormRelu(hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1), 183 | Conv2dIBNormRelu(int(hr_channels / 2), 1, 1, stride=1, padding=0, with_ibn=False, with_relu=False), 184 | ) 185 | 186 | def forward(self, img, lr8x, hr2x): 187 | lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False) 188 | lr4x = self.conv_lr4x(lr4x) 189 | lr2x = F.interpolate(lr4x, scale_factor=2, mode='bilinear', align_corners=False) 190 | 191 | f2x = self.conv_f2x(torch.cat((lr2x, hr2x), dim=1)) 192 | f = F.interpolate(f2x, scale_factor=2, mode='bilinear', align_corners=False) 193 | f = self.conv_f(torch.cat((f, img), dim=1)) 194 | pred_matte = torch.sigmoid(f) 195 | 196 | return pred_matte 197 | 198 | 199 | #------------------------------------------------------------------------------ 200 | # MODNet 201 | #------------------------------------------------------------------------------ 202 | 203 | class MODNet(nn.Module): 204 | """ Architecture of MODNet 205 | """ 206 | 207 | def __init__(self, in_channels=3, hr_channels=32, backbone_arch='mobilenetv2', backbone_pretrained=True): 208 | super(MODNet, self).__init__() 209 | 210 | self.in_channels = in_channels 211 | self.hr_channels = hr_channels 212 | self.backbone_arch = backbone_arch 213 | self.backbone_pretrained = backbone_pretrained 214 | 215 | self.backbone = SUPPORTED_BACKBONES[self.backbone_arch](self.in_channels) 216 | 217 | self.lr_branch = LRBranch(self.backbone) 218 | self.hr_branch = HRBranch(self.hr_channels, self.backbone.enc_channels) 219 | self.f_branch = FusionBranch(self.hr_channels, self.backbone.enc_channels) 220 | 221 | for m in self.modules(): 222 | if isinstance(m, nn.Conv2d): 223 | self._init_conv(m) 224 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d): 225 | self._init_norm(m) 226 | 227 | if self.backbone_pretrained: 228 | self.backbone.load_pretrained_ckpt() 229 | 230 | def forward(self, img): 231 | lr8x, [enc2x, enc4x] = self.lr_branch(img) 232 | hr2x = self.hr_branch(img, enc2x, enc4x, lr8x) 233 | pred_matte = self.f_branch(img, lr8x, hr2x) 234 | 235 | return pred_matte 236 | 237 | def freeze_norm(self): 238 | norm_types = [nn.BatchNorm2d, nn.InstanceNorm2d] 239 | for m in self.modules(): 240 | for n in norm_types: 241 | if isinstance(m, n): 242 | m.eval() 243 | continue 244 | 245 | def _init_conv(self, conv): 246 | nn.init.kaiming_uniform_( 247 | conv.weight, a=0, mode='fan_in', nonlinearity='relu') 248 | if conv.bias is not None: 249 | nn.init.constant_(conv.bias, 0) 250 | 251 | def _init_norm(self, norm): 252 | if norm.weight is not None: 253 | nn.init.constant_(norm.weight, 1) 254 | nn.init.constant_(norm.bias, 0) 255 | -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import scipy 3 | import numpy as np 4 | from scipy.ndimage import grey_dilation, grey_erosion 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | __all__ = [ 12 | 'supervised_training_iter', 13 | 'soc_adaptation_iter', 14 | ] 15 | 16 | 17 | # ---------------------------------------------------------------------------------- 18 | # Tool Classes/Functions 19 | # ---------------------------------------------------------------------------------- 20 | 21 | class GaussianBlurLayer(nn.Module): 22 | """ Add Gaussian Blur to a 4D tensors 23 | This layer takes a 4D tensor of {N, C, H, W} as input. 24 | The Gaussian blur will be performed in given channel number (C) splitly. 25 | """ 26 | 27 | def __init__(self, channels, kernel_size): 28 | """ 29 | Arguments: 30 | channels (int): Channel for input tensor 31 | kernel_size (int): Size of the kernel used in blurring 32 | """ 33 | 34 | super(GaussianBlurLayer, self).__init__() 35 | self.channels = channels 36 | self.kernel_size = kernel_size 37 | assert self.kernel_size % 2 != 0 38 | 39 | self.op = nn.Sequential( 40 | nn.ReflectionPad2d(math.floor(self.kernel_size / 2)), 41 | nn.Conv2d(channels, channels, self.kernel_size, 42 | stride=1, padding=0, bias=None, groups=channels) 43 | ) 44 | 45 | self._init_kernel() 46 | 47 | def forward(self, x): 48 | """ 49 | Arguments: 50 | x (torch.Tensor): input 4D tensor 51 | Returns: 52 | torch.Tensor: Blurred version of the input 53 | """ 54 | 55 | if not len(list(x.shape)) == 4: 56 | print('\'GaussianBlurLayer\' requires a 4D tensor as input\n') 57 | exit() 58 | elif not x.shape[1] == self.channels: 59 | print('In \'GaussianBlurLayer\', the required channel ({0}) is' 60 | 'not the same as input ({1})\n'.format(self.channels, x.shape[1])) 61 | exit() 62 | 63 | return self.op(x) 64 | 65 | def _init_kernel(self): 66 | sigma = 0.3 * ((self.kernel_size - 1) * 0.5 - 1) + 0.8 67 | 68 | n = np.zeros((self.kernel_size, self.kernel_size)) 69 | i = math.floor(self.kernel_size / 2) 70 | n[i, i] = 1 71 | kernel = scipy.ndimage.gaussian_filter(n, sigma) 72 | 73 | for name, param in self.named_parameters(): 74 | param.data.copy_(torch.from_numpy(kernel)) 75 | 76 | # ---------------------------------------------------------------------------------- 77 | 78 | 79 | # ---------------------------------------------------------------------------------- 80 | # MODNet Training Functions 81 | # ---------------------------------------------------------------------------------- 82 | 83 | blurer = GaussianBlurLayer(1, 3).cuda() 84 | 85 | 86 | def supervised_training_iter( 87 | modnet, optimizer, image, trimap, gt_matte, 88 | semantic_scale=10.0, detail_scale=10.0, matte_scale=1.0): 89 | """ Supervised training iteration of MODNet 90 | This function trains MODNet for one iteration in a labeled dataset. 91 | 92 | Arguments: 93 | modnet (torch.nn.Module): instance of MODNet 94 | optimizer (torch.optim.Optimizer): optimizer for supervised training 95 | image (torch.autograd.Variable): input RGB image 96 | its pixel values should be normalized 97 | trimap (torch.autograd.Variable): trimap used to calculate the losses 98 | its pixel values can be 0, 0.5, or 1 99 | (foreground=1, background=0, unknown=0.5) 100 | gt_matte (torch.autograd.Variable): ground truth alpha matte 101 | its pixel values are between [0, 1] 102 | semantic_scale (float): scale of the semantic loss 103 | NOTE: please adjust according to your dataset 104 | detail_scale (float): scale of the detail loss 105 | NOTE: please adjust according to your dataset 106 | matte_scale (float): scale of the matte loss 107 | NOTE: please adjust according to your dataset 108 | 109 | Returns: 110 | semantic_loss (torch.Tensor): loss of the semantic estimation [Low-Resolution (LR) Branch] 111 | detail_loss (torch.Tensor): loss of the detail prediction [High-Resolution (HR) Branch] 112 | matte_loss (torch.Tensor): loss of the semantic-detail fusion [Fusion Branch] 113 | 114 | Example: 115 | import torch 116 | from src.models.modnet import MODNet 117 | from src.trainer import supervised_training_iter 118 | 119 | bs = 16 # batch size 120 | lr = 0.01 # learn rate 121 | epochs = 40 # total epochs 122 | 123 | modnet = torch.nn.DataParallel(MODNet()).cuda() 124 | optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9) 125 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1) 126 | 127 | dataloader = CREATE_YOUR_DATALOADER(bs) # NOTE: please finish this function 128 | 129 | for epoch in range(0, epochs): 130 | for idx, (image, trimap, gt_matte) in enumerate(dataloader): 131 | semantic_loss, detail_loss, matte_loss = \ 132 | supervised_training_iter(modnet, optimizer, image, trimap, gt_matte) 133 | lr_scheduler.step() 134 | """ 135 | 136 | global blurer 137 | 138 | # set the model to train mode and clear the optimizer 139 | modnet.train() 140 | optimizer.zero_grad() 141 | 142 | # forward the model 143 | pred_semantic, pred_detail, pred_matte = modnet(image, False) 144 | 145 | # calculate the boundary mask from the trimap 146 | boundaries = (trimap < 0.5) + (trimap > 0.5) 147 | 148 | # calculate the semantic loss 149 | gt_semantic = F.interpolate(gt_matte, scale_factor=1/16, mode='bilinear') 150 | gt_semantic = blurer(gt_semantic) 151 | semantic_loss = torch.mean(F.mse_loss(pred_semantic, gt_semantic)) 152 | semantic_loss = semantic_scale * semantic_loss 153 | 154 | # calculate the detail loss 155 | pred_boundary_detail = torch.where(boundaries, trimap, pred_detail) 156 | gt_detail = torch.where(boundaries, trimap, gt_matte) 157 | detail_loss = torch.mean(F.l1_loss(pred_boundary_detail, gt_detail)) 158 | detail_loss = detail_scale * detail_loss 159 | 160 | # calculate the matte loss 161 | pred_boundary_matte = torch.where(boundaries, trimap, pred_matte) 162 | matte_l1_loss = F.l1_loss(pred_matte, gt_matte) + 4.0 * F.l1_loss(pred_boundary_matte, gt_matte) 163 | matte_compositional_loss = F.l1_loss(image * pred_matte, image * gt_matte) \ 164 | + 4.0 * F.l1_loss(image * pred_boundary_matte, image * gt_matte) 165 | matte_loss = torch.mean(matte_l1_loss + matte_compositional_loss) 166 | matte_loss = matte_scale * matte_loss 167 | 168 | # calculate the final loss, backward the loss, and update the model 169 | loss = semantic_loss + detail_loss + matte_loss 170 | loss.backward() 171 | optimizer.step() 172 | 173 | # for test 174 | return semantic_loss, detail_loss, matte_loss 175 | 176 | 177 | def soc_adaptation_iter( 178 | modnet, backup_modnet, optimizer, image, 179 | soc_semantic_scale=100.0, soc_detail_scale=1.0): 180 | """ Self-Supervised sub-objective consistency (SOC) adaptation iteration of MODNet 181 | This function fine-tunes MODNet for one iteration in an unlabeled dataset. 182 | Note that SOC can only fine-tune a converged MODNet, i.e., MODNet that has been 183 | trained in a labeled dataset. 184 | 185 | Arguments: 186 | modnet (torch.nn.Module): instance of MODNet 187 | backup_modnet (torch.nn.Module): backup of the trained MODNet 188 | optimizer (torch.optim.Optimizer): optimizer for self-supervised SOC 189 | image (torch.autograd.Variable): input RGB image 190 | its pixel values should be normalized 191 | soc_semantic_scale (float): scale of the SOC semantic loss 192 | NOTE: please adjust according to your dataset 193 | soc_detail_scale (float): scale of the SOC detail loss 194 | NOTE: please adjust according to your dataset 195 | 196 | Returns: 197 | soc_semantic_loss (torch.Tensor): loss of the semantic SOC 198 | soc_detail_loss (torch.Tensor): loss of the detail SOC 199 | 200 | Example: 201 | import copy 202 | import torch 203 | from src.models.modnet import MODNet 204 | from src.trainer import soc_adaptation_iter 205 | 206 | bs = 1 # batch size 207 | lr = 0.00001 # learn rate 208 | epochs = 10 # total epochs 209 | 210 | modnet = torch.nn.DataParallel(MODNet()).cuda() 211 | modnet = LOAD_TRAINED_CKPT() # NOTE: please finish this function 212 | 213 | optimizer = torch.optim.Adam(modnet.parameters(), lr=lr, betas=(0.9, 0.99)) 214 | dataloader = CREATE_YOUR_DATALOADER(bs) # NOTE: please finish this function 215 | 216 | for epoch in range(0, epochs): 217 | backup_modnet = copy.deepcopy(modnet) 218 | for idx, (image) in enumerate(dataloader): 219 | soc_semantic_loss, soc_detail_loss = \ 220 | soc_adaptation_iter(modnet, backup_modnet, optimizer, image) 221 | """ 222 | 223 | global blurer 224 | 225 | # set the backup model to eval mode 226 | backup_modnet.eval() 227 | 228 | # set the main model to train mode and freeze its norm layers 229 | modnet.train() 230 | modnet.module.freeze_norm() 231 | 232 | # clear the optimizer 233 | optimizer.zero_grad() 234 | 235 | # forward the main model 236 | pred_semantic, pred_detail, pred_matte = modnet(image, False) 237 | 238 | # forward the backup model 239 | with torch.no_grad(): 240 | _, pred_backup_detail, pred_backup_matte = backup_modnet(image, False) 241 | 242 | # calculate the boundary mask from `pred_matte` and `pred_semantic` 243 | pred_matte_fg = (pred_matte.detach() > 0.1).float() 244 | pred_semantic_fg = (pred_semantic.detach() > 0.1).float() 245 | pred_semantic_fg = F.interpolate(pred_semantic_fg, scale_factor=16, mode='bilinear') 246 | pred_fg = pred_matte_fg * pred_semantic_fg 247 | 248 | n, c, h, w = pred_matte.shape 249 | np_pred_fg = pred_fg.data.cpu().numpy() 250 | np_boundaries = np.zeros([n, c, h, w]) 251 | for sdx in range(0, n): 252 | sample_np_boundaries = np_boundaries[sdx, 0, ...] 253 | sample_np_pred_fg = np_pred_fg[sdx, 0, ...] 254 | 255 | side = int((h + w) / 2 * 0.05) 256 | dilated = grey_dilation(sample_np_pred_fg, size=(side, side)) 257 | eroded = grey_erosion(sample_np_pred_fg, size=(side, side)) 258 | 259 | sample_np_boundaries[np.where(dilated - eroded != 0)] = 1 260 | np_boundaries[sdx, 0, ...] = sample_np_boundaries 261 | 262 | boundaries = torch.tensor(np_boundaries).float().cuda() 263 | 264 | # sub-objectives consistency between `pred_semantic` and `pred_matte` 265 | # generate pseudo ground truth for `pred_semantic` 266 | downsampled_pred_matte = blurer(F.interpolate(pred_matte, scale_factor=1/16, mode='bilinear')) 267 | pseudo_gt_semantic = downsampled_pred_matte.detach() 268 | pseudo_gt_semantic = pseudo_gt_semantic * (pseudo_gt_semantic > 0.01).float() 269 | 270 | # generate pseudo ground truth for `pred_matte` 271 | pseudo_gt_matte = pred_semantic.detach() 272 | pseudo_gt_matte = pseudo_gt_matte * (pseudo_gt_matte > 0.01).float() 273 | 274 | # calculate the SOC semantic loss 275 | soc_semantic_loss = F.mse_loss(pred_semantic, pseudo_gt_semantic) + F.mse_loss(downsampled_pred_matte, pseudo_gt_matte) 276 | soc_semantic_loss = soc_semantic_scale * torch.mean(soc_semantic_loss) 277 | 278 | # NOTE: using the formulas in our paper to calculate the following losses has similar results 279 | # sub-objectives consistency between `pred_detail` and `pred_backup_detail` (on boundaries only) 280 | backup_detail_loss = boundaries * F.l1_loss(pred_detail, pred_backup_detail) 281 | backup_detail_loss = torch.sum(backup_detail_loss, dim=(1,2,3)) / torch.sum(boundaries, dim=(1,2,3)) 282 | backup_detail_loss = torch.mean(backup_detail_loss) 283 | 284 | # sub-objectives consistency between pred_matte` and `pred_backup_matte` (on boundaries only) 285 | backup_matte_loss = boundaries * F.l1_loss(pred_matte, pred_backup_matte) 286 | backup_matte_loss = torch.sum(backup_matte_loss, dim=(1,2,3)) / torch.sum(boundaries, dim=(1,2,3)) 287 | backup_matte_loss = torch.mean(backup_matte_loss) 288 | 289 | soc_detail_loss = soc_detail_scale * (backup_detail_loss + backup_matte_loss) 290 | 291 | # calculate the final loss, backward the loss, and update the model 292 | loss = soc_semantic_loss + soc_detail_loss 293 | 294 | loss.backward() 295 | optimizer.step() 296 | 297 | return soc_semantic_loss, soc_detail_loss 298 | 299 | # ---------------------------------------------------------------------------------- 300 | -------------------------------------------------------------------------------- /streamlit_app.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import streamlit as st 3 | from PIL import Image, ImageOps 4 | import time 5 | import os 6 | from io import BytesIO 7 | import uuid 8 | import inference_onnx 9 | import shutil 10 | 11 | #shutil.rmtree('input') 12 | #os.mkdir('input') 13 | 14 | def convert_bytes_to_image(img__name,img_bytes): 15 | #将bytes结果转化为字节流 16 | bytes_stream = img_bytes 17 | # BytesIO(img_bytes) 18 | #读取到图片 19 | if bytes_stream == None: 20 | pass 21 | else: 22 | # roiimg = Image.open(bytes_stream) 23 | roiimg = bytes_stream 24 | img_path = os.path.join('./input', img__name + ".jpg") 25 | imgByteArr = BytesIO() #初始化一个空字节流 26 | roiimg.save(imgByteArr,format('PNG')) #把我们得图片以‘PNG’保存到空字节流 27 | imgByteArr = imgByteArr.getvalue() #无视指针,获取全部内容,类型由io流变成bytes。 28 | with open(img_path,'wb') as f: 29 | f.write(imgByteArr) 30 | return img_path 31 | 32 | 33 | st.write(""" 34 | # One click human matting 🍄 35 | """ 36 | ) 37 | st.markdown('[![Github](http://jaywcjlove.github.io/sb/github/green-star.svg)](https://github.com/LiteraturePro/MODNet)') 38 | st.write("This is a simple portrait to background web application, implemented using Modnet algorithm.") 39 | imagetest = Image.open('./src/image.png') 40 | st.image(imagetest) 41 | file = st.file_uploader("", type=["jpg","jpeg", "png", "webp"]) 42 | 43 | 44 | if file is None: 45 | st.text("You haven't uploaded an image file.🎓") 46 | else: 47 | imaget = Image.open(file) 48 | # st.write(image) 49 | img_name = str(uuid.uuid4()) 50 | input_img_path = convert_bytes_to_image(img_name,imaget) 51 | output_img_path = os.path.join('./output', img_name + ".png") 52 | image_save = os.path.join('./image_save', img_name + ".png") 53 | 54 | st.image(imaget,caption='Original Image', use_column_width=True) 55 | # images = Image.open("C:\\Users\\literature\\Desktop\\测试图像\\imgs.png") 56 | # st.image(images, caption='Sunrise by the mountains', use_column_width=True) 57 | 58 | if st.button("Start"): 59 | # If the user uploads an image 60 | if imaget is not None: 61 | # Opening our image 62 | # image = Image.open(images) 63 | st.text("Please wait...") 64 | my_bar = st.progress(0) 65 | inference_onnx.main(input_img_path,output_img_path) 66 | image = Image.open(input_img_path) 67 | matte = Image.open(output_img_path) 68 | image.putalpha(matte) 69 | image.save(image_save) 70 | 71 | for percent_complete in range(100): 72 | time.sleep(0.1) 73 | my_bar.progress(percent_complete + 1) 74 | st.success('Done! 🚀') 75 | st.balloons() 76 | if os.path.exists(input_img_path): # 如果文件存在 77 | os.remove(input_img_path) 78 | else: 79 | st.error('no such file') # 则返回文件不存在 80 | 81 | st.image(image_save, caption='Processed Image', use_column_width=True) 82 | st.subheader('Tips: Right click to save the picture!') 83 | if os.path.exists(output_img_path): # 如果文件存在 84 | os.remove(output_img_path) 85 | else: 86 | st.error('no such file') # 则返回文件不存在 87 | if os.path.exists(image_save): # 如果文件存在 88 | os.remove(image_save) 89 | else: 90 | st.error('no such file') # 则返回文件不存在 91 | else: 92 | st.slider("Can you please upload an image 🙇🏽‍♂️") 93 | --------------------------------------------------------------------------------