├── .gitignore
├── Dockerfile
├── LICENSE
├── README.md
├── README_kr.md
├── README_zh.md
├── core
├── dsproc_mcls.py
├── dsproc_mclsmfolder.py
└── mengine.py
├── dataset
├── label.txt
├── train.txt
├── val.txt
└── val_dataset
│ └── 51aa9b8d0da890cd1d0c5029e3d89e3c.jpg
├── images
└── competition_title.png
├── infer_api.py
├── main.sh
├── main_infer.py
├── main_train.py
├── main_train_single_gpu.py
├── merge.py
├── model
├── convnext.py
└── replknet.py
├── requirements.txt
└── toolkit
├── chelper.py
├── cmetric.py
├── dhelper.py
├── dtransform.py
└── yacs.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.idea
2 | .DS_Store
3 | *.pth
4 | *.pyc
5 | *.ipynb
6 | __pycache_
7 | vision_rush_image*
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | # 引入cuda版本
2 | FROM nvidia/cuda:11.3.1-cudnn8-runtime-ubuntu20.04
3 |
4 | # 设置工作目录
5 | WORKDIR /code
6 |
7 | RUN ln -sf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime && echo "Asia/Shanghai" > /etc/timezone
8 | RUN apt-get update -y
9 | RUN apt-get install software-properties-common -y && add-apt-repository ppa:deadsnakes/ppa
10 | RUN apt-get install python3.8 python3-pip curl libgl1 libglib2.0-0 ffmpeg libsm6 libxext6 -y && apt-get clean && rm -rf /var/lib/apt/lists/*
11 | RUN update-alternatives --install /usr/bin/pytho3 python3 /usr/bin/python3.8 0
12 | RUN update-alternatives --set python3 /usr/bin/python3.8
13 |
14 | # 复制该./requirements.txt文件到工作目录中,安装python依赖库。
15 | ADD ./requirements.txt /code/requirements.txt
16 | RUN pip3 install pip --upgrade -i https://pypi.mirrors.ustc.edu.cn/simple/
17 | RUN pip3 install -r requirements.txt -i https://pypi.mirrors.ustc.edu.cn/simple/ && rm -rf `pip3 cache dir`
18 |
19 | # 复制模型及代码到工作目录
20 | ADD ./core /code/core
21 | ADD ./dataset /code/dataset
22 | ADD ./model /code/model
23 | ADD ./pre_model /code/pre_model
24 | ADD ./final_model_csv /code/final_model_csv
25 | ADD ./toolkit /code/toolkit
26 | ADD ./infer_api.py /code/infer_api.py
27 | ADD ./main_infer.py /code/main_infer.py
28 | ADD ./main_train.py /code/main_train.py
29 | ADD ./merge.py /code/merge.py
30 | ADD ./main.sh /code/main.sh
31 | ADD ./README.md /code/README.md
32 | ADD ./Dockerfile /code/Dockerfile
33 |
34 | #运行python文件
35 | ENTRYPOINT ["python3","infer_api.py"]
36 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | If you like our project, please give us a star ⭐ on GitHub for latest update.
3 |
4 |
5 |
6 |
7 | [](https://github.com/VisionRush/DeepFakeDefenders/blob/main/LICENSE)
8 | 
9 | [](https://hits.seeyoufarm.com)
10 | 
11 | [](https://github.com/PKU-YuanGroup/MoE-LLaVA/issues?q=is%3Aopen+is%3Aissue)
12 | [](https://github.com/PKU-YuanGroup/MoE-LLaVA/issues?q=is%3Aissue+is%3Aclosed)
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 | 💡 We also provide [[中文文档 / CHINESE DOC](README_zh.md)] and [[한국어 문서 / KOREAN DOC](README_kr.md)]. We very welcome and appreciate your contributions to this project.
21 |
22 | ## 📣 News
23 |
24 | * **[2024.09.05]** 🔥 We officially released the initial version of Deepfake defenders, and we won the third prize in the deepfake challenge at [[the conference on the bund](https://www.atecup.cn/deepfake)].
25 |
26 | ## 🚀 Quickly Start
27 |
28 | ### 1. Pretrained Models Preparation
29 |
30 | Before getting started, please place the ImageNet-1K pretrained weight files in the `./pre_model` directory. The download links for the weights are provided below:
31 | ```
32 | RepLKNet: https://drive.google.com/file/d/1vo-P3XB6mRLUeDzmgv90dOu73uCeLfZN/view?usp=sharing
33 | ConvNeXt: https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_384.pth
34 | ```
35 |
36 | ### 2. Training from Scratch
37 |
38 | #### 2.1 Modifying the dataset path
39 |
40 | Place the training-set **(\*.txt)** file, validation-set **(\*.txt)** file, and label **(\*.txt)** file required for training in the dataset folder and name them with the same file name (there are various txt examples under dataset)
41 |
42 | #### 2.2 Modifying the Hyper-parameters
43 |
44 | For the two models (RepLKNet and ConvNeXt) used, the following parameters need to be changed in `main_train.py`:
45 |
46 | ```python
47 | # For RepLKNet
48 | cfg.network.name = 'replknet'; cfg.train.batch_size = 16
49 | # For ConvNeXt
50 | cfg.network.name = 'convnext'; cfg.train.batch_size = 24
51 | ```
52 |
53 | #### 2.3 Using the training script
54 | ##### Multi-GPUs:(8 GPUs were used)
55 | ```shell
56 | bash main.sh
57 | ```
58 | ##### Single-GPU:
59 | ```shell
60 | CUDA_VISIBLE_DEVICES=0 python main_train_single_gpu.py
61 | ```
62 |
63 | #### 2.4 Model Assembling
64 |
65 | Replace the ConvNeXt trained model path and the RepLKNet trained model path in `merge.py`, and execute `python merge.py` to obtain the final inference test model.
66 |
67 | #### 2.5 Inference
68 |
69 | The following example uses the **POST** request interface to request the image path as the request parameter, and the response output is the deepfake score predicted by the model.
70 |
71 | ```python
72 | #!/usr/bin/env python
73 | # -*- coding:utf-8 -*-
74 | import requests
75 | import json
76 | import requests
77 | import json
78 |
79 | header = {
80 | 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.107 Safari/537.36'
81 | }
82 |
83 | url = 'http://ip:10005/inter_api'
84 | image_path = './dataset/val_dataset/51aa9b8d0da890cd1d0c5029e3d89e3c.jpg'
85 | data_map = {'img_path':image_path}
86 | response = requests.post(url, data=json.dumps(data_map), headers=header)
87 | content = response.content
88 | print(json.loads(content))
89 | ```
90 |
91 | ### 3. Deploy in Docker
92 | #### Building
93 |
94 | ```shell
95 | sudo docker build -t vision-rush-image:1.0.1 --network host .
96 | ```
97 |
98 | #### Running
99 |
100 | ```shell
101 | sudo docker run -d --name vision_rush_image --gpus=all --net host vision-rush-image:1.0.1
102 | ```
103 |
104 | ## Star History
105 |
106 | [](https://star-history.com/#DeepFakeDefenders/DeepFakeDefenders&Date)
107 |
--------------------------------------------------------------------------------
/README_kr.md:
--------------------------------------------------------------------------------
1 |
2 | 저희의 프로젝트가 마음에 드신다면, GitHub에서 별 ⭐ 을 GitHub에서 눌러 최신 업데이트를 받아보세요.
3 |
4 |
5 |
6 |
7 |
8 | [](https://github.com/VisionRush/DeepFakeDefenders/blob/main/LICENSE)
9 | 
10 | [](https://hits.seeyoufarm.com)
11 | 
12 | [](https://github.com/PKU-YuanGroup/MoE-LLaVA/issues?q=is%3Aopen+is%3Aissue)
13 | [](https://github.com/PKU-YuanGroup/MoE-LLaVA/issues?q=is%3Aissue+is%3Aclosed)
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 | 💡 [[영어 문서 / ENGLISH DOC](README)]와 [[중국어 문서 / CHINESE DOC](README_zh.md)]를 제공하고 있습니다. 저희는 이 프로젝트에 대한 기여를 매우 환영하고 감사드립니다.
22 |
23 | ## 📣 뉴스
24 |
25 | * **[2024.09.05]** 🔥 Deepfake defenders의 초기 버전을 공식적으로 릴리즈했으며, [[Bund에서의 컨퍼런스](https://www.atecup.cn/deepfake)]에서 deepfake challenge에서 3등을 수상했습니다.
26 |
27 | ## 🚀 빠르게 시작하기
28 |
29 | ### 1. 사전에 훈련된 모델 준비하기
30 |
31 | 시작하기 전, ImageNet-1K로 사전에 훈련된 가중치 파일들을 `./pre_model` 디렉토리에 넣어주세요. 가중치 파일들의 다운로드 링크들은 아래와 같습니다.
32 | ```
33 | RepLKNet: https://drive.google.com/file/d/1vo-P3XB6mRLUeDzmgv90dOu73uCeLfZN/view?usp=sharing
34 | ConvNeXt: https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_384.pth
35 | ```
36 |
37 | ### 2. 처음부터 훈련시키기
38 |
39 | #### 2.1 데이터셋의 경로 조정하기
40 |
41 | 학습에 필요한 파일인 트레이닝셋 파일인 **(\*.txt)** 파일, 벨리데이션셋 파일 **(\*.txt)**, 라벨 파일 **(\*.txt)** 을 dataset 폴더에 넣고, 파일들을 같은 이름으로 지정하세요. (dataset 아래에 다양한 txt 예제들이 있습니다)
42 |
43 | #### 2.2 하이퍼 파라미터 조정하기
44 |
45 | 두 모델(RepLKNet과 ConvNeXt)을 위해 `main_train.py`의 파라미터가 아래와 같이 설정되어야 합니다.
46 |
47 | ```python
48 | # RepLKNet으로 설정
49 | cfg.network.name = 'replknet'; cfg.train.batch_size = 16
50 | # ConvNeXt으로 설정
51 | cfg.network.name = 'convnext'; cfg.train.batch_size = 24
52 | ```
53 |
54 | #### 2.3 훈련 스크립트 사용하기
55 |
56 | ##### 다중 GPU: ( GPU 8개가 사용되었습니다. )
57 | ```shell
58 | bash main.sh
59 | ```
60 |
61 | ##### 단일 GPU:
62 | ```shell
63 | CUDA_VISIBLE_DEVICES=0 python main_train_single_gpu.py
64 | ```
65 |
66 | #### 2.4 모델 조립하기
67 |
68 | `mergy.py`의 ConvNeXt로 훈련된 모델 경로와 RepLKNet으로 훈련된 경로를 바꾸고, `python mergy.py`를 실행시켜 최종 인퍼런스 테스트 모델을 만듭니다.
69 |
70 | #### 2.5 인퍼런스
71 |
72 | 다음의 예제는 **POST** 요청 인터페이스를 사용하여 이미지 경로를 매개변수로 요청하여 모델이 예측한 딥페이크 점수를 응답을 출력합니다.
73 |
74 | ```python
75 | #!/usr/bin/env python
76 | # -*- coding:utf-8 -*-
77 | import requests
78 | import json
79 | import requests
80 | import json
81 |
82 | header = {
83 | 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.107 Safari/537.36'
84 | }
85 |
86 | url = 'http://ip:10005/inter_api'
87 | image_path = './dataset/val_dataset/51aa9b8d0da890cd1d0c5029e3d89e3c.jpg'
88 | data_map = {'img_path':image_path}
89 | response = requests.post(url, data=json.dumps(data_map), headers=header)
90 | content = response.content
91 | print(json.loads(content))
92 | ```
93 |
94 | ### 3. Docker에 배포하기
95 |
96 | #### 빌드하기
97 |
98 | ```shell
99 | sudo docker build -t vision-rush-image:1.0.1 --network host .
100 | ```
101 |
102 | #### 실행시기키
103 |
104 | ```shell
105 | sudo docker run -d --name vision_rush_image --gpus=all --net host vision-rush-image:1.0.1
106 | ```
107 |
108 | ## Star History
109 |
110 | [](https://star-history.com/#DeepFakeDefenders/DeepFakeDefenders&Date)
111 |
--------------------------------------------------------------------------------
/README_zh.md:
--------------------------------------------------------------------------------
1 |
2 | 如果您喜欢我们的项目,请在 GitHub 上给我们一个Star ⭐ 以获取最新更新。
3 |
4 |
5 |
6 |
7 | [](https://github.com/VisionRush/DeepFakeDefenders/blob/main/LICENSE)
8 | 
9 | [](https://hits.seeyoufarm.com)
10 | 
11 | [](https://github.com/PKU-YuanGroup/MoE-LLaVA/issues?q=is%3Aopen+is%3Aissue)
12 | [](https://github.com/PKU-YuanGroup/MoE-LLaVA/issues?q=is%3Aissue+is%3Aclosed)
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 | 💡 我们在这里提供了 [[英文文档 / ENGLISH DOC](README.md)] 和 [[韩文文档 / KOREAN DOC](README_kr.md)],我们十分欢迎和感谢您能够对我们的项目提出建议和贡献。
21 |
22 | ## 📣 News
23 |
24 | * **[2024.09.05]** 🔥 我们正式发布了Deepfake Defenders的初始版本,并在Deepfake挑战赛中获得了三等奖
25 | [[外滩大会](https://www.atecup.cn/deepfake)].
26 |
27 | ## 🚀 快速开始
28 | ### 一、预训练模型准备
29 | 在开始使用之前,请将模型的ImageNet-1K预训练权重文件放置在`./pre_model`目录下,权重下载链接如下:
30 | ```
31 | RepLKNet: https://drive.google.com/file/d/1vo-P3XB6mRLUeDzmgv90dOu73uCeLfZN/view?usp=sharing
32 | ConvNeXt: https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_384.pth
33 | ```
34 |
35 | ### 二、训练
36 |
37 | #### 1. 更改数据集路径
38 | 将训练所需的训练集txt文件、验证集txt文件以及标签txt文件分别放置在dataset文件夹下,并命名为相同的文件名(dataset下有各个txt示例)
39 | #### 2. 更改超参数
40 | 针对所采用的两个模型,在main_train.py中分别需要更改如下参数:
41 | ```python
42 | RepLKNet---cfg.network.name = 'replknet'; cfg.train.batch_size = 16
43 | ConvNeXt---cfg.network.name = 'convnext'; cfg.train.batch_size = 24
44 | ```
45 |
46 | #### 3. 启动训练
47 | ##### 单机多卡训练:(8卡)
48 | ```shell
49 | bash main.sh
50 | ```
51 | ##### 单机单卡训练:
52 | ```shell
53 | CUDA_VISIBLE_DEVICES=0 python main_train_single_gpu.py
54 | ```
55 |
56 | #### 4. 模型融合
57 | 在merge.py中更改ConvNeXt模型路径以及RepLKNet模型路径,执行python merge.py后获取最终推理测试模型。
58 |
59 | #### 5. 推理
60 |
61 | 示例如下,通过post请求接口请求,请求参数为图像路径,响应输出为模型预测的deepfake分数
62 |
63 | ```python
64 | #!/usr/bin/env python
65 | # -*- coding:utf-8 -*-
66 | import requests
67 | import json
68 | import requests
69 | import json
70 |
71 | header = {
72 | 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.107 Safari/537.36'
73 | }
74 |
75 | url = 'http://ip:10005/inter_api'
76 | image_path = './dataset/val_dataset/51aa9b8d0da890cd1d0c5029e3d89e3c.jpg'
77 | data_map = {'img_path':image_path}
78 | response = requests.post(url, data=json.dumps(data_map), headers=header)
79 | content = response.content
80 | print(json.loads(content))
81 | ```
82 |
83 | ### 三、docker
84 | #### 1. docker构建
85 | sudo docker build -t vision-rush-image:1.0.1 --network host .
86 | #### 2. 容器启动
87 | sudo docker run -d --name vision_rush_image --gpus=all --net host vision-rush-image:1.0.1
88 |
89 | ## Star History
90 |
91 | [](https://star-history.com/#DeepFakeDefenders/DeepFakeDefenders&Date)
92 |
--------------------------------------------------------------------------------
/core/dsproc_mcls.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from PIL import Image
4 | from collections import OrderedDict
5 | from toolkit.dhelper import traverse_recursively
6 | import numpy as np
7 | import einops
8 |
9 | from torch import nn
10 | import timm
11 | import torch.nn.functional as F
12 |
13 |
14 | class SRMConv2d_simple(nn.Module):
15 | def __init__(self, inc=3):
16 | super(SRMConv2d_simple, self).__init__()
17 | self.truc = nn.Hardtanh(-3, 3)
18 | self.kernel = torch.from_numpy(self._build_kernel(inc)).float()
19 |
20 | def forward(self, x):
21 | out = F.conv2d(x, self.kernel, stride=1, padding=2)
22 | out = self.truc(out)
23 |
24 | return out
25 |
26 | def _build_kernel(self, inc):
27 | # filter1: KB
28 | filter1 = [[0, 0, 0, 0, 0],
29 | [0, -1, 2, -1, 0],
30 | [0, 2, -4, 2, 0],
31 | [0, -1, 2, -1, 0],
32 | [0, 0, 0, 0, 0]]
33 | # filter2:KV
34 | filter2 = [[-1, 2, -2, 2, -1],
35 | [2, -6, 8, -6, 2],
36 | [-2, 8, -12, 8, -2],
37 | [2, -6, 8, -6, 2],
38 | [-1, 2, -2, 2, -1]]
39 | # filter3:hor 2rd
40 | filter3 = [[0, 0, 0, 0, 0],
41 | [0, 0, 0, 0, 0],
42 | [0, 1, -2, 1, 0],
43 | [0, 0, 0, 0, 0],
44 | [0, 0, 0, 0, 0]]
45 |
46 | filter1 = np.asarray(filter1, dtype=float) / 4.
47 | filter2 = np.asarray(filter2, dtype=float) / 12.
48 | filter3 = np.asarray(filter3, dtype=float) / 2.
49 | # statck the filters
50 | filters = [[filter1], # , filter1, filter1],
51 | [filter2], # , filter2, filter2],
52 | [filter3]] # , filter3, filter3]]
53 | filters = np.array(filters)
54 | filters = np.repeat(filters, inc, axis=1)
55 | return filters
56 |
57 |
58 | class MultiClassificationProcessor(torch.utils.data.Dataset):
59 |
60 | def __init__(self, transform=None):
61 | self.transformer_ = transform
62 | self.extension_ = '.jpg .jpeg .png .bmp .webp .tif .eps'
63 | # load category info
64 | self.ctg_names_ = [] # ctg_idx to ctg_name
65 | self.ctg_name2idx_ = OrderedDict() # ctg_name to ctg_idx
66 | # load image infos
67 | self.img_names_ = [] # img_idx to img_name
68 | self.img_paths_ = [] # img_idx to img_path
69 | self.img_labels_ = [] # img_idx to img_label
70 |
71 | self.srm = SRMConv2d_simple()
72 |
73 | def load_data_from_dir(self, dataset_list):
74 | """Load image from folder.
75 |
76 | Args:
77 | dataset_list: dataset list, each folder is a category, format is [file_root].
78 | """
79 | # load sample
80 | for img_root in dataset_list:
81 | ctg_name = os.path.basename(img_root)
82 | self.ctg_name2idx_[ctg_name] = len(self.ctg_names_)
83 | self.ctg_names_.append(ctg_name)
84 | img_paths = []
85 | traverse_recursively(img_root, img_paths, self.extension_)
86 | for img_path in img_paths:
87 | img_name = os.path.basename(img_path)
88 | self.img_names_.append(img_name)
89 | self.img_paths_.append(img_path)
90 | self.img_labels_.append(self.ctg_name2idx_[ctg_name])
91 | print('log: category is %d(%s), image num is %d' % (self.ctg_name2idx_[ctg_name], ctg_name, len(img_paths)))
92 |
93 | def load_data_from_txt(self, img_list_txt, ctg_list_txt):
94 | """Load image from txt.
95 |
96 | Args:
97 | img_list_txt: image txt, format is [file_path, ctg_idx].
98 | ctg_list_txt: category txt, format is [ctg_name, ctg_idx].
99 | """
100 | # check
101 | assert os.path.exists(img_list_txt), 'log: does not exist: {}'.format(img_list_txt)
102 | assert os.path.exists(ctg_list_txt), 'log: does not exist: {}'.format(ctg_list_txt)
103 |
104 | # load category
105 | # : open category info file
106 | with open(ctg_list_txt) as f:
107 | ctg_infos = [line.strip() for line in f.readlines()]
108 | # :load category name & category index
109 | for ctg_info in ctg_infos:
110 | tmp = ctg_info.split(' ')
111 | ctg_name = tmp[0]
112 | ctg_idx = int(tmp[-1])
113 | self.ctg_name2idx_[ctg_name] = ctg_idx
114 | self.ctg_names_.append(ctg_name)
115 |
116 | # load sample
117 | # : open image info file
118 | with open(img_list_txt) as f:
119 | img_infos = [line.strip() for line in f.readlines()]
120 | # : load image path & category index
121 | for img_info in img_infos:
122 | tmp = img_info.split(' ')
123 |
124 | img_path = ' '.join(tmp[:-1])
125 | img_name = img_path.split('/')[-1]
126 | ctg_idx = int(tmp[-1])
127 | self.img_names_.append(img_name)
128 | self.img_paths_.append(img_path)
129 | self.img_labels_.append(ctg_idx)
130 |
131 | for ctg_name in self.ctg_names_:
132 | print('log: category is %d(%s), image num is %d' % (self.ctg_name2idx_[ctg_name], ctg_name, self.img_labels_.count(self.ctg_name2idx_[ctg_name])))
133 |
134 | def _add_new_channels_worker(self, image):
135 | new_channels = []
136 |
137 | image = einops.rearrange(image, "h w c -> c h w")
138 | image = (image- torch.as_tensor(timm.data.constants.IMAGENET_DEFAULT_MEAN).view(-1, 1, 1)) / torch.as_tensor(timm.data.constants.IMAGENET_DEFAULT_STD).view(-1, 1, 1)
139 | srm = self.srm(image.unsqueeze(0)).squeeze(0)
140 | new_channels.append(einops.rearrange(srm, "c h w -> h w c").numpy())
141 |
142 | new_channels = np.concatenate(new_channels, axis=2)
143 | return torch.from_numpy(new_channels).float()
144 |
145 | def add_new_channels(self, images):
146 | images_copied = einops.rearrange(images, "c h w -> h w c")
147 | new_channels = self._add_new_channels_worker(images_copied)
148 | images_copied = torch.concatenate([images_copied, new_channels], dim=-1)
149 | images_copied = einops.rearrange(images_copied, "h w c -> c h w")
150 |
151 | return images_copied
152 |
153 | def __getitem__(self, index):
154 | img_path = self.img_paths_[index]
155 | img_label = self.img_labels_[index]
156 |
157 | img_data = Image.open(img_path).convert('RGB')
158 | img_size = img_data.size[::-1] # [h, w]
159 |
160 | if self.transformer_ is not None:
161 | img_data = self.transformer_[img_label](img_data)
162 | img_data = self.add_new_channels(img_data)
163 |
164 | return img_data, img_label, img_path, img_size
165 |
166 | def __len__(self):
167 | return len(self.img_names_)
168 |
--------------------------------------------------------------------------------
/core/dsproc_mclsmfolder.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from PIL import Image
4 | from collections import OrderedDict
5 | from toolkit.dhelper import traverse_recursively
6 | import random
7 | from torch import nn
8 | import numpy as np
9 | import timm
10 | import einops
11 | import torch.nn.functional as F
12 |
13 |
14 | class SRMConv2d_simple(nn.Module):
15 | def __init__(self, inc=3):
16 | super(SRMConv2d_simple, self).__init__()
17 | self.truc = nn.Hardtanh(-3, 3)
18 | self.kernel = torch.from_numpy(self._build_kernel(inc)).float()
19 |
20 | def forward(self, x):
21 | out = F.conv2d(x, self.kernel, stride=1, padding=2)
22 | out = self.truc(out)
23 |
24 | return out
25 |
26 | def _build_kernel(self, inc):
27 | # filter1: KB
28 | filter1 = [[0, 0, 0, 0, 0],
29 | [0, -1, 2, -1, 0],
30 | [0, 2, -4, 2, 0],
31 | [0, -1, 2, -1, 0],
32 | [0, 0, 0, 0, 0]]
33 | # filter2:KV
34 | filter2 = [[-1, 2, -2, 2, -1],
35 | [2, -6, 8, -6, 2],
36 | [-2, 8, -12, 8, -2],
37 | [2, -6, 8, -6, 2],
38 | [-1, 2, -2, 2, -1]]
39 | # filter3:hor 2rd
40 | filter3 = [[0, 0, 0, 0, 0],
41 | [0, 0, 0, 0, 0],
42 | [0, 1, -2, 1, 0],
43 | [0, 0, 0, 0, 0],
44 | [0, 0, 0, 0, 0]]
45 |
46 | filter1 = np.asarray(filter1, dtype=float) / 4.
47 | filter2 = np.asarray(filter2, dtype=float) / 12.
48 | filter3 = np.asarray(filter3, dtype=float) / 2.
49 | # statck the filters
50 | filters = [[filter1], # , filter1, filter1],
51 | [filter2], # , filter2, filter2],
52 | [filter3]] # , filter3, filter3]]
53 | filters = np.array(filters)
54 | filters = np.repeat(filters, inc, axis=1)
55 | return filters
56 |
57 |
58 | class MultiClassificationProcessor_mfolder(torch.utils.data.Dataset):
59 | def __init__(self, transform=None):
60 | self.transformer_ = transform
61 | self.extension_ = '.jpg .jpeg .png .bmp .webp .tif .eps'
62 | # load category info
63 | self.ctg_names_ = [] # ctg_idx to ctg_name
64 | self.ctg_name2idx_ = OrderedDict() # ctg_name to ctg_idx
65 | # load image infos
66 | self.img_names_ = [] # img_idx to img_name
67 | self.img_paths_ = [] # img_idx to img_path
68 | self.img_labels_ = [] # img_idx to img_label
69 |
70 | self.srm = SRMConv2d_simple()
71 |
72 | def load_data_from_dir_test(self, folders):
73 |
74 | # Load image from folder.
75 |
76 | # Args:
77 | # dataset_list: dictionary where key is a label and value is a list of folder paths.
78 | print(folders)
79 | img_paths = []
80 | traverse_recursively(folders, img_paths, self.extension_)
81 |
82 | for img_path in img_paths:
83 | img_name = os.path.basename(img_path)
84 | self.img_names_.append(img_name)
85 | self.img_paths_.append(img_path)
86 |
87 | length = len(img_paths)
88 | print('log: {} image num is {}'.format(folders, length))
89 |
90 | def load_data_from_dir(self, dataset_list):
91 |
92 | # Load image from folder.
93 |
94 | # Args:
95 | # dataset_list: dictionary where key is a label and value is a list of folder paths.
96 |
97 | for ctg_name, folders in dataset_list.items():
98 |
99 | if ctg_name not in self.ctg_name2idx_:
100 | self.ctg_name2idx_[ctg_name] = len(self.ctg_names_)
101 | self.ctg_names_.append(ctg_name)
102 |
103 | for img_root in folders:
104 | img_paths = []
105 | traverse_recursively(img_root, img_paths, self.extension_)
106 |
107 | print(img_root)
108 |
109 | length = len(img_paths)
110 | for i in range(length):
111 | img_path = img_paths[i]
112 | img_name = os.path.basename(img_path)
113 | self.img_names_.append(img_name)
114 | self.img_paths_.append(img_path)
115 | self.img_labels_.append(self.ctg_name2idx_[ctg_name])
116 |
117 | print('log: category is %d(%s), image num is %d' % (self.ctg_name2idx_[ctg_name], ctg_name, length))
118 |
119 | def load_data_from_txt(self, img_list_txt, ctg_list_txt):
120 | """Load image from txt.
121 |
122 | Args:
123 | img_list_txt: image txt, format is [file_path, ctg_idx].
124 | ctg_list_txt: category txt, format is [ctg_name, ctg_idx].
125 | """
126 | # check
127 | assert os.path.exists(img_list_txt), 'log: does not exist: {}'.format(img_list_txt)
128 | assert os.path.exists(ctg_list_txt), 'log: does not exist: {}'.format(ctg_list_txt)
129 |
130 | # load category
131 | # : open category info file
132 | with open(ctg_list_txt) as f:
133 | ctg_infos = [line.strip() for line in f.readlines()]
134 | # :load category name & category index
135 | for ctg_info in ctg_infos:
136 | tmp = ctg_info.split(' ')
137 | ctg_name = tmp[0]
138 | ctg_idx = int(tmp[1])
139 | self.ctg_name2idx_[ctg_name] = ctg_idx
140 | self.ctg_names_.append(ctg_name)
141 |
142 | # load sample
143 | # : open image info file
144 | with open(img_list_txt) as f:
145 | img_infos = [line.strip() for line in f.readlines()]
146 | random.shuffle(img_infos)
147 | # : load image path & category index
148 | for img_info in img_infos:
149 | img_path, ctg_name = img_info.rsplit(' ', 1)
150 | img_name = img_path.split('/')[-1]
151 | ctg_idx = int(ctg_name)
152 | self.img_names_.append(img_name)
153 | self.img_paths_.append(img_path)
154 | self.img_labels_.append(ctg_idx)
155 |
156 | for ctg_name in self.ctg_names_:
157 | print('log: category is %d(%s), image num is %d' % (self.ctg_name2idx_[ctg_name], ctg_name, self.img_labels_.count(self.ctg_name2idx_[ctg_name])))
158 |
159 | def _add_new_channels_worker(self, image):
160 | new_channels = []
161 |
162 | image = einops.rearrange(image, "h w c -> c h w")
163 | image = (image- torch.as_tensor(timm.data.constants.IMAGENET_DEFAULT_MEAN).view(-1, 1, 1)) / torch.as_tensor(timm.data.constants.IMAGENET_DEFAULT_STD).view(-1, 1, 1)
164 | srm = self.srm(image.unsqueeze(0)).squeeze(0)
165 | new_channels.append(einops.rearrange(srm, "c h w -> h w c").numpy())
166 |
167 | new_channels = np.concatenate(new_channels, axis=2)
168 | return torch.from_numpy(new_channels).float()
169 |
170 | def add_new_channels(self, images):
171 | images_copied = einops.rearrange(images, "c h w -> h w c")
172 | new_channels = self._add_new_channels_worker(images_copied)
173 | images_copied = torch.concatenate([images_copied, new_channels], dim=-1)
174 | images_copied = einops.rearrange(images_copied, "h w c -> c h w")
175 |
176 | return images_copied
177 |
178 | def __getitem__(self, index):
179 | img_path = self.img_paths_[index]
180 |
181 | img_data = Image.open(img_path).convert('RGB')
182 | img_size = img_data.size[::-1] # [h, w]
183 |
184 | all_data = []
185 | for transform in self.transformer_:
186 | current_data = transform(img_data)
187 | current_data = self.add_new_channels(current_data)
188 | all_data.append(current_data)
189 | img_label = self.img_labels_[index]
190 |
191 | return torch.stack(all_data, dim=0), img_label, img_path, img_size
192 |
193 | def __len__(self):
194 | return len(self.img_names_)
195 |
--------------------------------------------------------------------------------
/core/mengine.py:
--------------------------------------------------------------------------------
1 | import os
2 | import datetime
3 | import sys
4 |
5 | sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
6 | import torch
7 | import torch.nn as nn
8 | from torch.nn.parallel import DistributedDataParallel as DDP
9 | from tqdm import tqdm
10 | from toolkit.cmetric import MultiClassificationMetric, MultilabelClassificationMetric, simple_accuracy
11 | from toolkit.chelper import load_model
12 | from torch import distributed as dist
13 | from sklearn.metrics import roc_auc_score
14 | import numpy as np
15 | import time
16 |
17 |
18 | def reduce_tensor(tensor, n):
19 | rt = tensor.clone()
20 | dist.all_reduce(rt, op=dist.ReduceOp.SUM)
21 | rt /= n
22 | return rt
23 |
24 |
25 | def gather_tensor(tensor, n):
26 | rt = [torch.zeros_like(tensor) for _ in range(n)]
27 | dist.all_gather(rt, tensor)
28 | return torch.cat(rt, dim=0)
29 |
30 |
31 | class TrainEngine(object):
32 | def __init__(self, local_rank, world_size=0, DDP=False, SyncBatchNorm=False):
33 | # init setting
34 | self.local_rank = local_rank
35 | self.world_size = world_size
36 | self.device_ = f'cuda:{local_rank}'
37 | # create tool
38 | self.cls_meter_ = MultilabelClassificationMetric()
39 | self.loss_meter_ = MultiClassificationMetric()
40 | self.top1_meter_ = MultiClassificationMetric()
41 | self.DDP = DDP
42 | self.SyncBN = SyncBatchNorm
43 |
44 | def create_env(self, cfg):
45 | # create network
46 | self.netloc_ = load_model(cfg.network.name, cfg.network.class_num, self.SyncBN)
47 | print(self.netloc_)
48 |
49 | self.netloc_.cuda()
50 | if self.DDP:
51 | if self.SyncBN:
52 | self.netloc_ = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.netloc_)
53 | self.netloc_ = DDP(self.netloc_,
54 | device_ids=[self.local_rank],
55 | broadcast_buffers=True,
56 | )
57 |
58 | # create loss function
59 | self.criterion_ = nn.CrossEntropyLoss().cuda()
60 |
61 | # create optimizer
62 | self.optimizer_ = torch.optim.AdamW(self.netloc_.parameters(), lr=cfg.optimizer.lr,
63 | betas=(cfg.optimizer.beta1, cfg.optimizer.beta2), eps=cfg.optimizer.eps,
64 | weight_decay=cfg.optimizer.weight_decay)
65 |
66 | # create scheduler
67 | self.scheduler_ = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer_, cfg.train.epoch_num,
68 | eta_min=cfg.scheduler.min_lr)
69 |
70 | def train_multi_class(self, train_loader, epoch_idx, ema_start):
71 | starttime = datetime.datetime.now()
72 | # switch to train mode
73 | self.netloc_.train()
74 | self.loss_meter_.reset()
75 | self.top1_meter_.reset()
76 | # train
77 | train_loader = tqdm(train_loader, desc='train', ascii=True)
78 | for imgs_idx, (imgs_tensor, imgs_label, _, _) in enumerate(train_loader):
79 | # set cuda
80 | imgs_tensor = imgs_tensor.cuda() # [256, 3, 224, 224]
81 | imgs_label = imgs_label.cuda()
82 | # clear gradients(zero the parameter gradients)
83 | self.optimizer_.zero_grad()
84 | # calc forward
85 | preds = self.netloc_(imgs_tensor)
86 | # calc acc & loss
87 | loss = self.criterion_(preds, imgs_label)
88 |
89 | # backpropagation
90 | loss.backward()
91 | # update parameters
92 | self.optimizer_.step()
93 |
94 | # EMA update
95 | if ema_start:
96 | self.ema_model.update(self.netloc_)
97 |
98 | # accumulate loss & acc
99 | acc1 = simple_accuracy(preds, imgs_label)
100 | if self.DDP:
101 | loss = reduce_tensor(loss, self.world_size)
102 | acc1 = reduce_tensor(acc1, self.world_size)
103 | self.loss_meter_.update(loss.data.item())
104 | self.top1_meter_.update(acc1.item())
105 |
106 | # eval
107 | top1 = self.top1_meter_.mean
108 | loss = self.loss_meter_.mean
109 | endtime = datetime.datetime.now()
110 | self.lr_ = self.optimizer_.param_groups[0]['lr']
111 | if self.local_rank == 0:
112 | print('log: epoch-%d, train_top1 is %f, train_loss is %f, lr is %f, time is %d' % (
113 | epoch_idx, top1, loss, self.lr_, (endtime - starttime).seconds))
114 | # return
115 | return top1, loss, self.lr_
116 |
117 | def val_multi_class(self, val_loader, epoch_idx):
118 | np.set_printoptions(suppress=True)
119 | starttime = datetime.datetime.now()
120 | # switch to train mode
121 | self.netloc_.eval()
122 | self.loss_meter_.reset()
123 | self.top1_meter_.reset()
124 | self.all_probs = []
125 | self.all_labels = []
126 | # eval
127 | with torch.no_grad():
128 | val_loader = tqdm(val_loader, desc='valid', ascii=True)
129 | for imgs_idx, (imgs_tensor, imgs_label, _, _) in enumerate(val_loader):
130 | # set cuda
131 | imgs_tensor = imgs_tensor.cuda()
132 | imgs_label = imgs_label.cuda()
133 | # calc forward
134 | preds = self.netloc_(imgs_tensor)
135 | # calc acc & loss
136 | loss = self.criterion_(preds, imgs_label)
137 | # accumulate loss & acc
138 | acc1 = simple_accuracy(preds, imgs_label)
139 |
140 | outputs_scores = nn.functional.softmax(preds, dim=1)
141 | outputs_scores = torch.cat((outputs_scores, imgs_label.unsqueeze(-1)), dim=-1)
142 |
143 | if self.DDP:
144 | loss = reduce_tensor(loss, self.world_size)
145 | acc1 = reduce_tensor(acc1, self.world_size)
146 | outputs_scores = gather_tensor(outputs_scores, self.world_size)
147 |
148 | outputs_scores, label = outputs_scores[:, -2], outputs_scores[:, -1]
149 | self.all_probs += [float(i) for i in outputs_scores]
150 | self.all_labels += [ float(i) for i in label]
151 | self.loss_meter_.update(loss.item())
152 | self.top1_meter_.update(acc1.item())
153 | # eval
154 | top1 = self.top1_meter_.mean
155 | loss = self.loss_meter_.mean
156 | auc = roc_auc_score(self.all_labels, self.all_probs)
157 |
158 | endtime = datetime.datetime.now()
159 | if self.local_rank == 0:
160 | print('log: epoch-%d, val_top1 is %f, val_loss is %f, auc is %f, time is %d' % (
161 | epoch_idx, top1, loss, auc, (endtime - starttime).seconds))
162 |
163 | # update lr
164 | self.scheduler_.step()
165 |
166 | # return
167 | return top1, loss, auc
168 |
169 | def val_ema(self, val_loader, epoch_idx):
170 | np.set_printoptions(suppress=True)
171 | starttime = datetime.datetime.now()
172 | # switch to train mode
173 | self.ema_model.module.eval()
174 | self.loss_meter_.reset()
175 | self.top1_meter_.reset()
176 | self.all_probs = []
177 | self.all_labels = []
178 | # eval
179 | with torch.no_grad():
180 | val_loader = tqdm(val_loader, desc='valid', ascii=True)
181 | for imgs_idx, (imgs_tensor, imgs_label, _, _) in enumerate(val_loader):
182 | # set cuda
183 | imgs_tensor = imgs_tensor.cuda()
184 | imgs_label = imgs_label.cuda()
185 | # calc forward
186 | preds = self.ema_model.module(imgs_tensor)
187 |
188 | # calc acc & loss
189 | loss = self.criterion_(preds, imgs_label)
190 | # accumulate loss & acc
191 | acc1 = simple_accuracy(preds, imgs_label)
192 |
193 | outputs_scores = nn.functional.softmax(preds, dim=1)
194 | outputs_scores = torch.cat((outputs_scores, imgs_label.unsqueeze(-1)), dim=-1)
195 |
196 | if self.DDP:
197 | loss = reduce_tensor(loss, self.world_size)
198 | acc1 = reduce_tensor(acc1, self.world_size)
199 | outputs_scores = gather_tensor(outputs_scores, self.world_size)
200 |
201 | outputs_scores, label = outputs_scores[:, -2], outputs_scores[:, -1]
202 | self.all_probs += [float(i) for i in outputs_scores]
203 | self.all_labels += [ float(i) for i in label]
204 | self.loss_meter_.update(loss.item())
205 | self.top1_meter_.update(acc1.item())
206 | # eval
207 | top1 = self.top1_meter_.mean
208 | loss = self.loss_meter_.mean
209 | auc = roc_auc_score(self.all_labels, self.all_probs)
210 |
211 | endtime = datetime.datetime.now()
212 | if self.local_rank == 0:
213 | print('log: epoch-%d, ema_val_top1 is %f, ema_val_loss is %f, ema_auc is %f, time is %d' % (
214 | epoch_idx, top1, loss, auc, (endtime - starttime).seconds))
215 |
216 | # return
217 | return top1, loss, auc
218 |
219 | def save_checkpoint(self, file_root, epoch_idx, train_map, val_map, ema_start):
220 |
221 | file_name = os.path.join(file_root,
222 | time.strftime('%Y%m%d-%H-%M', time.localtime()) + '-' + str(epoch_idx) + '.pth')
223 |
224 | if self.DDP:
225 | stact_dict = self.netloc_.module.state_dict()
226 | else:
227 | stact_dict = self.netloc_.state_dict()
228 |
229 | torch.save(
230 | {
231 | 'epoch_idx': epoch_idx,
232 | 'state_dict': stact_dict,
233 | 'train_map': train_map,
234 | 'val_map': val_map,
235 | 'lr': self.lr_,
236 | 'optimizer': self.optimizer_.state_dict(),
237 | 'scheduler': self.scheduler_.state_dict()
238 | }, file_name)
239 |
240 | if ema_start:
241 | ema_file_name = os.path.join(file_root,
242 | time.strftime('%Y%m%d-%H-%M', time.localtime()) + '-EMA-' + str(epoch_idx) + '.pth')
243 | ema_stact_dict = self.ema_model.module.module.state_dict()
244 | torch.save(
245 | {
246 | 'epoch_idx': epoch_idx,
247 | 'state_dict': ema_stact_dict,
248 | 'train_map': train_map,
249 | 'val_map': val_map,
250 | 'lr': self.lr_,
251 | 'optimizer': self.optimizer_.state_dict(),
252 | 'scheduler': self.scheduler_.state_dict()
253 | }, ema_file_name)
254 |
--------------------------------------------------------------------------------
/dataset/label.txt:
--------------------------------------------------------------------------------
1 | real 0
2 | fake 1
3 |
--------------------------------------------------------------------------------
/dataset/train.txt:
--------------------------------------------------------------------------------
1 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/nature/b580b1fc51d19fc25d2969de07669c21.jpg 0
2 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/df36afc7a12cf840a961743e08bdd596.jpg 1
3 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/f9cec5f76c7f653c2f57d66d7b4ecee0.jpg 1
4 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/a81dc092765e18f3e343b78418cf9371.jpg 1
5 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/56461dd348c9434f44dc810fd06a640e.jpg 1
6 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/1124b822bb0f1076a9914aa454bbd65f.jpg 1
7 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/2fa1aae309a57e975c90285001d43982.jpg 1
8 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/nature/921604adb7ff623bd2fe32d454a1469c.jpg 0
9 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/773f7e1488a29cc52c52b154250df907.jpg 1
10 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/371ad468133750a5fdc670063a6b115a.jpg 1
11 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/780821e5db83764213aae04ac5a54671.jpg 1
12 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/nature/39c253b508dea029854a01de7a1389ab.jpg 0
13 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/9726ea54c28b55e38a5c7cf2fbd8d9da.jpg 1
14 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/4112df80cf4849d05196dc23ecf994cd.jpg 1
15 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/nature/f9858bf9cb1316c273d272249b725912.jpg 0
16 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/nature/bcb50b7c399f978aeb5432c9d80d855c.jpg 0
17 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/2ffd8043985f407069d77dfaae68e032.jpg 1
18 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/0bfd7972fae0bc19f6087fc3b5ac6db8.jpg 1
19 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/daf90a7842ff5bd486ec10fbed22e932.jpg 1
20 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/9dbebbfbc11e8b757b090f5a5ad3fa48.jpg 1
21 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/b8d3d8c2c6cac9fb5b485b94e553f432.jpg 1
22 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/0a59fe7481dc0f9a7dc76cb0bdd3ffe6.jpg 1
23 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/5b82f90800df81625ac78e51f51f1b2e.jpg 1
24 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/nature/badd574c91e6180ef829e2b0d67a3efb.jpg 0
25 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/7412c839b06db42aac1e022096b08031.jpg 1
26 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/nature/81e4b3e7ce314bcd28dde338caeda836.jpg 0
27 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/nature/aa87a563062e6b0741936609014329ab.jpg 0
28 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/0a4c5fdcbe7a3dca6c5a9ee45fd32bef.jpg 1
29 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/adfa3e7ea00ca1ce7a603a297c9ae701.jpg 1
30 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/31fcc0e2f049347b7220dd9eb4f66631.jpg 1
31 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/e699df9505f47dcbb1dcef6858f921e7.jpg 1
32 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/71e7824486a7fef83fa60324dd1cbba8.jpg 1
33 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/ed25cbc58d4f41f7c97201b1ba959834.jpg 1
34 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/nature/4b3d2176926766a4c0e259605dbbc67a.jpg 0
35 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/1dfd77d7ea1a1f05b9c2f532b2a91c62.jpg 1
36 | /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/8e6bea47a8dd71c09c0272be5e1ca584.jpg 1
--------------------------------------------------------------------------------
/dataset/val.txt:
--------------------------------------------------------------------------------
1 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/590e6bc87984f2b4e6d1ed6d4e889088.jpg 1
2 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/720f2234a138382af10b3e2bb6c373cd.jpg 1
3 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/01cb2d00e5d2412ce3cd1d1bb58d7d4e.jpg 1
4 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/41d70d6650eba9036cbb145b29ad14f7.jpg 1
5 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/nature/f7f4df6525cdf0ec27f8f40e2e980ad6.jpg 0
6 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/1dddd03ae6911514a6f1d3117e7e3fd3.jpg 1
7 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/nature/d33054b233cb2e0ebddbe63611626924.jpg 0
8 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/nature/27f2e00bd12d11173422119dfad885ef.jpg 0
9 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/nature/1a0cb2060fbc2065f2ba74f5b2833bc5.jpg 0
10 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/7e0668030bb9a6598621cc7f12600660.jpg 1
11 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/4d7548156c06f9ab12d6daa6524956ea.jpg 1
12 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/cb6a567da3e2f0bcfd19f81756242ba1.jpg 1
13 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/nature/fbff80c8dddf176f310fc10748ce5796.jpg 0
14 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/d68dce56f306f7b0965329f2389b2d5a.jpg 1
15 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/610198886f92d595aaf7cd5c83521ccb.jpg 1
16 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/987a546ad4b3fb76552a89af9b8f5761.jpg 1
17 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/nature/db80dfbe1bb84fe1f9c3e1f21f80561b.jpg 0
18 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/133c775e0516b078f2b951fe49d6b04a.jpg 1
19 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/9584c3c8e012f92b003498793a8a6492.jpg 1
20 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/nature/51aa9b8d0da890cd1d0c5029e3d89e3c.jpg 0
21 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/965c7d35e7a714603587a4710c357ede.jpg 1
22 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/7db2752f0d45637ff64e67f14099378e.jpg 1
23 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/cd9838425bb7e68f165b25a148ba8146.jpg 1
24 | /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/88f45da6e89e59842a9e6339d239a78f.jpg 1
--------------------------------------------------------------------------------
/dataset/val_dataset/51aa9b8d0da890cd1d0c5029e3d89e3c.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VisionRush/DeepFakeDefenders/13117ecf91c215a167126a6962fdd1525f7c957e/dataset/val_dataset/51aa9b8d0da890cd1d0c5029e3d89e3c.jpg
--------------------------------------------------------------------------------
/images/competition_title.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VisionRush/DeepFakeDefenders/13117ecf91c215a167126a6962fdd1525f7c957e/images/competition_title.png
--------------------------------------------------------------------------------
/infer_api.py:
--------------------------------------------------------------------------------
1 | import uvicorn
2 | from fastapi import FastAPI, Body
3 | from pydantic import BaseModel, Field
4 | import sys
5 | import os
6 | import json
7 | from main_infer import INFER_API
8 |
9 |
10 | infer_api = INFER_API()
11 |
12 | # create FastAPI instance
13 | app = FastAPI()
14 |
15 |
16 | class inputModel(BaseModel):
17 | img_path: str = Field(..., description="image path", examples=[""])
18 |
19 | # Call model interface, post request
20 | @app.post("/inter_api")
21 | def inter_api(input_model: inputModel):
22 | img_path = input_model.img_path
23 | infer_api = INFER_API()
24 | score = infer_api.test(img_path)
25 | return score
26 |
27 |
28 | # run
29 | if __name__ == '__main__':
30 | uvicorn.run(app='infer_api:app',
31 | host='0.0.0.0',
32 | port=10005,
33 | reload=False,
34 | workers=1
35 | )
36 |
--------------------------------------------------------------------------------
/main.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --use_env main_train.py
--------------------------------------------------------------------------------
/main_infer.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import numpy as np
3 | import timm
4 | import einops
5 | import torch
6 | from torch import nn
7 | from toolkit.dtransform import create_transforms_inference, create_transforms_inference1,\
8 | create_transforms_inference2,\
9 | create_transforms_inference3,\
10 | create_transforms_inference4,\
11 | create_transforms_inference5
12 | from toolkit.chelper import load_model
13 | import torch.nn.functional as F
14 |
15 |
16 | def extract_model_from_pth(params_path, net_model):
17 | checkpoint = torch.load(params_path)
18 | state_dict = checkpoint['state_dict']
19 |
20 | net_model.load_state_dict(state_dict, strict=True)
21 |
22 | return net_model
23 |
24 |
25 | class SRMConv2d_simple(nn.Module):
26 | def __init__(self, inc=3):
27 | super(SRMConv2d_simple, self).__init__()
28 | self.truc = nn.Hardtanh(-3, 3)
29 | self.kernel = torch.from_numpy(self._build_kernel(inc)).float()
30 |
31 | def forward(self, x):
32 | out = F.conv2d(x, self.kernel, stride=1, padding=2)
33 | out = self.truc(out)
34 |
35 | return out
36 |
37 | def _build_kernel(self, inc):
38 | # filter1: KB
39 | filter1 = [[0, 0, 0, 0, 0],
40 | [0, -1, 2, -1, 0],
41 | [0, 2, -4, 2, 0],
42 | [0, -1, 2, -1, 0],
43 | [0, 0, 0, 0, 0]]
44 | # filter2:KV
45 | filter2 = [[-1, 2, -2, 2, -1],
46 | [2, -6, 8, -6, 2],
47 | [-2, 8, -12, 8, -2],
48 | [2, -6, 8, -6, 2],
49 | [-1, 2, -2, 2, -1]]
50 | # filter3:hor 2rd
51 | filter3 = [[0, 0, 0, 0, 0],
52 | [0, 0, 0, 0, 0],
53 | [0, 1, -2, 1, 0],
54 | [0, 0, 0, 0, 0],
55 | [0, 0, 0, 0, 0]]
56 |
57 | filter1 = np.asarray(filter1, dtype=float) / 4.
58 | filter2 = np.asarray(filter2, dtype=float) / 12.
59 | filter3 = np.asarray(filter3, dtype=float) / 2.
60 | # statck the filters
61 | filters = [[filter1], # , filter1, filter1],
62 | [filter2], # , filter2, filter2],
63 | [filter3]] # , filter3, filter3]]
64 | filters = np.array(filters)
65 | filters = np.repeat(filters, inc, axis=1)
66 | return filters
67 |
68 |
69 | class INFER_API:
70 |
71 | _instance = None
72 |
73 | def __new__(cls):
74 | if cls._instance is None:
75 | cls._instance = super(INFER_API, cls).__new__(cls)
76 | cls._instance.initialize()
77 | return cls._instance
78 |
79 | def initialize(self):
80 | self.transformer_ = [create_transforms_inference(h=512, w=512),
81 | create_transforms_inference1(h=512, w=512),
82 | create_transforms_inference2(h=512, w=512),
83 | create_transforms_inference3(h=512, w=512),
84 | create_transforms_inference4(h=512, w=512),
85 | create_transforms_inference5(h=512, w=512)]
86 | self.srm = SRMConv2d_simple()
87 |
88 | # model init
89 | self.model = load_model('all', 2)
90 | model_path = './final_model_csv/final_model.pth'
91 | self.model = extract_model_from_pth(model_path, self.model)
92 |
93 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
94 | self.model = self.model.to(device)
95 |
96 | self.model.eval()
97 |
98 | def _add_new_channels_worker(self, image):
99 | new_channels = []
100 |
101 | image = einops.rearrange(image, "h w c -> c h w")
102 | image = (image - torch.as_tensor(timm.data.constants.IMAGENET_DEFAULT_MEAN).view(-1, 1, 1)) / torch.as_tensor(
103 | timm.data.constants.IMAGENET_DEFAULT_STD).view(-1, 1, 1)
104 | srm = self.srm(image.unsqueeze(0)).squeeze(0)
105 | new_channels.append(einops.rearrange(srm, "c h w -> h w c").numpy())
106 |
107 | new_channels = np.concatenate(new_channels, axis=2)
108 | return torch.from_numpy(new_channels).float()
109 |
110 | def add_new_channels(self, images):
111 | images_copied = einops.rearrange(images, "c h w -> h w c")
112 | new_channels = self._add_new_channels_worker(images_copied)
113 | images_copied = torch.concatenate([images_copied, new_channels], dim=-1)
114 | images_copied = einops.rearrange(images_copied, "h w c -> c h w")
115 |
116 | return images_copied
117 |
118 | def test(self, img_path):
119 | # img load
120 | img_data = Image.open(img_path).convert('RGB')
121 |
122 | # transform
123 | all_data = []
124 | for transform in self.transformer_:
125 | current_data = transform(img_data)
126 | current_data = self.add_new_channels(current_data)
127 | all_data.append(current_data)
128 | img_tensor = torch.stack(all_data, dim=0).unsqueeze(0).cuda()
129 |
130 | preds = self.model(img_tensor)
131 |
132 | return round(float(preds), 20)
133 |
134 |
135 | def main():
136 | img = '51aa9b8d0da890cd1d0c5029e3d89e3c.jpg'
137 | infer_api = INFER_API()
138 | print(infer_api.test(img))
139 |
140 |
141 | if __name__ == '__main__':
142 | main()
--------------------------------------------------------------------------------
/main_train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import datetime
4 | import torch
5 | import sys
6 |
7 | sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
8 | from torch.utils.tensorboard import SummaryWriter
9 | from core.dsproc_mcls import MultiClassificationProcessor
10 | from core.mengine import TrainEngine
11 | from toolkit.dtransform import create_transforms_inference, transforms_imagenet_train
12 | from toolkit.yacs import CfgNode as CN
13 | from timm.utils import ModelEmaV3
14 |
15 | import warnings
16 | warnings.filterwarnings("ignore")
17 |
18 | # check
19 | print(torch.__version__)
20 | print(torch.cuda.is_available())
21 |
22 | # init
23 | cfg = CN(new_allowed=True)
24 |
25 | # dataset dir
26 | ctg_list = './dataset/label.txt'
27 | train_list = './dataset/train.txt'
28 | val_list = './dataset/val.txt'
29 |
30 | # : network
31 | cfg.network = CN(new_allowed=True)
32 | cfg.network.name = 'replknet'
33 | cfg.network.class_num = 2
34 | cfg.network.input_size = 384
35 |
36 | # : train params
37 | mean = (0.485, 0.456, 0.406)
38 | std = (0.229, 0.224, 0.225)
39 |
40 | cfg.train = CN(new_allowed=True)
41 | cfg.train.resume = False
42 | cfg.train.resume_path = ''
43 | cfg.train.params_path = ''
44 | cfg.train.batch_size = 16
45 | cfg.train.epoch_num = 20
46 | cfg.train.epoch_start = 0
47 | cfg.train.worker_num = 8
48 |
49 | # : optimizer params
50 | cfg.optimizer = CN(new_allowed=True)
51 | cfg.optimizer.lr = 1e-4 * 1
52 | cfg.optimizer.weight_decay = 1e-2
53 | cfg.optimizer.momentum = 0.9
54 | cfg.optimizer.beta1 = 0.9
55 | cfg.optimizer.beta2 = 0.999
56 | cfg.optimizer.eps = 1e-8
57 |
58 | # : scheduler params
59 | cfg.scheduler = CN(new_allowed=True)
60 | cfg.scheduler.min_lr = 1e-6
61 |
62 | # DDP init
63 | local_rank = int(os.environ['LOCAL_RANK'])
64 | device = 'cuda:{}'.format(local_rank)
65 | torch.cuda.set_device(local_rank)
66 | torch.distributed.init_process_group(backend='nccl', init_method='env://')
67 | world_size = torch.distributed.get_world_size()
68 | rank = torch.distributed.get_rank()
69 |
70 | # init path
71 | task = 'competition'
72 | log_root = 'output/' + datetime.datetime.now().strftime("%Y-%m-%d") + '-' + time.strftime(
73 | "%H-%M-%S") + '_' + cfg.network.name + '_' + f"to_{task}_BinClass"
74 | if local_rank == 0:
75 | if not os.path.exists(log_root):
76 | os.makedirs(log_root)
77 | writer = SummaryWriter(log_root)
78 |
79 | # create engine
80 | train_engine = TrainEngine(local_rank, world_size, DDP=True, SyncBatchNorm=True)
81 | train_engine.create_env(cfg)
82 |
83 | # create transforms
84 | transforms_dict ={
85 | 0 : transforms_imagenet_train(img_size=(cfg.network.input_size, cfg.network.input_size)),
86 | 1 : transforms_imagenet_train(img_size=(cfg.network.input_size, cfg.network.input_size), jpeg_compression=1),
87 | }
88 |
89 | transforms_dict_test ={
90 | 0: create_transforms_inference(h=512, w=512),
91 | 1: create_transforms_inference(h=512, w=512),
92 | }
93 |
94 | transform = transforms_dict
95 | transform_test = transforms_dict_test
96 |
97 | # create dataset
98 | trainset = MultiClassificationProcessor(transform)
99 | trainset.load_data_from_txt(train_list, ctg_list)
100 |
101 | valset = MultiClassificationProcessor(transform_test)
102 | valset.load_data_from_txt(val_list, ctg_list)
103 |
104 | train_sampler = torch.utils.data.distributed.DistributedSampler(trainset)
105 | val_sampler = torch.utils.data.distributed.DistributedSampler(valset)
106 |
107 | # create dataloader
108 | train_loader = torch.utils.data.DataLoader(dataset=trainset,
109 | batch_size=cfg.train.batch_size,
110 | sampler=train_sampler,
111 | num_workers=cfg.train.worker_num,
112 | pin_memory=True,
113 | drop_last=True)
114 |
115 | val_loader = torch.utils.data.DataLoader(dataset=valset,
116 | batch_size=cfg.train.batch_size,
117 | sampler=val_sampler,
118 | num_workers=cfg.train.worker_num,
119 | pin_memory=True,
120 | drop_last=False)
121 |
122 | train_log_txtFile = log_root + "/" + "train_log.txt"
123 | f_open = open(train_log_txtFile, "w")
124 |
125 | # train & Val & Test
126 | best_test_mAP = 0.0
127 | best_test_idx = 0.0
128 | ema_start = True
129 | train_engine.ema_model = ModelEmaV3(train_engine.netloc_).cuda()
130 | for epoch_idx in range(cfg.train.epoch_start, cfg.train.epoch_num):
131 | # train
132 | train_top1, train_loss, train_lr = train_engine.train_multi_class(train_loader=train_loader, epoch_idx=epoch_idx, ema_start=ema_start)
133 | # val
134 | val_top1, val_loss, val_auc = train_engine.val_multi_class(val_loader=val_loader, epoch_idx=epoch_idx)
135 | # ema_val
136 | if ema_start:
137 | ema_val_top1, ema_val_loss, ema_val_auc = train_engine.val_ema(val_loader=val_loader, epoch_idx=epoch_idx)
138 |
139 | # check mAP and save
140 | if local_rank == 0:
141 | train_engine.save_checkpoint(log_root, epoch_idx, train_top1, val_top1, ema_start)
142 |
143 | if ema_start:
144 | outInfo = f"epoch_idx = {epoch_idx}, train_top1={train_top1}, train_loss={train_loss},val_top1={val_top1},val_loss={val_loss}, val_auc={val_auc}, ema_val_top1={ema_val_top1}, ema_val_loss={ema_val_loss}, ema_val_auc={ema_val_auc} \n"
145 | else:
146 | outInfo = f"epoch_idx = {epoch_idx}, train_top1={train_top1}, train_loss={train_loss},val_top1={val_top1},val_loss={val_loss}, val_auc={val_auc} \n"
147 |
148 | print(outInfo)
149 |
150 | f_open.write(outInfo)
151 | f_open.flush()
152 |
153 | # curve all mAP & mLoss
154 | writer.add_scalars('top1', {'train': train_top1, 'valid': val_top1}, epoch_idx)
155 | writer.add_scalars('loss', {'train': train_loss, 'valid': val_loss}, epoch_idx)
156 |
157 | # curve lr
158 | writer.add_scalar('train_lr', train_lr, epoch_idx)
159 |
--------------------------------------------------------------------------------
/main_train_single_gpu.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import datetime
4 | import torch
5 | import sys
6 |
7 | sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
8 | from torch.utils.tensorboard import SummaryWriter
9 | from core.dsproc_mcls import MultiClassificationProcessor
10 | from core.mengine import TrainEngine
11 | from toolkit.dtransform import create_transforms_inference, transforms_imagenet_train
12 | from toolkit.yacs import CfgNode as CN
13 | from timm.utils import ModelEmaV3
14 |
15 | import warnings
16 |
17 | warnings.filterwarnings("ignore")
18 |
19 | # check
20 | print(torch.__version__)
21 | print(torch.cuda.is_available())
22 |
23 | # init
24 | cfg = CN(new_allowed=True)
25 |
26 | # dataset dir
27 | ctg_list = './dataset/label.txt'
28 | train_list = './dataset/train.txt'
29 | val_list = './dataset/val.txt'
30 |
31 | # : network
32 | cfg.network = CN(new_allowed=True)
33 | cfg.network.name = 'replknet'
34 | cfg.network.class_num = 2
35 | cfg.network.input_size = 384
36 |
37 | # : train params
38 | mean = (0.485, 0.456, 0.406)
39 | std = (0.229, 0.224, 0.225)
40 |
41 | cfg.train = CN(new_allowed=True)
42 | cfg.train.resume = False
43 | cfg.train.resume_path = ''
44 | cfg.train.params_path = ''
45 | cfg.train.batch_size = 16
46 | cfg.train.epoch_num = 20
47 | cfg.train.epoch_start = 0
48 | cfg.train.worker_num = 8
49 |
50 | # : optimizer params
51 | cfg.optimizer = CN(new_allowed=True)
52 | cfg.optimizer.lr = 1e-4 * 1
53 | cfg.optimizer.weight_decay = 1e-2
54 | cfg.optimizer.momentum = 0.9
55 | cfg.optimizer.beta1 = 0.9
56 | cfg.optimizer.beta2 = 0.999
57 | cfg.optimizer.eps = 1e-8
58 |
59 | # : scheduler params
60 | cfg.scheduler = CN(new_allowed=True)
61 | cfg.scheduler.min_lr = 1e-6
62 |
63 | # init path
64 | task = 'competition'
65 | log_root = 'output/' + datetime.datetime.now().strftime("%Y-%m-%d") + '-' + time.strftime(
66 | "%H-%M-%S") + '_' + cfg.network.name + '_' + f"to_{task}_BinClass"
67 |
68 | if not os.path.exists(log_root):
69 | os.makedirs(log_root)
70 | writer = SummaryWriter(log_root)
71 |
72 | # create engine
73 | train_engine = TrainEngine(0, 0, DDP=False, SyncBatchNorm=False)
74 | train_engine.create_env(cfg)
75 |
76 | # create transforms
77 | transforms_dict = {
78 | 0: transforms_imagenet_train(img_size=(cfg.network.input_size, cfg.network.input_size)),
79 | 1: transforms_imagenet_train(img_size=(cfg.network.input_size, cfg.network.input_size), jpeg_compression=1),
80 | }
81 |
82 | transforms_dict_test = {
83 | 0: create_transforms_inference(h=512, w=512),
84 | 1: create_transforms_inference(h=512, w=512),
85 | }
86 |
87 | transform = transforms_dict
88 | transform_test = transforms_dict_test
89 |
90 | # create dataset
91 | trainset = MultiClassificationProcessor(transform)
92 | trainset.load_data_from_txt(train_list, ctg_list)
93 |
94 | valset = MultiClassificationProcessor(transform_test)
95 | valset.load_data_from_txt(val_list, ctg_list)
96 |
97 | # create dataloader
98 | train_loader = torch.utils.data.DataLoader(dataset=trainset,
99 | batch_size=cfg.train.batch_size,
100 | num_workers=cfg.train.worker_num,
101 | shuffle=True,
102 | pin_memory=True,
103 | drop_last=True)
104 |
105 | val_loader = torch.utils.data.DataLoader(dataset=valset,
106 | batch_size=cfg.train.batch_size,
107 | num_workers=cfg.train.worker_num,
108 | shuffle=False,
109 | pin_memory=True,
110 | drop_last=False)
111 |
112 | train_log_txtFile = log_root + "/" + "train_log.txt"
113 | f_open = open(train_log_txtFile, "w")
114 |
115 | # train & Val & Test
116 | best_test_mAP = 0.0
117 | best_test_idx = 0.0
118 | ema_start = True
119 | train_engine.ema_model = ModelEmaV3(train_engine.netloc_).cuda()
120 | for epoch_idx in range(cfg.train.epoch_start, cfg.train.epoch_num):
121 | # train
122 | train_top1, train_loss, train_lr = train_engine.train_multi_class(train_loader=train_loader, epoch_idx=epoch_idx,
123 | ema_start=ema_start)
124 | # val
125 | val_top1, val_loss, val_auc = train_engine.val_multi_class(val_loader=val_loader, epoch_idx=epoch_idx)
126 | # ema_val
127 | if ema_start:
128 | ema_val_top1, ema_val_loss, ema_val_auc = train_engine.val_ema(val_loader=val_loader, epoch_idx=epoch_idx)
129 |
130 | train_engine.save_checkpoint(log_root, epoch_idx, train_top1, val_top1, ema_start)
131 |
132 | if ema_start:
133 | outInfo = f"epoch_idx = {epoch_idx}, train_top1={train_top1}, train_loss={train_loss},val_top1={val_top1},val_loss={val_loss}, val_auc={val_auc}, ema_val_top1={ema_val_top1}, ema_val_loss={ema_val_loss}, ema_val_auc={ema_val_auc} \n"
134 | else:
135 | outInfo = f"epoch_idx = {epoch_idx}, train_top1={train_top1}, train_loss={train_loss},val_top1={val_top1},val_loss={val_loss}, val_auc={val_auc} \n"
136 |
137 | print(outInfo)
138 |
139 | f_open.write(outInfo)
140 | # 刷新文件
141 | f_open.flush()
142 |
143 | # curve all mAP & mLoss
144 | writer.add_scalars('top1', {'train': train_top1, 'valid': val_top1}, epoch_idx)
145 | writer.add_scalars('loss', {'train': train_loss, 'valid': val_loss}, epoch_idx)
146 |
147 | # curve lr
148 | writer.add_scalar('train_lr', train_lr, epoch_idx)
149 |
150 |
--------------------------------------------------------------------------------
/merge.py:
--------------------------------------------------------------------------------
1 | from toolkit.chelper import final_model
2 | import torch
3 | import os
4 |
5 |
6 | # Trained ConvNeXt and RepLKNet paths (for reference)
7 | convnext_path = './final_model_csv/convnext_end.pth'
8 | replknet_path = './final_model_csv/replk_end.pth'
9 |
10 | model = final_model()
11 | model.convnext.load_state_dict(torch.load(convnext_path, map_location='cpu')['state_dict'], strict=True)
12 | model.replknet.load_state_dict(torch.load(replknet_path, map_location='cpu')['state_dict'], strict=True)
13 |
14 | if not os.path.exists('./final_model_csv'):
15 | os.makedirs('./final_model_csv')
16 |
17 | torch.save({'state_dict': model.state_dict()}, './final_model_csv/final_model.pth')
18 |
--------------------------------------------------------------------------------
/model/convnext.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
3 | # All rights reserved.
4 |
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from timm.models.layers import trunc_normal_, DropPath
13 | from timm.models.registry import register_model
14 |
15 | class Block(nn.Module):
16 | r""" ConvNeXt Block. There are two equivalent implementations:
17 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
18 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
19 | We use (2) as we find it slightly faster in PyTorch
20 |
21 | Args:
22 | dim (int): Number of input channels.
23 | drop_path (float): Stochastic depth rate. Default: 0.0
24 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
25 | """
26 | def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
27 | super().__init__()
28 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
29 | self.norm = LayerNorm(dim, eps=1e-6)
30 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
31 | self.act = nn.GELU()
32 | self.pwconv2 = nn.Linear(4 * dim, dim)
33 | self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
34 | requires_grad=True) if layer_scale_init_value > 0 else None
35 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
36 |
37 | def forward(self, x):
38 | input = x
39 | x = self.dwconv(x)
40 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
41 | x = self.norm(x)
42 | x = self.pwconv1(x)
43 | x = self.act(x)
44 | x = self.pwconv2(x)
45 | if self.gamma is not None:
46 | x = self.gamma * x
47 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
48 |
49 | x = input + self.drop_path(x)
50 | return x
51 |
52 | class ConvNeXt(nn.Module):
53 | r""" ConvNeXt
54 | A PyTorch impl of : `A ConvNet for the 2020s` -
55 | https://arxiv.org/pdf/2201.03545.pdf
56 |
57 | Args:
58 | in_chans (int): Number of input image channels. Default: 3
59 | num_classes (int): Number of classes for classification head. Default: 1000
60 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
61 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
62 | drop_path_rate (float): Stochastic depth rate. Default: 0.
63 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
64 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
65 | """
66 | def __init__(self, in_chans=3, num_classes=1000,
67 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.,
68 | layer_scale_init_value=1e-6, head_init_scale=1.,
69 | ):
70 | super().__init__()
71 |
72 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
73 | stem = nn.Sequential(
74 | nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
75 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
76 | )
77 | self.downsample_layers.append(stem)
78 | for i in range(3):
79 | downsample_layer = nn.Sequential(
80 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
81 | nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
82 | )
83 | self.downsample_layers.append(downsample_layer)
84 |
85 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
86 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
87 | cur = 0
88 | for i in range(4):
89 | stage = nn.Sequential(
90 | *[Block(dim=dims[i], drop_path=dp_rates[cur + j],
91 | layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
92 | )
93 | self.stages.append(stage)
94 | cur += depths[i]
95 |
96 | self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
97 | self.head = nn.Linear(dims[-1], num_classes)
98 |
99 | self.apply(self._init_weights)
100 | self.head.weight.data.mul_(head_init_scale)
101 | self.head.bias.data.mul_(head_init_scale)
102 |
103 | def _init_weights(self, m):
104 | if isinstance(m, (nn.Conv2d, nn.Linear)):
105 | trunc_normal_(m.weight, std=.02)
106 | nn.init.constant_(m.bias, 0)
107 |
108 | def forward_features(self, x):
109 | for i in range(4):
110 | x = self.downsample_layers[i](x)
111 | x = self.stages[i](x)
112 | return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)
113 |
114 | def forward(self, x):
115 | x = self.forward_features(x)
116 | x = self.head(x)
117 | return x
118 |
119 | class LayerNorm(nn.Module):
120 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
121 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
122 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs
123 | with shape (batch_size, channels, height, width).
124 | """
125 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
126 | super().__init__()
127 | self.weight = nn.Parameter(torch.ones(normalized_shape))
128 | self.bias = nn.Parameter(torch.zeros(normalized_shape))
129 | self.eps = eps
130 | self.data_format = data_format
131 | if self.data_format not in ["channels_last", "channels_first"]:
132 | raise NotImplementedError
133 | self.normalized_shape = (normalized_shape, )
134 |
135 | def forward(self, x):
136 | if self.data_format == "channels_last":
137 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
138 | elif self.data_format == "channels_first":
139 | u = x.mean(1, keepdim=True)
140 | s = (x - u).pow(2).mean(1, keepdim=True)
141 | x = (x - u) / torch.sqrt(s + self.eps)
142 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
143 | return x
144 |
145 |
146 | model_urls = {
147 | "convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
148 | "convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
149 | "convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
150 | "convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
151 | "convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth",
152 | "convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth",
153 | "convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
154 | "convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth",
155 | "convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth",
156 | }
157 |
158 | @register_model
159 | def convnext_tiny(pretrained=False,in_22k=False, **kwargs):
160 | model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
161 | if pretrained:
162 | url = model_urls['convnext_tiny_22k'] if in_22k else model_urls['convnext_tiny_1k']
163 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
164 | model.load_state_dict(checkpoint["model"])
165 | return model
166 |
167 | @register_model
168 | def convnext_small(pretrained=False,in_22k=False, **kwargs):
169 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
170 | if pretrained:
171 | url = model_urls['convnext_small_22k'] if in_22k else model_urls['convnext_small_1k']
172 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
173 | model.load_state_dict(checkpoint["model"])
174 | return model
175 |
176 | @register_model
177 | def convnext_base(pretrained=False, in_22k=False, **kwargs):
178 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
179 | if pretrained:
180 | url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k']
181 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
182 | model.load_state_dict(checkpoint["model"])
183 | return model
184 |
185 | @register_model
186 | def convnext_large(pretrained=False, in_22k=False, **kwargs):
187 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
188 | if pretrained:
189 | url = model_urls['convnext_large_22k'] if in_22k else model_urls['convnext_large_1k']
190 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
191 | model.load_state_dict(checkpoint["model"])
192 | return model
193 |
194 | @register_model
195 | def convnext_xlarge(pretrained=False, in_22k=False, **kwargs):
196 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
197 | if pretrained:
198 | assert in_22k, "only ImageNet-22K pre-trained ConvNeXt-XL is available; please set in_22k=True"
199 | url = model_urls['convnext_xlarge_22k']
200 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
201 | model.load_state_dict(checkpoint["model"])
202 | return model
203 |
--------------------------------------------------------------------------------
/model/replknet.py:
--------------------------------------------------------------------------------
1 | # Scaling Up Your Kernels to 31x31: Revisiting Large Kernel Design in CNNs (https://arxiv.org/abs/2203.06717)
2 | # Github source: https://github.com/DingXiaoH/RepLKNet-pytorch
3 | # Licensed under The MIT License [see LICENSE for details]
4 | # Based on ConvNeXt, timm, DINO and DeiT code bases
5 | # https://github.com/facebookresearch/ConvNeXt
6 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm
7 | # https://github.com/facebookresearch/deit/
8 | # https://github.com/facebookresearch/dino
9 | # --------------------------------------------------------'
10 | import torch
11 | import torch.nn as nn
12 | import torch.utils.checkpoint as checkpoint
13 | from timm.models.layers import DropPath
14 | import sys
15 | import os
16 |
17 | def get_conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias):
18 | if type(kernel_size) is int:
19 | use_large_impl = kernel_size > 5
20 | else:
21 | assert len(kernel_size) == 2 and kernel_size[0] == kernel_size[1]
22 | use_large_impl = kernel_size[0] > 5
23 | has_large_impl = 'LARGE_KERNEL_CONV_IMPL' in os.environ
24 | if has_large_impl and in_channels == out_channels and out_channels == groups and use_large_impl and stride == 1 and padding == kernel_size // 2 and dilation == 1:
25 | sys.path.append(os.environ['LARGE_KERNEL_CONV_IMPL'])
26 | # Please follow the instructions https://github.com/DingXiaoH/RepLKNet-pytorch/blob/main/README.md
27 | # export LARGE_KERNEL_CONV_IMPL=absolute_path_to_where_you_cloned_the_example (i.e., depthwise_conv2d_implicit_gemm.py)
28 | # TODO more efficient PyTorch implementations of large-kernel convolutions. Pull requests are welcomed.
29 | # Or you may try MegEngine. We have integrated an efficient implementation into MegEngine and it will automatically use it.
30 | from depthwise_conv2d_implicit_gemm import DepthWiseConv2dImplicitGEMM
31 | return DepthWiseConv2dImplicitGEMM(in_channels, kernel_size, bias=bias)
32 | else:
33 | return nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
34 | padding=padding, dilation=dilation, groups=groups, bias=bias)
35 |
36 | use_sync_bn = False
37 |
38 | def enable_sync_bn():
39 | global use_sync_bn
40 | use_sync_bn = True
41 |
42 | def get_bn(channels):
43 | if use_sync_bn:
44 | return nn.SyncBatchNorm(channels)
45 | else:
46 | return nn.BatchNorm2d(channels)
47 |
48 | def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups, dilation=1):
49 | if padding is None:
50 | padding = kernel_size // 2
51 | result = nn.Sequential()
52 | result.add_module('conv', get_conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
53 | stride=stride, padding=padding, dilation=dilation, groups=groups, bias=False))
54 | result.add_module('bn', get_bn(out_channels))
55 | return result
56 |
57 | def conv_bn_relu(in_channels, out_channels, kernel_size, stride, padding, groups, dilation=1):
58 | if padding is None:
59 | padding = kernel_size // 2
60 | result = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
61 | stride=stride, padding=padding, groups=groups, dilation=dilation)
62 | result.add_module('nonlinear', nn.ReLU())
63 | return result
64 |
65 | def fuse_bn(conv, bn):
66 | kernel = conv.weight
67 | running_mean = bn.running_mean
68 | running_var = bn.running_var
69 | gamma = bn.weight
70 | beta = bn.bias
71 | eps = bn.eps
72 | std = (running_var + eps).sqrt()
73 | t = (gamma / std).reshape(-1, 1, 1, 1)
74 | return kernel * t, beta - running_mean * gamma / std
75 |
76 | class ReparamLargeKernelConv(nn.Module):
77 |
78 | def __init__(self, in_channels, out_channels, kernel_size,
79 | stride, groups,
80 | small_kernel,
81 | small_kernel_merged=False):
82 | super(ReparamLargeKernelConv, self).__init__()
83 | self.kernel_size = kernel_size
84 | self.small_kernel = small_kernel
85 | # We assume the conv does not change the feature map size, so padding = k//2. Otherwise, you may configure padding as you wish, and change the padding of small_conv accordingly.
86 | padding = kernel_size // 2
87 | if small_kernel_merged:
88 | self.lkb_reparam = get_conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
89 | stride=stride, padding=padding, dilation=1, groups=groups, bias=True)
90 | else:
91 | self.lkb_origin = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
92 | stride=stride, padding=padding, dilation=1, groups=groups)
93 | if small_kernel is not None:
94 | assert small_kernel <= kernel_size, 'The kernel size for re-param cannot be larger than the large kernel!'
95 | self.small_conv = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=small_kernel,
96 | stride=stride, padding=small_kernel//2, groups=groups, dilation=1)
97 |
98 | def forward(self, inputs):
99 | if hasattr(self, 'lkb_reparam'):
100 | out = self.lkb_reparam(inputs)
101 | else:
102 | out = self.lkb_origin(inputs)
103 | if hasattr(self, 'small_conv'):
104 | out += self.small_conv(inputs)
105 | return out
106 |
107 | def get_equivalent_kernel_bias(self):
108 | eq_k, eq_b = fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn)
109 | if hasattr(self, 'small_conv'):
110 | small_k, small_b = fuse_bn(self.small_conv.conv, self.small_conv.bn)
111 | eq_b += small_b
112 | # add to the central part
113 | eq_k += nn.functional.pad(small_k, [(self.kernel_size - self.small_kernel) // 2] * 4)
114 | return eq_k, eq_b
115 |
116 | def merge_kernel(self):
117 | eq_k, eq_b = self.get_equivalent_kernel_bias()
118 | self.lkb_reparam = get_conv2d(in_channels=self.lkb_origin.conv.in_channels,
119 | out_channels=self.lkb_origin.conv.out_channels,
120 | kernel_size=self.lkb_origin.conv.kernel_size, stride=self.lkb_origin.conv.stride,
121 | padding=self.lkb_origin.conv.padding, dilation=self.lkb_origin.conv.dilation,
122 | groups=self.lkb_origin.conv.groups, bias=True)
123 | self.lkb_reparam.weight.data = eq_k
124 | self.lkb_reparam.bias.data = eq_b
125 | self.__delattr__('lkb_origin')
126 | if hasattr(self, 'small_conv'):
127 | self.__delattr__('small_conv')
128 |
129 |
130 | class ConvFFN(nn.Module):
131 |
132 | def __init__(self, in_channels, internal_channels, out_channels, drop_path):
133 | super().__init__()
134 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
135 | self.preffn_bn = get_bn(in_channels)
136 | self.pw1 = conv_bn(in_channels=in_channels, out_channels=internal_channels, kernel_size=1, stride=1, padding=0, groups=1)
137 | self.pw2 = conv_bn(in_channels=internal_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0, groups=1)
138 | self.nonlinear = nn.GELU()
139 |
140 | def forward(self, x):
141 | out = self.preffn_bn(x)
142 | out = self.pw1(out)
143 | out = self.nonlinear(out)
144 | out = self.pw2(out)
145 | return x + self.drop_path(out)
146 |
147 |
148 | class RepLKBlock(nn.Module):
149 |
150 | def __init__(self, in_channels, dw_channels, block_lk_size, small_kernel, drop_path, small_kernel_merged=False):
151 | super().__init__()
152 | self.pw1 = conv_bn_relu(in_channels, dw_channels, 1, 1, 0, groups=1)
153 | self.pw2 = conv_bn(dw_channels, in_channels, 1, 1, 0, groups=1)
154 | self.large_kernel = ReparamLargeKernelConv(in_channels=dw_channels, out_channels=dw_channels, kernel_size=block_lk_size,
155 | stride=1, groups=dw_channels, small_kernel=small_kernel, small_kernel_merged=small_kernel_merged)
156 | self.lk_nonlinear = nn.ReLU()
157 | self.prelkb_bn = get_bn(in_channels)
158 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
159 | print('drop path:', self.drop_path)
160 |
161 | def forward(self, x):
162 | out = self.prelkb_bn(x)
163 | out = self.pw1(out)
164 | out = self.large_kernel(out)
165 | out = self.lk_nonlinear(out)
166 | out = self.pw2(out)
167 | return x + self.drop_path(out)
168 |
169 |
170 | class RepLKNetStage(nn.Module):
171 |
172 | def __init__(self, channels, num_blocks, stage_lk_size, drop_path,
173 | small_kernel, dw_ratio=1, ffn_ratio=4,
174 | use_checkpoint=False, # train with torch.utils.checkpoint to save memory
175 | small_kernel_merged=False,
176 | norm_intermediate_features=False):
177 | super().__init__()
178 | self.use_checkpoint = use_checkpoint
179 | blks = []
180 | for i in range(num_blocks):
181 | block_drop_path = drop_path[i] if isinstance(drop_path, list) else drop_path
182 | # Assume all RepLK Blocks within a stage share the same lk_size. You may tune it on your own model.
183 | replk_block = RepLKBlock(in_channels=channels, dw_channels=int(channels * dw_ratio), block_lk_size=stage_lk_size,
184 | small_kernel=small_kernel, drop_path=block_drop_path, small_kernel_merged=small_kernel_merged)
185 | convffn_block = ConvFFN(in_channels=channels, internal_channels=int(channels * ffn_ratio), out_channels=channels,
186 | drop_path=block_drop_path)
187 | blks.append(replk_block)
188 | blks.append(convffn_block)
189 | self.blocks = nn.ModuleList(blks)
190 | if norm_intermediate_features:
191 | self.norm = get_bn(channels) # Only use this with RepLKNet-XL on downstream tasks
192 | else:
193 | self.norm = nn.Identity()
194 |
195 | def forward(self, x):
196 | for blk in self.blocks:
197 | if self.use_checkpoint:
198 | x = checkpoint.checkpoint(blk, x) # Save training memory
199 | else:
200 | x = blk(x)
201 | return x
202 |
203 | class RepLKNet(nn.Module):
204 |
205 | def __init__(self, large_kernel_sizes, layers, channels, drop_path_rate, small_kernel,
206 | dw_ratio=1, ffn_ratio=4, in_channels=3, num_classes=1000, out_indices=None,
207 | use_checkpoint=False,
208 | small_kernel_merged=False,
209 | use_sync_bn=True,
210 | norm_intermediate_features=False # for RepLKNet-XL on COCO and ADE20K, use an extra BN to normalize the intermediate feature maps then feed them into the heads
211 | ):
212 | super().__init__()
213 |
214 | if num_classes is None and out_indices is None:
215 | raise ValueError('must specify one of num_classes (for pretraining) and out_indices (for downstream tasks)')
216 | elif num_classes is not None and out_indices is not None:
217 | raise ValueError('cannot specify both num_classes (for pretraining) and out_indices (for downstream tasks)')
218 | elif num_classes is not None and norm_intermediate_features:
219 | raise ValueError('for pretraining, no need to normalize the intermediate feature maps')
220 | self.out_indices = out_indices
221 | if use_sync_bn:
222 | enable_sync_bn()
223 |
224 | base_width = channels[0]
225 | self.use_checkpoint = use_checkpoint
226 | self.norm_intermediate_features = norm_intermediate_features
227 | self.num_stages = len(layers)
228 | self.stem = nn.ModuleList([
229 | conv_bn_relu(in_channels=in_channels, out_channels=base_width, kernel_size=3, stride=2, padding=1, groups=1),
230 | conv_bn_relu(in_channels=base_width, out_channels=base_width, kernel_size=3, stride=1, padding=1, groups=base_width),
231 | conv_bn_relu(in_channels=base_width, out_channels=base_width, kernel_size=1, stride=1, padding=0, groups=1),
232 | conv_bn_relu(in_channels=base_width, out_channels=base_width, kernel_size=3, stride=2, padding=1, groups=base_width)])
233 | # stochastic depth. We set block-wise drop-path rate. The higher level blocks are more likely to be dropped. This implementation follows Swin.
234 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(layers))]
235 | self.stages = nn.ModuleList()
236 | self.transitions = nn.ModuleList()
237 | for stage_idx in range(self.num_stages):
238 | layer = RepLKNetStage(channels=channels[stage_idx], num_blocks=layers[stage_idx],
239 | stage_lk_size=large_kernel_sizes[stage_idx],
240 | drop_path=dpr[sum(layers[:stage_idx]):sum(layers[:stage_idx + 1])],
241 | small_kernel=small_kernel, dw_ratio=dw_ratio, ffn_ratio=ffn_ratio,
242 | use_checkpoint=use_checkpoint, small_kernel_merged=small_kernel_merged,
243 | norm_intermediate_features=norm_intermediate_features)
244 | self.stages.append(layer)
245 | if stage_idx < len(layers) - 1:
246 | transition = nn.Sequential(
247 | conv_bn_relu(channels[stage_idx], channels[stage_idx + 1], 1, 1, 0, groups=1),
248 | conv_bn_relu(channels[stage_idx + 1], channels[stage_idx + 1], 3, stride=2, padding=1, groups=channels[stage_idx + 1]))
249 | self.transitions.append(transition)
250 |
251 | if num_classes is not None:
252 | self.norm = get_bn(channels[-1])
253 | self.avgpool = nn.AdaptiveAvgPool2d(1)
254 | self.head = nn.Linear(channels[-1], num_classes)
255 |
256 |
257 |
258 | def forward_features(self, x):
259 | x = self.stem[0](x)
260 | for stem_layer in self.stem[1:]:
261 | if self.use_checkpoint:
262 | x = checkpoint.checkpoint(stem_layer, x) # save memory
263 | else:
264 | x = stem_layer(x)
265 |
266 | if self.out_indices is None:
267 | # Just need the final output
268 | for stage_idx in range(self.num_stages):
269 | x = self.stages[stage_idx](x)
270 | if stage_idx < self.num_stages - 1:
271 | x = self.transitions[stage_idx](x)
272 | return x
273 | else:
274 | # Need the intermediate feature maps
275 | outs = []
276 | for stage_idx in range(self.num_stages):
277 | x = self.stages[stage_idx](x)
278 | if stage_idx in self.out_indices:
279 | outs.append(self.stages[stage_idx].norm(x)) # For RepLKNet-XL normalize the features before feeding them into the heads
280 | if stage_idx < self.num_stages - 1:
281 | x = self.transitions[stage_idx](x)
282 | return outs
283 |
284 | def forward(self, x):
285 | x = self.forward_features(x)
286 | if self.out_indices:
287 | return x
288 | else:
289 | x = self.norm(x)
290 | x = self.avgpool(x)
291 | x = torch.flatten(x, 1)
292 | x = self.head(x)
293 | return x
294 |
295 | def structural_reparam(self):
296 | for m in self.modules():
297 | if hasattr(m, 'merge_kernel'):
298 | m.merge_kernel()
299 |
300 | # If your framework cannot automatically fuse BN for inference, you may do it manually.
301 | # The BNs after and before conv layers can be removed.
302 | # No need to call this if your framework support automatic BN fusion.
303 | def deep_fuse_BN(self):
304 | for m in self.modules():
305 | if not isinstance(m, nn.Sequential):
306 | continue
307 | if not len(m) in [2, 3]: # Only handle conv-BN or conv-BN-relu
308 | continue
309 | # If you use a custom Conv2d impl, assume it also has 'kernel_size' and 'weight'
310 | if hasattr(m[0], 'kernel_size') and hasattr(m[0], 'weight') and isinstance(m[1], nn.BatchNorm2d):
311 | conv = m[0]
312 | bn = m[1]
313 | fused_kernel, fused_bias = fuse_bn(conv, bn)
314 | fused_conv = get_conv2d(conv.in_channels, conv.out_channels, kernel_size=conv.kernel_size,
315 | stride=conv.stride,
316 | padding=conv.padding, dilation=conv.dilation, groups=conv.groups, bias=True)
317 | fused_conv.weight.data = fused_kernel
318 | fused_conv.bias.data = fused_bias
319 | m[0] = fused_conv
320 | m[1] = nn.Identity()
321 |
322 |
323 | def create_RepLKNet31B(drop_path_rate=0.5, num_classes=1000, use_checkpoint=False, small_kernel_merged=False, use_sync_bn=True):
324 | return RepLKNet(large_kernel_sizes=[31,29,27,13], layers=[2,2,18,2], channels=[128,256,512,1024],
325 | drop_path_rate=drop_path_rate, small_kernel=5, num_classes=num_classes, use_checkpoint=use_checkpoint,
326 | small_kernel_merged=small_kernel_merged, use_sync_bn=use_sync_bn)
327 |
328 | def create_RepLKNet31L(drop_path_rate=0.3, num_classes=1000, use_checkpoint=True, small_kernel_merged=False):
329 | return RepLKNet(large_kernel_sizes=[31,29,27,13], layers=[2,2,18,2], channels=[192,384,768,1536],
330 | drop_path_rate=drop_path_rate, small_kernel=5, num_classes=num_classes, use_checkpoint=use_checkpoint,
331 | small_kernel_merged=small_kernel_merged)
332 |
333 | def create_RepLKNetXL(drop_path_rate=0.3, num_classes=1000, use_checkpoint=True, small_kernel_merged=False):
334 | return RepLKNet(large_kernel_sizes=[27,27,27,13], layers=[2,2,18,2], channels=[256,512,1024,2048],
335 | drop_path_rate=drop_path_rate, small_kernel=None, dw_ratio=1.5,
336 | num_classes=num_classes, use_checkpoint=use_checkpoint,
337 | small_kernel_merged=small_kernel_merged)
338 |
339 | if __name__ == '__main__':
340 | model = create_RepLKNet31B(small_kernel_merged=False)
341 | model.eval()
342 | print('------------------- training-time model -------------')
343 | print(model)
344 | x = torch.randn(2, 3, 224, 224)
345 | origin_y = model(x)
346 | model.structural_reparam()
347 | print('------------------- after re-param -------------')
348 | print(model)
349 | reparam_y = model(x)
350 | print('------------------- the difference is ------------------------')
351 | print((origin_y - reparam_y).abs().sum())
352 |
353 |
354 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | asttokens==2.4.1
2 | einops==0.8.0
3 | numpy==1.22.0
4 | opencv-python==4.8.0.74
5 | pillow==9.5.0
6 | PyYAML==6.0.1
7 | scikit-image==0.21.0
8 | scikit-learn==1.3.2
9 | tensorboard==2.14.0
10 | tensorboard-data-server==0.7.2
11 | thop==0.1.1.post2209072238
12 | timm==0.6.13
13 | tqdm==4.66.4
14 | fastapi==0.103.1
15 | uvicorn==0.22.0
16 | pydantic==1.10.9
17 | torch==1.13.1
18 | torchvision==0.14.1
19 |
20 |
--------------------------------------------------------------------------------
/toolkit/chelper.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from model.convnext import convnext_base
4 | import timm
5 | from model.replknet import create_RepLKNet31B
6 |
7 |
8 | class augment_inputs_network(nn.Module):
9 | def __init__(self, model):
10 | super(augment_inputs_network, self).__init__()
11 | self.model = model
12 | self.adapter = nn.Conv2d(in_channels=6, out_channels=3, kernel_size=3, stride=1, padding=1)
13 |
14 | def forward(self, x):
15 | x = self.adapter(x)
16 | x = (x - torch.as_tensor(timm.data.constants.IMAGENET_DEFAULT_MEAN, device=x.get_device()).view(1, -1, 1, 1)) / torch.as_tensor(timm.data.constants.IMAGENET_DEFAULT_STD, device=x.get_device()).view(1, -1, 1, 1)
17 |
18 | return self.model(x)
19 |
20 |
21 | class final_model(nn.Module): # Total parameters: 158.64741325378418 MB
22 | def __init__(self):
23 | super(final_model, self).__init__()
24 |
25 | self.convnext = convnext_base(num_classes=2)
26 | self.convnext = augment_inputs_network(self.convnext)
27 |
28 | self.replknet = create_RepLKNet31B(num_classes=2)
29 | self.replknet = augment_inputs_network(self.replknet)
30 |
31 | def forward(self, x):
32 | B, N, C, H, W = x.shape
33 | x = x.view(-1, C, H, W)
34 |
35 | pred1 = self.convnext(x)
36 | pred2 = self.replknet(x)
37 |
38 | outputs_score1 = nn.functional.softmax(pred1, dim=1)
39 | outputs_score2 = nn.functional.softmax(pred2, dim=1)
40 |
41 | predict_score1 = outputs_score1[:, 1]
42 | predict_score2 = outputs_score2[:, 1]
43 |
44 | predict_score1 = predict_score1.view(B, N).mean(dim=-1)
45 | predict_score2 = predict_score2.view(B, N).mean(dim=-1)
46 |
47 | return torch.stack((predict_score1, predict_score2), dim=-1).mean(dim=-1)
48 |
49 |
50 | def load_model(model_name, ctg_num, use_sync_bn):
51 | """Load standard model, like vgg16, resnet18,
52 |
53 | Args:
54 | model_name: e.g., vgg16, inception, resnet18, ...
55 | ctg_num: e.g., 1000
56 | use_sync_bn: True/False
57 | """
58 | if model_name == 'convnext':
59 | model = convnext_base(num_classes=ctg_num)
60 | model_path = 'pre_model/convnext_base_1k_384.pth'
61 | check_point = torch.load(model_path, map_location='cpu')['model']
62 | check_point.pop('head.weight')
63 | check_point.pop('head.bias')
64 | model.load_state_dict(check_point, strict=False)
65 |
66 | model = augment_inputs_network(model)
67 |
68 | elif model_name == 'replknet':
69 | model = create_RepLKNet31B(num_classes=ctg_num, use_sync_bn=use_sync_bn)
70 | model_path = 'pre_model/RepLKNet-31B_ImageNet-1K_384.pth'
71 | check_point = torch.load(model_path)
72 | check_point.pop('head.weight')
73 | check_point.pop('head.bias')
74 | model.load_state_dict(check_point, strict=False)
75 |
76 | model = augment_inputs_network(model)
77 |
78 | elif model_name == 'all':
79 | model = final_model()
80 |
81 | print("model_name", model_name)
82 |
83 | return model
84 |
85 |
--------------------------------------------------------------------------------
/toolkit/cmetric.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import numpy as np
4 | import sklearn
5 | import sklearn.metrics
6 |
7 |
8 | class MultilabelClassificationMetric(object):
9 | def __init__(self):
10 | super(MultilabelClassificationMetric, self).__init__()
11 | self.pred_scores_ = torch.FloatTensor() # .FloatStorage()
12 | self.grth_labels_ = torch.LongTensor() # .LongStorage()
13 |
14 | # Func:
15 | # Reset calculation.
16 | def reset(self):
17 | self.pred_scores_ = torch.FloatTensor(torch.FloatStorage())
18 | self.grth_labels_ = torch.LongTensor(torch.LongStorage())
19 |
20 | # Func:
21 | # Add prediction and groundtruth that will be used to calculate average precision.
22 | # Input:
23 | # pred_scores : predicted scores, size: [batch_size, label_dim], format: [s0, s1, ..., s19]
24 | # grth_labels : groundtruth labels, size: [batch_size, label_dim], format: [c0, c1, ..., c19]
25 | def add(self, pred_scores, grth_labels):
26 | if not torch.is_tensor(pred_scores):
27 | pred_scores = torch.from_numpy(pred_scores)
28 | if not torch.is_tensor(grth_labels):
29 | grth_labels = torch.from_numpy(grth_labels)
30 |
31 | # check
32 | assert pred_scores.dim() == 2, 'wrong pred_scores size (should be 2D with format: [batch_size, label_dim(one column per class)])'
33 | assert grth_labels.dim() == 2, 'wrong grth_labels size (should be 2D with format: [batch_size, label_dim(one column per class)])'
34 |
35 | # check storage is sufficient
36 | if self.pred_scores_.storage().size() < self.pred_scores_.numel() + pred_scores.numel():
37 | new_size = math.ceil(self.pred_scores_.storage().size() * 1.5)
38 | self.pred_scores_.storage().resize_(int(new_size + pred_scores.numel()))
39 | self.grth_labels_.storage().resize_(int(new_size + pred_scores.numel()))
40 |
41 | # store outputs and targets
42 | offset = self.pred_scores_.size(0) if self.pred_scores_.dim() > 0 else 0
43 | self.pred_scores_.resize_(offset + pred_scores.size(0), pred_scores.size(1))
44 | self.grth_labels_.resize_(offset + grth_labels.size(0), grth_labels.size(1))
45 | self.pred_scores_.narrow(0, offset, pred_scores.size(0)).copy_(pred_scores)
46 | self.grth_labels_.narrow(0, offset, grth_labels.size(0)).copy_(grth_labels)
47 |
48 | # Func:
49 | # Compute average precision.
50 | def calc_avg_precision(self):
51 | # check
52 | if self.pred_scores_.numel() == 0: return 0
53 | # calc by class
54 | aps = torch.zeros(self.pred_scores_.size(1))
55 | for cls_idx in range(self.pred_scores_.size(1)):
56 | # get pred scores & grth labels at class cls_idx
57 | cls_pred_scores = self.pred_scores_[:, cls_idx] # predictions for all images at class cls_idx, format: [img_num]
58 | cls_grth_labels = self.grth_labels_[:, cls_idx] # truthvalues for all iamges at class cls_idx, format: [img_num]
59 | # sort by socre
60 | _, img_indices = torch.sort(cls_pred_scores, dim=0, descending=True)
61 | # calc ap
62 | TP, TPFP = 0., 0.
63 | for img_idx in img_indices:
64 | label = cls_grth_labels[img_idx]
65 | # accumulate
66 | TPFP += 1
67 | if label == 1:
68 | TP += 1
69 | aps[cls_idx] += TP / TPFP
70 | aps[cls_idx] /= (TP + 1e-5)
71 | # return
72 | return aps
73 |
74 | # Func:
75 | # Compute average precision.
76 | def calc_avg_precision2(self):
77 | self.pred_scores_ = self.pred_scores_.cpu().numpy().astype('float32')
78 | self.grth_labels_ = self.grth_labels_.cpu().numpy().astype('float32')
79 | # check
80 | if self.pred_scores_.size == 0: return 0
81 | # calc by class
82 | aps = np.zeros(self.pred_scores_.shape[1])
83 | for cls_idx in range(self.pred_scores_.shape[1]):
84 | # get pred scores & grth labels at class cls_idx
85 | cls_pred_scores = self.pred_scores_[:, cls_idx]
86 | cls_grth_labels = self.grth_labels_[:, cls_idx]
87 | # compute ap for a object category
88 | aps[cls_idx] = sklearn.metrics.average_precision_score(cls_grth_labels, cls_pred_scores)
89 | aps[np.isnan(aps)] = 0
90 | aps = np.around(aps, decimals=4)
91 | return aps
92 |
93 |
94 | class MultiClassificationMetric(object):
95 | """Computes and stores the average and current value"""
96 | def __init__(self):
97 | super(MultiClassificationMetric, self).__init__()
98 | self.reset()
99 | self.val = 0
100 |
101 | def update(self, value, n=1):
102 | self.val = value
103 | self.sum += value
104 | self.var += value * value
105 | self.n += n
106 |
107 | if self.n == 0:
108 | self.mean, self.std = np.nan, np.nan
109 | elif self.n == 1:
110 | self.mean, self.std = self.sum, np.inf
111 | self.mean_old = self.mean
112 | self.m_s = 0.0
113 | else:
114 | self.mean = self.mean_old + (value - n * self.mean_old) / float(self.n)
115 | self.m_s += (value - self.mean_old) * (value - self.mean)
116 | self.mean_old = self.mean
117 | self.std = math.sqrt(self.m_s / (self.n - 1.0))
118 |
119 | def reset(self):
120 | self.n = 0
121 | self.sum = 0.0
122 | self.var = 0.0
123 | self.val = 0.0
124 | self.mean = np.nan
125 | self.mean_old = 0.0
126 | self.m_s = 0.0
127 | self.std = np.nan
128 |
129 |
130 | def simple_accuracy(output, target):
131 | """计算预测正确的准确率"""
132 | with torch.no_grad():
133 | _, preds = torch.max(output, 1)
134 |
135 | correct = preds.eq(target).float()
136 | accuracy = correct.sum() / len(target)
137 | return accuracy
--------------------------------------------------------------------------------
/toolkit/dhelper.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | def get_file_name_ext(filepath):
5 | # analyze
6 | file_name, file_ext = os.path.splitext(filepath)
7 | # return
8 | return file_name, file_ext
9 |
10 |
11 | def get_file_ext(filepath):
12 | return get_file_name_ext(filepath)[1]
13 |
14 |
15 | def traverse_recursively(fileroot, filepathes=[], extension='.*'):
16 | """Traverse all file path in specialed directory recursively.
17 |
18 | Args:
19 | h: crop height.
20 | extension: e.g. '.jpg .png .bmp .webp .tif .eps'
21 | """
22 | items = os.listdir(fileroot)
23 | for item in items:
24 | if os.path.isfile(os.path.join(fileroot, item)):
25 | filepath = os.path.join(fileroot, item)
26 | fileext = get_file_ext(filepath).lower()
27 | if extension == '.*':
28 | filepathes.append(filepath)
29 | elif fileext in extension:
30 | filepathes.append(filepath)
31 | else:
32 | pass
33 | elif os.path.isdir(os.path.join(fileroot, item)):
34 | traverse_recursively(os.path.join(fileroot, item), filepathes, extension)
35 | else:
36 | pass
37 |
--------------------------------------------------------------------------------
/toolkit/dtransform.py:
--------------------------------------------------------------------------------
1 | from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform
2 | from timm.data.transforms import RandomResizedCropAndInterpolation
3 | from PIL import Image
4 | import torchvision.transforms as transforms
5 | import cv2
6 | import numpy as np
7 | import torchvision.transforms.functional as F
8 |
9 |
10 | # 添加jpeg压缩
11 | class JPEGCompression:
12 | def __init__(self, quality=10, p=0.3):
13 | self.quality = quality
14 | self.p = p
15 |
16 | def __call__(self, img):
17 | if np.random.rand() < self.p:
18 | img_np = np.array(img)
19 | _, buffer = cv2.imencode('.jpg', img_np[:, :, ::-1], [int(cv2.IMWRITE_JPEG_QUALITY), self.quality])
20 | jpeg_img = cv2.imdecode(buffer, 1)
21 | return Image.fromarray(jpeg_img[:, :, ::-1])
22 | return img
23 |
24 |
25 | # 原始数据增强
26 | def transforms_imagenet_train(
27 | img_size=(224, 224),
28 | scale=(0.08, 1.0),
29 | ratio=(3./4., 4./3.),
30 | hflip=0.5,
31 | vflip=0.5,
32 | auto_augment='rand-m9-mstd0.5-inc1',
33 | interpolation='random',
34 | mean=(0.485, 0.456, 0.406),
35 | jpeg_compression = 0,
36 | ):
37 | scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
38 | ratio = tuple(ratio or (3./4., 4./3.)) # default imagenet ratio range
39 |
40 | primary_tfl = [
41 | RandomResizedCropAndInterpolation(img_size, scale=scale, ratio=ratio, interpolation=interpolation)]
42 | if hflip > 0.:
43 | primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)]
44 | if vflip > 0.:
45 | primary_tfl += [transforms.RandomVerticalFlip(p=vflip)]
46 |
47 | secondary_tfl = []
48 | if auto_augment:
49 | assert isinstance(auto_augment, str)
50 |
51 | if isinstance(img_size, (tuple, list)):
52 | img_size_min = min(img_size)
53 | else:
54 | img_size_min = img_size
55 |
56 | aa_params = dict(
57 | translate_const=int(img_size_min * 0.45),
58 | img_mean=tuple([min(255, round(255 * x)) for x in mean]),
59 | )
60 | if auto_augment.startswith('rand'):
61 | secondary_tfl += [rand_augment_transform(auto_augment, aa_params)]
62 | elif auto_augment.startswith('augmix'):
63 | aa_params['translate_pct'] = 0.3
64 | secondary_tfl += [augment_and_mix_transform(auto_augment, aa_params)]
65 | else:
66 | secondary_tfl += [auto_augment_transform(auto_augment, aa_params)]
67 |
68 | if jpeg_compression == 1:
69 | secondary_tfl += [JPEGCompression(quality=10, p=0.3)]
70 |
71 | final_tfl = [transforms.ToTensor()]
72 |
73 | return transforms.Compose(primary_tfl + secondary_tfl + final_tfl)
74 |
75 |
76 | # 推理(测试)使用
77 | def create_transforms_inference(h=256, w=256):
78 | transformer = transforms.Compose([
79 | transforms.Resize(size=(h, w)),
80 | transforms.ToTensor(),
81 | ])
82 |
83 | return transformer
84 |
85 |
86 | def create_transforms_inference1(h=256, w=256):
87 | transformer = transforms.Compose([
88 | transforms.Lambda(lambda img: F.rotate(img, angle=90)),
89 | transforms.Resize(size=(h, w)),
90 | transforms.ToTensor(),
91 | ])
92 |
93 | return transformer
94 |
95 |
96 | def create_transforms_inference2(h=256, w=256):
97 | transformer = transforms.Compose([
98 | transforms.Lambda(lambda img: F.rotate(img, angle=180)),
99 | transforms.Resize(size=(h, w)),
100 | transforms.ToTensor(),
101 | ])
102 |
103 | return transformer
104 |
105 |
106 | def create_transforms_inference3(h=256, w=256):
107 | transformer = transforms.Compose([
108 | transforms.Lambda(lambda img: F.rotate(img, angle=270)),
109 | transforms.Resize(size=(h, w)),
110 | transforms.ToTensor(),
111 | ])
112 |
113 | return transformer
114 |
115 |
116 | def create_transforms_inference4(h=256, w=256):
117 | transformer = transforms.Compose([
118 | transforms.Lambda(lambda img: F.hflip(img)),
119 | transforms.Resize(size=(h, w)),
120 | transforms.ToTensor(),
121 | ])
122 |
123 | return transformer
124 |
125 |
126 | def create_transforms_inference5(h=256, w=256):
127 | transformer = transforms.Compose([
128 | transforms.Lambda(lambda img: F.vflip(img)),
129 | transforms.Resize(size=(h, w)),
130 | transforms.ToTensor(),
131 | ])
132 |
133 | return transformer
134 |
--------------------------------------------------------------------------------
/toolkit/yacs.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018-present, Facebook, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | ##############################################################################
15 |
16 | """YACS -- Yet Another Configuration System is designed to be a simple
17 | configuration management system for academic and industrial research
18 | projects.
19 |
20 | See README.md for usage and examples.
21 | """
22 |
23 | import copy
24 | import io
25 | import logging
26 | import os
27 | import sys
28 | from ast import literal_eval
29 |
30 | import yaml
31 |
32 | # Flag for py2 and py3 compatibility to use when separate code paths are necessary
33 | # When _PY2 is False, we assume Python 3 is in use
34 | _PY2 = sys.version_info.major == 2
35 |
36 | # Filename extensions for loading configs from files
37 | _YAML_EXTS = {"", ".yaml", ".yml"}
38 | _PY_EXTS = {".py"}
39 |
40 | # py2 and py3 compatibility for checking file object type
41 | # We simply use this to infer py2 vs py3
42 | if _PY2:
43 | _FILE_TYPES = (file, io.IOBase)
44 | else:
45 | _FILE_TYPES = (io.IOBase,)
46 |
47 | # CfgNodes can only contain a limited set of valid types
48 | _VALID_TYPES = {tuple, list, str, int, float, bool, type(None)}
49 | # py2 allow for str and unicode
50 | if _PY2:
51 | _VALID_TYPES = _VALID_TYPES.union({unicode}) # noqa: F821
52 |
53 | # Utilities for importing modules from file paths
54 | if _PY2:
55 | # imp is available in both py2 and py3 for now, but is deprecated in py3
56 | import imp
57 | else:
58 | import importlib.util
59 |
60 | logger = logging.getLogger(__name__)
61 |
62 |
63 | class CfgNode(dict):
64 | """
65 | CfgNode represents an internal node in the configuration tree. It's a simple
66 | dict-like container that allows for attribute-based access to keys.
67 | """
68 |
69 | IMMUTABLE = "__immutable__"
70 | DEPRECATED_KEYS = "__deprecated_keys__"
71 | RENAMED_KEYS = "__renamed_keys__"
72 | NEW_ALLOWED = "__new_allowed__"
73 |
74 | def __init__(self, init_dict=None, key_list=None, new_allowed=False):
75 | """
76 | Args:
77 | init_dict (dict): the possibly-nested dictionary to initailize the CfgNode.
78 | key_list (list[str]): a list of names which index this CfgNode from the root.
79 | Currently only used for logging purposes.
80 | new_allowed (bool): whether adding new key is allowed when merging with
81 | other configs.
82 | """
83 | # Recursively convert nested dictionaries in init_dict into CfgNodes
84 | init_dict = {} if init_dict is None else init_dict
85 | key_list = [] if key_list is None else key_list
86 | init_dict = self._create_config_tree_from_dict(init_dict, key_list)
87 | super(CfgNode, self).__init__(init_dict)
88 | # Manage if the CfgNode is frozen or not
89 | self.__dict__[CfgNode.IMMUTABLE] = False
90 | # Deprecated options
91 | # If an option is removed from the code and you don't want to break existing
92 | # yaml configs, you can add the full config key as a string to the set below.
93 | self.__dict__[CfgNode.DEPRECATED_KEYS] = set()
94 | # Renamed options
95 | # If you rename a config option, record the mapping from the old name to the new
96 | # name in the dictionary below. Optionally, if the type also changed, you can
97 | # make the value a tuple that specifies first the renamed key and then
98 | # instructions for how to edit the config file.
99 | self.__dict__[CfgNode.RENAMED_KEYS] = {
100 | # 'EXAMPLE.OLD.KEY': 'EXAMPLE.NEW.KEY', # Dummy example to follow
101 | # 'EXAMPLE.OLD.KEY': ( # A more complex example to follow
102 | # 'EXAMPLE.NEW.KEY',
103 | # "Also convert to a tuple, e.g., 'foo' -> ('foo',) or "
104 | # + "'foo:bar' -> ('foo', 'bar')"
105 | # ),
106 | }
107 |
108 | # Allow new attributes after initialisation
109 | self.__dict__[CfgNode.NEW_ALLOWED] = new_allowed
110 |
111 | @classmethod
112 | def _create_config_tree_from_dict(cls, dic, key_list):
113 | """
114 | Create a configuration tree using the given dict.
115 | Any dict-like objects inside dict will be treated as a new CfgNode.
116 |
117 | Args:
118 | dic (dict):
119 | key_list (list[str]): a list of names which index this CfgNode from the root.
120 | Currently only used for logging purposes.
121 | """
122 | dic = copy.deepcopy(dic)
123 | for k, v in dic.items():
124 | if isinstance(v, dict):
125 | # Convert dict to CfgNode
126 | dic[k] = cls(v, key_list=key_list + [k])
127 | else:
128 | # Check for valid leaf type or nested CfgNode
129 | _assert_with_logging(
130 | _valid_type(v, allow_cfg_node=False),
131 | "Key {} with value {} is not a valid type; valid types: {}".format(
132 | ".".join(key_list + [str(k)]), type(v), _VALID_TYPES
133 | ),
134 | )
135 | return dic
136 |
137 | def __getattr__(self, name):
138 | if name in self:
139 | return self[name]
140 | else:
141 | raise AttributeError(name)
142 |
143 | def __setattr__(self, name, value):
144 | if self.is_frozen():
145 | raise AttributeError(
146 | "Attempted to set {} to {}, but CfgNode is immutable".format(
147 | name, value
148 | )
149 | )
150 |
151 | _assert_with_logging(
152 | name not in self.__dict__,
153 | "Invalid attempt to modify internal CfgNode state: {}".format(name),
154 | )
155 | _assert_with_logging(
156 | _valid_type(value, allow_cfg_node=True),
157 | "Invalid type {} for key {}; valid types = {}".format(
158 | type(value), name, _VALID_TYPES
159 | ),
160 | )
161 |
162 | self[name] = value
163 |
164 | def __str__(self):
165 | def _indent(s_, num_spaces):
166 | s = s_.split("\n")
167 | if len(s) == 1:
168 | return s_
169 | first = s.pop(0)
170 | s = [(num_spaces * " ") + line for line in s]
171 | s = "\n".join(s)
172 | s = first + "\n" + s
173 | return s
174 |
175 | r = ""
176 | s = []
177 | for k, v in sorted(self.items()):
178 | seperator = "\n" if isinstance(v, CfgNode) else " "
179 | attr_str = "{}:{}{}".format(str(k), seperator, str(v))
180 | attr_str = _indent(attr_str, 2)
181 | s.append(attr_str)
182 | r += "\n".join(s)
183 | return r
184 |
185 | def __repr__(self):
186 | return "{}({})".format(self.__class__.__name__, super(CfgNode, self).__repr__())
187 |
188 | def dump(self, **kwargs):
189 | """Dump to a string."""
190 |
191 | def convert_to_dict(cfg_node, key_list):
192 | if not isinstance(cfg_node, CfgNode):
193 | _assert_with_logging(
194 | _valid_type(cfg_node),
195 | "Key {} with value {} is not a valid type; valid types: {}".format(
196 | ".".join(key_list), type(cfg_node), _VALID_TYPES
197 | ),
198 | )
199 | return cfg_node
200 | else:
201 | cfg_dict = dict(cfg_node)
202 | for k, v in cfg_dict.items():
203 | cfg_dict[k] = convert_to_dict(v, key_list + [k])
204 | return cfg_dict
205 |
206 | self_as_dict = convert_to_dict(self, [])
207 | return yaml.safe_dump(self_as_dict, **kwargs)
208 |
209 | def merge_from_file(self, cfg_filename):
210 | """Load a yaml config file and merge it this CfgNode."""
211 | with open(cfg_filename, "r") as f:
212 | cfg = self.load_cfg(f)
213 | self.merge_from_other_cfg(cfg)
214 |
215 | def merge_from_other_cfg(self, cfg_other):
216 | """Merge `cfg_other` into this CfgNode."""
217 | _merge_a_into_b(cfg_other, self, self, [])
218 |
219 | def merge_from_list(self, cfg_list):
220 | """Merge config (keys, values) in a list (e.g., from command line) into
221 | this CfgNode. For example, `cfg_list = ['FOO.BAR', 0.5]`.
222 | """
223 | _assert_with_logging(
224 | len(cfg_list) % 2 == 0,
225 | "Override list has odd length: {}; it must be a list of pairs".format(
226 | cfg_list
227 | ),
228 | )
229 | root = self
230 | for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]):
231 | if root.key_is_deprecated(full_key):
232 | continue
233 | if root.key_is_renamed(full_key):
234 | root.raise_key_rename_error(full_key)
235 | key_list = full_key.split(".")
236 | d = self
237 | for subkey in key_list[:-1]:
238 | _assert_with_logging(
239 | subkey in d, "Non-existent key: {}".format(full_key)
240 | )
241 | d = d[subkey]
242 | subkey = key_list[-1]
243 | _assert_with_logging(subkey in d, "Non-existent key: {}".format(full_key))
244 | value = self._decode_cfg_value(v)
245 | value = _check_and_coerce_cfg_value_type(value, d[subkey], subkey, full_key)
246 | d[subkey] = value
247 |
248 | def freeze(self):
249 | """Make this CfgNode and all of its children immutable."""
250 | self._immutable(True)
251 |
252 | def defrost(self):
253 | """Make this CfgNode and all of its children mutable."""
254 | self._immutable(False)
255 |
256 | def is_frozen(self):
257 | """Return mutability."""
258 | return self.__dict__[CfgNode.IMMUTABLE]
259 |
260 | def _immutable(self, is_immutable):
261 | """Set immutability to is_immutable and recursively apply the setting
262 | to all nested CfgNodes.
263 | """
264 | self.__dict__[CfgNode.IMMUTABLE] = is_immutable
265 | # Recursively set immutable state
266 | for v in self.__dict__.values():
267 | if isinstance(v, CfgNode):
268 | v._immutable(is_immutable)
269 | for v in self.values():
270 | if isinstance(v, CfgNode):
271 | v._immutable(is_immutable)
272 |
273 | def clone(self):
274 | """Recursively copy this CfgNode."""
275 | return copy.deepcopy(self)
276 |
277 | def register_deprecated_key(self, key):
278 | """Register key (e.g. `FOO.BAR`) a deprecated option. When merging deprecated
279 | keys a warning is generated and the key is ignored.
280 | """
281 | _assert_with_logging(
282 | key not in self.__dict__[CfgNode.DEPRECATED_KEYS],
283 | "key {} is already registered as a deprecated key".format(key),
284 | )
285 | self.__dict__[CfgNode.DEPRECATED_KEYS].add(key)
286 |
287 | def register_renamed_key(self, old_name, new_name, message=None):
288 | """Register a key as having been renamed from `old_name` to `new_name`.
289 | When merging a renamed key, an exception is thrown alerting to user to
290 | the fact that the key has been renamed.
291 | """
292 | _assert_with_logging(
293 | old_name not in self.__dict__[CfgNode.RENAMED_KEYS],
294 | "key {} is already registered as a renamed cfg key".format(old_name),
295 | )
296 | value = new_name
297 | if message:
298 | value = (new_name, message)
299 | self.__dict__[CfgNode.RENAMED_KEYS][old_name] = value
300 |
301 | def key_is_deprecated(self, full_key):
302 | """Test if a key is deprecated."""
303 | if full_key in self.__dict__[CfgNode.DEPRECATED_KEYS]:
304 | logger.warning("Deprecated config key (ignoring): {}".format(full_key))
305 | return True
306 | return False
307 |
308 | def key_is_renamed(self, full_key):
309 | """Test if a key is renamed."""
310 | return full_key in self.__dict__[CfgNode.RENAMED_KEYS]
311 |
312 | def raise_key_rename_error(self, full_key):
313 | new_key = self.__dict__[CfgNode.RENAMED_KEYS][full_key]
314 | if isinstance(new_key, tuple):
315 | msg = " Note: " + new_key[1]
316 | new_key = new_key[0]
317 | else:
318 | msg = ""
319 | raise KeyError(
320 | "Key {} was renamed to {}; please update your config.{}".format(
321 | full_key, new_key, msg
322 | )
323 | )
324 |
325 | def is_new_allowed(self):
326 | return self.__dict__[CfgNode.NEW_ALLOWED]
327 |
328 | def set_new_allowed(self, is_new_allowed):
329 | """
330 | Set this config (and recursively its subconfigs) to allow merging
331 | new keys from other configs.
332 | """
333 | self.__dict__[CfgNode.NEW_ALLOWED] = is_new_allowed
334 | # Recursively set new_allowed state
335 | for v in self.__dict__.values():
336 | if isinstance(v, CfgNode):
337 | v.set_new_allowed(is_new_allowed)
338 | for v in self.values():
339 | if isinstance(v, CfgNode):
340 | v.set_new_allowed(is_new_allowed)
341 |
342 | @classmethod
343 | def load_cfg(cls, cfg_file_obj_or_str):
344 | """
345 | Load a cfg.
346 | Args:
347 | cfg_file_obj_or_str (str or file):
348 | Supports loading from:
349 | - A file object backed by a YAML file
350 | - A file object backed by a Python source file that exports an attribute
351 | "cfg" that is either a dict or a CfgNode
352 | - A string that can be parsed as valid YAML
353 | """
354 | _assert_with_logging(
355 | isinstance(cfg_file_obj_or_str, _FILE_TYPES + (str,)),
356 | "Expected first argument to be of type {} or {}, but it was {}".format(
357 | _FILE_TYPES, str, type(cfg_file_obj_or_str)
358 | ),
359 | )
360 | if isinstance(cfg_file_obj_or_str, str):
361 | return cls._load_cfg_from_yaml_str(cfg_file_obj_or_str)
362 | elif isinstance(cfg_file_obj_or_str, _FILE_TYPES):
363 | return cls._load_cfg_from_file(cfg_file_obj_or_str)
364 | else:
365 | raise NotImplementedError("Impossible to reach here (unless there's a bug)")
366 |
367 | @classmethod
368 | def _load_cfg_from_file(cls, file_obj):
369 | """Load a config from a YAML file or a Python source file."""
370 | _, file_extension = os.path.splitext(file_obj.name)
371 | if file_extension in _YAML_EXTS:
372 | return cls._load_cfg_from_yaml_str(file_obj.read())
373 | elif file_extension in _PY_EXTS:
374 | return cls._load_cfg_py_source(file_obj.name)
375 | else:
376 | raise Exception(
377 | "Attempt to load from an unsupported file type {}; "
378 | "only {} are supported".format(file_obj, _YAML_EXTS.union(_PY_EXTS))
379 | )
380 |
381 | @classmethod
382 | def _load_cfg_from_yaml_str(cls, str_obj):
383 | """Load a config from a YAML string encoding."""
384 | cfg_as_dict = yaml.safe_load(str_obj)
385 | return cls(cfg_as_dict)
386 |
387 | @classmethod
388 | def _load_cfg_py_source(cls, filename):
389 | """Load a config from a Python source file."""
390 | module = _load_module_from_file("yacs.config.override", filename)
391 | _assert_with_logging(
392 | hasattr(module, "cfg"),
393 | "Python module from file {} must have 'cfg' attr".format(filename),
394 | )
395 | VALID_ATTR_TYPES = {dict, CfgNode}
396 | _assert_with_logging(
397 | type(module.cfg) in VALID_ATTR_TYPES,
398 | "Imported module 'cfg' attr must be in {} but is {} instead".format(
399 | VALID_ATTR_TYPES, type(module.cfg)
400 | ),
401 | )
402 | return cls(module.cfg)
403 |
404 | @classmethod
405 | def _decode_cfg_value(cls, value):
406 | """
407 | Decodes a raw config value (e.g., from a yaml config files or command
408 | line argument) into a Python object.
409 |
410 | If the value is a dict, it will be interpreted as a new CfgNode.
411 | If the value is a str, it will be evaluated as literals.
412 | Otherwise it is returned as-is.
413 | """
414 | # Configs parsed from raw yaml will contain dictionary keys that need to be
415 | # converted to CfgNode objects
416 | if isinstance(value, dict):
417 | return cls(value)
418 | # All remaining processing is only applied to strings
419 | if not isinstance(value, str):
420 | return value
421 | # Try to interpret `value` as a:
422 | # string, number, tuple, list, dict, boolean, or None
423 | try:
424 | value = literal_eval(value)
425 | # The following two excepts allow v to pass through when it represents a
426 | # string.
427 | #
428 | # Longer explanation:
429 | # The type of v is always a string (before calling literal_eval), but
430 | # sometimes it *represents* a string and other times a data structure, like
431 | # a list. In the case that v represents a string, what we got back from the
432 | # yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is
433 | # ok with '"foo"', but will raise a ValueError if given 'foo'. In other
434 | # cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval
435 | # will raise a SyntaxError.
436 | except ValueError:
437 | pass
438 | except SyntaxError:
439 | pass
440 | return value
441 |
442 |
443 | load_cfg = (
444 | CfgNode.load_cfg
445 | ) # keep this function in global scope for backward compatibility
446 |
447 |
448 | def _valid_type(value, allow_cfg_node=False):
449 | return (type(value) in _VALID_TYPES) or (
450 | allow_cfg_node and isinstance(value, CfgNode)
451 | )
452 |
453 |
454 | def _merge_a_into_b(a, b, root, key_list):
455 | """Merge config dictionary a into config dictionary b, clobbering the
456 | options in b whenever they are also specified in a.
457 | """
458 | _assert_with_logging(
459 | isinstance(a, CfgNode),
460 | "`a` (cur type {}) must be an instance of {}".format(type(a), CfgNode),
461 | )
462 | _assert_with_logging(
463 | isinstance(b, CfgNode),
464 | "`b` (cur type {}) must be an instance of {}".format(type(b), CfgNode),
465 | )
466 |
467 | for k, v_ in a.items():
468 | full_key = ".".join(key_list + [k])
469 |
470 | v = copy.deepcopy(v_)
471 | v = b._decode_cfg_value(v)
472 |
473 | if k in b:
474 | v = _check_and_coerce_cfg_value_type(v, b[k], k, full_key)
475 | # Recursively merge dicts
476 | if isinstance(v, CfgNode):
477 | try:
478 | _merge_a_into_b(v, b[k], root, key_list + [k])
479 | except BaseException:
480 | raise
481 | else:
482 | b[k] = v
483 | elif b.is_new_allowed():
484 | b[k] = v
485 | else:
486 | if root.key_is_deprecated(full_key):
487 | continue
488 | elif root.key_is_renamed(full_key):
489 | root.raise_key_rename_error(full_key)
490 | else:
491 | raise KeyError("Non-existent config key: {}".format(full_key))
492 |
493 |
494 | def _check_and_coerce_cfg_value_type(replacement, original, key, full_key):
495 | """Checks that `replacement`, which is intended to replace `original` is of
496 | the right type. The type is correct if it matches exactly or is one of a few
497 | cases in which the type can be easily coerced.
498 | """
499 | original_type = type(original)
500 | replacement_type = type(replacement)
501 |
502 | # The types must match (with some exceptions)
503 | if replacement_type == original_type:
504 | return replacement
505 |
506 | # If either of them is None, allow type conversion to one of the valid types
507 | if (replacement_type == type(None) and original_type in _VALID_TYPES) or (
508 | original_type == type(None) and replacement_type in _VALID_TYPES
509 | ):
510 | return replacement
511 |
512 | # Cast replacement from from_type to to_type if the replacement and original
513 | # types match from_type and to_type
514 | def conditional_cast(from_type, to_type):
515 | if replacement_type == from_type and original_type == to_type:
516 | return True, to_type(replacement)
517 | else:
518 | return False, None
519 |
520 | # Conditionally casts
521 | # list <-> tuple
522 | casts = [(tuple, list), (list, tuple)]
523 | # For py2: allow converting from str (bytes) to a unicode string
524 | try:
525 | casts.append((str, unicode)) # noqa: F821
526 | except Exception:
527 | pass
528 |
529 | for (from_type, to_type) in casts:
530 | converted, converted_value = conditional_cast(from_type, to_type)
531 | if converted:
532 | return converted_value
533 |
534 | raise ValueError(
535 | "Type mismatch ({} vs. {}) with values ({} vs. {}) for config "
536 | "key: {}".format(
537 | original_type, replacement_type, original, replacement, full_key
538 | )
539 | )
540 |
541 |
542 | def _assert_with_logging(cond, msg):
543 | if not cond:
544 | logger.debug(msg)
545 | assert cond, msg
546 |
547 |
548 | def _load_module_from_file(name, filename):
549 | if _PY2:
550 | module = imp.load_source(name, filename)
551 | else:
552 | spec = importlib.util.spec_from_file_location(name, filename)
553 | module = importlib.util.module_from_spec(spec)
554 | spec.loader.exec_module(module)
555 | return module
556 |
--------------------------------------------------------------------------------