├── ckpts └── README.md ├── SDxl └── README.md ├── outputs └── README.md ├── data ├── GMDD │ └── phase1 │ │ ├── valset │ │ └── README.md │ │ ├── trainset │ │ └── README.md │ │ ├── trainset_label.txt │ │ └── valset_label.txt └── GMDD_test │ └── phase2 │ ├── testset1_seen │ └── README.md │ └── testset1_seen_nolabel.txt ├── figures ├── path_to_save_combined_image.jpg └── Read_images_and_show.py ├── network ├── SRM_Net.py ├── SimpleModel.py ├── IPD_Net.py ├── util │ ├── CBAM.py │ ├── SRM.py │ ├── NL.py │ └── resnet.py └── trainer.py ├── options ├── test_config.yaml └── train_config.yaml ├── LICENSE ├── gen_SDxl.py ├── API.py ├── README.md ├── test.py ├── train.py └── dataloder └── CustomDataset.py /ckpts/README.md: -------------------------------------------------------------------------------- 1 | The training weights are here -------------------------------------------------------------------------------- /SDxl/README.md: -------------------------------------------------------------------------------- 1 | The result of running gen_SDxl.py will be here -------------------------------------------------------------------------------- /outputs/README.md: -------------------------------------------------------------------------------- 1 | The result of running test.py will be here -------------------------------------------------------------------------------- /data/GMDD/phase1/valset/README.md: -------------------------------------------------------------------------------- 1 | Download the dataset and place it according to the path -------------------------------------------------------------------------------- /data/GMDD/phase1/trainset/README.md: -------------------------------------------------------------------------------- 1 | Download the dataset and place it according to the path -------------------------------------------------------------------------------- /data/GMDD_test/phase2/testset1_seen/README.md: -------------------------------------------------------------------------------- 1 | Download the dataset and place it according to the path -------------------------------------------------------------------------------- /figures/path_to_save_combined_image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gbljdgb/Kaggle-2024-Deepfake-Image-Detection/HEAD/figures/path_to_save_combined_image.jpg -------------------------------------------------------------------------------- /data/GMDD/phase1/trainset_label.txt: -------------------------------------------------------------------------------- 1 | img_name,target 2 | 3381ccbc4df9e7778b720d53a2987014.jpg,1 3 | 63fee8a89581307c0b4fd05a48e0ff79.jpg,0 4 | 7eb4553a58ab5a05ba59b40725c903fd.jpg,0 5 | 920085930764461878d67b71703778e8.jpg,1 6 | f6320687a93ccb0c5fa892dc3361b804.jpg,1 7 | 74970d23dab29994ce4513f1c6faaaa5.jpg,1 -------------------------------------------------------------------------------- /data/GMDD/phase1/valset_label.txt: -------------------------------------------------------------------------------- 1 | img_name,target 2 | cd0e3907b3312f6046b98187fc25f9c7.jpg,1 3 | aa92be19d0adf91a641301cfcce71e8a.jpg,0 4 | 5413a0b706d33ed0208e2e4e2cacaa06.jpg,0 5 | c90f2cfd5b5fd759febcdfa8ccade77b.jpg,1 6 | b9c3a3900c92767e2e9035765f5acb06.jpg,1 7 | e861870d8acddafcc07e529ee459a452.jpg,1 8 | b18fa89b2a8ebf0de89a1e0886d15e14.jpg,0 -------------------------------------------------------------------------------- /data/GMDD_test/phase2/testset1_seen_nolabel.txt: -------------------------------------------------------------------------------- 1 | img_name 2 | c5decfb888b08593e980537191776a84.jpg 3 | 46f7e69cdc47b9ed7c3efd935c88993c.jpg 4 | 4a5c44d7fa626fd5ef556166f07c669a.jpg 5 | 63a47345219a2f0b8c0215e9f405c745.jpg 6 | 36d40bca5f09e2c5d59ea826f7fe5f5a.jpg 7 | c90c524f4b44b4fe4e79bcfd1c5260a2.jpg 8 | d9e05c0424a708d3ddc4c2359a91c959.jpg 9 | 2eafc03ca65be4ab33df0271691e1c18.jpg 10 | e40c3e3537a13e4bf005d46f254dd415.jpg -------------------------------------------------------------------------------- /network/SRM_Net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from network.util.SRM import SRMConv2d_simple 4 | import timm 5 | 6 | class SRM_Net(nn.Module): 7 | 8 | def __init__(self, name, pretrained): 9 | super().__init__() 10 | 11 | self.SRM_k = SRMConv2d_simple(inc=3, learnable=False) 12 | 13 | self.model = timm.create_model(name, pretrained=pretrained, num_classes=2) 14 | 15 | def forward(self, x): 16 | 17 | # learn from patchcraft,SRM first 18 | x = self.SRM_k(x) 19 | 20 | x = self.model(x) 21 | 22 | return x 23 | -------------------------------------------------------------------------------- /figures/Read_images_and_show.py: -------------------------------------------------------------------------------- 1 | import os, time 2 | from PIL import Image 3 | from tqdm import tqdm 4 | 5 | # 此文件夹下全是图片 6 | folder_path = "data/GMDD/phase1/trainset/" 7 | folder_path = "data/GMDD/phase1/valset/" 8 | # 要写入的图像的位置 9 | output_img = "figures/output.png" 10 | 11 | # 用于存储不同的图像尺寸 12 | image_sizes = set() 13 | 14 | # List the images in the folder 15 | images = os.listdir(folder_path) 16 | 17 | # Display the first image as an example 18 | for idx in tqdm(range(len(images))): 19 | image_path = os.path.join(folder_path, images[idx]) 20 | img = Image.open(image_path) 21 | image_sizes.add(img.size) 22 | if idx % 50000 == 0: 23 | print(image_sizes) 24 | # img.save(output_img) 25 | 26 | # 输出不同图像尺寸的数量 27 | print(f"共有 {len(image_sizes)} 种不同的图像尺寸") 28 | 29 | # 如果需要,可以打印所有不同的图像尺寸 30 | print("不同的图像尺寸如下:") 31 | for size in image_sizes: 32 | print(size) 33 | -------------------------------------------------------------------------------- /options/test_config.yaml: -------------------------------------------------------------------------------- 1 | save_output_path: 'outputs' # 测试结果保存的文件夹名称, save_output_path+name为保存的文件夹 2 | name: 'IPD_Net' # 测试结果保存的文件夹名称, save_output_path+name为保存的文件夹 3 | network: 'IPD_Net' # 选择的神经网络结构, 供选择的目标在network/trainer.py中 4 | manualSeed: 0 # 随机数种子 5 | test_batchSize: 32 # 测试batchsize 6 | workers: 8 # dataloader的worker数 7 | gpu: 2 # 主卡 8 | resize_or_crop: 'resize' # 最终输入模型前要进行"resize"还是"crop" 9 | input_shape: 512 # 最终输入模型前要进行"resize"或"crop"后输入模型的图像大小 10 | pretrained_path: "ckpts/IPD_Net/ckpt_best.pth" # 自己预训练模型的位置,如果为空代表没有自己的预训练模型 11 | select_test: -1 # 如果不为-1,则代表只是选取一部分的数据来跑全程看会不会报错,想测试写100就行 12 | 13 | diFF_prob: 0 # 选取一部分图像换成扩散模型生成的图像的概率, 永远设置为0 14 | flag: True # 代表这是在测试, 永远设置为True 15 | placeholder: "ckpts/vit_base_patch16_224_dino/ckpt_best_0.9102912723.pth" # 这是给ViT加载, 用来占位的 16 | 17 | data: 18 | test_dir_path: "data/GMDD_test/phase2/testset1_seen" # 验证图片在的文件夹 19 | test_label: "data/GMDD_test/phase2/testset1_seen_nolabel.txt" # 验证图片的图像名称集合 20 | -------------------------------------------------------------------------------- /network/SimpleModel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from efficientnet_pytorch import EfficientNet 4 | 5 | class SimpleModel(nn.Module): 6 | 7 | def __init__(self, config): 8 | super(SimpleModel, self).__init__() 9 | 10 | self.config = config 11 | self.model = EfficientNet.from_pretrained('efficientnet-b2') 12 | self.model._fc = nn.Linear(in_features=1408, out_features=self.config['backbone_config']['num_classes']) 13 | # initial params 14 | torch.nn.init.normal_(self.model._fc.weight.data, 0.0, 0.02) 15 | 16 | def forward(self, x): 17 | x = self.model(x) 18 | 19 | return x 20 | 21 | 22 | if __name__ == "__main__": 23 | import yaml 24 | # 加载配置文件 25 | with open('training/config.yaml', 'r') as f: 26 | config = yaml.safe_load(f) 27 | 28 | efficientnet = SimpleModel(config) 29 | output = efficientnet(torch.rand((4,3,256,256))) # 前向传播 30 | print(output) 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 gbljdgb 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /gen_SDxl.py: -------------------------------------------------------------------------------- 1 | from diffusers import StableDiffusionPipeline, StableDiffusionXLImg2ImgPipeline 2 | from diffusers.utils import load_image 3 | import torch 4 | import os 5 | import uuid # 用于生成随机的唯一ID 6 | import random 7 | from tqdm import tqdm 8 | 9 | ######################################## 10 | gpu = 'cuda:3' 11 | gen_num = 10000 # 要生成多少张图片 12 | output_dir = "SDvxl" # 要输出的文件夹 13 | input_dir = "data/GMDD/phase1/trainset" # 作为图生图的图片 14 | ######################################## 15 | 16 | pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( 17 | "stabilityai/stable-diffusion-xl-refiner-1.0", 18 | torch_dtype=torch.float16, 19 | variant="fp16", 20 | use_safetensors=True) 21 | pipe = pipe.to(gpu) 22 | 23 | 24 | image_files = [f for f in os.listdir(input_dir)] 25 | 26 | for i in range(gen_num): 27 | selected_image = random.choice(image_files) 28 | selected_image_path = os.path.join(input_dir, selected_image) 29 | 30 | prompt = "a human face" 31 | 32 | init_image = load_image(selected_image_path).convert("RGB") 33 | 34 | image = pipe(prompt, image=init_image).images[0] 35 | 36 | # 生成一个随机的UUID作为文件名 37 | random_filename = str(uuid.uuid4()) + '.png' 38 | image.save(os.path.join(output_dir, random_filename)) 39 | -------------------------------------------------------------------------------- /network/IPD_Net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | try: 4 | from network.util.resnet import resnet50 5 | from network.util.NL import NLBlockND 6 | from network.util.SRM import SRMConv2d_simple 7 | except: 8 | from util.resnet import resnet50 9 | from util.NL import NLBlockND 10 | from util.SRM import SRMConv2d_simple 11 | 12 | 13 | class IPD_Net(nn.Module): 14 | 15 | def __init__(self, config): 16 | super().__init__() 17 | 18 | self.res = resnet50() 19 | 20 | self.PCL = NLBlockND(in_channels=2048, dimension=2, mode="dot") 21 | 22 | self.SRM_k = SRMConv2d_simple(inc=3, learnable=False) 23 | 24 | self.a_pool = nn.AdaptiveAvgPool2d((1, 1)) 25 | 26 | self.fc1 = nn.Linear(256, 2) 27 | 28 | def forward(self, x): 29 | 30 | # learn from patchcraft,SRM first 31 | x = self.SRM_k(x) 32 | 33 | # basic opration,conv some times 34 | x = self.res.conv1(x) 35 | x = self.res.bn1(x) 36 | x = self.res.relu(x) 37 | x = self.res.maxpool(x) 38 | 39 | # resnet50 as backbone 40 | x = self.res.layer1(x) 41 | x = self.res.layer2(x) 42 | x = self.res.layer3(x) 43 | x = self.res.layer4(x) 44 | 45 | # cal a PCL M 46 | x = self.PCL(x) 47 | 48 | # adaptive pool 49 | x = self.a_pool(x) 50 | x = torch.flatten(x, 1) 51 | x = self.fc1(x) 52 | 53 | return x 54 | 55 | 56 | if __name__ == "__main__": 57 | # 生成一个随机的 (N, C, H, W) 张量,假设 N=8, C=3, H=224, W=224 58 | N, C, H, W = 2, 3, 512, 512 59 | input_tensor = torch.randn(N, C, H, W) 60 | 61 | model = IPD_Net(None) 62 | 63 | # 通过网络传递输入张量 64 | output = model(input_tensor) 65 | 66 | # 打印输出张量的形状 67 | print("Output shape:", output.shape) 68 | -------------------------------------------------------------------------------- /API.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import torch 3 | from flask import Flask, request, jsonify 4 | import numpy as np 5 | import random 6 | import torch.nn.functional as F 7 | from network.trainer import trainer 8 | from PIL import Image 9 | import torchvision.transforms as transforms 10 | 11 | app = Flask(__name__) 12 | 13 | def setup_seed(seed): 14 | torch.manual_seed(seed) 15 | torch.cuda.manual_seed_all(seed) 16 | np.random.seed(seed) 17 | random.seed(seed) 18 | 19 | # 加载配置文件 20 | with open('options/test_config.yaml', 'r') as f: 21 | config = yaml.safe_load(f) 22 | 23 | setup_seed(config['manualSeed']) 24 | 25 | # 创建模型实例 26 | TRAINER = trainer(config, TEST=True) 27 | TRAINER.set_mode(mode='eval') 28 | 29 | transform = transforms.Compose([transforms.Resize((config['input_shape'], config['input_shape'])), 30 | transforms.ToTensor(), 31 | transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),]) 32 | 33 | @app.route('/predict', methods=['POST']) 34 | def predict(): 35 | if 'image' not in request.files: 36 | return jsonify({'error': 'No image provided'}), 400 37 | 38 | image_file = request.files['image'] 39 | # 这里假设有个方法来处理图片并转换为合适的输入格式 40 | data_dict = preprocess_image(image_file) 41 | 42 | with torch.no_grad(): 43 | TRAINER.set_input(data_dict) 44 | TRAINER.forward() 45 | TRAINER.output = F.softmax(TRAINER.output, dim=1) 46 | 47 | pred_prob = TRAINER.output[0][1].cpu().item() # 只取第一个预测 48 | return jsonify({'probability': pred_prob}) 49 | 50 | 51 | def preprocess_image(image_file): 52 | # 实现图片预处理的方法,返回一个合适的data_dict 53 | image = Image.open(image_file.stream).convert('RGB') 54 | sample = transform(image) 55 | data_dict = {} 56 | data_dict['image'] = sample.unsqueeze(0) 57 | data_dict['label'] = sample.unsqueeze(0) 58 | return data_dict 59 | 60 | if __name__ == "__main__": 61 | app.run(host='0.0.0.0', port=10086) 62 | -------------------------------------------------------------------------------- /options/train_config.yaml: -------------------------------------------------------------------------------- 1 | nEpochs: 100 # 训练最大轮数 2 | save_ckpt_path: 'ckpts' # 权重保存的文件夹名称, save_ckpt_path+name为保存的文件夹 3 | name: 'IPD_Net' # 权重保存的文件夹名称, save_ckpt_path+name为保存的文件夹 4 | network: 'IPD_Net' # 选择的神经网络结构, 供选择的目标在network/trainer.py中 5 | manualSeed: 0 # 随机数种子 6 | printFreq: 10 # 一个epoch的打印次数 7 | train_batchSize: 32 # 训练batchsize 8 | val_batchSize: 32 # 验证batchsize 9 | workers: 8 # dataloader的worker数 10 | gpu: 2 # 主卡 11 | select_test: -1 # 如果不为-1, 则代表只是选取一部分的数据来跑全程看会不会报错, 想测试写100就行 12 | resize_or_crop: 'resize' # 最终输入模型前要进行"resize"还是"crop" 13 | input_shape: 512 # 最终输入模型前要进行"resize"或"crop"后输入模型的图像大小 14 | pretrained_path: "" # 预训练模型的位置, 如果为空代表没有自己的预训练模型, 不为空代表继续微调 15 | 16 | # 训练时带入的图像增强 17 | flip_prob: 0.5 # 翻转概率 18 | rotate_prob: 0.1 # 旋转概率 19 | rotate_limit: [-10, 10] # 旋转限制 20 | blur_prob: 0.1 # 高斯模糊概率 21 | blur_sig: [0.0,1.0] # 高斯模糊系数 22 | brightness_prob: 0.1 # 亮度对比度变动概率 23 | brightness_limit: [0, 0.1] # 亮度系数 24 | contrast_limit: [0, 0.1] # 对比系数 25 | jpeg_prob: 0.1 # jpeg压缩概率 26 | jpeg_method: ['cv2','pil'] # jpeg压缩方法 27 | jpeg_qual: [80, 100] # jpeg压缩质量 28 | diFF_prob: 0 # 选取一部分图像换成扩散模型生成的图像的概率 29 | diff_path: "SDxl" # 扩散模型的文件夹, 里面全是扩散模型生成的图片 30 | 31 | flag: false # 代表这是在训练, 永远为False 32 | 33 | data: 34 | train_dir_path: "data/GMDD/phase1/trainset" 35 | val_dir_path: "data/GMDD/phase1/valset" 36 | train_label: "data/GMDD/phase1/trainset_label.txt" 37 | val_label: "data/GMDD/phase1/valset_label.txt" 38 | 39 | scheduler: # 学习率下降 40 | type: CosineAnnealingLR 41 | CosineAnnealingLR: # 余弦退火 42 | T_max: 100 # 学习率衰减到最小值时,迭代的次数 43 | eta_min: 0.000001 # 学习率最小的时候是多少 44 | last_epoch: -1 # 学习率下降进行到哪个epoch 45 | verbose: True # 调整学习率的时候要不要说 46 | 47 | optimizer: 48 | type: sgd 49 | adam: 50 | lr: 0.0001 # learning rate 51 | beta1: 0.9 # beta1 for Adam optimizer 52 | beta2: 0.999 # beta2 for Adam optimizer 53 | eps: 0.00000001 # epsilon for Adam optimizer 54 | weight_decay: 0.0005 # weight decay for regularization 55 | amsgrad: false 56 | sgd: 57 | lr: 0.0001 # learning rate 58 | momentum: 0.9 # momentum for SGD optimizer 59 | weight_decay: 0.0005 # weight decay for regularization 60 | -------------------------------------------------------------------------------- /network/util/CBAM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.utils.model_zoo as model_zoo 5 | import numpy as np 6 | 7 | class ChannelAttentionModule(nn.Module): 8 | 9 | def __init__(self, channel, reduction=16): 10 | super(ChannelAttentionModule, self).__init__() 11 | mid_channel = channel // reduction 12 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 13 | self.max_pool = nn.AdaptiveMaxPool2d(1) 14 | 15 | self.shared_MLP = nn.Sequential( 16 | nn.Linear(in_features=channel, out_features=mid_channel), 17 | nn.ReLU(inplace=True), 18 | nn.Linear(in_features=mid_channel, out_features=channel)) 19 | self.sigmoid = nn.Sigmoid() 20 | 21 | def forward(self, x): 22 | avgout = self.shared_MLP(self.avg_pool(x).view( 23 | x.size(0), -1)).unsqueeze(2).unsqueeze(3) 24 | maxout = self.shared_MLP(self.max_pool(x).view( 25 | x.size(0), -1)).unsqueeze(2).unsqueeze(3) 26 | return self.sigmoid(avgout + maxout) 27 | 28 | 29 | class SpatialAttentionModule(nn.Module): 30 | 31 | def __init__(self): 32 | super(SpatialAttentionModule, self).__init__() 33 | self.conv2d = nn.Conv2d(in_channels=2, 34 | out_channels=1, 35 | kernel_size=7, 36 | stride=1, 37 | padding=3) 38 | self.sigmoid = nn.Sigmoid() 39 | 40 | def forward(self, x): 41 | avgout = torch.mean(x, dim=1, keepdim=True) 42 | maxout, _ = torch.max(x, dim=1, keepdim=True) 43 | out = torch.cat([avgout, maxout], dim=1) 44 | out = self.sigmoid(self.conv2d(out)) 45 | return out 46 | 47 | 48 | class CBAM(nn.Module): 49 | 50 | def __init__(self, channel): 51 | super(CBAM, self).__init__() 52 | self.channel_attention = ChannelAttentionModule(channel) 53 | self.spatial_attention = SpatialAttentionModule() 54 | 55 | def forward(self, x): 56 | # x:(N,C,H,W),self.channel_attention(x)->(N,C,1,1) 57 | out = self.channel_attention(x) * x 58 | # out:(N,C,H,W),self.spatial_attention(out)->(N,1,H,W) 59 | out = self.spatial_attention(out) * out 60 | return out 61 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [Kaggle-2024-competitions] Inclusion: The Global Multimedia Deepfake Detection 2 | 3 | Welcome to the open-source repository for the **Inclusion: The Global Multimedia Deepfake Detection** competition on Kaggle. This project aims to advance the detection of deepfake media across various formats and platforms. 4 | 5 | ![These are imgs in datasets](/figures/path_to_save_combined_image.jpg "faces") 6 | 7 | ## Competition Overview 8 | 9 | The competition consists of two phases: 10 | 11 | ### Phase 1: [Track 1: Deepfake Image Detection](https://www.kaggle.com/competitions/multi-ffdi) 12 | - **Objective**: [In the first phase, only the training and validation sets are released, and the leaderboard is sorted by the validation set] 13 | - **Duration**: [Jun 30 - Aug 22] 14 | - **Evaluation Metric**: [AUC] 15 | - **Ranking**: [38/706] 16 | - **AUC Score**: [0.9982558829] 17 | 18 | ### Phase 2: [Track 1: submitting test results](https://www.kaggle.com/competitions/multi-ffdi-phase2) 19 | - **Objective**: [In phase 2, the test set is released and the test set is used as a leaderboard result for phase 2] 20 | - **Duration**: [Aug 15 - Aug 22] 21 | - **Evaluation Metric**: [AUC] 22 | - **Ranking**: [46/184] 23 | - **AUC Score**: [0.9551696556] 24 | 25 | ## Getting Started 26 | 27 | ### Prerequisites 28 | - Python 3.x 29 | - [PyTorch, numpy and other common deep learning libraries] 30 | - diffusers[torch] 31 | 32 | ### How to Run the Code 33 | 1. **Prepare Your Environment**: 34 | 35 | Ensure that you have Python 3.x installed and that you have set up a virtual environment (optional but recommended). 36 | 37 | 3. **Clone the repository**: 38 | ```bash 39 | git clone https://github.com/gbljdgb/Kaggle-2024-Deepfake-Image-Detection.git 40 | cd Kaggle-2024-Deepfake-Image-Detection 41 | 42 | 4. **Prepare supplementary training set for diffusion model generation (optional)**: 43 | 44 | The image generated by diffusion model is saved in the ./SDxl folder. 45 | 46 | ```bash 47 | python gen_SDxl.py 48 | 49 | 5. **Train and Test**: 50 | 51 | The weights are stored in . /ckpts folder and output in . /outputs folder. 52 | 53 | ```bash 54 | python train.py 55 | python test.py 56 | 57 | ### How to Use the API after training 58 | 1. **Configuration File Modification** 59 | 60 | Before using the API, please ensure you modify the configuration file `test_config.yaml`. After that, you can find the file at the following link: 61 | 62 | [Changing test_config.yaml](https://github.com/gbljdgb/Kaggle-2024-Deepfake-Image-Detection/blob/main/options/test_config.yaml) 63 | 64 | 2. **Running API**: 65 | ```bash 66 | python API.py 67 | curl -X POST -F "image=@[image_path_put_here]" http://localhost:10086/predict 68 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import os 3 | import torch 4 | from torch.utils.data import DataLoader 5 | from tqdm import tqdm 6 | import numpy as np 7 | import random 8 | import datetime 9 | import csv 10 | import torch.nn.functional as F 11 | 12 | from dataloder.CustomDataset import CustomDataset 13 | from network.trainer import trainer 14 | 15 | def setup_seed(seed): 16 | torch.manual_seed(seed) 17 | torch.cuda.manual_seed_all(seed) 18 | np.random.seed(seed) 19 | random.seed(seed) 20 | # torch.backends.cudnn.deterministic = True 21 | 22 | if __name__ == "__main__": 23 | # 加载配置文件 24 | with open('options/test_config.yaml', 'r') as f: 25 | config = yaml.safe_load(f) 26 | if not os.path.exists(os.path.join(config['save_output_path'],config['name'])): 27 | os.mkdir(os.path.join(config['save_output_path'], config['name'])) 28 | with open(os.path.join(config['save_output_path'], config['name'], 'config.yaml'), 'w') as f: 29 | yaml.safe_dump(config, f) 30 | 31 | # 获取时间戳 32 | dt = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') 33 | print(f"[+]起始时间:{dt}") 34 | 35 | # 设置随机数种子 36 | setup_seed(config['manualSeed']) 37 | 38 | # 创建自定义数据集 39 | test_dataset = CustomDataset(config, mode='test') 40 | 41 | # 创建数据加载器 42 | test_loader = DataLoader(test_dataset, 43 | batch_size=config['test_batchSize'], 44 | num_workers=config['workers'], 45 | collate_fn=test_dataset.collate_fn, 46 | shuffle=False) 47 | 48 | # 创建模型实例 49 | TRAINER = trainer(config, TEST=True) 50 | 51 | # 创建一个字典来存储图像文件名和预测类别索引 52 | prediction_results = [] 53 | 54 | # 评估模型 55 | TRAINER.set_mode(mode='eval') 56 | 57 | with torch.no_grad(): 58 | for data_dict in tqdm(test_loader): 59 | TRAINER.set_input(data_dict) 60 | TRAINER.forward() 61 | TRAINER.output = F.softmax(TRAINER.output, dim=1) 62 | 63 | # 为每个图像记录预测结果 64 | for i, prediction in zip(TRAINER.label, TRAINER.output): 65 | img_name = ''.join([chr(num.item()) for num in i]) 66 | pred_prob = prediction[1].cpu().item() 67 | # print(pred_prob) 68 | prediction_results.append([img_name, pred_prob]) 69 | # exit(0) 70 | 71 | # 指定输出的 JSON 文件路径 72 | output_csv_file = os.path.join(config['save_output_path'], config['name'], 'output.csv') 73 | 74 | # 将结果写入 CSV 文件 75 | with open(output_csv_file, mode='w', newline='') as file: 76 | writer = csv.writer(file) 77 | # 写入表头 78 | writer.writerow(['img_name', 'y_pred']) 79 | # 写入每一行数据 80 | writer.writerows(prediction_results) 81 | 82 | print(f"Results saved to {output_csv_file}") 83 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import os 3 | import torch 4 | from torch.utils.data import DataLoader 5 | from tqdm import tqdm 6 | import numpy as np 7 | import random 8 | import datetime 9 | 10 | from dataloder.CustomDataset import CustomDataset 11 | from network.trainer import trainer 12 | 13 | def setup_seed(seed): 14 | torch.manual_seed(seed) 15 | torch.cuda.manual_seed_all(seed) 16 | np.random.seed(seed) 17 | random.seed(seed) 18 | # torch.backends.cudnn.deterministic = True 19 | 20 | if __name__ == "__main__": 21 | # 加载配置文件 22 | with open('options/train_config.yaml', 'r') as f: 23 | config = yaml.safe_load(f) 24 | if not os.path.exists(os.path.join(config['save_ckpt_path'],config['name'])): 25 | os.mkdir(os.path.join(config['save_ckpt_path'], config['name'])) 26 | with open(os.path.join(config['save_ckpt_path'], config['name'], 'config.yaml'), 'w') as f: 27 | yaml.safe_dump(config, f) 28 | 29 | # 获取时间戳 30 | dt = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') 31 | print(f"[+]起始时间:{dt}") 32 | 33 | # 设置随机数种子 34 | setup_seed(config['manualSeed']) 35 | 36 | # 创建自定义数据集 37 | train_dataset = CustomDataset(config, mode='train') 38 | val_dataset = CustomDataset(config, mode='val') 39 | 40 | # 创建数据加载器 41 | train_loader = DataLoader(train_dataset, 42 | batch_size=config['train_batchSize'], 43 | num_workers=config['workers'], 44 | collate_fn=train_dataset.collate_fn, 45 | shuffle=True) 46 | len_train_dataloader = len(train_loader) 47 | val_loader = DataLoader(val_dataset, 48 | batch_size=config['val_batchSize'], 49 | num_workers=config['workers'], 50 | collate_fn=val_dataset.collate_fn, 51 | shuffle=False) 52 | 53 | # 创建模型实例 54 | TRAINER = trainer(config) 55 | 56 | # 训练模型 57 | for epoch in range(config['nEpochs']): 58 | 59 | TRAINER.set_mode(mode='train') 60 | 61 | for index, data_dict in enumerate(tqdm(train_loader)): 62 | 63 | TRAINER.set_input(data_dict) 64 | TRAINER.optimize_parameters() 65 | 66 | if len(train_loader)//config['printFreq'] == 0: 67 | config['printFreq'] = 1 68 | if index % (len(train_loader)//config['printFreq']) == 0: 69 | print(f"[+]Batch Loss: {TRAINER.loss:.8f}") 70 | 71 | TRAINER.set_mode(mode='eval') 72 | 73 | with torch.no_grad(): 74 | TRAINER.test_reset() # 重置指标 75 | for data_dict in tqdm(val_loader): 76 | TRAINER.set_input(data_dict) 77 | TRAINER.test_forward() 78 | auc, acc = TRAINER.test_finish() 79 | 80 | TRAINER.save_ckpt(acc, auc, epoch) 81 | print(f"[+]Epoch[{epoch}], ACC:[{acc}], AUC:[{auc}]") 82 | 83 | TRAINER.scheduler_step() # 学习率调整 84 | 85 | TRAINER.writer.close() 86 | print(f"[+]tensorboard销毁成功") 87 | print("[+]Train Finish!") 88 | -------------------------------------------------------------------------------- /network/util/SRM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | # borrow from https://github.com/SCLBD/DeepfakeBench/blob/main/training/detectors/srm_detector.py#L302 7 | 8 | class SRMConv2d_simple(nn.Module): 9 | 10 | def __init__(self, inc=3, learnable=False): 11 | super(SRMConv2d_simple, self).__init__() 12 | self.truc = nn.Hardtanh(-3, 3) 13 | kernel = self._build_kernel(inc) # (3,3,5,5) 14 | self.kernel = nn.Parameter(data=kernel, requires_grad=learnable) 15 | # self.hor_kernel = self._build_kernel().transpose(0,1,3,2) 16 | 17 | def forward(self, x): 18 | ''' 19 | x: imgs (Batch, H, W, 3) 20 | ''' 21 | out = F.conv2d(x, self.kernel, stride=1, padding=2) 22 | out = self.truc(out) 23 | 24 | return out 25 | 26 | def _build_kernel(self, inc): 27 | # filter1: KB 28 | filter1 = [[0, 0, 0, 0, 0], [0, -1, 2, -1, 0], [0, 2, -4, 2, 0], 29 | [0, -1, 2, -1, 0], [0, 0, 0, 0, 0]] 30 | # filter2:KV 31 | filter2 = [[-1, 2, -2, 2, -1], [2, -6, 8, -6, 2], [-2, 8, -12, 8, -2], 32 | [2, -6, 8, -6, 2], [-1, 2, -2, 2, -1]] 33 | # # filter3:hor 2rd 34 | filter3 = [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 1, -2, 1, 0], 35 | [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]] 36 | # filter3:hor 2rd 37 | # filter3 = [[0, 0, 0, 0, 0], 38 | # [0, 0, 1, 0, 0], 39 | # [0, 1, -4, 1, 0], 40 | # [0, 0, 1, 0, 0], 41 | # [0, 0, 0, 0, 0]] 42 | 43 | filter1 = np.asarray(filter1, dtype=float) / 4. 44 | filter2 = np.asarray(filter2, dtype=float) / 12. 45 | filter3 = np.asarray(filter3, dtype=float) / 2. 46 | # statck the filters 47 | filters = [ 48 | [filter1], #, filter1, filter1], 49 | [filter2], #, filter2, filter2], 50 | [filter3] 51 | ] #, filter3, filter3]] # (3,3,5,5) 52 | filters = np.array(filters) 53 | filters = np.repeat(filters, inc, axis=1) 54 | filters = torch.FloatTensor(filters) # (3,3,5,5) 55 | return filters 56 | 57 | 58 | class SRMConv2d_Separate(nn.Module): 59 | 60 | def __init__(self, inc, outc, learnable=False): 61 | super(SRMConv2d_Separate, self).__init__() 62 | self.inc = inc 63 | self.truc = nn.Hardtanh(-3, 3) 64 | kernel = self._build_kernel(inc) # (3,3,5,5) 65 | self.kernel = nn.Parameter(data=kernel, requires_grad=learnable) 66 | # self.hor_kernel = self._build_kernel().transpose(0,1,3,2) 67 | self.out_conv = nn.Sequential( 68 | nn.Conv2d(3*inc, outc, 1, 1, 0, 1, 1, bias=False), 69 | nn.BatchNorm2d(outc), 70 | nn.ReLU(inplace=True) 71 | ) 72 | 73 | for ly in self.out_conv.children(): 74 | if isinstance(ly, nn.Conv2d): 75 | nn.init.kaiming_normal_(ly.weight, a=1) 76 | 77 | def forward(self, x): 78 | ''' 79 | x: imgs (Batch,inc, H, W) 80 | kernel: (outc,inc,kH,kW) 81 | ''' 82 | out = F.conv2d(x, self.kernel, stride=1, padding=2, groups=self.inc) 83 | out = self.truc(out) 84 | out = self.out_conv(out) 85 | 86 | return out 87 | 88 | def _build_kernel(self, inc): 89 | # filter1: KB 90 | filter1 = [[0, 0, 0, 0, 0], 91 | [0, -1, 2, -1, 0], 92 | [0, 2, -4, 2, 0], 93 | [0, -1, 2, -1, 0], 94 | [0, 0, 0, 0, 0]] 95 | # filter2:KV 96 | filter2 = [[-1, 2, -2, 2, -1], 97 | [2, -6, 8, -6, 2], 98 | [-2, 8, -12, 8, -2], 99 | [2, -6, 8, -6, 2], 100 | [-1, 2, -2, 2, -1]] 101 | # # filter3:hor 2rd 102 | filter3 = [[0, 0, 0, 0, 0], 103 | [0, 0, 0, 0, 0], 104 | [0, 1, -2, 1, 0], 105 | [0, 0, 0, 0, 0], 106 | [0, 0, 0, 0, 0]] 107 | # filter3:hor 2rd 108 | # filter3 = [[0, 0, 0, 0, 0], 109 | # [0, 0, 1, 0, 0], 110 | # [0, 1, -4, 1, 0], 111 | # [0, 0, 1, 0, 0], 112 | # [0, 0, 0, 0, 0]] 113 | 114 | filter1 = np.asarray(filter1, dtype=float) / 4. 115 | filter2 = np.asarray(filter2, dtype=float) / 12. 116 | filter3 = np.asarray(filter3, dtype=float) / 2. 117 | # statck the filters 118 | filters = [[filter1],#, filter1, filter1], 119 | [filter2],#, filter2, filter2], 120 | [filter3]]#, filter3, filter3]] # (3,3,5,5) => (3,1,5,5) 121 | filters = np.array(filters) 122 | # filters = np.repeat(filters, inc, axis=1) 123 | filters = np.repeat(filters, inc, axis=0) 124 | filters = torch.FloatTensor(filters) # (3*inc,1,5,5) 125 | # print(filters.size()) 126 | return filters 127 | -------------------------------------------------------------------------------- /network/util/NL.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | # borrow from https://github.com/tea1528/Non-Local-NN-Pytorch/blob/master/models/non_local.py 7 | class NLBlockND(nn.Module): 8 | 9 | def __init__(self, 10 | in_channels, 11 | inter_channels=None, 12 | mode='embedded', 13 | dimension=3, 14 | bn_layer=True): 15 | """Implementation of Non-Local Block with 4 different pairwise functions but doesn't include subsampling trick 16 | args: 17 | in_channels: original channel size (1024 in the paper) 18 | inter_channels: channel size inside the block if not specifed reduced to half (512 in the paper) 19 | mode: supports Gaussian, Embedded Gaussian, Dot Product, and Concatenation 20 | dimension: can be 1 (temporal), 2 (spatial), 3 (spatiotemporal) 21 | bn_layer: whether to add batch norm 22 | """ 23 | super(NLBlockND, self).__init__() 24 | 25 | assert dimension in [1, 2, 3] 26 | 27 | if mode not in ['gaussian', 'embedded', 'dot', 'concatenate']: 28 | raise ValueError( 29 | '`mode` must be one of `gaussian`, `embedded`, `dot` or `concatenate`' 30 | ) 31 | 32 | self.mode = mode 33 | self.dimension = dimension 34 | 35 | self.in_channels = in_channels 36 | self.inter_channels = inter_channels 37 | 38 | # the channel size is reduced to half inside the block 39 | if self.inter_channels is None: 40 | self.inter_channels = in_channels // 2 41 | if self.inter_channels == 0: 42 | self.inter_channels = 1 43 | 44 | # assign appropriate convolutional, max pool, and batch norm layers for different dimensions 45 | if dimension == 3: 46 | conv_nd = nn.Conv3d 47 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) 48 | bn = nn.BatchNorm3d 49 | elif dimension == 2: 50 | conv_nd = nn.Conv2d 51 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) 52 | bn = nn.BatchNorm2d 53 | else: 54 | conv_nd = nn.Conv1d 55 | max_pool_layer = nn.MaxPool1d(kernel_size=(2)) 56 | bn = nn.BatchNorm1d 57 | 58 | # function g in the paper which goes through conv. with kernel size 1 59 | self.g = conv_nd(in_channels=self.in_channels, 60 | out_channels=self.inter_channels, 61 | kernel_size=1) 62 | 63 | # add BatchNorm layer after the last conv layer 64 | if bn_layer: 65 | self.W_z = nn.Sequential( 66 | conv_nd(in_channels=self.inter_channels, 67 | out_channels=self.in_channels, 68 | kernel_size=1), bn(self.in_channels)) 69 | # from section 4.1 of the paper, initializing params of BN ensures that the initial state of non-local block is identity mapping 70 | nn.init.constant_(self.W_z[1].weight, 0) 71 | nn.init.constant_(self.W_z[1].bias, 0) 72 | else: 73 | self.W_z = conv_nd(in_channels=self.inter_channels, 74 | out_channels=self.in_channels, 75 | kernel_size=1) 76 | 77 | # from section 3.3 of the paper by initializing Wz to 0, this block can be inserted to any existing architecture 78 | nn.init.constant_(self.W_z.weight, 0) 79 | nn.init.constant_(self.W_z.bias, 0) 80 | 81 | # define theta and phi for all operations except gaussian 82 | if self.mode == "embedded" or self.mode == "dot" or self.mode == "concatenate": 83 | self.theta = conv_nd(in_channels=self.in_channels, 84 | out_channels=self.inter_channels, 85 | kernel_size=1) 86 | self.phi = conv_nd(in_channels=self.in_channels, 87 | out_channels=self.inter_channels, 88 | kernel_size=1) 89 | 90 | if self.mode == "concatenate": 91 | self.W_f = nn.Sequential( 92 | nn.Conv2d(in_channels=self.inter_channels * 2, 93 | out_channels=1, 94 | kernel_size=1), nn.ReLU()) 95 | 96 | def forward(self, x): 97 | """ 98 | args 99 | x: (N, C, T, H, W) for dimension=3; (N, C, H, W) for dimension 2; (N, C, T) for dimension 1 100 | """ 101 | 102 | batch_size = x.size(0) 103 | 104 | # (N, C, THW) 105 | # this reshaping and permutation is from the spacetime_nonlocal function in the original Caffe2 implementatio 106 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 107 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 108 | 109 | theta_x = theta_x.permute(0, 2, 1) 110 | 111 | f = torch.matmul(theta_x, phi_x) 112 | y = f 113 | 114 | # contiguous here just allocates contiguous chunk of memory 115 | y = y.permute(0, 2, 1).contiguous() 116 | 117 | y = y.view(batch_size, -1, *x.size()[2:]) 118 | 119 | return y 120 | 121 | 122 | if __name__ == '__main__': 123 | import torch 124 | 125 | for bn_layer in [True, False]: 126 | img = torch.zeros(2, 3, 20) 127 | net = NLBlockND(in_channels=3, 128 | mode='concatenate', 129 | dimension=1, 130 | bn_layer=bn_layer) 131 | out = net(img) 132 | print(out.size()) 133 | 134 | img = torch.zeros(2, 3, 20, 20) 135 | net = NLBlockND(in_channels=3, 136 | mode='concatenate', 137 | dimension=2, 138 | bn_layer=bn_layer) 139 | out = net(img) 140 | print(out.size()) 141 | 142 | img = torch.randn(2, 3, 8, 20, 20) 143 | net = NLBlockND(in_channels=3, 144 | mode='concatenate', 145 | dimension=3, 146 | bn_layer=bn_layer) 147 | out = net(img) 148 | print(out.size()) 149 | -------------------------------------------------------------------------------- /network/util/resnet.py: -------------------------------------------------------------------------------- 1 | # borrow from https://blog.csdn.net/weixin_44023658/article/details/105843701 2 | 3 | import torch.nn as nn 4 | import torch 5 | import torch.utils.model_zoo 6 | 7 | #18/34 8 | class BasicBlock(nn.Module): 9 | expansion = 1 #每一个conv的卷积核个数的倍数 10 | 11 | def __init__(self, in_channel, out_channel, stride=1, downsample=None):#downsample对应虚线残差结构 12 | super(BasicBlock, self).__init__() 13 | self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, 14 | kernel_size=3, stride=stride, padding=1, bias=False) 15 | self.bn1 = nn.BatchNorm2d(out_channel)#BN处理 16 | self.relu = nn.ReLU() 17 | self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel, 18 | kernel_size=3, stride=1, padding=1, bias=False) 19 | self.bn2 = nn.BatchNorm2d(out_channel) 20 | self.downsample = downsample 21 | 22 | def forward(self, x): 23 | identity = x #捷径上的输出值 24 | if self.downsample is not None: 25 | identity = self.downsample(x) 26 | 27 | out = self.conv1(x) 28 | out = self.bn1(out) 29 | out = self.relu(out) 30 | 31 | out = self.conv2(out) 32 | out = self.bn2(out) 33 | 34 | out += identity 35 | out = self.relu(out) 36 | 37 | return out 38 | 39 | #50,101,152 40 | class Bottleneck(nn.Module): 41 | expansion = 4#4倍 42 | 43 | def __init__(self, in_channel, out_channel, stride=1, downsample=None): 44 | super(Bottleneck, self).__init__() 45 | self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, 46 | kernel_size=1, stride=1, bias=False) # squeeze channels 47 | self.bn1 = nn.BatchNorm2d(out_channel) 48 | self.relu = nn.ReLU(inplace=True) 49 | # ----------------------------------------- 50 | self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel, 51 | kernel_size=3, stride=stride, bias=False, padding=1) 52 | self.bn2 = nn.BatchNorm2d(out_channel) 53 | self.relu = nn.ReLU(inplace=True) 54 | # ----------------------------------------- 55 | self.conv3 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel*self.expansion,#输出*4 56 | kernel_size=1, stride=1, bias=False) # unsqueeze channels 57 | self.bn3 = nn.BatchNorm2d(out_channel*self.expansion) 58 | self.relu = nn.ReLU(inplace=True) 59 | self.downsample = downsample 60 | 61 | def forward(self, x): 62 | identity = x 63 | if self.downsample is not None: 64 | identity = self.downsample(x) 65 | 66 | out = self.conv1(x) 67 | out = self.bn1(out) 68 | out = self.relu(out) 69 | 70 | out = self.conv2(out) 71 | out = self.bn2(out) 72 | out = self.relu(out) 73 | 74 | out = self.conv3(out) 75 | out = self.bn3(out) 76 | 77 | out += identity 78 | out = self.relu(out) 79 | 80 | return out 81 | 82 | 83 | class ResNet(nn.Module): 84 | 85 | def __init__(self, block, blocks_num, num_classes=1000, include_top=True):#block残差结构 include_top为了之后搭建更加复杂的网络 86 | super(ResNet, self).__init__() 87 | self.include_top = include_top 88 | self.in_channel = 64 89 | 90 | self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2, 91 | padding=3, bias=False) 92 | self.bn1 = nn.BatchNorm2d(self.in_channel) 93 | self.relu = nn.ReLU(inplace=True) 94 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 95 | self.layer1 = self._make_layer(block, 64, blocks_num[0]) 96 | self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2) 97 | self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2) 98 | self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2) 99 | if self.include_top: 100 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # output size = (1, 1)自适应 101 | self.fc = nn.Linear(512 * block.expansion, num_classes) 102 | 103 | for m in self.modules(): 104 | if isinstance(m, nn.Conv2d): 105 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 106 | 107 | def _make_layer(self, block, channel, block_num, stride=1): 108 | downsample = None 109 | if stride != 1 or self.in_channel != channel * block.expansion: 110 | downsample = nn.Sequential( 111 | nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False), 112 | nn.BatchNorm2d(channel * block.expansion)) 113 | 114 | layers = [] 115 | layers.append(block(self.in_channel, channel, downsample=downsample, stride=stride)) 116 | self.in_channel = channel * block.expansion 117 | 118 | for _ in range(1, block_num): 119 | layers.append(block(self.in_channel, channel)) 120 | 121 | return nn.Sequential(*layers) 122 | 123 | def forward(self, x): 124 | x = self.conv1(x) 125 | x = self.bn1(x) 126 | x = self.relu(x) 127 | x = self.maxpool(x) 128 | 129 | x = self.layer1(x) 130 | x = self.layer2(x) 131 | x = self.layer3(x) 132 | x = self.layer4(x) 133 | 134 | if self.include_top: 135 | x = self.avgpool(x) 136 | x = torch.flatten(x, 1) 137 | x = self.fc(x) 138 | 139 | return x 140 | 141 | 142 | def resnet34(num_classes=1000, include_top=True): 143 | return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top) 144 | 145 | def resnet50(num_classes=1000, include_top=True): 146 | pretrained_dict = torch.utils.model_zoo.load_url('https://download.pytorch.org/models/resnet50-19c8e357.pth') 147 | model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top) 148 | model.load_state_dict(pretrained_dict) 149 | return model 150 | 151 | def resnet101(num_classes=1000, include_top=True): 152 | return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top) 153 | -------------------------------------------------------------------------------- /dataloder/CustomDataset.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from torch.utils.data import Dataset 3 | import torch 4 | import torchvision.transforms as transforms 5 | import pandas as pd 6 | import os 7 | import numpy as np 8 | from random import random, choice, randint 9 | import cv2 10 | from io import BytesIO 11 | from scipy.ndimage.filters import gaussian_filter 12 | 13 | def cv2_jpg(img, compress_val): 14 | img_cv2 = img[:, :, ::-1] 15 | encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), compress_val] 16 | result, encimg = cv2.imencode('.jpg', img_cv2, encode_param) 17 | decimg = cv2.imdecode(encimg, 1) 18 | return decimg[:, :, ::-1] 19 | 20 | def pil_jpg(img, compress_val): 21 | out = BytesIO() 22 | img = Image.fromarray(img) 23 | img.save(out, format='jpeg', quality=compress_val) 24 | img = Image.open(out) 25 | # load from memory before ByteIO closes 26 | img = np.array(img) 27 | out.close() 28 | return img 29 | 30 | jpeg_dict = {'cv2': cv2_jpg, 'pil': pil_jpg} 31 | def jpeg_from_key(img, compress_val, key): 32 | method = jpeg_dict[key] 33 | return method(img, compress_val) 34 | 35 | def sample_discrete(s): 36 | if len(s) == 1: 37 | return s[0] 38 | return choice(s) 39 | 40 | def gaussian_blur(img, sigma): 41 | gaussian_filter(img[:, :, 0], output=img[:, :, 0], sigma=sigma) 42 | gaussian_filter(img[:, :, 1], output=img[:, :, 1], sigma=sigma) 43 | gaussian_filter(img[:, :, 2], output=img[:, :, 2], sigma=sigma) 44 | 45 | def sample_continuous(s): 46 | if len(s) == 1: 47 | return s[0] 48 | if len(s) == 2: 49 | rg = s[1] - s[0] 50 | return random() * rg + s[0] 51 | raise ValueError("Length of iterable s should be 1 or 2.") 52 | 53 | def data_augment(img, config): 54 | img = np.array(img) 55 | 56 | if random() < config['blur_prob']: 57 | sig = sample_continuous(config['blur_sig']) 58 | gaussian_blur(img, sig) 59 | 60 | if random() < config['jpeg_prob']: 61 | method = sample_discrete(config['jpeg_method']) 62 | qual = randint(*config['jpeg_qual']) 63 | img = jpeg_from_key(img, qual, method) 64 | 65 | return Image.fromarray(img) 66 | 67 | def create_transforms(config, mode): 68 | if mode == 'train': 69 | isTrain = 1 70 | else: 71 | isTrain = 0 72 | 73 | if config['resize_or_crop'] == 'resize': 74 | size_func = transforms.Resize((config['input_shape'], config['input_shape'])) 75 | elif config['resize_or_crop'] == 'crop': 76 | if isTrain: 77 | size_func = transforms.RandomCrop((config['input_shape'], config['input_shape'])) 78 | else: 79 | size_func = transforms.CenterCrop((config['input_shape'], config['input_shape'])) 80 | 81 | if isTrain: # 如果在训练模式下 82 | flip_func = transforms.RandomHorizontalFlip(p=config['flip_prob']) # 图像翻转 83 | rotate_func = transforms.RandomApply([transforms.RandomRotation(degrees=config['rotate_limit'])],p=config['rotate_prob']) # 图像旋转 84 | aug_func = transforms.Lambda(lambda img: data_augment(img, config)) # 图像压缩+高斯模糊 85 | brightness_func = transforms.RandomApply([transforms.ColorJitter(brightness=config['brightness_limit'], contrast=config['contrast_limit'])], p=config['brightness_prob']) 86 | else: 87 | flip_func = transforms.Lambda(lambda img: img) 88 | rotate_func = transforms.Lambda(lambda img: img) 89 | aug_func = transforms.Lambda(lambda img: img) 90 | brightness_func = transforms.Lambda(lambda img: img) 91 | 92 | return transforms.Compose([ 93 | aug_func, 94 | flip_func, 95 | rotate_func, 96 | brightness_func, 97 | size_func, 98 | transforms.ToTensor(), 99 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 100 | std=[0.229, 0.224, 0.225]), 101 | ]) 102 | 103 | # 定义自定义数据集类 104 | class CustomDataset(Dataset): 105 | 106 | def __init__(self, config, mode): 107 | self.config = config # 配置 108 | self.mode = mode # 模式 109 | self.data = None 110 | self.length = None # 读取的数据集的长度 111 | 112 | if mode == "train": 113 | self.data = pd.read_csv(config['data']['train_label']) 114 | print(self.data['target'].value_counts()) # 打印训练集的0和1的个数 115 | 116 | # 平衡训练集中0和1的个数 117 | class_0 = self.data[self.data['target'] == 0] 118 | class_1 = self.data[self.data['target'] == 1] 119 | # 计算两类数据的数量差异 120 | difference = len(class_1) - len(class_0) 121 | # 复制类别较少的数据 122 | class_0_upsampled = class_0.sample(difference, replace=True) 123 | # 合并原始数据和复制的数据 124 | balanced_data = pd.concat([self.data, class_0_upsampled]) 125 | # 打乱数据 126 | self.data = balanced_data.sample(frac=1).reset_index(drop=True) 127 | print(self.data['target'].value_counts()) # 打印训练集的0和1的个数 128 | 129 | self.length = len(self.data) 130 | self.transform = create_transforms(config, mode=mode) 131 | 132 | elif mode == 'val': 133 | self.data = pd.read_csv(config['data']['val_label']) 134 | print(self.data['target'].value_counts()) # 打印验证集的0和1的个数 135 | 136 | self.length = len(self.data) 137 | self.transform = create_transforms(config, mode=mode) 138 | 139 | elif mode == 'test': # 如果是测试模式 140 | self.data = pd.read_csv(config['data']['test_label']) 141 | 142 | self.length = len(self.data) 143 | self.transform = create_transforms(config, mode=mode) 144 | 145 | else: 146 | raise RuntimeError('数据集读取方式定义错误') 147 | 148 | if config['select_test'] != -1: 149 | self.length = config['select_test'] # 用来测试代码正确性,设置为100,则代表只选取100张图片 150 | 151 | if config['diFF_prob'] != 0.: 152 | print('[+]选择启用扩散模型数据集') 153 | folder_path = config['diff_path'] 154 | self.diFF_file_paths = [os.path.join(folder_path, file_name) for file_name in os.listdir(folder_path)] 155 | 156 | def __len__(self): 157 | return self.length 158 | 159 | def __getitem__(self, idx): 160 | if self.mode == 'train': 161 | img_path = os.path.join(self.config['data'][f'{self.mode}_dir_path'],self.data.iloc[idx]['img_name']) 162 | label = int(self.data.iloc[idx]['target']) 163 | 164 | if random() < self.config['diFF_prob'] and label == 1: 165 | img_path = choice(self.diFF_file_paths) 166 | 167 | sample = Image.open(img_path).convert('RGB') 168 | sample = self.transform(sample) 169 | 170 | one_hot = torch.zeros(2) 171 | one_hot[label] = 1. 172 | 173 | return sample, one_hot 174 | 175 | elif self.mode == 'val': 176 | img_path = os.path.join(self.config['data'][f'{self.mode}_dir_path'],self.data.iloc[idx]['img_name']) 177 | label = int(self.data.iloc[idx]['target']) 178 | 179 | sample = Image.open(img_path).convert('RGB') 180 | sample = self.transform(sample) 181 | 182 | one_hot = torch.zeros(2) 183 | one_hot[label] = 1. 184 | 185 | return sample, one_hot 186 | 187 | elif self.mode == 'test': 188 | img_path = os.path.join(self.config['data'][f'{self.mode}_dir_path'],self.data.iloc[idx]['img_name']) 189 | img_name = self.data.iloc[idx]['img_name'] # 这是写在CSV里面的 190 | ascii_values = [ord(c) for c in img_name] 191 | img_name_tensor = torch.tensor(ascii_values) # 图像名称转tensor 192 | 193 | sample = Image.open(img_path).convert('RGB') 194 | sample = self.transform(sample) 195 | 196 | return sample, img_name_tensor 197 | 198 | else: 199 | raise RuntimeError('ERROR!') 200 | 201 | @staticmethod 202 | def collate_fn(batch): 203 | # Separate the image, label, landmark, and mask tensors 204 | images, labels = zip(*batch) 205 | 206 | # Stack the image, label, landmark, and mask tensors 207 | images = torch.stack(images, dim=0) 208 | labels = torch.stack(labels, dim=0) 209 | 210 | # Create a dictionary of the tensors 211 | data_dict = {} 212 | data_dict['image'] = images 213 | data_dict['label'] = labels 214 | return data_dict 215 | -------------------------------------------------------------------------------- /network/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | from network.IPD_Net import IPD_Net 6 | import os 7 | from sklearn.metrics import roc_auc_score, accuracy_score 8 | import numpy as np 9 | import timm 10 | from timm.models.registry import register_model 11 | from timm.models.vision_transformer import _create_vision_transformer 12 | from torch.utils.tensorboard import SummaryWriter 13 | 14 | class trainer(): 15 | def __init__(self, config, TEST=False): 16 | """ 如果是测试而不是训练或者验证, 传入TEST=false """ 17 | 18 | self.config = config 19 | 20 | # 定义模型 21 | if config['network'] == "IPD_Net": # 已经确认是ResNet pretrained:ImageNet1k 22 | if config['pretrained_path'] != "": # 如果加载的是自己训练的模型 23 | self.model = IPD_Net(config) 24 | self.load_ckpt(config['pretrained_path']) 25 | print(f"[+]预训练模型[{config['pretrained_path']}]已加载成功") 26 | else: 27 | self.model = IPD_Net(config) 28 | print(f'[+]网络预训练模型加载成功') 29 | elif config['network'] == "efficientnet_b4.ra2_in1k": # Dataset: ImageNet-1k, Image size: train = 320 x 320 30 | if config['pretrained_path'] != "": # 如果加载的是自己训练的模型 31 | self.model = timm.create_model('efficientnet_b4.ra2_in1k', pretrained=False, num_classes=2) 32 | self.load_ckpt(config['pretrained_path']) 33 | print(f"[+]预训练模型[{config['pretrained_path']}]已加载成功") 34 | else: 35 | self.model = timm.create_model('efficientnet_b4.ra2_in1k', pretrained=True, num_classes=2) 36 | print(f'[+]网络预训练模型加载成功') 37 | elif config['network'] == "SRM->efficientnet_b4.ra2_in1k": 38 | from network.SRM_Net import SRM_Net 39 | if config['pretrained_path'] != "": # 如果加载的是自己训练的模型 40 | self.model = SRM_Net(name="efficientnet_b4.ra2_in1k", pretrained=False) 41 | self.load_ckpt(config['pretrained_path']) 42 | print(f"[+]预训练模型[{config['pretrained_path']}]已加载成功") 43 | else: 44 | self.model = SRM_Net(name="efficientnet_b4.ra2_in1k", pretrained=True) 45 | print(f'[+]网络预训练模型加载成功') 46 | elif config['network'] == "swin_base_patch4_window7_224_ms_in1k": 47 | if config['pretrained_path'] != "": # 如果加载的是自己训练的模型 48 | self.model = timm.create_model('swin_base_patch4_window7_224.ms_in1k', pretrained=False, num_classes=2) 49 | self.load_ckpt(config['pretrained_path']) 50 | print(f"[+]预训练模型[{config['pretrained_path']}]已加载成功") 51 | else: 52 | self.model = timm.create_model('swin_base_patch4_window7_224.ms_in1k', pretrained=True, num_classes=2) 53 | print(f'[+]网络预训练模型加载成功') 54 | elif config['network'] == "vit_base_patch16_224": # 这个tm是ImgaeNet21k... 55 | if config['pretrained_path'] != "": # 如果加载的是自己训练的模型 56 | self.model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=2) 57 | self.load_ckpt(config['pretrained_path']) 58 | print(f"[+]预训练模型[{config['pretrained_path']}]已加载成功") 59 | else: 60 | self.model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=2) 61 | print(f'[+]网络预训练模型加载成功') 62 | elif config['network'] == "vit_base_patch16_224_dino": # 这个是ImageNet1k 63 | if config['pretrained_path'] != "": # 如果加载的是自己训练的模型 64 | self.model = timm.create_model('vit_base_patch16_224.dino', pretrained=False, num_classes=2) 65 | self.load_ckpt(config['pretrained_path']) 66 | print(f"[+]预训练模型[{config['pretrained_path']}]已加载成功") 67 | else: 68 | self.model = timm.create_model('vit_base_patch16_224.dino', pretrained=True, num_classes=2) 69 | print(f'[+]网络预训练模型加载成功') 70 | elif config['network'] == "vit_base_patch16_224_dino_as_pretrained_vit_base_patch16_512_as_finetine": 71 | @register_model # 注册模型 72 | def vit_base_patch16_512(pretrained: bool = False, **kwargs): 73 | model_args = dict(img_size=512) 74 | model = _create_vision_transformer('vit_base_patch16_224.dino', pretrained=pretrained, **dict(model_args, **kwargs)) 75 | return model 76 | if config['pretrained_path'] != "" and not config['flag']: # 这是加载老模型做微调 77 | self.model = timm.create_model('vit_base_patch16_512') 78 | cfg = self.model.default_cfg 79 | cfg['file'] = './tmp_file.pt' 80 | torch.save(torch.load(config['pretrained_path'])['model_state_dict'], cfg['file']) # 加载是224的权重 81 | self.model = timm.create_model('vit_base_patch16_512', pretrained=True, pretrained_cfg=cfg, num_classes=2) 82 | os.remove(cfg['file']) 83 | print(f"[+]预训练模型[{config['pretrained_path']}]已加载成功") 84 | elif config['pretrained_path'] != "" and config['flag']: # 这是加载模型做测试 85 | self.model = timm.create_model('vit_base_patch16_512') 86 | cfg = self.model.default_cfg 87 | cfg['file'] = './tmp_file.pt' 88 | torch.save(torch.load(config['placeholder'])['model_state_dict'], cfg['file']) # 加载是随便弄的224的权重,用来逼出他使用正确的模型,其实感觉随便一个权重都行,512的也行... 89 | self.model = timm.create_model('vit_base_patch16_512', pretrained=True, pretrained_cfg=cfg, num_classes=2) 90 | os.remove(cfg['file']) 91 | self.load_ckpt(config['pretrained_path']) 92 | else: 93 | self.model = timm.create_model('vit_base_patch16_512', pretrained=True, num_classes=2) 94 | print(f'[+]网络预训练模型加载成功') 95 | elif config['network'] == "tf_efficientnet_b3.ns_jft_in1k": # 这个是ImageNet1k 96 | if config['pretrained_path'] != "": # 如果加载的是自己训练的模型 97 | self.model = timm.create_model('tf_efficientnet_b3.ns_jft_in1k', pretrained=False, num_classes=2) 98 | self.load_ckpt(config['pretrained_path']) 99 | print(f"[+]预训练模型[{config['pretrained_path']}]已加载成功") 100 | else: 101 | self.model = timm.create_model('tf_efficientnet_b3.ns_jft_in1k', pretrained=True, num_classes=2) 102 | print(f'[+]网络预训练模型加载成功') 103 | elif config['network'] == "vit_base_patch16_224_as_pretrained_vit_base_patch16_512_as_finetine": 104 | @register_model # 注册模型 105 | def vit_base_patch16_512(pretrained: bool = False, **kwargs): 106 | model_args = dict(img_size=512) 107 | model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) 108 | return model 109 | if config['pretrained_path'] != "" and not config['flag']: # 这是加载老模型然后微调 110 | self.model = timm.create_model('vit_base_patch16_512') 111 | cfg = self.model.default_cfg 112 | cfg['file'] = './tmp_file.pt' 113 | torch.save(torch.load(config['pretrained_path'])['model_state_dict'], cfg['file']) 114 | self.model = timm.create_model('vit_base_patch16_512', pretrained=True, pretrained_cfg=cfg, num_classes=2) 115 | os.remove(cfg['file']) 116 | print(f"[+]预训练模型[{config['pretrained_path']}]已加载成功") 117 | elif config['pretrained_path'] != "" and config['flag']: # 这是加载模型测试 118 | self.model = timm.create_model('vit_base_patch16_512') 119 | cfg = self.model.default_cfg 120 | cfg['file'] = './tmp_file.pt' 121 | torch.save(torch.load(config['placeholder'])['model_state_dict'], cfg['file']) 122 | self.model = timm.create_model('vit_base_patch16_512', pretrained=True, pretrained_cfg=cfg, num_classes=2) 123 | os.remove(cfg['file']) 124 | self.load_ckpt(config['pretrained_path']) 125 | else: 126 | self.model = timm.create_model('vit_base_patch16_512', pretrained=True, num_classes=2) 127 | print(f'[+]网络预训练模型加载成功') 128 | else: 129 | raise ValueError("[-]选定的神经网络结构未提供") 130 | print(f"[+]选定的神经网络结构为{config['network']}") 131 | 132 | 133 | # GPU设置 134 | self.device = torch.device(f"cuda:{config['gpu']}" if torch.cuda.is_available() else "cpu") # 设置主GPU 135 | self.model.to(self.device) # 设置主GPU 136 | 137 | 138 | if not TEST: 139 | self.ckpt_dir = os.path.join(self.config['save_ckpt_path'], self.config['name']) # 模型保存的文件夹 140 | self.writer = SummaryWriter(self.ckpt_dir) # 第一个参数指明 writer 把summary内容写在哪个目录下 141 | 142 | # 定义损失函数 143 | self.criterion = nn.CrossEntropyLoss() 144 | 145 | # 定义优化器 146 | optimizer_type = config['optimizer']['type'] 147 | optimizer_args = config['optimizer'][optimizer_type] 148 | if config['optimizer']['type'] == 'adam': 149 | self.optimizer = optim.Adam(self.model.parameters(), 150 | lr=optimizer_args['lr'], 151 | betas=(optimizer_args['beta1'], 152 | optimizer_args['beta2']), 153 | eps=optimizer_args['eps'], 154 | weight_decay=optimizer_args['weight_decay'], 155 | amsgrad=optimizer_args['amsgrad']) 156 | elif config['optimizer']['type'] == 'sgd': 157 | self.optimizer = optim.SGD(self.model.parameters(), 158 | lr=optimizer_args['lr'], 159 | momentum=optimizer_args['momentum'], 160 | weight_decay=optimizer_args['weight_decay']) 161 | else: 162 | raise ValueError("[-]选定的优化器未提供") 163 | 164 | # 定义学习率下降策略 165 | scheduler_type = config['scheduler']['type'] 166 | scheduler_args = config['scheduler'][scheduler_type] 167 | if scheduler_type == 'CosineAnnealingLR': 168 | self.scheduler = optim.lr_scheduler.CosineAnnealingLR( 169 | self.optimizer, 170 | T_max=scheduler_args['T_max'], 171 | eta_min=scheduler_args['eta_min'], 172 | last_epoch=scheduler_args['last_epoch'], 173 | verbose=scheduler_args['verbose']) 174 | else: 175 | raise ValueError("[-]选定的学习率下载策略未提供") 176 | 177 | # 记录最佳ACC,一个epoch更新一次 178 | self.best_acc = 0. 179 | self.best_auc = 0. 180 | 181 | # 记录每个epoch结束后的测试的所有结果 182 | self.gt_list = None 183 | self.pre_list = None 184 | 185 | def scheduler_step(self): 186 | self.scheduler.step() 187 | 188 | def load_ckpt(self, ckpt_path): 189 | # 先加载到cpu上 190 | load_ = torch.load(ckpt_path, map_location=f"cuda:{self.config['gpu']}") 191 | self.model.load_state_dict(load_['model_state_dict']) 192 | print(f"[+]ACC: {load_['acc']}") 193 | print(f"[+]AUC: {load_['auc']}") 194 | 195 | def set_mode(self, mode): 196 | if mode == 'train': 197 | print("[+]训练模式启动") 198 | self.model.train() 199 | else: 200 | print("[+]验证模式启动") 201 | self.model.eval() 202 | 203 | def set_input(self, data_dict: dict): 204 | self.input = data_dict['image'].to(self.device) 205 | self.label = data_dict['label'].to(self.device) 206 | 207 | def forward(self): 208 | self.output = self.model(self.input) 209 | return self.output 210 | 211 | def get_loss(self): 212 | """ 获取一个batch的loss """ 213 | return self.criterion(self.output, self.label) 214 | 215 | def optimize_parameters(self): 216 | self.forward() 217 | self.loss = self.get_loss() 218 | self.optimizer.zero_grad() # 梯度清空 219 | self.loss.backward() # 反向传播 220 | """ for name, parms in self.model.named_parameters(): 221 | print(f'-->name: {name}, -->grad_requires: {parms.requires_grad}, -->grad_value: {parms.grad}') """ 222 | self.optimizer.step() # 梯度下降 223 | 224 | def test_reset(self): 225 | """ 每次验证开始前要重置 """ 226 | self.gt_list = [] 227 | self.pre_list = [] 228 | 229 | def test_forward(self): 230 | self.forward() 231 | self.gt_list.extend(self.label[:, 1].cpu().numpy()) 232 | self.output = F.softmax(self.output, dim=1) # 做一次softmax,因为是二分类 233 | self.pre_list.extend(self.output[:, 1].cpu().numpy()) 234 | pass 235 | 236 | def test_finish(self): 237 | self.gt_list, self.pre_list = np.array(self.gt_list), np.array(self.pre_list) 238 | auc = roc_auc_score(self.gt_list, self.pre_list) 239 | acc = accuracy_score(self.gt_list, self.pre_list > 0.5) 240 | return auc,acc 241 | 242 | def save_ckpt(self, acc, auc, epoch): 243 | """ 模型保存 """ 244 | self.writer.add_scalar("ACC", acc, epoch) 245 | self.writer.add_scalar("AUC", auc, epoch) 246 | self.writer.add_scalar("lr", self.optimizer.state_dict()['param_groups'][0]['lr'], epoch) 247 | 248 | if acc > self.best_acc: 249 | self.best_acc = acc 250 | if auc > self.best_auc: 251 | self.best_auc = auc 252 | if not os.path.exists(self.ckpt_dir): 253 | os.makedirs(self.ckpt_dir) 254 | ckpt_path = os.path.join(self.ckpt_dir, f'ckpt_best.pth') 255 | checkpoint = { 256 | 'model_state_dict': self.model.state_dict(), 257 | 'acc': acc, 258 | 'auc': auc, 259 | 'optimizer_state_dict': self.optimizer.state_dict(), 260 | 'epoch': epoch, 261 | } 262 | torch.save(checkpoint, ckpt_path) 263 | print(f"[+]最好的权重保存在[{ckpt_path}]") 264 | 265 | 266 | 267 | if __name__ == "__main__": 268 | import yaml 269 | # 加载配置文件 270 | with open('training/config.yaml', 'r') as f: 271 | config = yaml.safe_load(f) 272 | --------------------------------------------------------------------------------