├── .gitignore ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── bgRemover ├── __init__.py ├── asgi.py ├── settings.py ├── urls.py └── wsgi.py ├── libs ├── __init__.py ├── basnet.py ├── networks.py ├── postprocessing.py ├── preprocessing.py ├── strings.py └── u2net.py ├── manage.py ├── passenger_wsgi.py ├── removerML ├── __init__.py ├── admin.py ├── apps.py ├── migrations │ └── __init__.py ├── models.py ├── remover.py ├── static │ └── css │ │ └── styles.css ├── templates │ └── removerML │ │ └── index.html ├── tests.py ├── urls.py └── views.py ├── requirements.txt ├── setup.sh ├── setup └── download.py └── uploads └── .gitignore /.gitignore: -------------------------------------------------------------------------------- 1 | ### Django ### 2 | *.log 3 | *.pot 4 | *.pyc 5 | __pycache__/ 6 | local_settings.py 7 | db.sqlite3 8 | db.sqlite3-journal 9 | media 10 | models 11 | venv 12 | uploads 13 | 14 | # If your build process includes running collectstatic, then you probably don't need or want to include staticfiles/ 15 | # in your Git repository. Update and uncomment the following line accordingly. 16 | # /staticfiles/ 17 | 18 | ### Django.Python Stack ### 19 | # Byte-compiled / optimized / DLL files 20 | *.py[cod] 21 | *$py.class 22 | 23 | # C extensions 24 | *.so 25 | 26 | # Distribution / packaging 27 | .Python 28 | build/ 29 | develop-eggs/ 30 | dist/ 31 | downloads/ 32 | eggs/ 33 | .eggs/ 34 | lib/ 35 | lib64/ 36 | parts/ 37 | sdist/ 38 | var/ 39 | wheels/ 40 | pip-wheel-metadata/ 41 | share/python-wheels/ 42 | *.egg-info/ 43 | .installed.cfg 44 | *.egg 45 | MANIFEST 46 | 47 | # PyInstaller 48 | # Usually these files are written by a python script from a template 49 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 50 | *.manifest 51 | *.spec 52 | 53 | # Installer logs 54 | pip-log.txt 55 | pip-delete-this-directory.txt 56 | 57 | # Unit test / coverage reports 58 | htmlcov/ 59 | .tox/ 60 | .nox/ 61 | .coverage 62 | .coverage.* 63 | .cache 64 | nosetests.xml 65 | coverage.xml 66 | *.cover 67 | *.py,cover 68 | .hypothesis/ 69 | .pytest_cache/ 70 | 71 | # Translations 72 | *.mo 73 | 74 | # Django stuff: 75 | 76 | # Flask stuff: 77 | instance/ 78 | .webassets-cache 79 | 80 | # Scrapy stuff: 81 | .scrapy 82 | 83 | # Sphinx documentation 84 | docs/_build/ 85 | 86 | # PyBuilder 87 | target/ 88 | 89 | # Jupyter Notebook 90 | .ipynb_checkpoints 91 | 92 | # IPython 93 | profile_default/ 94 | ipython_config.py 95 | 96 | # pyenv 97 | .python-version 98 | 99 | # pipenv 100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 103 | # install all needed dependencies. 104 | #Pipfile.lock 105 | 106 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 107 | __pypackages__/ 108 | 109 | # Celery stuff 110 | celerybeat-schedule 111 | celerybeat.pid 112 | 113 | # SageMath parsed files 114 | *.sage.py 115 | 116 | # Environments 117 | .env 118 | .venv 119 | env/ 120 | venv/ 121 | ENV/ 122 | env.bak/ 123 | venv.bak/ 124 | 125 | # Spyder project settings 126 | .spyderproject 127 | .spyproject 128 | 129 | # Rope project settings 130 | .ropeproject 131 | 132 | # mkdocs documentation 133 | /site 134 | 135 | # mypy 136 | .mypy_cache/ 137 | .dmypy.json 138 | dmypy.json 139 | 140 | # Pyre type checker 141 | .pyre/ 142 | 143 | # pytype static type analyzer 144 | .pytype/ 145 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies both within project spaces and in public spaces 49 | when an individual is representing the project or its community. Examples of 50 | representing a project or community include using an official project e-mail 51 | address, posting via an official social media account, or acting as an appointed 52 | representative at an online or offline event. Representation of a project may be 53 | further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at farjaalahmad@gmail.com. All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # django_bgRemoverML 2 | 3 | A Machine Learning Project integrated with Django to Remove Background from Image . 4 | 5 | ## Installation: 6 | 7 | - git clone https://github.com/FarjaalAhmad/django_bgRemoverML 8 | - cd django_bgRemoverML 9 | - python3 -m pip install -r requirements.txt 10 | - bash setup.sh 11 | - python3 manage.py migrate 12 | - python3 manage.py runserver 13 | 14 | ### Supported OS: 15 | 16 | - Linux 17 | 18 | ### For API Usage: 19 | 20 | Make a POST request to http://localhost:8000/upload with the Following parameters. 21 | image=[BASE64 ENCODED IMAGE HERE] 22 | 23 | ### Bugs: 24 | 25 | - If you found any bugs, Feel Free to create an Issue. 26 | 27 | ### Deployment: 28 | [![DigitalOcean Referral Badge](https://web-platforms.sfo2.digitaloceanspaces.com/WWW/Badge%203.svg)](https://www.digitalocean.com/?refcode=42d61c4435ff&utm_campaign=Referral_Invite&utm_medium=Referral_Program&utm_source=badge) 29 | 30 | You can register here and Get $100 Free for 2 months. 31 | 32 | ### Contribution: 33 | 34 | - If you want to Contribute into this Project, Feel free to make Pull Request. 35 | 36 | 37 | 38 |
39 | 40 | ### Sorry Guys, I'm not able to give any time to this project. Sorry to say but I'm archiving this project. If anybody wants to continue developing/contributing to this project, You can mail me to become a collaborator or fork and continue on your own (Don't forget to send me your repo link, I'll mention that in my readme.md file so that people will follow your new one. 41 | -------------------------------------------------------------------------------- /bgRemover/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FarjaalAhmad/django_bgRemoverML/787737269eb7724481d667d30c6f502812759037/bgRemover/__init__.py -------------------------------------------------------------------------------- /bgRemover/asgi.py: -------------------------------------------------------------------------------- 1 | """ 2 | ASGI config for bgRemover project. 3 | 4 | It exposes the ASGI callable as a module-level variable named ``application``. 5 | 6 | For more information on this file, see 7 | https://docs.djangoproject.com/en/3.0/howto/deployment/asgi/ 8 | """ 9 | 10 | import os 11 | 12 | from django.core.asgi import get_asgi_application 13 | 14 | os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'bgRemover.settings') 15 | 16 | application = get_asgi_application() 17 | -------------------------------------------------------------------------------- /bgRemover/settings.py: -------------------------------------------------------------------------------- 1 | """ 2 | Django settings for bgRemover project. 3 | 4 | Generated by 'django-admin startproject' using Django 3.0.1. 5 | 6 | For more information on this file, see 7 | https://docs.djangoproject.com/en/3.0/topics/settings/ 8 | 9 | For the full list of settings and their values, see 10 | https://docs.djangoproject.com/en/3.0/ref/settings/ 11 | """ 12 | 13 | import os 14 | 15 | # Build paths inside the project like this: os.path.join(BASE_DIR, ...) 16 | BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 17 | 18 | 19 | # Quick-start development settings - unsuitable for production 20 | # See https://docs.djangoproject.com/en/3.0/howto/deployment/checklist/ 21 | 22 | # SECURITY WARNING: keep the secret key used in production secret! 23 | SECRET_KEY = '(e4!6(njxfbhxn*k6u%z5il!cc=(e#)4$aaca^b4n8hiuedu7b' 24 | 25 | # SECURITY WARNING: don't run with debug turned on in production! 26 | DEBUG = True 27 | 28 | ALLOWED_HOSTS = [] 29 | 30 | 31 | # Application definition 32 | 33 | INSTALLED_APPS = [ 34 | 'removerML.apps.RemovermlConfig', 35 | 'django.contrib.admin', 36 | 'django.contrib.auth', 37 | 'django.contrib.contenttypes', 38 | 'django.contrib.sessions', 39 | 'django.contrib.messages', 40 | 'django.contrib.staticfiles', 41 | ] 42 | 43 | MIDDLEWARE = [ 44 | 'django.middleware.security.SecurityMiddleware', 45 | 'django.contrib.sessions.middleware.SessionMiddleware', 46 | 'django.middleware.common.CommonMiddleware', 47 | # 'django.middleware.csrf.CsrfViewMiddleware', 48 | 'django.contrib.auth.middleware.AuthenticationMiddleware', 49 | 'django.contrib.messages.middleware.MessageMiddleware', 50 | 'django.middleware.clickjacking.XFrameOptionsMiddleware', 51 | ] 52 | 53 | ROOT_URLCONF = 'bgRemover.urls' 54 | 55 | TEMPLATES = [ 56 | { 57 | 'BACKEND': 'django.template.backends.django.DjangoTemplates', 58 | 'DIRS': [], 59 | 'APP_DIRS': True, 60 | 'OPTIONS': { 61 | 'context_processors': [ 62 | 'django.template.context_processors.debug', 63 | 'django.template.context_processors.request', 64 | 'django.contrib.auth.context_processors.auth', 65 | 'django.contrib.messages.context_processors.messages', 66 | ], 67 | }, 68 | }, 69 | ] 70 | 71 | WSGI_APPLICATION = 'bgRemover.wsgi.application' 72 | 73 | 74 | # Database 75 | # https://docs.djangoproject.com/en/3.0/ref/settings/#databases 76 | 77 | DATABASES = { 78 | 'default': { 79 | 'ENGINE': 'django.db.backends.sqlite3', 80 | 'NAME': os.path.join(BASE_DIR, 'db.sqlite3'), 81 | } 82 | } 83 | 84 | 85 | # Password validation 86 | # https://docs.djangoproject.com/en/3.0/ref/settings/#auth-password-validators 87 | 88 | AUTH_PASSWORD_VALIDATORS = [ 89 | { 90 | 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', 91 | }, 92 | { 93 | 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', 94 | }, 95 | { 96 | 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator', 97 | }, 98 | { 99 | 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator', 100 | }, 101 | ] 102 | 103 | 104 | # Internationalization 105 | # https://docs.djangoproject.com/en/3.0/topics/i18n/ 106 | 107 | LANGUAGE_CODE = 'en-us' 108 | 109 | TIME_ZONE = 'UTC' 110 | 111 | USE_I18N = True 112 | 113 | USE_L10N = True 114 | 115 | USE_TZ = True 116 | 117 | 118 | # Static files (CSS, JavaScript, Images) 119 | # https://docs.djangoproject.com/en/3.0/howto/static-files/ 120 | 121 | STATIC_URL = '/static/' 122 | 123 | MEDIA_URL = '/uploads/' 124 | MEDIA_ROOT = os.path.join(BASE_DIR, 'uploads') 125 | -------------------------------------------------------------------------------- /bgRemover/urls.py: -------------------------------------------------------------------------------- 1 | """bgRemover URL Configuration 2 | 3 | The `urlpatterns` list routes URLs to views. For more information please see: 4 | https://docs.djangoproject.com/en/3.0/topics/http/urls/ 5 | Examples: 6 | Function views 7 | 1. Add an import: from my_app import views 8 | 2. Add a URL to urlpatterns: path('', views.home, name='home') 9 | Class-based views 10 | 1. Add an import: from other_app.views import Home 11 | 2. Add a URL to urlpatterns: path('', Home.as_view(), name='home') 12 | Including another URLconf 13 | 1. Import the include() function: from django.urls import include, path 14 | 2. Add a URL to urlpatterns: path('blog/', include('blog.urls')) 15 | """ 16 | from django.contrib import admin 17 | from django.urls import path, include 18 | 19 | urlpatterns = [ 20 | # path('admin/', admin.site.urls), 21 | path("", include("removerML.urls")) 22 | ] 23 | -------------------------------------------------------------------------------- /bgRemover/wsgi.py: -------------------------------------------------------------------------------- 1 | """ 2 | WSGI config for bgRemover project. 3 | 4 | It exposes the WSGI callable as a module-level variable named ``application``. 5 | 6 | For more information on this file, see 7 | https://docs.djangoproject.com/en/3.0/howto/deployment/wsgi/ 8 | """ 9 | 10 | import os 11 | 12 | from django.core.wsgi import get_wsgi_application 13 | 14 | os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'bgRemover.settings') 15 | 16 | application = get_wsgi_application() 17 | -------------------------------------------------------------------------------- /libs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FarjaalAhmad/django_bgRemoverML/787737269eb7724481d667d30c6f502812759037/libs/__init__.py -------------------------------------------------------------------------------- /libs/basnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import models 4 | 5 | 6 | ## code from: https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 7 | # __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 8 | # 'resnet152', 'ResNet34P','ResNet50S','ResNet50P','ResNet101P'] 9 | # 10 | # resnet18_dir = '/local/sda4/yqian3/RoadNets/resnet_model/resnet18-5c106cde.pth' 11 | # resnet34_dir = '/local/sda4/yqian3/RoadNets/resnet_model/resnet34-333f7ec4.pth' 12 | # resnet50_dir = '/local/sda4/yqian3/RoadNets/resnet_model/resnet50-19c8e357.pth' 13 | # resnet101_dir = '/local/sda4/yqian3/RoadNets/resnet_model/resnet101-5d3b4d8f.pth' 14 | # 15 | # model_urls = { 16 | # 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 17 | # 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 18 | # 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 19 | # 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 20 | # 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 21 | # } 22 | 23 | def conv3x3(in_planes, out_planes, stride=1): 24 | "3x3 convolution with padding" 25 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 26 | padding=1, bias=False) 27 | 28 | 29 | class BasicBlock(nn.Module): 30 | expansion = 1 31 | 32 | def __init__(self, inplanes, planes, stride=1, downsample=None): 33 | super(BasicBlock, self).__init__() 34 | self.conv1 = conv3x3(inplanes, planes, stride) 35 | self.bn1 = nn.BatchNorm2d(planes) 36 | self.relu = nn.ReLU(inplace=True) 37 | self.conv2 = conv3x3(planes, planes) 38 | self.bn2 = nn.BatchNorm2d(planes) 39 | self.downsample = downsample 40 | self.stride = stride 41 | 42 | def forward(self, x): 43 | residual = x 44 | 45 | out = self.conv1(x) 46 | out = self.bn1(out) 47 | out = self.relu(out) 48 | 49 | out = self.conv2(out) 50 | out = self.bn2(out) 51 | 52 | if self.downsample is not None: 53 | residual = self.downsample(x) 54 | 55 | out += residual 56 | out = self.relu(out) 57 | 58 | return out 59 | 60 | 61 | class BasicBlockDe(nn.Module): 62 | expansion = 1 63 | 64 | def __init__(self, inplanes, planes, stride=1, downsample=None): 65 | super(BasicBlockDe, self).__init__() 66 | 67 | self.convRes = conv3x3(inplanes, planes, stride) 68 | self.bnRes = nn.BatchNorm2d(planes) 69 | self.reluRes = nn.ReLU(inplace=True) 70 | 71 | self.conv1 = conv3x3(inplanes, planes, stride) 72 | self.bn1 = nn.BatchNorm2d(planes) 73 | self.relu = nn.ReLU(inplace=True) 74 | self.conv2 = conv3x3(planes, planes) 75 | self.bn2 = nn.BatchNorm2d(planes) 76 | self.downsample = downsample 77 | self.stride = stride 78 | 79 | def forward(self, x): 80 | residual = self.convRes(x) 81 | residual = self.bnRes(residual) 82 | residual = self.reluRes(residual) 83 | 84 | out = self.conv1(x) 85 | out = self.bn1(out) 86 | out = self.relu(out) 87 | 88 | out = self.conv2(out) 89 | out = self.bn2(out) 90 | 91 | if self.downsample is not None: 92 | residual = self.downsample(x) 93 | 94 | out += residual 95 | out = self.relu(out) 96 | 97 | return out 98 | 99 | 100 | class Bottleneck(nn.Module): 101 | expansion = 4 102 | 103 | def __init__(self, inplanes, planes, stride=1, downsample=None): 104 | super(Bottleneck, self).__init__() 105 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 106 | self.bn1 = nn.BatchNorm2d(planes) 107 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 108 | padding=1, bias=False) 109 | self.bn2 = nn.BatchNorm2d(planes) 110 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 111 | self.bn3 = nn.BatchNorm2d(planes * 4) 112 | self.relu = nn.ReLU(inplace=True) 113 | self.downsample = downsample 114 | self.stride = stride 115 | 116 | def forward(self, x): 117 | residual = x 118 | 119 | out = self.conv1(x) 120 | out = self.bn1(out) 121 | out = self.relu(out) 122 | 123 | out = self.conv2(out) 124 | out = self.bn2(out) 125 | out = self.relu(out) 126 | 127 | out = self.conv3(out) 128 | out = self.bn3(out) 129 | 130 | if self.downsample is not None: 131 | residual = self.downsample(x) 132 | 133 | out += residual 134 | out = self.relu(out) 135 | 136 | return out 137 | 138 | 139 | class RefUnet(nn.Module): 140 | def __init__(self, in_ch, inc_ch): 141 | super(RefUnet, self).__init__() 142 | 143 | self.conv0 = nn.Conv2d(in_ch, inc_ch, 3, padding=1) 144 | 145 | self.conv1 = nn.Conv2d(inc_ch, 64, 3, padding=1) 146 | self.bn1 = nn.BatchNorm2d(64) 147 | self.relu1 = nn.ReLU(inplace=True) 148 | 149 | self.pool1 = nn.MaxPool2d(2, 2, ceil_mode=True) 150 | 151 | self.conv2 = nn.Conv2d(64, 64, 3, padding=1) 152 | self.bn2 = nn.BatchNorm2d(64) 153 | self.relu2 = nn.ReLU(inplace=True) 154 | 155 | self.pool2 = nn.MaxPool2d(2, 2, ceil_mode=True) 156 | 157 | self.conv3 = nn.Conv2d(64, 64, 3, padding=1) 158 | self.bn3 = nn.BatchNorm2d(64) 159 | self.relu3 = nn.ReLU(inplace=True) 160 | 161 | self.pool3 = nn.MaxPool2d(2, 2, ceil_mode=True) 162 | 163 | self.conv4 = nn.Conv2d(64, 64, 3, padding=1) 164 | self.bn4 = nn.BatchNorm2d(64) 165 | self.relu4 = nn.ReLU(inplace=True) 166 | 167 | self.pool4 = nn.MaxPool2d(2, 2, ceil_mode=True) 168 | 169 | 170 | self.conv5 = nn.Conv2d(64, 64, 3, padding=1) 171 | self.bn5 = nn.BatchNorm2d(64) 172 | self.relu5 = nn.ReLU(inplace=True) 173 | 174 | 175 | self.conv_d4 = nn.Conv2d(128, 64, 3, padding=1) 176 | self.bn_d4 = nn.BatchNorm2d(64) 177 | self.relu_d4 = nn.ReLU(inplace=True) 178 | 179 | self.conv_d3 = nn.Conv2d(128, 64, 3, padding=1) 180 | self.bn_d3 = nn.BatchNorm2d(64) 181 | self.relu_d3 = nn.ReLU(inplace=True) 182 | 183 | self.conv_d2 = nn.Conv2d(128, 64, 3, padding=1) 184 | self.bn_d2 = nn.BatchNorm2d(64) 185 | self.relu_d2 = nn.ReLU(inplace=True) 186 | 187 | self.conv_d1 = nn.Conv2d(128, 64, 3, padding=1) 188 | self.bn_d1 = nn.BatchNorm2d(64) 189 | self.relu_d1 = nn.ReLU(inplace=True) 190 | 191 | self.conv_d0 = nn.Conv2d(64, 1, 3, padding=1) 192 | 193 | self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) 194 | 195 | def forward(self, x): 196 | hx = x 197 | hx = self.conv0(hx) 198 | 199 | hx1 = self.relu1(self.bn1(self.conv1(hx))) 200 | hx = self.pool1(hx1) 201 | 202 | hx2 = self.relu2(self.bn2(self.conv2(hx))) 203 | hx = self.pool2(hx2) 204 | 205 | hx3 = self.relu3(self.bn3(self.conv3(hx))) 206 | hx = self.pool3(hx3) 207 | 208 | hx4 = self.relu4(self.bn4(self.conv4(hx))) 209 | hx = self.pool4(hx4) 210 | 211 | hx5 = self.relu5(self.bn5(self.conv5(hx))) 212 | 213 | hx = self.upscore2(hx5) 214 | 215 | d4 = self.relu_d4(self.bn_d4(self.conv_d4(torch.cat((hx, hx4), 1)))) 216 | hx = self.upscore2(d4) 217 | 218 | d3 = self.relu_d3(self.bn_d3(self.conv_d3(torch.cat((hx, hx3), 1)))) 219 | hx = self.upscore2(d3) 220 | 221 | d2 = self.relu_d2(self.bn_d2(self.conv_d2(torch.cat((hx, hx2), 1)))) 222 | hx = self.upscore2(d2) 223 | 224 | d1 = self.relu_d1(self.bn_d1(self.conv_d1(torch.cat((hx, hx1), 1)))) 225 | 226 | residual = self.conv_d0(d1) 227 | 228 | return x + residual 229 | 230 | 231 | class BASNet(nn.Module): 232 | def __init__(self, n_channels, n_classes): 233 | super(BASNet, self).__init__() 234 | 235 | resnet = models.resnet34(pretrained=True) 236 | 237 | # -------------Encoder-------------- 238 | 239 | self.inconv = nn.Conv2d(n_channels, 64, 3, padding=1) 240 | self.inbn = nn.BatchNorm2d(64) 241 | self.inrelu = nn.ReLU(inplace=True) 242 | 243 | # stage 1 244 | self.encoder1 = resnet.layer1 # 224 245 | # stage 2 246 | self.encoder2 = resnet.layer2 # 112 247 | # stage 3 248 | self.encoder3 = resnet.layer3 # 56 249 | # stage 4 250 | self.encoder4 = resnet.layer4 # 28 251 | 252 | self.pool4 = nn.MaxPool2d(2, 2, ceil_mode=True) 253 | 254 | # stage 5 255 | self.resb5_1 = BasicBlock(512, 512) 256 | self.resb5_2 = BasicBlock(512, 512) 257 | self.resb5_3 = BasicBlock(512, 512) # 14 258 | 259 | self.pool5 = nn.MaxPool2d(2, 2, ceil_mode=True) 260 | 261 | # stage 6 262 | self.resb6_1 = BasicBlock(512, 512) 263 | self.resb6_2 = BasicBlock(512, 512) 264 | self.resb6_3 = BasicBlock(512, 512) # 7 265 | 266 | # -------------Bridge-------------- 267 | 268 | # stage Bridge 269 | self.convbg_1 = nn.Conv2d(512, 512, 3, dilation=2, padding=2) # 7 270 | self.bnbg_1 = nn.BatchNorm2d(512) 271 | self.relubg_1 = nn.ReLU(inplace=True) 272 | self.convbg_m = nn.Conv2d(512, 512, 3, dilation=2, padding=2) 273 | self.bnbg_m = nn.BatchNorm2d(512) 274 | self.relubg_m = nn.ReLU(inplace=True) 275 | self.convbg_2 = nn.Conv2d(512, 512, 3, dilation=2, padding=2) 276 | self.bnbg_2 = nn.BatchNorm2d(512) 277 | self.relubg_2 = nn.ReLU(inplace=True) 278 | 279 | # -------------Decoder-------------- 280 | 281 | # stage 6d 282 | self.conv6d_1 = nn.Conv2d(1024, 512, 3, padding=1) # 16 283 | self.bn6d_1 = nn.BatchNorm2d(512) 284 | self.relu6d_1 = nn.ReLU(inplace=True) 285 | 286 | self.conv6d_m = nn.Conv2d(512, 512, 3, dilation=2, padding=2) ### 287 | self.bn6d_m = nn.BatchNorm2d(512) 288 | self.relu6d_m = nn.ReLU(inplace=True) 289 | 290 | self.conv6d_2 = nn.Conv2d(512, 512, 3, dilation=2, padding=2) 291 | self.bn6d_2 = nn.BatchNorm2d(512) 292 | self.relu6d_2 = nn.ReLU(inplace=True) 293 | 294 | # stage 5d 295 | self.conv5d_1 = nn.Conv2d(1024, 512, 3, padding=1) # 16 296 | self.bn5d_1 = nn.BatchNorm2d(512) 297 | self.relu5d_1 = nn.ReLU(inplace=True) 298 | 299 | self.conv5d_m = nn.Conv2d(512, 512, 3, padding=1) ### 300 | self.bn5d_m = nn.BatchNorm2d(512) 301 | self.relu5d_m = nn.ReLU(inplace=True) 302 | 303 | self.conv5d_2 = nn.Conv2d(512, 512, 3, padding=1) 304 | self.bn5d_2 = nn.BatchNorm2d(512) 305 | self.relu5d_2 = nn.ReLU(inplace=True) 306 | 307 | # stage 4d 308 | self.conv4d_1 = nn.Conv2d(1024, 512, 3, padding=1) # 32 309 | self.bn4d_1 = nn.BatchNorm2d(512) 310 | self.relu4d_1 = nn.ReLU(inplace=True) 311 | 312 | self.conv4d_m = nn.Conv2d(512, 512, 3, padding=1) ### 313 | self.bn4d_m = nn.BatchNorm2d(512) 314 | self.relu4d_m = nn.ReLU(inplace=True) 315 | 316 | self.conv4d_2 = nn.Conv2d(512, 256, 3, padding=1) 317 | self.bn4d_2 = nn.BatchNorm2d(256) 318 | self.relu4d_2 = nn.ReLU(inplace=True) 319 | 320 | # stage 3d 321 | self.conv3d_1 = nn.Conv2d(512, 256, 3, padding=1) # 64 322 | self.bn3d_1 = nn.BatchNorm2d(256) 323 | self.relu3d_1 = nn.ReLU(inplace=True) 324 | 325 | self.conv3d_m = nn.Conv2d(256, 256, 3, padding=1) ### 326 | self.bn3d_m = nn.BatchNorm2d(256) 327 | self.relu3d_m = nn.ReLU(inplace=True) 328 | 329 | self.conv3d_2 = nn.Conv2d(256, 128, 3, padding=1) 330 | self.bn3d_2 = nn.BatchNorm2d(128) 331 | self.relu3d_2 = nn.ReLU(inplace=True) 332 | 333 | # stage 2d 334 | 335 | self.conv2d_1 = nn.Conv2d(256, 128, 3, padding=1) # 128 336 | self.bn2d_1 = nn.BatchNorm2d(128) 337 | self.relu2d_1 = nn.ReLU(inplace=True) 338 | 339 | self.conv2d_m = nn.Conv2d(128, 128, 3, padding=1) ### 340 | self.bn2d_m = nn.BatchNorm2d(128) 341 | self.relu2d_m = nn.ReLU(inplace=True) 342 | 343 | self.conv2d_2 = nn.Conv2d(128, 64, 3, padding=1) 344 | self.bn2d_2 = nn.BatchNorm2d(64) 345 | self.relu2d_2 = nn.ReLU(inplace=True) 346 | 347 | # stage 1d 348 | self.conv1d_1 = nn.Conv2d(128, 64, 3, padding=1) # 256 349 | self.bn1d_1 = nn.BatchNorm2d(64) 350 | self.relu1d_1 = nn.ReLU(inplace=True) 351 | 352 | self.conv1d_m = nn.Conv2d(64, 64, 3, padding=1) ### 353 | self.bn1d_m = nn.BatchNorm2d(64) 354 | self.relu1d_m = nn.ReLU(inplace=True) 355 | 356 | self.conv1d_2 = nn.Conv2d(64, 64, 3, padding=1) 357 | self.bn1d_2 = nn.BatchNorm2d(64) 358 | self.relu1d_2 = nn.ReLU(inplace=True) 359 | 360 | # -------------Bilinear Upsampling-------------- 361 | self.upscore6 = nn.Upsample(scale_factor=32, mode='bilinear', align_corners=False) ### 362 | self.upscore5 = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=False) 363 | self.upscore4 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=False) 364 | self.upscore3 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False) 365 | self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) 366 | 367 | # -------------Side Output-------------- 368 | self.outconvb = nn.Conv2d(512, 1, 3, padding=1) 369 | self.outconv6 = nn.Conv2d(512, 1, 3, padding=1) 370 | self.outconv5 = nn.Conv2d(512, 1, 3, padding=1) 371 | self.outconv4 = nn.Conv2d(256, 1, 3, padding=1) 372 | self.outconv3 = nn.Conv2d(128, 1, 3, padding=1) 373 | self.outconv2 = nn.Conv2d(64, 1, 3, padding=1) 374 | self.outconv1 = nn.Conv2d(64, 1, 3, padding=1) 375 | 376 | # -------------Refine Module------------- 377 | self.refunet = RefUnet(1, 64) 378 | 379 | def forward(self, x): 380 | hx = x 381 | 382 | # -------------Encoder------------- 383 | hx = self.inconv(hx) 384 | hx = self.inbn(hx) 385 | hx = self.inrelu(hx) 386 | 387 | h1 = self.encoder1(hx) # 256 388 | h2 = self.encoder2(h1) # 128 389 | h3 = self.encoder3(h2) # 64 390 | h4 = self.encoder4(h3) # 32 391 | 392 | hx = self.pool4(h4) # 16 393 | 394 | hx = self.resb5_1(hx) 395 | hx = self.resb5_2(hx) 396 | h5 = self.resb5_3(hx) 397 | 398 | hx = self.pool5(h5) # 8 399 | 400 | hx = self.resb6_1(hx) 401 | hx = self.resb6_2(hx) 402 | h6 = self.resb6_3(hx) 403 | 404 | #-------------Bridge------------- 405 | hx = self.relubg_1(self.bnbg_1(self.convbg_1(h6))) # 8 406 | hx = self.relubg_m(self.bnbg_m(self.convbg_m(hx))) 407 | hbg = self.relubg_2(self.bnbg_2(self.convbg_2(hx))) 408 | 409 | # -------------Decoder------------- 410 | 411 | hx = self.relu6d_1(self.bn6d_1(self.conv6d_1(torch.cat((hbg, h6), 1)))) 412 | hx = self.relu6d_m(self.bn6d_m(self.conv6d_m(hx))) 413 | hd6 = self.relu6d_2(self.bn6d_2(self.conv6d_2(hx))) 414 | 415 | hx = self.upscore2(hd6) # 8 -> 16 416 | 417 | hx = self.relu5d_1(self.bn5d_1(self.conv5d_1(torch.cat((hx, h5), 1)))) 418 | hx = self.relu5d_m(self.bn5d_m(self.conv5d_m(hx))) 419 | hd5 = self.relu5d_2(self.bn5d_2(self.conv5d_2(hx))) 420 | 421 | hx = self.upscore2(hd5) # 16 -> 32 422 | 423 | hx = self.relu4d_1(self.bn4d_1(self.conv4d_1(torch.cat((hx, h4), 1)))) 424 | hx = self.relu4d_m(self.bn4d_m(self.conv4d_m(hx))) 425 | hd4 = self.relu4d_2(self.bn4d_2(self.conv4d_2(hx))) 426 | 427 | hx = self.upscore2(hd4) # 32 -> 64 428 | 429 | hx = self.relu3d_1(self.bn3d_1(self.conv3d_1(torch.cat((hx, h3), 1)))) 430 | hx = self.relu3d_m(self.bn3d_m(self.conv3d_m(hx))) 431 | hd3 = self.relu3d_2(self.bn3d_2(self.conv3d_2(hx))) 432 | 433 | hx = self.upscore2(hd3) # 64 -> 128 434 | 435 | hx = self.relu2d_1(self.bn2d_1(self.conv2d_1(torch.cat((hx, h2), 1)))) 436 | hx = self.relu2d_m(self.bn2d_m(self.conv2d_m(hx))) 437 | hd2 = self.relu2d_2(self.bn2d_2(self.conv2d_2(hx))) 438 | 439 | hx = self.upscore2(hd2) # 128 -> 256 440 | 441 | hx = self.relu1d_1(self.bn1d_1(self.conv1d_1(torch.cat((hx, h1), 1)))) 442 | hx = self.relu1d_m(self.bn1d_m(self.conv1d_m(hx))) 443 | hd1 = self.relu1d_2(self.bn1d_2(self.conv1d_2(hx))) 444 | 445 | # -------------Side Output------------- 446 | db = self.outconvb(hbg) 447 | db = self.upscore6(db) # 8->256 448 | 449 | d6 = self.outconv6(hd6) 450 | d6 = self.upscore6(d6) # 8->256 451 | 452 | d5 = self.outconv5(hd5) 453 | d5 = self.upscore5(d5) # 16->256 454 | 455 | d4 = self.outconv4(hd4) 456 | d4 = self.upscore4(d4) # 32->256 457 | 458 | d3 = self.outconv3(hd3) 459 | d3 = self.upscore3(d3) # 64->256 460 | 461 | d2 = self.outconv2(hd2) 462 | d2 = self.upscore2(d2) # 128->256 463 | 464 | d1 = self.outconv1(hd1) # 256 465 | 466 | # -------------Refine Module------------- 467 | dout = self.refunet(d1) # 256 468 | 469 | return torch.sigmoid(dout), torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid( 470 | d4), torch.sigmoid(d5), torch.sigmoid( 471 | d6), torch.sigmoid(db) 472 | -------------------------------------------------------------------------------- /libs/networks.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import time 4 | 5 | import numpy as np 6 | from PIL import Image 7 | from skimage import io, transform 8 | 9 | from libs import strings 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def model_detect(model_name): 15 | """Detects which model to use and returns its object""" 16 | models_names = strings.MODELS_NAMES 17 | if model_name in models_names: 18 | if model_name == "xception_model" or model_name == "mobile_net_model": 19 | return TFSegmentation(model_name) 20 | elif "u2net" in model_name: 21 | return U2NET(model_name) 22 | elif "basnet" == model_name: 23 | return BasNet(model_name) 24 | else: 25 | return False 26 | else: 27 | return False 28 | 29 | 30 | class U2NET: 31 | """U^2-Net model interface""" 32 | 33 | def __init__(self, name="u2net"): 34 | import torch 35 | from torch.autograd import Variable 36 | from libs.u2net import U2NET as U2NET_DEEP 37 | from libs.u2net import U2NETP as U2NETP_DEEP 38 | self.Variable = Variable 39 | self.torch = torch 40 | self.U2NET_DEEP = U2NET_DEEP 41 | self.U2NETP_DEEP = U2NETP_DEEP 42 | 43 | if name == 'u2net': # Load model 44 | logger.debug("Loading a U2NET model (176.6 mb) with better quality but slower processing.") 45 | net = self.U2NET_DEEP() 46 | elif name == 'u2netp': 47 | logger.debug("Loading a U2NETp model (4 mb) with lower quality but fast processing.") 48 | net = self.U2NETP_DEEP() 49 | else: 50 | raise Exception("Unknown u2net model!") 51 | try: 52 | if self.torch.cuda.is_available(): 53 | net.load_state_dict(self.torch.load(os.path.join("models", name, name + '.pth'))) 54 | net.cuda() 55 | else: 56 | net.load_state_dict(self.torch.load(os.path.join("models", name, name + '.pth'), map_location="cpu")) 57 | except FileNotFoundError: 58 | raise FileNotFoundError("No pre-trained model found! Run setup.sh or setup.bat to download it!") 59 | net.eval() 60 | self.__net__ = net # Define model 61 | 62 | def process_image(self, data, preprocessing=None, postprocessing=None): 63 | """ 64 | Removes background from image and returns PIL RGBA Image. 65 | :param data: Path to image or PIL image 66 | :param preprocessing: Image Pre-Processing Algorithm Class (optional) 67 | :param postprocessing: Image Post-Processing Algorithm Class (optional) 68 | :return: PIL RGBA Image. If an error reading the image is detected, returns False. 69 | """ 70 | if isinstance(data, str): 71 | logger.debug("Load image: {}".format(data)) 72 | image, org_image = self.__load_image__(data) # Load image 73 | if image is False or org_image is False: 74 | return False 75 | if preprocessing: # If an algorithm that preprocesses is specified, 76 | # then this algorithm should immediately remove the background 77 | image = preprocessing.run(self, image, org_image) 78 | else: 79 | image = self.__get_output__(image, org_image) # If this is not, then just remove the background 80 | if postprocessing: # If a postprocessing algorithm is specified, we send it an image without a background 81 | image = postprocessing.run(self, image, org_image) 82 | return image 83 | 84 | def __get_output__(self, image, org_image): 85 | """ 86 | Returns output from a neural network 87 | :param image: Prepared Image 88 | :param org_image: Original pil image 89 | :return: Image without background 90 | """ 91 | start_time = time.time() # Time counter 92 | image = image.type(self.torch.FloatTensor) 93 | if self.torch.cuda.is_available(): 94 | image = self.Variable(image.cuda()) 95 | else: 96 | image = self.Variable(image) 97 | mask, d2, d3, d4, d5, d6, d7 = self.__net__(image) # Predict mask 98 | logger.debug("Mask prediction completed") 99 | # Normalization 100 | logger.debug("Mask normalization") 101 | mask = mask[:, 0, :, :] 102 | mask = self.__normalize__(mask) 103 | # Prepare mask 104 | logger.debug("Prepare mask") 105 | mask = self.__prepare_mask__(mask, org_image.size) 106 | # Apply mask to image 107 | logger.debug("Apply mask to image") 108 | empty = Image.new("RGBA", org_image.size) 109 | image = Image.composite(org_image, empty, mask) 110 | logger.debug("Finished! Time spent: {}".format(time.time() - start_time)) 111 | return image 112 | 113 | def __load_image__(self, data): 114 | """ 115 | Loads an image file for other processing 116 | :param data: Path to image file or PIL image 117 | :return: image tensor, original pil image 118 | """ 119 | image_size = 320 # Size of the input and output image for the model 120 | if isinstance(data, str): 121 | try: 122 | image = io.imread(data) # Load image if there is a path 123 | except IOError: 124 | logger.error('Cannot retrieve image. Please check file: ' + data) 125 | return False, False 126 | pil_image = Image.fromarray(image) 127 | else: 128 | image = np.array(data) # Convert PIL image to numpy arr 129 | pil_image = data 130 | image = transform.resize(image, (image_size, image_size), mode='constant') # Resize image 131 | image = self.__ndrarray2tensor__(image) # Convert image from numpy arr to tensor 132 | return image, pil_image 133 | 134 | def __ndrarray2tensor__(self, image: np.ndarray): 135 | """ 136 | Converts a NumPy array to a tensor 137 | :param image: Image numpy array 138 | :return: Image tensor 139 | """ 140 | tmp_img = np.zeros((image.shape[0], image.shape[1], 3)) 141 | image /= np.max(image) 142 | if image.shape[2] == 1: 143 | tmp_img[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229 144 | tmp_img[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229 145 | tmp_img[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229 146 | else: 147 | tmp_img[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229 148 | tmp_img[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224 149 | tmp_img[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225 150 | tmp_img = tmp_img.transpose((2, 0, 1)) 151 | tmp_img = np.expand_dims(tmp_img, 0) 152 | return self.torch.from_numpy(tmp_img) 153 | 154 | def __normalize__(self, predicted): 155 | """Normalize the predicted map""" 156 | ma = self.torch.max(predicted) 157 | mi = self.torch.min(predicted) 158 | out = (predicted - mi) / (ma - mi) 159 | return out 160 | 161 | @staticmethod 162 | def __prepare_mask__(predict, image_size): 163 | """Prepares mask""" 164 | predict = predict.squeeze() 165 | predict_np = predict.cpu().data.numpy() 166 | mask = Image.fromarray(predict_np * 255).convert("L") 167 | mask = mask.resize(image_size, resample=Image.BILINEAR) 168 | return mask 169 | 170 | 171 | class BasNet: 172 | """BasNet model interface""" 173 | 174 | def __init__(self, name="basnet"): 175 | import torch 176 | from torch.autograd import Variable 177 | from libs.basnet import BASNet as BASNet_DEEP 178 | 179 | self.Variable = Variable 180 | self.torch = torch 181 | self.BASNet_DEEP = BASNet_DEEP 182 | 183 | if name == 'basnet': # Load model 184 | logger.debug("Loading a BASNet model.") 185 | net = self.BASNet_DEEP(3, 1) 186 | else: 187 | raise Exception("Unknown BASNet model") 188 | try: 189 | if self.torch.cuda.is_available(): 190 | net.load_state_dict(self.torch.load(os.path.join("models", name, name + '.pth'))) 191 | net.cuda() 192 | else: 193 | net.load_state_dict(self.torch.load(os.path.join("models", name, name + '.pth'), map_location="cpu")) 194 | except FileNotFoundError: 195 | raise FileNotFoundError("No pre-trained model found! Run setup.sh or setup.bat to download it!") 196 | net.eval() 197 | self.__net__ = net # Define model 198 | 199 | def process_image(self, data, preprocessing=None, postprocessing=None): 200 | """ 201 | Removes background from image and returns PIL RGBA Image. 202 | :param data: Path to image or PIL image 203 | :param preprocessing: Image Pre-Processing Algorithm Class (optional) 204 | :param postprocessing: Image Post-Processing Algorithm Class (optional) 205 | :return: PIL RGBA Image. If an error reading the image is detected, returns False. 206 | """ 207 | if isinstance(data, str): 208 | logger.debug("Load image: {}".format(data)) 209 | image, orig_image = self.__load_image__(data) # Load image 210 | if image is False or orig_image is False: 211 | return False 212 | if preprocessing: # If an algorithm that preprocesses is specified, 213 | # then this algorithm should immediately remove the background 214 | image = preprocessing.run(self, image, orig_image) 215 | else: 216 | image = self.__get_output__(image, orig_image) # If this is not, then just remove the background 217 | if postprocessing: # If a postprocessing algorithm is specified, we send it an image without a background 218 | image = postprocessing.run(self, image, orig_image) 219 | return image 220 | 221 | def __get_output__(self, image, org_image): 222 | """ 223 | Returns output from a neural network 224 | :param image: Prepared Image 225 | :param org_image: Original pil image 226 | :return: Image without background 227 | """ 228 | start_time = time.time() # Time counter 229 | image = image.type(self.torch.FloatTensor) 230 | if self.torch.cuda.is_available(): 231 | image = self.Variable(image.cuda()) 232 | else: 233 | image = self.Variable(image) 234 | mask, d2, d3, d4, d5, d6, d7, d8 = self.__net__(image) # Predict mask 235 | logger.debug("Mask prediction completed") 236 | # Normalization 237 | logger.debug("Mask normalization") 238 | mask = mask[:, 0, :, :] 239 | mask = self.__normalize__(mask) 240 | # Prepare mask 241 | logger.debug("Prepare mask") 242 | mask = self.__prepare_mask__(mask, org_image.size) 243 | # Apply mask to image 244 | logger.debug("Apply mask to image") 245 | empty = Image.new("RGBA", org_image.size) 246 | image = Image.composite(org_image, empty, mask) 247 | logger.debug("Finished! Time spent: {}".format(time.time() - start_time)) 248 | return image 249 | 250 | def __load_image__(self, data): 251 | """ 252 | Loads an image file for other processing 253 | :param data: Path to image file or PIL image 254 | :return: image tensor, Original Pil Image 255 | """ 256 | image_size = 256 # Size of the input and output image for the model 257 | if isinstance(data, str): 258 | try: 259 | image = io.imread(data) # Load image if there is a path 260 | except IOError: 261 | logger.error('Cannot retrieve image. Please check file: ' + data) 262 | return False, False 263 | pil_image = Image.fromarray(image) 264 | else: 265 | image = np.array(data) # Convert PIL image to numpy arr 266 | pil_image = data 267 | image = transform.resize(image, (image_size, image_size), mode='constant') # Resize image 268 | image = self.__ndrarray2tensor__(image) # Convert image from numpy arr to tensor 269 | return image, pil_image 270 | 271 | def __ndrarray2tensor__(self, image: np.ndarray): 272 | """ 273 | Converts a NumPy array to a tensor 274 | :param image: Image numpy array 275 | :return: Image tensor 276 | """ 277 | tmp_img = np.zeros((image.shape[0], image.shape[1], 3)) 278 | image /= np.max(image) 279 | if image.shape[2] == 1: 280 | tmp_img[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229 281 | tmp_img[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229 282 | tmp_img[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229 283 | else: 284 | tmp_img[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229 285 | tmp_img[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224 286 | tmp_img[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225 287 | tmp_img = tmp_img.transpose((2, 0, 1)) 288 | tmp_img = np.expand_dims(tmp_img, 0) 289 | return self.torch.from_numpy(tmp_img) 290 | 291 | def __normalize__(self, predicted): 292 | """Normalize the predicted map""" 293 | ma = self.torch.max(predicted) 294 | mi = self.torch.min(predicted) 295 | out = (predicted - mi) / (ma - mi) 296 | return out 297 | 298 | @staticmethod 299 | def __prepare_mask__(predict, image_size): 300 | """Prepares mask""" 301 | predict = predict.squeeze() 302 | predict_np = predict.cpu().data.numpy() 303 | mask = Image.fromarray(predict_np * 255).convert("L") 304 | mask = mask.resize(image_size, resample=Image.BILINEAR) 305 | return mask 306 | 307 | 308 | class TFSegmentation(object): 309 | """Class to load Deeplabv3 model and run inference.""" 310 | def __init__(self, model_type): 311 | """Creates and loads pretrained deeplab model.""" 312 | import scipy.ndimage as ndi 313 | import tensorflow as tf 314 | self.tf = tf 315 | self.ndi = ndi 316 | 317 | # Environment init 318 | self.INPUT_TENSOR_NAME = 'ImageTensor:0' 319 | self.OUTPUT_TENSOR_NAME = 'SemanticPredictions:0' 320 | self.INPUT_SIZE = 513 321 | self.FROZEN_GRAPH_NAME = 'frozen_inference_graph' 322 | # Start load process 323 | self.graph = self.tf.Graph() 324 | try: 325 | graph_def = self.tf.compat.v1.GraphDef.FromString(open(os.path.join("models", model_type, "model", 326 | "frozen_inference_graph.pb"), 327 | "rb").read()) 328 | except FileNotFoundError: 329 | raise FileNotFoundError("No pre-trained model found! Run setup.sh or setup.bat to download it!") 330 | logger.warning("Loading a DeepLab model ({})! " 331 | "This is an outdated model with poorer image quality and processing time." 332 | "Better use the U2NET model instead of this one!".format(model_type)) 333 | if graph_def is None: 334 | raise RuntimeError('Cannot find inference graph in tar archive.') 335 | with self.graph.as_default(): 336 | self.tf.import_graph_def(graph_def, name='') 337 | self.sess = self.tf.compat.v1.Session(graph=self.graph) 338 | 339 | @staticmethod 340 | def __load_image__(data): 341 | """ 342 | Loads an image file for other processing 343 | :param data: Path to image file or PIL image 344 | :return: Pil Image, Pil Image 345 | """ 346 | if isinstance(data, str): 347 | try: 348 | image = Image.open(data) # Load image if there is a path 349 | except IOError: 350 | logger.error('Cannot retrieve image. Please check file: ' + data) 351 | return False 352 | else: 353 | image = data 354 | return image, image 355 | 356 | def process_image(self, data, preprocessing=None, postprocessing=None): 357 | """ 358 | Removes background from image and returns PIL RGBA Image. 359 | :param data: Path to image or PIL image 360 | :param preprocessing: Image Pre-Processing Algorithm Class (optional) 361 | :param postprocessing: Image Post-Processing Algorithm Class (optional) 362 | :return: PIL RGBA Image. If an error reading the image is detected, returns False. 363 | """ 364 | if isinstance(data, str): 365 | logger.debug("Load image: {}".format(data)) 366 | image, org_image = self.__load_image__(data) # Load image 367 | if image is False or org_image is False: 368 | return False 369 | if preprocessing: # If an algorithm that preprocesses is specified, 370 | # then this algorithm should immediately remove the background 371 | image = preprocessing.run(self, image, org_image) 372 | else: 373 | image = self.__get_output__(image, org_image) # If this is not, then just remove the background 374 | if postprocessing: # If a postprocessing algorithm is specified, we send it an image without a background 375 | image = postprocessing.run(self, image, org_image) 376 | return image 377 | 378 | def __get_output__(self, image, _=None): 379 | """ 380 | Returns output from a neural network 381 | :param image: Prepared Image 382 | :param _: Not used argument for compatibility with pre-processing module 383 | :return: Image without background 384 | """ 385 | start_time = time.time() # Time counter 386 | seg_map = self.__predict__(image) 387 | logger.debug('Finished mask creation') 388 | image = image.convert('RGB') 389 | logger.debug("Mask overlay completed") 390 | image = self.__draw_segment__(image, seg_map) 391 | logger.debug("Finished! Time spent: {}".format(time.time() - start_time)) 392 | return image 393 | 394 | def __predict__(self, image): 395 | """Image processing.""" 396 | # Get image size 397 | width, height = image.size 398 | # Calculate scale value 399 | resize_ratio = 1.0 * self.INPUT_SIZE / max(width, height) 400 | # Calculate future image size 401 | target_size = (int(resize_ratio * width), int(resize_ratio * height)) 402 | # Resize image 403 | resized_image = image.convert('RGB').resize(target_size, Image.ANTIALIAS) 404 | # Send image to model 405 | batch_seg_map = self.sess.run( 406 | self.OUTPUT_TENSOR_NAME, 407 | feed_dict={self.INPUT_TENSOR_NAME: [np.asarray(resized_image)]}) 408 | # Get model output 409 | seg_map = batch_seg_map[0] 410 | # Get new image size and original image size 411 | width, height = resized_image.size 412 | width2, height2 = image.size 413 | # Calculate scale 414 | scale_w = width2 / width 415 | scale_h = height2 / height 416 | # Zoom numpy array for original image 417 | seg_map = self.ndi.zoom(seg_map, (scale_h, scale_w)) 418 | return seg_map 419 | 420 | @staticmethod 421 | def __draw_segment__(image, alpha_channel): 422 | """Postprocessing. Returns complete image.""" 423 | # Get image size 424 | width, height = image.size 425 | # Create empty numpy array 426 | dummy_img = np.zeros([height, width, 4], dtype=np.uint8) 427 | # Create alpha layer from model output 428 | for x in range(width): 429 | for y in range(height): 430 | color = alpha_channel[y, x] 431 | (r, g, b) = image.getpixel((x, y)) 432 | if color == 0: 433 | dummy_img[y, x, 3] = 0 434 | else: 435 | dummy_img[y, x] = [r, g, b, 255] 436 | # Restore image object from numpy array 437 | img = Image.fromarray(dummy_img) 438 | return img 439 | -------------------------------------------------------------------------------- /libs/postprocessing.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from PIL import Image 3 | from libs.strings import POSTPROCESS_METHODS 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | 8 | def method_detect(method: str): 9 | """Detects which method to use and returns its object""" 10 | if method in POSTPROCESS_METHODS: 11 | if method == "rtb-bnb": 12 | return RemovingTooTransparentBordersHardAndBlurringHardBorders() 13 | elif method == "rtb-bnb2": 14 | return RemovingTooTransparentBordersHardAndBlurringHardBordersTwo() 15 | else: 16 | return None 17 | else: 18 | return False 19 | 20 | 21 | class RemovingTooTransparentBordersHardAndBlurringHardBordersTwo: 22 | """ 23 | This is the class for the image post-processing algorithm. 24 | This algorithm improves the boundaries of the image obtained from the neural network. 25 | It is based on the principle of removing too transparent pixels 26 | and smoothing the borders after removing too transparent pixels. 27 | """ 28 | 29 | def __init__(self): 30 | import cv2 31 | import skimage 32 | import numpy as np 33 | self.cv2 = cv2 34 | self.skimage = skimage 35 | self.np = np 36 | 37 | self.model = None 38 | self.prep_image = None 39 | self.orig_image = None 40 | 41 | @staticmethod 42 | def __extact_alpha_channel__(image): 43 | """ 44 | Extracts alpha channel from RGBA image 45 | :param image: RGBA pil image 46 | :return: RGB Pil image 47 | """ 48 | # Extract just the alpha channel 49 | alpha = image.split()[-1] 50 | # Create a new image with an opaque black background 51 | bg = Image.new("RGBA", image.size, (0, 0, 0, 255)) 52 | # Copy the alpha channel to the new image using itself as the mask 53 | bg.paste(alpha, mask=alpha) 54 | return bg.convert("RGB") 55 | 56 | def __blur_edges__(self, imaged): 57 | """ 58 | Blurs the edges of the image 59 | :param imaged: RGBA Pil image 60 | :return: RGBA PIL image 61 | """ 62 | image = self.np.array(imaged) 63 | image = self.cv2.cvtColor(image, self.cv2.COLOR_RGBA2BGRA) 64 | # extract alpha channel 65 | a = image[:, :, 3] 66 | # blur alpha channel 67 | ab = self.cv2.GaussianBlur(a, (0, 0), sigmaX=2, sigmaY=2, borderType=self.cv2.BORDER_DEFAULT) 68 | # stretch so that 255 -> 255 and 127.5 -> 0 69 | aa = self.skimage.exposure.rescale_intensity(ab, in_range=(140, 255), out_range=(0, 255)) 70 | # replace alpha channel in input with new alpha channel 71 | out = image.copy() 72 | out[:, :, 3] = aa 73 | image = self.cv2.cvtColor(out, self.cv2.COLOR_BGRA2RGBA) 74 | return Image.fromarray(image) 75 | 76 | def __remove_too_transparent_borders__(self, mask, tranp_val=31): 77 | """ 78 | Marks all pixels in the mask with a transparency greater than $tranp_val as opaque. 79 | Pixels with transparency less than $tranp_val, as fully transparent 80 | :param tranp_val: Integer value. 81 | :return: Processed mask 82 | """ 83 | mask = self.np.array(mask.convert("L")) 84 | height, weight = mask.shape 85 | for h in range(height): 86 | for w in range(weight): 87 | val = mask[h, w] 88 | if val > tranp_val: 89 | mask[h, w] = 255 90 | else: 91 | mask[h, w] = 0 92 | return Image.fromarray(mask) 93 | 94 | def run(self, model, image, orig_image): 95 | """ 96 | Runs an image post-processing algorithm to improve background removal quality. 97 | :param model: The class of the neural network used to remove the background. 98 | :param image: Image without background 99 | :param orig_image: Source image 100 | """ 101 | mask = self.__remove_too_transparent_borders__(self.__extact_alpha_channel__(image)) 102 | empty = Image.new("RGBA", orig_image.size) 103 | image = Image.composite(orig_image, empty, mask) 104 | image = self.__blur_edges__(image) 105 | 106 | image = model.process_image(image) 107 | 108 | mask = self.__remove_too_transparent_borders__(self.__extact_alpha_channel__(image)) 109 | empty = Image.new("RGBA", orig_image.size) 110 | image = Image.composite(orig_image, empty, mask) 111 | image = self.__blur_edges__(image) 112 | return image 113 | 114 | 115 | class RemovingTooTransparentBordersHardAndBlurringHardBorders: 116 | """ 117 | This is the class for the image post-processing algorithm. 118 | This algorithm improves the boundaries of the image obtained from the neural network. 119 | It is based on the principle of removing too transparent pixels 120 | and smoothing the borders after removing too transparent pixels. 121 | The algorithm performs this procedure twice. 122 | For the first time, the algorithm processes the image from the neural network, 123 | then sends the processed image back to the neural network, and then processes it again and returns it to the user. 124 | This method gives the best result in combination with u2net without any preprocessing methods. 125 | """ 126 | 127 | def __init__(self): 128 | import cv2 129 | import skimage 130 | import numpy as np 131 | self.cv2 = cv2 132 | self.skimage = skimage 133 | self.np = np 134 | 135 | self.model = None 136 | self.prep_image = None 137 | self.orig_image = None 138 | 139 | @staticmethod 140 | def __extact_alpha_channel__(image): 141 | """ 142 | Extracts alpha channel from RGBA image 143 | :param image: RGBA pil image 144 | :return: RGB Pil image 145 | """ 146 | # Extract just the alpha channel 147 | alpha = image.split()[-1] 148 | # Create a new image with an opaque black background 149 | bg = Image.new("RGBA", image.size, (0, 0, 0, 255)) 150 | # Copy the alpha channel to the new image using itself as the mask 151 | bg.paste(alpha, mask=alpha) 152 | return bg.convert("RGB") 153 | 154 | def __blur_edges__(self, imaged): 155 | """ 156 | Blurs the edges of the image 157 | :param imaged: RGBA Pil image 158 | :return: RGBA PIL image 159 | """ 160 | image = self.np.array(imaged) 161 | image = self.cv2.cvtColor(image, self.cv2.COLOR_RGBA2BGRA) 162 | # extract alpha channel 163 | a = image[:, :, 3] 164 | # blur alpha channel 165 | ab = self.cv2.GaussianBlur(a, (0, 0), sigmaX=2, sigmaY=2, borderType=self.cv2.BORDER_DEFAULT) 166 | # stretch so that 255 -> 255 and 127.5 -> 0 167 | # noinspection PyUnresolvedReferences 168 | aa = self.skimage.exposure.rescale_intensity(ab, in_range=(140, 255), out_range=(0, 255)) 169 | # replace alpha channel in input with new alpha channel 170 | out = image.copy() 171 | out[:, :, 3] = aa 172 | image = self.cv2.cvtColor(out, self.cv2.COLOR_BGRA2RGBA) 173 | return Image.fromarray(image) 174 | 175 | def __remove_too_transparent_borders__(self, mask, tranp_val=31): 176 | """ 177 | Marks all pixels in the mask with a transparency greater than tranp_val as opaque. 178 | Pixels with transparency less than tranp_val, as fully transparent 179 | :param tranp_val: Integer value. 180 | :return: Processed mask 181 | """ 182 | mask = self.np.array(mask.convert("L")) 183 | height, weight = mask.shape 184 | for h in range(height): 185 | for w in range(weight): 186 | val = mask[h, w] 187 | if val > tranp_val: 188 | mask[h, w] = 255 189 | else: 190 | mask[h, w] = 0 191 | return Image.fromarray(mask) 192 | 193 | def run(self, _, image, orig_image): 194 | """ 195 | Runs an image post-processing algorithm to improve background removal quality. 196 | :param _: The class of the neural network used to remove the background. 197 | :param image: Image without background 198 | :param orig_image: Source image 199 | """ 200 | mask = self.__remove_too_transparent_borders__(self.__extact_alpha_channel__(image)) 201 | empty = Image.new("RGBA", orig_image.size) 202 | image = Image.composite(orig_image, empty, mask) 203 | image = self.__blur_edges__(image) 204 | return image 205 | -------------------------------------------------------------------------------- /libs/preprocessing.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | 4 | import numpy as np 5 | from PIL import Image 6 | 7 | from libs.strings import PREPROCESS_METHODS 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def method_detect(method: str): 13 | """Detects which method to use and returns its object""" 14 | if method in PREPROCESS_METHODS: 15 | if method == "bbmd-maskrcnn": 16 | return BoundingBoxDetectionWithMaskMaskRcnn() 17 | elif method == "bbd-fastrcnn": 18 | return BoundingBoxDetectionFastRcnn() 19 | else: 20 | return None 21 | else: 22 | return False 23 | 24 | 25 | class BoundingBoxDetectionFastRcnn: 26 | """ 27 | Class for the image preprocessing method. 28 | This image pre-processing technique uses two neural networks ($used_model and Fast RCNN) 29 | to first detect the boundaries of objects in a photograph, 30 | cut them out, sequentially remove the background from each object in turn 31 | and subsequently collect the entire image from separate parts 32 | """ 33 | 34 | def __init__(self): 35 | self.__fast_rcnn__ = FastRcnn() 36 | self.model = None 37 | self.prep_image = None 38 | self.orig_image = None 39 | 40 | @staticmethod 41 | def trans_paste(bg_img, fg_img, box=(0, 0)): 42 | """ 43 | Inserts an image into another image while maintaining transparency. 44 | :param bg_img: Background pil image 45 | :param fg_img: Foreground pil image 46 | :param box: Bounding box 47 | :return: Pil Image 48 | """ 49 | fg_img_trans = Image.new("RGBA", bg_img.size) 50 | fg_img_trans.paste(fg_img, box, mask=fg_img) 51 | new_img = Image.alpha_composite(bg_img, fg_img_trans) 52 | return new_img 53 | 54 | @staticmethod 55 | def __orig_object_border__(border, orig_image, resized_image, indent=16): 56 | """ 57 | Rescales the bounding box of an object 58 | :param indent: The boundary of the object will expand by this value. 59 | :param border: array consisting of the coordinates of the boundaries of the object 60 | :param orig_image: original pil image 61 | :param resized_image: resized image ndarray 62 | :return: tuple consisting of the coordinates of the boundaries of the object 63 | """ 64 | x_factor = resized_image.shape[1] / orig_image.size[0] 65 | y_factor = resized_image.shape[0] / orig_image.size[1] 66 | xmin, ymin, xmax, ymax = [int(x) for x in border] 67 | if ymin < 0: 68 | ymin = 0 69 | if ymax > resized_image.shape[0]: 70 | ymax = resized_image.shape[0] 71 | if xmax > resized_image.shape[1]: 72 | xmax = resized_image.shape[1] 73 | if xmin < 0: 74 | xmin = 0 75 | if x_factor == 0: 76 | x_factor = 1 77 | if y_factor == 0: 78 | y_factor = 1 79 | border = (int(xmin / x_factor) - indent, 80 | int(ymin / y_factor) - indent, int(xmax / x_factor) + indent, int(ymax / y_factor) + indent) 81 | return border 82 | 83 | def run(self, model, prep_image, orig_image): 84 | """ 85 | Runs an image preprocessing algorithm to improve background removal quality. 86 | :param model: The class of the neural network used to remove the background. 87 | :param prep_image: Prepared for the neural network image 88 | :param orig_image: Source image 89 | :returns: Image without background 90 | """ 91 | _, resized_image, results = self.__fast_rcnn__.process_image(orig_image) 92 | 93 | classes = self.__fast_rcnn__.class_names 94 | bboxes = results['bboxes'] 95 | ids = results['ids'] 96 | scores = results['scores'] 97 | 98 | object_num = len(bboxes) # We get the number of all objects in the photo 99 | 100 | if object_num < 1: # If there are no objects, or they are not found, 101 | # we try to remove the background using standard tools 102 | return model.__get_output__(prep_image, orig_image) 103 | else: 104 | # Check that all arrays match each other in size 105 | if ids is not None and not len(bboxes) == len(ids): 106 | return model.__get_output__(prep_image, 107 | orig_image) # we try to remove the background using standard tools 108 | if scores is not None and not len(bboxes) == len(scores): 109 | return model.__get_output__(prep_image, orig_image) 110 | # we try to remove the background using standard tools 111 | objects = [] 112 | for i, bbox in enumerate(bboxes): 113 | if scores is not None and scores.flat[i] < 0.5: 114 | continue 115 | if ids is not None and ids.flat[i] < 0: 116 | continue 117 | object_cls_id = int(ids.flat[i]) if ids is not None else -1 118 | if classes is not None and object_cls_id < len(classes): 119 | object_label = classes[object_cls_id] 120 | else: 121 | object_label = str(object_cls_id) if object_cls_id >= 0 else '' 122 | object_border = self.__orig_object_border__(bbox, orig_image, resized_image) 123 | objects.append([object_label, object_border]) 124 | if objects: 125 | if len(objects) == 1: 126 | return model.__get_output__(prep_image, orig_image) 127 | # we try to remove the background using standard tools 128 | else: 129 | obj_images = [] 130 | for obj in objects: 131 | border = obj[1] 132 | obj_crop = orig_image.crop(border) 133 | # TODO: make a special algorithm to improve the removal of background from images with people. 134 | if obj[0] == "person": 135 | obj_img = model.process_image(obj_crop) 136 | else: 137 | obj_img = model.process_image(obj_crop) 138 | obj_images.append([obj_img, obj]) 139 | image = Image.new("RGBA", orig_image.size) 140 | for obj in obj_images: 141 | image = self.trans_paste(image, obj[0], obj[1][1]) 142 | return image 143 | else: 144 | return model.__get_output__(prep_image, orig_image) 145 | 146 | 147 | class BoundingBoxDetectionWithMaskMaskRcnn: 148 | """ 149 | Class for the image preprocessing method. 150 | This image pre-processing technique uses two neural networks 151 | to first detect the boundaries and masks of objects in a photograph, 152 | cut them out, expand the masks by a certain number of pixels, 153 | apply them and remove the background from each object in turn 154 | and subsequently collect the entire image from separate parts 155 | """ 156 | 157 | def __init__(self): 158 | self.__mask_rcnn__ = MaskRcnn() 159 | self.model = None 160 | self.prep_image = None 161 | self.orig_image = None 162 | 163 | @staticmethod 164 | def __mask_extend__(mask, indent=10): 165 | """ 166 | Extends the mask of an object. 167 | :param mask: 8-bit ndarray mask 168 | :param indent: Indent on which to expand the mask 169 | :return: extended 8-bit mask ndarray 170 | """ 171 | # TODO: Rewrite this function. 172 | height, weight = mask.shape 173 | old_val = 0 174 | for h in range(height): 175 | for w in range(weight): 176 | val = mask[h, w] 177 | if val == 1 and old_val == 0: 178 | for i in range(1, indent + 1): 179 | if w - i > 0: 180 | mask[h, w - i] = 1 181 | old_val = val 182 | elif val == 0 and old_val == 1: 183 | if weight - w >= indent: 184 | for i in range(0, indent): 185 | mask[h, w + i] = 1 186 | else: 187 | for i in range(0, weight - w): 188 | mask[h, w + i] = 1 189 | old_val = val 190 | break 191 | return mask 192 | 193 | @staticmethod 194 | def trans_paste(bg_img, fg_img, box=(0, 0)): 195 | """ 196 | Inserts an image into another image while maintaining transparency. 197 | :param bg_img: Background pil image 198 | :param fg_img: Foreground pil image 199 | :param box: Bounding box 200 | :return: Pil Image 201 | """ 202 | fg_img_trans = Image.new("RGBA", bg_img.size) 203 | fg_img_trans.paste(fg_img, box, mask=fg_img) 204 | new_img = Image.alpha_composite(bg_img, fg_img_trans) 205 | return new_img 206 | 207 | @staticmethod 208 | def __orig_object_border__(border, orig_image, resized_image, indent=16): 209 | """ 210 | Rescales the bounding box of an object 211 | :param indent: The boundary of the object will expand by this value. 212 | :param border: array consisting of the coordinates of the boundaries of the object 213 | :param orig_image: original pil image 214 | :param resized_image: resized image ndarray 215 | :return: tuple consisting of the coordinates of the boundaries of the object 216 | """ 217 | x_factor = resized_image.shape[1] / orig_image.size[0] 218 | y_factor = resized_image.shape[0] / orig_image.size[1] 219 | xmin, ymin, xmax, ymax = [int(x) for x in border] 220 | if ymin < 0: 221 | ymin = 0 222 | if ymax > resized_image.shape[0]: 223 | ymax = resized_image.shape[0] 224 | if xmax > resized_image.shape[1]: 225 | xmax = resized_image.shape[1] 226 | if xmin < 0: 227 | xmin = 0 228 | if x_factor == 0: 229 | x_factor = 1 230 | if y_factor == 0: 231 | y_factor = 1 232 | border = (int(xmin / x_factor) - indent, 233 | int(ymin / y_factor) - indent, 234 | int(xmax / x_factor) + indent, 235 | int(ymax / y_factor) + indent) 236 | return border 237 | 238 | @staticmethod 239 | def __apply_mask__(image, mask): 240 | """ 241 | Applies a mask to an image. 242 | :param image: Pil image 243 | :param mask: 8 bit Mask ndarray 244 | :return: Pil Image 245 | """ 246 | image = np.array(image) 247 | image[:, :, 0] = np.where( 248 | mask == 0, 249 | 255, 250 | image[:, :, 0] 251 | ) 252 | image[:, :, 1] = np.where( 253 | mask == 0, 254 | 255, 255 | image[:, :, 1] 256 | ) 257 | image[:, :, 2] = np.where( 258 | mask == 0, 259 | 255, 260 | image[:, :, 2] 261 | ) 262 | return Image.fromarray(image) 263 | 264 | def run(self, model, prep_image, orig_image): 265 | """ 266 | Runs an image preprocessing algorithm to improve background removal quality. 267 | :param model: The class of the neural network used to remove the background. 268 | :param prep_image: Prepared for the neural network image 269 | :param orig_image: Source image 270 | :return: Image without background 271 | """ 272 | _, resized_image, results = self.__mask_rcnn__.process_image(orig_image) 273 | 274 | classes = self.__mask_rcnn__.class_names 275 | bboxes = results['bboxes'] 276 | masks = results['masks'] 277 | ids = results['ids'] 278 | scores = results['scores'] 279 | 280 | object_num = len(bboxes) # We get the number of all objects in the photo 281 | 282 | if object_num < 1: # If there are no objects, or they are not found, 283 | # we try to remove the background using standard tools 284 | return model.__get_output__(prep_image, orig_image) 285 | else: 286 | # Check that all arrays match each other in size 287 | if ids is not None and not len(bboxes) == len(ids): 288 | return model.__get_output__(prep_image, 289 | orig_image) # we try to remove the background using standard tools 290 | if scores is not None and not len(bboxes) == len(scores): 291 | return model.__get_output__(prep_image, orig_image) 292 | # we try to remove the background using standard tools 293 | objects = [] 294 | for i, bbox in enumerate(bboxes): 295 | if scores is not None and scores.flat[i] < 0.5: 296 | continue 297 | if ids is not None and ids.flat[i] < 0: 298 | continue 299 | object_cls_id = int(ids.flat[i]) if ids is not None else -1 300 | if classes is not None and object_cls_id < len(classes): 301 | object_label = classes[object_cls_id] 302 | else: 303 | object_label = str(object_cls_id) if object_cls_id >= 0 else '' 304 | object_border = self.__orig_object_border__(bbox, orig_image, resized_image) 305 | object_mask = masks[i, :, :] 306 | objects.append([object_label, object_border, object_mask]) 307 | if objects: 308 | if len(objects) == 1: 309 | return model.__get_output__(prep_image, orig_image) 310 | # we try to remove the background using standard tools 311 | else: 312 | obj_images = [] 313 | for obj in objects: 314 | extended_mask = self.__mask_extend__(obj[2]) 315 | obj_masked = self.__apply_mask__(orig_image, extended_mask) 316 | 317 | border = obj[1] 318 | obj_crop_masked = obj_masked.crop(border) 319 | # TODO: make a special algorithm to improve the removal of background from images with people. 320 | if obj[0] == "person": 321 | obj_img = model.process_image(obj_crop_masked) 322 | else: 323 | obj_img = model.process_image(obj_crop_masked) 324 | obj_images.append([obj_img, obj]) 325 | image = Image.new("RGBA", orig_image.size) 326 | for obj in obj_images: 327 | image = self.trans_paste(image, obj[0], obj[1][1]) 328 | return image 329 | else: 330 | return model.__get_output__(prep_image, orig_image) 331 | 332 | 333 | class FastRcnn: 334 | """ 335 | Fast Rcnn Neural Network to detect objects in the photo. 336 | """ 337 | 338 | def __init__(self): 339 | from gluoncv import model_zoo, data 340 | from mxnet import nd 341 | self.model_zoo = model_zoo 342 | self.data = data 343 | self.nd = nd 344 | logger.debug("Loading Fast RCNN neural network") 345 | self.__net__ = self.model_zoo.get_model('faster_rcnn_resnet50_v1b_voc', 346 | pretrained=True) # Download the pre-trained model, if one is missing. 347 | # noinspection PyUnresolvedReferences 348 | self.class_names = self.__net__.classes 349 | 350 | def __load_image__(self, data_input): 351 | """ 352 | Loads an image file for other processing 353 | :param data_input: Path to image file or PIL image 354 | :return: image 355 | """ 356 | if isinstance(data_input, str): 357 | try: 358 | data_input = Image.open(data_input) 359 | # Fix https://github.com/OPHoperHPO/image-background-remove-tool/issues/19 360 | data_input = data_input.convert("RGB") 361 | image = np.array(data_input) # Convert PIL image to numpy arr 362 | except IOError: 363 | logger.error('Cannot retrieve image. Please check file: ' + data_input) 364 | return False, False 365 | else: 366 | # Fix https://github.com/OPHoperHPO/image-background-remove-tool/issues/19 367 | data_input = data_input.convert("RGB") 368 | image = np.array(data_input) # Convert PIL image to numpy arr 369 | x, resized_image = self.data.transforms.presets.rcnn.transform_test(self.nd.array(image)) 370 | return x, image, resized_image 371 | 372 | def process_image(self, image): 373 | """ 374 | Detects objects in the photo and returns their names, borders. 375 | :param image: Path to image or PIL image. 376 | :return: original pil image, resized pil image, dict(ids, scores, bboxes) 377 | """ 378 | start_time = time.time() # Time counter 379 | x, image, resized_image = self.__load_image__(image) 380 | ids, scores, bboxes = [xx[0].asnumpy() for xx in self.__net__(x)] 381 | logger.debug("Finished! Time spent: {}".format(time.time() - start_time)) 382 | return image, resized_image, {"ids": ids, "scores": scores, "bboxes": bboxes} 383 | 384 | 385 | class MaskRcnn: 386 | """ 387 | Mask Rcnn Neural Network to detect objects in the photo. 388 | """ 389 | 390 | def __init__(self): 391 | from gluoncv import model_zoo, utils, data 392 | from mxnet import nd 393 | self.model_zoo = model_zoo 394 | self.utils = utils 395 | self.data = data 396 | self.nd = nd 397 | logger.debug("Loading Mask RCNN neural network") 398 | self.__net__ = self.model_zoo.get_model('mask_rcnn_resnet50_v1b_coco', 399 | pretrained=True) # Download the pre-trained model, if one is missing. 400 | # noinspection PyUnresolvedReferences 401 | self.class_names = self.__net__.classes 402 | 403 | def __load_image__(self, data_input): 404 | """ 405 | Loads an image file for other processing 406 | :param data_input: Path to image file or PIL image 407 | :return: neural network input, original pil image, resized image ndarray 408 | """ 409 | if isinstance(data_input, str): 410 | try: 411 | data_input = Image.open(data_input) 412 | # Fix https://github.com/OPHoperHPO/image-background-remove-tool/issues/19 413 | data_input = data_input.convert("RGB") 414 | image = np.array(data_input) # Convert PIL image to numpy arr 415 | except IOError: 416 | logger.error('Cannot retrieve image. Please check file: ' + data_input) 417 | return False, False 418 | else: 419 | # Fix https://github.com/OPHoperHPO/image-background-remove-tool/issues/19 420 | data_input = data_input.convert("RGB") 421 | image = np.array(data_input) # Convert PIL image to numpy arr 422 | x, resized_image = self.data.transforms.presets.rcnn.transform_test(self.nd.array(image)) 423 | return x, image, resized_image 424 | 425 | def process_image(self, image): 426 | """ 427 | Detects objects in the photo and returns their names, borders and a mask of poor quality. 428 | :param image: Path to image or PIL image. 429 | :return: original pil image, resized pil image, dict(ids, scores, bboxes, masks) 430 | """ 431 | start_time = time.time() # Time counter 432 | x, image, resized_image = self.__load_image__(image) 433 | ids, scores, bboxes, masks = [xx[0].asnumpy() for xx in self.__net__(x)] 434 | masks, _ = self.utils.viz.expand_mask(masks, bboxes, (image.shape[1], image.shape[0]), scores) 435 | logger.debug("Finished! Time spent: {}".format(time.time() - start_time)) 436 | return image, resized_image, {"ids": ids, "scores": scores, "bboxes": bboxes, 437 | "masks": masks} 438 | -------------------------------------------------------------------------------- /libs/strings.py: -------------------------------------------------------------------------------- 1 | NAME = "Image Background Remove Tool" 2 | MODELS_NAMES = ["u2net", "basnet", "u2netp", "xception_model", "mobile_net_model"] 3 | PREPROCESS_METHODS = ["bbd-fastrcnn", "bbmd-maskrcnn", "None"] 4 | POSTPROCESS_METHODS = ["rtb-bnb", "rtb-bnb2", "No"] 5 | DESCRIPTION = "A tool to remove a background from image using Neural Networks" 6 | LICENSE = "Apache License 2.0" 7 | ARGS_HELP = """ 8 | {} 9 | {} 10 | License: {} 11 | Running the script: 12 | python3 main.py -i -o -m -prep -postp 13 | Explanation of args: 14 | -i - path to input file or dir. 15 | -o - path to output file or dir. 16 | -prep - Preprocessing method. Can be {} . `bbd-fastrcnn` is better to use. 17 | -postp - Postprocessing method. Can be {} . `rtb-bnb` is better to use. 18 | -m - can be {}. u2net is better to use. 19 | DeepLab models (xception_model or mobile_net_model) are outdated 20 | and designed to remove the background from PORTRAIT photos or PHOTOS WITH ANIMALS! 21 | """.format(NAME, DESCRIPTION, LICENSE, 22 | ' or '.join(MODELS_NAMES), 23 | ' or '.join(PREPROCESS_METHODS), 24 | ' or '.join(POSTPROCESS_METHODS)) 25 | -------------------------------------------------------------------------------- /libs/u2net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class REBNCONV(nn.Module): 7 | def __init__(self, in_ch=3, out_ch=3, dirate=1): 8 | super(REBNCONV, self).__init__() 9 | 10 | self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate) 11 | self.bn_s1 = nn.BatchNorm2d(out_ch) 12 | self.relu_s1 = nn.ReLU(inplace=True) 13 | 14 | def forward(self, x): 15 | hx = x 16 | xout = self.relu_s1(self.bn_s1(self.conv_s1(hx))) 17 | 18 | return xout 19 | 20 | 21 | # upsample tensor 'src' to have the same spatial size with tensor 'tar' 22 | def _upsample_like(src, tar): 23 | src = F.interpolate(src, size=tar.shape[2:], mode='bilinear', align_corners=False) 24 | 25 | return src 26 | 27 | 28 | # RSU-7 29 | class RSU7(nn.Module): # UNet07DRES(nn.Module): 30 | 31 | def __init__(self, in_ch=3, mid_ch=12, out_ch=3): 32 | super(RSU7, self).__init__() 33 | 34 | self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) 35 | 36 | self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) 37 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 38 | 39 | self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) 40 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 41 | 42 | self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) 43 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 44 | 45 | self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) 46 | self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 47 | 48 | self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1) 49 | self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 50 | 51 | self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1) 52 | 53 | self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2) 54 | 55 | self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 56 | self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 57 | self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 58 | self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 59 | self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 60 | self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) 61 | 62 | def forward(self, x): 63 | hx = x 64 | hxin = self.rebnconvin(hx) 65 | 66 | hx1 = self.rebnconv1(hxin) 67 | hx = self.pool1(hx1) 68 | 69 | hx2 = self.rebnconv2(hx) 70 | hx = self.pool2(hx2) 71 | 72 | hx3 = self.rebnconv3(hx) 73 | hx = self.pool3(hx3) 74 | 75 | hx4 = self.rebnconv4(hx) 76 | hx = self.pool4(hx4) 77 | 78 | hx5 = self.rebnconv5(hx) 79 | hx = self.pool5(hx5) 80 | 81 | hx6 = self.rebnconv6(hx) 82 | 83 | hx7 = self.rebnconv7(hx6) 84 | 85 | hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1)) 86 | hx6dup = _upsample_like(hx6d, hx5) 87 | 88 | hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1)) 89 | hx5dup = _upsample_like(hx5d, hx4) 90 | 91 | hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1)) 92 | hx4dup = _upsample_like(hx4d, hx3) 93 | 94 | hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) 95 | hx3dup = _upsample_like(hx3d, hx2) 96 | 97 | hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) 98 | hx2dup = _upsample_like(hx2d, hx1) 99 | 100 | hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) 101 | 102 | return hx1d + hxin 103 | 104 | 105 | ### RSU-6 ### 106 | class RSU6(nn.Module): # UNet06DRES(nn.Module): 107 | 108 | def __init__(self, in_ch=3, mid_ch=12, out_ch=3): 109 | super(RSU6, self).__init__() 110 | 111 | self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) 112 | 113 | self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) 114 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 115 | 116 | self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) 117 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 118 | 119 | self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) 120 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 121 | 122 | self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) 123 | self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 124 | 125 | self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1) 126 | 127 | self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2) 128 | 129 | self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 130 | self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 131 | self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 132 | self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 133 | self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) 134 | 135 | def forward(self, x): 136 | hx = x 137 | 138 | hxin = self.rebnconvin(hx) 139 | 140 | hx1 = self.rebnconv1(hxin) 141 | hx = self.pool1(hx1) 142 | 143 | hx2 = self.rebnconv2(hx) 144 | hx = self.pool2(hx2) 145 | 146 | hx3 = self.rebnconv3(hx) 147 | hx = self.pool3(hx3) 148 | 149 | hx4 = self.rebnconv4(hx) 150 | hx = self.pool4(hx4) 151 | 152 | hx5 = self.rebnconv5(hx) 153 | 154 | hx6 = self.rebnconv6(hx5) 155 | 156 | hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1)) 157 | hx5dup = _upsample_like(hx5d, hx4) 158 | 159 | hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1)) 160 | hx4dup = _upsample_like(hx4d, hx3) 161 | 162 | hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) 163 | hx3dup = _upsample_like(hx3d, hx2) 164 | 165 | hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) 166 | hx2dup = _upsample_like(hx2d, hx1) 167 | 168 | hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) 169 | 170 | return hx1d + hxin 171 | 172 | 173 | ### RSU-5 ### 174 | class RSU5(nn.Module): # UNet05DRES(nn.Module): 175 | 176 | def __init__(self, in_ch=3, mid_ch=12, out_ch=3): 177 | super(RSU5, self).__init__() 178 | 179 | self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) 180 | 181 | self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) 182 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 183 | 184 | self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) 185 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 186 | 187 | self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) 188 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 189 | 190 | self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) 191 | 192 | self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2) 193 | 194 | self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 195 | self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 196 | self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 197 | self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) 198 | 199 | def forward(self, x): 200 | hx = x 201 | 202 | hxin = self.rebnconvin(hx) 203 | 204 | hx1 = self.rebnconv1(hxin) 205 | hx = self.pool1(hx1) 206 | 207 | hx2 = self.rebnconv2(hx) 208 | hx = self.pool2(hx2) 209 | 210 | hx3 = self.rebnconv3(hx) 211 | hx = self.pool3(hx3) 212 | 213 | hx4 = self.rebnconv4(hx) 214 | 215 | hx5 = self.rebnconv5(hx4) 216 | 217 | hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1)) 218 | hx4dup = _upsample_like(hx4d, hx3) 219 | 220 | hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) 221 | hx3dup = _upsample_like(hx3d, hx2) 222 | 223 | hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) 224 | hx2dup = _upsample_like(hx2d, hx1) 225 | 226 | hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) 227 | 228 | return hx1d + hxin 229 | 230 | 231 | ### RSU-4 ### 232 | class RSU4(nn.Module): # UNet04DRES(nn.Module): 233 | 234 | def __init__(self, in_ch=3, mid_ch=12, out_ch=3): 235 | super(RSU4, self).__init__() 236 | 237 | self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) 238 | 239 | self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) 240 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 241 | 242 | self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) 243 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 244 | 245 | self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) 246 | 247 | self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2) 248 | 249 | self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 250 | self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 251 | self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) 252 | 253 | def forward(self, x): 254 | hx = x 255 | 256 | hxin = self.rebnconvin(hx) 257 | 258 | hx1 = self.rebnconv1(hxin) 259 | hx = self.pool1(hx1) 260 | 261 | hx2 = self.rebnconv2(hx) 262 | hx = self.pool2(hx2) 263 | 264 | hx3 = self.rebnconv3(hx) 265 | 266 | hx4 = self.rebnconv4(hx3) 267 | 268 | hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1)) 269 | hx3dup = _upsample_like(hx3d, hx2) 270 | 271 | hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) 272 | hx2dup = _upsample_like(hx2d, hx1) 273 | 274 | hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) 275 | 276 | return hx1d + hxin 277 | 278 | 279 | ### RSU-4F ### 280 | class RSU4F(nn.Module): # UNet04FRES(nn.Module): 281 | 282 | def __init__(self, in_ch=3, mid_ch=12, out_ch=3): 283 | super(RSU4F, self).__init__() 284 | 285 | self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) 286 | 287 | self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) 288 | self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2) 289 | self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4) 290 | 291 | self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8) 292 | 293 | self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4) 294 | self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2) 295 | self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) 296 | 297 | def forward(self, x): 298 | hx = x 299 | 300 | hxin = self.rebnconvin(hx) 301 | 302 | hx1 = self.rebnconv1(hxin) 303 | hx2 = self.rebnconv2(hx1) 304 | hx3 = self.rebnconv3(hx2) 305 | 306 | hx4 = self.rebnconv4(hx3) 307 | 308 | hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1)) 309 | hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1)) 310 | hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1)) 311 | 312 | return hx1d + hxin 313 | 314 | 315 | ##### U^2-Net #### 316 | class U2NET(nn.Module): 317 | 318 | def __init__(self, in_ch=3, out_ch=1): 319 | super(U2NET, self).__init__() 320 | 321 | self.stage1 = RSU7(in_ch, 32, 64) 322 | self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 323 | 324 | self.stage2 = RSU6(64, 32, 128) 325 | self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 326 | 327 | self.stage3 = RSU5(128, 64, 256) 328 | self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 329 | 330 | self.stage4 = RSU4(256, 128, 512) 331 | self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 332 | 333 | self.stage5 = RSU4F(512, 256, 512) 334 | self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 335 | 336 | self.stage6 = RSU4F(512, 256, 512) 337 | 338 | # decoder 339 | self.stage5d = RSU4F(1024, 256, 512) 340 | self.stage4d = RSU4(1024, 128, 256) 341 | self.stage3d = RSU5(512, 64, 128) 342 | self.stage2d = RSU6(256, 32, 64) 343 | self.stage1d = RSU7(128, 16, 64) 344 | 345 | self.side1 = nn.Conv2d(64, out_ch, 3, padding=1) 346 | self.side2 = nn.Conv2d(64, out_ch, 3, padding=1) 347 | self.side3 = nn.Conv2d(128, out_ch, 3, padding=1) 348 | self.side4 = nn.Conv2d(256, out_ch, 3, padding=1) 349 | self.side5 = nn.Conv2d(512, out_ch, 3, padding=1) 350 | self.side6 = nn.Conv2d(512, out_ch, 3, padding=1) 351 | 352 | self.outconv = nn.Conv2d(6, out_ch, 1) 353 | 354 | def forward(self, x): 355 | hx = x 356 | 357 | # stage 1 358 | hx1 = self.stage1(hx) 359 | hx = self.pool12(hx1) 360 | 361 | # stage 2 362 | hx2 = self.stage2(hx) 363 | hx = self.pool23(hx2) 364 | 365 | # stage 3 366 | hx3 = self.stage3(hx) 367 | hx = self.pool34(hx3) 368 | 369 | # stage 4 370 | hx4 = self.stage4(hx) 371 | hx = self.pool45(hx4) 372 | 373 | # stage 5 374 | hx5 = self.stage5(hx) 375 | hx = self.pool56(hx5) 376 | 377 | # stage 6 378 | hx6 = self.stage6(hx) 379 | hx6up = _upsample_like(hx6, hx5) 380 | 381 | # -------------------- decoder -------------------- 382 | hx5d = self.stage5d(torch.cat((hx6up, hx5), 1)) 383 | hx5dup = _upsample_like(hx5d, hx4) 384 | 385 | hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1)) 386 | hx4dup = _upsample_like(hx4d, hx3) 387 | 388 | hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1)) 389 | hx3dup = _upsample_like(hx3d, hx2) 390 | 391 | hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1)) 392 | hx2dup = _upsample_like(hx2d, hx1) 393 | 394 | hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1)) 395 | 396 | # side output 397 | d1 = self.side1(hx1d) 398 | 399 | d2 = self.side2(hx2d) 400 | d2 = _upsample_like(d2, d1) 401 | 402 | d3 = self.side3(hx3d) 403 | d3 = _upsample_like(d3, d1) 404 | 405 | d4 = self.side4(hx4d) 406 | d4 = _upsample_like(d4, d1) 407 | 408 | d5 = self.side5(hx5d) 409 | d5 = _upsample_like(d5, d1) 410 | 411 | d6 = self.side6(hx6) 412 | d6 = _upsample_like(d6, d1) 413 | 414 | d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1)) 415 | 416 | return torch.sigmoid(d0), torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid( 417 | d4), torch.sigmoid(d5), torch.sigmoid(d6) 418 | 419 | 420 | ### U^2-Net small ### 421 | class U2NETP(nn.Module): 422 | 423 | def __init__(self, in_ch=3, out_ch=1): 424 | super(U2NETP, self).__init__() 425 | 426 | self.stage1 = RSU7(in_ch, 16, 64) 427 | self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 428 | 429 | self.stage2 = RSU6(64, 16, 64) 430 | self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 431 | 432 | self.stage3 = RSU5(64, 16, 64) 433 | self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 434 | 435 | self.stage4 = RSU4(64, 16, 64) 436 | self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 437 | 438 | self.stage5 = RSU4F(64, 16, 64) 439 | self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 440 | 441 | self.stage6 = RSU4F(64, 16, 64) 442 | 443 | # decoder 444 | self.stage5d = RSU4F(128, 16, 64) 445 | self.stage4d = RSU4(128, 16, 64) 446 | self.stage3d = RSU5(128, 16, 64) 447 | self.stage2d = RSU6(128, 16, 64) 448 | self.stage1d = RSU7(128, 16, 64) 449 | 450 | self.side1 = nn.Conv2d(64, out_ch, 3, padding=1) 451 | self.side2 = nn.Conv2d(64, out_ch, 3, padding=1) 452 | self.side3 = nn.Conv2d(64, out_ch, 3, padding=1) 453 | self.side4 = nn.Conv2d(64, out_ch, 3, padding=1) 454 | self.side5 = nn.Conv2d(64, out_ch, 3, padding=1) 455 | self.side6 = nn.Conv2d(64, out_ch, 3, padding=1) 456 | 457 | self.outconv = nn.Conv2d(6, out_ch, 1) 458 | 459 | def forward(self, x): 460 | hx = x 461 | 462 | # stage 1 463 | hx1 = self.stage1(hx) 464 | hx = self.pool12(hx1) 465 | 466 | # stage 2 467 | hx2 = self.stage2(hx) 468 | hx = self.pool23(hx2) 469 | 470 | # stage 3 471 | hx3 = self.stage3(hx) 472 | hx = self.pool34(hx3) 473 | 474 | # stage 4 475 | hx4 = self.stage4(hx) 476 | hx = self.pool45(hx4) 477 | 478 | # stage 5 479 | hx5 = self.stage5(hx) 480 | hx = self.pool56(hx5) 481 | 482 | # stage 6 483 | hx6 = self.stage6(hx) 484 | hx6up = _upsample_like(hx6, hx5) 485 | 486 | # decoder 487 | hx5d = self.stage5d(torch.cat((hx6up, hx5), 1)) 488 | hx5dup = _upsample_like(hx5d, hx4) 489 | 490 | hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1)) 491 | hx4dup = _upsample_like(hx4d, hx3) 492 | 493 | hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1)) 494 | hx3dup = _upsample_like(hx3d, hx2) 495 | 496 | hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1)) 497 | hx2dup = _upsample_like(hx2d, hx1) 498 | 499 | hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1)) 500 | 501 | # side output 502 | d1 = self.side1(hx1d) 503 | 504 | d2 = self.side2(hx2d) 505 | d2 = _upsample_like(d2, d1) 506 | 507 | d3 = self.side3(hx3d) 508 | d3 = _upsample_like(d3, d1) 509 | 510 | d4 = self.side4(hx4d) 511 | d4 = _upsample_like(d4, d1) 512 | 513 | d5 = self.side5(hx5d) 514 | d5 = _upsample_like(d5, d1) 515 | 516 | d6 = self.side6(hx6) 517 | d6 = _upsample_like(d6, d1) 518 | 519 | d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1)) 520 | 521 | return torch.sigmoid(d0), torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid( 522 | d4), torch.sigmoid(d5), torch.sigmoid(d6) 523 | -------------------------------------------------------------------------------- /manage.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Django's command-line utility for administrative tasks.""" 3 | import os 4 | import sys 5 | 6 | 7 | def main(): 8 | os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'bgRemover.settings') 9 | try: 10 | from django.core.management import execute_from_command_line 11 | except ImportError as exc: 12 | raise ImportError( 13 | "Couldn't import Django. Are you sure it's installed and " 14 | "available on your PYTHONPATH environment variable? Did you " 15 | "forget to activate a virtual environment?" 16 | ) from exc 17 | execute_from_command_line(sys.argv) 18 | 19 | 20 | if __name__ == '__main__': 21 | main() 22 | -------------------------------------------------------------------------------- /passenger_wsgi.py: -------------------------------------------------------------------------------- 1 | import bgRemover.wsgi 2 | application = bgRemover.wsgi.application 3 | -------------------------------------------------------------------------------- /removerML/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FarjaalAhmad/django_bgRemoverML/787737269eb7724481d667d30c6f502812759037/removerML/__init__.py -------------------------------------------------------------------------------- /removerML/admin.py: -------------------------------------------------------------------------------- 1 | from django.contrib import admin 2 | 3 | # Register your models here. 4 | -------------------------------------------------------------------------------- /removerML/apps.py: -------------------------------------------------------------------------------- 1 | from django.apps import AppConfig 2 | 3 | 4 | class RemovermlConfig(AppConfig): 5 | name = 'removerML' 6 | -------------------------------------------------------------------------------- /removerML/migrations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FarjaalAhmad/django_bgRemoverML/787737269eb7724481d667d30c6f502812759037/removerML/migrations/__init__.py -------------------------------------------------------------------------------- /removerML/models.py: -------------------------------------------------------------------------------- 1 | from django.db import models 2 | 3 | # Create your models here. 4 | -------------------------------------------------------------------------------- /removerML/remover.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import gc 4 | import tqdm 5 | import logging 6 | from libs.strings import * 7 | from libs.networks import model_detect 8 | import libs.preprocessing as preprocessing 9 | import libs.postprocessing as postprocessing 10 | 11 | logging.basicConfig(level=logging.ERROR) 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | def __work_mode__(path: str): 16 | """Determines the desired mode of operation""" 17 | if os.path.isfile(path): # Input is file 18 | return "file" 19 | if os.path.isdir(path): # Input is dir 20 | return "dir" 21 | else: 22 | return "no" 23 | 24 | 25 | def __save_image_file__(img, file_name, output_path, wmode): 26 | """ 27 | Saves the PIL image to a file 28 | :param img: PIL image 29 | :param file_name: File name 30 | :param output_path: Output path 31 | :param wmode: Work mode 32 | """ 33 | # create output directory if it doesn't exist 34 | folder = os.path.dirname(output_path) 35 | if folder != '': 36 | os.makedirs(folder, exist_ok=True) 37 | if wmode == "file": 38 | file_name_out = os.path.basename(output_path) 39 | if file_name_out == '': 40 | # Change file extension to png 41 | file_name = os.path.splitext(file_name)[0] + '.png' 42 | # Save image 43 | img.save(os.path.join(output_path, file_name)) 44 | gc.collect() 45 | else: 46 | try: 47 | # Save image 48 | img.save(output_path) 49 | gc.collect() 50 | except OSError as e: 51 | if str(e) == "cannot write mode RGBA as JPEG": 52 | raise OSError("Error! " 53 | "Please indicate the correct extension of the final file, for example: .png") 54 | else: 55 | raise e 56 | else: 57 | # Change file extension to png 58 | file_name = os.path.splitext(file_name)[0] + '.png' 59 | # Save image 60 | img.save(os.path.join(output_path, file_name)) 61 | gc.collect() 62 | 63 | 64 | def process(input_path, output_path, model_name="u2net", 65 | preprocessing_method_name="bbd-fastrcnn", postprocessing_method_name="rtb-bnb"): 66 | """ 67 | Processes the file. 68 | :param input_path: The path to the image / folder with the images to be processed. 69 | :param output_path: The path to the save location. 70 | :param model_name: Model to use. 71 | :param postprocessing_method_name: Method for image preprocessing 72 | :param preprocessing_method_name: Method for image post-processing 73 | """ 74 | if input_path is None or output_path is None: 75 | raise Exception( 76 | "Bad parameters! Please specify input path and output path.") 77 | 78 | model = model_detect(model_name) # Load model 79 | if not model: 80 | logger.warning("Warning! You specified an invalid model type. " 81 | "For image processing, the model with the best processing quality will be used. " 82 | "(u2net)") 83 | # If the model line is wrong, select the model with better quality. 84 | model_name = "u2net" 85 | model = model_detect(model_name) # Load model 86 | preprocessing_method = preprocessing.method_detect( 87 | preprocessing_method_name) 88 | postprocessing_method = postprocessing.method_detect( 89 | postprocessing_method_name) 90 | wmode = __work_mode__(input_path) # Get work mode 91 | if wmode == "file": # File work mode 92 | image = model.process_image( 93 | input_path, preprocessing_method, postprocessing_method) 94 | __save_image_file__(image, os.path.basename( 95 | input_path), output_path, wmode) 96 | elif wmode == "dir": # Dir work mode 97 | # Start process 98 | files = os.listdir(input_path) 99 | for file in tqdm.tqdm(files, ascii=True, desc='Remove Background', unit='image'): 100 | file_path = os.path.join(input_path, file) 101 | image = model.process_image( 102 | file_path, preprocessing_method, postprocessing_method) 103 | __save_image_file__(image, file, output_path, wmode) 104 | else: 105 | raise Exception( 106 | "Bad input parameter! Please indicate the correct path to the file or folder.") 107 | -------------------------------------------------------------------------------- /removerML/static/css/styles.css: -------------------------------------------------------------------------------- 1 | @import url("https://fonts.googleapis.com/css2?family=DM+Sans:wght@500&family=Staatliches&display=swap"); 2 | body { 3 | background-color: #ffc857; 4 | } 5 | hr { 6 | background-color: #ffc857; 7 | } 8 | .main-container { 9 | margin: auto; 10 | width: 50%; 11 | height: 80%; 12 | padding: 10px; 13 | margin: 20px auto; 14 | background-color: #323031; 15 | border-radius: 3px; 16 | box-shadow: 6px 2px 30px 0px rgba(0, 0, 0, 0.75); 17 | font-family: "DM Sans", sans-serif; 18 | word-wrap: break-word; 19 | color: #ffc857; 20 | } 21 | 22 | input[type="file"] { 23 | font-family: "DM Sans", sans-serif; 24 | word-wrap: break-word; 25 | } 26 | h1 { 27 | font-size: 70px; 28 | font-family: "Staatliches", cursive; 29 | color: #ffc857; 30 | } 31 | 32 | .upload-img-div { 33 | padding: 5%; 34 | } 35 | -------------------------------------------------------------------------------- /removerML/templates/removerML/index.html: -------------------------------------------------------------------------------- 1 | {% load static %} 2 | 3 | 4 | 5 | 6 | 8 | 9 | BG Remover 10 | 11 | 12 | 13 |
14 |
15 |
16 |
17 |

