├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── README_kr.md ├── README_zh.md ├── core ├── dsproc_mcls.py ├── dsproc_mclsmfolder.py └── mengine.py ├── dataset ├── label.txt ├── train.txt ├── val.txt └── val_dataset │ └── 51aa9b8d0da890cd1d0c5029e3d89e3c.jpg ├── images └── competition_title.png ├── infer_api.py ├── main.sh ├── main_infer.py ├── main_train.py ├── main_train_single_gpu.py ├── merge.py ├── model ├── convnext.py └── replknet.py ├── requirements.txt └── toolkit ├── chelper.py ├── cmetric.py ├── dhelper.py ├── dtransform.py └── yacs.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.idea 2 | .DS_Store 3 | *.pth 4 | *.pyc 5 | *.ipynb 6 | __pycache_ 7 | vision_rush_image* -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # 引入cuda版本 2 | FROM nvidia/cuda:11.3.1-cudnn8-runtime-ubuntu20.04 3 | 4 | # 设置工作目录 5 | WORKDIR /code 6 | 7 | RUN ln -sf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime && echo "Asia/Shanghai" > /etc/timezone 8 | RUN apt-get update -y 9 | RUN apt-get install software-properties-common -y && add-apt-repository ppa:deadsnakes/ppa 10 | RUN apt-get install python3.8 python3-pip curl libgl1 libglib2.0-0 ffmpeg libsm6 libxext6 -y && apt-get clean && rm -rf /var/lib/apt/lists/* 11 | RUN update-alternatives --install /usr/bin/pytho3 python3 /usr/bin/python3.8 0 12 | RUN update-alternatives --set python3 /usr/bin/python3.8 13 | 14 | # 复制该./requirements.txt文件到工作目录中,安装python依赖库。 15 | ADD ./requirements.txt /code/requirements.txt 16 | RUN pip3 install pip --upgrade -i https://pypi.mirrors.ustc.edu.cn/simple/ 17 | RUN pip3 install -r requirements.txt -i https://pypi.mirrors.ustc.edu.cn/simple/ && rm -rf `pip3 cache dir` 18 | 19 | # 复制模型及代码到工作目录 20 | ADD ./core /code/core 21 | ADD ./dataset /code/dataset 22 | ADD ./model /code/model 23 | ADD ./pre_model /code/pre_model 24 | ADD ./final_model_csv /code/final_model_csv 25 | ADD ./toolkit /code/toolkit 26 | ADD ./infer_api.py /code/infer_api.py 27 | ADD ./main_infer.py /code/main_infer.py 28 | ADD ./main_train.py /code/main_train.py 29 | ADD ./merge.py /code/merge.py 30 | ADD ./main.sh /code/main.sh 31 | ADD ./README.md /code/README.md 32 | ADD ./Dockerfile /code/Dockerfile 33 | 34 | #运行python文件 35 | ENTRYPOINT ["python3","infer_api.py"] 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

DeepFake Defenders

