├── .assets ├── 5.png ├── 6.png └── 7.png ├── .gitignore ├── BreezeStyleSheets ├── .gitignore ├── LICENSE.md ├── README.md ├── __init__.py ├── assets │ └── Breeze.gif ├── breeze.qrc ├── breeze_resources.py ├── dark.py ├── dark.qss ├── dark │ ├── branch_closed-on.svg │ ├── branch_closed.svg │ ├── branch_open-on.svg │ ├── branch_open.svg │ ├── checkbox_checked.svg │ ├── checkbox_checked_disabled.svg │ ├── checkbox_indeterminate.svg │ ├── checkbox_indeterminate_disabled.svg │ ├── checkbox_unchecked.svg │ ├── checkbox_unchecked_disabled.svg │ ├── close-hover.svg │ ├── close-pressed.svg │ ├── close.svg │ ├── down_arrow-hover.svg │ ├── down_arrow.svg │ ├── down_arrow_disabled.svg │ ├── hmovetoolbar.svg │ ├── hsepartoolbar.svg │ ├── left_arrow.svg │ ├── left_arrow_disabled.svg │ ├── radio_checked.svg │ ├── radio_checked_disabled.svg │ ├── radio_unchecked.svg │ ├── radio_unchecked_disabled.svg │ ├── right_arrow.svg │ ├── right_arrow_disabled.svg │ ├── sizegrip.svg │ ├── spinup_disabled.svg │ ├── stylesheet-branch-end-closed.svg │ ├── stylesheet-branch-end-open.svg │ ├── stylesheet-branch-end.svg │ ├── stylesheet-branch-more.svg │ ├── stylesheet-vline.svg │ ├── transparent.svg │ ├── undock-hover.svg │ ├── undock.svg │ ├── up_arrow-hover.svg │ ├── up_arrow.svg │ ├── up_arrow_disabled.svg │ ├── vmovetoolbar.svg │ └── vsepartoolbars.svg ├── example.py ├── light.py ├── light.qss ├── light │ ├── branch_closed-on.svg │ ├── branch_closed.svg │ ├── branch_open-on.svg │ ├── branch_open.svg │ ├── checkbox_checked-hover.svg │ ├── checkbox_checked.svg │ ├── checkbox_checked_disabled.svg │ ├── checkbox_indeterminate-hover.svg │ ├── checkbox_indeterminate.svg │ ├── checkbox_indeterminate_disabled.svg │ ├── checkbox_unchecked-hover.svg │ ├── checkbox_unchecked_disabled.svg │ ├── close-hover.svg │ ├── close-pressed.svg │ ├── close.svg │ ├── down_arrow-hover.svg │ ├── down_arrow.svg │ ├── down_arrow_disabled.svg │ ├── hmovetoolbar.svg │ ├── hsepartoolbar.svg │ ├── left_arrow.svg │ ├── left_arrow_disabled.svg │ ├── radio_checked-hover.svg │ ├── radio_checked.svg │ ├── radio_checked_disabled.svg │ ├── radio_unchecked-hover.svg │ ├── radio_unchecked_disabled.svg │ ├── right_arrow.svg │ ├── right_arrow_disabled.svg │ ├── sizegrip.svg │ ├── spinup_disabled.svg │ ├── stylesheet-branch-end-closed.svg │ ├── stylesheet-branch-end-open.svg │ ├── stylesheet-branch-end.svg │ ├── stylesheet-branch-more.svg │ ├── stylesheet-vline.svg │ ├── transparent.svg │ ├── undock-hover.svg │ ├── undock.svg │ ├── up_arrow-hover.svg │ ├── up_arrow.svg │ ├── up_arrow_disabled.svg │ ├── vmovetoolbar.svg │ └── vsepartoolbars.svg └── native.py ├── LICENSE ├── README.md ├── config ├── __init__.py ├── ade20k-hrnetv2.yaml ├── ade20k-mobilenetv2dilated-c1_deepsup.yaml ├── ade20k-resnet101-upernet.yaml ├── ade20k-resnet101dilated-ppm_deepsup.yaml ├── ade20k-resnet18dilated-ppm_deepsup.yaml ├── ade20k-resnet50-upernet.yaml ├── ade20k-resnet50dilated-ppm_deepsup.yaml └── defaults.py ├── data ├── color150.mat ├── object150_info.csv ├── training.odgt └── validation.odgt ├── dataset.py ├── gui.py ├── gui.ui ├── gui_main.py ├── lib ├── nn │ ├── __init__.py │ ├── modules │ │ ├── __init__.py │ │ ├── batchnorm.py │ │ ├── comm.py │ │ ├── replicate.py │ │ ├── tests │ │ │ ├── test_numeric_batchnorm.py │ │ │ └── test_sync_batchnorm.py │ │ └── unittest.py │ └── parallel │ │ ├── __init__.py │ │ └── data_parallel.py └── utils │ ├── __init__.py │ ├── data │ ├── __init__.py │ ├── dataloader.py │ ├── dataset.py │ ├── distributed.py │ └── sampler.py │ └── th.py ├── models ├── __init__.py ├── hrnet.py ├── mobilenet.py ├── models.py ├── models.py.bak ├── resnet.py ├── resnext.py └── utils.py ├── pics ├── 42_170616140840_1.jpg ├── 5DDED2D65C4EB2FAE3FC89CD64D376A5.jpg ├── free_stock_photo.jpg └── timg.jpg ├── requirement.txt ├── test.py └── utils.py /.assets/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tea321000/image_semantic_segmentation/80409858e8d44c1b24695035abf3a68fd83e9574/.assets/5.png -------------------------------------------------------------------------------- /.assets/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tea321000/image_semantic_segmentation/80409858e8d44c1b24695035abf3a68fd83e9574/.assets/6.png -------------------------------------------------------------------------------- /.assets/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tea321000/image_semantic_segmentation/80409858e8d44c1b24695035abf3a68fd83e9574/.assets/7.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | 3 | ckpt/ 4 | vis/ 5 | log/ 6 | pretrained/ 7 | .idea/ 8 | decoder_epoch*.pth 9 | encoder_epoch*.pth -------------------------------------------------------------------------------- /BreezeStyleSheets/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.pyc 3 | -------------------------------------------------------------------------------- /BreezeStyleSheets/LICENSE.md: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | ===================== 3 | 4 | Copyright © `<2013-2014>` `` 5 | Copyright © `<2015-2016>` `` 6 | 7 | Permission is hereby granted, free of charge, to any person 8 | obtaining a copy of this software and associated documentation 9 | files (the “Software”), to deal in the Software without 10 | restriction, including without limitation the rights to use, 11 | copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | copies of the Software, and to permit persons to whom the 13 | Software is furnished to do so, subject to the following 14 | conditions: 15 | 16 | The above copyright notice and this permission notice shall be 17 | included in all copies or substantial portions of the Software. 18 | 19 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, 20 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES 21 | OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 22 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT 23 | HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 24 | WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 25 | FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 26 | OTHER DEALINGS IN THE SOFTWARE. 27 | -------------------------------------------------------------------------------- /BreezeStyleSheets/README.md: -------------------------------------------------------------------------------- 1 | BreezeStyleSheets 2 | ================= 3 | 4 | Breeze and BreezeDark-like stylesheets for Qt Applications. 5 | 6 | C++ Installation 7 | ================ 8 | 9 | Copy `breeze.qrc`, `dark.qss`, `light.qss` and the `dark` and `light` folders into your project directory and add the qrc file to your project file. 10 | 11 | For example: 12 | 13 | ```qmake 14 | TARGET = app 15 | SOURCES = main.cpp 16 | RESOURCES = breeze.qrc 17 | ``` 18 | 19 | To load the stylesheet in C++, load the file using QFile and read the data. For example, to load BreezeDark, run: 20 | 21 | ```cpp 22 | 23 | #include 24 | #include 25 | #include 26 | 27 | 28 | int main(int argc, char *argv[]) 29 | { 30 | QApplication app(argc, argv); 31 | 32 | // set stylesheet 33 | QFile file(":/dark.qss"); 34 | file.open(QFile::ReadOnly | QFile::Text); 35 | QTextStream stream(&file); 36 | app.setStyleSheet(stream.readAll()); 37 | 38 | // code goes here 39 | 40 | return app.exec(); 41 | } 42 | ``` 43 | 44 | PyQt5 Installation 45 | ================== 46 | 47 | To compile the stylesheet for use with PyQt5, compile with the following command `pyrcc5 breeze.qrc -o breeze_resources.py`, and import the stylesheets. Afterwards, to load the stylesheet in Python, load the file using QFile and read the data. For example, to load BreezeDark, run: 48 | 49 | 50 | ```python 51 | 52 | from PyQt5 import QtWidgets 53 | from PyQt5.QtCore import QFile, QTextStream 54 | import breeze_resources 55 | 56 | 57 | def main(): 58 | app = QtWidgets.QApplication(sys.argv) 59 | 60 | # set stylesheet 61 | file = QFile(":/dark.qss") 62 | file.open(QFile.ReadOnly | QFile.Text) 63 | stream = QTextStream(file) 64 | app.setStyleSheet(stream.readAll()) 65 | 66 | # code goes here 67 | 68 | app.exec_() 69 | } 70 | ``` 71 | 72 | License 73 | ======= 74 | 75 | MIT, see [license](/LICENSE.md). 76 | 77 | Example 78 | ======= 79 | 80 | **Breeze/BreezeDark** 81 | 82 | Example user interface using the Breeze and BreezeDark stylesheets side-by-side. 83 | 84 | ![BreezeDark](/assets/Breeze.gif) 85 | 86 | Acknowledgements 87 | ================ 88 | 89 | BreezeStyleSheets is a fork of [QDarkStyleSheet](https://github.com/ColinDuquesnoy/QDarkStyleSheet). 90 | 91 | Contact 92 | ======= 93 | 94 | Email: ahuszagh@gmail.com 95 | Twitter: KardOnIce 96 | 97 | -------------------------------------------------------------------------------- /BreezeStyleSheets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tea321000/image_semantic_segmentation/80409858e8d44c1b24695035abf3a68fd83e9574/BreezeStyleSheets/__init__.py -------------------------------------------------------------------------------- /BreezeStyleSheets/assets/Breeze.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tea321000/image_semantic_segmentation/80409858e8d44c1b24695035abf3a68fd83e9574/BreezeStyleSheets/assets/Breeze.gif -------------------------------------------------------------------------------- /BreezeStyleSheets/breeze.qrc: -------------------------------------------------------------------------------- 1 | 2 | 3 | light/hmovetoolbar.svg 4 | light/vmovetoolbar.svg 5 | light/hsepartoolbar.svg 6 | light/vsepartoolbars.svg 7 | light/stylesheet-branch-end.svg 8 | light/stylesheet-branch-end-closed.svg 9 | light/stylesheet-branch-end-open.svg 10 | light/stylesheet-vline.svg 11 | light/stylesheet-branch-more.svg 12 | light/branch_closed.svg 13 | light/branch_closed-on.svg 14 | light/branch_open.svg 15 | light/branch_open-on.svg 16 | light/down_arrow.svg 17 | light/down_arrow_disabled.svg 18 | light/down_arrow-hover.svg 19 | light/left_arrow.svg 20 | light/left_arrow_disabled.svg 21 | light/right_arrow.svg 22 | light/right_arrow_disabled.svg 23 | light/up_arrow.svg 24 | light/up_arrow_disabled.svg 25 | light/up_arrow-hover.svg 26 | light/sizegrip.svg 27 | light/transparent.svg 28 | light/close.svg 29 | light/close-hover.svg 30 | light/close-pressed.svg 31 | light/undock.svg 32 | light/undock-hover.svg 33 | light/checkbox_checked-hover.svg 34 | light/checkbox_checked.svg 35 | light/checkbox_checked_disabled.svg 36 | light/checkbox_indeterminate.svg 37 | light/checkbox_indeterminate-hover.svg 38 | light/checkbox_indeterminate_disabled.svg 39 | light/checkbox_unchecked-hover.svg 40 | light/checkbox_unchecked_disabled.svg 41 | light/radio_checked-hover.svg 42 | light/radio_checked.svg 43 | light/radio_checked_disabled.svg 44 | light/radio_unchecked-hover.svg 45 | light/radio_unchecked_disabled.svg 46 | dark/hmovetoolbar.svg 47 | dark/vmovetoolbar.svg 48 | dark/hsepartoolbar.svg 49 | dark/vsepartoolbars.svg 50 | dark/stylesheet-branch-end.svg 51 | dark/stylesheet-branch-end-closed.svg 52 | dark/stylesheet-branch-end-open.svg 53 | dark/stylesheet-vline.svg 54 | dark/stylesheet-branch-more.svg 55 | dark/branch_closed.svg 56 | dark/branch_closed-on.svg 57 | dark/branch_open.svg 58 | dark/branch_open-on.svg 59 | dark/down_arrow.svg 60 | dark/down_arrow_disabled.svg 61 | dark/down_arrow-hover.svg 62 | dark/left_arrow.svg 63 | dark/left_arrow_disabled.svg 64 | dark/right_arrow.svg 65 | dark/right_arrow_disabled.svg 66 | dark/up_arrow.svg 67 | dark/up_arrow_disabled.svg 68 | dark/up_arrow-hover.svg 69 | dark/sizegrip.svg 70 | dark/transparent.svg 71 | dark/close.svg 72 | dark/close-hover.svg 73 | dark/close-pressed.svg 74 | dark/undock.svg 75 | dark/undock-hover.svg 76 | dark/checkbox_checked.svg 77 | dark/checkbox_checked_disabled.svg 78 | dark/checkbox_indeterminate.svg 79 | dark/checkbox_indeterminate_disabled.svg 80 | dark/checkbox_unchecked.svg 81 | dark/checkbox_unchecked_disabled.svg 82 | dark/radio_checked.svg 83 | dark/radio_checked_disabled.svg 84 | dark/radio_unchecked.svg 85 | dark/radio_unchecked_disabled.svg 86 | light.qss 87 | dark.qss 88 | 89 | 90 | -------------------------------------------------------------------------------- /BreezeStyleSheets/dark.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # The MIT License (MIT) 4 | # 5 | # Copyright (c) <2013-2014> 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in 15 | # all copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 23 | # THE SOFTWARE. 24 | # 25 | """ 26 | A simple example of use. 27 | 28 | Load an ui made in QtDesigner and apply the DarkStyleSheet. 29 | 30 | 31 | Requirements: 32 | - Python 2 or Python 3 33 | - PyQt4 34 | 35 | .. note.. :: qdarkstyle does not have to be installed to run 36 | the example 37 | 38 | """ 39 | import logging 40 | import sys 41 | from PyQt5 import QtWidgets, QtCore 42 | from PyQt5.QtCore import QFile, QTextStream 43 | # make the example runnable without the need to install 44 | 45 | import example 46 | import breeze_resources 47 | 48 | def main(): 49 | """ 50 | Application entry point 51 | """ 52 | logging.basicConfig(level=logging.DEBUG) 53 | # create the application and the main window 54 | app = QtWidgets.QApplication(sys.argv) 55 | #app.setStyle(QtWidgets.QStyleFactory.create("fusion")) 56 | window = QtWidgets.QMainWindow() 57 | 58 | # setup ui 59 | ui = example.Ui_MainWindow() 60 | ui.setupUi(window) 61 | 62 | # setup stylesheet 63 | file = QFile(":/dark.qss") 64 | file.open(QFile.ReadOnly | QFile.Text) 65 | stream = QTextStream(file) 66 | app.setStyleSheet(stream.readAll()) 67 | 68 | # auto quit after 2s when testing on travis-ci 69 | if "--travis" in sys.argv: 70 | QtCore.QTimer.singleShot(2000, app.exit) 71 | 72 | # run 73 | window.show() 74 | app.exec_() 75 | 76 | 77 | if __name__ == "__main__": 78 | main() 79 | -------------------------------------------------------------------------------- /BreezeStyleSheets/dark/branch_closed-on.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /BreezeStyleSheets/dark/branch_closed.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /BreezeStyleSheets/dark/branch_open-on.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /BreezeStyleSheets/dark/branch_open.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /BreezeStyleSheets/dark/checkbox_checked.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /BreezeStyleSheets/dark/checkbox_checked_disabled.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /BreezeStyleSheets/dark/checkbox_indeterminate.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /BreezeStyleSheets/dark/checkbox_indeterminate_disabled.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /BreezeStyleSheets/dark/checkbox_unchecked.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /BreezeStyleSheets/dark/checkbox_unchecked_disabled.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /BreezeStyleSheets/dark/close-hover.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /BreezeStyleSheets/dark/close-pressed.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /BreezeStyleSheets/dark/close.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /BreezeStyleSheets/dark/down_arrow-hover.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /BreezeStyleSheets/dark/down_arrow.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /BreezeStyleSheets/dark/down_arrow_disabled.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /BreezeStyleSheets/dark/hmovetoolbar.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /BreezeStyleSheets/dark/hsepartoolbar.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /BreezeStyleSheets/dark/left_arrow.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /BreezeStyleSheets/dark/left_arrow_disabled.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /BreezeStyleSheets/dark/radio_checked.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /BreezeStyleSheets/dark/radio_checked_disabled.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /BreezeStyleSheets/dark/radio_unchecked.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /BreezeStyleSheets/dark/radio_unchecked_disabled.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /BreezeStyleSheets/dark/right_arrow.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /BreezeStyleSheets/dark/right_arrow_disabled.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /BreezeStyleSheets/dark/sizegrip.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /BreezeStyleSheets/dark/spinup_disabled.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /BreezeStyleSheets/dark/stylesheet-branch-end-closed.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /BreezeStyleSheets/dark/stylesheet-branch-end-open.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /BreezeStyleSheets/dark/stylesheet-branch-end.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /BreezeStyleSheets/dark/stylesheet-branch-more.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /BreezeStyleSheets/dark/stylesheet-vline.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /BreezeStyleSheets/dark/transparent.svg: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /BreezeStyleSheets/dark/undock-hover.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /BreezeStyleSheets/dark/undock.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /BreezeStyleSheets/dark/up_arrow-hover.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /BreezeStyleSheets/dark/up_arrow.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /BreezeStyleSheets/dark/up_arrow_disabled.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /BreezeStyleSheets/dark/vmovetoolbar.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /BreezeStyleSheets/dark/vsepartoolbars.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /BreezeStyleSheets/light.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # The MIT License (MIT) 4 | # 5 | # Copyright (c) <2013-2014> 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in 15 | # all copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 23 | # THE SOFTWARE. 24 | # 25 | """ 26 | A simple example of use. 27 | 28 | Load an ui made in QtDesigner and apply the DarkStyleSheet. 29 | 30 | 31 | Requirements: 32 | - Python 2 or Python 3 33 | - PyQt4 34 | 35 | .. note.. :: qdarkstyle does not have to be installed to run 36 | the example 37 | 38 | """ 39 | import logging 40 | import sys 41 | from PyQt5 import QtWidgets, QtCore 42 | from PyQt5.QtCore import QFile, QTextStream 43 | # make the example runnable without the need to install 44 | 45 | import example 46 | import breeze_resources 47 | 48 | def main(): 49 | """ 50 | Application entry point 51 | """ 52 | logging.basicConfig(level=logging.DEBUG) 53 | # create the application and the main window 54 | app = QtWidgets.QApplication(sys.argv) 55 | #app.setStyle(QtWidgets.QStyleFactory.create("fusion")) 56 | window = QtWidgets.QMainWindow() 57 | 58 | # setup ui 59 | ui = example.Ui_MainWindow() 60 | ui.setupUi(window) 61 | ui.bt_delay_popup.addActions([ 62 | ui.actionAction, 63 | ui.actionAction_C 64 | ]) 65 | ui.bt_instant_popup.addActions([ 66 | ui.actionAction, 67 | ui.actionAction_C 68 | ]) 69 | ui.bt_menu_button_popup.addActions([ 70 | ui.actionAction, 71 | ui.actionAction_C 72 | ]) 73 | window.setWindowTitle("Breeze example") 74 | 75 | # tabify dock widgets to show bug #6 76 | window.tabifyDockWidget(ui.dockWidget1, ui.dockWidget2) 77 | 78 | # setup stylesheet 79 | file = QFile(":/light.qss") 80 | file.open(QFile.ReadOnly | QFile.Text) 81 | stream = QTextStream(file) 82 | app.setStyleSheet(stream.readAll()) 83 | 84 | # auto quit after 2s when testing on travis-ci 85 | if "--travis" in sys.argv: 86 | QtCore.QTimer.singleShot(2000, app.exit) 87 | 88 | # run 89 | window.show() 90 | app.exec_() 91 | 92 | 93 | if __name__ == "__main__": 94 | main() 95 | -------------------------------------------------------------------------------- /BreezeStyleSheets/light/branch_closed-on.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /BreezeStyleSheets/light/branch_closed.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /BreezeStyleSheets/light/branch_open-on.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /BreezeStyleSheets/light/branch_open.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /BreezeStyleSheets/light/checkbox_checked-hover.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /BreezeStyleSheets/light/checkbox_checked.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /BreezeStyleSheets/light/checkbox_checked_disabled.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /BreezeStyleSheets/light/checkbox_indeterminate-hover.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /BreezeStyleSheets/light/checkbox_indeterminate.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /BreezeStyleSheets/light/checkbox_indeterminate_disabled.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /BreezeStyleSheets/light/checkbox_unchecked-hover.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /BreezeStyleSheets/light/checkbox_unchecked_disabled.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /BreezeStyleSheets/light/close-hover.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /BreezeStyleSheets/light/close-pressed.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /BreezeStyleSheets/light/close.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /BreezeStyleSheets/light/down_arrow-hover.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /BreezeStyleSheets/light/down_arrow.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /BreezeStyleSheets/light/down_arrow_disabled.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /BreezeStyleSheets/light/hmovetoolbar.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /BreezeStyleSheets/light/hsepartoolbar.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /BreezeStyleSheets/light/left_arrow.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /BreezeStyleSheets/light/left_arrow_disabled.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /BreezeStyleSheets/light/radio_checked-hover.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /BreezeStyleSheets/light/radio_checked.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /BreezeStyleSheets/light/radio_checked_disabled.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /BreezeStyleSheets/light/radio_unchecked-hover.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /BreezeStyleSheets/light/radio_unchecked_disabled.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /BreezeStyleSheets/light/right_arrow.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /BreezeStyleSheets/light/right_arrow_disabled.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /BreezeStyleSheets/light/sizegrip.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /BreezeStyleSheets/light/spinup_disabled.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /BreezeStyleSheets/light/stylesheet-branch-end-closed.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /BreezeStyleSheets/light/stylesheet-branch-end-open.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /BreezeStyleSheets/light/stylesheet-branch-end.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /BreezeStyleSheets/light/stylesheet-branch-more.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /BreezeStyleSheets/light/stylesheet-vline.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /BreezeStyleSheets/light/transparent.svg: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /BreezeStyleSheets/light/undock-hover.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /BreezeStyleSheets/light/undock.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /BreezeStyleSheets/light/up_arrow-hover.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /BreezeStyleSheets/light/up_arrow.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /BreezeStyleSheets/light/up_arrow_disabled.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /BreezeStyleSheets/light/vmovetoolbar.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /BreezeStyleSheets/light/vsepartoolbars.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /BreezeStyleSheets/native.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # The MIT License (MIT) 4 | # 5 | # Copyright (c) <2013-2014> 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in 15 | # all copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 23 | # THE SOFTWARE. 24 | # 25 | """ 26 | A simple example of use. 27 | 28 | Load an ui made in QtDesigner and apply the DarkStyleSheet. 29 | 30 | 31 | Requirements: 32 | - Python 2 or Python 3 33 | - PyQt4 34 | 35 | .. note.. :: qdarkstyle does not have to be installed to run 36 | the example 37 | 38 | """ 39 | import logging 40 | import sys 41 | from PyQt5 import QtWidgets, QtCore 42 | # make the example runnable without the need to install 43 | 44 | import example 45 | 46 | 47 | def main(): 48 | """ 49 | Application entry point 50 | """ 51 | logging.basicConfig(level=logging.DEBUG) 52 | # create the application and the main window 53 | app = QtWidgets.QApplication(sys.argv) 54 | #app.setStyle(QtWidgets.QStyleFactory.create("fusion")) 55 | window = QtWidgets.QMainWindow() 56 | 57 | # setup ui 58 | ui = example.Ui_MainWindow() 59 | ui.setupUi(window) 60 | ui.bt_delay_popup.addActions([ 61 | ui.actionAction, 62 | ui.actionAction_C 63 | ]) 64 | ui.bt_instant_popup.addActions([ 65 | ui.actionAction, 66 | ui.actionAction_C 67 | ]) 68 | ui.bt_menu_button_popup.addActions([ 69 | ui.actionAction, 70 | ui.actionAction_C 71 | ]) 72 | window.setWindowTitle("Native example") 73 | 74 | # tabify dock widgets to show bug #6 75 | window.tabifyDockWidget(ui.dockWidget1, ui.dockWidget2) 76 | 77 | # auto quit after 2s when testing on travis-ci 78 | if "--travis" in sys.argv: 79 | QtCore.QTimer.singleShot(2000, app.exit) 80 | 81 | # run 82 | window.show() 83 | app.exec_() 84 | 85 | 86 | if __name__ == "__main__": 87 | main() 88 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 tea321000 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # image_semantic_segmentation 2 | 使用pyQt作为GUI框架的图像语义分割软件,支持mobilenet、resnet50、hrnet等8个模型 3 | 4 | 演示效果见[bilibili](https://www.bilibili.com/video/BV1VA411i73M/) 5 | 6 | 使用方法: 7 | 8 | 1.git clone https://github.com/tea321000/image_semantic_segmentation.git 9 | 10 | 2.从[百度网盘](https://pan.baidu.com/s/1xlROxeZ0EGqrn-4U2FgcqQ) 提取码:b2gs 或从[mit](http://sceneparsing.csail.mit.edu/model/pytorch/)下载模型并放入与程序同一文件夹下 11 | 12 | 3.pip install requirement.txt 13 | 14 | 4.python gui_main.py运行GUI主程序 15 | 16 | 语义分割效果: 17 | ![](.assets/5.png) 18 | 19 | 隐藏某些图层: 20 | ![](.assets/6.png) 21 | 22 | 保留某些图层: 23 | ![](.assets/7.png) 24 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | from .defaults import _C as cfg 2 | -------------------------------------------------------------------------------- /config/ade20k-hrnetv2.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | root_dataset: "./data/" 3 | list_train: "./data/training.odgt" 4 | list_val: "./data/validation.odgt" 5 | num_class: 150 6 | imgSizes: (300, 375, 450, 525, 600) 7 | imgMaxSize: 1000 8 | padding_constant: 32 9 | segm_downsampling_rate: 4 10 | random_flip: True 11 | 12 | MODEL: 13 | arch_encoder: "hrnetv2" 14 | arch_decoder: "c1" 15 | fc_dim: 720 16 | 17 | TRAIN: 18 | batch_size_per_gpu: 2 19 | num_epoch: 30 20 | start_epoch: 0 21 | epoch_iters: 5000 22 | optim: "SGD" 23 | lr_encoder: 0.02 24 | lr_decoder: 0.02 25 | lr_pow: 0.9 26 | beta1: 0.9 27 | weight_decay: 1e-4 28 | deep_sup_scale: 0.4 29 | fix_bn: False 30 | workers: 16 31 | disp_iter: 20 32 | seed: 304 33 | 34 | VAL: 35 | visualize: False 36 | checkpoint: "epoch_30.pth" 37 | 38 | TEST: 39 | checkpoint: "epoch_30.pth" 40 | result: "./" 41 | 42 | DIR: "ckpt/ade20k-hrnetv2-c1" 43 | -------------------------------------------------------------------------------- /config/ade20k-mobilenetv2dilated-c1_deepsup.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | root_dataset: "./data/" 3 | list_train: "./data/training.odgt" 4 | list_val: "./data/validation.odgt" 5 | num_class: 150 6 | imgSizes: (300, 375, 450, 525, 600) 7 | imgMaxSize: 1000 8 | padding_constant: 8 9 | segm_downsampling_rate: 8 10 | random_flip: True 11 | 12 | MODEL: 13 | arch_encoder: "mobilenetv2dilated" 14 | arch_decoder: "c1_deepsup" 15 | fc_dim: 320 16 | 17 | TRAIN: 18 | batch_size_per_gpu: 3 19 | num_epoch: 20 20 | start_epoch: 0 21 | epoch_iters: 5000 22 | optim: "SGD" 23 | lr_encoder: 0.02 24 | lr_decoder: 0.02 25 | lr_pow: 0.9 26 | beta1: 0.9 27 | weight_decay: 1e-4 28 | deep_sup_scale: 0.4 29 | fix_bn: False 30 | workers: 16 31 | disp_iter: 20 32 | seed: 304 33 | 34 | VAL: 35 | visualize: False 36 | checkpoint: "epoch_20.pth" 37 | 38 | TEST: 39 | checkpoint: "epoch_20.pth" 40 | result: "./" 41 | 42 | DIR: "ckpt/ade20k-mobilenetv2dilated-c1_deepsup" 43 | -------------------------------------------------------------------------------- /config/ade20k-resnet101-upernet.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | root_dataset: "./data/" 3 | list_train: "./data/training.odgt" 4 | list_val: "./data/validation.odgt" 5 | num_class: 150 6 | imgSizes: (300, 375, 450, 525, 600) 7 | imgMaxSize: 1000 8 | padding_constant: 32 9 | segm_downsampling_rate: 4 10 | random_flip: True 11 | 12 | MODEL: 13 | arch_encoder: "resnet101" 14 | arch_decoder: "upernet" 15 | fc_dim: 2048 16 | 17 | TRAIN: 18 | batch_size_per_gpu: 2 19 | num_epoch: 40 20 | start_epoch: 0 21 | epoch_iters: 5000 22 | optim: "SGD" 23 | lr_encoder: 0.02 24 | lr_decoder: 0.02 25 | lr_pow: 0.9 26 | beta1: 0.9 27 | weight_decay: 1e-4 28 | deep_sup_scale: 0.4 29 | fix_bn: False 30 | workers: 16 31 | disp_iter: 20 32 | seed: 304 33 | 34 | VAL: 35 | visualize: False 36 | checkpoint: "epoch_40.pth" 37 | 38 | TEST: 39 | checkpoint: "epoch_40.pth" 40 | result: "./" 41 | 42 | DIR: "ckpt/ade20k-resnet101-upernet" 43 | -------------------------------------------------------------------------------- /config/ade20k-resnet101dilated-ppm_deepsup.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | root_dataset: "./data/" 3 | list_train: "./data/training.odgt" 4 | list_val: "./data/validation.odgt" 5 | num_class: 150 6 | imgSizes: (300, 375, 450, 525, 600) 7 | imgMaxSize: 1000 8 | padding_constant: 8 9 | segm_downsampling_rate: 8 10 | random_flip: True 11 | 12 | MODEL: 13 | arch_encoder: "resnet50dilated" 14 | arch_decoder: "ppm_deepsup" 15 | fc_dim: 2048 16 | 17 | TRAIN: 18 | batch_size_per_gpu: 2 19 | num_epoch: 25 20 | start_epoch: 0 21 | epoch_iters: 5000 22 | optim: "SGD" 23 | lr_encoder: 0.02 24 | lr_decoder: 0.02 25 | lr_pow: 0.9 26 | beta1: 0.9 27 | weight_decay: 1e-4 28 | deep_sup_scale: 0.4 29 | fix_bn: False 30 | workers: 16 31 | disp_iter: 20 32 | seed: 304 33 | 34 | VAL: 35 | visualize: False 36 | checkpoint: "epoch_25.pth" 37 | 38 | TEST: 39 | checkpoint: "epoch_25.pth" 40 | result: "./" 41 | 42 | DIR: "ckpt/ade20k-resnet50dilated-ppm_deepsup" 43 | -------------------------------------------------------------------------------- /config/ade20k-resnet18dilated-ppm_deepsup.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | root_dataset: "./data/" 3 | list_train: "./data/training.odgt" 4 | list_val: "./data/validation.odgt" 5 | num_class: 150 6 | imgSizes: (300, 375, 450, 525, 600) 7 | imgMaxSize: 1000 8 | padding_constant: 8 9 | segm_downsampling_rate: 8 10 | random_flip: True 11 | 12 | MODEL: 13 | arch_encoder: "resnet18dilated" 14 | arch_decoder: "ppm_deepsup" 15 | fc_dim: 512 16 | 17 | TRAIN: 18 | batch_size_per_gpu: 2 19 | num_epoch: 20 20 | start_epoch: 0 21 | epoch_iters: 5000 22 | optim: "SGD" 23 | lr_encoder: 0.02 24 | lr_decoder: 0.02 25 | lr_pow: 0.9 26 | beta1: 0.9 27 | weight_decay: 1e-4 28 | deep_sup_scale: 0.4 29 | fix_bn: False 30 | workers: 16 31 | disp_iter: 20 32 | seed: 304 33 | 34 | VAL: 35 | visualize: False 36 | checkpoint: "epoch_20.pth" 37 | 38 | TEST: 39 | checkpoint: "epoch_20.pth" 40 | result: "./" 41 | 42 | DIR: "ckpt/ade20k-resnet18dilated-ppm_deepsup" 43 | -------------------------------------------------------------------------------- /config/ade20k-resnet50-upernet.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | root_dataset: "./data/" 3 | list_train: "./data/training.odgt" 4 | list_val: "./data/validation.odgt" 5 | num_class: 150 6 | imgSizes: (300, 375, 450, 525, 600) 7 | imgMaxSize: 1000 8 | padding_constant: 32 9 | segm_downsampling_rate: 4 10 | random_flip: True 11 | 12 | MODEL: 13 | arch_encoder: "resnet50" 14 | arch_decoder: "upernet" 15 | fc_dim: 2048 16 | 17 | TRAIN: 18 | batch_size_per_gpu: 2 19 | num_epoch: 30 20 | start_epoch: 0 21 | epoch_iters: 5000 22 | optim: "SGD" 23 | lr_encoder: 0.02 24 | lr_decoder: 0.02 25 | lr_pow: 0.9 26 | beta1: 0.9 27 | weight_decay: 1e-4 28 | deep_sup_scale: 0.4 29 | fix_bn: False 30 | workers: 16 31 | disp_iter: 20 32 | seed: 304 33 | 34 | VAL: 35 | visualize: False 36 | checkpoint: "epoch_30.pth" 37 | 38 | TEST: 39 | checkpoint: "epoch_30.pth" 40 | result: "./" 41 | 42 | DIR: "ckpt/ade20k-resnet50-upernet" 43 | -------------------------------------------------------------------------------- /config/ade20k-resnet50dilated-ppm_deepsup.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | root_dataset: "./data/" 3 | list_train: "./data/training.odgt" 4 | list_val: "./data/validation.odgt" 5 | num_class: 150 6 | imgSizes: (300, 375, 450, 525, 600) 7 | imgMaxSize: 1000 8 | padding_constant: 8 9 | segm_downsampling_rate: 8 10 | random_flip: True 11 | 12 | MODEL: 13 | arch_encoder: "resnet50dilated" 14 | arch_decoder: "ppm_deepsup" 15 | fc_dim: 2048 16 | 17 | TRAIN: 18 | batch_size_per_gpu: 2 19 | num_epoch: 20 20 | start_epoch: 0 21 | epoch_iters: 5000 22 | optim: "SGD" 23 | lr_encoder: 0.02 24 | lr_decoder: 0.02 25 | lr_pow: 0.9 26 | beta1: 0.9 27 | weight_decay: 1e-4 28 | deep_sup_scale: 0.4 29 | fix_bn: False 30 | workers: 16 31 | disp_iter: 20 32 | seed: 304 33 | 34 | VAL: 35 | visualize: False 36 | checkpoint: "epoch_20.pth" 37 | 38 | TEST: 39 | checkpoint: "epoch_20.pth" 40 | result: "./" 41 | 42 | DIR: "ckpt/ade20k-resnet50dilated-ppm_deepsup" 43 | -------------------------------------------------------------------------------- /config/defaults.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | # ----------------------------------------------------------------------------- 4 | # Config definition 5 | # ----------------------------------------------------------------------------- 6 | 7 | _C = CN() 8 | _C.DIR = "ckpt/ade20k-resnet50dilated-ppm_deepsup" 9 | 10 | # ----------------------------------------------------------------------------- 11 | # Dataset 12 | # ----------------------------------------------------------------------------- 13 | _C.DATASET = CN() 14 | _C.DATASET.root_dataset = "./data/" 15 | _C.DATASET.list_train = "./data/training.odgt" 16 | _C.DATASET.list_val = "./data/validation.odgt" 17 | _C.DATASET.num_class = 150 18 | # multiscale train/test, size of short edge (int or tuple) 19 | _C.DATASET.imgSizes = (300, 375, 450, 525, 600) 20 | # maximum input image size of long edge 21 | _C.DATASET.imgMaxSize = 1000 22 | # maxmimum downsampling rate of the network 23 | _C.DATASET.padding_constant = 8 24 | # downsampling rate of the segmentation label 25 | _C.DATASET.segm_downsampling_rate = 8 26 | # randomly horizontally flip images when train/test 27 | _C.DATASET.random_flip = True 28 | 29 | # ----------------------------------------------------------------------------- 30 | # Model 31 | # ----------------------------------------------------------------------------- 32 | _C.MODEL = CN() 33 | # architecture of net_encoder 34 | _C.MODEL.arch_encoder = "resnet50dilated" 35 | # architecture of net_decoder 36 | _C.MODEL.arch_decoder = "ppm_deepsup" 37 | # weights to finetune net_encoder 38 | _C.MODEL.weights_encoder = "" 39 | # weights to finetune net_decoder 40 | _C.MODEL.weights_decoder = "" 41 | # number of feature channels between encoder and decoder 42 | _C.MODEL.fc_dim = 2048 43 | 44 | # ----------------------------------------------------------------------------- 45 | # Training 46 | # ----------------------------------------------------------------------------- 47 | _C.TRAIN = CN() 48 | _C.TRAIN.batch_size_per_gpu = 2 49 | # epochs to train for 50 | _C.TRAIN.num_epoch = 20 51 | # epoch to start training. useful if continue from a checkpoint 52 | _C.TRAIN.start_epoch = 0 53 | # iterations of each epoch (irrelevant to batch size) 54 | _C.TRAIN.epoch_iters = 5000 55 | 56 | _C.TRAIN.optim = "SGD" 57 | _C.TRAIN.lr_encoder = 0.02 58 | _C.TRAIN.lr_decoder = 0.02 59 | # power in poly to drop LR 60 | _C.TRAIN.lr_pow = 0.9 61 | # momentum for sgd, beta1 for adam 62 | _C.TRAIN.beta1 = 0.9 63 | # weights regularizer 64 | _C.TRAIN.weight_decay = 1e-4 65 | # the weighting of deep supervision loss 66 | _C.TRAIN.deep_sup_scale = 0.4 67 | # fix bn params, only under finetuning 68 | _C.TRAIN.fix_bn = False 69 | # number of data loading workers 70 | _C.TRAIN.workers = 16 71 | 72 | # frequency to display 73 | _C.TRAIN.disp_iter = 20 74 | # manual seed 75 | _C.TRAIN.seed = 304 76 | 77 | # ----------------------------------------------------------------------------- 78 | # Validation 79 | # ----------------------------------------------------------------------------- 80 | _C.VAL = CN() 81 | # currently only supports 1 82 | _C.VAL.batch_size = 1 83 | # output visualization during validation 84 | _C.VAL.visualize = False 85 | # the checkpoint to evaluate on 86 | _C.VAL.checkpoint = "epoch_20.pth" 87 | 88 | # ----------------------------------------------------------------------------- 89 | # Testing 90 | # ----------------------------------------------------------------------------- 91 | _C.TEST = CN() 92 | # currently only supports 1 93 | _C.TEST.batch_size = 1 94 | # the checkpoint to test on 95 | _C.TEST.checkpoint = "epoch_20.pth" 96 | # folder to output visualization results 97 | _C.TEST.result = "./" 98 | -------------------------------------------------------------------------------- /data/color150.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tea321000/image_semantic_segmentation/80409858e8d44c1b24695035abf3a68fd83e9574/data/color150.mat -------------------------------------------------------------------------------- /data/object150_info.csv: -------------------------------------------------------------------------------- 1 | Idx,Ratio,Train,Val,Stuff,Name 2 | 1,0.1576,11664,1172,1,wall 3 | 2,0.1072,6046,612,1,building;edifice 4 | 3,0.0878,8265,796,1,sky 5 | 4,0.0621,9336,917,1,floor;flooring 6 | 5,0.0480,6678,641,0,tree 7 | 6,0.0450,6604,643,1,ceiling 8 | 7,0.0398,4023,408,1,road;route 9 | 8,0.0231,1906,199,0,bed 10 | 9,0.0198,4688,460,0,windowpane;window 11 | 10,0.0183,2423,225,1,grass 12 | 11,0.0181,2874,294,0,cabinet 13 | 12,0.0166,3068,310,1,sidewalk;pavement 14 | 13,0.0160,5075,526,0,person;individual;someone;somebody;mortal;soul 15 | 14,0.0151,1804,190,1,earth;ground 16 | 15,0.0118,6666,796,0,door;double;door 17 | 16,0.0110,4269,411,0,table 18 | 17,0.0109,1691,160,1,mountain;mount 19 | 18,0.0104,3999,441,0,plant;flora;plant;life 20 | 19,0.0104,2149,217,0,curtain;drape;drapery;mantle;pall 21 | 20,0.0103,3261,318,0,chair 22 | 21,0.0098,3164,306,0,car;auto;automobile;machine;motorcar 23 | 22,0.0074,709,75,1,water 24 | 23,0.0067,3296,315,0,painting;picture 25 | 24,0.0065,1191,106,0,sofa;couch;lounge 26 | 25,0.0061,1516,162,0,shelf 27 | 26,0.0060,667,69,1,house 28 | 27,0.0053,651,57,1,sea 29 | 28,0.0052,1847,224,0,mirror 30 | 29,0.0046,1158,128,1,rug;carpet;carpeting 31 | 30,0.0044,480,44,1,field 32 | 31,0.0044,1172,98,0,armchair 33 | 32,0.0044,1292,184,0,seat 34 | 33,0.0033,1386,138,0,fence;fencing 35 | 34,0.0031,698,61,0,desk 36 | 35,0.0030,781,73,0,rock;stone 37 | 36,0.0027,380,43,0,wardrobe;closet;press 38 | 37,0.0026,3089,302,0,lamp 39 | 38,0.0024,404,37,0,bathtub;bathing;tub;bath;tub 40 | 39,0.0024,804,99,0,railing;rail 41 | 40,0.0023,1453,153,0,cushion 42 | 41,0.0023,411,37,0,base;pedestal;stand 43 | 42,0.0022,1440,162,0,box 44 | 43,0.0022,800,77,0,column;pillar 45 | 44,0.0020,2650,298,0,signboard;sign 46 | 45,0.0019,549,46,0,chest;of;drawers;chest;bureau;dresser 47 | 46,0.0019,367,36,0,counter 48 | 47,0.0018,311,30,1,sand 49 | 48,0.0018,1181,122,0,sink 50 | 49,0.0018,287,23,1,skyscraper 51 | 50,0.0018,468,38,0,fireplace;hearth;open;fireplace 52 | 51,0.0018,402,43,0,refrigerator;icebox 53 | 52,0.0018,130,12,1,grandstand;covered;stand 54 | 53,0.0018,561,64,1,path 55 | 54,0.0017,880,102,0,stairs;steps 56 | 55,0.0017,86,12,1,runway 57 | 56,0.0017,172,11,0,case;display;case;showcase;vitrine 58 | 57,0.0017,198,18,0,pool;table;billiard;table;snooker;table 59 | 58,0.0017,930,109,0,pillow 60 | 59,0.0015,139,18,0,screen;door;screen 61 | 60,0.0015,564,52,1,stairway;staircase 62 | 61,0.0015,320,26,1,river 63 | 62,0.0015,261,29,1,bridge;span 64 | 63,0.0014,275,22,0,bookcase 65 | 64,0.0014,335,60,0,blind;screen 66 | 65,0.0014,792,75,0,coffee;table;cocktail;table 67 | 66,0.0014,395,49,0,toilet;can;commode;crapper;pot;potty;stool;throne 68 | 67,0.0014,1309,138,0,flower 69 | 68,0.0013,1112,113,0,book 70 | 69,0.0013,266,27,1,hill 71 | 70,0.0013,659,66,0,bench 72 | 71,0.0012,331,31,0,countertop 73 | 72,0.0012,531,56,0,stove;kitchen;stove;range;kitchen;range;cooking;stove 74 | 73,0.0012,369,36,0,palm;palm;tree 75 | 74,0.0012,144,9,0,kitchen;island 76 | 75,0.0011,265,29,0,computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system 77 | 76,0.0010,324,33,0,swivel;chair 78 | 77,0.0009,304,27,0,boat 79 | 78,0.0009,170,20,0,bar 80 | 79,0.0009,68,6,0,arcade;machine 81 | 80,0.0009,65,8,1,hovel;hut;hutch;shack;shanty 82 | 81,0.0009,248,25,0,bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle 83 | 82,0.0008,492,49,0,towel 84 | 83,0.0008,2510,269,0,light;light;source 85 | 84,0.0008,440,39,0,truck;motortruck 86 | 85,0.0008,147,18,1,tower 87 | 86,0.0008,583,56,0,chandelier;pendant;pendent 88 | 87,0.0007,533,61,0,awning;sunshade;sunblind 89 | 88,0.0007,1989,239,0,streetlight;street;lamp 90 | 89,0.0007,71,5,0,booth;cubicle;stall;kiosk 91 | 90,0.0007,618,53,0,television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box 92 | 91,0.0007,135,12,0,airplane;aeroplane;plane 93 | 92,0.0007,83,5,1,dirt;track 94 | 93,0.0007,178,17,0,apparel;wearing;apparel;dress;clothes 95 | 94,0.0006,1003,104,0,pole 96 | 95,0.0006,182,12,1,land;ground;soil 97 | 96,0.0006,452,50,0,bannister;banister;balustrade;balusters;handrail 98 | 97,0.0006,42,6,1,escalator;moving;staircase;moving;stairway 99 | 98,0.0006,307,31,0,ottoman;pouf;pouffe;puff;hassock 100 | 99,0.0006,965,114,0,bottle 101 | 100,0.0006,117,13,0,buffet;counter;sideboard 102 | 101,0.0006,354,35,0,poster;posting;placard;notice;bill;card 103 | 102,0.0006,108,9,1,stage 104 | 103,0.0006,557,55,0,van 105 | 104,0.0006,52,4,0,ship 106 | 105,0.0005,99,5,0,fountain 107 | 106,0.0005,57,4,1,conveyer;belt;conveyor;belt;conveyer;conveyor;transporter 108 | 107,0.0005,292,31,0,canopy 109 | 108,0.0005,77,9,0,washer;automatic;washer;washing;machine 110 | 109,0.0005,340,38,0,plaything;toy 111 | 110,0.0005,66,3,1,swimming;pool;swimming;bath;natatorium 112 | 111,0.0005,465,49,0,stool 113 | 112,0.0005,50,4,0,barrel;cask 114 | 113,0.0005,622,75,0,basket;handbasket 115 | 114,0.0005,80,9,1,waterfall;falls 116 | 115,0.0005,59,3,0,tent;collapsible;shelter 117 | 116,0.0005,531,72,0,bag 118 | 117,0.0005,282,30,0,minibike;motorbike 119 | 118,0.0005,73,7,0,cradle 120 | 119,0.0005,435,44,0,oven 121 | 120,0.0005,136,25,0,ball 122 | 121,0.0005,116,24,0,food;solid;food 123 | 122,0.0004,266,31,0,step;stair 124 | 123,0.0004,58,12,0,tank;storage;tank 125 | 124,0.0004,418,83,0,trade;name;brand;name;brand;marque 126 | 125,0.0004,319,43,0,microwave;microwave;oven 127 | 126,0.0004,1193,139,0,pot;flowerpot 128 | 127,0.0004,97,23,0,animal;animate;being;beast;brute;creature;fauna 129 | 128,0.0004,347,36,0,bicycle;bike;wheel;cycle 130 | 129,0.0004,52,5,1,lake 131 | 130,0.0004,246,22,0,dishwasher;dish;washer;dishwashing;machine 132 | 131,0.0004,108,13,0,screen;silver;screen;projection;screen 133 | 132,0.0004,201,30,0,blanket;cover 134 | 133,0.0004,285,21,0,sculpture 135 | 134,0.0004,268,27,0,hood;exhaust;hood 136 | 135,0.0003,1020,108,0,sconce 137 | 136,0.0003,1282,122,0,vase 138 | 137,0.0003,528,65,0,traffic;light;traffic;signal;stoplight 139 | 138,0.0003,453,57,0,tray 140 | 139,0.0003,671,100,0,ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin 141 | 140,0.0003,397,44,0,fan 142 | 141,0.0003,92,8,1,pier;wharf;wharfage;dock 143 | 142,0.0003,228,18,0,crt;screen 144 | 143,0.0003,570,59,0,plate 145 | 144,0.0003,217,22,0,monitor;monitoring;device 146 | 145,0.0003,206,19,0,bulletin;board;notice;board 147 | 146,0.0003,130,14,0,shower 148 | 147,0.0003,178,28,0,radiator 149 | 148,0.0002,504,57,0,glass;drinking;glass 150 | 149,0.0002,775,96,0,clock 151 | 150,0.0002,421,56,0,flag 152 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | from torchvision import transforms 5 | import numpy as np 6 | from PIL import Image 7 | 8 | 9 | def imresize(im, size, interp='bilinear'): 10 | if interp == 'nearest': 11 | resample = Image.NEAREST 12 | elif interp == 'bilinear': 13 | resample = Image.BILINEAR 14 | elif interp == 'bicubic': 15 | resample = Image.BICUBIC 16 | else: 17 | raise Exception('resample method undefined!') 18 | 19 | return im.resize(size, resample) 20 | 21 | 22 | class BaseDataset(torch.utils.data.Dataset): 23 | def __init__(self, odgt, opt, **kwargs): 24 | # parse options 25 | self.imgSizes = opt.imgSizes 26 | self.imgMaxSize = opt.imgMaxSize 27 | # max down sampling rate of network to avoid rounding during conv or pooling 28 | self.padding_constant = opt.padding_constant 29 | 30 | # parse the input list 31 | self.parse_input_list(odgt, **kwargs) 32 | 33 | # mean and std 34 | self.normalize = transforms.Normalize( 35 | mean=[0.485, 0.456, 0.406], 36 | std=[0.229, 0.224, 0.225]) 37 | 38 | def parse_input_list(self, odgt, max_sample=-1, start_idx=-1, end_idx=-1): 39 | if isinstance(odgt, list): 40 | self.list_sample = odgt 41 | elif isinstance(odgt, str): 42 | self.list_sample = [json.loads(x.rstrip()) for x in open(odgt, 'r')] 43 | 44 | if max_sample > 0: 45 | self.list_sample = self.list_sample[0:max_sample] 46 | if start_idx >= 0 and end_idx >= 0: # divide file list 47 | self.list_sample = self.list_sample[start_idx:end_idx] 48 | 49 | self.num_sample = len(self.list_sample) 50 | assert self.num_sample > 0 51 | print('# samples: {}'.format(self.num_sample)) 52 | 53 | def img_transform(self, img): 54 | # 0-255 to 0-1 55 | img = np.float32(np.array(img)) / 255. 56 | img = img.transpose((2, 0, 1)) 57 | img = self.normalize(torch.from_numpy(img.copy())) 58 | return img 59 | 60 | def segm_transform(self, segm): 61 | # to tensor, -1 to 149 62 | segm = torch.from_numpy(np.array(segm)).long() - 1 63 | return segm 64 | 65 | # Round x to the nearest multiple of p and x' >= x 66 | def round2nearest_multiple(self, x, p): 67 | return ((x - 1) // p + 1) * p 68 | 69 | 70 | class TrainDataset(BaseDataset): 71 | def __init__(self, root_dataset, odgt, opt, batch_per_gpu=1, **kwargs): 72 | super(TrainDataset, self).__init__(odgt, opt, **kwargs) 73 | self.root_dataset = root_dataset 74 | # down sampling rate of segm labe 75 | self.segm_downsampling_rate = opt.segm_downsampling_rate 76 | self.batch_per_gpu = batch_per_gpu 77 | 78 | # classify images into two classes: 1. h > w and 2. h <= w 79 | self.batch_record_list = [[], []] 80 | 81 | # override dataset length when trainig with batch_per_gpu > 1 82 | self.cur_idx = 0 83 | self.if_shuffled = False 84 | 85 | def _get_sub_batch(self): 86 | while True: 87 | # get a sample record 88 | this_sample = self.list_sample[self.cur_idx] 89 | if this_sample['height'] > this_sample['width']: 90 | self.batch_record_list[0].append(this_sample) # h > w, go to 1st class 91 | else: 92 | self.batch_record_list[1].append(this_sample) # h <= w, go to 2nd class 93 | 94 | # update current sample pointer 95 | self.cur_idx += 1 96 | if self.cur_idx >= self.num_sample: 97 | self.cur_idx = 0 98 | np.random.shuffle(self.list_sample) 99 | 100 | if len(self.batch_record_list[0]) == self.batch_per_gpu: 101 | batch_records = self.batch_record_list[0] 102 | self.batch_record_list[0] = [] 103 | break 104 | elif len(self.batch_record_list[1]) == self.batch_per_gpu: 105 | batch_records = self.batch_record_list[1] 106 | self.batch_record_list[1] = [] 107 | break 108 | return batch_records 109 | 110 | def __getitem__(self, index): 111 | # NOTE: random shuffle for the first time. shuffle in __init__ is useless 112 | if not self.if_shuffled: 113 | np.random.seed(index) 114 | np.random.shuffle(self.list_sample) 115 | self.if_shuffled = True 116 | 117 | # get sub-batch candidates 118 | batch_records = self._get_sub_batch() 119 | 120 | # resize all images' short edges to the chosen size 121 | if isinstance(self.imgSizes, list) or isinstance(self.imgSizes, tuple): 122 | this_short_size = np.random.choice(self.imgSizes) 123 | else: 124 | this_short_size = self.imgSizes 125 | 126 | # calculate the BATCH's height and width 127 | # since we concat more than one samples, the batch's h and w shall be larger than EACH sample 128 | batch_widths = np.zeros(self.batch_per_gpu, np.int32) 129 | batch_heights = np.zeros(self.batch_per_gpu, np.int32) 130 | for i in range(self.batch_per_gpu): 131 | img_height, img_width = batch_records[i]['height'], batch_records[i]['width'] 132 | this_scale = min( 133 | this_short_size / min(img_height, img_width), \ 134 | self.imgMaxSize / max(img_height, img_width)) 135 | batch_widths[i] = img_width * this_scale 136 | batch_heights[i] = img_height * this_scale 137 | 138 | # Here we must pad both input image and segmentation map to size h' and w' so that p | h' and p | w' 139 | batch_width = np.max(batch_widths) 140 | batch_height = np.max(batch_heights) 141 | batch_width = int(self.round2nearest_multiple(batch_width, self.padding_constant)) 142 | batch_height = int(self.round2nearest_multiple(batch_height, self.padding_constant)) 143 | 144 | assert self.padding_constant >= self.segm_downsampling_rate, \ 145 | 'padding constant must be equal or large than segm downsamping rate' 146 | batch_images = torch.zeros( 147 | self.batch_per_gpu, 3, batch_height, batch_width) 148 | batch_segms = torch.zeros( 149 | self.batch_per_gpu, 150 | batch_height // self.segm_downsampling_rate, 151 | batch_width // self.segm_downsampling_rate).long() 152 | 153 | for i in range(self.batch_per_gpu): 154 | this_record = batch_records[i] 155 | 156 | # load image and label 157 | image_path = os.path.join(self.root_dataset, this_record['fpath_img']) 158 | segm_path = os.path.join(self.root_dataset, this_record['fpath_segm']) 159 | 160 | img = Image.open(image_path).convert('RGB') 161 | segm = Image.open(segm_path) 162 | assert(segm.mode == "L") 163 | assert(img.size[0] == segm.size[0]) 164 | assert(img.size[1] == segm.size[1]) 165 | 166 | # random_flip 167 | if np.random.choice([0, 1]): 168 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 169 | segm = segm.transpose(Image.FLIP_LEFT_RIGHT) 170 | 171 | # note that each sample within a mini batch has different scale param 172 | img = imresize(img, (batch_widths[i], batch_heights[i]), interp='bilinear') 173 | segm = imresize(segm, (batch_widths[i], batch_heights[i]), interp='nearest') 174 | 175 | # further downsample seg label, need to avoid seg label misalignment 176 | segm_rounded_width = self.round2nearest_multiple(segm.size[0], self.segm_downsampling_rate) 177 | segm_rounded_height = self.round2nearest_multiple(segm.size[1], self.segm_downsampling_rate) 178 | segm_rounded = Image.new('L', (segm_rounded_width, segm_rounded_height), 0) 179 | segm_rounded.paste(segm, (0, 0)) 180 | segm = imresize( 181 | segm_rounded, 182 | (segm_rounded.size[0] // self.segm_downsampling_rate, \ 183 | segm_rounded.size[1] // self.segm_downsampling_rate), \ 184 | interp='nearest') 185 | 186 | # image transform, to torch float tensor 3xHxW 187 | img = self.img_transform(img) 188 | 189 | # segm transform, to torch long tensor HxW 190 | segm = self.segm_transform(segm) 191 | 192 | # put into batch arrays 193 | batch_images[i][:, :img.shape[1], :img.shape[2]] = img 194 | batch_segms[i][:segm.shape[0], :segm.shape[1]] = segm 195 | 196 | output = dict() 197 | output['img_data'] = batch_images 198 | output['seg_label'] = batch_segms 199 | return output 200 | 201 | def __len__(self): 202 | return int(1e10) # It's a fake length due to the trick that every loader maintains its own list 203 | #return self.num_sampleclass 204 | 205 | 206 | class ValDataset(BaseDataset): 207 | def __init__(self, root_dataset, odgt, opt, **kwargs): 208 | super(ValDataset, self).__init__(odgt, opt, **kwargs) 209 | self.root_dataset = root_dataset 210 | 211 | def __getitem__(self, index): 212 | this_record = self.list_sample[index] 213 | # load image and label 214 | image_path = os.path.join(self.root_dataset, this_record['fpath_img']) 215 | segm_path = os.path.join(self.root_dataset, this_record['fpath_segm']) 216 | img = Image.open(image_path).convert('RGB') 217 | segm = Image.open(segm_path) 218 | assert(segm.mode == "L") 219 | assert(img.size[0] == segm.size[0]) 220 | assert(img.size[1] == segm.size[1]) 221 | 222 | ori_width, ori_height = img.size 223 | 224 | img_resized_list = [] 225 | for this_short_size in self.imgSizes: 226 | # calculate target height and width 227 | scale = min(this_short_size / float(min(ori_height, ori_width)), 228 | self.imgMaxSize / float(max(ori_height, ori_width))) 229 | target_height, target_width = int(ori_height * scale), int(ori_width * scale) 230 | 231 | # to avoid rounding in network 232 | target_width = self.round2nearest_multiple(target_width, self.padding_constant) 233 | target_height = self.round2nearest_multiple(target_height, self.padding_constant) 234 | 235 | # resize images 236 | img_resized = imresize(img, (target_width, target_height), interp='bilinear') 237 | 238 | # image transform, to torch float tensor 3xHxW 239 | img_resized = self.img_transform(img_resized) 240 | img_resized = torch.unsqueeze(img_resized, 0) 241 | img_resized_list.append(img_resized) 242 | 243 | # segm transform, to torch long tensor HxW 244 | segm = self.segm_transform(segm) 245 | batch_segms = torch.unsqueeze(segm, 0) 246 | 247 | output = dict() 248 | output['img_ori'] = np.array(img) 249 | output['img_data'] = [x.contiguous() for x in img_resized_list] 250 | output['seg_label'] = batch_segms.contiguous() 251 | output['info'] = this_record['fpath_img'] 252 | return output 253 | 254 | def __len__(self): 255 | return self.num_sample 256 | 257 | 258 | class TestDataset(BaseDataset): 259 | def __init__(self, odgt, opt, **kwargs): 260 | super(TestDataset, self).__init__(odgt, opt, **kwargs) 261 | 262 | def __getitem__(self, index): 263 | this_record = self.list_sample[index] 264 | # load image 265 | image_path = this_record['fpath_img'] 266 | img = Image.open(image_path).convert('RGB') 267 | 268 | ori_width, ori_height = img.size 269 | 270 | img_resized_list = [] 271 | for this_short_size in self.imgSizes: 272 | # calculate target height and width 273 | scale = min(this_short_size / float(min(ori_height, ori_width)), 274 | self.imgMaxSize / float(max(ori_height, ori_width))) 275 | target_height, target_width = int(ori_height * scale), int(ori_width * scale) 276 | 277 | # to avoid rounding in network 278 | target_width = self.round2nearest_multiple(target_width, self.padding_constant) 279 | target_height = self.round2nearest_multiple(target_height, self.padding_constant) 280 | 281 | # resize images 282 | img_resized = imresize(img, (target_width, target_height), interp='bilinear') 283 | 284 | # image transform, to torch float tensor 3xHxW 285 | img_resized = self.img_transform(img_resized) 286 | img_resized = torch.unsqueeze(img_resized, 0) 287 | img_resized_list.append(img_resized) 288 | 289 | output = dict() 290 | output['img_ori'] = np.array(img) 291 | output['img_data'] = [x.contiguous() for x in img_resized_list] 292 | output['info'] = this_record['fpath_img'] 293 | return output 294 | 295 | def __len__(self): 296 | return self.num_sample 297 | -------------------------------------------------------------------------------- /gui.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Form implementation generated from reading ui file 'gui4.ui' 4 | # 5 | # Created by: PyQt5 UI code generator 5.15.0 6 | # 7 | # WARNING: Any manual changes made to this file will be lost when pyuic5 is 8 | # run again. Do not edit this file unless you know what you are doing. 9 | 10 | 11 | from PyQt5 import QtCore, QtGui, QtWidgets 12 | 13 | 14 | class Ui_MainWindow(object): 15 | def setupUi(self, MainWindow): 16 | MainWindow.setObjectName("MainWindow") 17 | MainWindow.resize(800, 932) 18 | MainWindow.setMinimumSize(QtCore.QSize(800, 768)) 19 | font = QtGui.QFont() 20 | font.setFamily("微软雅黑") 21 | font.setPointSize(11) 22 | MainWindow.setFont(font) 23 | self.centralwidget = QtWidgets.QWidget(MainWindow) 24 | self.centralwidget.setObjectName("centralwidget") 25 | self.verticalLayout = QtWidgets.QVBoxLayout(self.centralwidget) 26 | self.verticalLayout.setObjectName("verticalLayout") 27 | self.gridLayout = QtWidgets.QGridLayout() 28 | self.gridLayout.setObjectName("gridLayout") 29 | self.label_8 = QtWidgets.QLabel(self.centralwidget) 30 | self.label_8.setText("") 31 | self.label_8.setObjectName("label_8") 32 | self.gridLayout.addWidget(self.label_8, 6, 0, 1, 1) 33 | self.label_7 = QtWidgets.QLabel(self.centralwidget) 34 | self.label_7.setFont(font) 35 | self.label_7.setTextFormat(QtCore.Qt.PlainText) 36 | self.label_7.setAlignment(QtCore.Qt.AlignCenter) 37 | self.label_7.setObjectName("label_7") 38 | self.gridLayout.addWidget(self.label_7, 12, 0, 1, 1) 39 | self.seg_combo_box = QtWidgets.QComboBox(self.centralwidget) 40 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Preferred, QtWidgets.QSizePolicy.Fixed) 41 | sizePolicy.setHorizontalStretch(10) 42 | sizePolicy.setVerticalStretch(0) 43 | sizePolicy.setHeightForWidth(self.seg_combo_box.sizePolicy().hasHeightForWidth()) 44 | self.seg_combo_box.setSizePolicy(sizePolicy) 45 | self.seg_combo_box.setToolTip("") 46 | self.seg_combo_box.setObjectName("seg_combo_box") 47 | self.seg_combo_box.setFont(font) 48 | self.gridLayout.addWidget(self.seg_combo_box, 12, 1, 1, 4) 49 | self.lineEdit = QtWidgets.QLineEdit(self.centralwidget) 50 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Fixed) 51 | sizePolicy.setHorizontalStretch(10) 52 | sizePolicy.setVerticalStretch(0) 53 | sizePolicy.setHeightForWidth(self.lineEdit.sizePolicy().hasHeightForWidth()) 54 | self.lineEdit.setSizePolicy(sizePolicy) 55 | self.lineEdit.setObjectName("lineEdit") 56 | self.gridLayout.addWidget(self.lineEdit, 0, 1, 1, 4) 57 | self.listWidget = QtWidgets.QListWidget(self.centralwidget) 58 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding) 59 | sizePolicy.setHorizontalStretch(10) 60 | sizePolicy.setVerticalStretch(30) 61 | sizePolicy.setHeightForWidth(self.listWidget.sizePolicy().hasHeightForWidth()) 62 | self.listWidget.setSizePolicy(sizePolicy) 63 | self.listWidget.setObjectName("listWidget") 64 | self.gridLayout.addWidget(self.listWidget, 11, 0, 1, 2) 65 | self.label_5 = QtWidgets.QLabel(self.centralwidget) 66 | self.label_5.setFont(font) 67 | self.label_5.setAlignment(QtCore.Qt.AlignCenter) 68 | self.label_5.setObjectName("label_5") 69 | self.gridLayout.addWidget(self.label_5, 4, 0, 1, 1) 70 | self.output_empty_button = QtWidgets.QPushButton(self.centralwidget) 71 | self.output_empty_button.setObjectName("output_empty_button") 72 | self.output_empty_button.setFont(font) 73 | self.gridLayout.addWidget(self.output_empty_button, 4, 1, 1, 1) 74 | self.gpu_combo_box = QtWidgets.QComboBox(self.centralwidget) 75 | self.gpu_combo_box.setObjectName("gpu_combo_box") 76 | self.gpu_combo_box.addItem("") 77 | self.gpu_combo_box.addItem("") 78 | self.gpu_combo_box.setFont(font) 79 | self.gridLayout.addWidget(self.gpu_combo_box, 1, 1, 1, 1) 80 | self.save_seg_button = QtWidgets.QPushButton(self.centralwidget) 81 | self.save_seg_button.setEnabled(False) 82 | self.save_seg_button.setFont(font) 83 | self.save_seg_button.setObjectName("save_seg_button") 84 | self.gridLayout.addWidget(self.save_seg_button, 12, 5, 1, 1) 85 | self.progressBar = QtWidgets.QProgressBar(self.centralwidget) 86 | self.progressBar.setProperty("value", 0) 87 | self.progressBar.setObjectName("progressBar") 88 | self.gridLayout.addWidget(self.progressBar, 2, 0, 1, 5) 89 | self.label = QtWidgets.QLabel(self.centralwidget) 90 | self.label.setFont(font) 91 | self.label.setTextFormat(QtCore.Qt.PlainText) 92 | self.label.setAlignment(QtCore.Qt.AlignCenter) 93 | self.label.setObjectName("label") 94 | self.gridLayout.addWidget(self.label, 1, 0, 1, 1) 95 | self.seg_confirm_button = QtWidgets.QPushButton(self.centralwidget) 96 | self.seg_confirm_button.setFont(font) 97 | self.seg_confirm_button.setObjectName("seg_confirm_button") 98 | self.gridLayout.addWidget(self.seg_confirm_button, 2, 5, 1, 1) 99 | self.textBrowser = QtWidgets.QTextBrowser(self.centralwidget) 100 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding) 101 | sizePolicy.setHorizontalStretch(0) 102 | sizePolicy.setVerticalStretch(0) 103 | sizePolicy.setHeightForWidth(self.textBrowser.sizePolicy().hasHeightForWidth()) 104 | self.textBrowser.setSizePolicy(sizePolicy) 105 | self.textBrowser.setToolTip("") 106 | self.textBrowser.setObjectName("textBrowser") 107 | self.gridLayout.addWidget(self.textBrowser, 5, 0, 1, 6) 108 | self.label_4 = QtWidgets.QLabel(self.centralwidget) 109 | self.label_4.setFont(font) 110 | self.label_4.setTextFormat(QtCore.Qt.PlainText) 111 | self.label_4.setAlignment(QtCore.Qt.AlignCenter) 112 | self.label_4.setObjectName("label_4") 113 | self.gridLayout.addWidget(self.label_4, 1, 2, 1, 1) 114 | self.browser_button = QtWidgets.QPushButton(self.centralwidget) 115 | self.browser_button.setFont(font) 116 | self.browser_button.setObjectName("browser_button") 117 | self.gridLayout.addWidget(self.browser_button, 0, 5, 1, 1) 118 | self.display_button = QtWidgets.QPushButton(self.centralwidget) 119 | self.display_button.setEnabled(False) 120 | self.display_button.setObjectName("display_button") 121 | self.gridLayout.addWidget(self.display_button, 10, 0, 1, 1) 122 | self.model_combo_box = QtWidgets.QComboBox(self.centralwidget) 123 | self.model_combo_box.setToolTip("") 124 | self.model_combo_box.setObjectName("model_combo_box") 125 | self.model_combo_box.addItem("") 126 | self.model_combo_box.addItem("") 127 | self.model_combo_box.addItem("") 128 | self.model_combo_box.addItem("") 129 | self.model_combo_box.addItem("") 130 | self.model_combo_box.addItem("") 131 | self.model_combo_box.addItem("") 132 | self.model_combo_box.addItem("") 133 | self.model_combo_box.setFont(font) 134 | self.gridLayout.addWidget(self.model_combo_box, 1, 3, 1, 3) 135 | self.show_layers = QtWidgets.QLabel(self.centralwidget) 136 | self.show_layers.setText("") 137 | self.show_layers.setObjectName("show_layers") 138 | self.gridLayout.addWidget(self.show_layers, 11, 2, 1, 4) 139 | self.open_pic_combo_box = QtWidgets.QComboBox(self.centralwidget) 140 | self.open_pic_combo_box.setObjectName("open_pic_combo_box") 141 | self.open_pic_combo_box.addItem("") 142 | self.open_pic_combo_box.addItem("") 143 | self.open_pic_combo_box.setFont(font) 144 | self.gridLayout.addWidget(self.open_pic_combo_box, 0, 0, 1, 1) 145 | self.hide_button = QtWidgets.QPushButton(self.centralwidget) 146 | self.hide_button.setEnabled(False) 147 | self.hide_button.setObjectName("hide_button") 148 | self.gridLayout.addWidget(self.hide_button, 10, 1, 1, 1) 149 | self.label_2 = QtWidgets.QLabel(self.centralwidget) 150 | self.label_2.setFont(font) 151 | self.label_2.setTextFormat(QtCore.Qt.PlainText) 152 | self.label_2.setAlignment(QtCore.Qt.AlignCenter) 153 | self.label_2.setObjectName("label_2") 154 | self.gridLayout.addWidget(self.label_2, 9, 0, 1, 2) 155 | self.label_3 = QtWidgets.QLabel(self.centralwidget) 156 | self.label_3.setFont(font) 157 | self.label_3.setTextFormat(QtCore.Qt.PlainText) 158 | self.label_3.setAlignment(QtCore.Qt.AlignCenter) 159 | self.label_3.setObjectName("label_3") 160 | self.gridLayout.addWidget(self.label_3, 9, 2, 2, 4) 161 | self.verticalLayout.addLayout(self.gridLayout) 162 | MainWindow.setCentralWidget(self.centralwidget) 163 | self.menubar = QtWidgets.QMenuBar(MainWindow) 164 | self.menubar.setGeometry(QtCore.QRect(0, 0, 800, 26)) 165 | self.menubar.setObjectName("menubar") 166 | self.menu = QtWidgets.QMenu(self.menubar) 167 | self.menu.setObjectName("menu") 168 | self.menu_2 = QtWidgets.QMenu(self.menubar) 169 | self.menu_2.setObjectName("menu_2") 170 | self.menu_3 = QtWidgets.QMenu(self.menubar) 171 | self.menu_3.setObjectName("menu_3") 172 | MainWindow.setMenuBar(self.menubar) 173 | self.statusbar = QtWidgets.QStatusBar(MainWindow) 174 | self.statusbar.setObjectName("statusbar") 175 | MainWindow.setStatusBar(self.statusbar) 176 | self.open_pic = QtWidgets.QAction(MainWindow) 177 | icon = QtGui.QIcon() 178 | icon.addPixmap(QtGui.QPixmap(":/dark/sizegrip.svg"), QtGui.QIcon.Normal, QtGui.QIcon.Off) 179 | self.open_pic.setIcon(icon) 180 | self.open_pic.setObjectName("open_pic") 181 | self.save_seg = QtWidgets.QAction(MainWindow) 182 | icon1 = QtGui.QIcon() 183 | icon1.addPixmap(QtGui.QPixmap(":/light/undock-hover.svg"), QtGui.QIcon.Normal, QtGui.QIcon.Off) 184 | self.save_seg.setIcon(icon1) 185 | self.save_seg.setObjectName("save_seg") 186 | self.exit = QtWidgets.QAction(MainWindow) 187 | icon2 = QtGui.QIcon() 188 | icon2.addPixmap(QtGui.QPixmap(":/dark/close-pressed.svg"), QtGui.QIcon.Normal, QtGui.QIcon.Off) 189 | self.exit.setIcon(icon2) 190 | self.exit.setObjectName("exit") 191 | self.action_dark = QtWidgets.QAction(MainWindow) 192 | self.action_dark.setObjectName("action_dark") 193 | self.action_light = QtWidgets.QAction(MainWindow) 194 | self.action_light.setObjectName("action_light") 195 | self.author = QtWidgets.QAction(MainWindow) 196 | self.author.setObjectName("author") 197 | self.open_pic_folder = QtWidgets.QAction(MainWindow) 198 | icon3 = QtGui.QIcon() 199 | icon3.addPixmap(QtGui.QPixmap(":/dark/vmovetoolbar.svg"), QtGui.QIcon.Normal, QtGui.QIcon.Off) 200 | self.open_pic_folder.setIcon(icon3) 201 | self.open_pic_folder.setObjectName("open_pic_folder") 202 | self.menu.addAction(self.open_pic) 203 | self.menu.addAction(self.open_pic_folder) 204 | self.menu.addAction(self.save_seg) 205 | self.menu.addSeparator() 206 | self.menu.addAction(self.exit) 207 | self.menu_2.addAction(self.action_dark) 208 | self.menu_2.addAction(self.action_light) 209 | self.menu_3.addAction(self.author) 210 | self.menubar.addAction(self.menu.menuAction()) 211 | self.menubar.addAction(self.menu_2.menuAction()) 212 | self.menubar.addAction(self.menu_3.menuAction()) 213 | 214 | self.retranslateUi(MainWindow) 215 | self.seg_combo_box.setCurrentIndex(-1) 216 | self.model_combo_box.setCurrentIndex(3) 217 | QtCore.QMetaObject.connectSlotsByName(MainWindow) 218 | 219 | def retranslateUi(self, MainWindow): 220 | _translate = QtCore.QCoreApplication.translate 221 | MainWindow.setWindowTitle(_translate("MainWindow", "自动语义分割")) 222 | self.label_7.setText(_translate("MainWindow", "图片选择")) 223 | self.label_5.setText(_translate("MainWindow", "输出信息")) 224 | self.output_empty_button.setText(_translate("MainWindow", "清空")) 225 | self.gpu_combo_box.setToolTip(_translate("MainWindow", "CPU处理一幅图像较慢,视硬件不同大概需要20秒到1分钟,请耐心等待;\n" 226 | "使用GPU需安装CUDA环境否则仍然会用CPU进行处理,且显存小于8GB的显卡在处理某些图像时可能会爆显存而输出错误的分割结果,这种情况下请尝试使用CPU处理该图片")) 227 | self.gpu_combo_box.setItemText(0, _translate("MainWindow", "CPU")) 228 | self.gpu_combo_box.setItemText(1, _translate("MainWindow", "GPU")) 229 | self.save_seg_button.setText(_translate("MainWindow", "保存分割结果")) 230 | self.label.setText(_translate("MainWindow", "是否使用GPU\n" 231 | "(需安装CUDA)")) 232 | self.seg_confirm_button.setText(_translate("MainWindow", "进行分割")) 233 | self.label_4.setText(_translate("MainWindow", "语义分割模型\n" 234 | "(越往下精度越高 速度越慢)")) 235 | self.browser_button.setText(_translate("MainWindow", "浏览")) 236 | self.display_button.setText(_translate("MainWindow", "显示")) 237 | self.model_combo_box.setItemText(0, _translate("MainWindow", "ade20k-mobilenetv2dilated-c1_deepsup")) 238 | self.model_combo_box.setItemText(1, _translate("MainWindow", "ade20k-resnet18dilated-c1_deepsup")) 239 | self.model_combo_box.setItemText(2, _translate("MainWindow", "ade20k-resnet18dilated-ppm_deepsup")) 240 | self.model_combo_box.setItemText(3, _translate("MainWindow", "ade20k-resnet50dilated-ppm_deepsup")) 241 | self.model_combo_box.setItemText(4, _translate("MainWindow", "ade20k-resnet101dilated-ppm_deepsup")) 242 | self.model_combo_box.setItemText(5, _translate("MainWindow", "ade20k-resnet50-upernet")) 243 | self.model_combo_box.setItemText(6, _translate("MainWindow", "ade20k-resnet101-upernet")) 244 | self.model_combo_box.setItemText(7, _translate("MainWindow", "ade20k-hrnetv2")) 245 | self.open_pic_combo_box.setItemText(0, _translate("MainWindow", "打开图片")) 246 | self.open_pic_combo_box.setItemText(1, _translate("MainWindow", "打开图片文件夹")) 247 | self.hide_button.setText(_translate("MainWindow", "隐藏")) 248 | self.label_2.setText(_translate("MainWindow", "显示隐藏分割图层\n" 249 | "(可以用ctrl和shift多选)")) 250 | self.label_3.setText(_translate("MainWindow", "分割效果预览")) 251 | self.menu.setTitle(_translate("MainWindow", "菜单")) 252 | self.menu_2.setTitle(_translate("MainWindow", "主题")) 253 | self.menu_3.setTitle(_translate("MainWindow", "关于")) 254 | self.open_pic.setText(_translate("MainWindow", "打开图片")) 255 | self.save_seg.setText(_translate("MainWindow", "保存分割结果")) 256 | self.exit.setText(_translate("MainWindow", "退出")) 257 | self.action_dark.setText(_translate("MainWindow", "dark")) 258 | self.action_light.setText(_translate("MainWindow", "light")) 259 | self.author.setText(_translate("MainWindow", "作者信息")) 260 | self.open_pic_folder.setText(_translate("MainWindow", "打开图片文件夹")) 261 | -------------------------------------------------------------------------------- /gui.ui: -------------------------------------------------------------------------------- 1 | 2 | 3 | MainWindow 4 | 5 | 6 | 7 | 0 8 | 0 9 | 800 10 | 932 11 | 12 | 13 | 14 | 15 | 800 16 | 768 17 | 18 | 19 | 20 | 21 | 微软雅黑 22 | 23 | 24 | 25 | 自动语义分割 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 微软雅黑 43 | 44 | 45 | 46 | 图片选择 47 | 48 | 49 | Qt::PlainText 50 | 51 | 52 | Qt::AlignCenter 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 10 61 | 0 62 | 63 | 64 | 65 | 66 | 67 | 68 | -1 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 10 77 | 0 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 10 87 | 30 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 微软雅黑 97 | 98 | 99 | 100 | 输出信息 101 | 102 | 103 | Qt::AlignCenter 104 | 105 | 106 | 107 | 108 | 109 | 110 | 清空 111 | 112 | 113 | 114 | 115 | 116 | 117 | CPU处理一幅图像较慢,视硬件不同大概需要20秒到1分钟,请耐心等待; 118 | 使用GPU需安装CUDA环境否则仍然会用CPU进行处理,且显存小于8GB的显卡在处理某些图像时可能会爆显存而输出错误的分割结果,这种情况下请尝试使用CPU处理该图片 119 | 120 | 121 | 122 | CPU 123 | 124 | 125 | 126 | 127 | GPU 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | false 136 | 137 | 138 | 139 | 微软雅黑 140 | 141 | 142 | 143 | 保存分割结果 144 | 145 | 146 | 147 | 148 | 149 | 150 | 0 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 微软雅黑 159 | 160 | 161 | 162 | 是否使用GPU 163 | (需安装CUDA) 164 | 165 | 166 | Qt::PlainText 167 | 168 | 169 | Qt::AlignCenter 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 微软雅黑 178 | 179 | 180 | 181 | 进行分割 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 0 190 | 0 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 微软雅黑 203 | 204 | 205 | 206 | 语义分割模型 207 | (越往下精度越高 速度越慢) 208 | 209 | 210 | Qt::PlainText 211 | 212 | 213 | Qt::AlignCenter 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 微软雅黑 222 | 223 | 224 | 225 | 浏览 226 | 227 | 228 | 229 | 230 | 231 | 232 | false 233 | 234 | 235 | 显示 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 3 246 | 247 | 248 | 249 | ade20k-mobilenetv2dilated-c1_deepsup 250 | 251 | 252 | 253 | 254 | ade20k-resnet18dilated-c1_deepsup 255 | 256 | 257 | 258 | 259 | ade20k-resnet18dilated-ppm_deepsup 260 | 261 | 262 | 263 | 264 | ade20k-resnet50dilated-ppm_deepsup 265 | 266 | 267 | 268 | 269 | ade20k-resnet101dilated-ppm_deepsup 270 | 271 | 272 | 273 | 274 | ade20k-resnet50-upernet 275 | 276 | 277 | 278 | 279 | ade20k-resnet101-upernet 280 | 281 | 282 | 283 | 284 | ade20k-hrnetv2 285 | 286 | 287 | 288 | 289 | 290 | 291 | 292 | 293 | 294 | 295 | 296 | 297 | 298 | 299 | 300 | 打开图片 301 | 302 | 303 | 304 | 305 | 打开图片文件夹 306 | 307 | 308 | 309 | 310 | 311 | 312 | 313 | false 314 | 315 | 316 | 隐藏 317 | 318 | 319 | 320 | 321 | 322 | 323 | 324 | 微软雅黑 325 | 326 | 327 | 328 | 显示隐藏分割图层 329 | (可以用ctrl和shift多选) 330 | 331 | 332 | Qt::PlainText 333 | 334 | 335 | Qt::AlignCenter 336 | 337 | 338 | 339 | 340 | 341 | 342 | 343 | 微软雅黑 344 | 345 | 346 | 347 | 分割效果预览 348 | 349 | 350 | Qt::PlainText 351 | 352 | 353 | Qt::AlignCenter 354 | 355 | 356 | 357 | 358 | 359 | 360 | 361 | 362 | 363 | 364 | 0 365 | 0 366 | 800 367 | 26 368 | 369 | 370 | 371 | 372 | 菜单 373 | 374 | 375 | 376 | 377 | 378 | 379 | 380 | 381 | 382 | 主题 383 | 384 | 385 | 386 | 387 | 388 | 389 | 关于 390 | 391 | 392 | 393 | 394 | 395 | 396 | 397 | 398 | 399 | 400 | 401 | :/dark/sizegrip.svg:/dark/sizegrip.svg 402 | 403 | 404 | 打开图片 405 | 406 | 407 | 408 | 409 | 410 | :/light/undock-hover.svg:/light/undock-hover.svg 411 | 412 | 413 | 保存分割结果 414 | 415 | 416 | 417 | 418 | 419 | :/dark/close-pressed.svg:/dark/close-pressed.svg 420 | 421 | 422 | 退出 423 | 424 | 425 | 426 | 427 | dark 428 | 429 | 430 | 431 | 432 | light 433 | 434 | 435 | 436 | 437 | 作者信息 438 | 439 | 440 | 441 | 442 | 443 | :/dark/vmovetoolbar.svg:/dark/vmovetoolbar.svg 444 | 445 | 446 | 打开图片文件夹 447 | 448 | 449 | 450 | 451 | 452 | 453 | -------------------------------------------------------------------------------- /gui_main.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | from PIL import Image 6 | from PIL.ImageQt import ImageQt 7 | from PyQt5 import QtWidgets 8 | from PyQt5.QtCore import QFile, QTextStream, Qt, QTimer, pyqtSignal 9 | # make the example runnable without the need to install 10 | from PyQt5.QtGui import QPixmap 11 | from PyQt5.QtWidgets import QFileDialog, qApp, QMessageBox, QAbstractItemView, QMainWindow, QApplication 12 | from BreezeStyleSheets import breeze_resources 13 | 14 | from test import arg_from_ui 15 | import gui 16 | 17 | 18 | class Steam: 19 | def __init__(self, view): 20 | self.view = view 21 | 22 | def write(self, *args): 23 | self.view.append(*args) 24 | 25 | def flush(self): 26 | return 27 | 28 | 29 | class Signals(object): 30 | @staticmethod 31 | def open_pic(ui, text): 32 | options = QFileDialog.Options() 33 | options |= QFileDialog.DontUseNativeDialog 34 | try: 35 | if text == '打开图片': 36 | fileName, _ = QFileDialog.getOpenFileName(None, "打开图片", "", 37 | "JPEG图片 (*.jpg);;PNG图片(*.png)", options=options) 38 | elif text == '打开图片文件夹': 39 | fileName = QFileDialog.getExistingDirectory(None, "打开图片文件夹", "", options=options) 40 | if fileName: 41 | print('选择路径:', fileName) 42 | ui.lineEdit.setText(fileName) 43 | except ValueError: 44 | print('没有选择文件') 45 | 46 | @staticmethod 47 | def stack_images(directory, layer_set): 48 | if len(layer_set) == 0: 49 | return None 50 | for index, item in enumerate(layer_set): 51 | if index == 0: 52 | back = Image.open(os.path.join(directory, item)) 53 | continue 54 | fore = Image.open(os.path.join(directory, item)) 55 | back = Image.alpha_composite(back, fore) 56 | return ImageQt(back) 57 | 58 | def save_segmentation(self, seg_dir, layer_set): 59 | if hasattr(self, 'seg_dir'): 60 | for index, item in enumerate(layer_set): 61 | if index == 0: 62 | back = Image.open(os.path.join(seg_dir, item)) 63 | continue 64 | fore = Image.open(os.path.join(seg_dir, item)) 65 | back = Image.alpha_composite(back, fore) 66 | options = QFileDialog.Options() 67 | options |= QFileDialog.DontUseNativeDialog 68 | try: 69 | fileName, _ = QFileDialog.getSaveFileName(None, "保存当前显示分割结果", "", 70 | "PNG图片(*.png)", options=options) 71 | if fileName: 72 | if fileName.find('.png') == -1: 73 | back.save(fileName + '.png') 74 | print('保存至', fileName + '.png') 75 | else: 76 | back.save(fileName) 77 | print('保存至', fileName) 78 | except ValueError: 79 | print('保存时出现未知错误') 80 | 81 | @staticmethod 82 | def show_author(): 83 | msg = QMessageBox() 84 | msg.setIcon(QMessageBox.Information) 85 | msg.setText("作者:Mr.ET") 86 | msg.setInformativeText("作者邮箱:1900432020@email.szu.edu.cn") 87 | msg.setWindowTitle("作者信息") 88 | msg.setStandardButtons(QMessageBox.Ok) 89 | msg.exec_() 90 | 91 | gpu_mode = None 92 | def show_gpu_information(self, mode): 93 | if self.gpu_mode != mode: 94 | self.gpu_mode = mode 95 | msg = QMessageBox() 96 | msg.setIcon(QMessageBox.Warning) 97 | if mode == 'CPU': 98 | msg.setText("CPU处理一幅图像较慢,视硬件不同大概需要20秒到1分钟") 99 | msg.setInformativeText("由于处理较慢,执行过程中程序可能出现鼠标转圈的假死现象,请耐心等待") 100 | msg.setWindowTitle("使用CPU进行分割") 101 | elif mode == 'GPU': 102 | msg.setText("使用GPU分割需安装CUDA环境(本程序使用CUDA10.1环境进行开发),否则仍然会用CPU进行处理,安装链接:https://developer.nvidia.com/cuda-downloads") 103 | msg.setInformativeText("显存小于8GB的显卡处理某些图片可能会爆显存,这种情况下建议使用CPU处理该图片") 104 | msg.setWindowTitle("使用GPU进行分割") 105 | msg.setStandardButtons(QMessageBox.Ok) 106 | msg.exec_() 107 | 108 | def change_seg_folder(self, ui): 109 | self.layer_set.clear() 110 | for file in os.listdir(self.seg_dir): 111 | if file.endswith(".png"): 112 | if file.find('seg') != -1 or file.find('org') != -1: 113 | continue 114 | self.layer_set.add(file) 115 | 116 | ui.listWidget.clear() 117 | ui.listWidget.addItems(list(self.layer_set)) 118 | self.refresh(ui) 119 | 120 | def change_seg_combo_box(self, ui): 121 | self.seg_dir = os.path.join(self.result, ui.seg_combo_box.currentText()) 122 | self.change_seg_folder(ui) 123 | 124 | def refresh(self, ui, change=True): 125 | if hasattr(self, 'seg_dir'): 126 | if change: 127 | self.qim = self.stack_images(self.seg_dir, self.layer_set) 128 | if self.qim: 129 | pix_map = QPixmap.fromImage(self.qim) 130 | scaled_pix_map = pix_map.scaled(ui.show_layers.size(), Qt.KeepAspectRatio) 131 | ui.show_layers.setPixmap(scaled_pix_map) 132 | else: 133 | # 没有图层显示时直接清空 134 | ui.show_layers.clear() 135 | 136 | def change_layers(self, ui, operation): 137 | if len(ui.listWidget.selectedItems()) == 0: 138 | return 139 | for item in ui.listWidget.selectedItems(): 140 | # print(item.text()) 141 | item_name = item.text() 142 | # 隐藏操作 143 | if operation == 'hide': 144 | if item_name in self.layer_set: 145 | ui.listWidget.findItems(item_name, Qt.MatchExactly)[0].setForeground(Qt.gray) 146 | self.layer_set.remove(item_name) 147 | # 显示操作 148 | elif operation == 'show': 149 | if item_name not in self.layer_set: 150 | self.layer_set.add(item_name) 151 | if ui.mode == 'dark': 152 | ui.listWidget.findItems(item_name, Qt.MatchExactly)[0].setForeground(Qt.white) 153 | elif ui.mode == 'light': 154 | ui.listWidget.findItems(item_name, Qt.MatchExactly)[0].setForeground(Qt.black) 155 | self.refresh(ui) 156 | 157 | def seg_confirm(self, ui): 158 | if ui.gpu_combo_box.currentText() == 'CPU': 159 | gpu_flag = False 160 | print('CPU处理一幅图像较慢,视硬件不同大概需要20秒到1分钟,请耐心等待') 161 | elif ui.gpu_combo_box.currentText() == 'GPU': 162 | gpu_flag = True 163 | print('使用GPU需安装CUDA环境否则仍然会用CPU进行处理,且显存小于8GB的显卡处理某些图片可能会爆显存,爆显存的图片建议使用CPU处理该图片') 164 | self.show_gpu_information(ui.gpu_combo_box.currentText()) 165 | imgs = ui.lineEdit.text() 166 | if not os.path.exists(imgs): 167 | print('没有找到' + str(ui.lineEdit.text()) + '图片或图片文件夹,请检查路径是否正确') 168 | return 169 | model = ui.model_combo_box.currentText() 170 | if not os.path.exists(model): 171 | print('在程序目录下没有找到' + str(model) + '文件夹,请检查是否已经下载选择的模型并放入程序目录下') 172 | return 173 | cfg_path = os.path.join('config', str(model) + '.yaml') 174 | if not os.path.exists(cfg_path): 175 | print('没有找到' + str(cfg_path) + '配置文件,请检查是否已经下载yaml配置文件并放入config目录下') 176 | return 177 | pth_dict = {'ade20k-hrnetv2': 30, 'ade20k-mobilenetv2dilated-c1_deepsup': 20, 178 | 'ade20k-resnet18dilated-c1_deepsup': 20, 179 | 'ade20k-resnet18dilated-ppm_deepsup': 20, 'ade20k-resnet50dilated-ppm_deepsup': 20, 180 | 'ade20k-resnet50-upernet': 30, 'ade20k-resnet101dilated-ppm_deepsup': 25, 181 | 'ade20k-resnet101-upernet': 50} 182 | encoder_path = os.path.join(model, 'encoder_epoch_' + str(pth_dict[model]) + '.pth') 183 | decoder_path = os.path.join(model, 'decoder_epoch_' + str(pth_dict[model]) + '.pth') 184 | if not os.path.exists(encoder_path): 185 | print('没有找到' + str(encoder_path) + 'pth文件,请检查是否已经下载选择的模型并放入程序目录下') 186 | return 187 | if not os.path.exists(decoder_path): 188 | print('没有找到' + str(decoder_path) + 'pth文件,请检查是否已经下载选择的模型并放入程序目录下') 189 | return 190 | 191 | for file in os.listdir(model): 192 | if file.endswith(".pth"): 193 | checkpoint = file[file.find('epoch'):] 194 | break 195 | self.result = 'segmentation' 196 | ui.progressBar.setValue(0) 197 | arg_from_ui(imgs=imgs, progress=ui.progressBar, gpu_flag=gpu_flag, config_path=cfg_path, 198 | dir=model, checkpoint=checkpoint, result=self.result) 199 | 200 | self.layer_set = set() 201 | # 如果是文件夹,需要将图片添加到选择框中 202 | if os.path.isdir(imgs): 203 | ui.seg_combo_box.clear() 204 | for file in os.listdir(imgs): 205 | if file.endswith(".png") or file.endswith(".jpg"): 206 | ui.seg_combo_box.addItem(os.path.splitext(file)[0]) 207 | self.seg_dir = os.path.join(self.result, os.path.splitext(os.path.basename(os.listdir(imgs)[0]))[0]) 208 | 209 | # 假如输入单张图片直接显示 210 | elif os.path.isfile(imgs): 211 | self.seg_dir = os.path.join(self.result, os.path.splitext(os.path.basename(imgs))[0]) 212 | self.change_seg_folder(ui) 213 | ui.display_button.setEnabled(True) 214 | ui.hide_button.setEnabled(True) 215 | ui.save_seg_button.setEnabled(True) 216 | 217 | def register_signal(self, app, window, ui): 218 | sys.stdout = Steam(ui.textBrowser) 219 | ui.browser_button.clicked.connect(lambda: self.open_pic(ui, ui.open_pic_combo_box.currentText())) 220 | ui.open_pic.triggered.connect(lambda: self.open_pic(ui, '打开图片')) 221 | ui.open_pic_folder.triggered.connect(lambda: self.open_pic(ui, '打开图片文件夹')) 222 | ui.seg_confirm_button.clicked.connect(lambda: self.seg_confirm(ui)) 223 | ui.output_empty_button.clicked.connect(lambda: ui.textBrowser.setText('')) 224 | ui.show_layers.setMinimumSize(1, 1) 225 | ui.show_layers.setAlignment(Qt.AlignCenter) 226 | ui.listWidget.setSelectionMode(QAbstractItemView.ExtendedSelection) 227 | ui.display_button.clicked.connect(lambda: self.change_layers(ui, 'show')) 228 | ui.hide_button.clicked.connect(lambda: self.change_layers(ui, 'hide')) 229 | ui.save_seg_button.clicked.connect(lambda x: self.save_segmentation(self.seg_dir, self.layer_set)) 230 | ui.save_seg.triggered.connect( 231 | lambda x: self.save_segmentation(self.seg_dir, self.layer_set) if hasattr(self, 'seg_dir') else x) 232 | window.resized.connect(lambda: self.refresh(ui, False)) 233 | ui.seg_combo_box.currentTextChanged.connect(lambda: self.change_seg_combo_box(ui)) 234 | ui.action_dark.triggered.connect(lambda: set_theme(app, ui, 'dark')) 235 | ui.action_light.triggered.connect(lambda: set_theme(app, ui, 'light')) 236 | ui.author.triggered.connect(lambda: self.show_author()) 237 | ui.exit.triggered.connect(qApp.quit) 238 | 239 | 240 | class Window(QMainWindow): 241 | resized = pyqtSignal() 242 | 243 | def __init__(self, parent=None): 244 | super(Window, self).__init__(parent=parent) 245 | 246 | def resizeEvent(self, event): 247 | self.resized.emit() 248 | return super(Window, self).resizeEvent(event) 249 | 250 | 251 | def set_theme(app, ui, mode): 252 | ui.mode = mode 253 | if mode in ['dark', 'light']: 254 | file = QFile(":/" + mode + ".qss") 255 | file.open(QFile.ReadOnly | QFile.Text) 256 | stream = QTextStream(file) 257 | app.setStyleSheet(stream.readAll()) 258 | 259 | 260 | def main(): 261 | """ 262 | Application entry point 263 | """ 264 | # logging.basicConfig(level=logging.DEBUG) 265 | # create the application and the main window 266 | app = QApplication(sys.argv) 267 | # app.setStyle(QtWidgets.QStyleFactory.create("fusion")) 268 | # window = QtWidgets.QMainWindow() 269 | window = Window() 270 | 271 | # 设置UI界面 272 | ui = gui.Ui_MainWindow() 273 | ui.setupUi(window) 274 | # 设置默认主题为dark 275 | set_theme(app, ui, 'dark') 276 | 277 | # register multiple signals 278 | signals = Signals() 279 | signals.register_signal(app, window, ui) 280 | 281 | # auto quit after 2s when testing on travis-ci 282 | # if "--travis" in sys.argv: 283 | # QTimer.singleShot(2000, app.exit) 284 | 285 | # run 286 | window.show() 287 | # app.exec_() 288 | sys.exit(app.exec_()) 289 | 290 | 291 | if __name__ == "__main__": 292 | main() 293 | -------------------------------------------------------------------------------- /lib/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .modules import * 2 | from .parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to 3 | -------------------------------------------------------------------------------- /lib/nn/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback 13 | -------------------------------------------------------------------------------- /lib/nn/modules/batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | from torch.nn.modules.batchnorm import _BatchNorm 17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 18 | 19 | from .comm import SyncMaster 20 | 21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] 22 | 23 | 24 | def _sum_ft(tensor): 25 | """sum over the first and last dimention""" 26 | return tensor.sum(dim=0).sum(dim=-1) 27 | 28 | 29 | def _unsqueeze_ft(tensor): 30 | """add new dementions at the front and the tail""" 31 | return tensor.unsqueeze(0).unsqueeze(-1) 32 | 33 | 34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 36 | 37 | 38 | class _SynchronizedBatchNorm(_BatchNorm): 39 | def __init__(self, num_features, eps=1e-5, momentum=0.001, affine=True): 40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 41 | 42 | self._sync_master = SyncMaster(self._data_parallel_master) 43 | 44 | self._is_parallel = False 45 | self._parallel_id = None 46 | self._slave_pipe = None 47 | 48 | # customed batch norm statistics 49 | self._moving_average_fraction = 1. - momentum 50 | self.register_buffer('_tmp_running_mean', torch.zeros(self.num_features)) 51 | self.register_buffer('_tmp_running_var', torch.ones(self.num_features)) 52 | self.register_buffer('_running_iter', torch.ones(1)) 53 | self._tmp_running_mean = self.running_mean.clone() * self._running_iter 54 | self._tmp_running_var = self.running_var.clone() * self._running_iter 55 | 56 | def forward(self, input): 57 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 58 | if not (self._is_parallel and self.training): 59 | return F.batch_norm( 60 | input, self.running_mean, self.running_var, self.weight, self.bias, 61 | self.training, self.momentum, self.eps) 62 | 63 | # Resize the input to (B, C, -1). 64 | input_shape = input.size() 65 | input = input.view(input.size(0), self.num_features, -1) 66 | 67 | # Compute the sum and square-sum. 68 | sum_size = input.size(0) * input.size(2) 69 | input_sum = _sum_ft(input) 70 | input_ssum = _sum_ft(input ** 2) 71 | 72 | # Reduce-and-broadcast the statistics. 73 | if self._parallel_id == 0: 74 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 75 | else: 76 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 77 | 78 | # Compute the output. 79 | if self.affine: 80 | # MJY:: Fuse the multiplication for speed. 81 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 82 | else: 83 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 84 | 85 | # Reshape it. 86 | return output.view(input_shape) 87 | 88 | def __data_parallel_replicate__(self, ctx, copy_id): 89 | self._is_parallel = True 90 | self._parallel_id = copy_id 91 | 92 | # parallel_id == 0 means master device. 93 | if self._parallel_id == 0: 94 | ctx.sync_master = self._sync_master 95 | else: 96 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 97 | 98 | def _data_parallel_master(self, intermediates): 99 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 100 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 101 | 102 | to_reduce = [i[1][:2] for i in intermediates] 103 | to_reduce = [j for i in to_reduce for j in i] # flatten 104 | target_gpus = [i[1].sum.get_device() for i in intermediates] 105 | 106 | sum_size = sum([i[1].sum_size for i in intermediates]) 107 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 108 | 109 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 110 | 111 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 112 | 113 | outputs = [] 114 | for i, rec in enumerate(intermediates): 115 | outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) 116 | 117 | return outputs 118 | 119 | def _add_weighted(self, dest, delta, alpha=1, beta=1, bias=0): 120 | """return *dest* by `dest := dest*alpha + delta*beta + bias`""" 121 | return dest * alpha + delta * beta + bias 122 | 123 | def _compute_mean_std(self, sum_, ssum, size): 124 | """Compute the mean and standard-deviation with sum and square-sum. This method 125 | also maintains the moving average on the master device.""" 126 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 127 | mean = sum_ / size 128 | sumvar = ssum - sum_ * mean 129 | unbias_var = sumvar / (size - 1) 130 | bias_var = sumvar / size 131 | 132 | self._tmp_running_mean = self._add_weighted(self._tmp_running_mean, mean.data, alpha=self._moving_average_fraction) 133 | self._tmp_running_var = self._add_weighted(self._tmp_running_var, unbias_var.data, alpha=self._moving_average_fraction) 134 | self._running_iter = self._add_weighted(self._running_iter, 1, alpha=self._moving_average_fraction) 135 | 136 | self.running_mean = self._tmp_running_mean / self._running_iter 137 | self.running_var = self._tmp_running_var / self._running_iter 138 | 139 | return mean, bias_var.clamp(self.eps) ** -0.5 140 | 141 | 142 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 143 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 144 | mini-batch. 145 | 146 | .. math:: 147 | 148 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 149 | 150 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 151 | standard-deviation are reduced across all devices during training. 152 | 153 | For example, when one uses `nn.DataParallel` to wrap the network during 154 | training, PyTorch's implementation normalize the tensor on each device using 155 | the statistics only on that device, which accelerated the computation and 156 | is also easy to implement, but the statistics might be inaccurate. 157 | Instead, in this synchronized version, the statistics will be computed 158 | over all training samples distributed on multiple devices. 159 | 160 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 161 | as the built-in PyTorch implementation. 162 | 163 | The mean and standard-deviation are calculated per-dimension over 164 | the mini-batches and gamma and beta are learnable parameter vectors 165 | of size C (where C is the input size). 166 | 167 | During training, this layer keeps a running estimate of its computed mean 168 | and variance. The running sum is kept with a default momentum of 0.1. 169 | 170 | During evaluation, this running mean/variance is used for normalization. 171 | 172 | Because the BatchNorm is done over the `C` dimension, computing statistics 173 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 174 | 175 | Args: 176 | num_features: num_features from an expected input of size 177 | `batch_size x num_features [x width]` 178 | eps: a value added to the denominator for numerical stability. 179 | Default: 1e-5 180 | momentum: the value used for the running_mean and running_var 181 | computation. Default: 0.1 182 | affine: a boolean value that when set to ``True``, gives the layer learnable 183 | affine parameters. Default: ``True`` 184 | 185 | Shape: 186 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 187 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 188 | 189 | Examples: 190 | >>> # With Learnable Parameters 191 | >>> m = SynchronizedBatchNorm1d(100) 192 | >>> # Without Learnable Parameters 193 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 194 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 195 | >>> output = m(input) 196 | """ 197 | 198 | def _check_input_dim(self, input): 199 | if input.dim() != 2 and input.dim() != 3: 200 | raise ValueError('expected 2D or 3D input (got {}D input)' 201 | .format(input.dim())) 202 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input) 203 | 204 | 205 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 206 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 207 | of 3d inputs 208 | 209 | .. math:: 210 | 211 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 212 | 213 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 214 | standard-deviation are reduced across all devices during training. 215 | 216 | For example, when one uses `nn.DataParallel` to wrap the network during 217 | training, PyTorch's implementation normalize the tensor on each device using 218 | the statistics only on that device, which accelerated the computation and 219 | is also easy to implement, but the statistics might be inaccurate. 220 | Instead, in this synchronized version, the statistics will be computed 221 | over all training samples distributed on multiple devices. 222 | 223 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 224 | as the built-in PyTorch implementation. 225 | 226 | The mean and standard-deviation are calculated per-dimension over 227 | the mini-batches and gamma and beta are learnable parameter vectors 228 | of size C (where C is the input size). 229 | 230 | During training, this layer keeps a running estimate of its computed mean 231 | and variance. The running sum is kept with a default momentum of 0.1. 232 | 233 | During evaluation, this running mean/variance is used for normalization. 234 | 235 | Because the BatchNorm is done over the `C` dimension, computing statistics 236 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 237 | 238 | Args: 239 | num_features: num_features from an expected input of 240 | size batch_size x num_features x height x width 241 | eps: a value added to the denominator for numerical stability. 242 | Default: 1e-5 243 | momentum: the value used for the running_mean and running_var 244 | computation. Default: 0.1 245 | affine: a boolean value that when set to ``True``, gives the layer learnable 246 | affine parameters. Default: ``True`` 247 | 248 | Shape: 249 | - Input: :math:`(N, C, H, W)` 250 | - Output: :math:`(N, C, H, W)` (same shape as input) 251 | 252 | Examples: 253 | >>> # With Learnable Parameters 254 | >>> m = SynchronizedBatchNorm2d(100) 255 | >>> # Without Learnable Parameters 256 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 257 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 258 | >>> output = m(input) 259 | """ 260 | 261 | def _check_input_dim(self, input): 262 | if input.dim() != 4: 263 | raise ValueError('expected 4D input (got {}D input)' 264 | .format(input.dim())) 265 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input) 266 | 267 | 268 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 269 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 270 | of 4d inputs 271 | 272 | .. math:: 273 | 274 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 275 | 276 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 277 | standard-deviation are reduced across all devices during training. 278 | 279 | For example, when one uses `nn.DataParallel` to wrap the network during 280 | training, PyTorch's implementation normalize the tensor on each device using 281 | the statistics only on that device, which accelerated the computation and 282 | is also easy to implement, but the statistics might be inaccurate. 283 | Instead, in this synchronized version, the statistics will be computed 284 | over all training samples distributed on multiple devices. 285 | 286 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 287 | as the built-in PyTorch implementation. 288 | 289 | The mean and standard-deviation are calculated per-dimension over 290 | the mini-batches and gamma and beta are learnable parameter vectors 291 | of size C (where C is the input size). 292 | 293 | During training, this layer keeps a running estimate of its computed mean 294 | and variance. The running sum is kept with a default momentum of 0.1. 295 | 296 | During evaluation, this running mean/variance is used for normalization. 297 | 298 | Because the BatchNorm is done over the `C` dimension, computing statistics 299 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 300 | or Spatio-temporal BatchNorm 301 | 302 | Args: 303 | num_features: num_features from an expected input of 304 | size batch_size x num_features x depth x height x width 305 | eps: a value added to the denominator for numerical stability. 306 | Default: 1e-5 307 | momentum: the value used for the running_mean and running_var 308 | computation. Default: 0.1 309 | affine: a boolean value that when set to ``True``, gives the layer learnable 310 | affine parameters. Default: ``True`` 311 | 312 | Shape: 313 | - Input: :math:`(N, C, D, H, W)` 314 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 315 | 316 | Examples: 317 | >>> # With Learnable Parameters 318 | >>> m = SynchronizedBatchNorm3d(100) 319 | >>> # Without Learnable Parameters 320 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 321 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 322 | >>> output = m(input) 323 | """ 324 | 325 | def _check_input_dim(self, input): 326 | if input.dim() != 5: 327 | raise ValueError('expected 5D input (got {}D input)' 328 | .format(input.dim())) 329 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input) 330 | -------------------------------------------------------------------------------- /lib/nn/modules/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def register_slave(self, identifier): 79 | """ 80 | Register an slave device. 81 | 82 | Args: 83 | identifier: an identifier, usually is the device id. 84 | 85 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 86 | 87 | """ 88 | if self._activated: 89 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 90 | self._activated = False 91 | self._registry.clear() 92 | future = FutureResult() 93 | self._registry[identifier] = _MasterRegistry(future) 94 | return SlavePipe(identifier, self._queue, future) 95 | 96 | def run_master(self, master_msg): 97 | """ 98 | Main entry for the master device in each forward pass. 99 | The messages were first collected from each devices (including the master device), and then 100 | an callback will be invoked to compute the message to be sent back to each devices 101 | (including the master device). 102 | 103 | Args: 104 | master_msg: the message that the master want to send to itself. This will be placed as the first 105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 106 | 107 | Returns: the message to be sent back to the master device. 108 | 109 | """ 110 | self._activated = True 111 | 112 | intermediates = [(0, master_msg)] 113 | for i in range(self.nr_slaves): 114 | intermediates.append(self._queue.get()) 115 | 116 | results = self._master_callback(intermediates) 117 | assert results[0][0] == 0, 'The first result should belongs to the master.' 118 | 119 | for i, res in results: 120 | if i == 0: 121 | continue 122 | self._registry[i].result.put(res) 123 | 124 | for i in range(self.nr_slaves): 125 | assert self._queue.get() is True 126 | 127 | return results[0][1] 128 | 129 | @property 130 | def nr_slaves(self): 131 | return len(self._registry) 132 | -------------------------------------------------------------------------------- /lib/nn/modules/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /lib/nn/modules/tests/test_numeric_batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : test_numeric_batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | 9 | import unittest 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.autograd import Variable 14 | 15 | from sync_batchnorm.unittest import TorchTestCase 16 | 17 | 18 | def handy_var(a, unbias=True): 19 | n = a.size(0) 20 | asum = a.sum(dim=0) 21 | as_sum = (a ** 2).sum(dim=0) # a square sum 22 | sumvar = as_sum - asum * asum / n 23 | if unbias: 24 | return sumvar / (n - 1) 25 | else: 26 | return sumvar / n 27 | 28 | 29 | class NumericTestCase(TorchTestCase): 30 | def testNumericBatchNorm(self): 31 | a = torch.rand(16, 10) 32 | bn = nn.BatchNorm2d(10, momentum=1, eps=1e-5, affine=False) 33 | bn.train() 34 | 35 | a_var1 = Variable(a, requires_grad=True) 36 | b_var1 = bn(a_var1) 37 | loss1 = b_var1.sum() 38 | loss1.backward() 39 | 40 | a_var2 = Variable(a, requires_grad=True) 41 | a_mean2 = a_var2.mean(dim=0, keepdim=True) 42 | a_std2 = torch.sqrt(handy_var(a_var2, unbias=False).clamp(min=1e-5)) 43 | # a_std2 = torch.sqrt(a_var2.var(dim=0, keepdim=True, unbiased=False) + 1e-5) 44 | b_var2 = (a_var2 - a_mean2) / a_std2 45 | loss2 = b_var2.sum() 46 | loss2.backward() 47 | 48 | self.assertTensorClose(bn.running_mean, a.mean(dim=0)) 49 | self.assertTensorClose(bn.running_var, handy_var(a)) 50 | self.assertTensorClose(a_var1.data, a_var2.data) 51 | self.assertTensorClose(b_var1.data, b_var2.data) 52 | self.assertTensorClose(a_var1.grad, a_var2.grad) 53 | 54 | 55 | if __name__ == '__main__': 56 | unittest.main() 57 | -------------------------------------------------------------------------------- /lib/nn/modules/tests/test_sync_batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : test_sync_batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | 9 | import unittest 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.autograd import Variable 14 | 15 | from sync_batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, DataParallelWithCallback 16 | from sync_batchnorm.unittest import TorchTestCase 17 | 18 | 19 | def handy_var(a, unbias=True): 20 | n = a.size(0) 21 | asum = a.sum(dim=0) 22 | as_sum = (a ** 2).sum(dim=0) # a square sum 23 | sumvar = as_sum - asum * asum / n 24 | if unbias: 25 | return sumvar / (n - 1) 26 | else: 27 | return sumvar / n 28 | 29 | 30 | def _find_bn(module): 31 | for m in module.modules(): 32 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, SynchronizedBatchNorm1d, SynchronizedBatchNorm2d)): 33 | return m 34 | 35 | 36 | class SyncTestCase(TorchTestCase): 37 | def _syncParameters(self, bn1, bn2): 38 | bn1.reset_parameters() 39 | bn2.reset_parameters() 40 | if bn1.affine and bn2.affine: 41 | bn2.weight.data.copy_(bn1.weight.data) 42 | bn2.bias.data.copy_(bn1.bias.data) 43 | 44 | def _checkBatchNormResult(self, bn1, bn2, input, is_train, cuda=False): 45 | """Check the forward and backward for the customized batch normalization.""" 46 | bn1.train(mode=is_train) 47 | bn2.train(mode=is_train) 48 | 49 | if cuda: 50 | input = input.cuda() 51 | 52 | self._syncParameters(_find_bn(bn1), _find_bn(bn2)) 53 | 54 | input1 = Variable(input, requires_grad=True) 55 | output1 = bn1(input1) 56 | output1.sum().backward() 57 | input2 = Variable(input, requires_grad=True) 58 | output2 = bn2(input2) 59 | output2.sum().backward() 60 | 61 | self.assertTensorClose(input1.data, input2.data) 62 | self.assertTensorClose(output1.data, output2.data) 63 | self.assertTensorClose(input1.grad, input2.grad) 64 | self.assertTensorClose(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean) 65 | self.assertTensorClose(_find_bn(bn1).running_var, _find_bn(bn2).running_var) 66 | 67 | def testSyncBatchNormNormalTrain(self): 68 | bn = nn.BatchNorm1d(10) 69 | sync_bn = SynchronizedBatchNorm1d(10) 70 | 71 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True) 72 | 73 | def testSyncBatchNormNormalEval(self): 74 | bn = nn.BatchNorm1d(10) 75 | sync_bn = SynchronizedBatchNorm1d(10) 76 | 77 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False) 78 | 79 | def testSyncBatchNormSyncTrain(self): 80 | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) 81 | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 82 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 83 | 84 | bn.cuda() 85 | sync_bn.cuda() 86 | 87 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True, cuda=True) 88 | 89 | def testSyncBatchNormSyncEval(self): 90 | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) 91 | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 92 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 93 | 94 | bn.cuda() 95 | sync_bn.cuda() 96 | 97 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True) 98 | 99 | def testSyncBatchNorm2DSyncTrain(self): 100 | bn = nn.BatchNorm2d(10) 101 | sync_bn = SynchronizedBatchNorm2d(10) 102 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 103 | 104 | bn.cuda() 105 | sync_bn.cuda() 106 | 107 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True) 108 | 109 | 110 | if __name__ == '__main__': 111 | unittest.main() 112 | -------------------------------------------------------------------------------- /lib/nn/modules/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /lib/nn/parallel/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to 2 | -------------------------------------------------------------------------------- /lib/nn/parallel/data_parallel.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf8 -*- 2 | 3 | import torch.cuda as cuda 4 | import torch.nn as nn 5 | import torch 6 | import collections 7 | from torch.nn.parallel._functions import Gather 8 | 9 | 10 | __all__ = ['UserScatteredDataParallel', 'user_scattered_collate', 'async_copy_to'] 11 | 12 | 13 | def async_copy_to(obj, dev, main_stream=None): 14 | if torch.is_tensor(obj): 15 | v = obj.cuda(dev, non_blocking=True) 16 | if main_stream is not None: 17 | v.data.record_stream(main_stream) 18 | return v 19 | elif isinstance(obj, collections.Mapping): 20 | return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()} 21 | elif isinstance(obj, collections.Sequence): 22 | return [async_copy_to(o, dev, main_stream) for o in obj] 23 | else: 24 | return obj 25 | 26 | 27 | def dict_gather(outputs, target_device, dim=0): 28 | """ 29 | Gathers variables from different GPUs on a specified device 30 | (-1 means the CPU), with dictionary support. 31 | """ 32 | def gather_map(outputs): 33 | out = outputs[0] 34 | if torch.is_tensor(out): 35 | # MJY(20180330) HACK:: force nr_dims > 0 36 | if out.dim() == 0: 37 | outputs = [o.unsqueeze(0) for o in outputs] 38 | return Gather.apply(target_device, dim, *outputs) 39 | elif out is None: 40 | return None 41 | elif isinstance(out, collections.Mapping): 42 | return {k: gather_map([o[k] for o in outputs]) for k in out} 43 | elif isinstance(out, collections.Sequence): 44 | return type(out)(map(gather_map, zip(*outputs))) 45 | return gather_map(outputs) 46 | 47 | 48 | class DictGatherDataParallel(nn.DataParallel): 49 | def gather(self, outputs, output_device): 50 | return dict_gather(outputs, output_device, dim=self.dim) 51 | 52 | 53 | class UserScatteredDataParallel(DictGatherDataParallel): 54 | def scatter(self, inputs, kwargs, device_ids): 55 | assert len(inputs) == 1 56 | inputs = inputs[0] 57 | inputs = _async_copy_stream(inputs, device_ids) 58 | inputs = [[i] for i in inputs] 59 | assert len(kwargs) == 0 60 | kwargs = [{} for _ in range(len(inputs))] 61 | 62 | return inputs, kwargs 63 | 64 | 65 | def user_scattered_collate(batch): 66 | return batch 67 | 68 | 69 | def _async_copy(inputs, device_ids): 70 | nr_devs = len(device_ids) 71 | assert type(inputs) in (tuple, list) 72 | assert len(inputs) == nr_devs 73 | 74 | outputs = [] 75 | for i, dev in zip(inputs, device_ids): 76 | with cuda.device(dev): 77 | outputs.append(async_copy_to(i, dev)) 78 | 79 | return tuple(outputs) 80 | 81 | 82 | def _async_copy_stream(inputs, device_ids): 83 | nr_devs = len(device_ids) 84 | assert type(inputs) in (tuple, list) 85 | assert len(inputs) == nr_devs 86 | 87 | outputs = [] 88 | streams = [_get_stream(d) for d in device_ids] 89 | for i, dev, stream in zip(inputs, device_ids, streams): 90 | with cuda.device(dev): 91 | main_stream = cuda.current_stream() 92 | with cuda.stream(stream): 93 | outputs.append(async_copy_to(i, dev, main_stream=main_stream)) 94 | main_stream.wait_stream(stream) 95 | 96 | return outputs 97 | 98 | 99 | """Adapted from: torch/nn/parallel/_functions.py""" 100 | # background streams used for copying 101 | _streams = None 102 | 103 | 104 | def _get_stream(device): 105 | """Gets a background stream for copying between CPU and GPU""" 106 | global _streams 107 | if device == -1: 108 | return None 109 | if _streams is None: 110 | _streams = [None] * cuda.device_count() 111 | if _streams[device] is None: _streams[device] = cuda.Stream(device) 112 | return _streams[device] 113 | -------------------------------------------------------------------------------- /lib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .th import * 2 | -------------------------------------------------------------------------------- /lib/utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .dataset import Dataset, TensorDataset, ConcatDataset 3 | from .dataloader import DataLoader 4 | -------------------------------------------------------------------------------- /lib/utils/data/dataset.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import warnings 3 | 4 | from torch._utils import _accumulate 5 | from torch import randperm 6 | 7 | 8 | class Dataset(object): 9 | """An abstract class representing a Dataset. 10 | 11 | All other datasets should subclass it. All subclasses should override 12 | ``__len__``, that provides the size of the dataset, and ``__getitem__``, 13 | supporting integer indexing in range from 0 to len(self) exclusive. 14 | """ 15 | 16 | def __getitem__(self, index): 17 | raise NotImplementedError 18 | 19 | def __len__(self): 20 | raise NotImplementedError 21 | 22 | def __add__(self, other): 23 | return ConcatDataset([self, other]) 24 | 25 | 26 | class TensorDataset(Dataset): 27 | """Dataset wrapping data and target tensors. 28 | 29 | Each sample will be retrieved by indexing both tensors along the first 30 | dimension. 31 | 32 | Arguments: 33 | data_tensor (Tensor): contains sample data. 34 | target_tensor (Tensor): contains sample targets (labels). 35 | """ 36 | 37 | def __init__(self, data_tensor, target_tensor): 38 | assert data_tensor.size(0) == target_tensor.size(0) 39 | self.data_tensor = data_tensor 40 | self.target_tensor = target_tensor 41 | 42 | def __getitem__(self, index): 43 | return self.data_tensor[index], self.target_tensor[index] 44 | 45 | def __len__(self): 46 | return self.data_tensor.size(0) 47 | 48 | 49 | class ConcatDataset(Dataset): 50 | """ 51 | Dataset to concatenate multiple datasets. 52 | Purpose: useful to assemble different existing datasets, possibly 53 | large-scale datasets as the concatenation operation is done in an 54 | on-the-fly manner. 55 | 56 | Arguments: 57 | datasets (iterable): List of datasets to be concatenated 58 | """ 59 | 60 | @staticmethod 61 | def cumsum(sequence): 62 | r, s = [], 0 63 | for e in sequence: 64 | l = len(e) 65 | r.append(l + s) 66 | s += l 67 | return r 68 | 69 | def __init__(self, datasets): 70 | super(ConcatDataset, self).__init__() 71 | assert len(datasets) > 0, 'datasets should not be an empty iterable' 72 | self.datasets = list(datasets) 73 | self.cumulative_sizes = self.cumsum(self.datasets) 74 | 75 | def __len__(self): 76 | return self.cumulative_sizes[-1] 77 | 78 | def __getitem__(self, idx): 79 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 80 | if dataset_idx == 0: 81 | sample_idx = idx 82 | else: 83 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 84 | return self.datasets[dataset_idx][sample_idx] 85 | 86 | @property 87 | def cummulative_sizes(self): 88 | warnings.warn("cummulative_sizes attribute is renamed to " 89 | "cumulative_sizes", DeprecationWarning, stacklevel=2) 90 | return self.cumulative_sizes 91 | 92 | 93 | class Subset(Dataset): 94 | def __init__(self, dataset, indices): 95 | self.dataset = dataset 96 | self.indices = indices 97 | 98 | def __getitem__(self, idx): 99 | return self.dataset[self.indices[idx]] 100 | 101 | def __len__(self): 102 | return len(self.indices) 103 | 104 | 105 | def random_split(dataset, lengths): 106 | """ 107 | Randomly split a dataset into non-overlapping new datasets of given lengths 108 | ds 109 | 110 | Arguments: 111 | dataset (Dataset): Dataset to be split 112 | lengths (iterable): lengths of splits to be produced 113 | """ 114 | if sum(lengths) != len(dataset): 115 | raise ValueError("Sum of input lengths does not equal the length of the input dataset!") 116 | 117 | indices = randperm(sum(lengths)) 118 | return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)] 119 | -------------------------------------------------------------------------------- /lib/utils/data/distributed.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from .sampler import Sampler 4 | from torch.distributed import get_world_size, get_rank 5 | 6 | 7 | class DistributedSampler(Sampler): 8 | """Sampler that restricts data loading to a subset of the dataset. 9 | 10 | It is especially useful in conjunction with 11 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 12 | process can pass a DistributedSampler instance as a DataLoader sampler, 13 | and load a subset of the original dataset that is exclusive to it. 14 | 15 | .. note:: 16 | Dataset is assumed to be of constant size. 17 | 18 | Arguments: 19 | dataset: Dataset used for sampling. 20 | num_replicas (optional): Number of processes participating in 21 | distributed training. 22 | rank (optional): Rank of the current process within num_replicas. 23 | """ 24 | 25 | def __init__(self, dataset, num_replicas=None, rank=None): 26 | if num_replicas is None: 27 | num_replicas = get_world_size() 28 | if rank is None: 29 | rank = get_rank() 30 | self.dataset = dataset 31 | self.num_replicas = num_replicas 32 | self.rank = rank 33 | self.epoch = 0 34 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 35 | self.total_size = self.num_samples * self.num_replicas 36 | 37 | def __iter__(self): 38 | # deterministically shuffle based on epoch 39 | g = torch.Generator() 40 | g.manual_seed(self.epoch) 41 | indices = list(torch.randperm(len(self.dataset), generator=g)) 42 | 43 | # add extra samples to make it evenly divisible 44 | indices += indices[:(self.total_size - len(indices))] 45 | assert len(indices) == self.total_size 46 | 47 | # subsample 48 | offset = self.num_samples * self.rank 49 | indices = indices[offset:offset + self.num_samples] 50 | assert len(indices) == self.num_samples 51 | 52 | return iter(indices) 53 | 54 | def __len__(self): 55 | return self.num_samples 56 | 57 | def set_epoch(self, epoch): 58 | self.epoch = epoch 59 | -------------------------------------------------------------------------------- /lib/utils/data/sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Sampler(object): 5 | """Base class for all Samplers. 6 | 7 | Every Sampler subclass has to provide an __iter__ method, providing a way 8 | to iterate over indices of dataset elements, and a __len__ method that 9 | returns the length of the returned iterators. 10 | """ 11 | 12 | def __init__(self, data_source): 13 | pass 14 | 15 | def __iter__(self): 16 | raise NotImplementedError 17 | 18 | def __len__(self): 19 | raise NotImplementedError 20 | 21 | 22 | class SequentialSampler(Sampler): 23 | """Samples elements sequentially, always in the same order. 24 | 25 | Arguments: 26 | data_source (Dataset): dataset to sample from 27 | """ 28 | 29 | def __init__(self, data_source): 30 | self.data_source = data_source 31 | 32 | def __iter__(self): 33 | return iter(range(len(self.data_source))) 34 | 35 | def __len__(self): 36 | return len(self.data_source) 37 | 38 | 39 | class RandomSampler(Sampler): 40 | """Samples elements randomly, without replacement. 41 | 42 | Arguments: 43 | data_source (Dataset): dataset to sample from 44 | """ 45 | 46 | def __init__(self, data_source): 47 | self.data_source = data_source 48 | 49 | def __iter__(self): 50 | return iter(torch.randperm(len(self.data_source)).long()) 51 | 52 | def __len__(self): 53 | return len(self.data_source) 54 | 55 | 56 | class SubsetRandomSampler(Sampler): 57 | """Samples elements randomly from a given list of indices, without replacement. 58 | 59 | Arguments: 60 | indices (list): a list of indices 61 | """ 62 | 63 | def __init__(self, indices): 64 | self.indices = indices 65 | 66 | def __iter__(self): 67 | return (self.indices[i] for i in torch.randperm(len(self.indices))) 68 | 69 | def __len__(self): 70 | return len(self.indices) 71 | 72 | 73 | class WeightedRandomSampler(Sampler): 74 | """Samples elements from [0,..,len(weights)-1] with given probabilities (weights). 75 | 76 | Arguments: 77 | weights (list) : a list of weights, not necessary summing up to one 78 | num_samples (int): number of samples to draw 79 | replacement (bool): if ``True``, samples are drawn with replacement. 80 | If not, they are drawn without replacement, which means that when a 81 | sample index is drawn for a row, it cannot be drawn again for that row. 82 | """ 83 | 84 | def __init__(self, weights, num_samples, replacement=True): 85 | self.weights = torch.DoubleTensor(weights) 86 | self.num_samples = num_samples 87 | self.replacement = replacement 88 | 89 | def __iter__(self): 90 | return iter(torch.multinomial(self.weights, self.num_samples, self.replacement)) 91 | 92 | def __len__(self): 93 | return self.num_samples 94 | 95 | 96 | class BatchSampler(object): 97 | """Wraps another sampler to yield a mini-batch of indices. 98 | 99 | Args: 100 | sampler (Sampler): Base sampler. 101 | batch_size (int): Size of mini-batch. 102 | drop_last (bool): If ``True``, the sampler will drop the last batch if 103 | its size would be less than ``batch_size`` 104 | 105 | Example: 106 | >>> list(BatchSampler(range(10), batch_size=3, drop_last=False)) 107 | [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] 108 | >>> list(BatchSampler(range(10), batch_size=3, drop_last=True)) 109 | [[0, 1, 2], [3, 4, 5], [6, 7, 8]] 110 | """ 111 | 112 | def __init__(self, sampler, batch_size, drop_last): 113 | self.sampler = sampler 114 | self.batch_size = batch_size 115 | self.drop_last = drop_last 116 | 117 | def __iter__(self): 118 | batch = [] 119 | for idx in self.sampler: 120 | batch.append(idx) 121 | if len(batch) == self.batch_size: 122 | yield batch 123 | batch = [] 124 | if len(batch) > 0 and not self.drop_last: 125 | yield batch 126 | 127 | def __len__(self): 128 | if self.drop_last: 129 | return len(self.sampler) // self.batch_size 130 | else: 131 | return (len(self.sampler) + self.batch_size - 1) // self.batch_size 132 | -------------------------------------------------------------------------------- /lib/utils/th.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import numpy as np 4 | import collections 5 | 6 | __all__ = ['as_variable', 'as_numpy', 'mark_volatile'] 7 | 8 | def as_variable(obj): 9 | if isinstance(obj, Variable): 10 | return obj 11 | if isinstance(obj, collections.Sequence): 12 | return [as_variable(v) for v in obj] 13 | elif isinstance(obj, collections.Mapping): 14 | return {k: as_variable(v) for k, v in obj.items()} 15 | else: 16 | return Variable(obj) 17 | 18 | def as_numpy(obj): 19 | if isinstance(obj, collections.Sequence): 20 | return [as_numpy(v) for v in obj] 21 | elif isinstance(obj, collections.Mapping): 22 | return {k: as_numpy(v) for k, v in obj.items()} 23 | elif isinstance(obj, Variable): 24 | return obj.data.cpu().numpy() 25 | elif torch.is_tensor(obj): 26 | return obj.cpu().numpy() 27 | else: 28 | return np.array(obj) 29 | 30 | def mark_volatile(obj): 31 | if torch.is_tensor(obj): 32 | obj = Variable(obj) 33 | if isinstance(obj, Variable): 34 | obj.no_grad = True 35 | return obj 36 | elif isinstance(obj, collections.Mapping): 37 | return {k: mark_volatile(o) for k, o in obj.items()} 38 | elif isinstance(obj, collections.Sequence): 39 | return [mark_volatile(o) for o in obj] 40 | else: 41 | return obj 42 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import ModelBuilder, SegmentationModule 2 | -------------------------------------------------------------------------------- /models/mobilenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | This MobileNetV2 implementation is modified from the following repository: 3 | https://github.com/tonylins/pytorch-mobilenet-v2 4 | """ 5 | 6 | import torch.nn as nn 7 | import math 8 | from .utils import load_url 9 | from lib.nn import SynchronizedBatchNorm2d 10 | 11 | BatchNorm2d = SynchronizedBatchNorm2d 12 | 13 | 14 | __all__ = ['mobilenetv2'] 15 | 16 | 17 | model_urls = { 18 | 'mobilenetv2': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/mobilenet_v2.pth.tar', 19 | } 20 | 21 | 22 | def conv_bn(inp, oup, stride): 23 | return nn.Sequential( 24 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 25 | BatchNorm2d(oup), 26 | nn.ReLU6(inplace=True) 27 | ) 28 | 29 | 30 | def conv_1x1_bn(inp, oup): 31 | return nn.Sequential( 32 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 33 | BatchNorm2d(oup), 34 | nn.ReLU6(inplace=True) 35 | ) 36 | 37 | 38 | class InvertedResidual(nn.Module): 39 | def __init__(self, inp, oup, stride, expand_ratio): 40 | super(InvertedResidual, self).__init__() 41 | self.stride = stride 42 | assert stride in [1, 2] 43 | 44 | hidden_dim = round(inp * expand_ratio) 45 | self.use_res_connect = self.stride == 1 and inp == oup 46 | 47 | if expand_ratio == 1: 48 | self.conv = nn.Sequential( 49 | # dw 50 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 51 | BatchNorm2d(hidden_dim), 52 | nn.ReLU6(inplace=True), 53 | # pw-linear 54 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 55 | BatchNorm2d(oup), 56 | ) 57 | else: 58 | self.conv = nn.Sequential( 59 | # pw 60 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 61 | BatchNorm2d(hidden_dim), 62 | nn.ReLU6(inplace=True), 63 | # dw 64 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 65 | BatchNorm2d(hidden_dim), 66 | nn.ReLU6(inplace=True), 67 | # pw-linear 68 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 69 | BatchNorm2d(oup), 70 | ) 71 | 72 | def forward(self, x): 73 | if self.use_res_connect: 74 | return x + self.conv(x) 75 | else: 76 | return self.conv(x) 77 | 78 | 79 | class MobileNetV2(nn.Module): 80 | def __init__(self, n_class=1000, input_size=224, width_mult=1.): 81 | super(MobileNetV2, self).__init__() 82 | block = InvertedResidual 83 | input_channel = 32 84 | last_channel = 1280 85 | interverted_residual_setting = [ 86 | # t, c, n, s 87 | [1, 16, 1, 1], 88 | [6, 24, 2, 2], 89 | [6, 32, 3, 2], 90 | [6, 64, 4, 2], 91 | [6, 96, 3, 1], 92 | [6, 160, 3, 2], 93 | [6, 320, 1, 1], 94 | ] 95 | 96 | # building first layer 97 | assert input_size % 32 == 0 98 | input_channel = int(input_channel * width_mult) 99 | self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel 100 | self.features = [conv_bn(3, input_channel, 2)] 101 | # building inverted residual blocks 102 | for t, c, n, s in interverted_residual_setting: 103 | output_channel = int(c * width_mult) 104 | for i in range(n): 105 | if i == 0: 106 | self.features.append(block(input_channel, output_channel, s, expand_ratio=t)) 107 | else: 108 | self.features.append(block(input_channel, output_channel, 1, expand_ratio=t)) 109 | input_channel = output_channel 110 | # building last several layers 111 | self.features.append(conv_1x1_bn(input_channel, self.last_channel)) 112 | # make it nn.Sequential 113 | self.features = nn.Sequential(*self.features) 114 | 115 | # building classifier 116 | self.classifier = nn.Sequential( 117 | nn.Dropout(0.2), 118 | nn.Linear(self.last_channel, n_class), 119 | ) 120 | 121 | self._initialize_weights() 122 | 123 | def forward(self, x): 124 | x = self.features(x) 125 | x = x.mean(3).mean(2) 126 | x = self.classifier(x) 127 | return x 128 | 129 | def _initialize_weights(self): 130 | for m in self.modules(): 131 | if isinstance(m, nn.Conv2d): 132 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 133 | m.weight.data.normal_(0, math.sqrt(2. / n)) 134 | if m.bias is not None: 135 | m.bias.data.zero_() 136 | elif isinstance(m, BatchNorm2d): 137 | m.weight.data.fill_(1) 138 | m.bias.data.zero_() 139 | elif isinstance(m, nn.Linear): 140 | n = m.weight.size(1) 141 | m.weight.data.normal_(0, 0.01) 142 | m.bias.data.zero_() 143 | 144 | 145 | def mobilenetv2(pretrained=False, **kwargs): 146 | """Constructs a MobileNet_V2 model. 147 | 148 | Args: 149 | pretrained (bool): If True, returns a model pre-trained on ImageNet 150 | """ 151 | model = MobileNetV2(n_class=1000, **kwargs) 152 | if pretrained: 153 | model.load_state_dict(load_url(model_urls['mobilenetv2']), strict=False) 154 | return model 155 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | from .utils import load_url 4 | from lib.nn import SynchronizedBatchNorm2d 5 | BatchNorm2d = SynchronizedBatchNorm2d 6 | 7 | 8 | __all__ = ['ResNet', 'resnet18', 'resnet50', 'resnet101'] # resnet101 is coming soon! 9 | 10 | 11 | model_urls = { 12 | 'resnet18': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet18-imagenet.pth', 13 | 'resnet50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet50-imagenet.pth', 14 | 'resnet101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet101-imagenet.pth' 15 | } 16 | 17 | 18 | def conv3x3(in_planes, out_planes, stride=1): 19 | "3x3 convolution with padding" 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=1, bias=False) 22 | 23 | 24 | class BasicBlock(nn.Module): 25 | expansion = 1 26 | 27 | def __init__(self, inplanes, planes, stride=1, downsample=None): 28 | super(BasicBlock, self).__init__() 29 | self.conv1 = conv3x3(inplanes, planes, stride) 30 | self.bn1 = BatchNorm2d(planes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv2 = conv3x3(planes, planes) 33 | self.bn2 = BatchNorm2d(planes) 34 | self.downsample = downsample 35 | self.stride = stride 36 | 37 | def forward(self, x): 38 | residual = x 39 | 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv2(out) 45 | out = self.bn2(out) 46 | 47 | if self.downsample is not None: 48 | residual = self.downsample(x) 49 | 50 | out += residual 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | 56 | class Bottleneck(nn.Module): 57 | expansion = 4 58 | 59 | def __init__(self, inplanes, planes, stride=1, downsample=None): 60 | super(Bottleneck, self).__init__() 61 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 62 | self.bn1 = BatchNorm2d(planes) 63 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 64 | padding=1, bias=False) 65 | self.bn2 = BatchNorm2d(planes) 66 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 67 | self.bn3 = BatchNorm2d(planes * 4) 68 | self.relu = nn.ReLU(inplace=True) 69 | self.downsample = downsample 70 | self.stride = stride 71 | 72 | def forward(self, x): 73 | residual = x 74 | 75 | out = self.conv1(x) 76 | out = self.bn1(out) 77 | out = self.relu(out) 78 | 79 | out = self.conv2(out) 80 | out = self.bn2(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv3(out) 84 | out = self.bn3(out) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | 89 | out += residual 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | 95 | class ResNet(nn.Module): 96 | 97 | def __init__(self, block, layers, num_classes=1000): 98 | self.inplanes = 128 99 | super(ResNet, self).__init__() 100 | self.conv1 = conv3x3(3, 64, stride=2) 101 | self.bn1 = BatchNorm2d(64) 102 | self.relu1 = nn.ReLU(inplace=True) 103 | self.conv2 = conv3x3(64, 64) 104 | self.bn2 = BatchNorm2d(64) 105 | self.relu2 = nn.ReLU(inplace=True) 106 | self.conv3 = conv3x3(64, 128) 107 | self.bn3 = BatchNorm2d(128) 108 | self.relu3 = nn.ReLU(inplace=True) 109 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 110 | 111 | self.layer1 = self._make_layer(block, 64, layers[0]) 112 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 113 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 114 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 115 | self.avgpool = nn.AvgPool2d(7, stride=1) 116 | self.fc = nn.Linear(512 * block.expansion, num_classes) 117 | 118 | for m in self.modules(): 119 | if isinstance(m, nn.Conv2d): 120 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 121 | m.weight.data.normal_(0, math.sqrt(2. / n)) 122 | elif isinstance(m, BatchNorm2d): 123 | m.weight.data.fill_(1) 124 | m.bias.data.zero_() 125 | 126 | def _make_layer(self, block, planes, blocks, stride=1): 127 | downsample = None 128 | if stride != 1 or self.inplanes != planes * block.expansion: 129 | downsample = nn.Sequential( 130 | nn.Conv2d(self.inplanes, planes * block.expansion, 131 | kernel_size=1, stride=stride, bias=False), 132 | BatchNorm2d(planes * block.expansion), 133 | ) 134 | 135 | layers = [] 136 | layers.append(block(self.inplanes, planes, stride, downsample)) 137 | self.inplanes = planes * block.expansion 138 | for i in range(1, blocks): 139 | layers.append(block(self.inplanes, planes)) 140 | 141 | return nn.Sequential(*layers) 142 | 143 | def forward(self, x): 144 | x = self.relu1(self.bn1(self.conv1(x))) 145 | x = self.relu2(self.bn2(self.conv2(x))) 146 | x = self.relu3(self.bn3(self.conv3(x))) 147 | x = self.maxpool(x) 148 | 149 | x = self.layer1(x) 150 | x = self.layer2(x) 151 | x = self.layer3(x) 152 | x = self.layer4(x) 153 | 154 | x = self.avgpool(x) 155 | x = x.view(x.size(0), -1) 156 | x = self.fc(x) 157 | 158 | return x 159 | 160 | def resnet18(pretrained=False, **kwargs): 161 | """Constructs a ResNet-18 model. 162 | 163 | Args: 164 | pretrained (bool): If True, returns a model pre-trained on ImageNet 165 | """ 166 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 167 | if pretrained: 168 | model.load_state_dict(load_url(model_urls['resnet18'])) 169 | return model 170 | 171 | ''' 172 | def resnet34(pretrained=False, **kwargs): 173 | """Constructs a ResNet-34 model. 174 | 175 | Args: 176 | pretrained (bool): If True, returns a model pre-trained on ImageNet 177 | """ 178 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 179 | if pretrained: 180 | model.load_state_dict(load_url(model_urls['resnet34'])) 181 | return model 182 | ''' 183 | 184 | def resnet50(pretrained=False, **kwargs): 185 | """Constructs a ResNet-50 model. 186 | 187 | Args: 188 | pretrained (bool): If True, returns a model pre-trained on ImageNet 189 | """ 190 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 191 | if pretrained: 192 | model.load_state_dict(load_url(model_urls['resnet50']), strict=False) 193 | return model 194 | 195 | 196 | def resnet101(pretrained=False, **kwargs): 197 | """Constructs a ResNet-101 model. 198 | 199 | Args: 200 | pretrained (bool): If True, returns a model pre-trained on ImageNet 201 | """ 202 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 203 | if pretrained: 204 | model.load_state_dict(load_url(model_urls['resnet101']), strict=False) 205 | return model 206 | 207 | # def resnet152(pretrained=False, **kwargs): 208 | # """Constructs a ResNet-152 model. 209 | # 210 | # Args: 211 | # pretrained (bool): If True, returns a model pre-trained on ImageNet 212 | # """ 213 | # model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 214 | # if pretrained: 215 | # model.load_state_dict(load_url(model_urls['resnet152'])) 216 | # return model 217 | -------------------------------------------------------------------------------- /models/resnext.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | from .utils import load_url 4 | from lib.nn import SynchronizedBatchNorm2d 5 | BatchNorm2d = SynchronizedBatchNorm2d 6 | 7 | 8 | __all__ = ['ResNeXt', 'resnext101'] # support resnext 101 9 | 10 | 11 | model_urls = { 12 | #'resnext50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnext50-imagenet.pth', 13 | 'resnext101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnext101-imagenet.pth' 14 | } 15 | 16 | 17 | def conv3x3(in_planes, out_planes, stride=1): 18 | "3x3 convolution with padding" 19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 20 | padding=1, bias=False) 21 | 22 | 23 | class GroupBottleneck(nn.Module): 24 | expansion = 2 25 | 26 | def __init__(self, inplanes, planes, stride=1, groups=1, downsample=None): 27 | super(GroupBottleneck, self).__init__() 28 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 29 | self.bn1 = BatchNorm2d(planes) 30 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 31 | padding=1, groups=groups, bias=False) 32 | self.bn2 = BatchNorm2d(planes) 33 | self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1, bias=False) 34 | self.bn3 = BatchNorm2d(planes * 2) 35 | self.relu = nn.ReLU(inplace=True) 36 | self.downsample = downsample 37 | self.stride = stride 38 | 39 | def forward(self, x): 40 | residual = x 41 | 42 | out = self.conv1(x) 43 | out = self.bn1(out) 44 | out = self.relu(out) 45 | 46 | out = self.conv2(out) 47 | out = self.bn2(out) 48 | out = self.relu(out) 49 | 50 | out = self.conv3(out) 51 | out = self.bn3(out) 52 | 53 | if self.downsample is not None: 54 | residual = self.downsample(x) 55 | 56 | out += residual 57 | out = self.relu(out) 58 | 59 | return out 60 | 61 | 62 | class ResNeXt(nn.Module): 63 | 64 | def __init__(self, block, layers, groups=32, num_classes=1000): 65 | self.inplanes = 128 66 | super(ResNeXt, self).__init__() 67 | self.conv1 = conv3x3(3, 64, stride=2) 68 | self.bn1 = BatchNorm2d(64) 69 | self.relu1 = nn.ReLU(inplace=True) 70 | self.conv2 = conv3x3(64, 64) 71 | self.bn2 = BatchNorm2d(64) 72 | self.relu2 = nn.ReLU(inplace=True) 73 | self.conv3 = conv3x3(64, 128) 74 | self.bn3 = BatchNorm2d(128) 75 | self.relu3 = nn.ReLU(inplace=True) 76 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 77 | 78 | self.layer1 = self._make_layer(block, 128, layers[0], groups=groups) 79 | self.layer2 = self._make_layer(block, 256, layers[1], stride=2, groups=groups) 80 | self.layer3 = self._make_layer(block, 512, layers[2], stride=2, groups=groups) 81 | self.layer4 = self._make_layer(block, 1024, layers[3], stride=2, groups=groups) 82 | self.avgpool = nn.AvgPool2d(7, stride=1) 83 | self.fc = nn.Linear(1024 * block.expansion, num_classes) 84 | 85 | for m in self.modules(): 86 | if isinstance(m, nn.Conv2d): 87 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels // m.groups 88 | m.weight.data.normal_(0, math.sqrt(2. / n)) 89 | elif isinstance(m, BatchNorm2d): 90 | m.weight.data.fill_(1) 91 | m.bias.data.zero_() 92 | 93 | def _make_layer(self, block, planes, blocks, stride=1, groups=1): 94 | downsample = None 95 | if stride != 1 or self.inplanes != planes * block.expansion: 96 | downsample = nn.Sequential( 97 | nn.Conv2d(self.inplanes, planes * block.expansion, 98 | kernel_size=1, stride=stride, bias=False), 99 | BatchNorm2d(planes * block.expansion), 100 | ) 101 | 102 | layers = [] 103 | layers.append(block(self.inplanes, planes, stride, groups, downsample)) 104 | self.inplanes = planes * block.expansion 105 | for i in range(1, blocks): 106 | layers.append(block(self.inplanes, planes, groups=groups)) 107 | 108 | return nn.Sequential(*layers) 109 | 110 | def forward(self, x): 111 | x = self.relu1(self.bn1(self.conv1(x))) 112 | x = self.relu2(self.bn2(self.conv2(x))) 113 | x = self.relu3(self.bn3(self.conv3(x))) 114 | x = self.maxpool(x) 115 | 116 | x = self.layer1(x) 117 | x = self.layer2(x) 118 | x = self.layer3(x) 119 | x = self.layer4(x) 120 | 121 | x = self.avgpool(x) 122 | x = x.view(x.size(0), -1) 123 | x = self.fc(x) 124 | 125 | return x 126 | 127 | 128 | ''' 129 | def resnext50(pretrained=False, **kwargs): 130 | """Constructs a ResNet-50 model. 131 | 132 | Args: 133 | pretrained (bool): If True, returns a model pre-trained on Places 134 | """ 135 | model = ResNeXt(GroupBottleneck, [3, 4, 6, 3], **kwargs) 136 | if pretrained: 137 | model.load_state_dict(load_url(model_urls['resnext50']), strict=False) 138 | return model 139 | ''' 140 | 141 | 142 | def resnext101(pretrained=False, **kwargs): 143 | """Constructs a ResNet-101 model. 144 | 145 | Args: 146 | pretrained (bool): If True, returns a model pre-trained on Places 147 | """ 148 | model = ResNeXt(GroupBottleneck, [3, 4, 23, 3], **kwargs) 149 | if pretrained: 150 | model.load_state_dict(load_url(model_urls['resnext101']), strict=False) 151 | return model 152 | 153 | 154 | # def resnext152(pretrained=False, **kwargs): 155 | # """Constructs a ResNeXt-152 model. 156 | # 157 | # Args: 158 | # pretrained (bool): If True, returns a model pre-trained on Places 159 | # """ 160 | # model = ResNeXt(GroupBottleneck, [3, 8, 36, 3], **kwargs) 161 | # if pretrained: 162 | # model.load_state_dict(load_url(model_urls['resnext152'])) 163 | # return model 164 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | try: 4 | from urllib import urlretrieve 5 | except ImportError: 6 | from urllib.request import urlretrieve 7 | import torch 8 | 9 | 10 | def load_url(url, model_dir='./pretrained', map_location=None): 11 | if not os.path.exists(model_dir): 12 | os.makedirs(model_dir) 13 | filename = url.split('/')[-1] 14 | cached_file = os.path.join(model_dir, filename) 15 | if not os.path.exists(cached_file): 16 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) 17 | urlretrieve(url, cached_file) 18 | return torch.load(cached_file, map_location=map_location) 19 | -------------------------------------------------------------------------------- /pics/42_170616140840_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tea321000/image_semantic_segmentation/80409858e8d44c1b24695035abf3a68fd83e9574/pics/42_170616140840_1.jpg -------------------------------------------------------------------------------- /pics/5DDED2D65C4EB2FAE3FC89CD64D376A5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tea321000/image_semantic_segmentation/80409858e8d44c1b24695035abf3a68fd83e9574/pics/5DDED2D65C4EB2FAE3FC89CD64D376A5.jpg -------------------------------------------------------------------------------- /pics/free_stock_photo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tea321000/image_semantic_segmentation/80409858e8d44c1b24695035abf3a68fd83e9574/pics/free_stock_photo.jpg -------------------------------------------------------------------------------- /pics/timg.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tea321000/image_semantic_segmentation/80409858e8d44c1b24695035abf3a68fd83e9574/pics/timg.jpg -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | certifi==2020.4.5.1 2 | future==0.18.2 3 | numpy==1.18.5 4 | Pillow==7.1.2 5 | PyQt5==5.15.0 6 | PyQt5-sip==12.8.0 7 | PyYAML==5.3.1 8 | scipy==1.4.1 9 | six==1.15.0 10 | torch==1.5.0+cu101 11 | torchvision==0.2.2.post3 12 | tqdm==4.46.1 13 | wincertstore==0.2 14 | yacs==0.1.7 15 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # System libs 2 | # import sys 3 | # 4 | # sys.path.append('/workspace/library') 5 | import os 6 | import argparse 7 | from distutils.version import LooseVersion 8 | # Numerical libs 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | from scipy.io import loadmat 13 | import csv 14 | # Our libs 15 | from dataset import TestDataset 16 | from models import ModelBuilder, SegmentationModule 17 | from utils import colorEncode, find_recursive, setup_logger 18 | from lib.nn import user_scattered_collate, async_copy_to 19 | from lib.utils import as_numpy 20 | from PIL import Image 21 | from tqdm import tqdm 22 | from config import cfg 23 | 24 | 25 | colors = loadmat('data/color150.mat')['colors'] 26 | names = {} 27 | with open('data/object150_info.csv') as f: 28 | reader = csv.reader(f) 29 | next(reader) 30 | for row in reader: 31 | names[int(row[0])] = row[5].split(";")[0] 32 | 33 | 34 | def visualize_result(data, pred, cfg, args): 35 | (img, info) = data 36 | # print predictions in descending order 37 | pred = np.int32(pred) 38 | pixs = pred.size 39 | uniques, counts = np.unique(pred, return_counts=True) 40 | print("Predictions in [{}]:".format(info)) 41 | for idx in np.argsort(counts)[::-1]: 42 | name = names[uniques[idx] + 1] 43 | ratio = counts[idx] / pixs * 100 44 | if ratio > 0.1: 45 | print(" {}: {:.2f}%".format(name, ratio)) 46 | 47 | # colorize prediction 48 | 49 | pred_color = colorEncode(data, pred, colors, names, args.result).astype(np.uint8) 50 | 51 | # aggregate images and save 52 | im_vis = np.concatenate((img, pred_color), axis=1) 53 | 54 | img_name = info.split('/')[-1] 55 | Image.fromarray(im_vis).save( 56 | os.path.join(args.result, os.path.splitext(os.path.basename(img_name))[0], 'org&seg.png')) 57 | Image.fromarray(img).save( 58 | os.path.join(args.result, os.path.splitext(os.path.basename(img_name))[0], 'org.png')) 59 | Image.fromarray(pred_color).save( 60 | os.path.join(args.result, os.path.splitext(os.path.basename(img_name))[0], 'seg.png')) 61 | 62 | 63 | def test(segmentation_module, loader, gpu, gpu_flag, args, progress): 64 | segmentation_module.eval() 65 | pbar = tqdm(total=len(loader)) 66 | process_count = 0 67 | for batch_data in loader: 68 | # process data 69 | batch_data = batch_data[0] 70 | segSize = (batch_data['img_ori'].shape[0], 71 | batch_data['img_ori'].shape[1]) 72 | img_resized_list = batch_data['img_data'] 73 | 74 | with torch.no_grad(): 75 | scores = torch.zeros(1, cfg.DATASET.num_class, segSize[0], segSize[1]) 76 | if gpu_flag: 77 | scores = async_copy_to(scores, gpu) 78 | 79 | for img in img_resized_list: 80 | feed_dict = batch_data.copy() 81 | # feed_dict['img_data'] = img 82 | feed_dict['img_data'] = img 83 | del feed_dict['img_ori'] 84 | del feed_dict['info'] 85 | if gpu_flag: 86 | feed_dict = async_copy_to(feed_dict, gpu) 87 | 88 | # forward pass 89 | try: 90 | pred_tmp = segmentation_module(feed_dict, segSize=segSize) 91 | scores = scores + pred_tmp / len(cfg.DATASET.imgSizes) 92 | except RuntimeError as e: 93 | print('出现运行错误,假如出现CUDA OUT OF MEMORY则为爆显存,会输出错误分割结果,请尝试用CPU处理该图片。错误信息:', e) 94 | 95 | _, pred = torch.max(scores, dim=1) 96 | if gpu_flag: 97 | pred = as_numpy(pred.squeeze(0).cpu()) 98 | else: 99 | pred = as_numpy(pred.squeeze(0)) 100 | 101 | # visualization 102 | visualize_result( 103 | (batch_data['img_ori'], batch_data['info']), 104 | pred, 105 | cfg, 106 | args 107 | ) 108 | process_count += 1 109 | progress.setValue(int(process_count/len(loader)*100)) 110 | pbar.update(1) 111 | 112 | 113 | def main(cfg, gpu, args, progress): 114 | gpu_flag = args.gpu_flag 115 | if gpu_flag and torch.cuda.is_available(): 116 | torch.cuda.set_device(gpu) 117 | print('使用GPU进行语义分割') 118 | else: 119 | print('未开启GPU或未安装CUDA环境,设置使用CPU进行语义分割') 120 | 121 | # Network Builders 122 | if gpu_flag: 123 | net_encoder = ModelBuilder.build_encoder( 124 | arch=cfg.MODEL.arch_encoder, 125 | fc_dim=cfg.MODEL.fc_dim, 126 | weights=cfg.MODEL.weights_encoder) 127 | net_decoder = ModelBuilder.build_decoder( 128 | arch=cfg.MODEL.arch_decoder, 129 | fc_dim=cfg.MODEL.fc_dim, 130 | num_class=cfg.DATASET.num_class, 131 | weights=cfg.MODEL.weights_decoder, 132 | use_softmax=True) 133 | else: 134 | net_encoder = ModelBuilder.build_encoder( 135 | arch=cfg.MODEL.arch_encoder, 136 | fc_dim=cfg.MODEL.fc_dim, 137 | weights=cfg.MODEL.weights_encoder, gpu_flag=False) 138 | net_decoder = ModelBuilder.build_decoder( 139 | arch=cfg.MODEL.arch_decoder, 140 | fc_dim=cfg.MODEL.fc_dim, 141 | num_class=cfg.DATASET.num_class, 142 | weights=cfg.MODEL.weights_decoder, 143 | use_softmax=True, gpu_flag=False) 144 | 145 | crit = nn.NLLLoss(ignore_index=-1) 146 | 147 | segmentation_module = SegmentationModule(net_encoder, net_decoder, crit) 148 | 149 | # Dataset and Loader 150 | dataset_test = TestDataset( 151 | cfg.list_test, 152 | cfg.DATASET) 153 | loader_test = torch.utils.data.DataLoader( 154 | dataset_test, 155 | batch_size=cfg.TEST.batch_size, 156 | shuffle=False, 157 | collate_fn=user_scattered_collate, 158 | num_workers=5, 159 | drop_last=True) 160 | if gpu_flag: 161 | segmentation_module.cuda() 162 | 163 | # Main loop 164 | test(segmentation_module, loader_test, gpu, gpu_flag, args, progress) 165 | 166 | print('语义分割处理完成!') 167 | 168 | 169 | def arg_from_ui(imgs, progress, gpu_flag=None, config_path=None, dir=None, checkpoint=None, result=None): 170 | assert LooseVersion(torch.__version__) >= LooseVersion('0.4.0'), \ 171 | 'PyTorch>=0.4.0 is required' 172 | # args = {'cfg': 'config/ade20k-resnet50dilated-ppm_deepsup.yaml', 'gpu': 0, 'opts': None, 'gpu_flag': False, 173 | # 'dir': 'ade20k-resnet50dilated-ppm_deepsup', 'result': 'segmentation', 'checkpoint': 'epoch_20.pth'} 174 | parser = argparse.ArgumentParser( 175 | description="PyTorch Semantic Segmentation Testing" 176 | ) 177 | parser.add_argument( 178 | "--imgs", 179 | default=imgs, 180 | type=str, 181 | help="an image paths, or a directory name" 182 | ) 183 | parser.add_argument( 184 | "--config_path", 185 | default="config/ade20k-resnet50dilated-ppm_deepsup.yaml", 186 | metavar="FILE", 187 | help="path to config file", 188 | type=str, 189 | ) 190 | parser.add_argument( 191 | "--gpu", 192 | default=0, 193 | type=int, 194 | help="gpu id for evaluation" 195 | ) 196 | parser.add_argument( 197 | "opts", 198 | help="Modify config options using the command-line", 199 | default=None, 200 | nargs=argparse.REMAINDER, 201 | ) 202 | parser.add_argument( 203 | "--gpu_flag", 204 | help="open and close gpu", 205 | default=True, 206 | nargs=argparse.REMAINDER, 207 | ) 208 | parser.add_argument( 209 | "--dir", 210 | help="model dir", 211 | default="ade20k-resnet50dilated-ppm_deepsup", 212 | nargs=argparse.REMAINDER, 213 | ) 214 | parser.add_argument( 215 | "--result", 216 | help="segmentation result dir", 217 | default="segmentation", 218 | nargs=argparse.REMAINDER, 219 | ) 220 | parser.add_argument( 221 | "--checkpoint", 222 | help="pretrained model checkpoint", 223 | default="epoch_20.pth", 224 | nargs=argparse.REMAINDER, 225 | ) 226 | args = parser.parse_args() 227 | if gpu_flag is not None: 228 | args.gpu_flag = gpu_flag 229 | if config_path: 230 | args.config_path = config_path 231 | if dir: 232 | args.dir = dir 233 | if checkpoint: 234 | args.checkpoint = checkpoint 235 | if result: 236 | args.result = result 237 | 238 | cfg.merge_from_file(args.config_path) 239 | cfg.merge_from_list(args.opts) 240 | # cfg.freeze() 241 | 242 | logger = setup_logger(distributed_rank=0) # TODO 243 | logger.info("Loaded configuration file {}".format(args.config_path)) 244 | logger.info("Running with config:\n{}".format(cfg)) 245 | 246 | cfg.MODEL.arch_encoder = cfg.MODEL.arch_encoder.lower() 247 | cfg.MODEL.arch_decoder = cfg.MODEL.arch_decoder.lower() 248 | 249 | # absolute paths of model weights 250 | cfg.MODEL.weights_encoder = os.path.join( 251 | args.dir, 'encoder_' + args.checkpoint) 252 | cfg.MODEL.weights_decoder = os.path.join( 253 | args.dir, 'decoder_' + args.checkpoint) 254 | print(cfg.MODEL.weights_encoder) 255 | 256 | assert os.path.exists(cfg.MODEL.weights_encoder) and \ 257 | os.path.exists(cfg.MODEL.weights_decoder), "checkpoint does not exitst!" 258 | 259 | # generate testing image list 260 | print('-----imgs:', args.imgs) 261 | if os.path.isdir(args.imgs): 262 | imgs = find_recursive(args.imgs) 263 | else: 264 | imgs = [args.imgs] 265 | assert len(imgs), "imgs should be a path to image (.jpg) or directory." 266 | cfg.list_test = [{'fpath_img': x} for x in imgs] 267 | 268 | if not os.path.isdir(args.result): 269 | os.makedirs(args.result) 270 | 271 | main(cfg, args.gpu, args, progress) 272 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import logging 4 | import re 5 | import functools 6 | import fnmatch 7 | import numpy as np 8 | from PIL import Image 9 | from config import cfg 10 | 11 | 12 | def setup_logger(distributed_rank=0, filename="log.txt"): 13 | logger = logging.getLogger("Logger") 14 | logger.setLevel(logging.DEBUG) 15 | # don't log results for the non-master process 16 | if distributed_rank > 0: 17 | return logger 18 | ch = logging.StreamHandler(stream=sys.stdout) 19 | ch.setLevel(logging.DEBUG) 20 | fmt = "[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s" 21 | ch.setFormatter(logging.Formatter(fmt)) 22 | logger.addHandler(ch) 23 | 24 | return logger 25 | 26 | 27 | def find_recursive(root_dir, ext='.jpg'): 28 | files = [] 29 | for root, dirnames, filenames in os.walk(root_dir): 30 | for filename in fnmatch.filter(filenames, '*' + ext): 31 | files.append(os.path.join(root, filename)) 32 | return files 33 | 34 | 35 | class AverageMeter(object): 36 | """Computes and stores the average and current value""" 37 | def __init__(self): 38 | self.initialized = False 39 | self.val = None 40 | self.avg = None 41 | self.sum = None 42 | self.count = None 43 | 44 | def initialize(self, val, weight): 45 | self.val = val 46 | self.avg = val 47 | self.sum = val * weight 48 | self.count = weight 49 | self.initialized = True 50 | 51 | def update(self, val, weight=1): 52 | if not self.initialized: 53 | self.initialize(val, weight) 54 | else: 55 | self.add(val, weight) 56 | 57 | def add(self, val, weight): 58 | self.val = val 59 | self.sum += val * weight 60 | self.count += weight 61 | self.avg = self.sum / self.count 62 | 63 | def value(self): 64 | return self.val 65 | 66 | def average(self): 67 | return self.avg 68 | 69 | 70 | def unique(ar, return_index=False, return_inverse=False, return_counts=False): 71 | ar = np.asanyarray(ar).flatten() 72 | 73 | optional_indices = return_index or return_inverse 74 | optional_returns = optional_indices or return_counts 75 | 76 | if ar.size == 0: 77 | if not optional_returns: 78 | ret = ar 79 | else: 80 | ret = (ar,) 81 | if return_index: 82 | ret += (np.empty(0, np.bool),) 83 | if return_inverse: 84 | ret += (np.empty(0, np.bool),) 85 | if return_counts: 86 | ret += (np.empty(0, np.intp),) 87 | return ret 88 | if optional_indices: 89 | perm = ar.argsort(kind='mergesort' if return_index else 'quicksort') 90 | aux = ar[perm] 91 | else: 92 | ar.sort() 93 | aux = ar 94 | flag = np.concatenate(([True], aux[1:] != aux[:-1])) 95 | 96 | if not optional_returns: 97 | ret = aux[flag] 98 | else: 99 | ret = (aux[flag],) 100 | if return_index: 101 | ret += (perm[flag],) 102 | if return_inverse: 103 | iflag = np.cumsum(flag) - 1 104 | inv_idx = np.empty(ar.shape, dtype=np.intp) 105 | inv_idx[perm] = iflag 106 | ret += (inv_idx,) 107 | if return_counts: 108 | idx = np.concatenate(np.nonzero(flag) + ([ar.size],)) 109 | ret += (np.diff(idx),) 110 | return ret 111 | 112 | 113 | def colorEncode(data, labelmap, colors, names, result_path, mode='RGB'): 114 | (org, info) = data 115 | # print(names) 116 | img_name = info.split('/')[-1] 117 | labelmap = labelmap.astype('int') 118 | np.set_printoptions(threshold=np.inf) 119 | # print(labelmap) 120 | labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3), 121 | dtype=np.uint8) 122 | labelmap_org= np.zeros((labelmap.shape[0], labelmap.shape[1],4), 123 | dtype=np.uint8) 124 | for label in unique(labelmap): 125 | if label < 0: 126 | continue 127 | mask=(labelmap == label)[:, :, np.newaxis] 128 | labelmap_rgb += mask* \ 129 | np.tile(colors[label], 130 | (labelmap.shape[0], labelmap.shape[1], 1)) 131 | alpha = 255 * np.ones((org.shape[0], org.shape[1], 1)) 132 | rgba = np.concatenate([org, alpha], axis=2).astype('uint8') 133 | labelmap_org = mask*rgba 134 | if not os.path.isdir(os.path.join(result_path, os.path.splitext(os.path.basename(img_name))[0])): 135 | os.makedirs(os.path.join(result_path, os.path.splitext(os.path.basename(img_name))[0])) 136 | Image.fromarray(labelmap_org).save( 137 | os.path.join(result_path, os.path.splitext(os.path.basename(img_name))[0], names[label+1] + '.png')) 138 | 139 | if mode == 'BGR': 140 | return labelmap_rgb[:, :, ::-1] 141 | else: 142 | return labelmap_rgb 143 | 144 | 145 | def accuracy(preds, label): 146 | valid = (label >= 0) 147 | acc_sum = (valid * (preds == label)).sum() 148 | valid_sum = valid.sum() 149 | acc = float(acc_sum) / (valid_sum + 1e-10) 150 | return acc, valid_sum 151 | 152 | 153 | def intersectionAndUnion(imPred, imLab, numClass): 154 | imPred = np.asarray(imPred).copy() 155 | imLab = np.asarray(imLab).copy() 156 | 157 | imPred += 1 158 | imLab += 1 159 | # Remove classes from unlabeled pixels in gt image. 160 | # We should not penalize detections in unlabeled portions of the image. 161 | imPred = imPred * (imLab > 0) 162 | 163 | # Compute area intersection: 164 | intersection = imPred * (imPred == imLab) 165 | (area_intersection, _) = np.histogram( 166 | intersection, bins=numClass, range=(1, numClass)) 167 | 168 | # Compute area union: 169 | (area_pred, _) = np.histogram(imPred, bins=numClass, range=(1, numClass)) 170 | (area_lab, _) = np.histogram(imLab, bins=numClass, range=(1, numClass)) 171 | area_union = area_pred + area_lab - area_intersection 172 | 173 | return (area_intersection, area_union) 174 | 175 | 176 | class NotSupportedCliException(Exception): 177 | pass 178 | 179 | 180 | def process_range(xpu, inp): 181 | start, end = map(int, inp) 182 | if start > end: 183 | end, start = start, end 184 | return map(lambda x: '{}{}'.format(xpu, x), range(start, end+1)) 185 | 186 | 187 | REGEX = [ 188 | (re.compile(r'^gpu(\d+)$'), lambda x: ['gpu%s' % x[0]]), 189 | (re.compile(r'^(\d+)$'), lambda x: ['gpu%s' % x[0]]), 190 | (re.compile(r'^gpu(\d+)-(?:gpu)?(\d+)$'), 191 | functools.partial(process_range, 'gpu')), 192 | (re.compile(r'^(\d+)-(\d+)$'), 193 | functools.partial(process_range, 'gpu')), 194 | ] 195 | 196 | 197 | def parse_devices(input_devices): 198 | 199 | """Parse user's devices input str to standard format. 200 | e.g. [gpu0, gpu1, ...] 201 | 202 | """ 203 | ret = [] 204 | for d in input_devices.split(','): 205 | for regex, func in REGEX: 206 | m = regex.match(d.lower().strip()) 207 | if m: 208 | tmp = func(m.groups()) 209 | # prevent duplicate 210 | for x in tmp: 211 | if x not in ret: 212 | ret.append(x) 213 | break 214 | else: 215 | raise NotSupportedCliException( 216 | 'Can not recognize device: "{}"'.format(d)) 217 | return ret 218 | --------------------------------------------------------------------------------