├── docs ├── Table_V.jpg ├── akmnet.png ├── module.jpg ├── Table_III.jpg ├── Table_IV.jpg ├── Table_VII.jpg └── akmnetoverview.png ├── TEST.py ├── util.py ├── datasets ├── dataset.py └── transforms.py ├── README.md ├── main.py └── models └── resnet.py /docs/Table_V.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trunpm/AKMNet-Micro-Expression/HEAD/docs/Table_V.jpg -------------------------------------------------------------------------------- /docs/akmnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trunpm/AKMNet-Micro-Expression/HEAD/docs/akmnet.png -------------------------------------------------------------------------------- /docs/module.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trunpm/AKMNet-Micro-Expression/HEAD/docs/module.jpg -------------------------------------------------------------------------------- /docs/Table_III.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trunpm/AKMNet-Micro-Expression/HEAD/docs/Table_III.jpg -------------------------------------------------------------------------------- /docs/Table_IV.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trunpm/AKMNet-Micro-Expression/HEAD/docs/Table_IV.jpg -------------------------------------------------------------------------------- /docs/Table_VII.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trunpm/AKMNet-Micro-Expression/HEAD/docs/Table_VII.jpg -------------------------------------------------------------------------------- /docs/akmnetoverview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trunpm/AKMNet-Micro-Expression/HEAD/docs/akmnetoverview.png -------------------------------------------------------------------------------- /TEST.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | import torchvision 7 | from torch.autograd import Variable 8 | from collections import OrderedDict 9 | 10 | 11 | ###Data require 12 | import argparse 13 | from datasets.dataset import VolumeDataset 14 | from datasets.transforms import * 15 | from torch.utils.data import DataLoader 16 | 17 | # ###Model require 18 | from models import resnet 19 | 20 | 21 | 22 | parser = argparse.ArgumentParser('Resnets') 23 | parser.add_argument('--seed', type=int, default=1) 24 | 25 | # ========================= Data Configs ========================== 26 | parser.add_argument('--data_root_train', type=str, default='') 27 | parser.add_argument('--list_file_train', type=str, default='./Train.txt') 28 | parser.add_argument('--data_root_test', type=str, default='') 29 | parser.add_argument('--list_file_test', type=str, default='./Test.txt') 30 | parser.add_argument('--modality', type=str, default='Gray', help='RGB | Gray') 31 | parser.add_argument('--batch_size', type=int, default=1) 32 | parser.add_argument('--num_workers', type=int, default=16) 33 | 34 | # ========================= Model Configs ========================== 35 | parser.add_argument('--num_classes', default=4, type=int, help='Number of classes') 36 | parser.add_argument('--no_cuda', action='store_true', help='If true, cuda is not used.') 37 | parser.set_defaults(no_cuda=False) 38 | 39 | parser.add_argument('--device_ids', type=int, default=1) 40 | 41 | # ========================= Model Save ========================== 42 | parser.add_argument('--checkpoint_path', type=str, default='') 43 | 44 | args = parser.parse_args() 45 | 46 | 47 | 48 | ###main 49 | args.output_file = './TEST.txt' 50 | ###load model 51 | model = resnet.resnet18(pretrained=False, num_classes=args.num_classes) 52 | if args.checkpoint_path is '': 53 | args.checkpoint_path='./pt/epoch10.pt' 54 | model.load_state_dict(torch.load(args.checkpoint_path)) 55 | if not args.no_cuda: 56 | model = model.cuda(args.device_ids) 57 | 58 | 59 | test_dataset = VolumeDataset(data_root=args.data_root_test, list_file_root=args.list_file_test, modality=args.modality, 60 | transform=torchvision.transforms.Compose([ 61 | GroupScale((128,128)), 62 | ToTorchFormatTensor(div=True), 63 | ]), 64 | ) 65 | test_loader = DataLoader(test_dataset,batch_size=1,shuffle=False,num_workers=args.num_workers,drop_last=False) 66 | 67 | model.eval() 68 | count_correct = 0. 69 | with torch.no_grad(): 70 | for i_batch, sample_batch in enumerate(test_loader): 71 | Volume = Variable(sample_batch['Volume']).cuda(args.device_ids) 72 | labels = Variable(sample_batch['label']).long().cuda(args.device_ids) 73 | 74 | Bw,B,outputs = model(Volume) 75 | 76 | _,pred = torch.max(outputs, 1) 77 | count_correct += torch.sum(pred == labels) 78 | 79 | with open(args.output_file, 'a') as out_file: 80 | out_file.write('labels is:{0} pred is:{1}\n'.format(labels.data[0].cpu().numpy(),pred.data[0].cpu().numpy())) 81 | out_file.write('B is:{0}\n'.format(B.data[0].cpu().numpy().tolist())) 82 | out_file.write('Bw is:{0}\n'.format(Bw.data[0].cpu().numpy().tolist())) 83 | print("Total acc is:",float(count_correct) / len(test_loader.dataset)) 84 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | from collections import OrderedDict 5 | import time 6 | import math 7 | from torch.autograd import Variable 8 | import numpy as np 9 | 10 | 11 | 12 | def tr_epoch(model, data_loader, criterion1, criterionB, criterionBw, optimizer, args): 13 | # training----------------------------- 14 | model.train() 15 | loss_value1 = 0. 16 | loss_valueBw = 0. 17 | loss_valueB = 0. 18 | for i_batch, sample_batch in enumerate(data_loader): 19 | Volume = Variable(sample_batch['Volume']).cuda(args.device_ids) 20 | labels = Variable(sample_batch['label']).long().cuda(args.device_ids) 21 | 22 | Bw,B,outputs = model(Volume) 23 | loss1 = criterion1(outputs, labels) 24 | lossBw = criterionBw(Bw) 25 | lossB = criterionB(B) 26 | loss = loss1+lossBw+lossB 27 | loss_value1 += loss1 28 | loss_valueBw += lossBw 29 | loss_valueB += lossB 30 | 31 | if (i_batch+1)>int(np.floor(len(data_loader.dataset)/8))*8: 32 | loss = loss/(len(data_loader.dataset)-int(np.floor(len(data_loader.dataset)/8))*8) 33 | loss.backward() 34 | if (i_batch+1)==len(data_loader.dataset): 35 | optimizer.step() 36 | optimizer.zero_grad() 37 | else: 38 | loss = loss/8 39 | loss.backward() 40 | if (i_batch+1)%8==0: 41 | optimizer.step() 42 | optimizer.zero_grad() 43 | 44 | print('epoch Loss1: {:.6f}'.format(float(loss_value1.data)/(i_batch+1)), 'epoch LossBw: {:.6f}'.format(float(loss_valueBw.data)/(i_batch+1)), 'epoch LossB: {:.6f}'.format(float(loss_valueB.data)/(i_batch+1))) 45 | with open('./logtrain.txt', 'a') as out_file: 46 | out_file.write('epoch Loss1:{0},epoch LossBw:{1},epoch LossB:{2}'.format(float(loss_value1.data)/(i_batch+1),float(loss_valueBw.data)/(i_batch+1),float(loss_valueB.data)/(i_batch+1))+'\n') 47 | 48 | def ts_epoch(model, data_loader, criterion1, criterionB, criterionBw, args): 49 | model.eval() 50 | count_correct = 0. 51 | loss_value1 = 0. 52 | loss_valueBw = 0. 53 | loss_valueB = 0. 54 | with torch.no_grad(): 55 | for i_batch, sample_batch in enumerate(data_loader): 56 | Volume = Variable(sample_batch['Volume']).cuda(args.device_ids) 57 | labels = Variable(sample_batch['label']).long().cuda(args.device_ids) 58 | 59 | Bw,B,outputs = model(Volume) 60 | 61 | loss1 = criterion1(outputs, labels) 62 | lossBw = criterionBw(Bw) 63 | lossB = criterionB(B) 64 | loss_value1 += loss1 65 | loss_valueBw += lossBw 66 | loss_valueB += lossB 67 | 68 | _,pred = torch.max(outputs, 1) 69 | count_correct += torch.sum(pred == labels) 70 | 71 | print('Test Loss1: {:.6f}'.format(float(loss_value1.data)/(i_batch+1)),'epoch LossBw: {:.6f}'.format(float(loss_valueBw.data)/(i_batch+1)), 'epoch LossB: {:.6f}'.format(float(loss_valueB.data)/(i_batch+1))) 72 | print('Acc is:', float(count_correct) / len(data_loader.dataset)) 73 | with open('./logtest.txt', 'a') as out_file: 74 | out_file.write('Test Loss1:{0},epoch LossBw:{1},epoch LossB:{2},acc is:{3}'.format(float(loss_value1.data)/(i_batch+1),float(loss_valueBw.data)/(i_batch+1),float(loss_valueB.data)/(i_batch+1),float(count_correct) / len(data_loader.dataset))+'\n') 75 | 76 | return float(count_correct) / len(data_loader.dataset) 77 | -------------------------------------------------------------------------------- /datasets/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | from PIL import ImageFile 5 | ImageFile.LOAD_TRUNCATED_IMAGES = True 6 | import warnings 7 | warnings.filterwarnings('ignore') 8 | 9 | import torch 10 | import torch.utils.data as data 11 | 12 | 13 | 14 | class SampleProperty(object): 15 | def __init__(self, row): 16 | self._data = row 17 | 18 | @property 19 | def path(self): 20 | return self._data[0] 21 | 22 | @property 23 | def length(self): 24 | return self._data[1] 25 | 26 | @property 27 | def label(self): 28 | return int(self._data[2]) 29 | 30 | def Getimagesname(samplepath): 31 | imagesname=[] 32 | numstr_num={} 33 | if samplepath.find('/CASME/')!=-1: 34 | for image in os.listdir(samplepath): 35 | l = image.find('-') 36 | e = image.find('.') 37 | head = image[0:l+1] 38 | numstr_num[image[l+1:e]] = int(image[l+1:e]) 39 | end = '.jpg' 40 | numstr_num = sorted(numstr_num.items(),key=lambda x:x[1]) 41 | for t in numstr_num: 42 | imagesname.append(head+t[0]+end) 43 | return imagesname 44 | if samplepath.find('/CASMEII/')!=-1: 45 | for image in os.listdir(samplepath): 46 | l = image.find('img') 47 | e = image.find('.') 48 | head = image[0:l+3] 49 | numstr_num[image[l+3:e]] = int(image[l+3:e]) 50 | end = '.jpg' 51 | numstr_num = sorted(numstr_num.items(),key=lambda x:x[1]) 52 | for t in numstr_num: 53 | imagesname.append(head+t[0]+end) 54 | return imagesname 55 | if samplepath.find('/SAMM/')!=-1: 56 | for image in os.listdir(samplepath): 57 | l = image.find('_') 58 | e = image.find('.') 59 | head = image[0:l+1] 60 | numstr_num[image[l+1:e]] = int(image[l+1:e]) 61 | end = '.jpg' 62 | numstr_num = sorted(numstr_num.items(),key=lambda x:x[1]) 63 | for t in numstr_num: 64 | imagesname.append(head+t[0]+end) 65 | return imagesname 66 | if samplepath.find('/SMIC/')!=-1: 67 | for image in os.listdir(samplepath): 68 | l = image.find('image') 69 | e = image.find('.') 70 | head = 'image' 71 | numstr_num[image[l+5:e]] = int(image[l+5:e]) 72 | end = '.bmp' 73 | numstr_num = sorted(numstr_num.items(),key=lambda x:x[1]) 74 | for t in numstr_num: 75 | imagesname.append(head+t[0]+end) 76 | return imagesname 77 | 78 | 79 | 80 | class VolumeDataset(data.Dataset): 81 | def __init__(self, data_root, list_file_root, modality='Gray', transform=None): 82 | self.data_root = data_root 83 | self.list_file_root = list_file_root 84 | self.modality = modality 85 | self.transform = transform 86 | self._images_load() 87 | def _images_load(self): 88 | self.Sample_List = [SampleProperty(x.strip().split(' ')) for x in open(self.list_file_root)] 89 | 90 | def __getitem__(self, idx): 91 | sample = self.Sample_List[idx] 92 | 93 | Volume_temp = list() 94 | imagesname = Getimagesname(sample.path) 95 | for i in imagesname: 96 | if self.modality == 'RGB': 97 | image = Image.open(os.path.join(self.data_root, sample.path, i)).convert('RGB') 98 | if self.modality == 'Gray': 99 | image = Image.open(os.path.join(self.data_root, sample.path, i)).convert('L') 100 | Volume_temp.append(image) 101 | 102 | if self.transform is not None: 103 | Volume = self.transform(Volume_temp) ###C L H W 104 | 105 | SampleVolum = {'Volume': Volume, 'label': sample.label} 106 | return SampleVolum 107 | 108 | def __len__(self): 109 | return len(self.Sample_List) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AKMNet for Micro-Expression 2 | Recognizing Micro-Expression in Video Clip with Adaptive Key-Frame Mining 3 | 4 | ![alt text](docs/akmnetoverview.png 'overview of the network') 5 | 6 | # Purpose 7 | The existing representation based on various deep learning techniques learned from a full video clip is usually redundant. In addition, methods utilizing the single apex frame of each video clip require expert annotations and sacrifice the temporal dynamics. In our paper, we propose a novel end-to-end deep learning architecture, referred to as adaptive key-frame mining network (AKMNet). Operating on the video clip of micro-expression, AKMNet is able to learn discriminative spatio-temporal representation by combining spatial features of self-learned local key frames and their global-temporal dynamics. 8 | 9 | # Citation 10 | Peng, Min, Chongyang Wang, Yuan Gao, Tao Bi, Tong Chen, Yu Shi, and Xiang-Dong Zhou. "[Recognizing Micro-expression in Video Clip with Adaptive Key-frame Mining](https://arxiv.org/abs/2009.09179)", arXiv preprint arXiv:2009.09179 (2020). 11 | 12 | # Platform and dependencies 13 | Ubuntu 14.04 Python 3.7 CUDA8.0 CuDNN6.0+ 14 | pytorch==1.0.0 15 | 16 | # Data Preparation (option) 17 | * Download the dataset 18 | CASMEI: http://fu.psych.ac.cn/CASME/casme-en.php 19 | CASMEII: http://fu.psych.ac.cn/CASME/casme2-en.php 20 | SAMM: http://www2.docm.mmu.ac.uk/STAFF/m.yap/dataset.php 21 | SMIC: https://www.oulu.fi/cmvs/node/41319 22 | * preprocessing 23 | 1.you can also use the data in *cropped* fold to conduct the experiment. For SAMM dataset, the face detect and align method same as the paper *CASME II: An Improved Spontaneous Micro-Expression Database and the Baseline Evaluation* 24 | 2. for phase-based video magnification method, please ref to http://people.csail.mit.edu/nwadhwa/phase-video/ 25 | In our method, we do not need the frame normalization, the design of all modules in the AKMNet is independent on the length of input video clip 26 | 27 | # Method 28 | ![image](https://github.com/Trunpm/AKMNet-Micro-Expression/blob/main/docs/module.jpg) 29 | 30 | # Experiment 31 | * Comparison Experiment 32 | 33 | | *Methods* |*CASMEI*|*CASMEII*|*SMIC*|*SAMM*| 34 | |:-----------------:|:--------:|:----------:|:----------:|:----------:| 35 | | `LBP-TOP` | 0.6618 | 0.3843 | 0.3598 | 0.3899 | 36 | | `LBP-SIP` | 0.6026 | 0.4784 | 0.3842 | 0.5220 | 37 | | `STCLQP` | 0.6349 | 0.5922 | 0.5366 | 0.5283 | 38 | | `HIGO` | 0.5781 | 0.5137 | 0.3720 | 0.4465 | 39 | | `FHOFO ` | 0.6720 | 0.6471 | 0.5366 | 0.6038 | 40 | | `MDMO ` | 0.6825 | 0.6314 | 0.5793 | 0.6164 | 41 | | `Macro2Micro` | 0.6772 | 0.6078 | - | 0.6436 | 42 | | `MicroAttention` | 0.6825 | 0.6431 | - | 0.6489 | 43 | | `ATNet ` | 0.6720 | 0.6039 | - | 0.6543 | 44 | | `STSTNet ` | 0.6349 | 0.5529 | 0.5488 | 0.6289 | 45 | | `STRCN-G` | 0.7090 | 0.6039 | 0.6280 | 0.6478 | 46 | | # **AKMNet** |**0.7566** |**0.6706** |**0.7256** |**0.7170**| 47 | 48 | * Justification of the Adaptive Key-Frame Mining Module 49 | 50 | | *Methods* |*CASMEI*|*CASMEII*|*SMIC*|*SAMM*| 51 | |:-----------------:|:--------:|:----------:|:----------:|:----------:| 52 | | `AKMNetva-all` | 0.6618 | 0.3843 | 0.3598 | 0.3899 | 53 | | `AKMNetva-random` |0.6138 |0.6118 |0.5427 |0.6289 | 54 | | `AKMNetva-norm16` |0.6667 |0.6314 |0.5976 |0.6604 | 55 | | `AKMNetva-norm32` |0.6825 |0.6392 |0.6434 |0.6478 | 56 | | `AKMNetva-norm64` |0.7090 |0.6392 |0.6463 |0.6164 | 57 | | `AKMNetva-norm128` |0.6984 |0.6431 |0.6646 |0.6792 | 58 | | `**AKMNet**` |**0.7566** |**0.6706** |**0.7256** |**0.7170**| 59 | 60 | * Ablation Experiment 61 | 62 | | *Methods* |*CASMEI*|*CASMEII*|*SMIC*|*SAMM*| 63 | |:-----------------:|:--------:|:----------:|:----------:|:----------:| 64 | | `AKMNet-s12` |0.6984 |0.6392 |0.6463 |0.6667 | 65 | | `AKMNet-s13` |0.7354 |0.6431 |0.6463 |0.6604 | 66 | | `AKMNet-s23` |0.7249 |0.6549 |0.6707 |0.6918 | 67 | | `AKMNet-s123` |0.7566 |0.6706 |0.7256| 0.7170 | 68 | 69 | * Annotated Apex Frame VS ‘Most Informative’ Frame 70 | 71 | | *Methods* |*CASMEI*|*CASMEII*|*SMIC*|*SAMM*| 72 | |:-----------------:|:--------:|:----------:|:----------:|:----------:| 73 | | `Resnet18` |Apex frame |0.6772 |0.6078 |0.6436 | 74 | | `Resnet18` |Max-key frame |0.6825| 0.6392| 0.6486 | 75 | | `VGG-11` |Apex frame |0.6667 |0.6235| 0.6277 | 76 | | `VGG-11` |May-key frame |0.6931 |0.6353 |0.6649 | 77 | 78 | * how to use: 79 | for each LOSO exp: 80 | first: set *list_file_train* and *list_file_test* in `main.py` properly, each of them is a list file, contents in file like this: 81 | */home/XXX/fold/sub01/EP01_12__alpha15 19 3* 82 | *...* 83 | where */home/XXX/fold/sub01/EP01_12__alpha15* is a fold which contain a image sequence of a micro-expression, 19 is the len of the clips, 3 is the label 84 | second: set *premodel* in `main.py` if you have the pretrained model 85 | third: run `python main.py` in your terminal. 86 | 87 | * 88 | If you have questions, post them in GitHub issues. 89 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | import os 4 | import numpy as np 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | import torch.nn.parallel 11 | import torch.backends.cudnn as cudnn 12 | from torch.autograd import Variable 13 | import torchvision 14 | import torchvision.models as models 15 | 16 | ###Data require 17 | import argparse 18 | from datasets.dataset import VolumeDataset 19 | from datasets.transforms import * 20 | from torch.utils.data import DataLoader 21 | 22 | # ###Model require 23 | from models import resnet 24 | from util import tr_epoch, ts_epoch 25 | 26 | 27 | 28 | parser = argparse.ArgumentParser('Resnets') 29 | parser.add_argument('--seed', type=int, default=1) 30 | 31 | # ========================= Data Configs ========================== 32 | parser.add_argument('--data_root_train', type=str, default='') 33 | parser.add_argument('--list_file_train', type=str, default='./Train.txt') 34 | parser.add_argument('--data_root_test', type=str, default='') 35 | parser.add_argument('--list_file_test', type=str, default='./Test.txt') 36 | parser.add_argument('--modality', type=str, default='Gray', help='RGB | Gray') 37 | parser.add_argument('--batch_size', type=int, default=1) 38 | parser.add_argument('--num_workers', type=int, default=16) 39 | 40 | # ========================= Model Configs ========================== 41 | parser.add_argument('--premodel', default='XXX/epoch100.pt', type=str, help='Pretrained model (.pth)') 42 | parser.add_argument('--num_classes', default=4, type=int, help='Number of classes') 43 | parser.add_argument('--no_cuda', action='store_true', help='If true, cuda is not used.') 44 | parser.set_defaults(no_cuda=False) 45 | 46 | parser.add_argument('--lr', type=float, default=1e-3) 47 | parser.add_argument('--PenaltyBw', type=float, default=1) 48 | parser.add_argument('--PenaltyB', type=float, default=0.1) 49 | parser.add_argument('--momentum', type=float, default=0.9) 50 | parser.add_argument('--weight_decay', type=float, default=0.0005) 51 | parser.add_argument('--epoch', type=int, default=40) 52 | parser.add_argument('--device_ids', type=int, default=1) 53 | 54 | # ========================= Model Save ========================== 55 | parser.add_argument('--save_path', type=str, default='./pt') 56 | parser.add_argument('--checkpoint_path', type=str, default='') 57 | 58 | args = parser.parse_args() 59 | 60 | 61 | 62 | ###Data read 63 | train_dataset = VolumeDataset(data_root=args.data_root_train, list_file_root=args.list_file_train, modality=args.modality, 64 | transform=torchvision.transforms.Compose([ 65 | GroupScaleRandomCrop((144,144),(128,128)), 66 | ToTorchFormatTensor(div=True), 67 | ]), 68 | ) 69 | train_loader = DataLoader(train_dataset,batch_size=args.batch_size,shuffle=True,num_workers=args.num_workers,drop_last=False) 70 | 71 | test_dataset = VolumeDataset(data_root=args.data_root_test, list_file_root=args.list_file_test, modality=args.modality, 72 | transform=torchvision.transforms.Compose([ 73 | GroupScale((128,128)), 74 | ToTorchFormatTensor(div=True), 75 | ]), 76 | ) 77 | test_loader = DataLoader(test_dataset,batch_size=1,shuffle=False,num_workers=args.num_workers,drop_last=False) 78 | 79 | 80 | # # ###Model 81 | model = resnet.resnet18(pretrained=False, num_classes=args.num_classes) 82 | if args.premodel: 83 | pretrained_dict = torch.load(args.premodel) 84 | model_dict = model.state_dict() 85 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys() and v.size() == model_dict[k].size()} 86 | missed_params = [k for k, v in model_dict.items() if not k in pretrained_dict.keys()] 87 | print('loaded params/tot params:{}/{}'.format(len(pretrained_dict),len(model_dict))) 88 | print('miss matched params:',missed_params) 89 | model_dict.update(pretrained_dict) 90 | model.load_state_dict(model_dict) 91 | if not args.no_cuda: 92 | model = model.cuda(args.device_ids) 93 | print(model) 94 | 95 | 96 | # ###Hyperparam 97 | criterion1 = nn.CrossEntropyLoss() 98 | if not args.no_cuda: 99 | criterion1 = criterion1.cuda(args.device_ids) 100 | optimizer = optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 101 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epoch/5, eta_min=1e-8, last_epoch=-1) 102 | 103 | class BwLoss(nn.Module): 104 | def __init__(self): 105 | super(BwLoss, self).__init__() 106 | def forward(self, Bw): 107 | loss_Bw = 0.0 108 | for i in range(Bw.shape[0]): 109 | temp = Bw[i,:] 110 | loss_Bw += 2.0-(torch.mean(temp[temp>torch.mean(temp)])-torch.mean(temp[temp=10: 139 | if Acc>=Acc_best: 140 | Acc_best = Acc 141 | torch.save(model.state_dict(), args.save_path + '/' + 'epoch' + str(epoch) + '.pt') -------------------------------------------------------------------------------- /datasets/transforms.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import random 3 | from PIL import Image, ImageOps 4 | import numpy as np 5 | import numbers 6 | import math 7 | import torch 8 | 9 | 10 | class GroupRandomCrop(object): 11 | def __init__(self, size): 12 | if isinstance(size, numbers.Number): 13 | self.size = (int(size), int(size)) 14 | else: 15 | self.size = size 16 | 17 | def __call__(self, img_group): 18 | 19 | w, h = img_group[0].size 20 | th, tw = self.size 21 | 22 | out_images = list() 23 | 24 | x1 = random.randint(0, w - tw) 25 | y1 = random.randint(0, h - th) 26 | 27 | for img in img_group: 28 | assert(img.size[0] == w and img.size[1] == h) 29 | if w == tw and h == th: 30 | out_images.append(img) 31 | else: 32 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th))) 33 | 34 | return out_images 35 | 36 | 37 | class GroupCenterCrop(object): 38 | def __init__(self, size): 39 | self.worker = torchvision.transforms.CenterCrop(size) 40 | 41 | def __call__(self, img_group): 42 | return [self.worker(img) for img in img_group] 43 | 44 | 45 | class GroupRandomHorizontalFlip(object): 46 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5 47 | """ 48 | def __init__(self, is_flow=False): 49 | self.is_flow = is_flow 50 | 51 | def __call__(self, img_group, is_flow=False): 52 | v = random.random() 53 | if v < 0.5: 54 | ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group] 55 | if self.is_flow: 56 | for i in range(0, len(ret), 2): 57 | ret[i] = ImageOps.invert(ret[i]) # invert flow pixel values when flipping 58 | return ret 59 | else: 60 | return img_group 61 | 62 | 63 | class GroupScale(object): 64 | """ Rescales the input PIL.Image to the given 'size'. 65 | 'size' will be the size of the smaller edge. 66 | For example, if height > width, then image will be 67 | rescaled to (size * height / width, size) 68 | size: size of the smaller edge 69 | interpolation: Default: PIL.Image.BILINEAR 70 | """ 71 | 72 | def __init__(self, size, interpolation=Image.BILINEAR): 73 | self.worker = torchvision.transforms.Resize(size, interpolation) 74 | 75 | def __call__(self, img_group): 76 | return [self.worker(img) for img in img_group] 77 | 78 | 79 | class GroupScaleRandomCrop(object): 80 | def __init__(self, size, size2, interpolation=Image.BILINEAR): 81 | self.worker = torchvision.transforms.Resize(size, interpolation) 82 | self.RandomCrop = GroupRandomCrop(size2) 83 | self.worker2 = torchvision.transforms.Resize(size2, interpolation) 84 | 85 | def __call__(self, img_group): 86 | if random.random() < 0.5: 87 | return self.RandomCrop([self.worker(img) for img in img_group]) 88 | else: 89 | return [self.worker2(img) for img in img_group] 90 | 91 | 92 | class GroupOverSample(object): 93 | def __init__(self, crop_size, scale_size=None): 94 | self.crop_size = crop_size if not isinstance(crop_size, int) else (crop_size, crop_size) 95 | 96 | if scale_size is not None: 97 | self.scale_worker = GroupScale(scale_size) 98 | else: 99 | self.scale_worker = None 100 | 101 | def __call__(self, img_group): 102 | 103 | if self.scale_worker is not None: 104 | img_group = self.scale_worker(img_group) 105 | 106 | image_w, image_h = img_group[0].size 107 | crop_w, crop_h = self.crop_size 108 | 109 | offsets = GroupMultiScaleCrop.fill_fix_offset(False, image_w, image_h, crop_w, crop_h) 110 | oversample_group = list() 111 | for o_w, o_h in offsets: 112 | normal_group = list() 113 | flip_group = list() 114 | for i, img in enumerate(img_group): 115 | crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h)) 116 | normal_group.append(crop) 117 | flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT) 118 | 119 | if img.mode == 'L' and i % 2 == 0: 120 | flip_group.append(ImageOps.invert(flip_crop)) 121 | else: 122 | flip_group.append(flip_crop) 123 | 124 | oversample_group.extend(normal_group) 125 | oversample_group.extend(flip_group) 126 | return oversample_group 127 | 128 | 129 | class GroupMultiScaleCrop(object): 130 | 131 | def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True): 132 | self.scales = scales if scales is not None else [1, 875, .75, .66] 133 | self.max_distort = max_distort 134 | self.fix_crop = fix_crop 135 | self.more_fix_crop = more_fix_crop 136 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size] 137 | self.interpolation = Image.BILINEAR 138 | 139 | def __call__(self, img_group): 140 | 141 | im_size = img_group[0].size 142 | 143 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size) 144 | crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group] 145 | ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) 146 | for img in crop_img_group] 147 | return ret_img_group 148 | 149 | def _sample_crop_size(self, im_size): 150 | image_w, image_h = im_size[0], im_size[1] 151 | 152 | # find a crop size 153 | base_size = min(image_w, image_h) 154 | crop_sizes = [int(base_size * x) for x in self.scales] 155 | crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes] 156 | crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes] 157 | 158 | pairs = [] 159 | for i, h in enumerate(crop_h): 160 | for j, w in enumerate(crop_w): 161 | if abs(i - j) <= self.max_distort: 162 | pairs.append((w, h)) 163 | 164 | crop_pair = random.choice(pairs) 165 | if not self.fix_crop: 166 | w_offset = random.randint(0, image_w - crop_pair[0]) 167 | h_offset = random.randint(0, image_h - crop_pair[1]) 168 | else: 169 | w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1]) 170 | 171 | return crop_pair[0], crop_pair[1], w_offset, h_offset 172 | 173 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h): 174 | offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h) 175 | return random.choice(offsets) 176 | 177 | @staticmethod 178 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h): 179 | w_step = (image_w - crop_w) // 4 180 | h_step = (image_h - crop_h) // 4 181 | 182 | ret = list() 183 | ret.append((0, 0)) # upper left 184 | ret.append((4 * w_step, 0)) # upper right 185 | ret.append((0, 4 * h_step)) # lower left 186 | ret.append((4 * w_step, 4 * h_step)) # lower right 187 | ret.append((2 * w_step, 2 * h_step)) # center 188 | 189 | if more_fix_crop: 190 | ret.append((0, 2 * h_step)) # center left 191 | ret.append((4 * w_step, 2 * h_step)) # center right 192 | ret.append((2 * w_step, 4 * h_step)) # lower center 193 | ret.append((2 * w_step, 0 * h_step)) # upper center 194 | 195 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter 196 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter 197 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter 198 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter 199 | 200 | return ret 201 | 202 | 203 | class GroupRandomSizedCrop(object): 204 | """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size 205 | and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio 206 | This is popularly used to train the Inception networks 207 | size: size of the smaller edge 208 | interpolation: Default: PIL.Image.BILINEAR 209 | """ 210 | def __init__(self, size, interpolation=Image.BILINEAR): 211 | self.size = size 212 | self.interpolation = interpolation 213 | 214 | def __call__(self, img_group): 215 | for attempt in range(10): 216 | area = img_group[0].size[0] * img_group[0].size[1] 217 | target_area = random.uniform(0.08, 1.0) * area 218 | aspect_ratio = random.uniform(3. / 4, 4. / 3) 219 | 220 | w = int(round(math.sqrt(target_area * aspect_ratio))) 221 | h = int(round(math.sqrt(target_area / aspect_ratio))) 222 | 223 | if random.random() < 0.5: 224 | w, h = h, w 225 | 226 | if w <= img_group[0].size[0] and h <= img_group[0].size[1]: 227 | x1 = random.randint(0, img_group[0].size[0] - w) 228 | y1 = random.randint(0, img_group[0].size[1] - h) 229 | found = True 230 | break 231 | else: 232 | found = False 233 | x1 = 0 234 | y1 = 0 235 | 236 | if found: 237 | out_group = list() 238 | for img in img_group: 239 | img = img.crop((x1, y1, x1 + w, y1 + h)) 240 | assert(img.size == (w, h)) 241 | out_group.append(img.resize((self.size, self.size), self.interpolation)) 242 | return out_group 243 | else: 244 | # Fallback 245 | scale = GroupScale(self.size, interpolation=self.interpolation) 246 | crop = GroupRandomCrop(self.size) 247 | return crop(scale(img_group)) 248 | 249 | 250 | 251 | class ToTorchFormatTensor(object): 252 | """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C x T) in the range [0, 255] 253 | to a torch.FloatTensor of shape (C x T x H x W) in the range [0.0, 1.0] """ 254 | def __init__(self, div=True): 255 | self.div = div 256 | 257 | def __call__(self, img_group): 258 | if img_group[0].mode == 'L': 259 | imgs = torch.from_numpy(np.concatenate([np.expand_dims(np.expand_dims(x, 2), 3) for x in img_group], axis=3)).permute(2, 3, 0, 1).contiguous() 260 | elif img_group[0].mode == 'RGB': 261 | imgs = torch.from_numpy(np.concatenate([np.expand_dims(x, 3) for x in img_group], axis=3)).permute(2, 3, 0, 1).contiguous() 262 | return imgs.float().div(255) if self.div else imgs.float() 263 | 264 | 265 | class GroupNormalize(object): 266 | def __init__(self, mean, std): 267 | self.mean = mean 268 | self.std = std 269 | 270 | def __call__(self, tensor): 271 | rep_mean = self.mean * (tensor.size()[0]//len(self.mean)) 272 | rep_std = self.std * (tensor.size()[0]//len(self.std)) 273 | 274 | # TODO: make efficient 275 | for t, m, s in zip(tensor, rep_mean, rep_std): 276 | t.sub_(m).div_(s) 277 | 278 | return tensor 279 | 280 | class IdentityTransform(object): 281 | 282 | def __call__(self, data): 283 | return data -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Function 4 | import torch.utils.model_zoo as model_zoo 5 | 6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 7 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d'] 8 | 9 | 10 | model_urls = { 11 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 12 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 13 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 14 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 15 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 16 | } 17 | 18 | 19 | def conv3x3(in_planes, out_planes, stride=1, groups=1): 20 | """3x3 convolution with padding""" 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 22 | padding=1, groups=groups, bias=False) 23 | 24 | 25 | def conv1x1(in_planes, out_planes, stride=1): 26 | """1x1 convolution""" 27 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 28 | 29 | 30 | class BasicBlock(nn.Module): 31 | expansion = 1 32 | 33 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, norm_layer=None): 34 | super(BasicBlock, self).__init__() 35 | if norm_layer is None: 36 | norm_layer = nn.BatchNorm2d 37 | if groups != 1: 38 | raise ValueError('BasicBlock only supports groups=1') 39 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 40 | self.conv1 = conv3x3(inplanes, planes, stride) 41 | self.bn1 = norm_layer(planes) 42 | self.relu = nn.ReLU(inplace=True) 43 | self.conv2 = conv3x3(planes, planes) 44 | self.bn2 = norm_layer(planes) 45 | self.downsample = downsample 46 | self.stride = stride 47 | 48 | def forward(self, x): 49 | identity = x 50 | 51 | out = self.conv1(x) 52 | out = self.bn1(out) 53 | out = self.relu(out) 54 | 55 | out = self.conv2(out) 56 | out = self.bn2(out) 57 | 58 | if self.downsample is not None: 59 | identity = self.downsample(x) 60 | 61 | out += identity 62 | out = self.relu(out) 63 | 64 | return out 65 | 66 | 67 | class Bottleneck(nn.Module): 68 | expansion = 4 69 | 70 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, norm_layer=None): 71 | super(Bottleneck, self).__init__() 72 | if norm_layer is None: 73 | norm_layer = nn.BatchNorm2d 74 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 75 | self.conv1 = conv1x1(inplanes, planes) 76 | self.bn1 = norm_layer(planes) 77 | self.conv2 = conv3x3(planes, planes, stride, groups) 78 | self.bn2 = norm_layer(planes) 79 | self.conv3 = conv1x1(planes, planes * self.expansion) 80 | self.bn3 = norm_layer(planes * self.expansion) 81 | self.relu = nn.ReLU(inplace=True) 82 | self.downsample = downsample 83 | self.stride = stride 84 | 85 | def forward(self, x): 86 | identity = x 87 | 88 | out = self.conv1(x) 89 | out = self.bn1(out) 90 | out = self.relu(out) 91 | 92 | out = self.conv2(out) 93 | out = self.bn2(out) 94 | out = self.relu(out) 95 | 96 | out = self.conv3(out) 97 | out = self.bn3(out) 98 | 99 | if self.downsample is not None: 100 | identity = self.downsample(x) 101 | 102 | out += identity 103 | out = self.relu(out) 104 | 105 | return out 106 | 107 | 108 | 109 | # class BinarizedF(Function): 110 | # def forward(self, input): 111 | # self.save_for_backward(input) 112 | # ones = torch.ones_like(input) 113 | # zeros = torch.zeros_like(input) 114 | # output = torch.where(input>0,ones,zeros) 115 | # return output 116 | # def backward(self, output_grad): 117 | # input, = self.saved_tensors 118 | # ones = torch.ones_like(input) 119 | # zeros = torch.zeros_like(input) 120 | # input_grad = output_grad*torch.where((0=torch.mean(input[n,:]),ones,zeros) 137 | return output 138 | def backward(self, output_grad): 139 | input, = self.saved_tensors 140 | ones = torch.ones_like(input) 141 | zeros = torch.zeros_like(input) 142 | input_grad = output_grad.clone() 143 | for n in range(input.shape[0]): 144 | input_grad[n,:] = output_grad[n,:]*torch.where((1>torch.mean(input[n,:]))&(0