├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── conda_env.yml ├── datasets └── sysu_mm01 │ ├── test_config.pth │ ├── train_config.pth │ └── val_config.pth └── src_py ├── dataset.py ├── doall.py ├── doall_test.py ├── log.py ├── model.py ├── model_resnet_pretrained.py ├── settings.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "src_py/SYSU-MM01"] 2 | path = src_py/SYSU-MM01 3 | url = https://github.com/wuancong/SYSU-MM01 4 | [submodule ".\\src_py\\SYSU_MM01_pythoneval"] 5 | path = src_py/SYSU_MM01_pythoneval 6 | url = https://github.com/InnovArul/SYSU_MM01_pythoneval 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # rgb_IR_personreid 2 | 3 | ## RGB-D ICCV-2017 paper 4 | 5 | Author's webpage for ICCV2017 paper **"RGB-Infrared Cross-Modality Person Re-Identification"**: 6 | [http://isee.sysu.edu.cn/project/RGBIRReID.htm](http://isee.sysu.edu.cn/project/RGBIRReID.htm) 7 | 8 | ## Dataset 9 | 10 | Download the dataset from the following link and extract contents to the folder `datasets/sysu_mm01`. 11 | 12 | [Dropbox (from ICCV-2017)](https://www.dropbox.com/sh/v036mg1q4yg7awb/AABhxU-FJ4X2oyq7-Ts6bgD0a?dl=0 ) 13 | 14 | The folder contents of `datasets` folder will look like: 15 | 16 | ``` 17 | ./datasets 18 | - sysu_mm01 19 | -- cam1 20 | -- cam2 21 | -- cam3 22 | -- cam4 23 | -- cam5 24 | -- cam6 25 | -- exp 26 | 27 | ``` 28 | 29 | ## Training 30 | 31 | To train the model (`ResNet6` from `model.py`), go to `src_py` folder in command prompt and execute: 32 | 33 | ``` 34 | python doall.py 35 | ``` 36 | 37 | The training logs and models will be saved under `/scratch` folder. 38 | 39 | ## Testing 40 | 41 | To test the model, go to `src_py` folder and edit the file `doall_test.py` file to place the path of pretrained model: 42 | 43 | ``` 44 | pretrained_model = "../scratch/..." 45 | ``` 46 | Then execute the command below to test the model: 47 | ``` 48 | python doall_test.py 49 | ``` 50 | The features of all the images from different cameras will be stored with file names suffixed `_camX` where `X=cam number` under the corresponding `scratch` log folder. These features will be used by Matlab evaluation code to calculate the relevant metrics. 51 | 52 | ## Metrics calculation 53 | 54 | To evaluate and get the metrics, the authors have release a Matlab evaluation script in the github repository: [https://github.com/wuancong/SYSU-MM01/blob/master/evaluation](https://github.com/wuancong/SYSU-MM01/blob/master/evaluation). 55 | 56 | Given the features, the evaluation code in their repository calculates the rank-1, mAP metrics. Open the `demo.m` file inside `/src_py/SYSU-MM01/evaluation/` directory and give the relevant information regarding feature path, result folder, prefix of the model and execute the `demo.m` file. 57 | 58 | Alternatively, you can use python version of evaluation script from `/src_py/SYSU_MM01_pythoneval/`. Refer to the [README](https://github.com/InnovArul/SYSU_MM01_pythoneval/blob/master/README.md) file of `SYSU_MM01_pythoneval` project for more information on how to use the evaluation routine. It is as similar to Matlab evaluation, but implemented in python to have seemless interface with python scripts. -------------------------------------------------------------------------------- /conda_env.yml: -------------------------------------------------------------------------------- 1 | name: pytorch1.0 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _anaconda_depends=2019.03=py37_0 8 | - _libgcc_mutex=0.1=main 9 | - alabaster=0.7.12=py37_0 10 | - anaconda=custom=py37_1 11 | - anaconda-client=1.7.2=py37_0 12 | - anaconda-project=0.8.3=py_0 13 | - asn1crypto=1.0.1=py37_0 14 | - astroid=2.3.1=py37_0 15 | - astropy=3.2.2=py37h7b6447c_0 16 | - atomicwrites=1.3.0=py37_1 17 | - attrs=19.2.0=py_0 18 | - babel=2.7.0=py_0 19 | - backcall=0.1.0=py37_0 20 | - backports=1.0=py_2 21 | - backports.os=0.1.1=py37_0 22 | - backports.shutil_get_terminal_size=1.0.0=py37_2 23 | - beautifulsoup4=4.8.0=py37_0 24 | - bitarray=1.0.1=py37h7b6447c_0 25 | - bkcharts=0.2=py37_0 26 | - blas=1.0=mkl 27 | - bleach=3.1.0=py37_0 28 | - blosc=1.16.3=hd408876_0 29 | - bokeh=1.3.4=py37_0 30 | - boto=2.49.0=py37_0 31 | - bottleneck=1.2.1=py37h035aef0_1 32 | - bzip2=1.0.8=h7b6447c_0 33 | - ca-certificates=2019.10.16=0 34 | - cairo=1.14.12=h8948797_3 35 | - certifi=2019.9.11=py37_0 36 | - cffi=1.12.3=py37h2e261b9_0 37 | - chardet=3.0.4=py37_1003 38 | - click=7.0=py37_0 39 | - cloudpickle=1.2.2=py_0 40 | - clyent=1.2.2=py37_1 41 | - colorama=0.4.1=py37_0 42 | - contextlib2=0.6.0=py_0 43 | - cryptography=2.7=py37h1ba5d50_0 44 | - cudatoolkit=10.0.130=0 45 | - curl=7.65.3=hbc83047_0 46 | - cycler=0.10.0=py37_0 47 | - cython=0.29.13=py37he6710b0_0 48 | - cytoolz=0.10.0=py37h7b6447c_0 49 | - dask=2.5.2=py_0 50 | - dask-core=2.5.2=py_0 51 | - dbus=1.13.6=h746ee38_0 52 | - decorator=4.4.0=py37_1 53 | - defusedxml=0.6.0=py_0 54 | - distributed=2.5.2=py_0 55 | - docutils=0.15.2=py37_0 56 | - entrypoints=0.3=py37_0 57 | - et_xmlfile=1.0.1=py37_0 58 | - expat=2.2.6=he6710b0_0 59 | - fastcache=1.1.0=py37h7b6447c_0 60 | - filelock=3.0.12=py_0 61 | - flask=1.1.1=py_0 62 | - fontconfig=2.13.0=h9420a91_0 63 | - freetype=2.9.1=h8a8886c_1 64 | - fribidi=1.0.5=h7b6447c_0 65 | - fsspec=0.5.2=py_0 66 | - get_terminal_size=1.0.0=haa9412d_0 67 | - gevent=1.4.0=py37h7b6447c_0 68 | - glib=2.56.2=hd408876_0 69 | - glob2=0.7=py_0 70 | - gmp=6.1.2=h6c8ec71_1 71 | - gmpy2=2.0.8=py37h10f8cd9_2 72 | - graphite2=1.3.13=h23475e2_0 73 | - greenlet=0.4.15=py37h7b6447c_0 74 | - gst-plugins-base=1.14.0=hbbd80ab_1 75 | - gstreamer=1.14.0=hb453b48_1 76 | - h5py=2.9.0=py37h7918eee_0 77 | - harfbuzz=1.8.8=hffaf4a1_0 78 | - hdf5=1.10.4=hb1b8bf9_0 79 | - heapdict=1.0.1=py_0 80 | - html5lib=1.0.1=py37_0 81 | - icu=58.2=h9c2bf20_1 82 | - idna=2.8=py37_0 83 | - imageio=2.6.0=py37_0 84 | - imagesize=1.1.0=py37_0 85 | - importlib_metadata=0.23=py37_0 86 | - intel-openmp=2019.4=243 87 | - ipykernel=5.1.2=py37h39e3cac_0 88 | - ipython=7.8.0=py37h39e3cac_0 89 | - ipython_genutils=0.2.0=py37_0 90 | - ipywidgets=7.5.1=py_0 91 | - isort=4.3.21=py37_0 92 | - itsdangerous=1.1.0=py37_0 93 | - jbig=2.1=hdba287a_0 94 | - jdcal=1.4.1=py_0 95 | - jedi=0.15.1=py37_0 96 | - jeepney=0.4.1=py_0 97 | - jinja2=2.10.3=py_0 98 | - joblib=0.13.2=py37_0 99 | - jpeg=9b=h024ee3a_2 100 | - json5=0.8.5=py_0 101 | - jsonschema=3.0.2=py37_0 102 | - jupyter=1.0.0=py37_7 103 | - jupyter_client=5.3.3=py37_1 104 | - jupyter_console=6.0.0=py37_0 105 | - jupyter_core=4.5.0=py_0 106 | - jupyterlab=1.1.4=pyhf63ae98_0 107 | - jupyterlab_server=1.0.6=py_0 108 | - keyring=18.0.0=py37_0 109 | - kiwisolver=1.1.0=py37he6710b0_0 110 | - krb5=1.16.1=h173b8e3_7 111 | - lazy-object-proxy=1.4.2=py37h7b6447c_0 112 | - libarchive=3.3.3=h5d8350f_5 113 | - libcurl=7.65.3=h20c2e04_0 114 | - libedit=3.1.20181209=hc058e9b_0 115 | - libffi=3.2.1=hd88cf55_4 116 | - libgcc-ng=9.1.0=hdf63c60_0 117 | - libgfortran-ng=7.3.0=hdf63c60_0 118 | - liblief=0.9.0=h7725739_2 119 | - libpng=1.6.37=hbc83047_0 120 | - libsodium=1.0.16=h1bed415_0 121 | - libssh2=1.8.2=h1ba5d50_0 122 | - libstdcxx-ng=9.1.0=hdf63c60_0 123 | - libtiff=4.0.10=h2733197_2 124 | - libtool=2.4.6=h7b6447c_5 125 | - libuuid=1.0.3=h1bed415_2 126 | - libxcb=1.13=h1bed415_1 127 | - libxml2=2.9.9=hea5a465_1 128 | - libxslt=1.1.33=h7d1a2b0_0 129 | - llvmlite=0.29.0=py37hd408876_0 130 | - locket=0.2.0=py37_1 131 | - lxml=4.4.1=py37hefd8a0e_0 132 | - lz4-c=1.8.1.2=h14c3975_0 133 | - lzo=2.10=h49e0be7_2 134 | - markupsafe=1.1.1=py37h7b6447c_0 135 | - matplotlib=3.1.1=py37h5429711_0 136 | - mccabe=0.6.1=py37_1 137 | - mistune=0.8.4=py37h7b6447c_0 138 | - mkl=2019.4=243 139 | - mkl-service=2.3.0=py37he904b0f_0 140 | - mkl_fft=1.0.14=py37ha843d7b_0 141 | - mkl_random=1.1.0=py37hd6b4f25_0 142 | - mock=3.0.5=py37_0 143 | - more-itertools=7.2.0=py37_0 144 | - mpc=1.1.0=h10f8cd9_1 145 | - mpfr=4.0.1=hdf1c602_3 146 | - mpmath=1.1.0=py37_0 147 | - msgpack-python=0.6.1=py37hfd86e86_1 148 | - multipledispatch=0.6.0=py37_0 149 | - nbconvert=5.6.0=py37_1 150 | - nbformat=4.4.0=py37_0 151 | - ncurses=6.1=he6710b0_1 152 | - networkx=2.3=py_0 153 | - ninja=1.9.0=py37hfd86e86_0 154 | - nltk=3.4.5=py37_0 155 | - nose=1.3.7=py37_2 156 | - notebook=6.0.1=py37_0 157 | - numba=0.45.1=py37h962f231_0 158 | - numexpr=2.7.0=py37h9e4a6bb_0 159 | - numpy=1.17.2=py37haad9e8e_0 160 | - numpy-base=1.17.2=py37hde5b4d6_0 161 | - numpydoc=0.9.1=py_0 162 | - olefile=0.46=py37_0 163 | - openpyxl=3.0.0=py_0 164 | - openssl=1.1.1d=h7b6447c_3 165 | - packaging=19.2=py_0 166 | - pandas=0.25.1=py37he6710b0_0 167 | - pandoc=2.2.3.2=0 168 | - pandocfilters=1.4.2=py37_1 169 | - pango=1.42.4=h049681c_0 170 | - parso=0.5.1=py_0 171 | - partd=1.0.0=py_0 172 | - patchelf=0.9=he6710b0_3 173 | - path.py=12.0.1=py_0 174 | - pathlib2=2.3.5=py37_0 175 | - patsy=0.5.1=py37_0 176 | - pcre=8.43=he6710b0_0 177 | - pep8=1.7.1=py37_0 178 | - pexpect=4.7.0=py37_0 179 | - pickleshare=0.7.5=py37_0 180 | - pillow=6.2.0=py37h34e0f95_0 181 | - pip=19.2.3=py37_0 182 | - pixman=0.38.0=h7b6447c_0 183 | - pkginfo=1.5.0.1=py37_0 184 | - pluggy=0.13.0=py37_0 185 | - ply=3.11=py37_0 186 | - prometheus_client=0.7.1=py_0 187 | - prompt_toolkit=2.0.10=py_0 188 | - psutil=5.6.3=py37h7b6447c_0 189 | - ptyprocess=0.6.0=py37_0 190 | - py=1.8.0=py37_0 191 | - py-lief=0.9.0=py37h7725739_2 192 | - pycodestyle=2.5.0=py37_0 193 | - pycosat=0.6.3=py37h14c3975_0 194 | - pycparser=2.19=py37_0 195 | - pycrypto=2.6.1=py37h14c3975_9 196 | - pycurl=7.43.0.3=py37h1ba5d50_0 197 | - pyflakes=2.1.1=py37_0 198 | - pygments=2.4.2=py_0 199 | - pylint=2.4.2=py37_0 200 | - pyodbc=4.0.27=py37he6710b0_0 201 | - pyopenssl=19.0.0=py37_0 202 | - pyparsing=2.4.2=py_0 203 | - pyqt=5.9.2=py37h05f1152_2 204 | - pyrsistent=0.15.4=py37h7b6447c_0 205 | - pysocks=1.7.1=py37_0 206 | - pytables=3.5.2=py37h71ec239_1 207 | - pytest=5.0.1=py37_0 208 | - pytest-arraydiff=0.3=py37h39e3cac_0 209 | - pytest-astropy=0.5.0=py37_0 210 | - pytest-doctestplus=0.4.0=py_0 211 | - pytest-openfiles=0.4.0=py_0 212 | - pytest-remotedata=0.3.2=py37_0 213 | - python=3.7.4=h265db76_1 214 | - python-dateutil=2.8.0=py37_0 215 | - python-libarchive-c=2.8=py37_13 216 | - pytorch=1.3.1=py3.7_cuda10.0.130_cudnn7.6.3_0 217 | - pytz=2019.3=py_0 218 | - pywavelets=1.0.3=py37hdd07704_1 219 | - pyyaml=5.1.2=py37h7b6447c_0 220 | - pyzmq=18.1.0=py37he6710b0_0 221 | - qt=5.9.7=h5867ecd_1 222 | - qtawesome=0.6.0=py_0 223 | - qtconsole=4.5.5=py_0 224 | - qtpy=1.9.0=py_0 225 | - readline=7.0=h7b6447c_5 226 | - requests=2.22.0=py37_0 227 | - ripgrep=0.10.0=hc07d326_0 228 | - rope=0.14.0=py_0 229 | - ruamel_yaml=0.15.46=py37h14c3975_0 230 | - scikit-image=0.15.0=py37he6710b0_0 231 | - scikit-learn=0.21.3=py37hd81dba3_0 232 | - scipy=1.3.1=py37h7c811a0_0 233 | - seaborn=0.9.0=py37_0 234 | - secretstorage=3.1.1=py37_0 235 | - send2trash=1.5.0=py37_0 236 | - setuptools=41.4.0=py37_0 237 | - simplegeneric=0.8.1=py37_2 238 | - singledispatch=3.4.0.3=py37_0 239 | - sip=4.19.8=py37hf484d3e_0 240 | - six=1.12.0=py37_0 241 | - snappy=1.1.7=hbae5bb6_3 242 | - snowballstemmer=2.0.0=py_0 243 | - sortedcollections=1.1.2=py37_0 244 | - sortedcontainers=2.1.0=py37_0 245 | - soupsieve=1.9.3=py37_0 246 | - sphinx=2.2.0=py_0 247 | - sphinxcontrib=1.0=py37_1 248 | - sphinxcontrib-applehelp=1.0.1=py_0 249 | - sphinxcontrib-devhelp=1.0.1=py_0 250 | - sphinxcontrib-htmlhelp=1.0.2=py_0 251 | - sphinxcontrib-jsmath=1.0.1=py_0 252 | - sphinxcontrib-qthelp=1.0.2=py_0 253 | - sphinxcontrib-serializinghtml=1.1.3=py_0 254 | - sphinxcontrib-websupport=1.1.2=py_0 255 | - spyder=3.3.6=py37_0 256 | - spyder-kernels=0.5.2=py37_0 257 | - sqlalchemy=1.3.9=py37h7b6447c_0 258 | - sqlite=3.30.0=h7b6447c_0 259 | - statsmodels=0.10.1=py37hdd07704_0 260 | - sympy=1.4=py37_0 261 | - tbb=2019.4=hfd86e86_0 262 | - tblib=1.4.0=py_0 263 | - terminado=0.8.2=py37_0 264 | - testpath=0.4.2=py37_0 265 | - tk=8.6.8=hbc83047_0 266 | - toolz=0.10.0=py_0 267 | - torchvision=0.4.2=py37_cu100 268 | - tornado=6.0.3=py37h7b6447c_0 269 | - tqdm=4.36.1=py_0 270 | - traitlets=4.3.3=py37_0 271 | - unicodecsv=0.14.1=py37_0 272 | - unixodbc=2.3.7=h14c3975_0 273 | - urllib3=1.24.2=py37_0 274 | - wcwidth=0.1.7=py37_0 275 | - webencodings=0.5.1=py37_1 276 | - werkzeug=0.16.0=py_0 277 | - wheel=0.33.6=py37_0 278 | - widgetsnbextension=3.5.1=py37_0 279 | - wrapt=1.11.2=py37h7b6447c_0 280 | - wurlitzer=1.0.3=py37_0 281 | - xlrd=1.2.0=py37_0 282 | - xlsxwriter=1.2.1=py_0 283 | - xlwt=1.3.0=py37_0 284 | - xz=5.2.4=h14c3975_4 285 | - yaml=0.1.7=had09818_2 286 | - zeromq=4.3.1=he6710b0_3 287 | - zict=1.0.0=py_0 288 | - zipp=0.6.0=py_0 289 | - zlib=1.2.11=h7b6447c_3 290 | - zstd=1.3.7=h0b5b093_0 291 | - pip: 292 | - crayon==0.4 293 | - docopt==0.6.2 294 | - jsonpatch==1.24 295 | - jsonpointer==2.0 296 | - ptpython==2.0.6 297 | - pycrayon==0.5 298 | - torchfile==0.1.0 299 | - torchnet==0.0.4 300 | - torchsummary==1.5.1 301 | - visdom==0.1.8.9 302 | - websocket-client==0.57.0 303 | 304 | -------------------------------------------------------------------------------- /datasets/sysu_mm01/test_config.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InnovArul/rgb_IR_personreid/450f1d04995263980f4fb0b098960e43911a0e18/datasets/sysu_mm01/test_config.pth -------------------------------------------------------------------------------- /datasets/sysu_mm01/train_config.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InnovArul/rgb_IR_personreid/450f1d04995263980f4fb0b098960e43911a0e18/datasets/sysu_mm01/train_config.pth -------------------------------------------------------------------------------- /datasets/sysu_mm01/val_config.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InnovArul/rgb_IR_personreid/450f1d04995263980f4fb0b098960e43911a0e18/datasets/sysu_mm01/val_config.pth -------------------------------------------------------------------------------- /src_py/dataset.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import torch 3 | import numpy as np 4 | from skimage import io, transform, img_as_float 5 | import log 6 | import torch.utils.data as data 7 | from utils import * 8 | 9 | 10 | def read_ids(root_path, split_type): 11 | config_file_path = os.path.join(root_path, "exp", split_type + "_id.txt") 12 | with open(config_file_path, "r") as file: 13 | file_lines = file.readlines() 14 | 15 | # the file has only one line with ids 16 | id_line = file_lines[0] 17 | all_ids = ["%04d" % int(i) for i in id_line.split(",")] 18 | print(config_file_path + " : " + str(len(all_ids))) 19 | return all_ids 20 | 21 | 22 | def read_image_of_category(img_path, category): 23 | # read the image file 24 | # print(img_path) 25 | img = img_as_float(io.imread(img_path)) 26 | 27 | # if the image is IR, get the single channel as all channels will have same value 28 | if category == "IR": 29 | # print('IR first channel') 30 | img = img[:, :, 1] 31 | img = img[..., np.newaxis] 32 | 33 | # print('resizing') 34 | img = transform.resize(img, output_shape=(224, 224)) 35 | 36 | # if image is rgb, convert to gray scale 37 | if category == "rgb": 38 | # print('gray scale') 39 | img = rgb2gray(img) 40 | img = img[..., np.newaxis] 41 | 42 | return img 43 | 44 | 45 | class cam_ID_folder: 46 | def __init__(self, root_path, cam_name, ID, cam_config, is_read_image=False): 47 | # init the instance variables 48 | self.root_path = root_path 49 | self.cam_name = cam_name 50 | self.ID = ID 51 | self.folder_path = os.path.join(self.root_path, self.cam_name, self.ID) 52 | self.cam_config = cam_config 53 | self.is_read_image = is_read_image 54 | 55 | def is_exists(self): 56 | # returns true if the folder exists 57 | return os.path.exists(self.folder_path) 58 | 59 | def read_image_file(self, img_path): 60 | img = None 61 | if self.is_read_image: 62 | img = read_image_of_category(self.img_path, self.cam_config[self.cam_name]) 63 | else: 64 | img = img_path 65 | 66 | return img 67 | 68 | def get_file_instances(self): 69 | instances = [] 70 | 71 | # if the folder exists 72 | if self.is_exists(): 73 | print( 74 | self.folder_path 75 | + " : " 76 | + str(len(os.listdir(self.folder_path))) 77 | + " files" 78 | ) 79 | # for each of the file in the directory 80 | for file in os.listdir(self.folder_path): 81 | # read the file and store it in the list 82 | filepath = os.path.join(self.folder_path, file) 83 | img = self.read_image_file(filepath) 84 | 85 | instances.append( 86 | (img, self.ID, self.cam_config[self.cam_name], self.cam_name) 87 | ) 88 | 89 | return instances 90 | 91 | 92 | # convert rgb to gray image 93 | def rgb2gray(rgb): 94 | return np.dot(rgb[..., :3], [0.299, 0.587, 0.114]) 95 | 96 | 97 | class Dataset(data.Dataset): 98 | def __init__( 99 | self, root_path, IDs, config_name, is_read_image=True, is_return_path=False 100 | ): 101 | self.root_path = root_path 102 | self.cam_config = { 103 | "cam1": "rgb", 104 | "cam2": "rgb", 105 | "cam3": "IR", 106 | "cam4": "rgb", 107 | "cam5": "rgb", 108 | "cam6": "IR", 109 | } 110 | 111 | self.IDs = IDs 112 | self.is_read_image = is_read_image 113 | self.config_name = config_name 114 | self.data_instances = self.read_data_instances() 115 | self.IDs2Classes = {} 116 | 117 | for index, id in enumerate(self.IDs): 118 | self.IDs2Classes[id] = index 119 | 120 | def read_data_instances(self): 121 | data_instances = [] 122 | 123 | # check if the config already exists 124 | config_file = os.path.join(self.root_path, self.config_name + "_config.pth") 125 | 126 | # check if the config file is already existing 127 | if os.path.exists(config_file): 128 | # load the existing config file 129 | print("existing config file " + config_file + " found!. Reading the file!") 130 | data_instances = torch.load(config_file) 131 | else: 132 | # for each of the ids 133 | for ID in self.IDs: 134 | # for each of the cameras 135 | for cam_name in self.cam_config.keys(): 136 | # get all the data instances 137 | folder = cam_ID_folder( 138 | self.root_path, cam_name, ID, self.cam_config 139 | ) 140 | data_instances += folder.get_file_instances() 141 | 142 | # save the configuration 143 | torch.save(data_instances, config_file) 144 | 145 | return data_instances 146 | 147 | def __len__(self): 148 | return len(self.data_instances) 149 | 150 | def do_random_translation(self, img): 151 | height, width, _ = img.shape 152 | width_range = 0.1 * width 153 | height_range = 0.1 * height 154 | 155 | # get a random height and width 156 | translate_height = np.random.uniform(-height_range, height_range) 157 | translate_width = np.random.uniform(-width_range, width_range) 158 | 159 | # create a similarity transform 160 | sim_transform = transform.SimilarityTransform( 161 | scale=1, rotation=0, translation=(translate_width, translate_height) 162 | ) 163 | 164 | # warp the image 165 | img = transform.warp(img, sim_transform) 166 | return img 167 | 168 | def pad_zeros_by_category(self, img, category): 169 | # pad the zeros based on img type 170 | # print(img.shape) 171 | padding = np.zeros_like(img) 172 | if category == "rgb": 173 | # for rgb images, pad zeros as second channel 174 | img = np.concatenate((img, padding), axis=2) 175 | else: 176 | # for IR images, pad zeros as first channel 177 | img = np.concatenate((padding, img), axis=2) 178 | return img 179 | 180 | def __getitem__(self, index): 181 | img, ID, category, cam_name = self.data_instances[index] 182 | 183 | if self.is_read_image: 184 | img = read_image_of_category(img, category) 185 | 186 | # do random data augmentation 187 | img = self.do_random_translation(img) 188 | 189 | # pad zeros according to category 190 | img = self.pad_zeros_by_category(img, category) 191 | 192 | return ( 193 | img.transpose((2, 0, 1)), 194 | self.IDs2Classes[ID], 195 | ) # for data.Dataset .transpose((2,0,1)) 196 | 197 | 198 | class TestDataset(Dataset): 199 | def __init__(self, root_path, IDs, config_name): 200 | # init the super class 201 | super().__init__(root_path, IDs, config_name) 202 | 203 | def read_data_instances(self): 204 | print("child class method") 205 | data_instances = {} 206 | for cam_name in self.cam_config.keys(): 207 | data_instances[cam_name] = {} 208 | 209 | # check if the config already exists 210 | config_file = os.path.join(self.root_path, self.config_name + "_config.pth") 211 | 212 | # check if the config file is already existing 213 | if os.path.exists(config_file): 214 | # load the existing config file 215 | print("existing config file " + config_file + " found!. Reading the file!") 216 | data_instances = torch.load(config_file) 217 | else: 218 | # for each of the ids 219 | for ID in self.IDs: 220 | # for each of the cameras 221 | for cam_name in self.cam_config.keys(): 222 | # get all the data instances 223 | folder = cam_ID_folder( 224 | self.root_path, cam_name, ID, self.cam_config 225 | ) 226 | current_folder_instances = folder.get_file_instances() 227 | current_folder_instances = self.order_file_names( 228 | current_folder_instances 229 | ) 230 | data_instances[cam_name][ID] = current_folder_instances 231 | 232 | # save the configuration 233 | torch.save(data_instances, config_file) 234 | 235 | return data_instances 236 | 237 | def get_cam_files_config(self): 238 | return self.data_instances 239 | 240 | def read_image_from_config(self, config): 241 | # config will contain 242 | # img path, ID, category(rgb or IR), cam name 243 | img = read_image_of_category(config[0], config[2]) 244 | img = self.pad_zeros_by_category(img, config[2]) 245 | 246 | return torch.from_numpy(img.transpose((2, 0, 1))) 247 | 248 | def order_file_names(self, instances): 249 | # create a hash with file name 250 | filenames_hash = {} 251 | for inst in instances: 252 | filename = get_file_name(inst[0]) 253 | # print(filename) 254 | filenames_hash[filename] = inst 255 | 256 | # create an array ordered by filename, in numerical order 257 | total_files = len(instances) 258 | ordered_instances = [] 259 | for i in range(total_files): 260 | ordered_instances.append(filenames_hash["%04d" % (i + 1)]) 261 | 262 | return ordered_instances 263 | -------------------------------------------------------------------------------- /src_py/doall.py: -------------------------------------------------------------------------------- 1 | import os, sys, torch 2 | import numpy as np 3 | from torch.autograd import Variable 4 | from dataset import * 5 | import settings, log 6 | import torch.nn as nn, torch.utils.data as data 7 | import torch.optim as optim 8 | from utils import * 9 | from tqdm import tqdm 10 | from model import resnet6 11 | 12 | 13 | force_new_model = True 14 | pretrained_model = None 15 | # init the settings 16 | settings.init_settings(force_new_model, pretrained_model) 17 | # init the log 18 | log.init_logger(tensorboard=False) 19 | 20 | 21 | def show_images(img): 22 | show_image(img[:, :, 0]) 23 | show_image(img[:, :, 1]) 24 | 25 | 26 | def train(model, train_data, criterion, optimizer, epoch): 27 | total_train_images = len(train_data) 28 | # print(total_train_images) 29 | 30 | if opt["useGPU"]: 31 | model = model.cuda() 32 | 33 | # training mode 34 | model.train() 35 | 36 | logger.info("epoch # " + str(epoch)) 37 | total_loss = 0 38 | 39 | # for index in range(total_train_images): 40 | # img, classID = train_data[100+index] 41 | # show_images(img) 42 | # input() 43 | 44 | for batch_index, contents in enumerate(tqdm(train_data)): 45 | imgs, targets = contents 46 | 47 | if opt["useGPU"]: 48 | imgs = imgs.float().cuda() 49 | targets = targets.cuda() 50 | 51 | var_imgs = Variable(imgs) 52 | var_targets = Variable(targets) 53 | 54 | features, output = model(var_imgs) 55 | 56 | optimizer.zero_grad() 57 | loss = criterion(output, var_targets) 58 | loss.backward() 59 | optimizer.step() 60 | 61 | total_loss += loss.item() 62 | if batch_index % 100 == 0: 63 | logger.info("total loss for current epoch so far : " + str(total_loss)) 64 | 65 | logger.info("total loss for current epoch : " + str(total_loss)) 66 | torch.save( 67 | model.cpu().state_dict(), 68 | os.path.join(opt["save"], "deep_zero_model#" + str(epoch) + ".pth"), 69 | ) 70 | 71 | 72 | if __name__ == "__main__": 73 | opt = settings.opt 74 | logger = log.logger 75 | train_ids = read_ids(opt["dataroot"], "train") 76 | val_ids = read_ids(opt["dataroot"], "val") 77 | test_ids = read_ids(opt["dataroot"], "test") 78 | 79 | train_dataset = Dataset(opt["dataroot"], train_ids, "train") 80 | val_dataset = Dataset(opt["dataroot"], val_ids, "val") 81 | test_dataset = TestDataset(opt["dataroot"], test_ids, "test") 82 | 83 | train_dataloader = data.DataLoader( 84 | train_dataset, batch_size=opt["batch_size"], shuffle=True 85 | ) 86 | 87 | model = get_model( 88 | arch=opt["arch"], num_classes=len(train_ids), pretrained_model=pretrained_model 89 | ) 90 | criterion = nn.CrossEntropyLoss() 91 | optimizer = optim.SGD( 92 | model.parameters(), 93 | lr=opt["lr"], 94 | momentum=opt["momentum"], 95 | nesterov=opt["nesterov"], 96 | weight_decay=opt["weight_decay"], 97 | ) 98 | 99 | # for each epoch, run the training 100 | for epoch in range(opt["epochs"]): 101 | train(model, train_dataloader, criterion, optimizer, epoch) 102 | -------------------------------------------------------------------------------- /src_py/doall_test.py: -------------------------------------------------------------------------------- 1 | import os, sys, torch 2 | import numpy as np 3 | from torch.autograd import Variable 4 | from dataset import * 5 | import settings, log 6 | import torch.nn as nn, torch.utils.data as data 7 | import torch.optim as optim 8 | from utils import * 9 | from tqdm import tqdm 10 | from model import resnet6 11 | import scipy.io as sio 12 | 13 | pretrained_model = "../scratch/sysu_mm01/deepzeropadding-14May2019-125214_deep-zero-padding/deep_zero_model#156.pth" 14 | # init the settings 15 | 16 | settings.init_settings(False, pretrained_model) 17 | # init the log 18 | log.init_logger(tensorboard=False, prepend_text="test_") 19 | 20 | 21 | def get_max_test_id(test_ids): 22 | int_test_ids = [int(ID) for ID in test_ids] 23 | return np.max(int_test_ids) 24 | 25 | 26 | def prepare_empty_matfile_config(max_test_id): 27 | cam_features = np.empty(max_test_id, dtype=object) 28 | for i in range(len(cam_features)): 29 | cam_features[i] = [] 30 | return cam_features 31 | 32 | 33 | def test(model, test_dataset, test_ids): 34 | data_instances = test_dataset.get_cam_files_config() 35 | #print(len(data_instances), data_instances[0]) 36 | matfile_prefix = get_file_name(pretrained_model) 37 | testresults_dir = os.path.join(opt["save"], matfile_prefix) 38 | if not os.path.exists(testresults_dir): 39 | os.mkdir(testresults_dir) 40 | 41 | max_test_id = get_max_test_id(test_ids) 42 | model.eval() 43 | if opt["useGPU"]: 44 | model = model.cuda() 45 | 46 | for cam_name, id_contents in data_instances.items(): 47 | # data_instances 48 | # --cam1 49 | # -- 0001 50 | # -- 0001.jpg 51 | # -- 0002.jpg 52 | # -- ... 53 | matfile_path = os.path.join( 54 | testresults_dir, matfile_prefix + "_" + cam_name + ".mat" 55 | ) 56 | print(cam_name) 57 | 58 | # prepare empty features for all the person ids upto max_test_id 59 | # in the feature_original, features wrt all the images from all ids are extracted regardless 60 | # of whether it is a test subject or not 61 | # but this script only extracts the features for test subjects, that too only upto max_test_id (cam3 contains 533 ids, but 333 is the max test id) 62 | # other ids within 333 will have empty features (shape = (0,0)) 63 | cam_features = prepare_empty_matfile_config(max_test_id) 64 | 65 | for id_, img_contents in tqdm(id_contents.items()): 66 | all_current_id_features = np.empty(shape=[0, 2048]) 67 | for img_config in img_contents: 68 | #print(img_config) 69 | img = test_dataset.read_image_from_config(img_config) 70 | 71 | if opt["useGPU"]: 72 | img = img.unsqueeze(0).float().cuda() 73 | 74 | var_img = Variable(img) 75 | features, _ = model(var_img) 76 | current_feature = features.data[0].cpu().numpy().reshape(1, -1) 77 | #print(current_feature.shape) 78 | #print(all_current_id_features.shape) 79 | all_current_id_features = np.append( 80 | all_current_id_features, current_feature, axis=0 81 | ) 82 | 83 | cam_features[int(id_) - 1] = all_current_id_features 84 | # print(cam_features[int(id_)-1].shape) 85 | 86 | sio.savemat(matfile_path, {"feature": cam_features}) 87 | 88 | 89 | if __name__ == "__main__": 90 | opt = settings.opt 91 | logger = log.logger 92 | train_ids = read_ids(opt["dataroot"], "train") 93 | test_ids = read_ids(opt["dataroot"], "test") 94 | test_dataset = TestDataset(opt["dataroot"], test_ids, "test") 95 | 96 | model = get_model(num_classes=len(train_ids), arch = opt['arch'], pretrained_model = pretrained_model) 97 | model.load_state_dict(torch.load(pretrained_model)) 98 | 99 | test(model, test_dataset, test_ids) 100 | -------------------------------------------------------------------------------- /src_py/log.py: -------------------------------------------------------------------------------- 1 | import logging, os, sys, settings 2 | from pycrayon import CrayonClient 3 | from pathlib import Path 4 | os.environ['no_proxy'] = '127.0.0.1,localhost' 5 | 6 | def init_logger(tensorboard=True, prepend_text=""): 7 | global logger, experimentLogger 8 | logger = logging.getLogger('heel-contour-prediction') 9 | 10 | #log file handler 11 | print(settings.opt) 12 | fileHandler = logging.FileHandler(os.path.join(settings.opt['save'], prepend_text + settings.opt['description'] + '.log')) 13 | formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s') 14 | fileHandler.setFormatter(formatter) 15 | fileHandler.setLevel(logging.INFO) 16 | logger.addHandler(fileHandler) 17 | 18 | #output stream handler 19 | streamHandler = logging.StreamHandler() 20 | streamHandler.setFormatter(formatter) 21 | streamHandler.setLevel(logging.INFO) 22 | logger.addHandler(streamHandler) 23 | 24 | logger.setLevel(logging.INFO) 25 | logger.info('file handler and stream handler are ready for logging') 26 | 27 | if(tensorboard == True): 28 | cc = CrayonClient(hostname="localhost") 29 | experimentLogger = cc.create_experiment(Path(settings.opt['save']).name) 30 | 31 | # log the configuration 32 | logger.info(settings.opt) 33 | -------------------------------------------------------------------------------- /src_py/model.py: -------------------------------------------------------------------------------- 1 | import os, sys, math 2 | import numpy as np 3 | import torch.nn as nn, torch.nn.functional as F 4 | 5 | def conv3x3(in_planes, out_planes, stride=1): 6 | """3x3 convolution with padding""" 7 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 8 | padding=1, bias=False) 9 | 10 | 11 | class BasicBlock(nn.Module): 12 | expansion = 1 13 | 14 | def __init__(self, inplanes, planes, stride=1, downsample=None): 15 | super(BasicBlock, self).__init__() 16 | self.conv1 = conv3x3(inplanes, planes, stride) 17 | self.bn1 = nn.BatchNorm2d(planes) 18 | self.relu = nn.ReLU(inplace=True) 19 | self.conv2 = conv3x3(planes, planes) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | self.downsample = downsample 22 | self.stride = stride 23 | 24 | def forward(self, x): 25 | residual = x 26 | 27 | out = self.conv1(x) 28 | out = self.bn1(out) 29 | out = self.relu(out) 30 | 31 | out = self.conv2(out) 32 | out = self.bn2(out) 33 | 34 | if self.downsample is not None: 35 | residual = self.downsample(x) 36 | 37 | out += residual 38 | out = self.relu(out) 39 | 40 | return out 41 | 42 | class ResNet(nn.Module): 43 | 44 | def __init__(self, block, layers, num_classes=1000): 45 | self.inplanes = 64 46 | super(ResNet, self).__init__() 47 | self.conv1 = nn.Conv2d(2, 64, kernel_size=7, stride=2, padding=3, 48 | bias=False) 49 | self.bn1 = nn.BatchNorm2d(64) 50 | self.relu = nn.ReLU(inplace=True) 51 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 52 | self.layer1 = self._make_layer(block, 64, layers[0]) 53 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 54 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 55 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 56 | self.avgpool = nn.AvgPool2d(7, stride=1) 57 | self.fc = nn.Linear(512 * block.expansion, num_classes) 58 | 59 | for m in self.modules(): 60 | if isinstance(m, nn.Conv2d): 61 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 62 | m.weight.data.normal_(0, math.sqrt(2. / n)) 63 | elif isinstance(m, nn.BatchNorm2d): 64 | m.weight.data.fill_(1) 65 | m.bias.data.zero_() 66 | 67 | def _make_layer(self, block, planes, blocks, stride=1): 68 | downsample = None 69 | if stride != 1 or self.inplanes != planes * block.expansion: 70 | downsample = nn.Sequential( 71 | nn.Conv2d(self.inplanes, planes * block.expansion, 72 | kernel_size=1, stride=stride, bias=False), 73 | nn.BatchNorm2d(planes * block.expansion), 74 | ) 75 | 76 | layers = [] 77 | layers.append(block(self.inplanes, planes, stride, downsample)) 78 | self.inplanes = planes * block.expansion 79 | for i in range(1, blocks): 80 | layers.append(block(self.inplanes, planes)) 81 | 82 | return nn.Sequential(*layers) 83 | 84 | def forward(self, x): 85 | x = self.conv1(x) 86 | x = self.bn1(x) 87 | x = self.relu(x) 88 | x = self.maxpool(x) 89 | 90 | x = self.layer1(x) 91 | x = self.layer2(x) 92 | x = self.layer3(x) 93 | x = self.layer4(x) 94 | 95 | x = self.avgpool(x) 96 | features = x.view(x.size(0), -1) 97 | classifier_output = self.fc(features) 98 | 99 | return features, classifier_output 100 | 101 | def resnet6(num_classes, pretrained_model=None): 102 | """Constructs a ResNet-18 model. 103 | Args: 104 | pretrained (bool): If True, returns a model pre-trained on ImageNet 105 | """ 106 | model = ResNet(block=BasicBlock, layers=[1, 1, 1, 1], num_classes=num_classes) 107 | if pretrained_model is not None: 108 | print('loading pretrained weights from ' + pretrained_model) 109 | model.load_state_dict(torch.load(pretrained_model)) 110 | return model 111 | -------------------------------------------------------------------------------- /src_py/model_resnet_pretrained.py: -------------------------------------------------------------------------------- 1 | import os, sys, math 2 | import numpy as np 3 | import torch.nn as nn, torch.nn.functional as F 4 | import torchvision.models as models 5 | 6 | 7 | class ResNet50(nn.Module): 8 | def __init__(self, num_classes=1000, pretrained=True): 9 | super().__init__() 10 | print("creating model " + self.__class__.__name__) 11 | self.resnet50 = models.resnet50(pretrained=pretrained) 12 | 13 | # modify the initial conv layer 14 | conv1_layer = self.resnet50.conv1 15 | conv1_weights = conv1_layer.weight.data.clone() 16 | 17 | self.resnet50.conv1 = nn.Conv2d( 18 | 2, 19 | out_channels=conv1_layer.out_channels, 20 | kernel_size=conv1_layer.kernel_size, 21 | padding=conv1_layer.padding, 22 | stride=conv1_layer.stride, 23 | bias=not (conv1_layer is None), 24 | ) 25 | 26 | mean_conv1_weights = conv1_weights.mean(dim=1, keepdim=True).repeat(1, 2, 1, 1) 27 | self.resnet50.conv1.weight.data.copy_(mean_conv1_weights) 28 | 29 | self.feature_dim = 2048 30 | self.fc = nn.Linear(self.feature_dim, num_classes) 31 | 32 | def forward(self, x): 33 | x = self.resnet50.conv1(x) 34 | x = self.resnet50.bn1(x) 35 | x = self.resnet50.relu(x) 36 | x = self.resnet50.maxpool(x) 37 | 38 | x = self.resnet50.layer1(x) 39 | x = self.resnet50.layer2(x) 40 | x = self.resnet50.layer3(x) 41 | x = self.resnet50.layer4(x) 42 | 43 | x = F.adaptive_avg_pool2d(x, output_size=1) 44 | features = x.view(x.size(0), -1) 45 | classifier_output = self.fc(features) 46 | 47 | return features, classifier_output 48 | 49 | -------------------------------------------------------------------------------- /src_py/settings.py: -------------------------------------------------------------------------------- 1 | # options file 2 | import time, os, sys, torch, torchvision 3 | import torchnet.meter as meter 4 | from pathlib import Path 5 | from torch.autograd import Variable 6 | import torch 7 | 8 | 9 | def init_settings(force_new_model, pretrained_model_path=None): 10 | global opt 11 | opt = dict() 12 | opt["dataroot"] = "../datasets/sysu_mm01" 13 | opt["batch_size"] = 64 14 | opt["lr"] = 0.001 15 | opt["momentum"] = 0.9 16 | opt["weight_decay"] = 0.0005 17 | opt["nesterov"] = False 18 | opt["description"] = "deep-zero-padding" 19 | opt["useGPU"] = True 20 | opt["epochs"] = 20000 21 | opt["arch"] = "resnet50" # resnet6 | resnet50 22 | 23 | # determine the log / save folder path 24 | if force_new_model: 25 | opt["save"] = ( 26 | "../scratch/sysu_mm01/deepzeropadding-" 27 | + time.strftime("%d%b%Y-%H%M%S") 28 | + "_" 29 | + opt["description"] 30 | ) 31 | else: 32 | opt["save"] = str(Path(pretrained_model_path).parent) 33 | print("save path : " + opt["save"]) 34 | os.makedirs(opt["save"], exist_ok=True) 35 | 36 | -------------------------------------------------------------------------------- /src_py/utils.py: -------------------------------------------------------------------------------- 1 | # import matplotlib.pyplot as plt 2 | import os, sys, torch 3 | from model import resnet6 4 | from model_resnet_pretrained import ResNet50 5 | 6 | def show_image(image): 7 | dpi = 80 8 | figsize = (image.shape[1] / float(dpi), image.shape[0] / float(dpi)) 9 | fig = plt.figure(figsize=figsize) 10 | plt.imshow(image) 11 | fig.show() 12 | 13 | 14 | def get_file_name(filepath): 15 | return os.path.basename(filepath).split(".")[0] 16 | 17 | 18 | def get_model(arch, num_classes, pretrained_model): 19 | if arch == "resnet6": 20 | model = resnet6(num_classes=num_classes) 21 | elif arch == "resnet50": 22 | model = ResNet50(num_classes=num_classes) 23 | else: 24 | assert False, "unknown model arch: " + arch 25 | 26 | if pretrained_model is not None: 27 | model.load_state_dict(torch.load(pretrained_model)) 28 | 29 | return model 30 | 31 | 32 | def load_pretrained_model(model, pretrained_model_path, verbose=False): 33 | """To load the pretrained model considering the number of keys and their sizes 34 | 35 | Arguments: 36 | model {loaded model} -- already loaded model 37 | pretrained_model_path {str} -- path to the pretrained model file 38 | 39 | Raises: 40 | IOError -- if the file path is not found 41 | 42 | Returns: 43 | model -- model with loaded params 44 | """ 45 | 46 | if isinstance(pretrained_model_path, str): 47 | if not os.path.exists(pretrained_model_path): 48 | raise IOError( 49 | "Can't find pretrained model: {}".format(pretrained_model_path) 50 | ) 51 | 52 | print("Loading checkpoint from '{}'".format(pretrained_model_path)) 53 | pretrained_state = torch.load(pretrained_model_path)["state_dict"] 54 | else: 55 | # incase pretrained model weights are given 56 | pretrained_state = pretrained_model_path 57 | 58 | print(len(pretrained_state), " keys in pretrained model") 59 | 60 | current_model_state = model.state_dict() 61 | print(len(current_model_state), " keys in current model") 62 | pretrained_state = { 63 | key: val 64 | for key, val in pretrained_state.items() 65 | if key in current_model_state and val.size() == current_model_state[key].size() 66 | } 67 | 68 | print( 69 | len(pretrained_state), 70 | " keys in pretrained model are available in current model", 71 | ) 72 | current_model_state.update(pretrained_state) 73 | model.load_state_dict(current_model_state) 74 | 75 | if verbose: 76 | non_available_keys_in_pretrained = [ 77 | key 78 | for key, val in pretrained_state.items() 79 | if key not in current_model_state 80 | or val.size() != current_model_state[key].size() 81 | ] 82 | non_available_keys_in_current = [ 83 | key 84 | for key, val in current_model_state.items() 85 | if key not in pretrained_state or val.size() != pretrained_state[key].size() 86 | ] 87 | 88 | print( 89 | "not available keys in pretrained model: ", non_available_keys_in_pretrained 90 | ) 91 | print("not available keys in current model: ", non_available_keys_in_current) 92 | 93 | return model 94 | --------------------------------------------------------------------------------