├── .gitignore ├── Assets └── README_CN.md ├── LICENSE ├── README.md ├── classification ├── .gitignore ├── LICENSE ├── README.md ├── data │ ├── __init__.py │ ├── data_interface.py │ ├── ref │ │ └── .gitkeep │ └── standard_data.py ├── main.py ├── model │ ├── __init__.py │ ├── common.py │ ├── model_interface.py │ ├── simple_net.py │ └── standard_net.py └── utils.py ├── special └── kfold │ ├── .gitignore │ ├── LICENSE │ ├── README.md │ ├── data │ ├── __init__.py │ ├── data_interface.py │ ├── ref │ │ └── .gitkeep │ └── standard_data.py │ ├── kfold_train.sh │ ├── kfold_val.ipynb │ ├── main.py │ ├── model │ ├── __init__.py │ ├── common.py │ ├── model_interface.py │ ├── simple_net.py │ └── standard_net.py │ └── utils.py └── super-resolution ├── .gitignore ├── LICENSE ├── data ├── __init__.py ├── common.py ├── data_interface.py ├── recursive_up.py └── satup_data.py ├── main.py ├── model ├── __init__.py ├── common.py ├── metrics.py ├── model_interface.py ├── rdn_fuse.py └── utils.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.pyc 3 | dataset/ 4 | lightning_logs/ 5 | MNIST/ 6 | weights/ 7 | /backup -------------------------------------------------------------------------------- /Assets/README_CN.md: -------------------------------------------------------------------------------- 1 | # Pytorch-Lightning-Template 2 | 3 | [**English Version**](../README.md) 4 | 5 | ## Introduction 6 | 7 | Pytorch-Lightning 是一个很便利的库,它可以看作是Pytorch的抽象和包装。它的好处是可复用性强,易维护,逻辑清晰等。缺点是过重,需要学习和理解的内容较多,另外其直接将模型和训练代码强绑定的模式并不适合大型的真实项目,因为这些项目中往往有很多不同的模型需要训练测试。数据部分也是如此,DataLoader等与自定义Dataset强绑定也造成了类似的问题:同样的代码被不优雅的多次拷贝粘贴。 8 | 9 | 经过不断的摸索和调试,我总结出了下面这样一套好用的模板,也可以说是对Pytorch-Lightning的进一步抽象。初版中模板内容都在根文件夹下,但是经过一个多月的使用,我发现如果能针对不同类型的项目建立相应的模板,编码效率可以得到进一步提高。如分类任务和超分辨率任务最近都有做,但它们都有一些固定的需求点,每次直接使用经过特化的模板完成速度更快,也减少了一些可避免的bug和debug。同时,也可以添加一些仅适用于本任务的代码与文件。 10 | 11 | **当前由于建库时间尚短,只有这两种模板。但后面随着我应用它到其他项目上,也会添加新的特化模板。如果您使用了本模板到您的任务(如NLP, GAN, 语音识别等),欢迎提出PR,以便整合您的模板到总库,方便更多人使用。如果您的任务还不在列表中,不妨从`classification`模板开始,调整配制出适合您任务的模板。由于绝大部分模板底层代码是相同的,这可以被很快完成。** 12 | 13 | 欢迎大家尝试这一套代码风格,如果用习惯的话还是相当方便复用的,也不容易半道退坑。更加详细的解释和对Pytorch-Lightning的完全攻略可以在[本篇](https://zhuanlan.zhihu.com/p/353985363)知乎博客上找到。 14 | 15 | ## File Structure 16 | 17 | ``` 18 | root- 19 | |-data 20 | |-__init__.py 21 | |-data_interface.py 22 | |-xxxdataset1.py 23 | |-xxxdataset2.py 24 | |-... 25 | |-model 26 | |-__init__.py 27 | |-model_interface.py 28 | |-xxxmodel1.py 29 | |-xxxmodel2.py 30 | |-... 31 | |-main.py 32 | |-utils.py 33 | ``` 34 | 35 | ## Installation 36 | 37 | 本模板不需要安装,直接`git clone https://github.com/miracleyoo/pytorch-lightning-template.git` 到本地即可。使用时选择你需要的问题类型(如`classification`),将那个文件夹直接拷贝到你的项目文件夹中。 38 | 39 | ## Explanation 40 | 41 | 模板架构: 42 | 43 | - 主目录下只放一个`main.py`文件和一个用于辅助的`utils.py`。 44 | 45 | - `data`和`modle`两个文件夹中放入`__init__.py`文件,做成包。这样方便导入。两个`init`文件分别是: 46 | 47 | - `from .data_interface import DInterface` 48 | - `from .model_interface import MInterface` 49 | 50 | - 在`data_interface `中建立一个`class DInterface(pl.LightningDataModule):`用作所有数据集文件的接口。`__init__()`函数中import相应Dataset类,`setup()`进行实例化,并老老实实加入所需要的的`train_dataloader`, `val_dataloader`, `test_dataloader`函数。这些函数往往都是相似的,可以用几个输入args控制不同的部分。 51 | 52 | - 同理,在`model_interface `中建立`class MInterface(pl.LightningModule):`类,作为模型的中间接口。`__init__()`函数中import相应模型类,然后老老实实加入`configure_optimizers`, `training_step`, `validation_step`等函数,用一个接口类控制所有模型。不同部分使用输入参数控制。 53 | 54 | - `main.py`函数只负责: 55 | 56 | - 定义parser,添加parse项。(注意如果你的模型或数据集文件的`__init__`函数中有需要外部控制的变量,如一个`random_arg`,你可以直接在`main.py`的Parser中添加这样一项,如`parser.add_argument('--random_arg', default='test', type=str)`,两个`Interface`类会自动传导这些参数到你的模型或数据集类中。) 57 | - 选好需要的`callback`函数们,如自动存档,Early Stop,LR Scheduler等。 58 | - 实例化`MInterface`, `DInterface`, `Trainer`。 59 | 60 | 完事。 61 | 62 | **需要注意的是,为了实现自动加入新model和dataset而不用更改Interface,model文件夹中的模型文件名应该使用snake case命名,如`rdn_fuse.py`,而文件中的主类则要使用对应的驼峰命名法,如`RdnFuse`**。 63 | 64 | 数据集data文件夹也是一样。 65 | 66 | 虽然对命名提出了较紧的要求,但实际上并不会影响使用,反而让你的代码结构更加清晰。希望使用时候可以注意这点,以免无法parse。 67 | 68 | ## Citation 69 | 70 | 如果本模板对您的研究起到了一定的助力,请考虑引用我们的论文: 71 | 72 | ``` 73 | @misc{https://doi.org/10.48550/arxiv.2301.06648, 74 | doi = {10.48550/ARXIV.2301.06648}, 75 | url = {https://arxiv.org/abs/2301.06648}, 76 | author = {Zhang, Zhongyang and Chai, Kaidong and Yu, Haowen and Majaj, Ramzi and Walsh, Francesca and Wang, Edward and Mahbub, Upal and Siegelmann, Hava and Kim, Donghyun and Rahman, Tauhidur}, 77 | keywords = {Computer Vision and Pattern Recognition (cs.CV), Artificial Intelligence (cs.AI), FOS: Computer and information sciences, FOS: Computer and information sciences}, 78 | title = {YeLan: Event Camera-Based 3D Human Pose Estimation for Technology-Mediated Dancing in Challenging Environments with Comprehensive Motion-to-Event Simulator}, 79 | publisher = {arXiv}, 80 | year = {2023}, 81 | copyright = {arXiv.org perpetual, non-exclusive license} 82 | } 83 | 84 | @InProceedings{Zhang_2022_WACV, 85 | author = {Zhang, Zhongyang and Xu, Zhiyang and Ahmed, Zia and Salekin, Asif and Rahman, Tauhidur}, 86 | title = {Hyperspectral Image Super-Resolution in Arbitrary Input-Output Band Settings}, 87 | booktitle = {Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV) Workshops}, 88 | month = {January}, 89 | year = {2022}, 90 | pages = {749-759} 91 | } 92 | ``` 93 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch-Lightning-Template 2 | 3 | [**Chinese Version 中文版**](./Assets/README_CN.md) 4 | 5 | ## Introduction 6 | 7 | Pytorch-Lightning is a very convenient library. It can be seen as an abstraction and packaging of Pytorch. Its advantages are strong reusability, easy maintenance, clear logic, etc. The disadvantage is that it is too heavy and requires quite a bit of time to learn and understand. In addition, since it directly binds the model and the training code, it is not suitable for real projects with multiple model and dataset files. The same is true for the data module design. The strong coupling of things like DataLoader and custom Datasets also causes a similar problem: the same code is copied and pasted inelegantly here and there. 8 | 9 | After much exploration and practice, I have summarized the following templates, which can also be a further abstraction of Pytorch-Lightning. In the first version, all the template content is under the root folder. However, after using it for more than a month, I found that more specified templates for different types of projects can boost coding efficiency. For example, classification and super-resolution tasks all have some fixed demand points. The project code can be implemented faster by directly modifying specialized templates, and some avoidable bugs have also been reduced. 10 | 11 | **Currently, since this is still a new library, there are only these two templates. However, later as I apply it to other projects, new specialized templates will also be added. If you have used this template for your tasks (such as NLP, GAN, speech recognition, etc.), you are welcome to submit a PR so that you can integrate your template into the library for more people to use. If your task is not on the list yet, starting from the `classification` template is a good choice. Since most of the underlying logic and code of the templates are the same, this can be done very quickly. ** 12 | 13 | Everyone is welcome to try this set of code styles. It is quite convenient to reuse if you are used to it, and it is not easy to fall back into the hole. A more detailed explanation and a complete guide to Pytorch-Lightning can be found in the [this article](https://zhuanlan.zhihu.com/p/353985363) Zhihu blog. 14 | 15 | ## File Structure 16 | 17 | ``` 18 | root- 19 | |-data 20 | |-__init__.py 21 | |-data_interface.py 22 | |-xxxdataset1.py 23 | |-xxxdataset2.py 24 | |-... 25 | |-model 26 | |-__init__.py 27 | |-model_interface.py 28 | |-xxxmodel1.py 29 | |-xxxmodel2.py 30 | |-... 31 | |-main.py 32 | |-utils.py 33 | ``` 34 | 35 | ## Installation 36 | 37 | No installation is needed. Directly run `git clone https://github.com/miracleyoo/pytorch-lightning-template.git` to clone it to your local position. Choose your problem type like `classification`, and copy the corresponding template to your project directory. 38 | 39 | ## Explanation of Structure 40 | 41 | - Thre are only `main.py` and `utils.py` in the root directory. The former is the entrance of the code, and the latter is a support file. 42 | 43 | - There is a `__init__.py` file in both `data` and `modle` folder to make them into packages. In this way, the import becomes easier. 44 | 45 | - Create a `class DInterface(pl.LightningDataModule):` in `data_interface ` to work as the interface of all different customeized Dataset files. Corresponding Dataset class is imported in the `__init__()` function. Instantiation are done in the `setup()`, and `train_dataloader`, `val_dataloader`, `test_dataloader` functions are created. 46 | 47 | - Similarly, class `class MInterface(pl.LightningModule):` are created in `model_interface` to work as the interface of all your model files. Corresponding model class is imported in the `__init__()` function. The only things you need to modify in the interface is the functions like `configure_optimizers`, `training_step`, `validation_step` which control your own training process. One interface for all models, and the difference are handled in args. 48 | 49 | - `main.py` is only responsible for the following tasks: 50 | 51 | - Define parser, add parse items. (Attention: If there are some arguments which are supposed to be controled outside, like in the command line, you can directly add a parse item in `main.py` file. For example, there is a string argument called `random_arg`, you can add `parser.add_argument('--random_arg', default='test', type=str)` to the `main.py` file.) Two `Interface` class will automatically select and pass those arguments to the corresponding model/data class. 52 | - Choose the needed `callback` functions, like auto-save, Early Stop, and LR Scheduler。 53 | - Instantiate `MInterface`, `DInterface`, `Trainer`。 54 | 55 | Fin. 56 | 57 | ## Attention 58 | 59 | **One thing that you need to pay attention to is, in order to let the `MInterface` and `DInterface` be able to parse your newly added models and datasets automatically by simply specify the argument `--model_name` and `--dataset`, we use snake case (like `standard_net.py`) for model/dataset file, and use the same content with camel case for class name, like `StandardNet`.** 60 | 61 | The same is true for `data` folder. 62 | 63 | Although this seems restricting your naming of models and datasets, but it can also make your code easier to read and understand. Please pay attention to this point to avoid parsing issues. 64 | 65 | ## Citation 66 | 67 | If you used this template and find it helpful to your research, please consider citing our paper: 68 | 69 | ``` 70 | @article{ZHANG2023126388, 71 | title = {Neuromorphic high-frequency 3D dancing pose estimation in dynamic environment}, 72 | journal = {Neurocomputing}, 73 | volume = {547}, 74 | pages = {126388}, 75 | year = {2023}, 76 | issn = {0925-2312}, 77 | doi = {https://doi.org/10.1016/j.neucom.2023.126388}, 78 | url = {https://www.sciencedirect.com/science/article/pii/S0925231223005118}, 79 | author = {Zhongyang Zhang and Kaidong Chai and Haowen Yu and Ramzi Majaj and Francesca Walsh and Edward Wang and Upal Mahbub and Hava Siegelmann and Donghyun Kim and Tauhidur Rahman}, 80 | keywords = {Event Camera, Dynamic Vision Sensor, Neuromorphic Camera, Simulator, Dataset, Deep Learning, Human Pose Estimation, 3D Human Pose Estimation, Technology-Mediated Dancing}, 81 | } 82 | 83 | @InProceedings{Zhang_2022_WACV, 84 | author = {Zhang, Zhongyang and Xu, Zhiyang and Ahmed, Zia and Salekin, Asif and Rahman, Tauhidur}, 85 | title = {Hyperspectral Image Super-Resolution in Arbitrary Input-Output Band Settings}, 86 | booktitle = {Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV) Workshops}, 87 | month = {January}, 88 | year = {2022}, 89 | pages = {749-759} 90 | } 91 | ``` 92 | -------------------------------------------------------------------------------- /classification/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.pyc 3 | dataset/ 4 | lightning_logs/ 5 | MNIST/ 6 | weights/ 7 | /backup -------------------------------------------------------------------------------- /classification/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /classification/README.md: -------------------------------------------------------------------------------- 1 | # Pytorch-Lightning-Template: Classification 2 | 3 | ## Introduction 4 | 5 | 本目录主要提供的是classification类型的template。 6 | 7 | 不同类型的template的主要区别在于: 8 | 1. `main.py` 中callbacks的观察对象即命名方法(这里是`val_acc`)。 9 | 2. `model/model_interface.py` 中增加了对`val_acc`的计算。 10 | 3. `model`中加入了特制的`standard_net.py`,用于应对各种常见预训练模型问题。 11 | 4. `data`中的`standard_data.py`提供了分类问题中常见的数据处理方法。 12 | 5. 其他一些细节。 -------------------------------------------------------------------------------- /classification/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Zhongyang Zhang 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .data_interface import DInterface -------------------------------------------------------------------------------- /classification/data/data_interface.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Zhongyang Zhang 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import inspect 16 | import importlib 17 | import pickle as pkl 18 | import pytorch_lightning as pl 19 | from torch.utils.data import DataLoader 20 | from torch.utils.data.sampler import WeightedRandomSampler 21 | 22 | 23 | class DInterface(pl.LightningDataModule): 24 | 25 | def __init__(self, num_workers=8, 26 | dataset='', 27 | **kwargs): 28 | super().__init__() 29 | self.num_workers = num_workers 30 | self.dataset = dataset 31 | self.kwargs = kwargs 32 | self.batch_size = kwargs['batch_size'] 33 | self.load_data_module() 34 | 35 | def setup(self, stage=None): 36 | # Assign train/val datasets for use in dataloaders 37 | if stage == 'fit' or stage is None: 38 | self.trainset = self.instancialize(train=True) 39 | self.valset = self.instancialize(train=False) 40 | 41 | # Assign test dataset for use in dataloader(s) 42 | if stage == 'test' or stage is None: 43 | self.testset = self.instancialize(train=False) 44 | 45 | # # If you need to balance your data using Pytorch Sampler, 46 | # # please uncomment the following lines. 47 | 48 | # with open('./data/ref/samples_weight.pkl', 'rb') as f: 49 | # self.sample_weight = pkl.load(f) 50 | 51 | # def train_dataloader(self): 52 | # sampler = WeightedRandomSampler(self.sample_weight, len(self.trainset)*20) 53 | # return DataLoader(self.trainset, batch_size=self.batch_size, num_workers=self.num_workers, sampler = sampler) 54 | 55 | def train_dataloader(self): 56 | return DataLoader(self.trainset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True) 57 | 58 | def val_dataloader(self): 59 | return DataLoader(self.valset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) 60 | 61 | def test_dataloader(self): 62 | return DataLoader(self.testset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) 63 | 64 | def load_data_module(self): 65 | name = self.dataset 66 | # Change the `snake_case.py` file name to `CamelCase` class name. 67 | # Please always name your model file name as `snake_case.py` and 68 | # class name corresponding `CamelCase`. 69 | camel_name = ''.join([i.capitalize() for i in name.split('_')]) 70 | try: 71 | self.data_module = getattr(importlib.import_module( 72 | '.'+name, package=__package__), camel_name) 73 | except: 74 | raise ValueError( 75 | f'Invalid Dataset File Name or Invalid Class Name data.{name}.{camel_name}') 76 | 77 | def instancialize(self, **other_args): 78 | """ Instancialize a model using the corresponding parameters 79 | from self.hparams dictionary. You can also input any args 80 | to overwrite the corresponding value in self.kwargs. 81 | """ 82 | class_args = inspect.getargspec(self.data_module.__init__).args[1:] 83 | inkeys = self.kwargs.keys() 84 | args1 = {} 85 | for arg in class_args: 86 | if arg in inkeys: 87 | args1[arg] = self.kwargs[arg] 88 | args1.update(other_args) 89 | return self.data_module(**args1) 90 | -------------------------------------------------------------------------------- /classification/data/ref/.gitkeep: -------------------------------------------------------------------------------- 1 | {\rtf1} -------------------------------------------------------------------------------- /classification/data/standard_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Zhongyang Zhang 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import os.path as op 17 | import numpy as np 18 | import pickle as pkl 19 | import torch.utils.data as data 20 | 21 | from torchvision import transforms 22 | from sklearn.model_selection import train_test_split 23 | 24 | 25 | class StandardData(data.Dataset): 26 | def __init__(self, data_dir=r'data/ref', 27 | class_num=9, 28 | train=True, 29 | no_augment=True, 30 | aug_prob=0.5, 31 | img_mean=(0.485, 0.456, 0.406), 32 | img_std=(0.229, 0.224, 0.225)): 33 | # Set all input args as attributes 34 | self.__dict__.update(locals()) 35 | self.aug = train and not no_augment 36 | 37 | self.check_files() 38 | 39 | def check_files(self): 40 | # This part is the core code block for load your own dataset. 41 | # You can choose to scan a folder, or load a file list pickle 42 | # file, or any other formats. The only thing you need to gua- 43 | # rantee is the `self.path_list` must be given a valid value. 44 | file_list_path = op.join(self.data_dir, 'file_list.pkl') 45 | with open(file_list_path, 'rb') as f: 46 | file_list = pkl.load(f) 47 | 48 | fl_train, fl_val = train_test_split( 49 | file_list, test_size=0.2, random_state=2333) 50 | self.path_list = fl_train if self.train else fl_val 51 | 52 | label_file = './data/ref/label_dict.pkl' 53 | with open(label_file, 'rb') as f: 54 | self.label_dict = pkl.load(f) 55 | 56 | def __len__(self): 57 | return len(self.path_list) 58 | 59 | def to_one_hot(self, idx): 60 | out = np.zeros(self.class_num, dtype=float) 61 | out[idx] = 1 62 | return out 63 | 64 | def __getitem__(self, idx): 65 | path = self.path_list[idx] 66 | filename = op.splitext(op.basename(path))[0] 67 | img = np.load(path).transpose(1, 2, 0) 68 | 69 | labels = self.to_one_hot(self.label_dict[filename.split('_')[0]]) 70 | labels = torch.from_numpy(labels).float() 71 | 72 | trans = torch.nn.Sequential( 73 | transforms.RandomHorizontalFlip(self.aug_prob), 74 | transforms.RandomVerticalFlip(self.aug_prob), 75 | transforms.RandomRotation(10), 76 | transforms.RandomCrop(128), 77 | transforms.Normalize(self.img_mean, self.img_std) 78 | ) if self.train else torch.nn.Sequential( 79 | transforms.CenterCrop(128), 80 | transforms.Normalize(self.img_mean, self.img_std) 81 | ) 82 | 83 | img_tensor = trans(img) 84 | 85 | return img_tensor, labels, filename -------------------------------------------------------------------------------- /classification/main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Zhongyang Zhang 2 | # Contact: mirakuruyoo@gmai.com 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """ This main entrance of the whole project. 17 | 18 | Most of the code should not be changed, please directly 19 | add all the input arguments of your model's constructor 20 | and the dataset file's constructor. The MInterface and 21 | DInterface can be seen as transparent to all your args. 22 | """ 23 | import os 24 | import pytorch_lightning as pl 25 | from argparse import ArgumentParser 26 | from pytorch_lightning import Trainer 27 | import pytorch_lightning.callbacks as plc 28 | from pytorch_lightning.loggers import TensorBoardLogger 29 | 30 | from model import MInterface 31 | from data import DInterface 32 | from utils import load_model_path_by_args 33 | 34 | 35 | def load_callbacks(): 36 | callbacks = [] 37 | callbacks.append(plc.EarlyStopping( 38 | monitor='val_acc', 39 | mode='max', 40 | patience=10, 41 | min_delta=0.001 42 | )) 43 | 44 | callbacks.append(plc.ModelCheckpoint( 45 | monitor='val_acc', 46 | filename='best-{epoch:02d}-{val_acc:.3f}', 47 | save_top_k=1, 48 | mode='max', 49 | save_last=True 50 | )) 51 | 52 | if args.lr_scheduler: 53 | callbacks.append(plc.LearningRateMonitor( 54 | logging_interval='epoch')) 55 | return callbacks 56 | 57 | 58 | def main(args): 59 | pl.seed_everything(args.seed) 60 | load_path = load_model_path_by_args(args) 61 | data_module = DInterface(**vars(args)) 62 | 63 | if load_path is None: 64 | model = MInterface(**vars(args)) 65 | else: 66 | model = MInterface(**vars(args)) 67 | args.ckpt_path = load_path 68 | 69 | # # If you want to change the logger's saving folder 70 | # logger = TensorBoardLogger(save_dir='kfold_log', name=args.log_dir) 71 | # args.callbacks = load_callbacks() 72 | # args.logger = logger 73 | 74 | trainer = Trainer.from_argparse_args(args) 75 | trainer.fit(model, data_module) 76 | 77 | 78 | if __name__ == '__main__': 79 | parser = ArgumentParser() 80 | # Basic Training Control 81 | parser.add_argument('--batch_size', default=32, type=int) 82 | parser.add_argument('--num_workers', default=8, type=int) 83 | parser.add_argument('--seed', default=1234, type=int) 84 | parser.add_argument('--lr', default=1e-3, type=float) 85 | 86 | # LR Scheduler 87 | parser.add_argument('--lr_scheduler', choices=['step', 'cosine'], type=str) 88 | parser.add_argument('--lr_decay_steps', default=20, type=int) 89 | parser.add_argument('--lr_decay_rate', default=0.5, type=float) 90 | parser.add_argument('--lr_decay_min_lr', default=1e-5, type=float) 91 | 92 | # Restart Control 93 | parser.add_argument('--load_best', action='store_true') 94 | parser.add_argument('--load_dir', default=None, type=str) 95 | parser.add_argument('--load_ver', default=None, type=str) 96 | parser.add_argument('--load_v_num', default=None, type=int) 97 | 98 | # Training Info 99 | parser.add_argument('--dataset', default='standard_data', type=str) 100 | parser.add_argument('--data_dir', default='ref/data', type=str) 101 | parser.add_argument('--model_name', default='standard_net', type=str) 102 | parser.add_argument('--loss', default='bce', type=str) 103 | parser.add_argument('--weight_decay', default=1e-5, type=float) 104 | parser.add_argument('--no_augment', action='store_true') 105 | parser.add_argument('--log_dir', default='lightning_logs', type=str) 106 | 107 | # Model Hyperparameters 108 | parser.add_argument('--hid', default=64, type=int) 109 | parser.add_argument('--block_num', default=8, type=int) 110 | parser.add_argument('--in_channel', default=3, type=int) 111 | parser.add_argument('--layer_num', default=5, type=int) 112 | 113 | # Other 114 | parser.add_argument('--aug_prob', default=0.5, type=float) 115 | 116 | # Add pytorch lightning's args to parser as a group. 117 | parser = Trainer.add_argparse_args(parser) 118 | 119 | ## Deprecated, old version 120 | # parser = Trainer.add_argparse_args( 121 | # parser.add_argument_group(title="pl.Trainer args")) 122 | 123 | # Reset Some Default Trainer Arguments' Default Values 124 | parser.set_defaults(max_epochs=100) 125 | 126 | args = parser.parse_args() 127 | 128 | # List Arguments 129 | args.mean_sen = [0.485, 0.456, 0.406] 130 | args.std_sen = [0.229, 0.224, 0.225] 131 | 132 | main(args) 133 | -------------------------------------------------------------------------------- /classification/model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Zhongyang Zhang 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .model_interface import MInterface -------------------------------------------------------------------------------- /classification/model/common.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Zhongyang Zhang 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | import torch 17 | import torch.nn as nn 18 | 19 | 20 | def conv3x3(in_channels, out_channels, kernel_size, bias=True, stride=1): 21 | return nn.Conv2d( 22 | in_channels, out_channels, kernel_size, 23 | padding=(kernel_size//2), bias=bias, stride=stride) 24 | 25 | 26 | def mean_shift_1d(data, mean, std, base=1, add=False): 27 | if add: 28 | data = data * std / base + mean 29 | else: 30 | data = (data - mean) / std * base 31 | return data 32 | 33 | 34 | def mean_shift_2d(data, mean, std, base=1, add=False): 35 | data = data.permute(2, 3, 0, 1) 36 | 37 | if add: 38 | data = data * std / base + mean 39 | else: 40 | data = (data - mean) / std * base 41 | return data.permute(2, 3, 0, 1) 42 | 43 | 44 | class BasicBlock(nn.Sequential): 45 | def __init__( 46 | self, in_channels, out_channels, kernel_size, stride=1, bias=True, 47 | bn=False, act=nn.ReLU(True)): 48 | 49 | m = [nn.Conv2d( 50 | in_channels, out_channels, kernel_size, 51 | padding=(kernel_size//2), stride=stride, bias=bias) 52 | ] 53 | if bn: 54 | m.append(nn.BatchNorm2d(out_channels)) 55 | if act is not None: 56 | m.append(act) 57 | super(BasicBlock, self).__init__(*m) 58 | 59 | 60 | class ResBlock(nn.Module): 61 | def __init__( 62 | self, conv, n_feats, kernel_size, 63 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 64 | 65 | super(ResBlock, self).__init__() 66 | m = [] 67 | for i in range(2): 68 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) 69 | if bn: 70 | m.append(nn.BatchNorm2d(n_feats)) 71 | if i == 0: 72 | m.append(act) 73 | 74 | self.body = nn.Sequential(*m) 75 | self.res_scale = res_scale 76 | 77 | def forward(self, x): 78 | res = self.body(x).mul(self.res_scale) 79 | res += x 80 | return res 81 | 82 | 83 | class Upsampler(nn.Sequential): 84 | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): 85 | 86 | m = [] 87 | if scale in (2, 4, 8): # Is scale = 2^n? 88 | for _ in range(int(math.log(scale, 2))): 89 | m.append(conv(n_feats, 4 * n_feats, 3, bias)) 90 | m.append(nn.PixelShuffle(2)) 91 | if bn: 92 | m.append(nn.BatchNorm2d(n_feats)) 93 | 94 | if act == 'relu': 95 | m.append(nn.ReLU(True)) 96 | elif act == 'prelu': 97 | m.append(nn.PReLU(n_feats)) 98 | 99 | elif scale == 3: 100 | m.append(conv(n_feats, 9 * n_feats, 3, bias)) 101 | m.append(nn.PixelShuffle(3)) 102 | if bn: 103 | m.append(nn.BatchNorm2d(n_feats)) 104 | 105 | if act == 'relu': 106 | m.append(nn.ReLU(True)) 107 | elif act == 'prelu': 108 | m.append(nn.PReLU(n_feats)) 109 | else: 110 | raise NotImplementedError 111 | 112 | super(Upsampler, self).__init__(*m) 113 | 114 | 115 | ## ---------------------- RDB Modules ------------------------ ## 116 | class RDB_Conv(nn.Module): 117 | """ Residual Dense Convolution. 118 | """ 119 | 120 | def __init__(self, inChannels, growRate, kSize=3): 121 | super(RDB_Conv, self).__init__() 122 | Cin = inChannels 123 | G = growRate 124 | self.conv = nn.Sequential(*[ 125 | nn.Conv2d(Cin, G, kSize, padding=(kSize - 1) // 2, stride=1), 126 | nn.ReLU() 127 | ]) 128 | 129 | def forward(self, x): 130 | out = self.conv(x) 131 | return torch.cat((x, out), 1) 132 | 133 | 134 | class RDB(nn.Module): 135 | """ Residual Dense Block. 136 | """ 137 | 138 | def __init__(self, growRate0, growRate, nConvLayers, kSize=3): 139 | super(RDB, self).__init__() 140 | G0 = growRate0 141 | G = growRate 142 | C = nConvLayers 143 | 144 | convs = [] 145 | for c in range(C): 146 | convs.append(RDB_Conv(G0 + c * G, G)) 147 | self.convs = nn.Sequential(*convs) 148 | 149 | # Local Feature Fusion 150 | self.LFF = nn.Conv2d(G0 + C * G, G0, 1, padding=0, stride=1) 151 | 152 | def forward(self, x): 153 | return self.LFF(self.convs(x)) + x 154 | -------------------------------------------------------------------------------- /classification/model/model_interface.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Zhongyang Zhang 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import inspect 16 | import torch 17 | import importlib 18 | from torch.nn import functional as F 19 | import torch.optim.lr_scheduler as lrs 20 | 21 | import pytorch_lightning as pl 22 | 23 | 24 | class MInterface(pl.LightningModule): 25 | def __init__(self, model_name, loss, lr, **kargs): 26 | super().__init__() 27 | self.save_hyperparameters() 28 | self.load_model() 29 | self.configure_loss() 30 | 31 | def forward(self, img): 32 | return self.model(img) 33 | 34 | def training_step(self, batch, batch_idx): 35 | img, labels, filename = batch 36 | out = self(img) 37 | loss = self.loss_function(out, labels) 38 | self.log('loss', loss, on_step=True, on_epoch=True, prog_bar=True) 39 | return loss 40 | 41 | def validation_step(self, batch, batch_idx): 42 | img, labels, filename = batch 43 | out = self(img) 44 | loss = self.loss_function(out, labels) 45 | label_digit = labels.argmax(axis=1) 46 | out_digit = out.argmax(axis=1) 47 | 48 | correct_num = sum(label_digit == out_digit).cpu().item() 49 | 50 | self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True) 51 | self.log('val_acc', correct_num/len(out_digit), 52 | on_step=False, on_epoch=True, prog_bar=True) 53 | 54 | return (correct_num, len(out_digit)) 55 | 56 | def test_step(self, batch, batch_idx): 57 | # Here we just reuse the validation_step for testing 58 | return self.validation_step(batch, batch_idx) 59 | 60 | def on_validation_epoch_end(self): 61 | # Make the Progress Bar leave there 62 | self.print('') 63 | 64 | def configure_optimizers(self): 65 | if hasattr(self.hparams, 'weight_decay'): 66 | weight_decay = self.hparams.weight_decay 67 | else: 68 | weight_decay = 0 69 | optimizer = torch.optim.Adam( 70 | self.parameters(), lr=self.hparams.lr, weight_decay=weight_decay) 71 | 72 | if self.hparams.lr_scheduler is None: 73 | return optimizer 74 | else: 75 | if self.hparams.lr_scheduler == 'step': 76 | scheduler = lrs.StepLR(optimizer, 77 | step_size=self.hparams.lr_decay_steps, 78 | gamma=self.hparams.lr_decay_rate) 79 | elif self.hparams.lr_scheduler == 'cosine': 80 | scheduler = lrs.CosineAnnealingLR(optimizer, 81 | T_max=self.hparams.lr_decay_steps, 82 | eta_min=self.hparams.lr_decay_min_lr) 83 | else: 84 | raise ValueError('Invalid lr_scheduler type!') 85 | return [optimizer], [scheduler] 86 | 87 | def configure_loss(self): 88 | loss = self.hparams.loss.lower() 89 | if loss == 'mse': 90 | self.loss_function = F.mse_loss 91 | elif loss == 'l1': 92 | self.loss_function = F.l1_loss 93 | elif loss == 'bce': 94 | self.loss_function = F.binary_cross_entropy 95 | else: 96 | raise ValueError("Invalid Loss Type!") 97 | 98 | def load_model(self): 99 | name = self.hparams.model_name 100 | # Change the `snake_case.py` file name to `CamelCase` class name. 101 | # Please always name your model file name as `snake_case.py` and 102 | # class name corresponding `CamelCase`. 103 | camel_name = ''.join([i.capitalize() for i in name.split('_')]) 104 | try: 105 | Model = getattr(importlib.import_module( 106 | '.'+name, package=__package__), camel_name) 107 | except: 108 | raise ValueError( 109 | f'Invalid Module File Name or Invalid Class Name {name}.{camel_name}!') 110 | self.model = self.instancialize(Model) 111 | 112 | def instancialize(self, Model, **other_args): 113 | """ Instancialize a model using the corresponding parameters 114 | from self.hparams dictionary. You can also input any args 115 | to overwrite the corresponding value in self.hparams. 116 | """ 117 | class_args = inspect.getargspec(Model.__init__).args[1:] 118 | inkeys = self.hparams.keys() 119 | args1 = {} 120 | for arg in class_args: 121 | if arg in inkeys: 122 | args1[arg] = getattr(self.hparams, arg) 123 | args1.update(other_args) 124 | return Model(**args1) 125 | -------------------------------------------------------------------------------- /classification/model/simple_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch import nn 4 | from . import common 5 | 6 | class SimpleNet(nn.Module): 7 | def __init__(self, in_channel=3, out_channel=10, hid=128, layer_num=5): 8 | super().__init__() 9 | body = [common.conv3x3(in_channel, hid, 3), 10 | nn.ReLU()] 11 | for _ in range(layer_num-1): 12 | body.append(common.conv3x3(hid, hid, 3)) 13 | body.append(nn.ReLU()) 14 | 15 | self.body = nn.Sequential(*body) 16 | self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) 17 | self.classifier = nn.Sequential( 18 | nn.Dropout(), 19 | nn.Linear(hid * 6 * 6, 2048), 20 | nn.ReLU(inplace=True), 21 | nn.Dropout(), 22 | nn.Linear(2048, 2048), 23 | nn.ReLU(inplace=True), 24 | nn.Linear(2048, out_channel), 25 | nn.Sigmoid() 26 | ) 27 | 28 | def forward(self, x): 29 | x = self.body(x) 30 | x = self.avgpool(x) 31 | x = torch.flatten(x, 1) 32 | x = self.classifier(x) 33 | return x 34 | -------------------------------------------------------------------------------- /classification/model/standard_net.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Zhongyang Zhang 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | from torch import nn 17 | 18 | class StandardNet(nn.Module): 19 | """ If you want to use pretrained model, or simply the standard structure implemented 20 | by Pytorch official, please use this template. It enable you to easily control whether 21 | use or not the pretrained weights, and whether to freeze the internal layers or not, 22 | and the in/out channel numbers, resnet version. This is made for resnet, but you can 23 | also adapt it to other structures by changing the `torch.hub.load` content. 24 | """ 25 | def __init__(self, in_channel=3, out_channel=10, resnet_name='resnet18', freeze=False, pretrained=False): 26 | super().__init__() 27 | print(in_channel, out_channel) 28 | self.resnet = torch.hub.load('pytorch/vision:v0.9.0', resnet_name, pretrained=pretrained) 29 | 30 | if freeze: 31 | for param in self.resnet.parameters(): 32 | param.requires_grad = False 33 | 34 | inter_ftrs = self.resnet.conv1.out_channels 35 | self.resnet.conv1 = nn.Conv2d(in_channel, inter_ftrs, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) 36 | 37 | num_ftrs = self.resnet.fc.in_features 38 | self.resnet.fc = nn.Linear(num_ftrs, out_channel) 39 | 40 | self.sigmoid = nn.Sigmoid() 41 | 42 | def forward(self, x): 43 | x = self.resnet(x) 44 | x = self.sigmoid(x) 45 | return x 46 | -------------------------------------------------------------------------------- /classification/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Zhongyang Zhang 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | from pathlib2 import Path 17 | 18 | def load_model_path(root=None, version=None, v_num=None, best=False): 19 | """ When best = True, return the best model's path in a directory 20 | by selecting the best model with largest epoch. If not, return 21 | the last model saved. You must provide at least one of the 22 | first three args. 23 | Args: 24 | root: The root directory of checkpoints. It can also be a 25 | model ckpt file. Then the function will return it. 26 | version: The name of the version you are going to load. 27 | v_num: The version's number that you are going to load. 28 | best: Whether return the best model. 29 | """ 30 | def sort_by_epoch(path): 31 | name = path.stem 32 | epoch=int(name.split('-')[1].split('=')[1]) 33 | return epoch 34 | 35 | def generate_root(): 36 | if root is not None: 37 | return root 38 | elif version is not None: 39 | return str(Path('lightning_logs', version, 'checkpoints')) 40 | else: 41 | return str(Path('lightning_logs', f'version_{v_num}', 'checkpoints')) 42 | 43 | if root==version==v_num==None: 44 | return None 45 | 46 | root = generate_root() 47 | if Path(root).is_file(): 48 | return root 49 | if best: 50 | files=[i for i in list(Path(root).iterdir()) if i.stem.startswith('best')] 51 | files.sort(key=sort_by_epoch, reverse=True) 52 | res = str(files[0]) 53 | else: 54 | res = str(Path(root) / 'last.ckpt') 55 | return res 56 | 57 | def load_model_path_by_args(args): 58 | return load_model_path(root=args.load_dir, version=args.load_ver, v_num=args.load_v_num) -------------------------------------------------------------------------------- /special/kfold/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.pyc 3 | dataset/ 4 | lightning_logs/ 5 | MNIST/ 6 | weights/ 7 | /backup -------------------------------------------------------------------------------- /special/kfold/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /special/kfold/README.md: -------------------------------------------------------------------------------- 1 | # Pytorch-Lightning-Template: Classification 2 | 3 | ## Introduction 4 | 5 | This sub-directory mainly provide a kfold-training and evaluating solution. 6 | 7 | 1. Add `kfold_train.sh` file to easily train k-fold all k models. 8 | 2. Add corresponding arguments in `main.py` file. 9 | 3. Add kfold support in `data/standard_data.py` file. 10 | 4. Add a jupyter notebook to easily load all models and evaluate them, as well as visulation. -------------------------------------------------------------------------------- /special/kfold/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Zhongyang Zhang 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .data_interface import DInterface -------------------------------------------------------------------------------- /special/kfold/data/data_interface.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Zhongyang Zhang 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import inspect 16 | import importlib 17 | import pickle as pkl 18 | import pytorch_lightning as pl 19 | from torch.utils.data import DataLoader 20 | from torch.utils.data.sampler import WeightedRandomSampler 21 | 22 | 23 | class DInterface(pl.LightningDataModule): 24 | 25 | def __init__(self, num_workers=8, 26 | dataset='', 27 | **kwargs): 28 | super().__init__() 29 | self.num_workers = num_workers 30 | self.dataset = dataset 31 | self.kwargs = kwargs 32 | self.batch_size = kwargs['batch_size'] 33 | self.load_data_module() 34 | 35 | def setup(self, stage=None): 36 | # Assign train/val datasets for use in dataloaders 37 | if stage == 'fit' or stage is None: 38 | self.trainset = self.instancialize(train=True) 39 | self.valset = self.instancialize(train=False) 40 | 41 | # Assign test dataset for use in dataloader(s) 42 | if stage == 'test' or stage is None: 43 | self.testset = self.instancialize(train=False) 44 | 45 | # # If you need to balance your data using Pytorch Sampler, 46 | # # please uncomment the following lines. 47 | 48 | # with open('./data/ref/samples_weight.pkl', 'rb') as f: 49 | # self.sample_weight = pkl.load(f) 50 | 51 | # def train_dataloader(self): 52 | # sampler = WeightedRandomSampler(self.sample_weight, len(self.trainset)*20) 53 | # return DataLoader(self.trainset, batch_size=self.batch_size, num_workers=self.num_workers, sampler = sampler) 54 | 55 | def train_dataloader(self): 56 | return DataLoader(self.trainset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True) 57 | 58 | def val_dataloader(self): 59 | return DataLoader(self.valset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) 60 | 61 | def test_dataloader(self): 62 | return DataLoader(self.testset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) 63 | 64 | def load_data_module(self): 65 | name = self.dataset 66 | # Change the `snake_case.py` file name to `CamelCase` class name. 67 | # Please always name your model file name as `snake_case.py` and 68 | # class name corresponding `CamelCase`. 69 | camel_name = ''.join([i.capitalize() for i in name.split('_')]) 70 | try: 71 | self.data_module = getattr(importlib.import_module( 72 | '.'+name, package=__package__), camel_name) 73 | except: 74 | raise ValueError( 75 | f'Invalid Dataset File Name or Invalid Class Name data.{name}.{camel_name}') 76 | 77 | def instancialize(self, **other_args): 78 | """ Instancialize a model using the corresponding parameters 79 | from self.hparams dictionary. You can also input any args 80 | to overwrite the corresponding value in self.kwargs. 81 | """ 82 | class_args = inspect.getargspec(self.data_module.__init__).args[1:] 83 | inkeys = self.kwargs.keys() 84 | args1 = {} 85 | for arg in class_args: 86 | if arg in inkeys: 87 | args1[arg] = self.kwargs[arg] 88 | args1.update(other_args) 89 | return self.data_module(**args1) 90 | -------------------------------------------------------------------------------- /special/kfold/data/ref/.gitkeep: -------------------------------------------------------------------------------- 1 | {\rtf1} -------------------------------------------------------------------------------- /special/kfold/data/standard_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Zhongyang Zhang 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import os.path as op 17 | import numpy as np 18 | import pickle as pkl 19 | import torch.utils.data as data 20 | 21 | from torchvision import transforms 22 | from sklearn.model_selection import train_test_split 23 | from sklearn.model_selection import KFold 24 | 25 | 26 | class StandardData(data.Dataset): 27 | def __init__(self, data_dir=r'data/ref', 28 | class_num=9, 29 | train=True, 30 | no_augment=True, 31 | aug_prob=0.5, 32 | img_mean=(0.485, 0.456, 0.406), 33 | img_std=(0.229, 0.224, 0.225), 34 | kfold=0, 35 | fold_num=0): 36 | # Set all input args as attributes 37 | self.__dict__.update(locals()) 38 | self.aug = train and not no_augment 39 | 40 | self.check_files() 41 | 42 | def check_files(self): 43 | # This part is the core code block for load your own dataset. 44 | # You can choose to scan a folder, or load a file list pickle 45 | # file, or any other formats. The only thing you need to gua- 46 | # rantee is the `self.path_list` must be given a valid value. 47 | file_list_path = op.join(self.data_dir, 'file_list.pkl') 48 | with open(file_list_path, 'rb') as f: 49 | file_list = pkl.load(f) 50 | 51 | if self.kfold != 0: 52 | kf = KFold(n_splits=self.kfold, shuffle=True, random_state=2333) 53 | fl_train_idx, fl_val_idx = list(kf.split(file_list))[self.fold_num] 54 | fl_train = np.array(file_list)[fl_train_idx] 55 | fl_val = np.array(file_list)[fl_val_idx] 56 | else: 57 | fl_train, fl_val = train_test_split( 58 | file_list, test_size=0.2, random_state=2333) 59 | 60 | self.path_list = fl_train if self.train else fl_val 61 | 62 | label_file = './data/ref/label_dict.pkl' 63 | with open(label_file, 'rb') as f: 64 | self.label_dict = pkl.load(f) 65 | 66 | def __len__(self): 67 | return len(self.path_list) 68 | 69 | def to_one_hot(self, idx): 70 | out = np.zeros(self.class_num, dtype=float) 71 | out[idx] = 1 72 | return out 73 | 74 | def __getitem__(self, idx): 75 | path = self.path_list[idx] 76 | filename = op.splitext(op.basename(path))[0] 77 | img = np.load(path).transpose(1, 2, 0) 78 | 79 | labels = self.to_one_hot(self.label_dict[filename.split('_')[0]]) 80 | labels = torch.from_numpy(labels).float() 81 | 82 | trans = torch.nn.Sequential( 83 | transforms.RandomHorizontalFlip(self.aug_prob), 84 | transforms.RandomVerticalFlip(self.aug_prob), 85 | transforms.RandomRotation(10), 86 | transforms.RandomCrop(128), 87 | transforms.Normalize(self.img_mean, self.img_std) 88 | ) if self.train else torch.nn.Sequential( 89 | transforms.CenterCrop(128), 90 | transforms.Normalize(self.img_mean, self.img_std) 91 | ) 92 | 93 | img_tensor = trans(img) 94 | 95 | return img_tensor, labels, filename -------------------------------------------------------------------------------- /special/kfold/kfold_train.sh: -------------------------------------------------------------------------------- 1 | for i in {0..4} 2 | do 3 | echo "Working on $i th fold." 4 | python main.py --gpus=1 --train_type=sr --data_dir=data/ref --batch_size=128 --model_name=simple_net --layer_num=5 --kfold=5 --fold_num=$i --log_dir=sr 5 | done -------------------------------------------------------------------------------- /special/kfold/main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Zhongyang Zhang 2 | # Contact: mirakuruyoo@gmai.com 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """ This main entrance of the whole project. 17 | 18 | Most of the code should not be changed, please directly 19 | add all the input arguments of your model's constructor 20 | and the dataset file's constructor. The MInterface and 21 | DInterface can be seen as transparent to all your args. 22 | """ 23 | import os 24 | import pytorch_lightning as pl 25 | from argparse import ArgumentParser 26 | from pytorch_lightning import Trainer 27 | import pytorch_lightning.callbacks as plc 28 | from pytorch_lightning.loggers import TensorBoardLogger 29 | 30 | from model import MInterface 31 | from data import DInterface 32 | from utils import load_model_path_by_args 33 | 34 | 35 | def load_callbacks(): 36 | callbacks = [] 37 | callbacks.append(plc.EarlyStopping( 38 | monitor='val_acc', 39 | mode='max', 40 | patience=10, 41 | min_delta=0.001 42 | )) 43 | 44 | callbacks.append(plc.ModelCheckpoint( 45 | monitor='val_acc', 46 | filename='best-{epoch:02d}-{val_acc:.3f}', 47 | save_top_k=1, 48 | mode='max', 49 | save_last=True 50 | )) 51 | 52 | if args.lr_scheduler: 53 | callbacks.append(plc.LearningRateMonitor( 54 | logging_interval='epoch')) 55 | return callbacks 56 | 57 | 58 | def main(args): 59 | pl.seed_everything(args.seed) 60 | load_path = load_model_path_by_args(args) 61 | data_module = DInterface(**vars(args)) 62 | 63 | if load_path is None: 64 | model = MInterface(**vars(args)) 65 | else: 66 | model = MInterface(**vars(args)) 67 | args.ckpt_path = load_path 68 | 69 | # If you want to change the logger's saving folder 70 | logger = TensorBoardLogger(save_dir='kfold_log', name=args.log_dir) 71 | args.callbacks = load_callbacks() 72 | args.logger = logger 73 | 74 | trainer = Trainer.from_argparse_args(args) 75 | trainer.fit(model, data_module) 76 | 77 | 78 | if __name__ == '__main__': 79 | parser = ArgumentParser() 80 | # Basic Training Control 81 | parser.add_argument('--batch_size', default=32, type=int) 82 | parser.add_argument('--num_workers', default=8, type=int) 83 | parser.add_argument('--seed', default=1234, type=int) 84 | parser.add_argument('--lr', default=1e-3, type=float) 85 | 86 | # LR Scheduler 87 | parser.add_argument('--lr_scheduler', choices=['step', 'cosine'], type=str) 88 | parser.add_argument('--lr_decay_steps', default=20, type=int) 89 | parser.add_argument('--lr_decay_rate', default=0.5, type=float) 90 | parser.add_argument('--lr_decay_min_lr', default=1e-5, type=float) 91 | 92 | # Restart Control 93 | parser.add_argument('--load_best', action='store_true') 94 | parser.add_argument('--load_dir', default=None, type=str) 95 | parser.add_argument('--load_ver', default=None, type=str) 96 | parser.add_argument('--load_v_num', default=None, type=int) 97 | 98 | # Training Info 99 | parser.add_argument('--dataset', default='standard_data', type=str) 100 | parser.add_argument('--data_dir', default='ref/data', type=str) 101 | parser.add_argument('--model_name', default='standard_net', type=str) 102 | parser.add_argument('--loss', default='bce', type=str) 103 | parser.add_argument('--weight_decay', default=1e-5, type=float) 104 | parser.add_argument('--no_augment', action='store_true') 105 | parser.add_argument('--log_dir', default='lightning_logs', type=str) 106 | 107 | # Model Hyperparameters 108 | parser.add_argument('--hid', default=64, type=int) 109 | parser.add_argument('--block_num', default=8, type=int) 110 | parser.add_argument('--in_channel', default=3, type=int) 111 | parser.add_argument('--layer_num', default=5, type=int) 112 | 113 | # KFold Support 114 | parser.add_argument('--kfold', default=0, type=int) 115 | parser.add_argument('--fold_num', default=0, type=int) 116 | 117 | # Other 118 | parser.add_argument('--aug_prob', default=0.5, type=float) 119 | 120 | parser = Trainer.add_argparse_args( 121 | parser.add_argument_group(title="pl.Trainer args")) 122 | 123 | # Reset Some Default Trainer Arguments' Default Values 124 | parser.set_defaults(max_epochs=100) 125 | 126 | args = parser.parse_args() 127 | 128 | # List Arguments 129 | args.mean_sen = [0.485, 0.456, 0.406] 130 | args.std_sen = [0.229, 0.224, 0.225] 131 | 132 | main(args) 133 | -------------------------------------------------------------------------------- /special/kfold/model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Zhongyang Zhang 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .model_interface import MInterface -------------------------------------------------------------------------------- /special/kfold/model/common.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Zhongyang Zhang 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | import torch 17 | import torch.nn as nn 18 | 19 | 20 | def conv3x3(in_channels, out_channels, kernel_size, bias=True, stride=1): 21 | return nn.Conv2d( 22 | in_channels, out_channels, kernel_size, 23 | padding=(kernel_size//2), bias=bias, stride=stride) 24 | 25 | 26 | def mean_shift_1d(data, mean, std, base=1, add=False): 27 | if add: 28 | data = data * std / base + mean 29 | else: 30 | data = (data - mean) / std * base 31 | return data 32 | 33 | 34 | def mean_shift_2d(data, mean, std, base=1, add=False): 35 | data = data.permute(2, 3, 0, 1) 36 | 37 | if add: 38 | data = data * std / base + mean 39 | else: 40 | data = (data - mean) / std * base 41 | return data.permute(2, 3, 0, 1) 42 | 43 | 44 | class BasicBlock(nn.Sequential): 45 | def __init__( 46 | self, in_channels, out_channels, kernel_size, stride=1, bias=True, 47 | bn=False, act=nn.ReLU(True)): 48 | 49 | m = [nn.Conv2d( 50 | in_channels, out_channels, kernel_size, 51 | padding=(kernel_size//2), stride=stride, bias=bias) 52 | ] 53 | if bn: 54 | m.append(nn.BatchNorm2d(out_channels)) 55 | if act is not None: 56 | m.append(act) 57 | super(BasicBlock, self).__init__(*m) 58 | 59 | 60 | class ResBlock(nn.Module): 61 | def __init__( 62 | self, conv, n_feats, kernel_size, 63 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 64 | 65 | super(ResBlock, self).__init__() 66 | m = [] 67 | for i in range(2): 68 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) 69 | if bn: 70 | m.append(nn.BatchNorm2d(n_feats)) 71 | if i == 0: 72 | m.append(act) 73 | 74 | self.body = nn.Sequential(*m) 75 | self.res_scale = res_scale 76 | 77 | def forward(self, x): 78 | res = self.body(x).mul(self.res_scale) 79 | res += x 80 | return res 81 | 82 | 83 | class Upsampler(nn.Sequential): 84 | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): 85 | 86 | m = [] 87 | if scale in (2, 4, 8): # Is scale = 2^n? 88 | for _ in range(int(math.log(scale, 2))): 89 | m.append(conv(n_feats, 4 * n_feats, 3, bias)) 90 | m.append(nn.PixelShuffle(2)) 91 | if bn: 92 | m.append(nn.BatchNorm2d(n_feats)) 93 | 94 | if act == 'relu': 95 | m.append(nn.ReLU(True)) 96 | elif act == 'prelu': 97 | m.append(nn.PReLU(n_feats)) 98 | 99 | elif scale == 3: 100 | m.append(conv(n_feats, 9 * n_feats, 3, bias)) 101 | m.append(nn.PixelShuffle(3)) 102 | if bn: 103 | m.append(nn.BatchNorm2d(n_feats)) 104 | 105 | if act == 'relu': 106 | m.append(nn.ReLU(True)) 107 | elif act == 'prelu': 108 | m.append(nn.PReLU(n_feats)) 109 | else: 110 | raise NotImplementedError 111 | 112 | super(Upsampler, self).__init__(*m) 113 | 114 | 115 | ## ---------------------- RDB Modules ------------------------ ## 116 | class RDB_Conv(nn.Module): 117 | """ Residual Dense Convolution. 118 | """ 119 | 120 | def __init__(self, inChannels, growRate, kSize=3): 121 | super(RDB_Conv, self).__init__() 122 | Cin = inChannels 123 | G = growRate 124 | self.conv = nn.Sequential(*[ 125 | nn.Conv2d(Cin, G, kSize, padding=(kSize - 1) // 2, stride=1), 126 | nn.ReLU() 127 | ]) 128 | 129 | def forward(self, x): 130 | out = self.conv(x) 131 | return torch.cat((x, out), 1) 132 | 133 | 134 | class RDB(nn.Module): 135 | """ Residual Dense Block. 136 | """ 137 | 138 | def __init__(self, growRate0, growRate, nConvLayers, kSize=3): 139 | super(RDB, self).__init__() 140 | G0 = growRate0 141 | G = growRate 142 | C = nConvLayers 143 | 144 | convs = [] 145 | for c in range(C): 146 | convs.append(RDB_Conv(G0 + c * G, G)) 147 | self.convs = nn.Sequential(*convs) 148 | 149 | # Local Feature Fusion 150 | self.LFF = nn.Conv2d(G0 + C * G, G0, 1, padding=0, stride=1) 151 | 152 | def forward(self, x): 153 | return self.LFF(self.convs(x)) + x 154 | -------------------------------------------------------------------------------- /special/kfold/model/model_interface.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Zhongyang Zhang 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import inspect 16 | import torch 17 | import importlib 18 | from torch.nn import functional as F 19 | import torch.optim.lr_scheduler as lrs 20 | 21 | import pytorch_lightning as pl 22 | 23 | 24 | class MInterface(pl.LightningModule): 25 | def __init__(self, model_name, loss, lr, **kargs): 26 | super().__init__() 27 | self.save_hyperparameters() 28 | self.load_model() 29 | self.configure_loss() 30 | 31 | def forward(self, img): 32 | return self.model(img) 33 | 34 | def training_step(self, batch, batch_idx): 35 | img, labels, filename = batch 36 | out = self(img) 37 | loss = self.loss_function(out, labels) 38 | self.log('loss', loss, on_step=True, on_epoch=True, prog_bar=True) 39 | return loss 40 | 41 | def validation_step(self, batch, batch_idx): 42 | img, labels, filename = batch 43 | out = self(img) 44 | loss = self.loss_function(out, labels) 45 | label_digit = labels.argmax(axis=1) 46 | out_digit = out.argmax(axis=1) 47 | 48 | correct_num = sum(label_digit == out_digit).cpu().item() 49 | 50 | self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True) 51 | self.log('val_acc', correct_num/len(out_digit), 52 | on_step=False, on_epoch=True, prog_bar=True) 53 | 54 | return (correct_num, len(out_digit)) 55 | 56 | def test_step(self, batch, batch_idx): 57 | # Here we just reuse the validation_step for testing 58 | return self.validation_step(batch, batch_idx) 59 | 60 | def on_validation_epoch_end(self): 61 | # Make the Progress Bar leave there 62 | self.print('') 63 | 64 | def configure_optimizers(self): 65 | if hasattr(self.hparams, 'weight_decay'): 66 | weight_decay = self.hparams.weight_decay 67 | else: 68 | weight_decay = 0 69 | optimizer = torch.optim.Adam( 70 | self.parameters(), lr=self.hparams.lr, weight_decay=weight_decay) 71 | 72 | if self.hparams.lr_scheduler is None: 73 | return optimizer 74 | else: 75 | if self.hparams.lr_scheduler == 'step': 76 | scheduler = lrs.StepLR(optimizer, 77 | step_size=self.hparams.lr_decay_steps, 78 | gamma=self.hparams.lr_decay_rate) 79 | elif self.hparams.lr_scheduler == 'cosine': 80 | scheduler = lrs.CosineAnnealingLR(optimizer, 81 | T_max=self.hparams.lr_decay_steps, 82 | eta_min=self.hparams.lr_decay_min_lr) 83 | else: 84 | raise ValueError('Invalid lr_scheduler type!') 85 | return [optimizer], [scheduler] 86 | 87 | def configure_loss(self): 88 | loss = self.hparams.loss.lower() 89 | if loss == 'mse': 90 | self.loss_function = F.mse_loss 91 | elif loss == 'l1': 92 | self.loss_function = F.l1_loss 93 | elif loss == 'bce': 94 | self.loss_function = F.binary_cross_entropy 95 | else: 96 | raise ValueError("Invalid Loss Type!") 97 | 98 | def load_model(self): 99 | name = self.hparams.model_name 100 | # Change the `snake_case.py` file name to `CamelCase` class name. 101 | # Please always name your model file name as `snake_case.py` and 102 | # class name corresponding `CamelCase`. 103 | camel_name = ''.join([i.capitalize() for i in name.split('_')]) 104 | try: 105 | Model = getattr(importlib.import_module( 106 | '.'+name, package=__package__), camel_name) 107 | except: 108 | raise ValueError( 109 | f'Invalid Module File Name or Invalid Class Name {name}.{camel_name}!') 110 | self.model = self.instancialize(Model) 111 | 112 | def instancialize(self, Model, **other_args): 113 | """ Instancialize a model using the corresponding parameters 114 | from self.hparams dictionary. You can also input any args 115 | to overwrite the corresponding value in self.hparams. 116 | """ 117 | class_args = inspect.getargspec(Model.__init__).args[1:] 118 | inkeys = self.hparams.keys() 119 | args1 = {} 120 | for arg in class_args: 121 | if arg in inkeys: 122 | args1[arg] = getattr(self.hparams, arg) 123 | args1.update(other_args) 124 | return Model(**args1) 125 | -------------------------------------------------------------------------------- /special/kfold/model/simple_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch import nn 4 | from . import common 5 | 6 | class SimpleNet(nn.Module): 7 | def __init__(self, in_channel=3, out_channel=10, hid=128, layer_num=5): 8 | super().__init__() 9 | body = [common.conv3x3(in_channel, hid, 3), 10 | nn.ReLU()] 11 | for _ in range(layer_num-1): 12 | body.append(common.conv3x3(hid, hid, 3)) 13 | body.append(nn.ReLU()) 14 | 15 | self.body = nn.Sequential(*body) 16 | self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) 17 | self.classifier = nn.Sequential( 18 | nn.Dropout(), 19 | nn.Linear(hid * 6 * 6, 2048), 20 | nn.ReLU(inplace=True), 21 | nn.Dropout(), 22 | nn.Linear(2048, 2048), 23 | nn.ReLU(inplace=True), 24 | nn.Linear(2048, out_channel), 25 | nn.Sigmoid() 26 | ) 27 | 28 | def forward(self, x): 29 | x = self.body(x) 30 | x = self.avgpool(x) 31 | x = torch.flatten(x, 1) 32 | x = self.classifier(x) 33 | return x 34 | -------------------------------------------------------------------------------- /special/kfold/model/standard_net.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Zhongyang Zhang 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | from torch import nn 17 | 18 | class StandardNet(nn.Module): 19 | """ If you want to use pretrained model, or simply the standard structure implemented 20 | by Pytorch official, please use this template. It enable you to easily control whether 21 | use or not the pretrained weights, and whether to freeze the internal layers or not, 22 | and the in/out channel numbers, resnet version. This is made for resnet, but you can 23 | also adapt it to other structures by changing the `torch.hub.load` content. 24 | """ 25 | def __init__(self, in_channel=3, out_channel=10, resnet_name='resnet18', freeze=False, pretrained=False): 26 | super().__init__() 27 | print(in_channel, out_channel) 28 | self.resnet = torch.hub.load('pytorch/vision:v0.9.0', resnet_name, pretrained=pretrained) 29 | 30 | if freeze: 31 | for param in self.resnet.parameters(): 32 | param.requires_grad = False 33 | 34 | inter_ftrs = self.resnet.conv1.out_channels 35 | self.resnet.conv1 = nn.Conv2d(in_channel, inter_ftrs, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) 36 | 37 | num_ftrs = self.resnet.fc.in_features 38 | self.resnet.fc = nn.Linear(num_ftrs, out_channel) 39 | 40 | self.sigmoid = nn.Sigmoid() 41 | 42 | def forward(self, x): 43 | x = self.resnet(x) 44 | x = self.sigmoid(x) 45 | return x 46 | -------------------------------------------------------------------------------- /special/kfold/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Zhongyang Zhang 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | from pathlib2 import Path 17 | 18 | def load_model_path(root=None, version=None, v_num=None, best=False): 19 | """ When best = True, return the best model's path in a directory 20 | by selecting the best model with largest epoch. If not, return 21 | the last model saved. You must provide at least one of the 22 | first three args. 23 | Args: 24 | root: The root directory of checkpoints. It can also be a 25 | model ckpt file. Then the function will return it. 26 | version: The name of the version you are going to load. 27 | v_num: The version's number that you are going to load. 28 | best: Whether return the best model. 29 | """ 30 | def sort_by_epoch(path): 31 | name = path.stem 32 | epoch=int(name.split('-')[1].split('=')[1]) 33 | return epoch 34 | 35 | def generate_root(): 36 | if root is not None: 37 | return root 38 | elif version is not None: 39 | return str(Path('lightning_logs', version, 'checkpoints')) 40 | else: 41 | return str(Path('lightning_logs', f'version_{v_num}', 'checkpoints')) 42 | 43 | if root==version==v_num==None: 44 | return None 45 | 46 | root = generate_root() 47 | if Path(root).is_file(): 48 | return root 49 | if best: 50 | files=[i for i in list(Path(root).iterdir()) if i.stem.startswith('best')] 51 | files.sort(key=sort_by_epoch, reverse=True) 52 | res = str(files[0]) 53 | else: 54 | res = str(Path(root) / 'last.ckpt') 55 | return res 56 | 57 | def load_model_path_by_args(args): 58 | return load_model_path(root=args.load_dir, version=args.load_ver, v_num=args.load_v_num) -------------------------------------------------------------------------------- /super-resolution/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.pyc 3 | dataset/ 4 | lightning_logs/ 5 | MNIST/ 6 | weights/ 7 | /backup -------------------------------------------------------------------------------- /super-resolution/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /super-resolution/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Zhongyang Zhang 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .data_interface import DInterface -------------------------------------------------------------------------------- /super-resolution/data/common.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Zhongyang Zhang 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import random 16 | import cv2 17 | import torch 18 | import numpy as np 19 | import skimage.color as sc 20 | import pickle as pkl 21 | 22 | from operator import itemgetter 23 | 24 | 25 | class dotdict(dict): 26 | """dot.notation access to dictionary attributes""" 27 | __getattr__ = dict.get 28 | __setattr__ = dict.__setitem__ 29 | __delattr__ = dict.__delitem__ 30 | 31 | 32 | def even_sample(items, number): 33 | """ Evenly sample `number` items from `items`. 34 | """ 35 | indexs = (np.linspace(0, len(items)-1, number).astype(int)).tolist() 36 | return itemgetter(*indexs)(items) 37 | 38 | 39 | def get_patch(*args, patch_size=96, scale=1, its=None): 40 | """ Every image has a different aspect ratio. In order to make 41 | the input shape the same, here we crop a 96*96 patch on LR 42 | image, and crop a corresponding area(96*r, 96*r) on HR image. 43 | Args: 44 | args: lr, hr 45 | patch_size: The x and y length of the crop area on lr image. 46 | scale: r, upscale ratio 47 | Returns: 48 | 0: cropped lr image. 49 | 1: cropped hr image. 50 | """ 51 | ih, iw = args[0].shape[:2] 52 | 53 | tp = int(scale * patch_size) 54 | ip = patch_size 55 | 56 | ix = random.randrange(0, (iw-ip)) 57 | iy = random.randrange(0, (ih-ip)) 58 | 59 | tx, ty = int(scale * ix), int(scale * iy) 60 | 61 | if its is None: 62 | its = np.zeros(len(args)-1, int) 63 | itp = (tp, ip) 64 | ret = [ 65 | args[0][iy:iy + ip, ix:ix + ip, :], 66 | *[a[(ty, iy)[b]:(ty, iy)[b] + itp[b], (tx, ix)[b]:(tx, ix)[b] + itp[b], :] for a, b in zip(args[1:], its)] 67 | ] 68 | return ret 69 | 70 | 71 | def set_channel(*args, n_channels=3): 72 | """ Do the channel number check. If input channel is 73 | not n_channels, convert it to n_channels. 74 | Args: 75 | n_channels: the target channel number. 76 | """ 77 | def _set_channel(img): 78 | if img.ndim == 2: 79 | img = np.expand_dims(img, axis=2) 80 | 81 | c = img.shape[2] 82 | if n_channels == 1 and c == 3: 83 | img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2) 84 | elif n_channels == 3 and c == 1: 85 | img = np.concatenate([img] * n_channels, 2) 86 | 87 | return img 88 | 89 | return [_set_channel(a) for a in args] 90 | 91 | 92 | def bitdepth_convert(image, src=16, dst=8): 93 | """ Convert images with different bit depth. 94 | Args: 95 | image: Input image, and ndarry. 96 | src: source bit depth. 97 | dst: target bit depth. 98 | """ 99 | coe = src - dst 100 | image = (image + 1) / (2 ** coe) - 1 101 | return image 102 | 103 | 104 | def np2Tensor(*args, color_range=255): 105 | """ Transform an numpy array to tensor. Each single value in 106 | the target tensor will be mapped into [0,1] 107 | Args: 108 | color_range: Max value of a single pixel in the original array. 109 | """ 110 | def _np2Tensor(img): 111 | np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1))) 112 | tensor = torch.from_numpy(np_transpose).float() 113 | tensor.mul_(color_range / 255) 114 | return tensor 115 | 116 | return [_np2Tensor(a) for a in args] 117 | 118 | 119 | def augment(*args, hflip=True, rot=True, prob=0.5): 120 | """ Same data augmentation for a series of input images. 121 | Operations included: random horizontal flip, random vertical 122 | flip, random 90 degree rotation. 123 | Args: 124 | args: A list/tuple of images. 125 | hflip: Whether use random horizontal flip. 126 | rot: Whether use random vertical flip and rotation. 127 | prob: The probobility of using augment. 128 | """ 129 | hflip = hflip and random.random() < prob 130 | vflip = rot and random.random() < prob 131 | rot90 = rot and random.random() < prob 132 | 133 | def _augment(img): 134 | if hflip: 135 | img = img[:, ::-1, :] 136 | if vflip: 137 | img = img[::-1, :, :] 138 | if rot90: 139 | img = img.transpose(1, 0, 2) 140 | 141 | return img 142 | 143 | return [_augment(a) for a in args] 144 | 145 | 146 | def black_square(lr, hr, prob=0.5): 147 | """ Randomly select an edge of square between `min_edge//8` to 148 | `min_edge//4` and put the square to a random position. 149 | Args: 150 | lr: LR image. 151 | hr: HR image. 152 | scale: HR/LR scale. 153 | prob: Probability of adding this black square. 154 | """ 155 | if random.random() < prob: 156 | h, w = lr.shape[:2] 157 | scale = hr.shape[0] // h 158 | max_edge = min(h, w)//4 159 | min_edge = min(h, w)//8 160 | edge = random.choice(range(min_edge, max_edge)) 161 | 162 | start_y = random.choice(range(h-edge)) 163 | start_x = random.choice(range(w-edge)) 164 | 165 | lr[start_y:(start_y+edge), start_x:(start_x+edge), :] = 0 166 | hr[int(start_y*scale):int((start_y+edge)*scale), 167 | int(start_x * scale):int((start_x+edge)*scale), :] = 0 168 | return lr, hr 169 | 170 | 171 | def down_up(*args, scales, up_prob=1, prob=0.5): 172 | """ Downscale and then upscale the input iamges. 173 | Args: 174 | args: A list/tuple of images. 175 | scales: donw-up scale list. Like: (1.5, 2, 3, 4) 176 | up_prob: Probability of applying upsample after downsample. 177 | prob: Probability of applying this augment. 178 | """ 179 | def _down_up(img): 180 | if decision: 181 | img = cv2.resize(img, None, fx=1/scale, fy=1/scale, 182 | interpolation=cv2.INTER_CUBIC) # INTER_AREA 183 | if up: 184 | img = cv2.resize(img, None, fx=scale, fy=scale, 185 | interpolation=cv2.INTER_CUBIC) 186 | return img 187 | 188 | decision = random.random() < prob 189 | up = random.random() < up_prob 190 | scale = random.choice(scales) if type(scales) in (list, tuple) else scales 191 | return [_down_up(a) for a in args] 192 | -------------------------------------------------------------------------------- /super-resolution/data/data_interface.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Zhongyang Zhang 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import inspect 16 | import importlib 17 | import pytorch_lightning as pl 18 | from torch.utils.data import DataLoader, random_split 19 | import torchvision.transforms as transforms 20 | 21 | 22 | class DInterface(pl.LightningDataModule): 23 | 24 | def __init__(self, num_workers=8, 25 | dataset='', 26 | **kwargs): 27 | super().__init__() 28 | self.num_workers = num_workers 29 | self.dataset = dataset 30 | self.kwargs = kwargs 31 | self.batch_size = kwargs['batch_size'] 32 | self.load_data_module() 33 | 34 | def setup(self, stage=None): 35 | # Assign train/val datasets for use in dataloaders 36 | if stage == 'fit' or stage is None: 37 | self.trainset = self.instancialize(train=True) 38 | self.valset = self.instancialize(train=False) 39 | 40 | # Assign test dataset for use in dataloader(s) 41 | if stage == 'test' or stage is None: 42 | self.testset = self.instancialize(train=False) 43 | 44 | def train_dataloader(self): 45 | return DataLoader(self.trainset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True) 46 | 47 | def val_dataloader(self): 48 | return DataLoader(self.valset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) 49 | 50 | def test_dataloader(self): 51 | return DataLoader(self.testset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) 52 | 53 | def load_data_module(self): 54 | name = self.dataset 55 | # Change the `snake_case.py` file name to `CamelCase` class name. 56 | # Please always name your model file name as `snake_case.py` and 57 | # class name corresponding `CamelCase`. 58 | camel_name = ''.join([i.capitalize() for i in name.split('_')]) 59 | try: 60 | self.data_module = getattr(importlib.import_module( 61 | '.'+name, package=__package__), camel_name) 62 | except: 63 | raise ValueError( 64 | f'Invalid Dataset File Name or Invalid Class Name data.{name}.{camel_name}') 65 | 66 | def instancialize(self, **other_args): 67 | """ Instancialize a model using the corresponding parameters 68 | from self.hparams dictionary. You can also input any args 69 | to overwrite the corresponding value in self.kwargs. 70 | """ 71 | class_args = inspect.getargspec(self.data_module.__init__).args[1:] 72 | inkeys = self.kwargs.keys() 73 | args1 = {} 74 | for arg in class_args: 75 | if arg in inkeys: 76 | args1[arg] = self.kwargs[arg] 77 | args1.update(other_args) 78 | return self.data_module(**args1) 79 | -------------------------------------------------------------------------------- /super-resolution/data/recursive_up.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Zhongyang Zhang 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | from pathlib2 import Path 17 | 18 | import torch.utils.data as data 19 | from . import common 20 | 21 | 22 | class RecursiveUp(data.Dataset): 23 | def __init__(self, data_dir='dataset', 24 | train=True): 25 | super().__init__() 26 | self.train = train 27 | self.root = Path(data_dir) 28 | self.check_data() 29 | 30 | def check_data(self): 31 | self.sentinel_root = self.root / 'sentinel' 32 | self.drone_root = self.root / 'drone' 33 | self.filelist = [f.name for f in self.sentinel_root.iterdir()] 34 | self.filelist = self.filelist[:-1] if self.train else self.filelist[-1:] 35 | self.scale_strs = [str(2**i)+'x' for i in range(9)] 36 | 37 | def __getitem__(self, idx): 38 | sen = np.load(self.sentinel_root / self.filelist[idx]) 39 | drone = [np.load(self.drone_root / scale_str / self.filelist[idx]) 40 | for scale_str in self.scale_strs] 41 | sen = sen.transpose(1, 2, 0) 42 | drone = [d.transpose(1, 2, 0) for d in drone] 43 | sen = common.np2Tensor(sen)[0] 44 | drone = common.np2Tensor(*drone) 45 | return sen, drone 46 | 47 | def __len__(self): 48 | return len(self.filelist) 49 | -------------------------------------------------------------------------------- /super-resolution/data/satup_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Zhongyang Zhang 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import random 16 | import os.path as op 17 | import numpy as np 18 | import pickle as pkl 19 | from pathlib2 import Path 20 | 21 | import torch.utils.data as data 22 | from . import common 23 | 24 | 25 | class SatupData(data.Dataset): 26 | def __init__(self, data_dir='dataset', 27 | color_range=255, 28 | train=True, 29 | no_augment=True, 30 | aug_prob=0.5, 31 | batch_size=1): 32 | # Set all input args as attributes 33 | self.__dict__.update(locals()) 34 | self.aug = train and not no_augment 35 | 36 | self.check_files() 37 | self.count = 0 38 | 39 | def check_files(self): 40 | middir = 'train' if self.train else 'val' 41 | info_file = Path(self.data_dir, f'{middir}_lr.pkl') 42 | with open(info_file, 'rb') as f: 43 | self.lr_list = pkl.load(f) 44 | 45 | def __len__(self): 46 | return len(self.lr_list) 47 | 48 | def __getitem__(self, idx): 49 | lrfile = self.lr_list[idx] 50 | hrfile = self.lr_list[idx].replace('LRBigEarth', 'SRBigEarth') 51 | filename = op.splitext(op.basename(hrfile))[0] 52 | lr = np.load(lrfile).transpose(1, 2, 0) 53 | hr = np.load(hrfile).transpose(1, 2, 0) 54 | lr = common.bitdepth_convert(lr, src=16, dst=8) 55 | hr = common.bitdepth_convert(hr, src=16, dst=8) 56 | 57 | if self.aug: 58 | lr, hr = common.augment(lr, hr, prob=self.aug_prob) 59 | lr, hr = common.black_square(lr, hr, prob=self.aug_prob) 60 | if self.count % self.batch_size == 0: 61 | self.aug_scale = random.choice([1.5, 2, 3, 4]) 62 | self.aug_down_up = 1 if random.random() < self.aug_prob else 0 63 | lr, hr = common.down_up( 64 | lr, hr, scales=self.aug_scale, prob=self.aug_down_up, up_prob=0) 65 | 66 | lr_tensor, hr_tensor = common.np2Tensor( 67 | lr, hr, color_range=self.color_range) 68 | 69 | self.count += 1 70 | 71 | return lr_tensor, hr_tensor, filename 72 | -------------------------------------------------------------------------------- /super-resolution/main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Zhongyang Zhang 2 | # Contact: mirakuruyoo@gmai.com 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """ This main entrance of the whole project. 17 | 18 | Most of the code should not be changed, please directly 19 | add all the input arguments of your model's constructor 20 | and the dataset file's constructor. The MInterface and 21 | DInterface can be seen as transparent to all your args. 22 | """ 23 | import pytorch_lightning as pl 24 | from argparse import ArgumentParser 25 | from pytorch_lightning import Trainer 26 | import pytorch_lightning.callbacks as plc 27 | 28 | from model import MInterface 29 | from data import DInterface 30 | from utils import load_model_path_by_args 31 | 32 | 33 | def load_callbacks(): 34 | callbacks = [] 35 | callbacks.append(plc.EarlyStopping( 36 | monitor='mpsnr', 37 | mode='max', 38 | patience=10, 39 | min_delta=0.01 40 | )) 41 | 42 | callbacks.append(plc.ModelCheckpoint( 43 | monitor='mpsnr', 44 | filename='best-{epoch:02d}-{mpsnr:.2f}-{mssim:.3f}', 45 | save_top_k=1, 46 | mode='max', 47 | save_last=True 48 | )) 49 | 50 | if args.lr_scheduler: 51 | callbacks.append(plc.LearningRateMonitor( 52 | logging_interval='epoch')) 53 | return callbacks 54 | 55 | 56 | def main(args): 57 | pl.seed_everything(args.seed) 58 | load_path = load_model_path_by_args(args) 59 | data_module = DInterface(**vars(args)) 60 | 61 | if load_path is None: 62 | model = MInterface(**vars(args)) 63 | else: 64 | model = MInterface(**vars(args)) 65 | args.ckpt_path = load_path 66 | 67 | args.callbacks = load_callbacks() 68 | trainer = Trainer.from_argparse_args(args) 69 | trainer.fit(model, data_module) 70 | 71 | 72 | 73 | if __name__ == '__main__': 74 | parser = ArgumentParser() 75 | # Basic Training Control 76 | parser.add_argument('--batch_size', default=32, type=int) 77 | parser.add_argument('--num_workers', default=8, type=int) 78 | parser.add_argument('--seed', default=1234, type=int) 79 | parser.add_argument('--lr', default=1e-3, type=float) 80 | 81 | # LR Scheduler 82 | parser.add_argument('--lr_scheduler', choices=['step', 'cosine'], type=str) 83 | parser.add_argument('--lr_decay_steps', default=20, type=int) 84 | parser.add_argument('--lr_decay_rate', default=0.5, type=float) 85 | parser.add_argument('--lr_decay_min_lr', default=1e-5, type=float) 86 | 87 | # Restart Control 88 | parser.add_argument('--load_best', action='store_true') 89 | parser.add_argument('--load_dir', default=None, type=str) 90 | parser.add_argument('--load_ver', default=None, type=str) 91 | parser.add_argument('--load_v_num', default=None, type=int) 92 | 93 | # Training Info 94 | parser.add_argument('--dataset', default='satup_data', type=str) 95 | parser.add_argument('--data_dir', default='dataset', type=str) 96 | parser.add_argument('--model_name', default='rdn_fuse', type=str) 97 | parser.add_argument('--loss', default='l1', type=str) 98 | parser.add_argument('--weight_decay', default=1e-5, type=float) 99 | parser.add_argument('--no_augment', action='store_true') 100 | 101 | # Model Hyperparameters 102 | parser.add_argument('--scale', default=2, type=int) 103 | parser.add_argument('--in_bands_num', default=12, type=int) 104 | parser.add_argument('--hid', default=64, type=int) 105 | parser.add_argument('--block_num', default=8, type=int) 106 | parser.add_argument('--rdn_size', default=3, type=int) 107 | parser.add_argument('--rdb_growrate', default=64, type=int) 108 | parser.add_argument('--rdb_conv_num', default=8, type=int) 109 | 110 | # Other 111 | parser.add_argument('--color_range', default=255, type=int) 112 | parser.add_argument('--aug_prob', default=0.5, type=float) 113 | 114 | # Add pytorch lightning's args to parser as a group. 115 | parser = Trainer.add_argparse_args(parser) 116 | 117 | ## Deprecated, old version 118 | # parser = Trainer.add_argparse_args( 119 | # parser.add_argument_group(title="pl.Trainer args")) 120 | 121 | # Reset Some Default Trainer Arguments' Default Values 122 | parser.set_defaults(max_epochs=250) 123 | 124 | args = parser.parse_args() 125 | 126 | # List Arguments 127 | args.mean_sen = [1.315, 1.211, 1.948, 1.892, 3.311, 128 | 6.535, 7.634, 8.197, 8.395, 8.341, 5.89, 3.616] 129 | args.std_sen = [5.958, 2.273, 2.299, 2.668, 2.895, 130 | 4.276, 4.978, 5.237, 5.304, 5.103, 4.298, 3.3] 131 | 132 | main(args) 133 | -------------------------------------------------------------------------------- /super-resolution/model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Zhongyang Zhang 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .model_interface import MInterface -------------------------------------------------------------------------------- /super-resolution/model/common.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Zhongyang Zhang 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | import torch 17 | import torch.nn as nn 18 | 19 | 20 | def conv3x3(in_channels, out_channels, kernel_size, bias=True, stride=1): 21 | return nn.Conv2d( 22 | in_channels, out_channels, kernel_size, 23 | padding=(kernel_size//2), bias=bias, stride=stride) 24 | 25 | 26 | def mean_shift_1d(data, mean, std, base=100, add=False): 27 | if add: 28 | data = data * std / base + mean 29 | else: 30 | data = (data - mean) / std * base 31 | return data 32 | 33 | 34 | def mean_shift_2d(data, mean, std, base=100, add=False): 35 | data = data.permute(2, 3, 0, 1) 36 | 37 | if add: 38 | data = data * std / base + mean 39 | else: 40 | data = (data - mean) / std * base 41 | return data.permute(2, 3, 0, 1) 42 | 43 | 44 | class BasicBlock(nn.Sequential): 45 | def __init__( 46 | self, in_channels, out_channels, kernel_size, stride=1, bias=True, 47 | bn=False, act=nn.ReLU(True)): 48 | 49 | m = [nn.Conv2d( 50 | in_channels, out_channels, kernel_size, 51 | padding=(kernel_size//2), stride=stride, bias=bias) 52 | ] 53 | if bn: 54 | m.append(nn.BatchNorm2d(out_channels)) 55 | if act is not None: 56 | m.append(act) 57 | super(BasicBlock, self).__init__(*m) 58 | 59 | 60 | class ResBlock(nn.Module): 61 | def __init__( 62 | self, conv, n_feats, kernel_size, 63 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 64 | 65 | super(ResBlock, self).__init__() 66 | m = [] 67 | for i in range(2): 68 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) 69 | if bn: 70 | m.append(nn.BatchNorm2d(n_feats)) 71 | if i == 0: 72 | m.append(act) 73 | 74 | self.body = nn.Sequential(*m) 75 | self.res_scale = res_scale 76 | 77 | def forward(self, x): 78 | res = self.body(x).mul(self.res_scale) 79 | res += x 80 | return res 81 | 82 | 83 | class Upsampler(nn.Sequential): 84 | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): 85 | 86 | m = [] 87 | if scale in (2, 4, 8): # Is scale = 2^n? 88 | for _ in range(int(math.log(scale, 2))): 89 | m.append(conv(n_feats, 4 * n_feats, 3, bias)) 90 | m.append(nn.PixelShuffle(2)) 91 | if bn: 92 | m.append(nn.BatchNorm2d(n_feats)) 93 | 94 | if act == 'relu': 95 | m.append(nn.ReLU(True)) 96 | elif act == 'prelu': 97 | m.append(nn.PReLU(n_feats)) 98 | 99 | elif scale == 3: 100 | m.append(conv(n_feats, 9 * n_feats, 3, bias)) 101 | m.append(nn.PixelShuffle(3)) 102 | if bn: 103 | m.append(nn.BatchNorm2d(n_feats)) 104 | 105 | if act == 'relu': 106 | m.append(nn.ReLU(True)) 107 | elif act == 'prelu': 108 | m.append(nn.PReLU(n_feats)) 109 | else: 110 | raise NotImplementedError 111 | 112 | super(Upsampler, self).__init__(*m) 113 | 114 | 115 | ## ---------------------- RDB Modules ------------------------ ## 116 | class RDB_Conv(nn.Module): 117 | """ Residual Dense Convolution. 118 | """ 119 | 120 | def __init__(self, inChannels, growRate, kSize=3): 121 | super(RDB_Conv, self).__init__() 122 | Cin = inChannels 123 | G = growRate 124 | self.conv = nn.Sequential(*[ 125 | nn.Conv2d(Cin, G, kSize, padding=(kSize - 1) // 2, stride=1), 126 | nn.ReLU() 127 | ]) 128 | 129 | def forward(self, x): 130 | out = self.conv(x) 131 | return torch.cat((x, out), 1) 132 | 133 | 134 | class RDB(nn.Module): 135 | """ Residual Dense Block. 136 | """ 137 | 138 | def __init__(self, growRate0, growRate, nConvLayers, kSize=3): 139 | super(RDB, self).__init__() 140 | G0 = growRate0 141 | G = growRate 142 | C = nConvLayers 143 | 144 | convs = [] 145 | for c in range(C): 146 | convs.append(RDB_Conv(G0 + c * G, G)) 147 | self.convs = nn.Sequential(*convs) 148 | 149 | # Local Feature Fusion 150 | self.LFF = nn.Conv2d(G0 + C * G, G0, 1, padding=0, stride=1) 151 | 152 | def forward(self, x): 153 | return self.LFF(self.convs(x)) + x 154 | -------------------------------------------------------------------------------- /super-resolution/model/metrics.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Author : zhwzhong 4 | @License : (C) Copyright 2013-2018, hit 5 | @Contact : zhwzhong.hit@gmail.com 6 | @Software: PyCharm 7 | @File : metrics.py 8 | @Time : 2019/12/4 17:35 9 | @Desc : 10 | """ 11 | import numpy as np 12 | from scipy.signal import convolve2d 13 | # from skimage.measure import compare_psnr, compare_ssim 14 | from skimage.metrics import peak_signal_noise_ratio, structural_similarity 15 | 16 | 17 | def compare_ergas(x_true, x_pred, ratio): 18 | """ 19 | Calculate ERGAS, ERGAS offers a global indication of the quality of fused image.The ideal value is 0. 20 | :param x_true: 21 | :param x_pred: 22 | :param ratio: Upsampling scale. 23 | :return: 24 | """ 25 | x_true, x_pred = img_2d_mat(x_true=x_true, x_pred=x_pred) 26 | sum_ergas = 0 27 | for i in range(x_true.shape[0]): 28 | vec_x = x_true[i] 29 | vec_y = x_pred[i] 30 | err = vec_x - vec_y 31 | r_mse = np.mean(np.power(err, 2)) 32 | tmp = r_mse / (np.mean(vec_x)**2) 33 | sum_ergas += tmp 34 | return (100 / ratio) * np.sqrt(sum_ergas / x_true.shape[0]) 35 | 36 | 37 | def compare_sam(x_true, x_pred): 38 | """ 39 | :param x_true: HSI image:(H, W, C) 40 | :param x_pred: HSI image:(H, W, C) 41 | :return: 计算原始高光谱数据与重构高光谱数据的光谱角相似度 42 | """ 43 | num = 0 44 | sum_sam = 0 45 | x_true, x_pred = x_true.astype(np.float32), x_pred.astype(np.float32) 46 | for x in range(x_true.shape[0]): 47 | for y in range(x_true.shape[1]): 48 | tmp_pred = x_pred[x, y].ravel() 49 | tmp_true = x_true[x, y].ravel() 50 | if np.linalg.norm(tmp_true) != 0 and np.linalg.norm(tmp_pred) != 0: 51 | sum_sam += np.arccos( 52 | np.inner(tmp_pred, tmp_true) / (np.linalg.norm(tmp_true) * np.linalg.norm(tmp_pred))) 53 | num += 1 54 | sam_deg = (sum_sam / num) * 180 / np.pi 55 | return sam_deg 56 | 57 | 58 | def compare_corr(x_true, x_pred): 59 | """ 60 | Calculate the cross correlation between x_pred and x_true. 61 | 求对应波段的相关系数,然后取均值 62 | CC is a spatial measure. 63 | """ 64 | x_true, x_pred = img_2d_mat(x_true=x_true, x_pred=x_pred) 65 | x_true = x_true - np.mean(x_true, axis=1).reshape(-1, 1) 66 | x_pred = x_pred - np.mean(x_pred, axis=1).reshape(-1, 1) 67 | numerator = np.sum(x_true * x_pred, axis=1).reshape(-1, 1) 68 | denominator = np.sqrt(np.sum(x_true * x_true, axis=1) 69 | * np.sum(x_pred * x_pred, axis=1)).reshape(-1, 1) 70 | return (numerator / denominator).mean() 71 | 72 | 73 | def img_2d_mat(x_true, x_pred): 74 | """ 75 | # 将三维的多光谱图像转为2位矩阵 76 | :param x_true: (H, W, C) 77 | :param x_pred: (H, W, C) 78 | :return: a matrix which shape is (C, H * W) 79 | """ 80 | h, w, c = x_true.shape 81 | x_true, x_pred = x_true.astype(np.float32), x_pred.astype(np.float32) 82 | x_mat = np.zeros((c, h * w), dtype=np.float32) 83 | y_mat = np.zeros((c, h * w), dtype=np.float32) 84 | for i in range(c): 85 | x_mat[i] = x_true[:, :, i].reshape((1, -1)) 86 | y_mat[i] = x_pred[:, :, i].reshape((1, -1)) 87 | return x_mat, y_mat 88 | 89 | 90 | def compare_rmse(x_true, x_pred): 91 | """ 92 | Calculate Root mean squared error 93 | :param x_true: 94 | :param x_pred: 95 | :return: 96 | """ 97 | x_true, x_pred = x_true.astype(np.float32), x_pred.astype(np.float32) 98 | return np.linalg.norm(x_true - x_pred) / (np.sqrt(x_true.shape[0] * x_true.shape[1] * x_true.shape[2])) 99 | 100 | 101 | def compare_mpsnr(x_true, x_pred, data_range, detail=False): 102 | """ 103 | :param x_true: Input image must have three dimension (H, W, C) 104 | :param x_pred: 105 | :return: 106 | """ 107 | x_true, x_pred = x_true.astype(np.float32), x_pred.astype(np.float32) 108 | channels = x_true.shape[2] 109 | total_psnr = [peak_signal_noise_ratio(image_true=x_true[:, :, k], image_test=x_pred[:, :, k], data_range=data_range) 110 | for k in range(channels)] 111 | if detail: 112 | return np.mean(total_psnr), total_psnr 113 | else: 114 | return np.mean(total_psnr) 115 | 116 | 117 | def compare_mssim(x_true, x_pred, data_range, multidimension, detail=False): 118 | """ 119 | :param x_true: 120 | :param x_pred: 121 | :param data_range: 122 | :param multidimension: 123 | :return: 124 | """ 125 | mssim = [structural_similarity(x_true[:, :, i], x_pred[:, :, i], data_range=data_range, multidimension=multidimension) 126 | for i in range(x_true.shape[2])] 127 | if detail: 128 | return np.mean(mssim), mssim 129 | else: 130 | return np.mean(mssim) 131 | 132 | 133 | def compare_sid(x_true, x_pred): 134 | """ 135 | SID is an information theoretic measure for spectral similarity and discriminability. 136 | :param x_true: 137 | :param x_pred: 138 | :return: 139 | """ 140 | x_true, x_pred = x_true.astype(np.float32), x_pred.astype(np.float32) 141 | N = x_true.shape[2] 142 | err = np.zeros(N) 143 | for i in range(N): 144 | err[i] = abs(np.sum(x_pred[:, :, i] * np.log10((x_pred[:, :, i] + 1e-3) / (x_true[:, :, i] + 1e-3))) + 145 | np.sum(x_true[:, :, i] * np.log10((x_true[:, :, i] + 1e-3) / (x_pred[:, :, i] + 1e-3)))) 146 | return np.mean(err / (x_true.shape[1] * x_true.shape[0])) 147 | 148 | 149 | def compare_appsa(x_true, x_pred): 150 | """ 151 | 152 | :param x_true: 153 | :param x_pred: 154 | :return: 155 | """ 156 | x_true, x_pred = x_true.astype(np.float32), x_pred.astype(np.float32) 157 | nom = np.sum(x_true * x_pred, axis=2) 158 | denom = np.linalg.norm(x_true, axis=2) * np.linalg.norm(x_pred, axis=2) 159 | 160 | cos = np.where((nom / (denom + 1e-3)) > 1, 1, (nom / (denom + 1e-3))) 161 | appsa = np.arccos(cos) 162 | return np.sum(appsa) / (x_true.shape[1] * x_true.shape[0]) 163 | 164 | 165 | def compare_mare(x_true, x_pred): 166 | """ 167 | 168 | :param x_true: 169 | :param x_pred: 170 | :return: 171 | """ 172 | x_true, x_pred = x_true.astype(np.float32), x_pred.astype(np.float32) 173 | diff = x_true - x_pred 174 | abs_diff = np.abs(diff) 175 | # added epsilon to avoid division by zero. 176 | relative_abs_diff = np.divide(abs_diff, x_true + 1) 177 | return np.mean(relative_abs_diff) 178 | 179 | 180 | def img_qi(img1, img2, block_size=8): 181 | N = block_size ** 2 182 | sum2_filter = np.ones((block_size, block_size)) 183 | 184 | img1_sq = img1 * img1 185 | img2_sq = img2 * img2 186 | img12 = img1 * img2 187 | 188 | img1_sum = convolve2d(img1, np.rot90(sum2_filter), mode='valid') 189 | img2_sum = convolve2d(img2, np.rot90(sum2_filter), mode='valid') 190 | img1_sq_sum = convolve2d(img1_sq, np.rot90(sum2_filter), mode='valid') 191 | img2_sq_sum = convolve2d(img2_sq, np.rot90(sum2_filter), mode='valid') 192 | img12_sum = convolve2d(img12, np.rot90(sum2_filter), mode='valid') 193 | 194 | img12_sum_mul = img1_sum * img2_sum 195 | img12_sq_sum_mul = img1_sum * img1_sum + img2_sum * img2_sum 196 | numerator = 4 * (N * img12_sum - img12_sum_mul) * img12_sum_mul 197 | denominator1 = N * (img1_sq_sum + img2_sq_sum) - img12_sq_sum_mul 198 | denominator = denominator1 * img12_sq_sum_mul 199 | quality_map = np.ones(denominator.shape) 200 | index = (denominator1 == 0) & (img12_sq_sum_mul != 0) 201 | quality_map[index] = 2 * img12_sum_mul[index] / img12_sq_sum_mul[index] 202 | index = (denominator != 0) 203 | quality_map[index] = numerator[index] / denominator[index] 204 | return quality_map.mean() 205 | 206 | 207 | def compare_qave(x_true, x_pred, block_size=8): 208 | n_bands = x_true.shape[2] 209 | q_orig = np.zeros(n_bands) 210 | for idim in range(n_bands): 211 | q_orig[idim] = img_qi(x_true[:, :, idim], 212 | x_pred[:, :, idim], block_size) 213 | return q_orig.mean() 214 | 215 | 216 | def quality_assessment(x_true, x_pred, data_range, ratio, multi_dimension=False): 217 | """ 218 | :param multi_dimension: 219 | :param ratio: 220 | :param data_range: 221 | :param x_true: 222 | :param x_pred: 223 | :param block_size 224 | :return: 225 | """ 226 | result = {'MPSNR': compare_mpsnr(x_true=x_true, x_pred=x_pred, data_range=data_range), 227 | 'MSSIM': compare_mssim(x_true=x_true, x_pred=x_pred, data_range=data_range, 228 | multidimension=multi_dimension), 229 | # 'ERGAS': compare_ergas(x_true=x_true, x_pred=x_pred, ratio=ratio), 230 | 'SAM': compare_sam(x_true=x_true, x_pred=x_pred), 231 | 'CrossCorrelation': compare_corr(x_true=x_true, x_pred=x_pred), 232 | 'RMSE': compare_rmse(x_true=x_true, x_pred=x_pred), 233 | } 234 | return result 235 | 236 | 237 | def baseline_assessment(x_true, x_pred, data_range, multi_dimension=False): 238 | mpsnr, psnrs = compare_mpsnr(x_true=x_true, x_pred=x_pred, data_range=data_range, detail=True) 239 | mssim, ssims = compare_mssim(x_true=x_true, x_pred=x_pred, data_range=data_range, 240 | multidimension=multi_dimension, detail=True) 241 | return mpsnr, mssim, psnrs, ssims 242 | 243 | 244 | def tensor_accessment(x_true, x_pred, data_range, multi_dimension=False): 245 | x_true = x_true.transpose(0, 2, 3, 1)[0] 246 | x_pred = x_pred.transpose(0, 2, 3, 1)[0] 247 | mpsnr, psnrs = compare_mpsnr(x_true=x_true, x_pred=x_pred, data_range=data_range, detail=True) 248 | mssim, ssims = compare_mssim(x_true=x_true, x_pred=x_pred, data_range=data_range, 249 | multidimension=multi_dimension, detail=True) 250 | return mpsnr, mssim, psnrs, ssims 251 | 252 | 253 | def batch_accessment(x_true, x_pred, data_range, ratio, multi_dimension=False): 254 | scores = [] 255 | avg_score = {'MPSNR': 0, 'MSSIM': 0, 'SAM': 0, 256 | 'CrossCorrelation': 0, 'RMSE': 0} 257 | x_true = x_true.transpose(0, 2, 3, 1) 258 | x_pred = x_pred.transpose(0, 2, 3, 1) 259 | 260 | for i in range(x_true.shape[0]): 261 | scores.append(quality_assessment( 262 | x_true[i], x_pred[i], data_range, ratio, multi_dimension)) 263 | for met in avg_score.keys(): 264 | avg_score[met] = np.mean([score[met] for score in scores]) 265 | return avg_score 266 | 267 | # from scipy import io as sio 268 | # im_out = np.array(sio.loadmat('/home/zhwzhong/PycharmProject/HyperSR/SOAT/HyperSR/SRindices/Chikuse_EDSRViDeCNN_Blocks=9_Feats=256_Loss_H_Real_1_1_X2X2_N5new_BS32_Epo60_epoch_60_Fri_Sep_20_21:38:44_2019.mat')['output']) 269 | # im_gt = np.array(sio.loadmat('/home/zhwzhong/PycharmProject/HyperSR/SOAT/HyperSR/SRindices/Chikusei_test.mat')['gt']) 270 | # 271 | # sum_rmse, sum_sam, sum_psnr, sum_ssim, sum_ergas = [], [], [], [], [] 272 | # for i in range(im_gt.shape[0]): 273 | # print(im_out[i].shape) 274 | # score = quality_assessment(x_pred=im_out[i], x_true=im_gt[i], data_range=1, ratio=4, multi_dimension=False, block_size=8) 275 | # sum_rmse.append(score['RMSE']) 276 | # sum_psnr.append(score['MPSNR']) 277 | # sum_ssim.append(score['MSSIM']) 278 | # sum_sam.append(score['SAM']) 279 | # sum_ergas.append(score['ERGAS']) 280 | # 281 | # print(np.mean(sum_rmse), np.mean(sum_psnr), np.mean(sum_ssim), np.mean(sum_sam)) 282 | -------------------------------------------------------------------------------- /super-resolution/model/model_interface.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Zhongyang Zhang 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import inspect 16 | import torch 17 | import numpy as np 18 | import importlib 19 | from torch import nn 20 | from torch.nn import functional as F 21 | import torch.optim.lr_scheduler as lrs 22 | 23 | import pytorch_lightning as pl 24 | from .metrics import tensor_accessment 25 | from .utils import quantize 26 | 27 | 28 | class MInterface(pl.LightningModule): 29 | def __init__(self, model_name, loss, lr, **kargs): 30 | super().__init__() 31 | self.save_hyperparameters() 32 | self.load_model() 33 | self.configure_loss() 34 | 35 | # Project-Specific Definitions 36 | self.hsi_index = np.r_[0, 4:12] 37 | self.rgb_index = (3, 2, 1) 38 | 39 | def forward(self, lr_hsi, hr_rgb): 40 | return self.model(lr_hsi, hr_rgb) 41 | 42 | def training_step(self, batch, batch_idx): 43 | lr, hr, _ = batch 44 | sr = self(lr, hr[:, self.rgb_index, ]) 45 | loss = self.loss_function(sr[:, self.hsi_index], hr[:, self.hsi_index]) 46 | self.log('loss', loss, on_step=True, on_epoch=True, prog_bar=True) 47 | return loss 48 | 49 | def validation_step(self, batch, batch_idx): 50 | lr, hr, _ = batch 51 | sr = self(lr, hr[:, self.rgb_index, ]) 52 | sr = quantize(sr, self.hparams.color_range) 53 | mpsnr, mssim, _, _ = tensor_accessment( 54 | x_pred=sr[:, self.hsi_index].cpu().numpy(), 55 | x_true=hr[:, self.hsi_index].cpu().numpy(), 56 | data_range=self.hparams.color_range, 57 | multi_dimension=False) 58 | 59 | self.log('mpsnr', mpsnr, on_step=False, on_epoch=True, prog_bar=True) 60 | self.log('mssim', mssim, on_step=False, on_epoch=True, prog_bar=True) 61 | 62 | def test_step(self, batch, batch_idx): 63 | # Here we just reuse the validation_step for testing 64 | return self.validation_step(batch, batch_idx) 65 | 66 | def on_validation_epoch_end(self): 67 | # Make the Progress Bar leave there 68 | self.print('') 69 | # self.print(self.get_progress_bar_dict()) 70 | 71 | def configure_optimizers(self): 72 | if hasattr(self.hparams, 'weight_decay'): 73 | weight_decay = self.hparams.weight_decay 74 | else: 75 | weight_decay = 0 76 | optimizer = torch.optim.Adam( 77 | self.parameters(), lr=self.hparams.lr, weight_decay=weight_decay) 78 | 79 | if self.hparams.lr_scheduler is None: 80 | return optimizer 81 | else: 82 | if self.hparams.lr_scheduler == 'step': 83 | scheduler = lrs.StepLR(optimizer, 84 | step_size=self.hparams.lr_decay_steps, 85 | gamma=self.hparams.lr_decay_rate) 86 | elif self.hparams.lr_scheduler == 'cosine': 87 | scheduler = lrs.CosineAnnealingLR(optimizer, 88 | T_max=self.hparams.lr_decay_steps, 89 | eta_min=self.hparams.lr_decay_min_lr) 90 | else: 91 | raise ValueError('Invalid lr_scheduler type!') 92 | return [optimizer], [scheduler] 93 | 94 | def configure_loss(self): 95 | loss = self.hparams.loss.lower() 96 | if loss == 'mse': 97 | self.loss_function = F.mse_loss 98 | elif loss == 'l1': 99 | self.loss_function = F.l1_loss 100 | else: 101 | raise ValueError("Invalid Loss Type!") 102 | 103 | def load_model(self): 104 | name = self.hparams.model_name 105 | # Change the `snake_case.py` file name to `CamelCase` class name. 106 | # Please always name your model file name as `snake_case.py` and 107 | # class name corresponding `CamelCase`. 108 | camel_name = ''.join([i.capitalize() for i in name.split('_')]) 109 | try: 110 | Model = getattr(importlib.import_module( 111 | '.'+name, package=__package__), camel_name) 112 | except: 113 | raise ValueError( 114 | f'Invalid Module File Name or Invalid Class Name {name}.{camel_name}!') 115 | self.model = self.instancialize(Model) 116 | 117 | def instancialize(self, Model, **other_args): 118 | """ Instancialize a model using the corresponding parameters 119 | from self.hparams dictionary. You can also input any args 120 | to overwrite the corresponding value in self.hparams. 121 | """ 122 | class_args = inspect.getargspec(Model.__init__).args[1:] 123 | inkeys = self.hparams.keys() 124 | args1 = {} 125 | for arg in class_args: 126 | if arg in inkeys: 127 | args1[arg] = getattr(self.hparams, arg) 128 | args1.update(other_args) 129 | return Model(**args1) 130 | -------------------------------------------------------------------------------- /super-resolution/model/rdn_fuse.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Zhongyang Zhang 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import torch 17 | from torch import nn 18 | from functools import partial 19 | from .common import RDB, Upsampler, mean_shift_2d, conv3x3 20 | 21 | class RdnFuse(nn.Module): 22 | def __init__(self, scale, 23 | in_bands_num, 24 | hid=64, 25 | block_num=8, 26 | rdn_size=3, 27 | rdb_growrate=64, 28 | rdb_conv_num=8, 29 | mean_sen=[1.315, 1.211, 1.948, 1.892, 3.311, 6.535, 7.634, 8.197, 8.395, 8.341, 5.89, 3.616], 30 | std_sen=[5.958, 2.273, 2.299, 2.668, 2.895, 4.276, 4.978, 5.237, 5.304, 5.103, 4.298, 3.3], 31 | ): 32 | super().__init__() 33 | # Set all input args as attributes 34 | self.__dict__.update(locals()) 35 | 36 | ibn = self.in_bands_num 37 | lbn = self.in_bands_num - 3 # LR Bands Num 38 | kSize = self.rdn_size 39 | 40 | # Up-sampling net 41 | self.UPNet = Upsampler(conv3x3, self.scale, lbn) 42 | self.trans = conv3x3(ibn, hid, kSize) 43 | 44 | # Shallow feature extraction net 45 | self.SFENet = conv3x3(hid, hid, kSize) 46 | 47 | # Redidual dense blocks and dense feature fusion 48 | self.RDBs_body = nn.ModuleList() 49 | for _ in range(self.block_num): 50 | self.RDBs_body.append( 51 | RDB(growRate0=hid, growRate=self.rdb_growrate, nConvLayers=self.rdb_conv_num) 52 | ) 53 | 54 | # Global Feature Fusion 55 | self.GFF_body = nn.Sequential(*[ 56 | conv3x3(self.block_num * hid, hid, 1), 57 | conv3x3(hid, hid, kSize) 58 | ]) 59 | 60 | self.tail = conv3x3(hid, lbn, kSize) 61 | 62 | def init_norm_func(self, ref): 63 | self.hsi_index = np.r_[0,4:12] 64 | self.rgb_index = np.r_[1:4] 65 | mean_hsi = np.array(self.mean_sen)[self.hsi_index] 66 | std_hsi = np.array(self.std_sen)[self.hsi_index] 67 | mean_rgb = np.array(self.mean_sen)[self.rgb_index] 68 | std_rgb = np.array(self.std_sen)[self.rgb_index] 69 | 70 | self.sub_mean = partial(mean_shift_2d, 71 | mean=torch.tensor(mean_hsi, dtype=torch.float32).type_as(ref), 72 | std=torch.tensor(std_hsi, dtype=torch.float32).type_as(ref), 73 | base=1, 74 | add=False) 75 | self.add_mean = partial(mean_shift_2d, 76 | mean=torch.tensor(mean_hsi, dtype=torch.float32).type_as(ref), 77 | std=torch.tensor(std_hsi, dtype=torch.float32).type_as(ref), 78 | base=1, 79 | add=True) 80 | self.sub_mean_rgb = partial(mean_shift_2d, 81 | mean=torch.tensor(mean_rgb, dtype=torch.float32).type_as(ref), 82 | std=torch.tensor(std_rgb, dtype=torch.float32).type_as(ref), 83 | base=1, 84 | add=False) 85 | 86 | def forward(self, x, hr_rgb): 87 | if not hasattr(self, 'sub_mean'): 88 | self.init_norm_func(x) 89 | hr_rgb_ori = hr_rgb 90 | 91 | # -- Head -- # 92 | x = x[:,np.r_[0,4:12],:,:] 93 | x = self.sub_mean(x) 94 | hr_rgb = self.sub_mean_rgb(hr_rgb) 95 | 96 | x = self.UPNet(x) 97 | x = torch.cat([x, hr_rgb], axis=1) 98 | f1 = self.trans(x) 99 | x = self.SFENet(f1) 100 | 101 | # -- Body -- # 102 | RDBs_out_body = [] 103 | for i in range(self.block_num): 104 | x = self.RDBs_body[i](x) 105 | RDBs_out_body.append(x) 106 | 107 | x = self.GFF_body(torch.cat(RDBs_out_body, 1)) 108 | x += f1 109 | 110 | x = self.tail(x) 111 | x = self.add_mean(x) 112 | x = torch.cat([x, hr_rgb_ori], axis=1) 113 | x = x[:,(0,11,10,9,1,2,3,4,5,6,7,8),] 114 | return x 115 | -------------------------------------------------------------------------------- /super-resolution/model/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Zhongyang Zhang 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | 17 | def quantize(img, rgb_range): 18 | pixel_range = 255 / rgb_range 19 | return img.mul(pixel_range).clamp(0, 255).div(pixel_range) -------------------------------------------------------------------------------- /super-resolution/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Zhongyang Zhang 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | from pathlib2 import Path 17 | 18 | def load_model_path(root=None, version=None, v_num=None, best=False): 19 | """ When best = True, return the best model's path in a directory 20 | by selecting the best model with largest epoch. If not, return 21 | the last model saved. You must provide at least one of the 22 | first three args. 23 | Args: 24 | root: The root directory of checkpoints. It can also be a 25 | model ckpt file. Then the function will return it. 26 | version: The name of the version you are going to load. 27 | v_num: The version's number that you are going to load. 28 | best: Whether return the best model. 29 | """ 30 | def sort_by_epoch(path): 31 | name = path.stem 32 | epoch=int(name.split('-')[1].split('=')[1]) 33 | return epoch 34 | 35 | def generate_root(): 36 | if root is not None: 37 | return root 38 | elif version is not None: 39 | return str(Path('lightning_logs', version, 'checkpoints')) 40 | else: 41 | return str(Path('lightning_logs', f'version_{v_num}', 'checkpoints')) 42 | 43 | if root==version==v_num==None: 44 | return None 45 | 46 | root = generate_root() 47 | if Path(root).is_file(): 48 | return root 49 | if best: 50 | files=[i for i in list(Path(root).iterdir()) if i.stem.startswith('best')] 51 | files.sort(key=sort_by_epoch, reverse=True) 52 | res = str(files[0]) 53 | else: 54 | res = str(Path(root) / 'last.ckpt') 55 | return res 56 | 57 | def load_model_path_by_args(args): 58 | return load_model_path(root=args.load_dir, version=args.load_ver, v_num=args.load_v_num) --------------------------------------------------------------------------------