18 | BACKGROUND REMOVER 19 |

20 |
21 | 22 |
23 | 24 | 25 | {% if image_path %} 26 |
27 |
Your Image : 28 | image here 29 |
30 | {% endif %} 31 |
32 | 33 |
34 | 35 | 36 | -------------------------------------------------------------------------------- /removerML/tests.py: -------------------------------------------------------------------------------- 1 | from django.test import TestCase 2 | 3 | # Create your tests here. 4 | -------------------------------------------------------------------------------- /removerML/urls.py: -------------------------------------------------------------------------------- 1 | from django.conf.urls.static import static 2 | from django.conf import settings 3 | from django.urls import path 4 | from . import views 5 | 6 | urlpatterns = [ 7 | # path('admin/', admin.site.urls), 8 | path("", views.index, name="index"), 9 | path("upload", views.data, name="data") 10 | ] 11 | urlpatterns += static(settings.MEDIA_URL, document_root=settings.MEDIA_ROOT) 12 | -------------------------------------------------------------------------------- /removerML/views.py: -------------------------------------------------------------------------------- 1 | from django.shortcuts import render, HttpResponse 2 | from django.core.files.storage import FileSystemStorage 3 | import os, datetime, base64 4 | from . import remover 5 | 6 | extensions = ['.jpg', '.jpeg', '.png'] 7 | 8 | def index(request): 9 | if request.method == 'POST' and request.FILES['image']: 10 | image = request.FILES['image'] 11 | ext = os.path.splitext(image.name)[1] 12 | if ext.lower() in extensions: 13 | fs = FileSystemStorage() 14 | filename = fs.save(image.name, image) 15 | uploaded_file_url = fs.url(filename) 16 | file_name = uploaded_file_url.split("/")[2] 17 | 18 | input_path = os.getcwd() + uploaded_file_url 19 | output_path = os.getcwd() + "/uploads/" + file_name.split(".")[0] + "_processed.png" 20 | 21 | remover.process(input_path, output_path) 22 | image_path = uploaded_file_url.split(".")[0] + "_processed.png" 23 | return render(request, 'removerML/index.html', {"image_path": image_path}) 24 | else: 25 | return HttpResponse("Only Allowed extensions are {}".format(extensions)) 26 | return render(request, 'removerML/index.html') 27 | 28 | def data(request): 29 | if request.method == 'POST' and request.POST['image']: 30 | image = request.POST['image'] 31 | data = base64.b64decode(image) 32 | image_name = datetime.datetime.now().strftime("%Y%b%d%H%M%S%f") + ".jpg" 33 | image_path = os.getcwd() + "/uploads/" + image_name 34 | 35 | with open(image_path, "wb") as f: 36 | f.write(data) 37 | 38 | input_path = os.getcwd() + "/uploads/" + image_name 39 | output_path = os.getcwd() + "/uploads/" + image_name.split(".")[0] + "_processed.png" 40 | 41 | remover.process(input_path, output_path) 42 | image_path = "/uploads/" + image_name.split(".")[0] + "_processed.png" 43 | return HttpResponse(request.get_host() + image_path) 44 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.9.0 2 | astunparse==1.6.3 3 | cachetools==4.1.1 4 | certifi==2020.6.20 5 | chardet==3.0.4 6 | cycler==0.10.0 7 | decorator==4.4.2 8 | django==3.1.12 9 | filelock==3.0.12 10 | future==0.18.2 11 | gast==0.3.3 12 | gdown==3.11.1 13 | gluoncv==0.7.0 14 | google-auth==1.18.0 15 | google-auth-oauthlib==0.4.1 16 | google-pasta==0.2.0 17 | graphviz==0.8.4 18 | grpcio==1.30.0 19 | h5py==2.10.0 20 | idna==2.10 21 | imageio==2.9.0 22 | Keras-Preprocessing==1.1.2 23 | kiwisolver==1.2.0 24 | Markdown==3.2.2 25 | matplotlib==3.2.2 26 | mxnet==1.6.0 27 | networkx==2.4 28 | numpy==1.18.5 29 | oauthlib==3.1.0 30 | opencv-python==4.3.0.36 31 | opt-einsum==3.2.1 32 | Pillow==8.2.0 33 | portalocker==1.7.0 34 | protobuf==3.12.2 35 | pyasn1==0.4.8 36 | pyasn1-modules==0.2.8 37 | pyparsing==2.4.7 38 | PySocks==1.7.1 39 | python-dateutil==2.8.1 40 | PyWavelets==1.1.1 41 | requests==2.24.0 42 | requests-oauthlib==1.3.0 43 | rsa==4.7 44 | scikit-image==0.17.2 45 | scipy==1.4.1 46 | six==1.15.0 47 | tensorboard==2.2.2 48 | tensorboard-plugin-wit==1.7.0 49 | tensorflow==2.5.0 50 | tensorflow-estimator==2.2.0 51 | termcolor==1.1.0 52 | tifffile==2020.7.4 53 | torch==1.5.1 54 | tqdm==4.47.0 55 | urllib3==1.26.5 56 | Werkzeug==1.0.1 57 | wrapt==1.12.1 58 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | WORKING_DIR=$(dirname "$(readlink -f "$0")") 4 | echo "WORKING_DIR: ${WORKING_DIR}" 5 | 6 | declare -a urls=( 7 | "http://download.tensorflow.org/models/deeplabv3_mnv2_pascal_train_aug_2018_01_29.tar.gz" 8 | "http://download.tensorflow.org/models/deeplabv3_pascal_train_aug_2018_01_04.tar.gz" 9 | ) 10 | 11 | for url in ${urls[@]}; do 12 | FILENAME=$(basename ${url}) 13 | FILE_PATH="${WORKING_DIR}/${FILENAME}" 14 | if [[ ! -f "${FILE_PATH}" ]]; then 15 | wget ${url} 16 | fi 17 | 18 | case ${FILENAME} in 19 | "deeplabv3_mnv2_pascal_train_aug_2018_01_29.tar.gz" ) 20 | EXTRACT_PATH="${WORKING_DIR}/models/mobile_net_model/model" 21 | ;; 22 | "deeplabv3_pascal_train_aug_2018_01_04.tar.gz") 23 | EXTRACT_PATH="${WORKING_DIR}/models/xception_model/model" 24 | ;; 25 | esac 26 | 27 | if [[ ! -d "${EXTRACT_PATH}" ]]; then 28 | mkdir -p ${EXTRACT_PATH} 29 | fi 30 | tar xvzf "${FILE_PATH}" -C ${EXTRACT_PATH} --strip=1 31 | rm ${FILE_PATH} 32 | done 33 | 34 | PYTHON_PATH="$(which python3)" 35 | SETUP_DIR="${WORKING_DIR}/setup/" 36 | DOWNLOAD_SCRIPT_PY="${SETUP_DIR}/download.py" 37 | 38 | # Install gdown if not installed, without this package globaly the installation will fail 39 | ${PYTHON_PATH} -m pip install gdown 40 | 41 | if [[ -f ${DOWNLOAD_SCRIPT_PY} ]]; then 42 | ${PYTHON_PATH} ${DOWNLOAD_SCRIPT_PY} 43 | else 44 | echo "${DOWNLOAD_SCRIPT_PY}, not found!" 45 | fi 46 | 47 | declare -A files 48 | files=( 49 | ["basnet.pth"]="/models/basnet/" 50 | ["u2net.pth"]="/models/u2net/" 51 | ["u2netp.pth"]="/models/u2netp/" 52 | ) 53 | 54 | for file in ${!files[@]}; do 55 | FILE_PATH="${WORKING_DIR}/${file}" 56 | DESTINATION_DIR="${WORKING_DIR}/${files[${file}]}" 57 | if [[ -f ${FILE_PATH} ]]; then 58 | echo "Founded file: ${file}" 59 | 60 | if [[ ! -d "${DESTINATION_DIR}" ]]; then 61 | echo "Created dir: ${DESTINATION_DIR}" 62 | mkdir -p "${DESTINATION_DIR}" 63 | fi 64 | if mv ${FILE_PATH} ${DESTINATION_DIR}; then 65 | echo "${FILE_PATH} moved to ${DESTINATION_DIR}" 66 | else 67 | echo "Error while move file: ${FILE_PATH} to ${DESTINATION_DIR}" 68 | fi 69 | fi 70 | done 71 | 72 | rm -rf "${SETUP_DIR}" 73 | -------------------------------------------------------------------------------- /setup/download.py: -------------------------------------------------------------------------------- 1 | import gdown 2 | 3 | url = 'https://drive.google.com/uc?id=1s52ek_4YTDRt_EOkx1FS53u-vJa0c4nu' 4 | output = 'basnet.pth' 5 | gdown.download(url, output, quiet=False) 6 | gdown.cached_download(url, output, postprocess=gdown.extractall) 7 | 8 | url = 'https://drive.google.com/uc?id=1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy' 9 | output = 'u2netp.pth' 10 | gdown.download(url, output, quiet=False) 11 | gdown.cached_download(url, output, postprocess=gdown.extractall) 12 | 13 | url = 'https://drive.google.com/uc?id=1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ' 14 | output = 'u2net.pth' 15 | gdown.download(url, output, quiet=False) 16 | gdown.cached_download(url, output, postprocess=gdown.extractall) 17 | -------------------------------------------------------------------------------- /uploads/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FarjaalAhmad/django_bgRemoverML/787737269eb7724481d667d30c6f502812759037/uploads/.gitignore --------------------------------------------------------------------------------