├── data ├── test │ └── .gitkeep ├── train │ └── .gitkeep └── temp │ ├── images │ └── .gitkeep │ └── labels │ └── .gitkeep ├── models ├── __init__.py └── model.py ├── config.py ├── move.py ├── README.md ├── dataset └── dataloader.py ├── data_aug.py ├── utils.py └── main.py /data/test/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/train/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/temp/images/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/temp/labels/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | class DefaultConfigs(object): 2 | #1.string parameters 3 | train_data = "./data/train/" 4 | test_data = "./data/test/" 5 | val_data = "no" 6 | model_name = "resnet50" 7 | weights = "./checkpoints/" 8 | best_models = weights + "best_model/" 9 | submit = "./submit/" 10 | logs = "./logs/" 11 | gpus = "1" 12 | 13 | #2.numeric parameters 14 | epochs = 40 15 | batch_size = 8 16 | img_height = 650 17 | img_weight = 650 18 | num_classes = 59 19 | seed = 888 20 | lr = 1e-4 21 | lr_decay = 1e-4 22 | weight_decay = 1e-4 23 | 24 | config = DefaultConfigs() 25 | -------------------------------------------------------------------------------- /move.py: -------------------------------------------------------------------------------- 1 | import json 2 | import shutil 3 | import os 4 | from glob import glob 5 | from tqdm import tqdm 6 | 7 | try: 8 | for i in range(0,59): 9 | os.mkdir("./data/train/" + str(i)) 10 | except: 11 | pass 12 | 13 | file_train = json.load(open("./data/temp/labels/AgriculturalDisease_train_annotations.json","r",encoding="utf-8")) 14 | file_val = json.load(open("./data/temp/labels/AgriculturalDisease_validation_annotations.json","r",encoding="utf-8")) 15 | 16 | file_list = file_train + file_val 17 | 18 | for file in tqdm(file_list): 19 | filename = file["image_id"] 20 | origin_path = "./data/temp/images/" + filename 21 | ids = file["disease_class"] 22 | if ids == 44: 23 | continue 24 | if ids == 45: 25 | continue 26 | if ids > 45: 27 | ids = ids -2 28 | save_path = "./data/train/" + str(ids) + "/" 29 | shutil.copy(origin_path,save_path) 30 | 31 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from config import config 5 | 6 | def generate_model(): 7 | class DenseModel(nn.Module): 8 | def __init__(self, pretrained_model): 9 | super(DenseModel, self).__init__() 10 | self.classifier = nn.Linear(pretrained_model.classifier.in_features, config.num_classes) 11 | 12 | for m in self.modules(): 13 | if isinstance(m, nn.Conv2d): 14 | nn.init.kaiming_normal(m.weight) 15 | elif isinstance(m, nn.BatchNorm2d): 16 | m.weight.data.fill_(1) 17 | m.bias.data.zero_() 18 | elif isinstance(m, nn.Linear): 19 | m.bias.data.zero_() 20 | 21 | self.features = pretrained_model.features 22 | self.layer1 = pretrained_model.features._modules['denseblock1'] 23 | self.layer2 = pretrained_model.features._modules['denseblock2'] 24 | self.layer3 = pretrained_model.features._modules['denseblock3'] 25 | self.layer4 = pretrained_model.features._modules['denseblock4'] 26 | 27 | def forward(self, x): 28 | features = self.features(x) 29 | out = F.relu(features, inplace=True) 30 | out = F.avg_pool2d(out, kernel_size=8).view(features.size(0), -1) 31 | out = F.sigmoid(self.classifier(out)) 32 | return out 33 | 34 | return DenseModel(torchvision.models.densenet169(pretrained=True)) 35 | 36 | def get_net(): 37 | #return MyModel(torchvision.models.resnet101(pretrained = True)) 38 | model = torchvision.models.resnet50(pretrained = True) 39 | #for param in model.parameters(): 40 | # param.requires_grad = False 41 | model.avgpool = nn.AdaptiveAvgPool2d(1) 42 | model.fc = nn.Linear(2048,config.num_classes) 43 | return model 44 | 45 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### 声明:开源只是为了方便大家交流学习,数据请勿用于商业用途!!!!转载或解读请注明出处,谢谢! 2 | 3 | **背景** 4 | 5 | 很早之前开源过 pytorch 进行图像分类的代码([从实例掌握 pytorch 进行图像分类](http://spytensor.com/index.php/archives/21/)),历时两个多月的学习和总结,近期也做了升级。在此基础上写了一个 Ai Challenger 农作物竞赛的 baseline 供大家交流。 6 | 7 | **2018 年 12 月 13 日更新** 8 | 9 | 新增数据集下载链接:[百度网盘]( https://pan.baidu.com/s/16f1nQchS-zBtzSWn9Guyyg ) 提取码:iksk 10 | 数据集是 10 月 23 日 更新后的新数据集,包含训练集、验证集、测试集A/B. 11 | 另外最近有同学拿到类似的数据,想做分类的任务,但是这份代码是针对这次比赛开源的,在数据读取方式上会有区别,对于新手来说不太友好,我开源了一份针对图像分类任务的代码,并附上简单教程,相信看完后能比较轻松使用 pytorch 进行图像分类。 12 | 13 | 教程: [从实例掌握 pytorch 进行图像分类](http://www.spytensor.com/index.php/archives/21/) 14 | 15 | 代码: [pytorch-image-classification](https://github.com/spytensor/pytorch-image-classification) 16 | 17 | **2018年 10 月 30 日更新** 18 | 19 | 新增 `data_aug.py` 用于线下数据增强,由于时间问题,这个比赛不再做啦,这些增强方式大家有需要可以研究一下,支持的增强方式: 20 | 21 | - 高斯噪声 22 | - 亮度变化 23 | - 左右翻转 24 | - 上下翻转 25 | - 色彩抖动 26 | - 对比度变化 27 | - 锐度变化 28 | 29 | 注:对比度增强在可视化后,主观感觉特征更明显了,目前我还未跑完。提醒一下,如果做了对比度增强,在测试集的时候最好也做一下。 30 | 31 | 个人博客:[超杰](http://spytensor.com/) 32 | 33 | 比赛地址:[农作物病害检测](https://challenger.ai/competition/pdr2018) 34 | 35 | 完整代码地址:[plants_disease_detection](https://github.com/spytensor/plants_disease_detection) 36 | 37 | 注: 38 | 欢迎大佬学习交流啊,这份代码可改进的地方太多了, 39 | 如果大佬们有啥改进的意见请指导! 40 | 联系方式:zhuchaojie@buaa.edu.cn 41 | 42 | **成绩**:线上 0.8805,线下0.875,由于划分存在随机性,可能复现会出现波动,已经尽可能排除随机种子的干扰了。 43 | 44 | ## 提醒 45 | 46 | `main.py` 中的test函数已经修正,执行后在 `./submit/`中会得到提交格式的 json 文件,现已支持 Focalloss 和交叉验证,需要的自行修改一下就可以了。 47 | 依赖中的 pytorch 版本请保持一致,不然可能会有一些小 BUG。 48 | 49 | ### 1. 依赖 50 | 51 | python3.6 pytorch0.4.1 52 | 53 | ### 2. 关于数据的处理 54 | 55 | 首先说明,使用的数据为官方更新后的数据,并做了一个统计分析(下文会给出),最后决定删除第 44 类和第 45 类。 56 | 并且由于数据分布的原因,我将 train 和 val 数据集合并后,采用随机划分。 57 | 58 | 数据增强方式: 59 | 60 | - RandomRotation(30) 61 | - RandomHorizontalFlip() 62 | - RandomVerticalFlip() 63 | - RandomAffine(45) 64 | 65 | 图片尺寸选择了 650,暂时没有对这个尺寸进行调优(毕竟太忙了。。) 66 | 67 | ### 3. 模型选择 68 | 69 | 模型目前就尝试了 resnet50,后续有卡的话再说吧。。。 70 | 71 | ### 4. 超参数设置 72 | 73 | 详情在 config.py 中 74 | 75 | ### 5.使用方法 76 | 77 | - 第一步:将测试集图片复制到 `data/test/` 下 78 | - 第二步:将训练集合验证集中的图片都复制到 `data/temp/images/` 下,将两个 `json` 文件放到 `data/temp/labels/` 下 79 | - 执行 move.py 文件 80 | - 执行 main.py 进行训练 81 | 82 | ### 6.数据分布图 83 | 84 | 训练集 85 | 86 | ![train](http://www.spytensor.com/images/plants/train.png) 87 | 88 | 验证集 89 | 90 | ![val](http://www.spytensor.com/images/plants/val.png) 91 | 92 | 全部数据集 93 | 94 | ![all](http://www.spytensor.com/images/plants/all.png) 95 | -------------------------------------------------------------------------------- /dataset/dataloader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from torchvision import transforms as T 3 | from config import config 4 | from PIL import Image 5 | from itertools import chain 6 | from glob import glob 7 | from tqdm import tqdm 8 | import random 9 | import numpy as np 10 | import pandas as pd 11 | import os 12 | import cv2 13 | import torch 14 | 15 | #1.set random seed 16 | random.seed(config.seed) 17 | np.random.seed(config.seed) 18 | torch.manual_seed(config.seed) 19 | torch.cuda.manual_seed_all(config.seed) 20 | 21 | #2.define dataset 22 | class ChaojieDataset(Dataset): 23 | def __init__(self,label_list,transforms=None,train=True,test=False): 24 | self.test = test 25 | self.train = train 26 | imgs = [] 27 | if self.test: 28 | for index,row in label_list.iterrows(): 29 | imgs.append((row["filename"])) 30 | self.imgs = imgs 31 | else: 32 | for index,row in label_list.iterrows(): 33 | imgs.append((row["filename"],row["label"])) 34 | self.imgs = imgs 35 | if transforms is None: 36 | if self.test or not train: 37 | self.transforms = T.Compose([ 38 | T.Resize((config.img_weight,config.img_height)), 39 | T.ToTensor(), 40 | T.Normalize(mean = [0.485,0.456,0.406], 41 | std = [0.229,0.224,0.225])]) 42 | else: 43 | self.transforms = T.Compose([ 44 | T.Resize((config.img_weight,config.img_height)), 45 | T.RandomRotation(30), 46 | T.RandomHorizontalFlip(), 47 | T.RandomVerticalFlip(), 48 | T.RandomAffine(45), 49 | T.ToTensor(), 50 | T.Normalize(mean = [0.485,0.456,0.406], 51 | std = [0.229,0.224,0.225])]) 52 | else: 53 | self.transforms = transforms 54 | def __getitem__(self,index): 55 | if self.test: 56 | filename = self.imgs[index] 57 | img = Image.open(filename) 58 | img = self.transforms(img) 59 | return img,filename 60 | else: 61 | filename,label = self.imgs[index] 62 | img = Image.open(filename) 63 | img = self.transforms(img) 64 | return img,label 65 | def __len__(self): 66 | return len(self.imgs) 67 | 68 | def collate_fn(batch): 69 | imgs = [] 70 | label = [] 71 | for sample in batch: 72 | imgs.append(sample[0]) 73 | label.append(sample[1]) 74 | 75 | return torch.stack(imgs, 0), \ 76 | label 77 | 78 | def get_files(root,mode): 79 | #for test 80 | if mode == "test": 81 | files = [] 82 | for img in os.listdir(root): 83 | files.append(root + img) 84 | files = pd.DataFrame({"filename":files}) 85 | return files 86 | elif mode != "test": 87 | #for train and val 88 | all_data_path,labels = [],[] 89 | image_folders = list(map(lambda x:root+x,os.listdir(root))) 90 | jpg_image_1 = list(map(lambda x:glob(x+"/*.jpg"),image_folders)) 91 | jpg_image_2 = list(map(lambda x:glob(x+"/*.JPG"),image_folders)) 92 | all_images = list(chain.from_iterable(jpg_image_1 + jpg_image_2)) 93 | print("loading train dataset") 94 | for file in tqdm(all_images): 95 | all_data_path.append(file) 96 | labels.append(int(file.split("/")[-2])) 97 | all_files = pd.DataFrame({"filename":all_data_path,"label":labels}) 98 | return all_files 99 | else: 100 | print("check the mode please!") 101 | 102 | -------------------------------------------------------------------------------- /data_aug.py: -------------------------------------------------------------------------------- 1 | from PIL import Image,ImageEnhance,ImageFilter,ImageOps 2 | import os 3 | import shutil 4 | import numpy as np 5 | import cv2 6 | import random 7 | from skimage.util import random_noise 8 | from skimage import exposure 9 | 10 | 11 | image_number = 0 12 | 13 | raw_path = "./data/train/" 14 | 15 | new_path = "./aug/train/" 16 | 17 | # 加高斯噪声 18 | def addNoise(img): 19 | ''' 20 | 注意:输出的像素是[0,1]之间,所以乘以5得到[0,255]之间 21 | ''' 22 | return random_noise(img, mode='gaussian', seed=13, clip=True)*255 23 | 24 | def changeLight(img): 25 | rate = random.uniform(0.5, 1.5) 26 | # print(rate) 27 | img = exposure.adjust_gamma(img, rate) #大于1为调暗,小于1为调亮;1.05 28 | return img 29 | 30 | try: 31 | for i in range(59): 32 | os.makedirs(new_path + os.sep + str(i)) 33 | except: 34 | pass 35 | 36 | for raw_dir_name in range(59): 37 | 38 | raw_dir_name = str(raw_dir_name) 39 | 40 | saved_image_path = new_path + raw_dir_name+"/" 41 | 42 | raw_image_path = raw_path + raw_dir_name+"/" 43 | 44 | if not os.path.exists(saved_image_path): 45 | 46 | os.mkdir(saved_image_path) 47 | 48 | raw_image_file_name = os.listdir(raw_image_path) 49 | 50 | raw_image_file_path = [] 51 | 52 | for i in raw_image_file_name: 53 | 54 | raw_image_file_path.append(raw_image_path+i) 55 | 56 | for x in raw_image_file_path: 57 | 58 | img = Image.open(x) 59 | cv_image = cv2.imread(x) 60 | 61 | # 高斯噪声 62 | gau_image = addNoise(cv_image) 63 | # 随机改变 64 | light = changeLight(cv_image) 65 | light_and_gau = addNoise(light) 66 | 67 | cv2.imwrite(saved_image_path + "gau_" + os.path.basename(x),gau_image) 68 | cv2.imwrite(saved_image_path + "light_" + os.path.basename(x),light) 69 | cv2.imwrite(saved_image_path + "gau_light" + os.path.basename(x),light_and_gau) 70 | #img = img.resize((800,600)) 71 | 72 | #1.翻转 73 | 74 | img_flip_left_right = img.transpose(Image.FLIP_LEFT_RIGHT) 75 | 76 | img_flip_top_bottom = img.transpose(Image.FLIP_TOP_BOTTOM) 77 | 78 | #2.旋转 79 | 80 | #img_rotate_90 = img.transpose(Image.ROTATE_90) 81 | 82 | #img_rotate_180 = img.transpose(Image.ROTATE_180) 83 | 84 | #img_rotate_270 = img.transpose(Image.ROTATE_270) 85 | 86 | #img_rotate_90_left = img_flip_left_right.transpose(Image.ROTATE_90) 87 | 88 | #img_rotate_270_left = img_flip_left_right.transpose(Image.ROTATE_270) 89 | 90 | #3.亮度 91 | 92 | #enh_bri = ImageEnhance.Brightness(img) 93 | #brightness = 1.5 94 | #image_brightened = enh_bri.enhance(brightness) 95 | 96 | #4.色彩 97 | 98 | #enh_col = ImageEnhance.Color(img) 99 | #color = 1.5 100 | 101 | #image_colored = enh_col.enhance(color) 102 | 103 | #5.对比度 104 | 105 | enh_con = ImageEnhance.Contrast(img) 106 | 107 | contrast = 1.5 108 | 109 | image_contrasted = enh_con.enhance(contrast) 110 | 111 | #6.锐度 112 | 113 | #enh_sha = ImageEnhance.Sharpness(img) 114 | #sharpness = 3.0 115 | 116 | #image_sharped = enh_sha.enhance(sharpness) 117 | 118 | #保存 119 | 120 | img.save(saved_image_path + os.path.basename(x)) 121 | 122 | img_flip_left_right.save(saved_image_path + "left_right_" + os.path.basename(x)) 123 | 124 | img_flip_top_bottom.save(saved_image_path + "top_bottom_" + os.path.basename(x)) 125 | 126 | #img_rotate_90.save(saved_image_path + "rotate_90_" + os.path.basename(x)) 127 | 128 | #img_rotate_180.save(saved_image_path + "rotate_180_" + os.path.basename(x)) 129 | 130 | #img_rotate_270.save(saved_image_path + "rotate_270_" + os.path.basename(x)) 131 | 132 | #img_rotate_90_left.save(saved_image_path + "rotate_90_left_" + os.path.basename(x)) 133 | 134 | #img_rotate_270_left.save(saved_image_path + "rotate_270_left_" + os.path.basename(x)) 135 | 136 | #image_brightened.save(saved_image_path + "brighted_" + os.path.basename(x)) 137 | 138 | #image_colored.save(saved_image_path + "colored_" + os.path.basename(x)) 139 | 140 | image_contrasted.save(saved_image_path + "contrasted_" + os.path.basename(x)) 141 | 142 | #image_sharped.save(saved_image_path + "sharped_" + os.path.basename(x)) 143 | 144 | image_number += 1 145 | 146 | print("convert pictur" "es :%s size:%s mode:%s" % (image_number, img.size, img.mode)) 147 | 148 | 149 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import torch 3 | import sys 4 | import os 5 | import json 6 | import numpy as np 7 | from config import config 8 | from torch import nn 9 | import torch.nn.functional as F 10 | def save_checkpoint(state, is_best,fold): 11 | filename = config.weights + config.model_name + os.sep +str(fold) + os.sep + "_checkpoint.pth.tar" 12 | torch.save(state, filename) 13 | if is_best: 14 | shutil.copyfile(filename, config.best_models + config.model_name+ os.sep +str(fold) + os.sep + 'model_best.pth.tar') 15 | 16 | class AverageMeter(object): 17 | """Computes and stores the average and current value""" 18 | def __init__(self): 19 | self.reset() 20 | 21 | def reset(self): 22 | self.val = 0 23 | self.avg = 0 24 | self.sum = 0 25 | self.count = 0 26 | 27 | def update(self, val, n=1): 28 | self.val = val 29 | self.sum += val * n 30 | self.count += n 31 | self.avg = self.sum / self.count 32 | 33 | def adjust_learning_rate(optimizer, epoch): 34 | """Sets the learning rate to the initial LR decayed by 10 every 3 epochs""" 35 | lr = config.lr * (0.1 ** (epoch // 3)) 36 | for param_group in optimizer.param_groups: 37 | param_group['lr'] = lr 38 | 39 | 40 | def schedule(current_epoch, current_lrs, **logs): 41 | lrs = [1e-3, 1e-4, 0.5e-4, 1e-5, 0.5e-5] 42 | epochs = [0, 1, 6, 8, 12] 43 | for lr, epoch in zip(lrs, epochs): 44 | if current_epoch >= epoch: 45 | current_lrs[5] = lr 46 | if current_epoch >= 2: 47 | current_lrs[4] = lr * 1 48 | current_lrs[3] = lr * 1 49 | current_lrs[2] = lr * 1 50 | current_lrs[1] = lr * 1 51 | current_lrs[0] = lr * 0.1 52 | return current_lrs 53 | 54 | def accuracy(output, target, topk=(1,)): 55 | """Computes the accuracy over the k top predictions for the specified values of k""" 56 | with torch.no_grad(): 57 | maxk = max(topk) 58 | batch_size = target.size(0) 59 | 60 | _, pred = output.topk(maxk, 1, True, True) 61 | pred = pred.t() 62 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 63 | 64 | res = [] 65 | for k in topk: 66 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 67 | res.append(correct_k.mul_(100.0 / batch_size)) 68 | return res 69 | 70 | class Logger(object): 71 | def __init__(self): 72 | self.terminal = sys.stdout #stdout 73 | self.file = None 74 | 75 | def open(self, file, mode=None): 76 | if mode is None: mode ='w' 77 | self.file = open(file, mode) 78 | 79 | def write(self, message, is_terminal=1, is_file=1 ): 80 | if '\r' in message: is_file=0 81 | 82 | if is_terminal == 1: 83 | self.terminal.write(message) 84 | self.terminal.flush() 85 | #time.sleep(1) 86 | 87 | if is_file == 1: 88 | self.file.write(message) 89 | self.file.flush() 90 | 91 | def flush(self): 92 | # this flush method is needed for python 3 compatibility. 93 | # this handles the flush command by doing nothing. 94 | # you might want to specify some extra behavior here. 95 | pass 96 | 97 | def get_learning_rate(optimizer): 98 | lr=[] 99 | for param_group in optimizer.param_groups: 100 | lr +=[ param_group['lr'] ] 101 | 102 | #assert(len(lr)==1) #we support only one param_group 103 | lr = lr[0] 104 | 105 | return lr 106 | 107 | 108 | def time_to_str(t, mode='min'): 109 | if mode=='min': 110 | t = int(t)/60 111 | hr = t//60 112 | min = t%60 113 | return '%2d hr %02d min'%(hr,min) 114 | 115 | elif mode=='sec': 116 | t = int(t) 117 | min = t//60 118 | sec = t%60 119 | return '%2d min %02d sec'%(min,sec) 120 | 121 | 122 | else: 123 | raise NotImplementedError 124 | 125 | 126 | class FocalLoss(nn.Module): 127 | 128 | def __init__(self, focusing_param=2, balance_param=0.25): 129 | super(FocalLoss, self).__init__() 130 | 131 | self.focusing_param = focusing_param 132 | self.balance_param = balance_param 133 | 134 | def forward(self, output, target): 135 | 136 | cross_entropy = F.cross_entropy(output, target) 137 | cross_entropy_log = torch.log(cross_entropy) 138 | logpt = - F.cross_entropy(output, target) 139 | pt = torch.exp(logpt) 140 | 141 | focal_loss = -((1 - pt) ** self.focusing_param) * logpt 142 | 143 | balanced_focal_loss = self.balance_param * focal_loss 144 | 145 | return balanced_focal_loss 146 | 147 | class MyEncoder(json.JSONEncoder): 148 | def default(self, obj): 149 | if isinstance(obj, np.integer): 150 | return int(obj) 151 | elif isinstance(obj, np.floating): 152 | return float(obj) 153 | elif isinstance(obj, np.ndarray): 154 | return obj.tolist() 155 | else: 156 | return super(MyEncoder, self).default(obj) 157 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import time 4 | import json 5 | import torch 6 | import torchvision 7 | import numpy as np 8 | import pandas as pd 9 | import warnings 10 | from datetime import datetime 11 | from torch import nn,optim 12 | from config import config 13 | from collections import OrderedDict 14 | from torch.autograd import Variable 15 | from torch.utils.data import DataLoader 16 | from dataset.dataloader import * 17 | from sklearn.model_selection import train_test_split,StratifiedKFold 18 | from timeit import default_timer as timer 19 | from models.model import * 20 | from utils import * 21 | 22 | #1. set random.seed and cudnn performance 23 | random.seed(config.seed) 24 | np.random.seed(config.seed) 25 | torch.manual_seed(config.seed) 26 | torch.cuda.manual_seed_all(config.seed) 27 | os.environ["CUDA_VISIBLE_DEVICES"] = config.gpus 28 | torch.backends.cudnn.benchmark = True 29 | warnings.filterwarnings('ignore') 30 | 31 | #2. evaluate func 32 | def evaluate(val_loader,model,criterion): 33 | #2.1 define meters 34 | losses = AverageMeter() 35 | top1 = AverageMeter() 36 | top2 = AverageMeter() 37 | #2.2 switch to evaluate mode and confirm model has been transfered to cuda 38 | model.cuda() 39 | model.eval() 40 | with torch.no_grad(): 41 | for i,(input,target) in enumerate(val_loader): 42 | input = Variable(input).cuda() 43 | target = Variable(torch.from_numpy(np.array(target)).long()).cuda() 44 | #target = Variable(target).cuda() 45 | #2.2.1 compute output 46 | output = model(input) 47 | loss = criterion(output,target) 48 | 49 | #2.2.2 measure accuracy and record loss 50 | precision1,precision2 = accuracy(output,target,topk=(1,2)) 51 | losses.update(loss.item(),input.size(0)) 52 | top1.update(precision1[0],input.size(0)) 53 | top2.update(precision2[0],input.size(0)) 54 | 55 | return [losses.avg,top1.avg,top2.avg] 56 | 57 | #3. test model on public dataset and save the probability matrix 58 | def test(test_loader,model,folds): 59 | #3.1 confirm the model converted to cuda 60 | csv_map = OrderedDict({"filename":[],"probability":[]}) 61 | model.cuda() 62 | model.eval() 63 | with open("./submit/baseline.json","w",encoding="utf-8") as f : 64 | submit_results = [] 65 | for i,(input,filepath) in enumerate(tqdm(test_loader)): 66 | #3.2 change everything to cuda and get only basename 67 | filepath = [os.path.basename(x) for x in filepath] 68 | with torch.no_grad(): 69 | image_var = Variable(input).cuda() 70 | #3.3.output 71 | #print(filepath) 72 | #print(input,input.shape) 73 | y_pred = model(image_var) 74 | #print(y_pred.shape) 75 | smax = nn.Softmax(1) 76 | smax_out = smax(y_pred) 77 | #3.4 save probability to csv files 78 | csv_map["filename"].extend(filepath) 79 | for output in smax_out: 80 | prob = ";".join([str(i) for i in output.data.tolist()]) 81 | csv_map["probability"].append(prob) 82 | result = pd.DataFrame(csv_map) 83 | result["probability"] = result["probability"].map(lambda x : [float(i) for i in x.split(";")]) 84 | for index, row in result.iterrows(): 85 | pred_label = np.argmax(row['probability']) 86 | if pred_label > 43: 87 | pred_label = pred_label + 2 88 | submit_results.append({"image_id":row['filename'],"disease_class":pred_label}) 89 | json.dump(submit_results,f,ensure_ascii=False,cls = MyEncoder) 90 | 91 | #4. more details to build main function 92 | def main(): 93 | fold = 0 94 | #4.1 mkdirs 95 | if not os.path.exists(config.submit): 96 | os.mkdir(config.submit) 97 | if not os.path.exists(config.weights): 98 | os.mkdir(config.weights) 99 | if not os.path.exists(config.best_models): 100 | os.mkdir(config.best_models) 101 | if not os.path.exists(config.logs): 102 | os.mkdir(config.logs) 103 | if not os.path.exists(config.weights + config.model_name + os.sep +str(fold) + os.sep): 104 | os.makedirs(config.weights + config.model_name + os.sep +str(fold) + os.sep) 105 | if not os.path.exists(config.best_models + config.model_name + os.sep +str(fold) + os.sep): 106 | os.makedirs(config.best_models + config.model_name + os.sep +str(fold) + os.sep) 107 | #4.2 get model and optimizer 108 | model = get_net() 109 | #model = torch.nn.DataParallel(model) 110 | model.cuda() 111 | #optimizer = optim.SGD(model.parameters(),lr = config.lr,momentum=0.9,weight_decay=config.weight_decay) 112 | optimizer = optim.Adam(model.parameters(),lr = config.lr,amsgrad=True,weight_decay=config.weight_decay) 113 | criterion = nn.CrossEntropyLoss().cuda() 114 | #criterion = FocalLoss().cuda() 115 | log = Logger() 116 | log.open(config.logs + "log_train.txt",mode="a") 117 | log.write("\n----------------------------------------------- [START %s] %s\n\n" % (datetime.now().strftime('%Y-%m-%d %H:%M:%S'), '-' * 51)) 118 | #4.3 some parameters for K-fold and restart model 119 | start_epoch = 0 120 | best_precision1 = 0 121 | best_precision_save = 0 122 | resume = False 123 | 124 | #4.4 restart the training process 125 | if resume: 126 | checkpoint = torch.load(config.best_models + str(fold) + "/model_best.pth.tar") 127 | start_epoch = checkpoint["epoch"] 128 | fold = checkpoint["fold"] 129 | best_precision1 = checkpoint["best_precision1"] 130 | model.load_state_dict(checkpoint["state_dict"]) 131 | optimizer.load_state_dict(checkpoint["optimizer"]) 132 | 133 | #4.5 get files and split for K-fold dataset 134 | #4.5.1 read files 135 | train_ = get_files(config.train_data,"train") 136 | #val_data_list = get_files(config.val_data,"val") 137 | test_files = get_files(config.test_data,"test") 138 | 139 | """ 140 | #4.5.2 split 141 | split_fold = StratifiedKFold(n_splits=3) 142 | folds_indexes = split_fold.split(X=origin_files["filename"],y=origin_files["label"]) 143 | folds_indexes = np.array(list(folds_indexes)) 144 | fold_index = folds_indexes[fold] 145 | 146 | #4.5.3 using fold index to split for train data and val data 147 | train_data_list = pd.concat([origin_files["filename"][fold_index[0]],origin_files["label"][fold_index[0]]],axis=1) 148 | val_data_list = pd.concat([origin_files["filename"][fold_index[1]],origin_files["label"][fold_index[1]]],axis=1) 149 | """ 150 | train_data_list,val_data_list = train_test_split(train_,test_size = 0.15,stratify=train_["label"]) 151 | #4.5.4 load dataset 152 | train_dataloader = DataLoader(ChaojieDataset(train_data_list),batch_size=config.batch_size,shuffle=True,collate_fn=collate_fn,pin_memory=True) 153 | val_dataloader = DataLoader(ChaojieDataset(val_data_list,train=False),batch_size=config.batch_size,shuffle=True,collate_fn=collate_fn,pin_memory=False) 154 | test_dataloader = DataLoader(ChaojieDataset(test_files,test=True),batch_size=1,shuffle=False,pin_memory=False) 155 | #scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,"max",verbose=1,patience=3) 156 | scheduler = optim.lr_scheduler.StepLR(optimizer,step_size = 10,gamma=0.1) 157 | #4.5.5.1 define metrics 158 | train_losses = AverageMeter() 159 | train_top1 = AverageMeter() 160 | train_top2 = AverageMeter() 161 | valid_loss = [np.inf,0,0] 162 | model.train() 163 | #logs 164 | log.write('** start training here! **\n') 165 | log.write(' |------------ VALID -------------|----------- TRAIN -------------|------Accuracy------|------------|\n') 166 | log.write('lr iter epoch | loss top-1 top-2 | loss top-1 top-2 | Current Best | time |\n') 167 | log.write('-------------------------------------------------------------------------------------------------------------------------------\n') 168 | #4.5.5 train 169 | start = timer() 170 | for epoch in range(start_epoch,config.epochs): 171 | scheduler.step(epoch) 172 | # train 173 | #global iter 174 | for iter,(input,target) in enumerate(train_dataloader): 175 | #4.5.5 switch to continue train process 176 | model.train() 177 | input = Variable(input).cuda() 178 | target = Variable(torch.from_numpy(np.array(target)).long()).cuda() 179 | #target = Variable(target).cuda() 180 | output = model(input) 181 | loss = criterion(output,target) 182 | 183 | precision1_train,precision2_train = accuracy(output,target,topk=(1,2)) 184 | train_losses.update(loss.item(),input.size(0)) 185 | train_top1.update(precision1_train[0],input.size(0)) 186 | train_top2.update(precision2_train[0],input.size(0)) 187 | #backward 188 | optimizer.zero_grad() 189 | loss.backward() 190 | optimizer.step() 191 | lr = get_learning_rate(optimizer) 192 | print('\r',end='',flush=True) 193 | print('%0.4f %5.1f %6.1f | %0.3f %0.3f %0.3f | %0.3f %0.3f %0.3f | %s | %s' % (\ 194 | lr, iter/len(train_dataloader) + epoch, epoch, 195 | valid_loss[0], valid_loss[1], valid_loss[2], 196 | train_losses.avg, train_top1.avg, train_top2.avg,str(best_precision_save), 197 | time_to_str((timer() - start),'min')) 198 | , end='',flush=True) 199 | #evaluate 200 | lr = get_learning_rate(optimizer) 201 | #evaluate every half epoch 202 | valid_loss = evaluate(val_dataloader,model,criterion) 203 | is_best = valid_loss[1] > best_precision1 204 | best_precision1 = max(valid_loss[1],best_precision1) 205 | try: 206 | best_precision_save = best_precision1.cpu().data.numpy() 207 | except: 208 | pass 209 | save_checkpoint({ 210 | "epoch":epoch + 1, 211 | "model_name":config.model_name, 212 | "state_dict":model.state_dict(), 213 | "best_precision1":best_precision1, 214 | "optimizer":optimizer.state_dict(), 215 | "fold":fold, 216 | "valid_loss":valid_loss, 217 | },is_best,fold) 218 | #adjust learning rate 219 | #scheduler.step(valid_loss[1]) 220 | print("\r",end="",flush=True) 221 | log.write('%0.4f %5.1f %6.1f | %0.3f %0.3f %0.3f | %0.3f %0.3f %0.3f | %s | %s' % (\ 222 | lr, 0 + epoch, epoch, 223 | valid_loss[0], valid_loss[1], valid_loss[2], 224 | train_losses.avg, train_top1.avg, train_top2.avg, str(best_precision_save), 225 | time_to_str((timer() - start),'min')) 226 | ) 227 | log.write('\n') 228 | time.sleep(0.01) 229 | best_model = torch.load(config.best_models + os.sep+config.model_name+os.sep+ str(fold) +os.sep+ 'model_best.pth.tar') 230 | model.load_state_dict(best_model["state_dict"]) 231 | test(test_dataloader,model,fold) 232 | 233 | if __name__ =="__main__": 234 | main() 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | --------------------------------------------------------------------------------