├── README.MD ├── dataset.py ├── main.py ├── models ├── __init__.py ├── basic_module.py ├── loss.py └── resnet.py ├── requirements.txt ├── scene.pth ├── scripts └── data_process.py └── utils.py /README.MD: -------------------------------------------------------------------------------- 1 | # PyTorch Baseline for AI challenger scene 2 | AI Challenger [场景分类竞赛](https://challenger.ai/competition/scene/)的示例代码 3 | 4 | ## 使用 5 | ### 1.环境配置 6 | 安装: 7 | - [PyTorch](http://pytorch.org/),根据官网说明下载指定的版本即可。 8 | - 第三方依赖: `pip install -r requirements.txt` 9 | 10 | ### 2.数据预处理: 11 | 12 | 这一步主要是根据json文件进行简单的处理,生成二进制scene.pth,可以跳过这一步,直接下载随程序附带的scene.pth 13 | 14 | 如果你想自己生成scene.pth,修改scripts/data_process.py 中的文件路径,然后运行 15 | ```bash 16 | python scripts/data_process.py 17 | ``` 18 | 19 | ### 3.启动visdom 20 | 可视化工具[visdom](https://github.com/facebookresearch/visdom) 21 | ```bash 22 | nohup python2 -m visdom.server& 23 | ``` 24 | 25 | ### 4.训练 26 | 训练之前还需要新建`checkpoints`文件夹用来保存模型`mkdir checkpoints`。 27 | 注意修改utils.py 中文件路径 28 | 29 | ```bash 30 | python main.py train --model='resnet34' 31 | ``` 32 | 在 Titan Xp下,大概90分钟可以在验证集上得到大约0.938的准确率 33 | 34 | 35 | ```bash 36 | python main.py train --model='resnet365' 37 | ``` 38 | 39 | 使用place365的预训练模型resnet50, 可以在验证集达到**0.957**的top3分数, 40 | 41 | 42 | 打开浏览器 输入http://ip:8097 可以看到训练过程。visdom 中要用到两个js文件`plotly.min.js`和`react-grid-layout.min.js`,这两个js文件被防火墙所拦截~ 所以你可能需要自备梯子才能用visdom。 43 | 44 | 另外一个解决方法是: 45 | `locate locate visdom/static/index.html`,修改index.html中两行js的地址 46 | 47 | ### 5.提交 48 | ```bash 49 | python main.py submit --model='resnet34' --load-path='res34_1018_2204_0.938002232143' 50 | 51 | ``` 52 | 会在当前目录生成`result.json`文件,直接提交即可 53 | 54 | ## 关于CPU运行 55 | 把所有`.cuda()`代码去掉,就能使得程序在CPU上运行 56 | 57 | 58 | ## 各个文件说明 59 | 欢迎参考之前在[知乎专栏](https://zhuanlan.zhihu.com/p/29024978)写过的一篇关于PyTorch的文件组织安排的文章了解每个文件的作用 60 | 61 | `models/`: 存放各个模型定义,所有的模型继承自`basic_module.py`中的`BasicModule`. 62 | 63 | `models/resnet.py`: 对torchvision中的resnet18, resnet34, resnet50, resnet101 和resnet152 进行了简单的封装。 64 | 65 | `dataset.py`: 数据加载相关 66 | 67 | `main.py`: 主程序,包含训练和测试 68 | 69 | `utils.py`: 可视化工具visdom的封装,计算top3准确率函数,可配置变量(可通过命令行参数修改,也可以通过修改文件配置)等。 70 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | #coding:utf8 2 | import torch as t 3 | from torchvision import transforms 4 | from torch.utils import data 5 | import os 6 | import PIL 7 | import random 8 | 9 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 10 | std=[0.229, 0.224, 0.225]) 11 | 12 | def load(path): 13 | return PIL.Image.open(path).convert('RGB') 14 | 15 | class ClsDataset(data.Dataset): 16 | def __init__(self,opt): 17 | self.opt = opt 18 | self.datas = t.load(opt.meta_path) 19 | 20 | self.val_transforms = transforms.Compose([ 21 | transforms.Scale(opt.img_size), 22 | transforms.CenterCrop(opt.img_size), 23 | transforms.ToTensor(), 24 | normalize, 25 | ]) 26 | self.train_transforms = transforms.Compose([ 27 | transforms.RandomSizedCrop(opt.img_size), 28 | transforms.RandomHorizontalFlip(), 29 | transforms.ToTensor(), 30 | normalize, 31 | ]) 32 | self.train() 33 | 34 | def __getitem__(self,index): 35 | img_path = os.path.join(self.path,self.imgs[index]) 36 | img = load(img_path) 37 | img = self.transforms(img) 38 | return img,self.labels[index],self.imgs[index] 39 | 40 | def train(self): 41 | data = self.datas['train'] 42 | self.imgs,self.labels = data['ids'],data['labels'] 43 | self.path = self.opt.train_dir 44 | self.transforms = self.train_transforms 45 | return self 46 | 47 | def test(self): 48 | data= self.datas['test1'] 49 | self.imgs,self.labels = data['ids'],data['labels'] 50 | self.path = self.opt.test_dir 51 | self.transforms=self.val_transforms 52 | return self 53 | 54 | def val(self): 55 | data = self.datas['val'] 56 | self.imgs,self.labels = data['ids'],data['labels'] 57 | self.path = self.opt.val_dir 58 | self.transforms=self.val_transforms 59 | return self 60 | 61 | def __len__(self): 62 | return len(self.imgs) 63 | 64 | if __name__=='__main__': 65 | test 66 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #coding:utf8 2 | import os,json,time,fire,ipdb,tqdm 3 | 4 | import numpy as np 5 | import torch as t 6 | from torch import nn, optim 7 | from torch.autograd import Variable 8 | from torchnet import meter 9 | 10 | from utils import Visualizer,topk_acc,opt 11 | from dataset import ClsDataset 12 | import models 13 | 14 | vis = Visualizer(env=opt.env) 15 | 16 | def submit(**kwargs): 17 | ''' 18 | 测试验证集,并生成可提交的json文件 19 | ''' 20 | opt.parse(kwargs) 21 | 22 | # 模型 23 | model = getattr(models,opt.model)(opt) 24 | model.load(opt.load_path) 25 | model.eval().cuda() 26 | 27 | # 数据 28 | dataset = ClsDataset(opt) 29 | dataset.test() 30 | dataloader = t.utils.data.DataLoader(dataset,opt.batch_size, shuffle=False, num_workers=opt.workers,pin_memory=True) 31 | 32 | # 测试 33 | results = [] 34 | for ii, data in tqdm.tqdm(enumerate(dataloader)): 35 | input, label,image_ids = data 36 | val_input = Variable(input, volatile=True).cuda() 37 | val_label = Variable(label.type(t.LongTensor), volatile=True).cuda() 38 | score = model(val_input) 39 | predict = score.data.topk(k=3)[1].tolist() 40 | result = [ {"image_id": image_id, 41 | "label_id": label_id } 42 | for image_id,label_id in zip(image_ids,predict) ] 43 | results+=result 44 | 45 | # 保存文件 46 | with open(opt.result_path,'w') as f: 47 | json.dump(results,f) 48 | 49 | 50 | def val(model,dataset): 51 | ''' 52 | 计算模型在验证集上的准确率 53 | 返回top1和top3的准确率 54 | ''' 55 | model.eval() 56 | dataset.val() 57 | acc_meter = meter.AverageValueMeter() 58 | top1_meter = meter.AverageValueMeter() 59 | dataloader = t.utils.data.DataLoader(dataset,opt.batch_size, opt.shuffle, num_workers=opt.workers,pin_memory=True) 60 | for ii, data in tqdm.tqdm(enumerate(dataloader)): 61 | input, label,_ = data 62 | val_input = Variable(input, volatile=True).cuda() 63 | val_label = Variable(label.type(t.LongTensor), volatile=True).cuda() 64 | score = model(val_input) 65 | acc = topk_acc(score.data,label.cuda() ) 66 | top1 = topk_acc(score.data,label.cuda(),k=1) 67 | acc_meter.add(acc) 68 | top1_meter.add(top1) 69 | model.train() 70 | dataset.train() 71 | print(acc_meter.value()[0],top1_meter.value()[0]) 72 | return acc_meter.value()[0], top1_meter.value()[0] 73 | 74 | 75 | def train(**kwargs): 76 | ''' 77 | 训练模型 78 | ''' 79 | opt.parse(kwargs) 80 | 81 | lr1, lr2 = opt.lr1, opt.lr2 82 | vis.vis.env = opt.env 83 | 84 | # 模型 85 | model = getattr(models,opt.model)(opt) 86 | if opt.load_path: 87 | model.load(opt.load_path) 88 | print(model) 89 | model.cuda() 90 | optimizer = model.get_optimizer(lr1,lr2) 91 | criterion = getattr(models,opt.loss)() 92 | 93 | # 指标:求均值 94 | loss_meter = meter.AverageValueMeter() 95 | acc_meter = meter.AverageValueMeter() 96 | top1_meter = meter.AverageValueMeter() 97 | 98 | step = 0 99 | max_acc = 0 100 | vis.vis.texts = '' 101 | 102 | # 数据 103 | dataset = ClsDataset(opt) 104 | dataloader = t.utils.data.DataLoader(dataset, opt.batch_size, opt.shuffle, num_workers=opt.workers,pin_memory=True) 105 | 106 | # 训练 107 | for epoch in range(opt.max_epoch): 108 | loss_meter.reset() 109 | acc_meter.reset() 110 | top1_meter.reset() 111 | 112 | for ii, data in tqdm.tqdm(enumerate(dataloader, 0)): 113 | # 训练 114 | optimizer.zero_grad() 115 | input, label,_ = data 116 | input = Variable(input.cuda()) 117 | label = Variable(label.cuda()) 118 | output = model(input).view(input.size(0),-1) 119 | error = criterion(output, label) 120 | error.backward() 121 | optimizer.step() 122 | 123 | # 计算损失的均值和训练集的准确率均值 124 | loss_meter.add(error.data[0]) 125 | acc = topk_acc(output.data,label.data) 126 | acc_meter.add(acc) 127 | top1_acc = topk_acc(output.data,label.data,k=1) 128 | top1_meter.add(top1_acc) 129 | 130 | # 可视化 131 | if (ii+1) % opt.plot_every == 0: 132 | if os.path.exists(opt.debug_file): 133 | ipdb.set_trace() 134 | 135 | log_values = dict(loss = loss_meter.value()[0], 136 | train_acc = acc_meter.value()[0], 137 | epoch = epoch, 138 | ii = ii, 139 | train_top1_acc= top1_meter.value()[0] 140 | ) 141 | vis.plot_many(log_values) 142 | 143 | # 数据跑一遍之后,计算在验证集上的分数 144 | accuracy,top1_accuracy = val(model,dataset) 145 | vis.plot('val_acc', accuracy) 146 | vis.plot('val_top1',top1_accuracy) 147 | info = time.strftime('[%m%d_%H%M%S]') + 'epoch:{epoch},val_acc:{val_acc},lr:{lr},val_top1:{val_top1},train_acc:{train_acc}
'.format( 148 | epoch=epoch, 149 | lr=lr1, 150 | val_acc=accuracy, 151 | val_top1=top1_accuracy, 152 | train_acc=acc_meter.value() 153 | ) 154 | vis.vis.texts += info 155 | vis.vis.text(vis.vis.texts, win=u'log') 156 | 157 | # 调整学习率 158 | # 如果验证集上准确率降低了,就降低学习率,并加载之前的最好模型 159 | # 否则保存模型,并记下模型保存路径 160 | if accuracy > max_acc: 161 | max_acc = accuracy 162 | best_path = model.save(accuracy) 163 | else: 164 | if lr1==0: lr1=lr2 165 | model.load(best_path) 166 | lr1, lr2 = lr1 *opt.lr_decay, lr2 * opt.lr_decay 167 | optimizer = model.get_optimizer(lr1,lr2) 168 | 169 | vis.vis.save([opt.env]) 170 | 171 | if __name__ == '__main__': 172 | fire.Fire() 173 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from resnet import resnet18,resnet34,resnet50,resnet101,resnet152,resnet365 2 | from loss import * 3 | -------------------------------------------------------------------------------- /models/basic_module.py: -------------------------------------------------------------------------------- 1 | 2 | #coding:utf8 3 | import torch as t 4 | import time 5 | class BasicModule(t.nn.Module): 6 | ''' 7 | 封装了nn.Module 8 | ''' 9 | 10 | def __init__(self,opt=None): 11 | super(BasicModule,self).__init__() 12 | self.model_name=str(type(self).__name__)# 默认名字 13 | self.opt = opt 14 | 15 | def load(self, path,map_location=lambda storage, loc: storage): 16 | checkpoint = t.load(path,map_location=map_location) 17 | if 'opt' in checkpoint: 18 | self.load_state_dict(checkpoint['d']) 19 | print('old config:') 20 | print(checkpoint['opt']) 21 | else: 22 | self.load_state_dict(checkpoint) 23 | # for k,v in checkpoint['opt'].items(): 24 | # setattr(self.opt,k,v) 25 | 26 | def save(self, name=''): 27 | format = 'checkpoints/'+self.model_name+'_%m%d_%H%M_' 28 | file_name = time.strftime(format) + str(name) 29 | 30 | state_dict = self.state_dict() 31 | opt_state_dict = dict(self.opt.state_dict()) 32 | optimizer_state_dict = self.optimizer.state_dict() 33 | 34 | t.save({'d':state_dict,'opt':opt_state_dict,'optimizer':optimizer_state_dict}, file_name) 35 | return file_name 36 | 37 | def get_optimizer(self,lr1,lr2): 38 | self.optimizer = t.optim.Adam( 39 | [ 40 | {'params': self.features.parameters(), 'lr': lr1}, 41 | {'params': self.classifier.parameters(), 'lr':lr2} 42 | ] ) 43 | return self.optimizer 44 | 45 | def update_optimizer(self,lr1,lr2): 46 | param_groups = self.optimizer.param_groups 47 | param_groups[0]['lr']=lr1 48 | param_groups[1]['lr']=lr2 49 | return self.optimizer 50 | -------------------------------------------------------------------------------- /models/loss.py: -------------------------------------------------------------------------------- 1 | #coding:utf8 2 | import torch as t 3 | def celoss(): 4 | return t.nn.CrossEntropyLoss() 5 | 6 | def bloss(): 7 | def loss(s,l): 8 | pass 9 | return loss -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | #coding:utf8 2 | import torchvision as tv 3 | from torch import nn 4 | from basic_module import BasicModule 5 | import torch as t 6 | 7 | class ResNet(BasicModule): 8 | def __init__(self,model,opt=None,feature_dim=2048,name='resnet'): 9 | super(ResNet, self).__init__(opt) 10 | self.model_name=name 11 | 12 | model.avgpool = nn.AdaptiveAvgPool2d(1) 13 | del model.fc 14 | model.fc = lambda x:x 15 | self.features = model 16 | self.classifier = nn.Linear(feature_dim,80) 17 | 18 | def forward(self,x): 19 | features = self.features(x) 20 | return self.classifier(features) 21 | 22 | def resnet18(opt): 23 | model = tv.models.resnet18(pretrained=not opt.load_path) 24 | return ResNet(model,opt,feature_dim=512,name='res18') 25 | 26 | def resnet34(opt): 27 | model = tv.models.resnet34(pretrained=not opt.load_path) 28 | return ResNet(model,opt,feature_dim=512,name='res34') 29 | 30 | def resnet50(opt): 31 | model = tv.models.resnet50(pretrained=not opt.load_path) 32 | return ResNet(model,opt,name='res50') 33 | 34 | def resnet101(opt): 35 | model = tv.models.resnet101(pretrained=not opt.load_path) 36 | return ResNet(model,opt,name='res101') 37 | 38 | def resnet152(opt): 39 | model = tv.models.resnet152(pretrained=not opt.load_path) 40 | return ResNet(model,opt,name='res152') 41 | 42 | def resnet365(opt): 43 | model = t.load('checkpoints/whole_resnet50_places365.pth.tar') 44 | # model = tv.models.resnet50() 45 | return ResNet(model,opt,name='res_365') 46 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | visdom 2 | fire 3 | ipdb 4 | torchvision 5 | git+https://github.com/pytorch/tnt.git@master -------------------------------------------------------------------------------- /scene.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenyuntc/scene-baseline/8d105747c81d87fa5e9f84b4999b3ea7abd389da/scene.pth -------------------------------------------------------------------------------- /scripts/data_process.py: -------------------------------------------------------------------------------- 1 | # 数据预处理脚本 2 | 3 | 4 | 5 | train_ann_file = '/data/image/ai_cha/scene/ai_challenger_scene_train_20170904/scene_train_annotations_20170904.json' 6 | val_ann_file = '/data/image/ai_cha/scene/ai_challenger_scene_validation_20170908/scene_validation_annotations_20170908.json' 7 | test_dir = '/data/image/ai_cha/scene/ai_challenger_scene_test_a_20170922/scene_test_a_images_20170922/' 8 | clas_file = '/data/image/ai_cha/scene/ai_challenger_scene_test_a_20170922/scene_classes.csv' 9 | 10 | import pandas as pd 11 | from collections import namedtuple 12 | # ClassName = namedtuple('ClassName',['zh','en']) 13 | # SceneDataAll = namedtuple('SceneDataAll',['train','val','test1','label']) 14 | # SceneData = namedtuple('SceneData',['ids','labels','id2ix']) 15 | 16 | 17 | a=pd.read_csv(clas_file,header=None) 18 | label_ids,label_zh,label_en = a[0],a[1],a[2] 19 | 20 | #79: ClassName(zh=u'\u96ea\u5c4b/\u51b0\u96d5(\u5c71)', en='igloo/ice_engraving')} 21 | id2label = {k:(v1.decode('utf8'),v2) for k,v1,v2 in zip(label_ids,label_zh,label_en)} 22 | # id2label = {k:ClassName(v1.decode('utf8'),v2) for k,v1,v2 in zip(label_ids,label_zh,label_en)} 23 | 24 | import json 25 | with open(train_ann_file) as f: 26 | datas = json.load(f) 27 | 28 | ids = [ii['image_id'] for ii in datas] 29 | labels = [int(ii['label_id']) for ii in datas] 30 | id2ix = {id:ix for ix,id in enumerate(ids)} 31 | 32 | 33 | #train = SceneData(ids,labels,id2ix) 34 | train = dict(ids=ids,labels = labels,id2ix=id2ix) 35 | 36 | with open(val_ann_file) as f: 37 | datas = json.load(f) 38 | 39 | ids = [ii['image_id'] for ii in datas] 40 | labels = [int(ii['label_id']) for ii in datas] 41 | id2ix = {id:ix for ix,id in enumerate(ids)} 42 | val = dict(ids=ids,labels = labels,id2ix=id2ix) 43 | # val = SceneData(ids,labels,id2ix) 44 | import os 45 | ids = os.listdir(test_dir) 46 | id2ix = {id:ix for ix,id in enumerate(ids)} 47 | # test = SceneData(ids,None,id2ix) 48 | test = dict(ids=ids,labels = labels,id2ix=id2ix) 49 | # all = SceneDataAll(train,val,test,id2label) 50 | 51 | all = dict(train=train,test1=test,val=val,id2label=id2label) 52 | import torch as t 53 | t.save(all,'scene.pth') -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #coding:utf8 2 | import visdom 3 | import numpy as np 4 | 5 | class Visualizer(): 6 | ''' 7 | 对可视化工具visdom的封装 8 | ''' 9 | def __init__(self, env, **kwargs): 10 | import visdom 11 | self.vis = visdom.Visdom(env=env, **kwargs) 12 | self.index = {} 13 | 14 | def plot_many(self, d): 15 | for k, v in d.iteritems(): 16 | self.plot(k, v) 17 | 18 | def plot(self, name, y): 19 | x = self.index.get(name, 0) 20 | self.vis.line(Y=np.array([y]), X=np.array([x]), 21 | win=unicode(name), 22 | opts=dict(title=name), 23 | update=None if x == 0 else 'append' 24 | ) 25 | self.index[name] = x + 1 26 | 27 | def topk_acc(score,label,k=3): 28 | ''' 29 | topk accuracy,默认是top3准确率 30 | ''' 31 | topk = score.topk(k)[1] 32 | label = label.view(-1,1).expand_as(topk) 33 | acc = (label == topk).float().sum()/(0.0+label.size(0)) 34 | return acc 35 | 36 | class Config: 37 | train_dir = '/data/image/ai_cha/scene/sl/train/' 38 | test_dir = '/data/image/ai_cha/scene/sl/testa' 39 | val_dir = '/data/image/ai_cha/scene/sl/val' 40 | meta_path = '/data/image/ai_cha/scene/sl/scene.pth' 41 | img_size=256 42 | 43 | lr1 = 0 44 | lr2 = 0.0005 45 | lr_decay = 0.5 46 | batch_size = 128 47 | max_epoch = 100 48 | debug_file = '/tmp/debugc' 49 | shuffle = True 50 | env = 'scene' # visdom env 51 | plot_every = 10 # 每10步可视化一次 52 | 53 | workers = 4 # CPU多线程加载数据 54 | load_path=None# 55 | model = 'resnet50'# 具体名称查看 models/__init__.py 56 | loss='celoss' 57 | result_path='result.json' #提交文件保存路径 58 | 59 | def parse(self,kwargs,print_=True): 60 | ''' 61 | 根据字典kwargs 更新 config参数 62 | ''' 63 | for k,v in kwargs.iteritems(): 64 | if not hasattr(self,k): 65 | raise Exception("opt has not attribute <%s>" %k) 66 | setattr(self,k,v) 67 | if print_: 68 | print('user config:') 69 | print('#################################') 70 | for k in dir(self): 71 | if not k.startswith('_') and k!='parse' and k!='state_dict': 72 | print k,getattr(self,k) 73 | print('#################################') 74 | return self 75 | 76 | 77 | def state_dict(self): 78 | return {k:getattr(self,k) for k in dir(self) if not k.startswith('_') and k!='parse' and k!='state_dict' } 79 | 80 | Config.parse = parse 81 | Config.state_dict = state_dict 82 | opt = Config() 83 | --------------------------------------------------------------------------------