2 |
If you like our project, please give us a star ⭐ on GitHub for latest update.
3 | 4 |
5 | 6 | 7 | [![License](https://img.shields.io/badge/License-Apache%202.0-yellow)](https://github.com/VisionRush/DeepFakeDefenders/blob/main/LICENSE) 8 | ![GitHub contributors](https://img.shields.io/github/contributors/VisionRush/DeepFakeDefenders) 9 | [![Hits](https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fgithub.com%2FVisionRush%2FDeepFakeDefenders&count_bg=%2379C83D&title_bg=%23555555&icon=&icon_color=%23E7E7E7&title=Visitors&edge_flat=false)](https://hits.seeyoufarm.com) 10 | ![GitHub Repo stars](https://img.shields.io/github/stars/VisionRush/DeepFakeDefenders) 11 | [![GitHub issues](https://img.shields.io/github/issues/VisionRush/DeepFakeDefenders?color=critical&label=Issues)](https://github.com/PKU-YuanGroup/MoE-LLaVA/issues?q=is%3Aopen+is%3Aissue) 12 | [![GitHub closed issues](https://img.shields.io/github/issues-closed/VisionRush/DeepFakeDefenders?color=success&label=Issues)](https://github.com/PKU-YuanGroup/MoE-LLaVA/issues?q=is%3Aissue+is%3Aclosed)
13 | 14 |
15 | 16 |

17 | 18 |

19 | 20 | 💡 We also provide [[中文文档 / CHINESE DOC](README_zh.md)] and [[한국어 문서 / KOREAN DOC](README_kr.md)]. We very welcome and appreciate your contributions to this project. 21 | 22 | ## 📣 News 23 | 24 | * **[2024.09.05]** 🔥 We officially released the initial version of Deepfake defenders, and we won the third prize in the deepfake challenge at [[the conference on the bund](https://www.atecup.cn/deepfake)]. 25 | 26 | ## 🚀 Quickly Start 27 | 28 | ### 1. Pretrained Models Preparation 29 | 30 | Before getting started, please place the ImageNet-1K pretrained weight files in the `./pre_model` directory. The download links for the weights are provided below: 31 | ``` 32 | RepLKNet: https://drive.google.com/file/d/1vo-P3XB6mRLUeDzmgv90dOu73uCeLfZN/view?usp=sharing 33 | ConvNeXt: https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_384.pth 34 | ``` 35 | 36 | ### 2. Training from Scratch 37 | 38 | #### 2.1 Modifying the dataset path 39 | 40 | Place the training-set **(\*.txt)** file, validation-set **(\*.txt)** file, and label **(\*.txt)** file required for training in the dataset folder and name them with the same file name (there are various txt examples under dataset) 41 | 42 | #### 2.2 Modifying the Hyper-parameters 43 | 44 | For the two models (RepLKNet and ConvNeXt) used, the following parameters need to be changed in `main_train.py`: 45 | 46 | ```python 47 | # For RepLKNet 48 | cfg.network.name = 'replknet'; cfg.train.batch_size = 16 49 | # For ConvNeXt 50 | cfg.network.name = 'convnext'; cfg.train.batch_size = 24 51 | ``` 52 | 53 | #### 2.3 Using the training script 54 | ##### Multi-GPUs:(8 GPUs were used) 55 | ```shell 56 | bash main.sh 57 | ``` 58 | ##### Single-GPU: 59 | ```shell 60 | CUDA_VISIBLE_DEVICES=0 python main_train_single_gpu.py 61 | ``` 62 | 63 | #### 2.4 Model Assembling 64 | 65 | Replace the ConvNeXt trained model path and the RepLKNet trained model path in `merge.py`, and execute `python merge.py` to obtain the final inference test model. 66 | 67 | #### 2.5 Inference 68 | 69 | The following example uses the **POST** request interface to request the image path as the request parameter, and the response output is the deepfake score predicted by the model. 70 | 71 | ```python 72 | #!/usr/bin/env python 73 | # -*- coding:utf-8 -*- 74 | import requests 75 | import json 76 | import requests 77 | import json 78 | 79 | header = { 80 | 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.107 Safari/537.36' 81 | } 82 | 83 | url = 'http://ip:10005/inter_api' 84 | image_path = './dataset/val_dataset/51aa9b8d0da890cd1d0c5029e3d89e3c.jpg' 85 | data_map = {'img_path':image_path} 86 | response = requests.post(url, data=json.dumps(data_map), headers=header) 87 | content = response.content 88 | print(json.loads(content)) 89 | ``` 90 | 91 | ### 3. Deploy in Docker 92 | #### Building 93 | 94 | ```shell 95 | sudo docker build -t vision-rush-image:1.0.1 --network host . 96 | ``` 97 | 98 | #### Running 99 | 100 | ```shell 101 | sudo docker run -d --name vision_rush_image --gpus=all --net host vision-rush-image:1.0.1 102 | ``` 103 | 104 | ## Star History 105 | 106 | [![Star History Chart](https://api.star-history.com/svg?repos=VisionRush/DeepFakeDefenders&type=Date)](https://star-history.com/#DeepFakeDefenders/DeepFakeDefenders&Date) 107 | -------------------------------------------------------------------------------- /README_kr.md: -------------------------------------------------------------------------------- 1 |

DeepFake Defenders

2 |
저희의 프로젝트가 마음에 드신다면, GitHub에서 별 ⭐ 을 GitHub에서 눌러 최신 업데이트를 받아보세요.
3 | 4 | 5 |
6 | 7 | 8 | [![License](https://img.shields.io/badge/License-Apache%202.0-yellow)](https://github.com/VisionRush/DeepFakeDefenders/blob/main/LICENSE) 9 | ![GitHub contributors](https://img.shields.io/github/contributors/VisionRush/DeepFakeDefenders) 10 | [![Hits](https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fgithub.com%2FVisionRush%2FDeepFakeDefenders&count_bg=%2379C83D&title_bg=%23555555&icon=&icon_color=%23E7E7E7&title=Visitors&edge_flat=false)](https://hits.seeyoufarm.com) 11 | ![GitHub Repo stars](https://img.shields.io/github/stars/VisionRush/DeepFakeDefenders) 12 | [![GitHub issues](https://img.shields.io/github/issues/VisionRush/DeepFakeDefenders?color=critical&label=Issues)](https://github.com/PKU-YuanGroup/MoE-LLaVA/issues?q=is%3Aopen+is%3Aissue) 13 | [![GitHub closed issues](https://img.shields.io/github/issues-closed/VisionRush/DeepFakeDefenders?color=success&label=Issues)](https://github.com/PKU-YuanGroup/MoE-LLaVA/issues?q=is%3Aissue+is%3Aclosed)
14 | 15 |
16 | 17 |

18 | 19 |

20 | 21 | 💡 [[영어 문서 / ENGLISH DOC](README)]와 [[중국어 문서 / CHINESE DOC](README_zh.md)]를 제공하고 있습니다. 저희는 이 프로젝트에 대한 기여를 매우 환영하고 감사드립니다. 22 | 23 | ## 📣 뉴스 24 | 25 | * **[2024.09.05]** 🔥 Deepfake defenders의 초기 버전을 공식적으로 릴리즈했으며, [[Bund에서의 컨퍼런스](https://www.atecup.cn/deepfake)]에서 deepfake challenge에서 3등을 수상했습니다. 26 | 27 | ## 🚀 빠르게 시작하기 28 | 29 | ### 1. 사전에 훈련된 모델 준비하기 30 | 31 | 시작하기 전, ImageNet-1K로 사전에 훈련된 가중치 파일들을 `./pre_model` 디렉토리에 넣어주세요. 가중치 파일들의 다운로드 링크들은 아래와 같습니다. 32 | ``` 33 | RepLKNet: https://drive.google.com/file/d/1vo-P3XB6mRLUeDzmgv90dOu73uCeLfZN/view?usp=sharing 34 | ConvNeXt: https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_384.pth 35 | ``` 36 | 37 | ### 2. 처음부터 훈련시키기 38 | 39 | #### 2.1 데이터셋의 경로 조정하기 40 | 41 | 학습에 필요한 파일인 트레이닝셋 파일인 **(\*.txt)** 파일, 벨리데이션셋 파일 **(\*.txt)**, 라벨 파일 **(\*.txt)** 을 dataset 폴더에 넣고, 파일들을 같은 이름으로 지정하세요. (dataset 아래에 다양한 txt 예제들이 있습니다) 42 | 43 | #### 2.2 하이퍼 파라미터 조정하기 44 | 45 | 두 모델(RepLKNet과 ConvNeXt)을 위해 `main_train.py`의 파라미터가 아래와 같이 설정되어야 합니다. 46 | 47 | ```python 48 | # RepLKNet으로 설정 49 | cfg.network.name = 'replknet'; cfg.train.batch_size = 16 50 | # ConvNeXt으로 설정 51 | cfg.network.name = 'convnext'; cfg.train.batch_size = 24 52 | ``` 53 | 54 | #### 2.3 훈련 스크립트 사용하기 55 | 56 | ##### 다중 GPU: ( GPU 8개가 사용되었습니다. ) 57 | ```shell 58 | bash main.sh 59 | ``` 60 | 61 | ##### 단일 GPU: 62 | ```shell 63 | CUDA_VISIBLE_DEVICES=0 python main_train_single_gpu.py 64 | ``` 65 | 66 | #### 2.4 모델 조립하기 67 | 68 | `mergy.py`의 ConvNeXt로 훈련된 모델 경로와 RepLKNet으로 훈련된 경로를 바꾸고, `python mergy.py`를 실행시켜 최종 인퍼런스 테스트 모델을 만듭니다. 69 | 70 | #### 2.5 인퍼런스 71 | 72 | 다음의 예제는 **POST** 요청 인터페이스를 사용하여 이미지 경로를 매개변수로 요청하여 모델이 예측한 딥페이크 점수를 응답을 출력합니다. 73 | 74 | ```python 75 | #!/usr/bin/env python 76 | # -*- coding:utf-8 -*- 77 | import requests 78 | import json 79 | import requests 80 | import json 81 | 82 | header = { 83 | 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.107 Safari/537.36' 84 | } 85 | 86 | url = 'http://ip:10005/inter_api' 87 | image_path = './dataset/val_dataset/51aa9b8d0da890cd1d0c5029e3d89e3c.jpg' 88 | data_map = {'img_path':image_path} 89 | response = requests.post(url, data=json.dumps(data_map), headers=header) 90 | content = response.content 91 | print(json.loads(content)) 92 | ``` 93 | 94 | ### 3. Docker에 배포하기 95 | 96 | #### 빌드하기 97 | 98 | ```shell 99 | sudo docker build -t vision-rush-image:1.0.1 --network host . 100 | ``` 101 | 102 | #### 실행시기키 103 | 104 | ```shell 105 | sudo docker run -d --name vision_rush_image --gpus=all --net host vision-rush-image:1.0.1 106 | ``` 107 | 108 | ## Star History 109 | 110 | [![Star History Chart](https://api.star-history.com/svg?repos=VisionRush/DeepFakeDefenders&type=Date)](https://star-history.com/#DeepFakeDefenders/DeepFakeDefenders&Date) 111 | -------------------------------------------------------------------------------- /README_zh.md: -------------------------------------------------------------------------------- 1 |

DeepFake Defenders

2 |
如果您喜欢我们的项目,请在 GitHub 上给我们一个Star ⭐ 以获取最新更新。
3 | 4 |
5 | 6 | 7 | [![License](https://img.shields.io/badge/License-Apache%202.0-yellow)](https://github.com/VisionRush/DeepFakeDefenders/blob/main/LICENSE) 8 | ![GitHub contributors](https://img.shields.io/github/contributors/VisionRush/DeepFakeDefenders) 9 | [![Hits](https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fgithub.com%2FVisionRush%2FDeepFakeDefenders&count_bg=%2379C83D&title_bg=%23555555&icon=&icon_color=%23E7E7E7&title=Visitors&edge_flat=false)](https://hits.seeyoufarm.com) 10 | ![GitHub Repo stars](https://img.shields.io/github/stars/VisionRush/DeepFakeDefenders) 11 | [![GitHub issues](https://img.shields.io/github/issues/VisionRush/DeepFakeDefenders?color=critical&label=Issues)](https://github.com/PKU-YuanGroup/MoE-LLaVA/issues?q=is%3Aopen+is%3Aissue) 12 | [![GitHub closed issues](https://img.shields.io/github/issues-closed/VisionRush/DeepFakeDefenders?color=success&label=Issues)](https://github.com/PKU-YuanGroup/MoE-LLaVA/issues?q=is%3Aissue+is%3Aclosed)
13 | 14 |
15 | 16 |

17 | 18 |

19 | 20 | 💡 我们在这里提供了 [[英文文档 / ENGLISH DOC](README.md)] 和 [[韩文文档 / KOREAN DOC](README_kr.md)],我们十分欢迎和感谢您能够对我们的项目提出建议和贡献。 21 | 22 | ## 📣 News 23 | 24 | * **[2024.09.05]** 🔥 我们正式发布了Deepfake Defenders的初始版本,并在Deepfake挑战赛中获得了三等奖 25 | [[外滩大会](https://www.atecup.cn/deepfake)]. 26 | 27 | ## 🚀 快速开始 28 | ### 一、预训练模型准备 29 | 在开始使用之前,请将模型的ImageNet-1K预训练权重文件放置在`./pre_model`目录下,权重下载链接如下: 30 | ``` 31 | RepLKNet: https://drive.google.com/file/d/1vo-P3XB6mRLUeDzmgv90dOu73uCeLfZN/view?usp=sharing 32 | ConvNeXt: https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_384.pth 33 | ``` 34 | 35 | ### 二、训练 36 | 37 | #### 1. 更改数据集路径 38 | 将训练所需的训练集txt文件、验证集txt文件以及标签txt文件分别放置在dataset文件夹下,并命名为相同的文件名(dataset下有各个txt示例) 39 | #### 2. 更改超参数 40 | 针对所采用的两个模型,在main_train.py中分别需要更改如下参数: 41 | ```python 42 | RepLKNet---cfg.network.name = 'replknet'; cfg.train.batch_size = 16 43 | ConvNeXt---cfg.network.name = 'convnext'; cfg.train.batch_size = 24 44 | ``` 45 | 46 | #### 3. 启动训练 47 | ##### 单机多卡训练:(8卡) 48 | ```shell 49 | bash main.sh 50 | ``` 51 | ##### 单机单卡训练: 52 | ```shell 53 | CUDA_VISIBLE_DEVICES=0 python main_train_single_gpu.py 54 | ``` 55 | 56 | #### 4. 模型融合 57 | 在merge.py中更改ConvNeXt模型路径以及RepLKNet模型路径,执行python merge.py后获取最终推理测试模型。 58 | 59 | #### 5. 推理 60 | 61 | 示例如下,通过post请求接口请求,请求参数为图像路径,响应输出为模型预测的deepfake分数 62 | 63 | ```python 64 | #!/usr/bin/env python 65 | # -*- coding:utf-8 -*- 66 | import requests 67 | import json 68 | import requests 69 | import json 70 | 71 | header = { 72 | 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.107 Safari/537.36' 73 | } 74 | 75 | url = 'http://ip:10005/inter_api' 76 | image_path = './dataset/val_dataset/51aa9b8d0da890cd1d0c5029e3d89e3c.jpg' 77 | data_map = {'img_path':image_path} 78 | response = requests.post(url, data=json.dumps(data_map), headers=header) 79 | content = response.content 80 | print(json.loads(content)) 81 | ``` 82 | 83 | ### 三、docker 84 | #### 1. docker构建 85 | sudo docker build -t vision-rush-image:1.0.1 --network host . 86 | #### 2. 容器启动 87 | sudo docker run -d --name vision_rush_image --gpus=all --net host vision-rush-image:1.0.1 88 | 89 | ## Star History 90 | 91 | [![Star History Chart](https://api.star-history.com/svg?repos=VisionRush/DeepFakeDefenders&type=Date)](https://star-history.com/#DeepFakeDefenders/DeepFakeDefenders&Date) 92 | -------------------------------------------------------------------------------- /core/dsproc_mcls.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from PIL import Image 4 | from collections import OrderedDict 5 | from toolkit.dhelper import traverse_recursively 6 | import numpy as np 7 | import einops 8 | 9 | from torch import nn 10 | import timm 11 | import torch.nn.functional as F 12 | 13 | 14 | class SRMConv2d_simple(nn.Module): 15 | def __init__(self, inc=3): 16 | super(SRMConv2d_simple, self).__init__() 17 | self.truc = nn.Hardtanh(-3, 3) 18 | self.kernel = torch.from_numpy(self._build_kernel(inc)).float() 19 | 20 | def forward(self, x): 21 | out = F.conv2d(x, self.kernel, stride=1, padding=2) 22 | out = self.truc(out) 23 | 24 | return out 25 | 26 | def _build_kernel(self, inc): 27 | # filter1: KB 28 | filter1 = [[0, 0, 0, 0, 0], 29 | [0, -1, 2, -1, 0], 30 | [0, 2, -4, 2, 0], 31 | [0, -1, 2, -1, 0], 32 | [0, 0, 0, 0, 0]] 33 | # filter2:KV 34 | filter2 = [[-1, 2, -2, 2, -1], 35 | [2, -6, 8, -6, 2], 36 | [-2, 8, -12, 8, -2], 37 | [2, -6, 8, -6, 2], 38 | [-1, 2, -2, 2, -1]] 39 | # filter3:hor 2rd 40 | filter3 = [[0, 0, 0, 0, 0], 41 | [0, 0, 0, 0, 0], 42 | [0, 1, -2, 1, 0], 43 | [0, 0, 0, 0, 0], 44 | [0, 0, 0, 0, 0]] 45 | 46 | filter1 = np.asarray(filter1, dtype=float) / 4. 47 | filter2 = np.asarray(filter2, dtype=float) / 12. 48 | filter3 = np.asarray(filter3, dtype=float) / 2. 49 | # statck the filters 50 | filters = [[filter1], # , filter1, filter1], 51 | [filter2], # , filter2, filter2], 52 | [filter3]] # , filter3, filter3]] 53 | filters = np.array(filters) 54 | filters = np.repeat(filters, inc, axis=1) 55 | return filters 56 | 57 | 58 | class MultiClassificationProcessor(torch.utils.data.Dataset): 59 | 60 | def __init__(self, transform=None): 61 | self.transformer_ = transform 62 | self.extension_ = '.jpg .jpeg .png .bmp .webp .tif .eps' 63 | # load category info 64 | self.ctg_names_ = [] # ctg_idx to ctg_name 65 | self.ctg_name2idx_ = OrderedDict() # ctg_name to ctg_idx 66 | # load image infos 67 | self.img_names_ = [] # img_idx to img_name 68 | self.img_paths_ = [] # img_idx to img_path 69 | self.img_labels_ = [] # img_idx to img_label 70 | 71 | self.srm = SRMConv2d_simple() 72 | 73 | def load_data_from_dir(self, dataset_list): 74 | """Load image from folder. 75 | 76 | Args: 77 | dataset_list: dataset list, each folder is a category, format is [file_root]. 78 | """ 79 | # load sample 80 | for img_root in dataset_list: 81 | ctg_name = os.path.basename(img_root) 82 | self.ctg_name2idx_[ctg_name] = len(self.ctg_names_) 83 | self.ctg_names_.append(ctg_name) 84 | img_paths = [] 85 | traverse_recursively(img_root, img_paths, self.extension_) 86 | for img_path in img_paths: 87 | img_name = os.path.basename(img_path) 88 | self.img_names_.append(img_name) 89 | self.img_paths_.append(img_path) 90 | self.img_labels_.append(self.ctg_name2idx_[ctg_name]) 91 | print('log: category is %d(%s), image num is %d' % (self.ctg_name2idx_[ctg_name], ctg_name, len(img_paths))) 92 | 93 | def load_data_from_txt(self, img_list_txt, ctg_list_txt): 94 | """Load image from txt. 95 | 96 | Args: 97 | img_list_txt: image txt, format is [file_path, ctg_idx]. 98 | ctg_list_txt: category txt, format is [ctg_name, ctg_idx]. 99 | """ 100 | # check 101 | assert os.path.exists(img_list_txt), 'log: does not exist: {}'.format(img_list_txt) 102 | assert os.path.exists(ctg_list_txt), 'log: does not exist: {}'.format(ctg_list_txt) 103 | 104 | # load category 105 | # : open category info file 106 | with open(ctg_list_txt) as f: 107 | ctg_infos = [line.strip() for line in f.readlines()] 108 | # :load category name & category index 109 | for ctg_info in ctg_infos: 110 | tmp = ctg_info.split(' ') 111 | ctg_name = tmp[0] 112 | ctg_idx = int(tmp[-1]) 113 | self.ctg_name2idx_[ctg_name] = ctg_idx 114 | self.ctg_names_.append(ctg_name) 115 | 116 | # load sample 117 | # : open image info file 118 | with open(img_list_txt) as f: 119 | img_infos = [line.strip() for line in f.readlines()] 120 | # : load image path & category index 121 | for img_info in img_infos: 122 | tmp = img_info.split(' ') 123 | 124 | img_path = ' '.join(tmp[:-1]) 125 | img_name = img_path.split('/')[-1] 126 | ctg_idx = int(tmp[-1]) 127 | self.img_names_.append(img_name) 128 | self.img_paths_.append(img_path) 129 | self.img_labels_.append(ctg_idx) 130 | 131 | for ctg_name in self.ctg_names_: 132 | print('log: category is %d(%s), image num is %d' % (self.ctg_name2idx_[ctg_name], ctg_name, self.img_labels_.count(self.ctg_name2idx_[ctg_name]))) 133 | 134 | def _add_new_channels_worker(self, image): 135 | new_channels = [] 136 | 137 | image = einops.rearrange(image, "h w c -> c h w") 138 | image = (image- torch.as_tensor(timm.data.constants.IMAGENET_DEFAULT_MEAN).view(-1, 1, 1)) / torch.as_tensor(timm.data.constants.IMAGENET_DEFAULT_STD).view(-1, 1, 1) 139 | srm = self.srm(image.unsqueeze(0)).squeeze(0) 140 | new_channels.append(einops.rearrange(srm, "c h w -> h w c").numpy()) 141 | 142 | new_channels = np.concatenate(new_channels, axis=2) 143 | return torch.from_numpy(new_channels).float() 144 | 145 | def add_new_channels(self, images): 146 | images_copied = einops.rearrange(images, "c h w -> h w c") 147 | new_channels = self._add_new_channels_worker(images_copied) 148 | images_copied = torch.concatenate([images_copied, new_channels], dim=-1) 149 | images_copied = einops.rearrange(images_copied, "h w c -> c h w") 150 | 151 | return images_copied 152 | 153 | def __getitem__(self, index): 154 | img_path = self.img_paths_[index] 155 | img_label = self.img_labels_[index] 156 | 157 | img_data = Image.open(img_path).convert('RGB') 158 | img_size = img_data.size[::-1] # [h, w] 159 | 160 | if self.transformer_ is not None: 161 | img_data = self.transformer_[img_label](img_data) 162 | img_data = self.add_new_channels(img_data) 163 | 164 | return img_data, img_label, img_path, img_size 165 | 166 | def __len__(self): 167 | return len(self.img_names_) 168 | -------------------------------------------------------------------------------- /core/dsproc_mclsmfolder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from PIL import Image 4 | from collections import OrderedDict 5 | from toolkit.dhelper import traverse_recursively 6 | import random 7 | from torch import nn 8 | import numpy as np 9 | import timm 10 | import einops 11 | import torch.nn.functional as F 12 | 13 | 14 | class SRMConv2d_simple(nn.Module): 15 | def __init__(self, inc=3): 16 | super(SRMConv2d_simple, self).__init__() 17 | self.truc = nn.Hardtanh(-3, 3) 18 | self.kernel = torch.from_numpy(self._build_kernel(inc)).float() 19 | 20 | def forward(self, x): 21 | out = F.conv2d(x, self.kernel, stride=1, padding=2) 22 | out = self.truc(out) 23 | 24 | return out 25 | 26 | def _build_kernel(self, inc): 27 | # filter1: KB 28 | filter1 = [[0, 0, 0, 0, 0], 29 | [0, -1, 2, -1, 0], 30 | [0, 2, -4, 2, 0], 31 | [0, -1, 2, -1, 0], 32 | [0, 0, 0, 0, 0]] 33 | # filter2:KV 34 | filter2 = [[-1, 2, -2, 2, -1], 35 | [2, -6, 8, -6, 2], 36 | [-2, 8, -12, 8, -2], 37 | [2, -6, 8, -6, 2], 38 | [-1, 2, -2, 2, -1]] 39 | # filter3:hor 2rd 40 | filter3 = [[0, 0, 0, 0, 0], 41 | [0, 0, 0, 0, 0], 42 | [0, 1, -2, 1, 0], 43 | [0, 0, 0, 0, 0], 44 | [0, 0, 0, 0, 0]] 45 | 46 | filter1 = np.asarray(filter1, dtype=float) / 4. 47 | filter2 = np.asarray(filter2, dtype=float) / 12. 48 | filter3 = np.asarray(filter3, dtype=float) / 2. 49 | # statck the filters 50 | filters = [[filter1], # , filter1, filter1], 51 | [filter2], # , filter2, filter2], 52 | [filter3]] # , filter3, filter3]] 53 | filters = np.array(filters) 54 | filters = np.repeat(filters, inc, axis=1) 55 | return filters 56 | 57 | 58 | class MultiClassificationProcessor_mfolder(torch.utils.data.Dataset): 59 | def __init__(self, transform=None): 60 | self.transformer_ = transform 61 | self.extension_ = '.jpg .jpeg .png .bmp .webp .tif .eps' 62 | # load category info 63 | self.ctg_names_ = [] # ctg_idx to ctg_name 64 | self.ctg_name2idx_ = OrderedDict() # ctg_name to ctg_idx 65 | # load image infos 66 | self.img_names_ = [] # img_idx to img_name 67 | self.img_paths_ = [] # img_idx to img_path 68 | self.img_labels_ = [] # img_idx to img_label 69 | 70 | self.srm = SRMConv2d_simple() 71 | 72 | def load_data_from_dir_test(self, folders): 73 | 74 | # Load image from folder. 75 | 76 | # Args: 77 | # dataset_list: dictionary where key is a label and value is a list of folder paths. 78 | print(folders) 79 | img_paths = [] 80 | traverse_recursively(folders, img_paths, self.extension_) 81 | 82 | for img_path in img_paths: 83 | img_name = os.path.basename(img_path) 84 | self.img_names_.append(img_name) 85 | self.img_paths_.append(img_path) 86 | 87 | length = len(img_paths) 88 | print('log: {} image num is {}'.format(folders, length)) 89 | 90 | def load_data_from_dir(self, dataset_list): 91 | 92 | # Load image from folder. 93 | 94 | # Args: 95 | # dataset_list: dictionary where key is a label and value is a list of folder paths. 96 | 97 | for ctg_name, folders in dataset_list.items(): 98 | 99 | if ctg_name not in self.ctg_name2idx_: 100 | self.ctg_name2idx_[ctg_name] = len(self.ctg_names_) 101 | self.ctg_names_.append(ctg_name) 102 | 103 | for img_root in folders: 104 | img_paths = [] 105 | traverse_recursively(img_root, img_paths, self.extension_) 106 | 107 | print(img_root) 108 | 109 | length = len(img_paths) 110 | for i in range(length): 111 | img_path = img_paths[i] 112 | img_name = os.path.basename(img_path) 113 | self.img_names_.append(img_name) 114 | self.img_paths_.append(img_path) 115 | self.img_labels_.append(self.ctg_name2idx_[ctg_name]) 116 | 117 | print('log: category is %d(%s), image num is %d' % (self.ctg_name2idx_[ctg_name], ctg_name, length)) 118 | 119 | def load_data_from_txt(self, img_list_txt, ctg_list_txt): 120 | """Load image from txt. 121 | 122 | Args: 123 | img_list_txt: image txt, format is [file_path, ctg_idx]. 124 | ctg_list_txt: category txt, format is [ctg_name, ctg_idx]. 125 | """ 126 | # check 127 | assert os.path.exists(img_list_txt), 'log: does not exist: {}'.format(img_list_txt) 128 | assert os.path.exists(ctg_list_txt), 'log: does not exist: {}'.format(ctg_list_txt) 129 | 130 | # load category 131 | # : open category info file 132 | with open(ctg_list_txt) as f: 133 | ctg_infos = [line.strip() for line in f.readlines()] 134 | # :load category name & category index 135 | for ctg_info in ctg_infos: 136 | tmp = ctg_info.split(' ') 137 | ctg_name = tmp[0] 138 | ctg_idx = int(tmp[1]) 139 | self.ctg_name2idx_[ctg_name] = ctg_idx 140 | self.ctg_names_.append(ctg_name) 141 | 142 | # load sample 143 | # : open image info file 144 | with open(img_list_txt) as f: 145 | img_infos = [line.strip() for line in f.readlines()] 146 | random.shuffle(img_infos) 147 | # : load image path & category index 148 | for img_info in img_infos: 149 | img_path, ctg_name = img_info.rsplit(' ', 1) 150 | img_name = img_path.split('/')[-1] 151 | ctg_idx = int(ctg_name) 152 | self.img_names_.append(img_name) 153 | self.img_paths_.append(img_path) 154 | self.img_labels_.append(ctg_idx) 155 | 156 | for ctg_name in self.ctg_names_: 157 | print('log: category is %d(%s), image num is %d' % (self.ctg_name2idx_[ctg_name], ctg_name, self.img_labels_.count(self.ctg_name2idx_[ctg_name]))) 158 | 159 | def _add_new_channels_worker(self, image): 160 | new_channels = [] 161 | 162 | image = einops.rearrange(image, "h w c -> c h w") 163 | image = (image- torch.as_tensor(timm.data.constants.IMAGENET_DEFAULT_MEAN).view(-1, 1, 1)) / torch.as_tensor(timm.data.constants.IMAGENET_DEFAULT_STD).view(-1, 1, 1) 164 | srm = self.srm(image.unsqueeze(0)).squeeze(0) 165 | new_channels.append(einops.rearrange(srm, "c h w -> h w c").numpy()) 166 | 167 | new_channels = np.concatenate(new_channels, axis=2) 168 | return torch.from_numpy(new_channels).float() 169 | 170 | def add_new_channels(self, images): 171 | images_copied = einops.rearrange(images, "c h w -> h w c") 172 | new_channels = self._add_new_channels_worker(images_copied) 173 | images_copied = torch.concatenate([images_copied, new_channels], dim=-1) 174 | images_copied = einops.rearrange(images_copied, "h w c -> c h w") 175 | 176 | return images_copied 177 | 178 | def __getitem__(self, index): 179 | img_path = self.img_paths_[index] 180 | 181 | img_data = Image.open(img_path).convert('RGB') 182 | img_size = img_data.size[::-1] # [h, w] 183 | 184 | all_data = [] 185 | for transform in self.transformer_: 186 | current_data = transform(img_data) 187 | current_data = self.add_new_channels(current_data) 188 | all_data.append(current_data) 189 | img_label = self.img_labels_[index] 190 | 191 | return torch.stack(all_data, dim=0), img_label, img_path, img_size 192 | 193 | def __len__(self): 194 | return len(self.img_names_) 195 | -------------------------------------------------------------------------------- /core/mengine.py: -------------------------------------------------------------------------------- 1 | import os 2 | import datetime 3 | import sys 4 | 5 | sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn.parallel import DistributedDataParallel as DDP 9 | from tqdm import tqdm 10 | from toolkit.cmetric import MultiClassificationMetric, MultilabelClassificationMetric, simple_accuracy 11 | from toolkit.chelper import load_model 12 | from torch import distributed as dist 13 | from sklearn.metrics import roc_auc_score 14 | import numpy as np 15 | import time 16 | 17 | 18 | def reduce_tensor(tensor, n): 19 | rt = tensor.clone() 20 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 21 | rt /= n 22 | return rt 23 | 24 | 25 | def gather_tensor(tensor, n): 26 | rt = [torch.zeros_like(tensor) for _ in range(n)] 27 | dist.all_gather(rt, tensor) 28 | return torch.cat(rt, dim=0) 29 | 30 | 31 | class TrainEngine(object): 32 | def __init__(self, local_rank, world_size=0, DDP=False, SyncBatchNorm=False): 33 | # init setting 34 | self.local_rank = local_rank 35 | self.world_size = world_size 36 | self.device_ = f'cuda:{local_rank}' 37 | # create tool 38 | self.cls_meter_ = MultilabelClassificationMetric() 39 | self.loss_meter_ = MultiClassificationMetric() 40 | self.top1_meter_ = MultiClassificationMetric() 41 | self.DDP = DDP 42 | self.SyncBN = SyncBatchNorm 43 | 44 | def create_env(self, cfg): 45 | # create network 46 | self.netloc_ = load_model(cfg.network.name, cfg.network.class_num, self.SyncBN) 47 | print(self.netloc_) 48 | 49 | self.netloc_.cuda() 50 | if self.DDP: 51 | if self.SyncBN: 52 | self.netloc_ = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.netloc_) 53 | self.netloc_ = DDP(self.netloc_, 54 | device_ids=[self.local_rank], 55 | broadcast_buffers=True, 56 | ) 57 | 58 | # create loss function 59 | self.criterion_ = nn.CrossEntropyLoss().cuda() 60 | 61 | # create optimizer 62 | self.optimizer_ = torch.optim.AdamW(self.netloc_.parameters(), lr=cfg.optimizer.lr, 63 | betas=(cfg.optimizer.beta1, cfg.optimizer.beta2), eps=cfg.optimizer.eps, 64 | weight_decay=cfg.optimizer.weight_decay) 65 | 66 | # create scheduler 67 | self.scheduler_ = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer_, cfg.train.epoch_num, 68 | eta_min=cfg.scheduler.min_lr) 69 | 70 | def train_multi_class(self, train_loader, epoch_idx, ema_start): 71 | starttime = datetime.datetime.now() 72 | # switch to train mode 73 | self.netloc_.train() 74 | self.loss_meter_.reset() 75 | self.top1_meter_.reset() 76 | # train 77 | train_loader = tqdm(train_loader, desc='train', ascii=True) 78 | for imgs_idx, (imgs_tensor, imgs_label, _, _) in enumerate(train_loader): 79 | # set cuda 80 | imgs_tensor = imgs_tensor.cuda() # [256, 3, 224, 224] 81 | imgs_label = imgs_label.cuda() 82 | # clear gradients(zero the parameter gradients) 83 | self.optimizer_.zero_grad() 84 | # calc forward 85 | preds = self.netloc_(imgs_tensor) 86 | # calc acc & loss 87 | loss = self.criterion_(preds, imgs_label) 88 | 89 | # backpropagation 90 | loss.backward() 91 | # update parameters 92 | self.optimizer_.step() 93 | 94 | # EMA update 95 | if ema_start: 96 | self.ema_model.update(self.netloc_) 97 | 98 | # accumulate loss & acc 99 | acc1 = simple_accuracy(preds, imgs_label) 100 | if self.DDP: 101 | loss = reduce_tensor(loss, self.world_size) 102 | acc1 = reduce_tensor(acc1, self.world_size) 103 | self.loss_meter_.update(loss.data.item()) 104 | self.top1_meter_.update(acc1.item()) 105 | 106 | # eval 107 | top1 = self.top1_meter_.mean 108 | loss = self.loss_meter_.mean 109 | endtime = datetime.datetime.now() 110 | self.lr_ = self.optimizer_.param_groups[0]['lr'] 111 | if self.local_rank == 0: 112 | print('log: epoch-%d, train_top1 is %f, train_loss is %f, lr is %f, time is %d' % ( 113 | epoch_idx, top1, loss, self.lr_, (endtime - starttime).seconds)) 114 | # return 115 | return top1, loss, self.lr_ 116 | 117 | def val_multi_class(self, val_loader, epoch_idx): 118 | np.set_printoptions(suppress=True) 119 | starttime = datetime.datetime.now() 120 | # switch to train mode 121 | self.netloc_.eval() 122 | self.loss_meter_.reset() 123 | self.top1_meter_.reset() 124 | self.all_probs = [] 125 | self.all_labels = [] 126 | # eval 127 | with torch.no_grad(): 128 | val_loader = tqdm(val_loader, desc='valid', ascii=True) 129 | for imgs_idx, (imgs_tensor, imgs_label, _, _) in enumerate(val_loader): 130 | # set cuda 131 | imgs_tensor = imgs_tensor.cuda() 132 | imgs_label = imgs_label.cuda() 133 | # calc forward 134 | preds = self.netloc_(imgs_tensor) 135 | # calc acc & loss 136 | loss = self.criterion_(preds, imgs_label) 137 | # accumulate loss & acc 138 | acc1 = simple_accuracy(preds, imgs_label) 139 | 140 | outputs_scores = nn.functional.softmax(preds, dim=1) 141 | outputs_scores = torch.cat((outputs_scores, imgs_label.unsqueeze(-1)), dim=-1) 142 | 143 | if self.DDP: 144 | loss = reduce_tensor(loss, self.world_size) 145 | acc1 = reduce_tensor(acc1, self.world_size) 146 | outputs_scores = gather_tensor(outputs_scores, self.world_size) 147 | 148 | outputs_scores, label = outputs_scores[:, -2], outputs_scores[:, -1] 149 | self.all_probs += [float(i) for i in outputs_scores] 150 | self.all_labels += [ float(i) for i in label] 151 | self.loss_meter_.update(loss.item()) 152 | self.top1_meter_.update(acc1.item()) 153 | # eval 154 | top1 = self.top1_meter_.mean 155 | loss = self.loss_meter_.mean 156 | auc = roc_auc_score(self.all_labels, self.all_probs) 157 | 158 | endtime = datetime.datetime.now() 159 | if self.local_rank == 0: 160 | print('log: epoch-%d, val_top1 is %f, val_loss is %f, auc is %f, time is %d' % ( 161 | epoch_idx, top1, loss, auc, (endtime - starttime).seconds)) 162 | 163 | # update lr 164 | self.scheduler_.step() 165 | 166 | # return 167 | return top1, loss, auc 168 | 169 | def val_ema(self, val_loader, epoch_idx): 170 | np.set_printoptions(suppress=True) 171 | starttime = datetime.datetime.now() 172 | # switch to train mode 173 | self.ema_model.module.eval() 174 | self.loss_meter_.reset() 175 | self.top1_meter_.reset() 176 | self.all_probs = [] 177 | self.all_labels = [] 178 | # eval 179 | with torch.no_grad(): 180 | val_loader = tqdm(val_loader, desc='valid', ascii=True) 181 | for imgs_idx, (imgs_tensor, imgs_label, _, _) in enumerate(val_loader): 182 | # set cuda 183 | imgs_tensor = imgs_tensor.cuda() 184 | imgs_label = imgs_label.cuda() 185 | # calc forward 186 | preds = self.ema_model.module(imgs_tensor) 187 | 188 | # calc acc & loss 189 | loss = self.criterion_(preds, imgs_label) 190 | # accumulate loss & acc 191 | acc1 = simple_accuracy(preds, imgs_label) 192 | 193 | outputs_scores = nn.functional.softmax(preds, dim=1) 194 | outputs_scores = torch.cat((outputs_scores, imgs_label.unsqueeze(-1)), dim=-1) 195 | 196 | if self.DDP: 197 | loss = reduce_tensor(loss, self.world_size) 198 | acc1 = reduce_tensor(acc1, self.world_size) 199 | outputs_scores = gather_tensor(outputs_scores, self.world_size) 200 | 201 | outputs_scores, label = outputs_scores[:, -2], outputs_scores[:, -1] 202 | self.all_probs += [float(i) for i in outputs_scores] 203 | self.all_labels += [ float(i) for i in label] 204 | self.loss_meter_.update(loss.item()) 205 | self.top1_meter_.update(acc1.item()) 206 | # eval 207 | top1 = self.top1_meter_.mean 208 | loss = self.loss_meter_.mean 209 | auc = roc_auc_score(self.all_labels, self.all_probs) 210 | 211 | endtime = datetime.datetime.now() 212 | if self.local_rank == 0: 213 | print('log: epoch-%d, ema_val_top1 is %f, ema_val_loss is %f, ema_auc is %f, time is %d' % ( 214 | epoch_idx, top1, loss, auc, (endtime - starttime).seconds)) 215 | 216 | # return 217 | return top1, loss, auc 218 | 219 | def save_checkpoint(self, file_root, epoch_idx, train_map, val_map, ema_start): 220 | 221 | file_name = os.path.join(file_root, 222 | time.strftime('%Y%m%d-%H-%M', time.localtime()) + '-' + str(epoch_idx) + '.pth') 223 | 224 | if self.DDP: 225 | stact_dict = self.netloc_.module.state_dict() 226 | else: 227 | stact_dict = self.netloc_.state_dict() 228 | 229 | torch.save( 230 | { 231 | 'epoch_idx': epoch_idx, 232 | 'state_dict': stact_dict, 233 | 'train_map': train_map, 234 | 'val_map': val_map, 235 | 'lr': self.lr_, 236 | 'optimizer': self.optimizer_.state_dict(), 237 | 'scheduler': self.scheduler_.state_dict() 238 | }, file_name) 239 | 240 | if ema_start: 241 | ema_file_name = os.path.join(file_root, 242 | time.strftime('%Y%m%d-%H-%M', time.localtime()) + '-EMA-' + str(epoch_idx) + '.pth') 243 | ema_stact_dict = self.ema_model.module.module.state_dict() 244 | torch.save( 245 | { 246 | 'epoch_idx': epoch_idx, 247 | 'state_dict': ema_stact_dict, 248 | 'train_map': train_map, 249 | 'val_map': val_map, 250 | 'lr': self.lr_, 251 | 'optimizer': self.optimizer_.state_dict(), 252 | 'scheduler': self.scheduler_.state_dict() 253 | }, ema_file_name) 254 | -------------------------------------------------------------------------------- /dataset/label.txt: -------------------------------------------------------------------------------- 1 | real 0 2 | fake 1 3 | -------------------------------------------------------------------------------- /dataset/train.txt: -------------------------------------------------------------------------------- 1 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/nature/b580b1fc51d19fc25d2969de07669c21.jpg 0 2 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/df36afc7a12cf840a961743e08bdd596.jpg 1 3 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/f9cec5f76c7f653c2f57d66d7b4ecee0.jpg 1 4 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/a81dc092765e18f3e343b78418cf9371.jpg 1 5 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/56461dd348c9434f44dc810fd06a640e.jpg 1 6 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/1124b822bb0f1076a9914aa454bbd65f.jpg 1 7 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/2fa1aae309a57e975c90285001d43982.jpg 1 8 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/nature/921604adb7ff623bd2fe32d454a1469c.jpg 0 9 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/773f7e1488a29cc52c52b154250df907.jpg 1 10 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/371ad468133750a5fdc670063a6b115a.jpg 1 11 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/780821e5db83764213aae04ac5a54671.jpg 1 12 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/nature/39c253b508dea029854a01de7a1389ab.jpg 0 13 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/9726ea54c28b55e38a5c7cf2fbd8d9da.jpg 1 14 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/4112df80cf4849d05196dc23ecf994cd.jpg 1 15 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/nature/f9858bf9cb1316c273d272249b725912.jpg 0 16 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/nature/bcb50b7c399f978aeb5432c9d80d855c.jpg 0 17 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/2ffd8043985f407069d77dfaae68e032.jpg 1 18 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/0bfd7972fae0bc19f6087fc3b5ac6db8.jpg 1 19 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/daf90a7842ff5bd486ec10fbed22e932.jpg 1 20 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/9dbebbfbc11e8b757b090f5a5ad3fa48.jpg 1 21 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/b8d3d8c2c6cac9fb5b485b94e553f432.jpg 1 22 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/0a59fe7481dc0f9a7dc76cb0bdd3ffe6.jpg 1 23 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/5b82f90800df81625ac78e51f51f1b2e.jpg 1 24 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/nature/badd574c91e6180ef829e2b0d67a3efb.jpg 0 25 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/7412c839b06db42aac1e022096b08031.jpg 1 26 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/nature/81e4b3e7ce314bcd28dde338caeda836.jpg 0 27 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/nature/aa87a563062e6b0741936609014329ab.jpg 0 28 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/0a4c5fdcbe7a3dca6c5a9ee45fd32bef.jpg 1 29 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/adfa3e7ea00ca1ce7a603a297c9ae701.jpg 1 30 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/31fcc0e2f049347b7220dd9eb4f66631.jpg 1 31 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/e699df9505f47dcbb1dcef6858f921e7.jpg 1 32 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/71e7824486a7fef83fa60324dd1cbba8.jpg 1 33 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/ed25cbc58d4f41f7c97201b1ba959834.jpg 1 34 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/nature/4b3d2176926766a4c0e259605dbbc67a.jpg 0 35 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/1dfd77d7ea1a1f05b9c2f532b2a91c62.jpg 1 36 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/8e6bea47a8dd71c09c0272be5e1ca584.jpg 1 -------------------------------------------------------------------------------- /dataset/val.txt: -------------------------------------------------------------------------------- 1 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/590e6bc87984f2b4e6d1ed6d4e889088.jpg 1 2 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/720f2234a138382af10b3e2bb6c373cd.jpg 1 3 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/01cb2d00e5d2412ce3cd1d1bb58d7d4e.jpg 1 4 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/41d70d6650eba9036cbb145b29ad14f7.jpg 1 5 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/nature/f7f4df6525cdf0ec27f8f40e2e980ad6.jpg 0 6 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/1dddd03ae6911514a6f1d3117e7e3fd3.jpg 1 7 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/nature/d33054b233cb2e0ebddbe63611626924.jpg 0 8 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/nature/27f2e00bd12d11173422119dfad885ef.jpg 0 9 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/nature/1a0cb2060fbc2065f2ba74f5b2833bc5.jpg 0 10 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/7e0668030bb9a6598621cc7f12600660.jpg 1 11 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/4d7548156c06f9ab12d6daa6524956ea.jpg 1 12 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/cb6a567da3e2f0bcfd19f81756242ba1.jpg 1 13 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/nature/fbff80c8dddf176f310fc10748ce5796.jpg 0 14 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/d68dce56f306f7b0965329f2389b2d5a.jpg 1 15 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/610198886f92d595aaf7cd5c83521ccb.jpg 1 16 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/987a546ad4b3fb76552a89af9b8f5761.jpg 1 17 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/nature/db80dfbe1bb84fe1f9c3e1f21f80561b.jpg 0 18 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/133c775e0516b078f2b951fe49d6b04a.jpg 1 19 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/9584c3c8e012f92b003498793a8a6492.jpg 1 20 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/nature/51aa9b8d0da890cd1d0c5029e3d89e3c.jpg 0 21 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/965c7d35e7a714603587a4710c357ede.jpg 1 22 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/7db2752f0d45637ff64e67f14099378e.jpg 1 23 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/cd9838425bb7e68f165b25a148ba8146.jpg 1 24 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/88f45da6e89e59842a9e6339d239a78f.jpg 1 -------------------------------------------------------------------------------- /dataset/val_dataset/51aa9b8d0da890cd1d0c5029e3d89e3c.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VisionRush/DeepFakeDefenders/13117ecf91c215a167126a6962fdd1525f7c957e/dataset/val_dataset/51aa9b8d0da890cd1d0c5029e3d89e3c.jpg -------------------------------------------------------------------------------- /images/competition_title.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VisionRush/DeepFakeDefenders/13117ecf91c215a167126a6962fdd1525f7c957e/images/competition_title.png -------------------------------------------------------------------------------- /infer_api.py: -------------------------------------------------------------------------------- 1 | import uvicorn 2 | from fastapi import FastAPI, Body 3 | from pydantic import BaseModel, Field 4 | import sys 5 | import os 6 | import json 7 | from main_infer import INFER_API 8 | 9 | 10 | infer_api = INFER_API() 11 | 12 | # create FastAPI instance 13 | app = FastAPI() 14 | 15 | 16 | class inputModel(BaseModel): 17 | img_path: str = Field(..., description="image path", examples=[""]) 18 | 19 | # Call model interface, post request 20 | @app.post("/inter_api") 21 | def inter_api(input_model: inputModel): 22 | img_path = input_model.img_path 23 | infer_api = INFER_API() 24 | score = infer_api.test(img_path) 25 | return score 26 | 27 | 28 | # run 29 | if __name__ == '__main__': 30 | uvicorn.run(app='infer_api:app', 31 | host='0.0.0.0', 32 | port=10005, 33 | reload=False, 34 | workers=1 35 | ) 36 | -------------------------------------------------------------------------------- /main.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --use_env main_train.py -------------------------------------------------------------------------------- /main_infer.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import timm 4 | import einops 5 | import torch 6 | from torch import nn 7 | from toolkit.dtransform import create_transforms_inference, create_transforms_inference1,\ 8 | create_transforms_inference2,\ 9 | create_transforms_inference3,\ 10 | create_transforms_inference4,\ 11 | create_transforms_inference5 12 | from toolkit.chelper import load_model 13 | import torch.nn.functional as F 14 | 15 | 16 | def extract_model_from_pth(params_path, net_model): 17 | checkpoint = torch.load(params_path) 18 | state_dict = checkpoint['state_dict'] 19 | 20 | net_model.load_state_dict(state_dict, strict=True) 21 | 22 | return net_model 23 | 24 | 25 | class SRMConv2d_simple(nn.Module): 26 | def __init__(self, inc=3): 27 | super(SRMConv2d_simple, self).__init__() 28 | self.truc = nn.Hardtanh(-3, 3) 29 | self.kernel = torch.from_numpy(self._build_kernel(inc)).float() 30 | 31 | def forward(self, x): 32 | out = F.conv2d(x, self.kernel, stride=1, padding=2) 33 | out = self.truc(out) 34 | 35 | return out 36 | 37 | def _build_kernel(self, inc): 38 | # filter1: KB 39 | filter1 = [[0, 0, 0, 0, 0], 40 | [0, -1, 2, -1, 0], 41 | [0, 2, -4, 2, 0], 42 | [0, -1, 2, -1, 0], 43 | [0, 0, 0, 0, 0]] 44 | # filter2:KV 45 | filter2 = [[-1, 2, -2, 2, -1], 46 | [2, -6, 8, -6, 2], 47 | [-2, 8, -12, 8, -2], 48 | [2, -6, 8, -6, 2], 49 | [-1, 2, -2, 2, -1]] 50 | # filter3:hor 2rd 51 | filter3 = [[0, 0, 0, 0, 0], 52 | [0, 0, 0, 0, 0], 53 | [0, 1, -2, 1, 0], 54 | [0, 0, 0, 0, 0], 55 | [0, 0, 0, 0, 0]] 56 | 57 | filter1 = np.asarray(filter1, dtype=float) / 4. 58 | filter2 = np.asarray(filter2, dtype=float) / 12. 59 | filter3 = np.asarray(filter3, dtype=float) / 2. 60 | # statck the filters 61 | filters = [[filter1], # , filter1, filter1], 62 | [filter2], # , filter2, filter2], 63 | [filter3]] # , filter3, filter3]] 64 | filters = np.array(filters) 65 | filters = np.repeat(filters, inc, axis=1) 66 | return filters 67 | 68 | 69 | class INFER_API: 70 | 71 | _instance = None 72 | 73 | def __new__(cls): 74 | if cls._instance is None: 75 | cls._instance = super(INFER_API, cls).__new__(cls) 76 | cls._instance.initialize() 77 | return cls._instance 78 | 79 | def initialize(self): 80 | self.transformer_ = [create_transforms_inference(h=512, w=512), 81 | create_transforms_inference1(h=512, w=512), 82 | create_transforms_inference2(h=512, w=512), 83 | create_transforms_inference3(h=512, w=512), 84 | create_transforms_inference4(h=512, w=512), 85 | create_transforms_inference5(h=512, w=512)] 86 | self.srm = SRMConv2d_simple() 87 | 88 | # model init 89 | self.model = load_model('all', 2) 90 | model_path = './final_model_csv/final_model.pth' 91 | self.model = extract_model_from_pth(model_path, self.model) 92 | 93 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 94 | self.model = self.model.to(device) 95 | 96 | self.model.eval() 97 | 98 | def _add_new_channels_worker(self, image): 99 | new_channels = [] 100 | 101 | image = einops.rearrange(image, "h w c -> c h w") 102 | image = (image - torch.as_tensor(timm.data.constants.IMAGENET_DEFAULT_MEAN).view(-1, 1, 1)) / torch.as_tensor( 103 | timm.data.constants.IMAGENET_DEFAULT_STD).view(-1, 1, 1) 104 | srm = self.srm(image.unsqueeze(0)).squeeze(0) 105 | new_channels.append(einops.rearrange(srm, "c h w -> h w c").numpy()) 106 | 107 | new_channels = np.concatenate(new_channels, axis=2) 108 | return torch.from_numpy(new_channels).float() 109 | 110 | def add_new_channels(self, images): 111 | images_copied = einops.rearrange(images, "c h w -> h w c") 112 | new_channels = self._add_new_channels_worker(images_copied) 113 | images_copied = torch.concatenate([images_copied, new_channels], dim=-1) 114 | images_copied = einops.rearrange(images_copied, "h w c -> c h w") 115 | 116 | return images_copied 117 | 118 | def test(self, img_path): 119 | # img load 120 | img_data = Image.open(img_path).convert('RGB') 121 | 122 | # transform 123 | all_data = [] 124 | for transform in self.transformer_: 125 | current_data = transform(img_data) 126 | current_data = self.add_new_channels(current_data) 127 | all_data.append(current_data) 128 | img_tensor = torch.stack(all_data, dim=0).unsqueeze(0).cuda() 129 | 130 | preds = self.model(img_tensor) 131 | 132 | return round(float(preds), 20) 133 | 134 | 135 | def main(): 136 | img = '51aa9b8d0da890cd1d0c5029e3d89e3c.jpg' 137 | infer_api = INFER_API() 138 | print(infer_api.test(img)) 139 | 140 | 141 | if __name__ == '__main__': 142 | main() -------------------------------------------------------------------------------- /main_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import datetime 4 | import torch 5 | import sys 6 | 7 | sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 8 | from torch.utils.tensorboard import SummaryWriter 9 | from core.dsproc_mcls import MultiClassificationProcessor 10 | from core.mengine import TrainEngine 11 | from toolkit.dtransform import create_transforms_inference, transforms_imagenet_train 12 | from toolkit.yacs import CfgNode as CN 13 | from timm.utils import ModelEmaV3 14 | 15 | import warnings 16 | warnings.filterwarnings("ignore") 17 | 18 | # check 19 | print(torch.__version__) 20 | print(torch.cuda.is_available()) 21 | 22 | # init 23 | cfg = CN(new_allowed=True) 24 | 25 | # dataset dir 26 | ctg_list = './dataset/label.txt' 27 | train_list = './dataset/train.txt' 28 | val_list = './dataset/val.txt' 29 | 30 | # : network 31 | cfg.network = CN(new_allowed=True) 32 | cfg.network.name = 'replknet' 33 | cfg.network.class_num = 2 34 | cfg.network.input_size = 384 35 | 36 | # : train params 37 | mean = (0.485, 0.456, 0.406) 38 | std = (0.229, 0.224, 0.225) 39 | 40 | cfg.train = CN(new_allowed=True) 41 | cfg.train.resume = False 42 | cfg.train.resume_path = '' 43 | cfg.train.params_path = '' 44 | cfg.train.batch_size = 16 45 | cfg.train.epoch_num = 20 46 | cfg.train.epoch_start = 0 47 | cfg.train.worker_num = 8 48 | 49 | # : optimizer params 50 | cfg.optimizer = CN(new_allowed=True) 51 | cfg.optimizer.lr = 1e-4 * 1 52 | cfg.optimizer.weight_decay = 1e-2 53 | cfg.optimizer.momentum = 0.9 54 | cfg.optimizer.beta1 = 0.9 55 | cfg.optimizer.beta2 = 0.999 56 | cfg.optimizer.eps = 1e-8 57 | 58 | # : scheduler params 59 | cfg.scheduler = CN(new_allowed=True) 60 | cfg.scheduler.min_lr = 1e-6 61 | 62 | # DDP init 63 | local_rank = int(os.environ['LOCAL_RANK']) 64 | device = 'cuda:{}'.format(local_rank) 65 | torch.cuda.set_device(local_rank) 66 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 67 | world_size = torch.distributed.get_world_size() 68 | rank = torch.distributed.get_rank() 69 | 70 | # init path 71 | task = 'competition' 72 | log_root = 'output/' + datetime.datetime.now().strftime("%Y-%m-%d") + '-' + time.strftime( 73 | "%H-%M-%S") + '_' + cfg.network.name + '_' + f"to_{task}_BinClass" 74 | if local_rank == 0: 75 | if not os.path.exists(log_root): 76 | os.makedirs(log_root) 77 | writer = SummaryWriter(log_root) 78 | 79 | # create engine 80 | train_engine = TrainEngine(local_rank, world_size, DDP=True, SyncBatchNorm=True) 81 | train_engine.create_env(cfg) 82 | 83 | # create transforms 84 | transforms_dict ={ 85 | 0 : transforms_imagenet_train(img_size=(cfg.network.input_size, cfg.network.input_size)), 86 | 1 : transforms_imagenet_train(img_size=(cfg.network.input_size, cfg.network.input_size), jpeg_compression=1), 87 | } 88 | 89 | transforms_dict_test ={ 90 | 0: create_transforms_inference(h=512, w=512), 91 | 1: create_transforms_inference(h=512, w=512), 92 | } 93 | 94 | transform = transforms_dict 95 | transform_test = transforms_dict_test 96 | 97 | # create dataset 98 | trainset = MultiClassificationProcessor(transform) 99 | trainset.load_data_from_txt(train_list, ctg_list) 100 | 101 | valset = MultiClassificationProcessor(transform_test) 102 | valset.load_data_from_txt(val_list, ctg_list) 103 | 104 | train_sampler = torch.utils.data.distributed.DistributedSampler(trainset) 105 | val_sampler = torch.utils.data.distributed.DistributedSampler(valset) 106 | 107 | # create dataloader 108 | train_loader = torch.utils.data.DataLoader(dataset=trainset, 109 | batch_size=cfg.train.batch_size, 110 | sampler=train_sampler, 111 | num_workers=cfg.train.worker_num, 112 | pin_memory=True, 113 | drop_last=True) 114 | 115 | val_loader = torch.utils.data.DataLoader(dataset=valset, 116 | batch_size=cfg.train.batch_size, 117 | sampler=val_sampler, 118 | num_workers=cfg.train.worker_num, 119 | pin_memory=True, 120 | drop_last=False) 121 | 122 | train_log_txtFile = log_root + "/" + "train_log.txt" 123 | f_open = open(train_log_txtFile, "w") 124 | 125 | # train & Val & Test 126 | best_test_mAP = 0.0 127 | best_test_idx = 0.0 128 | ema_start = True 129 | train_engine.ema_model = ModelEmaV3(train_engine.netloc_).cuda() 130 | for epoch_idx in range(cfg.train.epoch_start, cfg.train.epoch_num): 131 | # train 132 | train_top1, train_loss, train_lr = train_engine.train_multi_class(train_loader=train_loader, epoch_idx=epoch_idx, ema_start=ema_start) 133 | # val 134 | val_top1, val_loss, val_auc = train_engine.val_multi_class(val_loader=val_loader, epoch_idx=epoch_idx) 135 | # ema_val 136 | if ema_start: 137 | ema_val_top1, ema_val_loss, ema_val_auc = train_engine.val_ema(val_loader=val_loader, epoch_idx=epoch_idx) 138 | 139 | # check mAP and save 140 | if local_rank == 0: 141 | train_engine.save_checkpoint(log_root, epoch_idx, train_top1, val_top1, ema_start) 142 | 143 | if ema_start: 144 | outInfo = f"epoch_idx = {epoch_idx}, train_top1={train_top1}, train_loss={train_loss},val_top1={val_top1},val_loss={val_loss}, val_auc={val_auc}, ema_val_top1={ema_val_top1}, ema_val_loss={ema_val_loss}, ema_val_auc={ema_val_auc} \n" 145 | else: 146 | outInfo = f"epoch_idx = {epoch_idx}, train_top1={train_top1}, train_loss={train_loss},val_top1={val_top1},val_loss={val_loss}, val_auc={val_auc} \n" 147 | 148 | print(outInfo) 149 | 150 | f_open.write(outInfo) 151 | f_open.flush() 152 | 153 | # curve all mAP & mLoss 154 | writer.add_scalars('top1', {'train': train_top1, 'valid': val_top1}, epoch_idx) 155 | writer.add_scalars('loss', {'train': train_loss, 'valid': val_loss}, epoch_idx) 156 | 157 | # curve lr 158 | writer.add_scalar('train_lr', train_lr, epoch_idx) 159 | -------------------------------------------------------------------------------- /main_train_single_gpu.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import datetime 4 | import torch 5 | import sys 6 | 7 | sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 8 | from torch.utils.tensorboard import SummaryWriter 9 | from core.dsproc_mcls import MultiClassificationProcessor 10 | from core.mengine import TrainEngine 11 | from toolkit.dtransform import create_transforms_inference, transforms_imagenet_train 12 | from toolkit.yacs import CfgNode as CN 13 | from timm.utils import ModelEmaV3 14 | 15 | import warnings 16 | 17 | warnings.filterwarnings("ignore") 18 | 19 | # check 20 | print(torch.__version__) 21 | print(torch.cuda.is_available()) 22 | 23 | # init 24 | cfg = CN(new_allowed=True) 25 | 26 | # dataset dir 27 | ctg_list = './dataset/label.txt' 28 | train_list = './dataset/train.txt' 29 | val_list = './dataset/val.txt' 30 | 31 | # : network 32 | cfg.network = CN(new_allowed=True) 33 | cfg.network.name = 'replknet' 34 | cfg.network.class_num = 2 35 | cfg.network.input_size = 384 36 | 37 | # : train params 38 | mean = (0.485, 0.456, 0.406) 39 | std = (0.229, 0.224, 0.225) 40 | 41 | cfg.train = CN(new_allowed=True) 42 | cfg.train.resume = False 43 | cfg.train.resume_path = '' 44 | cfg.train.params_path = '' 45 | cfg.train.batch_size = 16 46 | cfg.train.epoch_num = 20 47 | cfg.train.epoch_start = 0 48 | cfg.train.worker_num = 8 49 | 50 | # : optimizer params 51 | cfg.optimizer = CN(new_allowed=True) 52 | cfg.optimizer.lr = 1e-4 * 1 53 | cfg.optimizer.weight_decay = 1e-2 54 | cfg.optimizer.momentum = 0.9 55 | cfg.optimizer.beta1 = 0.9 56 | cfg.optimizer.beta2 = 0.999 57 | cfg.optimizer.eps = 1e-8 58 | 59 | # : scheduler params 60 | cfg.scheduler = CN(new_allowed=True) 61 | cfg.scheduler.min_lr = 1e-6 62 | 63 | # init path 64 | task = 'competition' 65 | log_root = 'output/' + datetime.datetime.now().strftime("%Y-%m-%d") + '-' + time.strftime( 66 | "%H-%M-%S") + '_' + cfg.network.name + '_' + f"to_{task}_BinClass" 67 | 68 | if not os.path.exists(log_root): 69 | os.makedirs(log_root) 70 | writer = SummaryWriter(log_root) 71 | 72 | # create engine 73 | train_engine = TrainEngine(0, 0, DDP=False, SyncBatchNorm=False) 74 | train_engine.create_env(cfg) 75 | 76 | # create transforms 77 | transforms_dict = { 78 | 0: transforms_imagenet_train(img_size=(cfg.network.input_size, cfg.network.input_size)), 79 | 1: transforms_imagenet_train(img_size=(cfg.network.input_size, cfg.network.input_size), jpeg_compression=1), 80 | } 81 | 82 | transforms_dict_test = { 83 | 0: create_transforms_inference(h=512, w=512), 84 | 1: create_transforms_inference(h=512, w=512), 85 | } 86 | 87 | transform = transforms_dict 88 | transform_test = transforms_dict_test 89 | 90 | # create dataset 91 | trainset = MultiClassificationProcessor(transform) 92 | trainset.load_data_from_txt(train_list, ctg_list) 93 | 94 | valset = MultiClassificationProcessor(transform_test) 95 | valset.load_data_from_txt(val_list, ctg_list) 96 | 97 | # create dataloader 98 | train_loader = torch.utils.data.DataLoader(dataset=trainset, 99 | batch_size=cfg.train.batch_size, 100 | num_workers=cfg.train.worker_num, 101 | shuffle=True, 102 | pin_memory=True, 103 | drop_last=True) 104 | 105 | val_loader = torch.utils.data.DataLoader(dataset=valset, 106 | batch_size=cfg.train.batch_size, 107 | num_workers=cfg.train.worker_num, 108 | shuffle=False, 109 | pin_memory=True, 110 | drop_last=False) 111 | 112 | train_log_txtFile = log_root + "/" + "train_log.txt" 113 | f_open = open(train_log_txtFile, "w") 114 | 115 | # train & Val & Test 116 | best_test_mAP = 0.0 117 | best_test_idx = 0.0 118 | ema_start = True 119 | train_engine.ema_model = ModelEmaV3(train_engine.netloc_).cuda() 120 | for epoch_idx in range(cfg.train.epoch_start, cfg.train.epoch_num): 121 | # train 122 | train_top1, train_loss, train_lr = train_engine.train_multi_class(train_loader=train_loader, epoch_idx=epoch_idx, 123 | ema_start=ema_start) 124 | # val 125 | val_top1, val_loss, val_auc = train_engine.val_multi_class(val_loader=val_loader, epoch_idx=epoch_idx) 126 | # ema_val 127 | if ema_start: 128 | ema_val_top1, ema_val_loss, ema_val_auc = train_engine.val_ema(val_loader=val_loader, epoch_idx=epoch_idx) 129 | 130 | train_engine.save_checkpoint(log_root, epoch_idx, train_top1, val_top1, ema_start) 131 | 132 | if ema_start: 133 | outInfo = f"epoch_idx = {epoch_idx}, train_top1={train_top1}, train_loss={train_loss},val_top1={val_top1},val_loss={val_loss}, val_auc={val_auc}, ema_val_top1={ema_val_top1}, ema_val_loss={ema_val_loss}, ema_val_auc={ema_val_auc} \n" 134 | else: 135 | outInfo = f"epoch_idx = {epoch_idx}, train_top1={train_top1}, train_loss={train_loss},val_top1={val_top1},val_loss={val_loss}, val_auc={val_auc} \n" 136 | 137 | print(outInfo) 138 | 139 | f_open.write(outInfo) 140 | # 刷新文件 141 | f_open.flush() 142 | 143 | # curve all mAP & mLoss 144 | writer.add_scalars('top1', {'train': train_top1, 'valid': val_top1}, epoch_idx) 145 | writer.add_scalars('loss', {'train': train_loss, 'valid': val_loss}, epoch_idx) 146 | 147 | # curve lr 148 | writer.add_scalar('train_lr', train_lr, epoch_idx) 149 | 150 | -------------------------------------------------------------------------------- /merge.py: -------------------------------------------------------------------------------- 1 | from toolkit.chelper import final_model 2 | import torch 3 | import os 4 | 5 | 6 | # Trained ConvNeXt and RepLKNet paths (for reference) 7 | convnext_path = './final_model_csv/convnext_end.pth' 8 | replknet_path = './final_model_csv/replk_end.pth' 9 | 10 | model = final_model() 11 | model.convnext.load_state_dict(torch.load(convnext_path, map_location='cpu')['state_dict'], strict=True) 12 | model.replknet.load_state_dict(torch.load(replknet_path, map_location='cpu')['state_dict'], strict=True) 13 | 14 | if not os.path.exists('./final_model_csv'): 15 | os.makedirs('./final_model_csv') 16 | 17 | torch.save({'state_dict': model.state_dict()}, './final_model_csv/final_model.pth') 18 | -------------------------------------------------------------------------------- /model/convnext.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from timm.models.layers import trunc_normal_, DropPath 13 | from timm.models.registry import register_model 14 | 15 | class Block(nn.Module): 16 | r""" ConvNeXt Block. There are two equivalent implementations: 17 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) 18 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back 19 | We use (2) as we find it slightly faster in PyTorch 20 | 21 | Args: 22 | dim (int): Number of input channels. 23 | drop_path (float): Stochastic depth rate. Default: 0.0 24 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 25 | """ 26 | def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6): 27 | super().__init__() 28 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv 29 | self.norm = LayerNorm(dim, eps=1e-6) 30 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers 31 | self.act = nn.GELU() 32 | self.pwconv2 = nn.Linear(4 * dim, dim) 33 | self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), 34 | requires_grad=True) if layer_scale_init_value > 0 else None 35 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 36 | 37 | def forward(self, x): 38 | input = x 39 | x = self.dwconv(x) 40 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 41 | x = self.norm(x) 42 | x = self.pwconv1(x) 43 | x = self.act(x) 44 | x = self.pwconv2(x) 45 | if self.gamma is not None: 46 | x = self.gamma * x 47 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 48 | 49 | x = input + self.drop_path(x) 50 | return x 51 | 52 | class ConvNeXt(nn.Module): 53 | r""" ConvNeXt 54 | A PyTorch impl of : `A ConvNet for the 2020s` - 55 | https://arxiv.org/pdf/2201.03545.pdf 56 | 57 | Args: 58 | in_chans (int): Number of input image channels. Default: 3 59 | num_classes (int): Number of classes for classification head. Default: 1000 60 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] 61 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] 62 | drop_path_rate (float): Stochastic depth rate. Default: 0. 63 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 64 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. 65 | """ 66 | def __init__(self, in_chans=3, num_classes=1000, 67 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0., 68 | layer_scale_init_value=1e-6, head_init_scale=1., 69 | ): 70 | super().__init__() 71 | 72 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers 73 | stem = nn.Sequential( 74 | nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), 75 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first") 76 | ) 77 | self.downsample_layers.append(stem) 78 | for i in range(3): 79 | downsample_layer = nn.Sequential( 80 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), 81 | nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2), 82 | ) 83 | self.downsample_layers.append(downsample_layer) 84 | 85 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks 86 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 87 | cur = 0 88 | for i in range(4): 89 | stage = nn.Sequential( 90 | *[Block(dim=dims[i], drop_path=dp_rates[cur + j], 91 | layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])] 92 | ) 93 | self.stages.append(stage) 94 | cur += depths[i] 95 | 96 | self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer 97 | self.head = nn.Linear(dims[-1], num_classes) 98 | 99 | self.apply(self._init_weights) 100 | self.head.weight.data.mul_(head_init_scale) 101 | self.head.bias.data.mul_(head_init_scale) 102 | 103 | def _init_weights(self, m): 104 | if isinstance(m, (nn.Conv2d, nn.Linear)): 105 | trunc_normal_(m.weight, std=.02) 106 | nn.init.constant_(m.bias, 0) 107 | 108 | def forward_features(self, x): 109 | for i in range(4): 110 | x = self.downsample_layers[i](x) 111 | x = self.stages[i](x) 112 | return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C) 113 | 114 | def forward(self, x): 115 | x = self.forward_features(x) 116 | x = self.head(x) 117 | return x 118 | 119 | class LayerNorm(nn.Module): 120 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. 121 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 122 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs 123 | with shape (batch_size, channels, height, width). 124 | """ 125 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): 126 | super().__init__() 127 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 128 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 129 | self.eps = eps 130 | self.data_format = data_format 131 | if self.data_format not in ["channels_last", "channels_first"]: 132 | raise NotImplementedError 133 | self.normalized_shape = (normalized_shape, ) 134 | 135 | def forward(self, x): 136 | if self.data_format == "channels_last": 137 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 138 | elif self.data_format == "channels_first": 139 | u = x.mean(1, keepdim=True) 140 | s = (x - u).pow(2).mean(1, keepdim=True) 141 | x = (x - u) / torch.sqrt(s + self.eps) 142 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 143 | return x 144 | 145 | 146 | model_urls = { 147 | "convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth", 148 | "convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth", 149 | "convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth", 150 | "convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth", 151 | "convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth", 152 | "convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth", 153 | "convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", 154 | "convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", 155 | "convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", 156 | } 157 | 158 | @register_model 159 | def convnext_tiny(pretrained=False,in_22k=False, **kwargs): 160 | model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) 161 | if pretrained: 162 | url = model_urls['convnext_tiny_22k'] if in_22k else model_urls['convnext_tiny_1k'] 163 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) 164 | model.load_state_dict(checkpoint["model"]) 165 | return model 166 | 167 | @register_model 168 | def convnext_small(pretrained=False,in_22k=False, **kwargs): 169 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) 170 | if pretrained: 171 | url = model_urls['convnext_small_22k'] if in_22k else model_urls['convnext_small_1k'] 172 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 173 | model.load_state_dict(checkpoint["model"]) 174 | return model 175 | 176 | @register_model 177 | def convnext_base(pretrained=False, in_22k=False, **kwargs): 178 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) 179 | if pretrained: 180 | url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k'] 181 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 182 | model.load_state_dict(checkpoint["model"]) 183 | return model 184 | 185 | @register_model 186 | def convnext_large(pretrained=False, in_22k=False, **kwargs): 187 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) 188 | if pretrained: 189 | url = model_urls['convnext_large_22k'] if in_22k else model_urls['convnext_large_1k'] 190 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 191 | model.load_state_dict(checkpoint["model"]) 192 | return model 193 | 194 | @register_model 195 | def convnext_xlarge(pretrained=False, in_22k=False, **kwargs): 196 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs) 197 | if pretrained: 198 | assert in_22k, "only ImageNet-22K pre-trained ConvNeXt-XL is available; please set in_22k=True" 199 | url = model_urls['convnext_xlarge_22k'] 200 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 201 | model.load_state_dict(checkpoint["model"]) 202 | return model 203 | -------------------------------------------------------------------------------- /model/replknet.py: -------------------------------------------------------------------------------- 1 | # Scaling Up Your Kernels to 31x31: Revisiting Large Kernel Design in CNNs (https://arxiv.org/abs/2203.06717) 2 | # Github source: https://github.com/DingXiaoH/RepLKNet-pytorch 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Based on ConvNeXt, timm, DINO and DeiT code bases 5 | # https://github.com/facebookresearch/ConvNeXt 6 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm 7 | # https://github.com/facebookresearch/deit/ 8 | # https://github.com/facebookresearch/dino 9 | # --------------------------------------------------------' 10 | import torch 11 | import torch.nn as nn 12 | import torch.utils.checkpoint as checkpoint 13 | from timm.models.layers import DropPath 14 | import sys 15 | import os 16 | 17 | def get_conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias): 18 | if type(kernel_size) is int: 19 | use_large_impl = kernel_size > 5 20 | else: 21 | assert len(kernel_size) == 2 and kernel_size[0] == kernel_size[1] 22 | use_large_impl = kernel_size[0] > 5 23 | has_large_impl = 'LARGE_KERNEL_CONV_IMPL' in os.environ 24 | if has_large_impl and in_channels == out_channels and out_channels == groups and use_large_impl and stride == 1 and padding == kernel_size // 2 and dilation == 1: 25 | sys.path.append(os.environ['LARGE_KERNEL_CONV_IMPL']) 26 | # Please follow the instructions https://github.com/DingXiaoH/RepLKNet-pytorch/blob/main/README.md 27 | # export LARGE_KERNEL_CONV_IMPL=absolute_path_to_where_you_cloned_the_example (i.e., depthwise_conv2d_implicit_gemm.py) 28 | # TODO more efficient PyTorch implementations of large-kernel convolutions. Pull requests are welcomed. 29 | # Or you may try MegEngine. We have integrated an efficient implementation into MegEngine and it will automatically use it. 30 | from depthwise_conv2d_implicit_gemm import DepthWiseConv2dImplicitGEMM 31 | return DepthWiseConv2dImplicitGEMM(in_channels, kernel_size, bias=bias) 32 | else: 33 | return nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, 34 | padding=padding, dilation=dilation, groups=groups, bias=bias) 35 | 36 | use_sync_bn = False 37 | 38 | def enable_sync_bn(): 39 | global use_sync_bn 40 | use_sync_bn = True 41 | 42 | def get_bn(channels): 43 | if use_sync_bn: 44 | return nn.SyncBatchNorm(channels) 45 | else: 46 | return nn.BatchNorm2d(channels) 47 | 48 | def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups, dilation=1): 49 | if padding is None: 50 | padding = kernel_size // 2 51 | result = nn.Sequential() 52 | result.add_module('conv', get_conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, 53 | stride=stride, padding=padding, dilation=dilation, groups=groups, bias=False)) 54 | result.add_module('bn', get_bn(out_channels)) 55 | return result 56 | 57 | def conv_bn_relu(in_channels, out_channels, kernel_size, stride, padding, groups, dilation=1): 58 | if padding is None: 59 | padding = kernel_size // 2 60 | result = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, 61 | stride=stride, padding=padding, groups=groups, dilation=dilation) 62 | result.add_module('nonlinear', nn.ReLU()) 63 | return result 64 | 65 | def fuse_bn(conv, bn): 66 | kernel = conv.weight 67 | running_mean = bn.running_mean 68 | running_var = bn.running_var 69 | gamma = bn.weight 70 | beta = bn.bias 71 | eps = bn.eps 72 | std = (running_var + eps).sqrt() 73 | t = (gamma / std).reshape(-1, 1, 1, 1) 74 | return kernel * t, beta - running_mean * gamma / std 75 | 76 | class ReparamLargeKernelConv(nn.Module): 77 | 78 | def __init__(self, in_channels, out_channels, kernel_size, 79 | stride, groups, 80 | small_kernel, 81 | small_kernel_merged=False): 82 | super(ReparamLargeKernelConv, self).__init__() 83 | self.kernel_size = kernel_size 84 | self.small_kernel = small_kernel 85 | # We assume the conv does not change the feature map size, so padding = k//2. Otherwise, you may configure padding as you wish, and change the padding of small_conv accordingly. 86 | padding = kernel_size // 2 87 | if small_kernel_merged: 88 | self.lkb_reparam = get_conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, 89 | stride=stride, padding=padding, dilation=1, groups=groups, bias=True) 90 | else: 91 | self.lkb_origin = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, 92 | stride=stride, padding=padding, dilation=1, groups=groups) 93 | if small_kernel is not None: 94 | assert small_kernel <= kernel_size, 'The kernel size for re-param cannot be larger than the large kernel!' 95 | self.small_conv = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=small_kernel, 96 | stride=stride, padding=small_kernel//2, groups=groups, dilation=1) 97 | 98 | def forward(self, inputs): 99 | if hasattr(self, 'lkb_reparam'): 100 | out = self.lkb_reparam(inputs) 101 | else: 102 | out = self.lkb_origin(inputs) 103 | if hasattr(self, 'small_conv'): 104 | out += self.small_conv(inputs) 105 | return out 106 | 107 | def get_equivalent_kernel_bias(self): 108 | eq_k, eq_b = fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn) 109 | if hasattr(self, 'small_conv'): 110 | small_k, small_b = fuse_bn(self.small_conv.conv, self.small_conv.bn) 111 | eq_b += small_b 112 | # add to the central part 113 | eq_k += nn.functional.pad(small_k, [(self.kernel_size - self.small_kernel) // 2] * 4) 114 | return eq_k, eq_b 115 | 116 | def merge_kernel(self): 117 | eq_k, eq_b = self.get_equivalent_kernel_bias() 118 | self.lkb_reparam = get_conv2d(in_channels=self.lkb_origin.conv.in_channels, 119 | out_channels=self.lkb_origin.conv.out_channels, 120 | kernel_size=self.lkb_origin.conv.kernel_size, stride=self.lkb_origin.conv.stride, 121 | padding=self.lkb_origin.conv.padding, dilation=self.lkb_origin.conv.dilation, 122 | groups=self.lkb_origin.conv.groups, bias=True) 123 | self.lkb_reparam.weight.data = eq_k 124 | self.lkb_reparam.bias.data = eq_b 125 | self.__delattr__('lkb_origin') 126 | if hasattr(self, 'small_conv'): 127 | self.__delattr__('small_conv') 128 | 129 | 130 | class ConvFFN(nn.Module): 131 | 132 | def __init__(self, in_channels, internal_channels, out_channels, drop_path): 133 | super().__init__() 134 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 135 | self.preffn_bn = get_bn(in_channels) 136 | self.pw1 = conv_bn(in_channels=in_channels, out_channels=internal_channels, kernel_size=1, stride=1, padding=0, groups=1) 137 | self.pw2 = conv_bn(in_channels=internal_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0, groups=1) 138 | self.nonlinear = nn.GELU() 139 | 140 | def forward(self, x): 141 | out = self.preffn_bn(x) 142 | out = self.pw1(out) 143 | out = self.nonlinear(out) 144 | out = self.pw2(out) 145 | return x + self.drop_path(out) 146 | 147 | 148 | class RepLKBlock(nn.Module): 149 | 150 | def __init__(self, in_channels, dw_channels, block_lk_size, small_kernel, drop_path, small_kernel_merged=False): 151 | super().__init__() 152 | self.pw1 = conv_bn_relu(in_channels, dw_channels, 1, 1, 0, groups=1) 153 | self.pw2 = conv_bn(dw_channels, in_channels, 1, 1, 0, groups=1) 154 | self.large_kernel = ReparamLargeKernelConv(in_channels=dw_channels, out_channels=dw_channels, kernel_size=block_lk_size, 155 | stride=1, groups=dw_channels, small_kernel=small_kernel, small_kernel_merged=small_kernel_merged) 156 | self.lk_nonlinear = nn.ReLU() 157 | self.prelkb_bn = get_bn(in_channels) 158 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 159 | print('drop path:', self.drop_path) 160 | 161 | def forward(self, x): 162 | out = self.prelkb_bn(x) 163 | out = self.pw1(out) 164 | out = self.large_kernel(out) 165 | out = self.lk_nonlinear(out) 166 | out = self.pw2(out) 167 | return x + self.drop_path(out) 168 | 169 | 170 | class RepLKNetStage(nn.Module): 171 | 172 | def __init__(self, channels, num_blocks, stage_lk_size, drop_path, 173 | small_kernel, dw_ratio=1, ffn_ratio=4, 174 | use_checkpoint=False, # train with torch.utils.checkpoint to save memory 175 | small_kernel_merged=False, 176 | norm_intermediate_features=False): 177 | super().__init__() 178 | self.use_checkpoint = use_checkpoint 179 | blks = [] 180 | for i in range(num_blocks): 181 | block_drop_path = drop_path[i] if isinstance(drop_path, list) else drop_path 182 | # Assume all RepLK Blocks within a stage share the same lk_size. You may tune it on your own model. 183 | replk_block = RepLKBlock(in_channels=channels, dw_channels=int(channels * dw_ratio), block_lk_size=stage_lk_size, 184 | small_kernel=small_kernel, drop_path=block_drop_path, small_kernel_merged=small_kernel_merged) 185 | convffn_block = ConvFFN(in_channels=channels, internal_channels=int(channels * ffn_ratio), out_channels=channels, 186 | drop_path=block_drop_path) 187 | blks.append(replk_block) 188 | blks.append(convffn_block) 189 | self.blocks = nn.ModuleList(blks) 190 | if norm_intermediate_features: 191 | self.norm = get_bn(channels) # Only use this with RepLKNet-XL on downstream tasks 192 | else: 193 | self.norm = nn.Identity() 194 | 195 | def forward(self, x): 196 | for blk in self.blocks: 197 | if self.use_checkpoint: 198 | x = checkpoint.checkpoint(blk, x) # Save training memory 199 | else: 200 | x = blk(x) 201 | return x 202 | 203 | class RepLKNet(nn.Module): 204 | 205 | def __init__(self, large_kernel_sizes, layers, channels, drop_path_rate, small_kernel, 206 | dw_ratio=1, ffn_ratio=4, in_channels=3, num_classes=1000, out_indices=None, 207 | use_checkpoint=False, 208 | small_kernel_merged=False, 209 | use_sync_bn=True, 210 | norm_intermediate_features=False # for RepLKNet-XL on COCO and ADE20K, use an extra BN to normalize the intermediate feature maps then feed them into the heads 211 | ): 212 | super().__init__() 213 | 214 | if num_classes is None and out_indices is None: 215 | raise ValueError('must specify one of num_classes (for pretraining) and out_indices (for downstream tasks)') 216 | elif num_classes is not None and out_indices is not None: 217 | raise ValueError('cannot specify both num_classes (for pretraining) and out_indices (for downstream tasks)') 218 | elif num_classes is not None and norm_intermediate_features: 219 | raise ValueError('for pretraining, no need to normalize the intermediate feature maps') 220 | self.out_indices = out_indices 221 | if use_sync_bn: 222 | enable_sync_bn() 223 | 224 | base_width = channels[0] 225 | self.use_checkpoint = use_checkpoint 226 | self.norm_intermediate_features = norm_intermediate_features 227 | self.num_stages = len(layers) 228 | self.stem = nn.ModuleList([ 229 | conv_bn_relu(in_channels=in_channels, out_channels=base_width, kernel_size=3, stride=2, padding=1, groups=1), 230 | conv_bn_relu(in_channels=base_width, out_channels=base_width, kernel_size=3, stride=1, padding=1, groups=base_width), 231 | conv_bn_relu(in_channels=base_width, out_channels=base_width, kernel_size=1, stride=1, padding=0, groups=1), 232 | conv_bn_relu(in_channels=base_width, out_channels=base_width, kernel_size=3, stride=2, padding=1, groups=base_width)]) 233 | # stochastic depth. We set block-wise drop-path rate. The higher level blocks are more likely to be dropped. This implementation follows Swin. 234 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(layers))] 235 | self.stages = nn.ModuleList() 236 | self.transitions = nn.ModuleList() 237 | for stage_idx in range(self.num_stages): 238 | layer = RepLKNetStage(channels=channels[stage_idx], num_blocks=layers[stage_idx], 239 | stage_lk_size=large_kernel_sizes[stage_idx], 240 | drop_path=dpr[sum(layers[:stage_idx]):sum(layers[:stage_idx + 1])], 241 | small_kernel=small_kernel, dw_ratio=dw_ratio, ffn_ratio=ffn_ratio, 242 | use_checkpoint=use_checkpoint, small_kernel_merged=small_kernel_merged, 243 | norm_intermediate_features=norm_intermediate_features) 244 | self.stages.append(layer) 245 | if stage_idx < len(layers) - 1: 246 | transition = nn.Sequential( 247 | conv_bn_relu(channels[stage_idx], channels[stage_idx + 1], 1, 1, 0, groups=1), 248 | conv_bn_relu(channels[stage_idx + 1], channels[stage_idx + 1], 3, stride=2, padding=1, groups=channels[stage_idx + 1])) 249 | self.transitions.append(transition) 250 | 251 | if num_classes is not None: 252 | self.norm = get_bn(channels[-1]) 253 | self.avgpool = nn.AdaptiveAvgPool2d(1) 254 | self.head = nn.Linear(channels[-1], num_classes) 255 | 256 | 257 | 258 | def forward_features(self, x): 259 | x = self.stem[0](x) 260 | for stem_layer in self.stem[1:]: 261 | if self.use_checkpoint: 262 | x = checkpoint.checkpoint(stem_layer, x) # save memory 263 | else: 264 | x = stem_layer(x) 265 | 266 | if self.out_indices is None: 267 | # Just need the final output 268 | for stage_idx in range(self.num_stages): 269 | x = self.stages[stage_idx](x) 270 | if stage_idx < self.num_stages - 1: 271 | x = self.transitions[stage_idx](x) 272 | return x 273 | else: 274 | # Need the intermediate feature maps 275 | outs = [] 276 | for stage_idx in range(self.num_stages): 277 | x = self.stages[stage_idx](x) 278 | if stage_idx in self.out_indices: 279 | outs.append(self.stages[stage_idx].norm(x)) # For RepLKNet-XL normalize the features before feeding them into the heads 280 | if stage_idx < self.num_stages - 1: 281 | x = self.transitions[stage_idx](x) 282 | return outs 283 | 284 | def forward(self, x): 285 | x = self.forward_features(x) 286 | if self.out_indices: 287 | return x 288 | else: 289 | x = self.norm(x) 290 | x = self.avgpool(x) 291 | x = torch.flatten(x, 1) 292 | x = self.head(x) 293 | return x 294 | 295 | def structural_reparam(self): 296 | for m in self.modules(): 297 | if hasattr(m, 'merge_kernel'): 298 | m.merge_kernel() 299 | 300 | # If your framework cannot automatically fuse BN for inference, you may do it manually. 301 | # The BNs after and before conv layers can be removed. 302 | # No need to call this if your framework support automatic BN fusion. 303 | def deep_fuse_BN(self): 304 | for m in self.modules(): 305 | if not isinstance(m, nn.Sequential): 306 | continue 307 | if not len(m) in [2, 3]: # Only handle conv-BN or conv-BN-relu 308 | continue 309 | # If you use a custom Conv2d impl, assume it also has 'kernel_size' and 'weight' 310 | if hasattr(m[0], 'kernel_size') and hasattr(m[0], 'weight') and isinstance(m[1], nn.BatchNorm2d): 311 | conv = m[0] 312 | bn = m[1] 313 | fused_kernel, fused_bias = fuse_bn(conv, bn) 314 | fused_conv = get_conv2d(conv.in_channels, conv.out_channels, kernel_size=conv.kernel_size, 315 | stride=conv.stride, 316 | padding=conv.padding, dilation=conv.dilation, groups=conv.groups, bias=True) 317 | fused_conv.weight.data = fused_kernel 318 | fused_conv.bias.data = fused_bias 319 | m[0] = fused_conv 320 | m[1] = nn.Identity() 321 | 322 | 323 | def create_RepLKNet31B(drop_path_rate=0.5, num_classes=1000, use_checkpoint=False, small_kernel_merged=False, use_sync_bn=True): 324 | return RepLKNet(large_kernel_sizes=[31,29,27,13], layers=[2,2,18,2], channels=[128,256,512,1024], 325 | drop_path_rate=drop_path_rate, small_kernel=5, num_classes=num_classes, use_checkpoint=use_checkpoint, 326 | small_kernel_merged=small_kernel_merged, use_sync_bn=use_sync_bn) 327 | 328 | def create_RepLKNet31L(drop_path_rate=0.3, num_classes=1000, use_checkpoint=True, small_kernel_merged=False): 329 | return RepLKNet(large_kernel_sizes=[31,29,27,13], layers=[2,2,18,2], channels=[192,384,768,1536], 330 | drop_path_rate=drop_path_rate, small_kernel=5, num_classes=num_classes, use_checkpoint=use_checkpoint, 331 | small_kernel_merged=small_kernel_merged) 332 | 333 | def create_RepLKNetXL(drop_path_rate=0.3, num_classes=1000, use_checkpoint=True, small_kernel_merged=False): 334 | return RepLKNet(large_kernel_sizes=[27,27,27,13], layers=[2,2,18,2], channels=[256,512,1024,2048], 335 | drop_path_rate=drop_path_rate, small_kernel=None, dw_ratio=1.5, 336 | num_classes=num_classes, use_checkpoint=use_checkpoint, 337 | small_kernel_merged=small_kernel_merged) 338 | 339 | if __name__ == '__main__': 340 | model = create_RepLKNet31B(small_kernel_merged=False) 341 | model.eval() 342 | print('------------------- training-time model -------------') 343 | print(model) 344 | x = torch.randn(2, 3, 224, 224) 345 | origin_y = model(x) 346 | model.structural_reparam() 347 | print('------------------- after re-param -------------') 348 | print(model) 349 | reparam_y = model(x) 350 | print('------------------- the difference is ------------------------') 351 | print((origin_y - reparam_y).abs().sum()) 352 | 353 | 354 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | asttokens==2.4.1 2 | einops==0.8.0 3 | numpy==1.22.0 4 | opencv-python==4.8.0.74 5 | pillow==9.5.0 6 | PyYAML==6.0.1 7 | scikit-image==0.21.0 8 | scikit-learn==1.3.2 9 | tensorboard==2.14.0 10 | tensorboard-data-server==0.7.2 11 | thop==0.1.1.post2209072238 12 | timm==0.6.13 13 | tqdm==4.66.4 14 | fastapi==0.103.1 15 | uvicorn==0.22.0 16 | pydantic==1.10.9 17 | torch==1.13.1 18 | torchvision==0.14.1 19 | 20 | -------------------------------------------------------------------------------- /toolkit/chelper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from model.convnext import convnext_base 4 | import timm 5 | from model.replknet import create_RepLKNet31B 6 | 7 | 8 | class augment_inputs_network(nn.Module): 9 | def __init__(self, model): 10 | super(augment_inputs_network, self).__init__() 11 | self.model = model 12 | self.adapter = nn.Conv2d(in_channels=6, out_channels=3, kernel_size=3, stride=1, padding=1) 13 | 14 | def forward(self, x): 15 | x = self.adapter(x) 16 | x = (x - torch.as_tensor(timm.data.constants.IMAGENET_DEFAULT_MEAN, device=x.get_device()).view(1, -1, 1, 1)) / torch.as_tensor(timm.data.constants.IMAGENET_DEFAULT_STD, device=x.get_device()).view(1, -1, 1, 1) 17 | 18 | return self.model(x) 19 | 20 | 21 | class final_model(nn.Module): # Total parameters: 158.64741325378418 MB 22 | def __init__(self): 23 | super(final_model, self).__init__() 24 | 25 | self.convnext = convnext_base(num_classes=2) 26 | self.convnext = augment_inputs_network(self.convnext) 27 | 28 | self.replknet = create_RepLKNet31B(num_classes=2) 29 | self.replknet = augment_inputs_network(self.replknet) 30 | 31 | def forward(self, x): 32 | B, N, C, H, W = x.shape 33 | x = x.view(-1, C, H, W) 34 | 35 | pred1 = self.convnext(x) 36 | pred2 = self.replknet(x) 37 | 38 | outputs_score1 = nn.functional.softmax(pred1, dim=1) 39 | outputs_score2 = nn.functional.softmax(pred2, dim=1) 40 | 41 | predict_score1 = outputs_score1[:, 1] 42 | predict_score2 = outputs_score2[:, 1] 43 | 44 | predict_score1 = predict_score1.view(B, N).mean(dim=-1) 45 | predict_score2 = predict_score2.view(B, N).mean(dim=-1) 46 | 47 | return torch.stack((predict_score1, predict_score2), dim=-1).mean(dim=-1) 48 | 49 | 50 | def load_model(model_name, ctg_num, use_sync_bn): 51 | """Load standard model, like vgg16, resnet18, 52 | 53 | Args: 54 | model_name: e.g., vgg16, inception, resnet18, ... 55 | ctg_num: e.g., 1000 56 | use_sync_bn: True/False 57 | """ 58 | if model_name == 'convnext': 59 | model = convnext_base(num_classes=ctg_num) 60 | model_path = 'pre_model/convnext_base_1k_384.pth' 61 | check_point = torch.load(model_path, map_location='cpu')['model'] 62 | check_point.pop('head.weight') 63 | check_point.pop('head.bias') 64 | model.load_state_dict(check_point, strict=False) 65 | 66 | model = augment_inputs_network(model) 67 | 68 | elif model_name == 'replknet': 69 | model = create_RepLKNet31B(num_classes=ctg_num, use_sync_bn=use_sync_bn) 70 | model_path = 'pre_model/RepLKNet-31B_ImageNet-1K_384.pth' 71 | check_point = torch.load(model_path) 72 | check_point.pop('head.weight') 73 | check_point.pop('head.bias') 74 | model.load_state_dict(check_point, strict=False) 75 | 76 | model = augment_inputs_network(model) 77 | 78 | elif model_name == 'all': 79 | model = final_model() 80 | 81 | print("model_name", model_name) 82 | 83 | return model 84 | 85 | -------------------------------------------------------------------------------- /toolkit/cmetric.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import numpy as np 4 | import sklearn 5 | import sklearn.metrics 6 | 7 | 8 | class MultilabelClassificationMetric(object): 9 | def __init__(self): 10 | super(MultilabelClassificationMetric, self).__init__() 11 | self.pred_scores_ = torch.FloatTensor() # .FloatStorage() 12 | self.grth_labels_ = torch.LongTensor() # .LongStorage() 13 | 14 | # Func: 15 | # Reset calculation. 16 | def reset(self): 17 | self.pred_scores_ = torch.FloatTensor(torch.FloatStorage()) 18 | self.grth_labels_ = torch.LongTensor(torch.LongStorage()) 19 | 20 | # Func: 21 | # Add prediction and groundtruth that will be used to calculate average precision. 22 | # Input: 23 | # pred_scores : predicted scores, size: [batch_size, label_dim], format: [s0, s1, ..., s19] 24 | # grth_labels : groundtruth labels, size: [batch_size, label_dim], format: [c0, c1, ..., c19] 25 | def add(self, pred_scores, grth_labels): 26 | if not torch.is_tensor(pred_scores): 27 | pred_scores = torch.from_numpy(pred_scores) 28 | if not torch.is_tensor(grth_labels): 29 | grth_labels = torch.from_numpy(grth_labels) 30 | 31 | # check 32 | assert pred_scores.dim() == 2, 'wrong pred_scores size (should be 2D with format: [batch_size, label_dim(one column per class)])' 33 | assert grth_labels.dim() == 2, 'wrong grth_labels size (should be 2D with format: [batch_size, label_dim(one column per class)])' 34 | 35 | # check storage is sufficient 36 | if self.pred_scores_.storage().size() < self.pred_scores_.numel() + pred_scores.numel(): 37 | new_size = math.ceil(self.pred_scores_.storage().size() * 1.5) 38 | self.pred_scores_.storage().resize_(int(new_size + pred_scores.numel())) 39 | self.grth_labels_.storage().resize_(int(new_size + pred_scores.numel())) 40 | 41 | # store outputs and targets 42 | offset = self.pred_scores_.size(0) if self.pred_scores_.dim() > 0 else 0 43 | self.pred_scores_.resize_(offset + pred_scores.size(0), pred_scores.size(1)) 44 | self.grth_labels_.resize_(offset + grth_labels.size(0), grth_labels.size(1)) 45 | self.pred_scores_.narrow(0, offset, pred_scores.size(0)).copy_(pred_scores) 46 | self.grth_labels_.narrow(0, offset, grth_labels.size(0)).copy_(grth_labels) 47 | 48 | # Func: 49 | # Compute average precision. 50 | def calc_avg_precision(self): 51 | # check 52 | if self.pred_scores_.numel() == 0: return 0 53 | # calc by class 54 | aps = torch.zeros(self.pred_scores_.size(1)) 55 | for cls_idx in range(self.pred_scores_.size(1)): 56 | # get pred scores & grth labels at class cls_idx 57 | cls_pred_scores = self.pred_scores_[:, cls_idx] # predictions for all images at class cls_idx, format: [img_num] 58 | cls_grth_labels = self.grth_labels_[:, cls_idx] # truthvalues for all iamges at class cls_idx, format: [img_num] 59 | # sort by socre 60 | _, img_indices = torch.sort(cls_pred_scores, dim=0, descending=True) 61 | # calc ap 62 | TP, TPFP = 0., 0. 63 | for img_idx in img_indices: 64 | label = cls_grth_labels[img_idx] 65 | # accumulate 66 | TPFP += 1 67 | if label == 1: 68 | TP += 1 69 | aps[cls_idx] += TP / TPFP 70 | aps[cls_idx] /= (TP + 1e-5) 71 | # return 72 | return aps 73 | 74 | # Func: 75 | # Compute average precision. 76 | def calc_avg_precision2(self): 77 | self.pred_scores_ = self.pred_scores_.cpu().numpy().astype('float32') 78 | self.grth_labels_ = self.grth_labels_.cpu().numpy().astype('float32') 79 | # check 80 | if self.pred_scores_.size == 0: return 0 81 | # calc by class 82 | aps = np.zeros(self.pred_scores_.shape[1]) 83 | for cls_idx in range(self.pred_scores_.shape[1]): 84 | # get pred scores & grth labels at class cls_idx 85 | cls_pred_scores = self.pred_scores_[:, cls_idx] 86 | cls_grth_labels = self.grth_labels_[:, cls_idx] 87 | # compute ap for a object category 88 | aps[cls_idx] = sklearn.metrics.average_precision_score(cls_grth_labels, cls_pred_scores) 89 | aps[np.isnan(aps)] = 0 90 | aps = np.around(aps, decimals=4) 91 | return aps 92 | 93 | 94 | class MultiClassificationMetric(object): 95 | """Computes and stores the average and current value""" 96 | def __init__(self): 97 | super(MultiClassificationMetric, self).__init__() 98 | self.reset() 99 | self.val = 0 100 | 101 | def update(self, value, n=1): 102 | self.val = value 103 | self.sum += value 104 | self.var += value * value 105 | self.n += n 106 | 107 | if self.n == 0: 108 | self.mean, self.std = np.nan, np.nan 109 | elif self.n == 1: 110 | self.mean, self.std = self.sum, np.inf 111 | self.mean_old = self.mean 112 | self.m_s = 0.0 113 | else: 114 | self.mean = self.mean_old + (value - n * self.mean_old) / float(self.n) 115 | self.m_s += (value - self.mean_old) * (value - self.mean) 116 | self.mean_old = self.mean 117 | self.std = math.sqrt(self.m_s / (self.n - 1.0)) 118 | 119 | def reset(self): 120 | self.n = 0 121 | self.sum = 0.0 122 | self.var = 0.0 123 | self.val = 0.0 124 | self.mean = np.nan 125 | self.mean_old = 0.0 126 | self.m_s = 0.0 127 | self.std = np.nan 128 | 129 | 130 | def simple_accuracy(output, target): 131 | """计算预测正确的准确率""" 132 | with torch.no_grad(): 133 | _, preds = torch.max(output, 1) 134 | 135 | correct = preds.eq(target).float() 136 | accuracy = correct.sum() / len(target) 137 | return accuracy -------------------------------------------------------------------------------- /toolkit/dhelper.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def get_file_name_ext(filepath): 5 | # analyze 6 | file_name, file_ext = os.path.splitext(filepath) 7 | # return 8 | return file_name, file_ext 9 | 10 | 11 | def get_file_ext(filepath): 12 | return get_file_name_ext(filepath)[1] 13 | 14 | 15 | def traverse_recursively(fileroot, filepathes=[], extension='.*'): 16 | """Traverse all file path in specialed directory recursively. 17 | 18 | Args: 19 | h: crop height. 20 | extension: e.g. '.jpg .png .bmp .webp .tif .eps' 21 | """ 22 | items = os.listdir(fileroot) 23 | for item in items: 24 | if os.path.isfile(os.path.join(fileroot, item)): 25 | filepath = os.path.join(fileroot, item) 26 | fileext = get_file_ext(filepath).lower() 27 | if extension == '.*': 28 | filepathes.append(filepath) 29 | elif fileext in extension: 30 | filepathes.append(filepath) 31 | else: 32 | pass 33 | elif os.path.isdir(os.path.join(fileroot, item)): 34 | traverse_recursively(os.path.join(fileroot, item), filepathes, extension) 35 | else: 36 | pass 37 | -------------------------------------------------------------------------------- /toolkit/dtransform.py: -------------------------------------------------------------------------------- 1 | from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform 2 | from timm.data.transforms import RandomResizedCropAndInterpolation 3 | from PIL import Image 4 | import torchvision.transforms as transforms 5 | import cv2 6 | import numpy as np 7 | import torchvision.transforms.functional as F 8 | 9 | 10 | # 添加jpeg压缩 11 | class JPEGCompression: 12 | def __init__(self, quality=10, p=0.3): 13 | self.quality = quality 14 | self.p = p 15 | 16 | def __call__(self, img): 17 | if np.random.rand() < self.p: 18 | img_np = np.array(img) 19 | _, buffer = cv2.imencode('.jpg', img_np[:, :, ::-1], [int(cv2.IMWRITE_JPEG_QUALITY), self.quality]) 20 | jpeg_img = cv2.imdecode(buffer, 1) 21 | return Image.fromarray(jpeg_img[:, :, ::-1]) 22 | return img 23 | 24 | 25 | # 原始数据增强 26 | def transforms_imagenet_train( 27 | img_size=(224, 224), 28 | scale=(0.08, 1.0), 29 | ratio=(3./4., 4./3.), 30 | hflip=0.5, 31 | vflip=0.5, 32 | auto_augment='rand-m9-mstd0.5-inc1', 33 | interpolation='random', 34 | mean=(0.485, 0.456, 0.406), 35 | jpeg_compression = 0, 36 | ): 37 | scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range 38 | ratio = tuple(ratio or (3./4., 4./3.)) # default imagenet ratio range 39 | 40 | primary_tfl = [ 41 | RandomResizedCropAndInterpolation(img_size, scale=scale, ratio=ratio, interpolation=interpolation)] 42 | if hflip > 0.: 43 | primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)] 44 | if vflip > 0.: 45 | primary_tfl += [transforms.RandomVerticalFlip(p=vflip)] 46 | 47 | secondary_tfl = [] 48 | if auto_augment: 49 | assert isinstance(auto_augment, str) 50 | 51 | if isinstance(img_size, (tuple, list)): 52 | img_size_min = min(img_size) 53 | else: 54 | img_size_min = img_size 55 | 56 | aa_params = dict( 57 | translate_const=int(img_size_min * 0.45), 58 | img_mean=tuple([min(255, round(255 * x)) for x in mean]), 59 | ) 60 | if auto_augment.startswith('rand'): 61 | secondary_tfl += [rand_augment_transform(auto_augment, aa_params)] 62 | elif auto_augment.startswith('augmix'): 63 | aa_params['translate_pct'] = 0.3 64 | secondary_tfl += [augment_and_mix_transform(auto_augment, aa_params)] 65 | else: 66 | secondary_tfl += [auto_augment_transform(auto_augment, aa_params)] 67 | 68 | if jpeg_compression == 1: 69 | secondary_tfl += [JPEGCompression(quality=10, p=0.3)] 70 | 71 | final_tfl = [transforms.ToTensor()] 72 | 73 | return transforms.Compose(primary_tfl + secondary_tfl + final_tfl) 74 | 75 | 76 | # 推理(测试)使用 77 | def create_transforms_inference(h=256, w=256): 78 | transformer = transforms.Compose([ 79 | transforms.Resize(size=(h, w)), 80 | transforms.ToTensor(), 81 | ]) 82 | 83 | return transformer 84 | 85 | 86 | def create_transforms_inference1(h=256, w=256): 87 | transformer = transforms.Compose([ 88 | transforms.Lambda(lambda img: F.rotate(img, angle=90)), 89 | transforms.Resize(size=(h, w)), 90 | transforms.ToTensor(), 91 | ]) 92 | 93 | return transformer 94 | 95 | 96 | def create_transforms_inference2(h=256, w=256): 97 | transformer = transforms.Compose([ 98 | transforms.Lambda(lambda img: F.rotate(img, angle=180)), 99 | transforms.Resize(size=(h, w)), 100 | transforms.ToTensor(), 101 | ]) 102 | 103 | return transformer 104 | 105 | 106 | def create_transforms_inference3(h=256, w=256): 107 | transformer = transforms.Compose([ 108 | transforms.Lambda(lambda img: F.rotate(img, angle=270)), 109 | transforms.Resize(size=(h, w)), 110 | transforms.ToTensor(), 111 | ]) 112 | 113 | return transformer 114 | 115 | 116 | def create_transforms_inference4(h=256, w=256): 117 | transformer = transforms.Compose([ 118 | transforms.Lambda(lambda img: F.hflip(img)), 119 | transforms.Resize(size=(h, w)), 120 | transforms.ToTensor(), 121 | ]) 122 | 123 | return transformer 124 | 125 | 126 | def create_transforms_inference5(h=256, w=256): 127 | transformer = transforms.Compose([ 128 | transforms.Lambda(lambda img: F.vflip(img)), 129 | transforms.Resize(size=(h, w)), 130 | transforms.ToTensor(), 131 | ]) 132 | 133 | return transformer 134 | -------------------------------------------------------------------------------- /toolkit/yacs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | ############################################################################## 15 | 16 | """YACS -- Yet Another Configuration System is designed to be a simple 17 | configuration management system for academic and industrial research 18 | projects. 19 | 20 | See README.md for usage and examples. 21 | """ 22 | 23 | import copy 24 | import io 25 | import logging 26 | import os 27 | import sys 28 | from ast import literal_eval 29 | 30 | import yaml 31 | 32 | # Flag for py2 and py3 compatibility to use when separate code paths are necessary 33 | # When _PY2 is False, we assume Python 3 is in use 34 | _PY2 = sys.version_info.major == 2 35 | 36 | # Filename extensions for loading configs from files 37 | _YAML_EXTS = {"", ".yaml", ".yml"} 38 | _PY_EXTS = {".py"} 39 | 40 | # py2 and py3 compatibility for checking file object type 41 | # We simply use this to infer py2 vs py3 42 | if _PY2: 43 | _FILE_TYPES = (file, io.IOBase) 44 | else: 45 | _FILE_TYPES = (io.IOBase,) 46 | 47 | # CfgNodes can only contain a limited set of valid types 48 | _VALID_TYPES = {tuple, list, str, int, float, bool, type(None)} 49 | # py2 allow for str and unicode 50 | if _PY2: 51 | _VALID_TYPES = _VALID_TYPES.union({unicode}) # noqa: F821 52 | 53 | # Utilities for importing modules from file paths 54 | if _PY2: 55 | # imp is available in both py2 and py3 for now, but is deprecated in py3 56 | import imp 57 | else: 58 | import importlib.util 59 | 60 | logger = logging.getLogger(__name__) 61 | 62 | 63 | class CfgNode(dict): 64 | """ 65 | CfgNode represents an internal node in the configuration tree. It's a simple 66 | dict-like container that allows for attribute-based access to keys. 67 | """ 68 | 69 | IMMUTABLE = "__immutable__" 70 | DEPRECATED_KEYS = "__deprecated_keys__" 71 | RENAMED_KEYS = "__renamed_keys__" 72 | NEW_ALLOWED = "__new_allowed__" 73 | 74 | def __init__(self, init_dict=None, key_list=None, new_allowed=False): 75 | """ 76 | Args: 77 | init_dict (dict): the possibly-nested dictionary to initailize the CfgNode. 78 | key_list (list[str]): a list of names which index this CfgNode from the root. 79 | Currently only used for logging purposes. 80 | new_allowed (bool): whether adding new key is allowed when merging with 81 | other configs. 82 | """ 83 | # Recursively convert nested dictionaries in init_dict into CfgNodes 84 | init_dict = {} if init_dict is None else init_dict 85 | key_list = [] if key_list is None else key_list 86 | init_dict = self._create_config_tree_from_dict(init_dict, key_list) 87 | super(CfgNode, self).__init__(init_dict) 88 | # Manage if the CfgNode is frozen or not 89 | self.__dict__[CfgNode.IMMUTABLE] = False 90 | # Deprecated options 91 | # If an option is removed from the code and you don't want to break existing 92 | # yaml configs, you can add the full config key as a string to the set below. 93 | self.__dict__[CfgNode.DEPRECATED_KEYS] = set() 94 | # Renamed options 95 | # If you rename a config option, record the mapping from the old name to the new 96 | # name in the dictionary below. Optionally, if the type also changed, you can 97 | # make the value a tuple that specifies first the renamed key and then 98 | # instructions for how to edit the config file. 99 | self.__dict__[CfgNode.RENAMED_KEYS] = { 100 | # 'EXAMPLE.OLD.KEY': 'EXAMPLE.NEW.KEY', # Dummy example to follow 101 | # 'EXAMPLE.OLD.KEY': ( # A more complex example to follow 102 | # 'EXAMPLE.NEW.KEY', 103 | # "Also convert to a tuple, e.g., 'foo' -> ('foo',) or " 104 | # + "'foo:bar' -> ('foo', 'bar')" 105 | # ), 106 | } 107 | 108 | # Allow new attributes after initialisation 109 | self.__dict__[CfgNode.NEW_ALLOWED] = new_allowed 110 | 111 | @classmethod 112 | def _create_config_tree_from_dict(cls, dic, key_list): 113 | """ 114 | Create a configuration tree using the given dict. 115 | Any dict-like objects inside dict will be treated as a new CfgNode. 116 | 117 | Args: 118 | dic (dict): 119 | key_list (list[str]): a list of names which index this CfgNode from the root. 120 | Currently only used for logging purposes. 121 | """ 122 | dic = copy.deepcopy(dic) 123 | for k, v in dic.items(): 124 | if isinstance(v, dict): 125 | # Convert dict to CfgNode 126 | dic[k] = cls(v, key_list=key_list + [k]) 127 | else: 128 | # Check for valid leaf type or nested CfgNode 129 | _assert_with_logging( 130 | _valid_type(v, allow_cfg_node=False), 131 | "Key {} with value {} is not a valid type; valid types: {}".format( 132 | ".".join(key_list + [str(k)]), type(v), _VALID_TYPES 133 | ), 134 | ) 135 | return dic 136 | 137 | def __getattr__(self, name): 138 | if name in self: 139 | return self[name] 140 | else: 141 | raise AttributeError(name) 142 | 143 | def __setattr__(self, name, value): 144 | if self.is_frozen(): 145 | raise AttributeError( 146 | "Attempted to set {} to {}, but CfgNode is immutable".format( 147 | name, value 148 | ) 149 | ) 150 | 151 | _assert_with_logging( 152 | name not in self.__dict__, 153 | "Invalid attempt to modify internal CfgNode state: {}".format(name), 154 | ) 155 | _assert_with_logging( 156 | _valid_type(value, allow_cfg_node=True), 157 | "Invalid type {} for key {}; valid types = {}".format( 158 | type(value), name, _VALID_TYPES 159 | ), 160 | ) 161 | 162 | self[name] = value 163 | 164 | def __str__(self): 165 | def _indent(s_, num_spaces): 166 | s = s_.split("\n") 167 | if len(s) == 1: 168 | return s_ 169 | first = s.pop(0) 170 | s = [(num_spaces * " ") + line for line in s] 171 | s = "\n".join(s) 172 | s = first + "\n" + s 173 | return s 174 | 175 | r = "" 176 | s = [] 177 | for k, v in sorted(self.items()): 178 | seperator = "\n" if isinstance(v, CfgNode) else " " 179 | attr_str = "{}:{}{}".format(str(k), seperator, str(v)) 180 | attr_str = _indent(attr_str, 2) 181 | s.append(attr_str) 182 | r += "\n".join(s) 183 | return r 184 | 185 | def __repr__(self): 186 | return "{}({})".format(self.__class__.__name__, super(CfgNode, self).__repr__()) 187 | 188 | def dump(self, **kwargs): 189 | """Dump to a string.""" 190 | 191 | def convert_to_dict(cfg_node, key_list): 192 | if not isinstance(cfg_node, CfgNode): 193 | _assert_with_logging( 194 | _valid_type(cfg_node), 195 | "Key {} with value {} is not a valid type; valid types: {}".format( 196 | ".".join(key_list), type(cfg_node), _VALID_TYPES 197 | ), 198 | ) 199 | return cfg_node 200 | else: 201 | cfg_dict = dict(cfg_node) 202 | for k, v in cfg_dict.items(): 203 | cfg_dict[k] = convert_to_dict(v, key_list + [k]) 204 | return cfg_dict 205 | 206 | self_as_dict = convert_to_dict(self, []) 207 | return yaml.safe_dump(self_as_dict, **kwargs) 208 | 209 | def merge_from_file(self, cfg_filename): 210 | """Load a yaml config file and merge it this CfgNode.""" 211 | with open(cfg_filename, "r") as f: 212 | cfg = self.load_cfg(f) 213 | self.merge_from_other_cfg(cfg) 214 | 215 | def merge_from_other_cfg(self, cfg_other): 216 | """Merge `cfg_other` into this CfgNode.""" 217 | _merge_a_into_b(cfg_other, self, self, []) 218 | 219 | def merge_from_list(self, cfg_list): 220 | """Merge config (keys, values) in a list (e.g., from command line) into 221 | this CfgNode. For example, `cfg_list = ['FOO.BAR', 0.5]`. 222 | """ 223 | _assert_with_logging( 224 | len(cfg_list) % 2 == 0, 225 | "Override list has odd length: {}; it must be a list of pairs".format( 226 | cfg_list 227 | ), 228 | ) 229 | root = self 230 | for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]): 231 | if root.key_is_deprecated(full_key): 232 | continue 233 | if root.key_is_renamed(full_key): 234 | root.raise_key_rename_error(full_key) 235 | key_list = full_key.split(".") 236 | d = self 237 | for subkey in key_list[:-1]: 238 | _assert_with_logging( 239 | subkey in d, "Non-existent key: {}".format(full_key) 240 | ) 241 | d = d[subkey] 242 | subkey = key_list[-1] 243 | _assert_with_logging(subkey in d, "Non-existent key: {}".format(full_key)) 244 | value = self._decode_cfg_value(v) 245 | value = _check_and_coerce_cfg_value_type(value, d[subkey], subkey, full_key) 246 | d[subkey] = value 247 | 248 | def freeze(self): 249 | """Make this CfgNode and all of its children immutable.""" 250 | self._immutable(True) 251 | 252 | def defrost(self): 253 | """Make this CfgNode and all of its children mutable.""" 254 | self._immutable(False) 255 | 256 | def is_frozen(self): 257 | """Return mutability.""" 258 | return self.__dict__[CfgNode.IMMUTABLE] 259 | 260 | def _immutable(self, is_immutable): 261 | """Set immutability to is_immutable and recursively apply the setting 262 | to all nested CfgNodes. 263 | """ 264 | self.__dict__[CfgNode.IMMUTABLE] = is_immutable 265 | # Recursively set immutable state 266 | for v in self.__dict__.values(): 267 | if isinstance(v, CfgNode): 268 | v._immutable(is_immutable) 269 | for v in self.values(): 270 | if isinstance(v, CfgNode): 271 | v._immutable(is_immutable) 272 | 273 | def clone(self): 274 | """Recursively copy this CfgNode.""" 275 | return copy.deepcopy(self) 276 | 277 | def register_deprecated_key(self, key): 278 | """Register key (e.g. `FOO.BAR`) a deprecated option. When merging deprecated 279 | keys a warning is generated and the key is ignored. 280 | """ 281 | _assert_with_logging( 282 | key not in self.__dict__[CfgNode.DEPRECATED_KEYS], 283 | "key {} is already registered as a deprecated key".format(key), 284 | ) 285 | self.__dict__[CfgNode.DEPRECATED_KEYS].add(key) 286 | 287 | def register_renamed_key(self, old_name, new_name, message=None): 288 | """Register a key as having been renamed from `old_name` to `new_name`. 289 | When merging a renamed key, an exception is thrown alerting to user to 290 | the fact that the key has been renamed. 291 | """ 292 | _assert_with_logging( 293 | old_name not in self.__dict__[CfgNode.RENAMED_KEYS], 294 | "key {} is already registered as a renamed cfg key".format(old_name), 295 | ) 296 | value = new_name 297 | if message: 298 | value = (new_name, message) 299 | self.__dict__[CfgNode.RENAMED_KEYS][old_name] = value 300 | 301 | def key_is_deprecated(self, full_key): 302 | """Test if a key is deprecated.""" 303 | if full_key in self.__dict__[CfgNode.DEPRECATED_KEYS]: 304 | logger.warning("Deprecated config key (ignoring): {}".format(full_key)) 305 | return True 306 | return False 307 | 308 | def key_is_renamed(self, full_key): 309 | """Test if a key is renamed.""" 310 | return full_key in self.__dict__[CfgNode.RENAMED_KEYS] 311 | 312 | def raise_key_rename_error(self, full_key): 313 | new_key = self.__dict__[CfgNode.RENAMED_KEYS][full_key] 314 | if isinstance(new_key, tuple): 315 | msg = " Note: " + new_key[1] 316 | new_key = new_key[0] 317 | else: 318 | msg = "" 319 | raise KeyError( 320 | "Key {} was renamed to {}; please update your config.{}".format( 321 | full_key, new_key, msg 322 | ) 323 | ) 324 | 325 | def is_new_allowed(self): 326 | return self.__dict__[CfgNode.NEW_ALLOWED] 327 | 328 | def set_new_allowed(self, is_new_allowed): 329 | """ 330 | Set this config (and recursively its subconfigs) to allow merging 331 | new keys from other configs. 332 | """ 333 | self.__dict__[CfgNode.NEW_ALLOWED] = is_new_allowed 334 | # Recursively set new_allowed state 335 | for v in self.__dict__.values(): 336 | if isinstance(v, CfgNode): 337 | v.set_new_allowed(is_new_allowed) 338 | for v in self.values(): 339 | if isinstance(v, CfgNode): 340 | v.set_new_allowed(is_new_allowed) 341 | 342 | @classmethod 343 | def load_cfg(cls, cfg_file_obj_or_str): 344 | """ 345 | Load a cfg. 346 | Args: 347 | cfg_file_obj_or_str (str or file): 348 | Supports loading from: 349 | - A file object backed by a YAML file 350 | - A file object backed by a Python source file that exports an attribute 351 | "cfg" that is either a dict or a CfgNode 352 | - A string that can be parsed as valid YAML 353 | """ 354 | _assert_with_logging( 355 | isinstance(cfg_file_obj_or_str, _FILE_TYPES + (str,)), 356 | "Expected first argument to be of type {} or {}, but it was {}".format( 357 | _FILE_TYPES, str, type(cfg_file_obj_or_str) 358 | ), 359 | ) 360 | if isinstance(cfg_file_obj_or_str, str): 361 | return cls._load_cfg_from_yaml_str(cfg_file_obj_or_str) 362 | elif isinstance(cfg_file_obj_or_str, _FILE_TYPES): 363 | return cls._load_cfg_from_file(cfg_file_obj_or_str) 364 | else: 365 | raise NotImplementedError("Impossible to reach here (unless there's a bug)") 366 | 367 | @classmethod 368 | def _load_cfg_from_file(cls, file_obj): 369 | """Load a config from a YAML file or a Python source file.""" 370 | _, file_extension = os.path.splitext(file_obj.name) 371 | if file_extension in _YAML_EXTS: 372 | return cls._load_cfg_from_yaml_str(file_obj.read()) 373 | elif file_extension in _PY_EXTS: 374 | return cls._load_cfg_py_source(file_obj.name) 375 | else: 376 | raise Exception( 377 | "Attempt to load from an unsupported file type {}; " 378 | "only {} are supported".format(file_obj, _YAML_EXTS.union(_PY_EXTS)) 379 | ) 380 | 381 | @classmethod 382 | def _load_cfg_from_yaml_str(cls, str_obj): 383 | """Load a config from a YAML string encoding.""" 384 | cfg_as_dict = yaml.safe_load(str_obj) 385 | return cls(cfg_as_dict) 386 | 387 | @classmethod 388 | def _load_cfg_py_source(cls, filename): 389 | """Load a config from a Python source file.""" 390 | module = _load_module_from_file("yacs.config.override", filename) 391 | _assert_with_logging( 392 | hasattr(module, "cfg"), 393 | "Python module from file {} must have 'cfg' attr".format(filename), 394 | ) 395 | VALID_ATTR_TYPES = {dict, CfgNode} 396 | _assert_with_logging( 397 | type(module.cfg) in VALID_ATTR_TYPES, 398 | "Imported module 'cfg' attr must be in {} but is {} instead".format( 399 | VALID_ATTR_TYPES, type(module.cfg) 400 | ), 401 | ) 402 | return cls(module.cfg) 403 | 404 | @classmethod 405 | def _decode_cfg_value(cls, value): 406 | """ 407 | Decodes a raw config value (e.g., from a yaml config files or command 408 | line argument) into a Python object. 409 | 410 | If the value is a dict, it will be interpreted as a new CfgNode. 411 | If the value is a str, it will be evaluated as literals. 412 | Otherwise it is returned as-is. 413 | """ 414 | # Configs parsed from raw yaml will contain dictionary keys that need to be 415 | # converted to CfgNode objects 416 | if isinstance(value, dict): 417 | return cls(value) 418 | # All remaining processing is only applied to strings 419 | if not isinstance(value, str): 420 | return value 421 | # Try to interpret `value` as a: 422 | # string, number, tuple, list, dict, boolean, or None 423 | try: 424 | value = literal_eval(value) 425 | # The following two excepts allow v to pass through when it represents a 426 | # string. 427 | # 428 | # Longer explanation: 429 | # The type of v is always a string (before calling literal_eval), but 430 | # sometimes it *represents* a string and other times a data structure, like 431 | # a list. In the case that v represents a string, what we got back from the 432 | # yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is 433 | # ok with '"foo"', but will raise a ValueError if given 'foo'. In other 434 | # cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval 435 | # will raise a SyntaxError. 436 | except ValueError: 437 | pass 438 | except SyntaxError: 439 | pass 440 | return value 441 | 442 | 443 | load_cfg = ( 444 | CfgNode.load_cfg 445 | ) # keep this function in global scope for backward compatibility 446 | 447 | 448 | def _valid_type(value, allow_cfg_node=False): 449 | return (type(value) in _VALID_TYPES) or ( 450 | allow_cfg_node and isinstance(value, CfgNode) 451 | ) 452 | 453 | 454 | def _merge_a_into_b(a, b, root, key_list): 455 | """Merge config dictionary a into config dictionary b, clobbering the 456 | options in b whenever they are also specified in a. 457 | """ 458 | _assert_with_logging( 459 | isinstance(a, CfgNode), 460 | "`a` (cur type {}) must be an instance of {}".format(type(a), CfgNode), 461 | ) 462 | _assert_with_logging( 463 | isinstance(b, CfgNode), 464 | "`b` (cur type {}) must be an instance of {}".format(type(b), CfgNode), 465 | ) 466 | 467 | for k, v_ in a.items(): 468 | full_key = ".".join(key_list + [k]) 469 | 470 | v = copy.deepcopy(v_) 471 | v = b._decode_cfg_value(v) 472 | 473 | if k in b: 474 | v = _check_and_coerce_cfg_value_type(v, b[k], k, full_key) 475 | # Recursively merge dicts 476 | if isinstance(v, CfgNode): 477 | try: 478 | _merge_a_into_b(v, b[k], root, key_list + [k]) 479 | except BaseException: 480 | raise 481 | else: 482 | b[k] = v 483 | elif b.is_new_allowed(): 484 | b[k] = v 485 | else: 486 | if root.key_is_deprecated(full_key): 487 | continue 488 | elif root.key_is_renamed(full_key): 489 | root.raise_key_rename_error(full_key) 490 | else: 491 | raise KeyError("Non-existent config key: {}".format(full_key)) 492 | 493 | 494 | def _check_and_coerce_cfg_value_type(replacement, original, key, full_key): 495 | """Checks that `replacement`, which is intended to replace `original` is of 496 | the right type. The type is correct if it matches exactly or is one of a few 497 | cases in which the type can be easily coerced. 498 | """ 499 | original_type = type(original) 500 | replacement_type = type(replacement) 501 | 502 | # The types must match (with some exceptions) 503 | if replacement_type == original_type: 504 | return replacement 505 | 506 | # If either of them is None, allow type conversion to one of the valid types 507 | if (replacement_type == type(None) and original_type in _VALID_TYPES) or ( 508 | original_type == type(None) and replacement_type in _VALID_TYPES 509 | ): 510 | return replacement 511 | 512 | # Cast replacement from from_type to to_type if the replacement and original 513 | # types match from_type and to_type 514 | def conditional_cast(from_type, to_type): 515 | if replacement_type == from_type and original_type == to_type: 516 | return True, to_type(replacement) 517 | else: 518 | return False, None 519 | 520 | # Conditionally casts 521 | # list <-> tuple 522 | casts = [(tuple, list), (list, tuple)] 523 | # For py2: allow converting from str (bytes) to a unicode string 524 | try: 525 | casts.append((str, unicode)) # noqa: F821 526 | except Exception: 527 | pass 528 | 529 | for (from_type, to_type) in casts: 530 | converted, converted_value = conditional_cast(from_type, to_type) 531 | if converted: 532 | return converted_value 533 | 534 | raise ValueError( 535 | "Type mismatch ({} vs. {}) with values ({} vs. {}) for config " 536 | "key: {}".format( 537 | original_type, replacement_type, original, replacement, full_key 538 | ) 539 | ) 540 | 541 | 542 | def _assert_with_logging(cond, msg): 543 | if not cond: 544 | logger.debug(msg) 545 | assert cond, msg 546 | 547 | 548 | def _load_module_from_file(name, filename): 549 | if _PY2: 550 | module = imp.load_source(name, filename) 551 | else: 552 | spec = importlib.util.spec_from_file_location(name, filename) 553 | module = importlib.util.module_from_spec(spec) 554 | spec.loader.exec_module(module) 555 | return module 556 | --------------------------------------------------------------------------------