├── .gitignore ├── LICENSE ├── README.md ├── config.py ├── data_gen.py ├── demo.py ├── extract.py ├── human_colormap.mat ├── images ├── 0_image.png ├── 0_merged.png ├── 0_out.png ├── 1_image.png ├── 1_merged.png ├── 1_out.png ├── 2_image.png ├── 2_merged.png ├── 2_out.png ├── 3_image.png ├── 3_merged.png ├── 3_out.png ├── 4_image.png ├── 4_merged.png ├── 4_out.png ├── 5_image.png ├── 5_merged.png ├── 5_out.png ├── 6_image.png ├── 6_merged.png ├── 6_out.png ├── 7_image.png ├── 7_merged.png ├── 7_out.png ├── 8_image.png ├── 8_merged.png ├── 8_out.png ├── 9_image.png ├── 9_merged.png ├── 9_out.png └── deeplabv3.png ├── models.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | __pycache__/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 刘杨 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Look Into Person 2 | 3 | Human Parsing with DeepLabv3 in PyTorch. 4 | 5 | ## Dependencies 6 | - [NumPy](http://docs.scipy.org/doc/numpy-1.10.1/user/install.html) 7 | - [PyTorch](https://pytorch.org/) 8 | - [OpenCV](https://opencv-python-tutroals.readthedocs.io/en/latest/) 9 | 10 | ## Dataset 11 | 12 | ![image](https://github.com/foamliu/Look-Into-Person/raw/master/images/dataset.png) 13 | 14 | Follow the [instruction](http://sysu-hcp.net/lip/index.php) to download Look-Into-Person dataset. 15 | 16 | ## Architecture 17 | 18 | DeepLab-v3+ 19 | 20 | ![image](https://github.com/foamliu/Look-Into-Person-v2/raw/master/images/deeplabv3.png) 21 | 22 | 23 | ## Usage 24 | ### Data Pre-processing 25 | Extract training images: 26 | ```bash 27 | $ python extract.py 28 | ``` 29 | 30 | ### Train 31 | ```bash 32 | $ python train.py 33 | ``` 34 | 35 | If you want to visualize during training, run in your terminal: 36 | ```bash 37 | $ tensorboard --logdir runs 38 | ``` 39 | 40 | ### Demo 41 | 42 | Download [pre-trained model](https://github.com/foamliu/Look-Into-Person/releases/download/v1.0/model.11-0.8409.hdf5) and put it into models folder. 43 | 44 | ```bash 45 | $ python demo.py 46 | ``` 47 | 48 | Input | Merged | Output | 49 | |---|---|---| 50 | |![image](https://github.com/foamliu/Look-Into-Person-v2/raw/master/images/0_image.png) | ![image](https://github.com/foamliu/Look-Into-Person-v2/raw/master/images/0_merged.png)| ![image](https://github.com/foamliu/Look-Into-Person-v2/raw/master/images/0_out.png)| 51 | |![image](https://github.com/foamliu/Look-Into-Person-v2/raw/master/images/1_image.png) | ![image](https://github.com/foamliu/Look-Into-Person-v2/raw/master/images/1_merged.png)| ![image](https://github.com/foamliu/Look-Into-Person-v2/raw/master/images/1_out.png)| 52 | |![image](https://github.com/foamliu/Look-Into-Person-v2/raw/master/images/2_image.png) | ![image](https://github.com/foamliu/Look-Into-Person-v2/raw/master/images/2_merged.png)| ![image](https://github.com/foamliu/Look-Into-Person-v2/raw/master/images/2_out.png)| 53 | |![image](https://github.com/foamliu/Look-Into-Person-v2/raw/master/images/3_image.png) | ![image](https://github.com/foamliu/Look-Into-Person-v2/raw/master/images/3_merged.png)| ![image](https://github.com/foamliu/Look-Into-Person-v2/raw/master/images/3_out.png)| 54 | |![image](https://github.com/foamliu/Look-Into-Person-v2/raw/master/images/4_image.png) | ![image](https://github.com/foamliu/Look-Into-Person-v2/raw/master/images/4_merged.png)| ![image](https://github.com/foamliu/Look-Into-Person-v2/raw/master/images/4_out.png)| 55 | |![image](https://github.com/foamliu/Look-Into-Person-v2/raw/master/images/5_image.png) | ![image](https://github.com/foamliu/Look-Into-Person-v2/raw/master/images/5_merged.png)| ![image](https://github.com/foamliu/Look-Into-Person-v2/raw/master/images/5_out.png)| 56 | |![image](https://github.com/foamliu/Look-Into-Person-v2/raw/master/images/6_image.png) | ![image](https://github.com/foamliu/Look-Into-Person-v2/raw/master/images/6_merged.png)| ![image](https://github.com/foamliu/Look-Into-Person-v2/raw/master/images/6_out.png)| 57 | |![image](https://github.com/foamliu/Look-Into-Person-v2/raw/master/images/7_image.png) | ![image](https://github.com/foamliu/Look-Into-Person-v2/raw/master/images/7_merged.png)| ![image](https://github.com/foamliu/Look-Into-Person-v2/raw/master/images/7_out.png)| 58 | |![image](https://github.com/foamliu/Look-Into-Person-v2/raw/master/images/8_image.png) | ![image](https://github.com/foamliu/Look-Into-Person-v2/raw/master/images/8_merged.png)| ![image](https://github.com/foamliu/Look-Into-Person-v2/raw/master/images/8_out.png)| 59 | |![image](https://github.com/foamliu/Look-Into-Person-v2/raw/master/images/9_image.png) | ![image](https://github.com/foamliu/Look-Into-Person-v2/raw/master/images/9_merged.png)| ![image](https://github.com/foamliu/Look-Into-Person-v2/raw/master/images/9_out.png)| 60 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.io 3 | import torch 4 | 5 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # sets device for model and PyTorch tensors 6 | 7 | im_size = 320 8 | channel = 3 9 | batch_size = 16 10 | epochs = 1000 11 | patience = 50 12 | num_train_samples = 28280 13 | num_valid_samples = 5000 14 | num_classes = 20 15 | weight_decay = 1e-2 16 | 17 | # Training parameters 18 | num_workers = 1 # for data-loading; right now, only 1 works with h5py 19 | grad_clip = 5. # clip gradients at an absolute value of 20 | print_freq = 100 # print training/validation stats every __ batches 21 | checkpoint = None # path to checkpoint, None if none 22 | 23 | mat = scipy.io.loadmat('human_colormap.mat') 24 | color_map = (mat['colormap'] * 256).astype(np.int32) 25 | -------------------------------------------------------------------------------- /data_gen.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import cv2 as cv 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import Dataset 8 | from torchvision import transforms 9 | 10 | from config import im_size, color_map, num_classes 11 | 12 | train_images_folder = 'data/instance-level_human_parsing/Training/Images' 13 | train_categories_folder = 'data/instance-level_human_parsing/Training/Category_ids' 14 | valid_images_folder = 'data/instance-level_human_parsing/Validation/Images' 15 | valid_categories_folder = 'data/instance-level_human_parsing/Validation/Category_ids' 16 | 17 | # Data augmentation and normalization for training 18 | # Just normalization for validation 19 | data_transforms = { 20 | 'train': transforms.Compose([ 21 | transforms.ToTensor(), 22 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 23 | ]), 24 | 'valid': transforms.Compose([ 25 | transforms.ToTensor(), 26 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 27 | ]), 28 | } 29 | 30 | 31 | def get_category(categories_folder, name): 32 | filename = os.path.join(categories_folder, name + '.png') 33 | semantic = cv.imread(filename, 0) 34 | return semantic 35 | 36 | 37 | def to_bgr(y_pred): 38 | ret = np.zeros((im_size, im_size, 3), np.float32) 39 | for r in range(320): 40 | for c in range(320): 41 | color_id = y_pred[r, c] 42 | # print("color_id: " + str(color_id)) 43 | ret[r, c, :] = color_map[color_id] 44 | ret = ret.astype(np.uint8) 45 | return ret 46 | 47 | 48 | def random_choice(image_size): 49 | height, width = image_size 50 | crop_height, crop_width = 320, 320 51 | x = random.randint(0, max(0, width - crop_width)) 52 | y = random.randint(0, max(0, height - crop_height)) 53 | return x, y 54 | 55 | 56 | def safe_crop(mat, x, y): 57 | crop_height, crop_width = 320, 320 58 | if len(mat.shape) == 2: 59 | ret = np.zeros((crop_height, crop_width), np.uint8) 60 | else: 61 | ret = np.zeros((crop_height, crop_width, 3), np.uint8) 62 | crop = mat[y:y + crop_height, x:x + crop_width] 63 | h, w = crop.shape[:2] 64 | ret[0:h, 0:w] = crop 65 | return ret 66 | 67 | 68 | class LIPDataset(Dataset): 69 | def __init__(self, split): 70 | self.usage = split 71 | 72 | if split == 'train': 73 | id_file = 'data/instance-level_human_parsing/Training/train_id.txt' 74 | self.images_folder = train_images_folder 75 | self.categories_folder = train_categories_folder 76 | else: 77 | id_file = 'data/instance-level_human_parsing/Validation/val_id.txt' 78 | self.images_folder = valid_images_folder 79 | self.categories_folder = valid_categories_folder 80 | 81 | with open(id_file, 'r') as f: 82 | self.names = f.read().splitlines() 83 | 84 | self.transformer = data_transforms[split] 85 | 86 | def __getitem__(self, i): 87 | name = self.names[i] 88 | filename = os.path.join(self.images_folder, name + '.jpg') 89 | img = cv.imread(filename) 90 | image_size = img.shape[:2] 91 | category = get_category(self.categories_folder, name) 92 | 93 | x, y = random_choice(image_size) 94 | img = safe_crop(img, x, y) 95 | category = safe_crop(category, x, y) 96 | category = np.clip(category, 0, num_classes - 1) 97 | 98 | if np.random.random_sample() > 0.5: 99 | img = np.fliplr(img) 100 | category = np.fliplr(category) 101 | 102 | img = img[..., ::-1] # RGB 103 | img = transforms.ToPILImage()(img) 104 | img = self.transformer(img) 105 | 106 | y = category 107 | 108 | return img, torch.from_numpy(y.copy()) 109 | 110 | def __len__(self): 111 | return len(self.names) 112 | 113 | 114 | if __name__ == "__main__": 115 | dataset = LIPDataset('train') 116 | print(dataset[0]) 117 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | # import the necessary packages 2 | import os 3 | import random 4 | 5 | import cv2 as cv 6 | import numpy as np 7 | import torch 8 | from torchvision import transforms 9 | 10 | from config import device, im_size, num_classes 11 | from data_gen import random_choice, safe_crop, to_bgr, data_transforms 12 | from utils import ensure_folder 13 | 14 | if __name__ == '__main__': 15 | checkpoint = 'BEST_checkpoint.tar' 16 | checkpoint = torch.load(checkpoint) 17 | model = checkpoint['model'] 18 | model = model.to(device) 19 | model.eval() 20 | 21 | transformer = data_transforms['valid'] 22 | 23 | ensure_folder('images') 24 | 25 | test_images_folder = 'data/instance-level_human_parsing/Testing/Images' 26 | id_file = 'data/instance-level_human_parsing/Testing/test_id.txt' 27 | with open(id_file, 'r') as f: 28 | names = f.read().splitlines() 29 | 30 | samples = random.sample(names, 10) 31 | 32 | for i in range(len(samples)): 33 | image_name = samples[i] 34 | filename = os.path.join(test_images_folder, image_name + '.jpg') 35 | image = cv.imread(filename) 36 | image_size = image.shape[:2] 37 | 38 | x, y = random_choice(image_size) 39 | image = safe_crop(image, x, y) 40 | print('Start processing image: {}'.format(filename)) 41 | 42 | x_test = torch.zeros((1, 3, im_size, im_size), dtype=torch.float) 43 | img = image[..., ::-1] # RGB 44 | img = transforms.ToPILImage()(img) 45 | img = transformer(img) 46 | x_test[0:, 0:3, :, :] = img 47 | 48 | with torch.no_grad(): 49 | out = model(x_test)['out'] 50 | 51 | out = out.cpu().numpy()[0] 52 | out = np.argmax(out, axis=0) 53 | out = to_bgr(out) 54 | 55 | ret = image * 0.6 + out * 0.4 56 | ret = ret.astype(np.uint8) 57 | 58 | if not os.path.exists('images'): 59 | os.makedirs('images') 60 | 61 | cv.imwrite('images/{}_image.png'.format(i), image) 62 | cv.imwrite('images/{}_merged.png'.format(i), ret) 63 | cv.imwrite('images/{}_out.png'.format(i), out) 64 | -------------------------------------------------------------------------------- /extract.py: -------------------------------------------------------------------------------- 1 | import tarfile 2 | 3 | 4 | if __name__ == '__main__': 5 | filename = 'data/instance-level_human_parsing.tar.gz' 6 | print('Extracting {}...'.format(filename)) 7 | with tarfile.open(filename) as tar: 8 | tar.extractall('data') 9 | 10 | 11 | -------------------------------------------------------------------------------- /human_colormap.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Look-Into-Person-PyTorch/ca524bac51e3b54a0d723e746ee400905567adcb/human_colormap.mat -------------------------------------------------------------------------------- /images/0_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Look-Into-Person-PyTorch/ca524bac51e3b54a0d723e746ee400905567adcb/images/0_image.png -------------------------------------------------------------------------------- /images/0_merged.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Look-Into-Person-PyTorch/ca524bac51e3b54a0d723e746ee400905567adcb/images/0_merged.png -------------------------------------------------------------------------------- /images/0_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Look-Into-Person-PyTorch/ca524bac51e3b54a0d723e746ee400905567adcb/images/0_out.png -------------------------------------------------------------------------------- /images/1_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Look-Into-Person-PyTorch/ca524bac51e3b54a0d723e746ee400905567adcb/images/1_image.png -------------------------------------------------------------------------------- /images/1_merged.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Look-Into-Person-PyTorch/ca524bac51e3b54a0d723e746ee400905567adcb/images/1_merged.png -------------------------------------------------------------------------------- /images/1_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Look-Into-Person-PyTorch/ca524bac51e3b54a0d723e746ee400905567adcb/images/1_out.png -------------------------------------------------------------------------------- /images/2_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Look-Into-Person-PyTorch/ca524bac51e3b54a0d723e746ee400905567adcb/images/2_image.png -------------------------------------------------------------------------------- /images/2_merged.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Look-Into-Person-PyTorch/ca524bac51e3b54a0d723e746ee400905567adcb/images/2_merged.png -------------------------------------------------------------------------------- /images/2_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Look-Into-Person-PyTorch/ca524bac51e3b54a0d723e746ee400905567adcb/images/2_out.png -------------------------------------------------------------------------------- /images/3_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Look-Into-Person-PyTorch/ca524bac51e3b54a0d723e746ee400905567adcb/images/3_image.png -------------------------------------------------------------------------------- /images/3_merged.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Look-Into-Person-PyTorch/ca524bac51e3b54a0d723e746ee400905567adcb/images/3_merged.png -------------------------------------------------------------------------------- /images/3_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Look-Into-Person-PyTorch/ca524bac51e3b54a0d723e746ee400905567adcb/images/3_out.png -------------------------------------------------------------------------------- /images/4_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Look-Into-Person-PyTorch/ca524bac51e3b54a0d723e746ee400905567adcb/images/4_image.png -------------------------------------------------------------------------------- /images/4_merged.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Look-Into-Person-PyTorch/ca524bac51e3b54a0d723e746ee400905567adcb/images/4_merged.png -------------------------------------------------------------------------------- /images/4_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Look-Into-Person-PyTorch/ca524bac51e3b54a0d723e746ee400905567adcb/images/4_out.png -------------------------------------------------------------------------------- /images/5_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Look-Into-Person-PyTorch/ca524bac51e3b54a0d723e746ee400905567adcb/images/5_image.png -------------------------------------------------------------------------------- /images/5_merged.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Look-Into-Person-PyTorch/ca524bac51e3b54a0d723e746ee400905567adcb/images/5_merged.png -------------------------------------------------------------------------------- /images/5_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Look-Into-Person-PyTorch/ca524bac51e3b54a0d723e746ee400905567adcb/images/5_out.png -------------------------------------------------------------------------------- /images/6_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Look-Into-Person-PyTorch/ca524bac51e3b54a0d723e746ee400905567adcb/images/6_image.png -------------------------------------------------------------------------------- /images/6_merged.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Look-Into-Person-PyTorch/ca524bac51e3b54a0d723e746ee400905567adcb/images/6_merged.png -------------------------------------------------------------------------------- /images/6_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Look-Into-Person-PyTorch/ca524bac51e3b54a0d723e746ee400905567adcb/images/6_out.png -------------------------------------------------------------------------------- /images/7_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Look-Into-Person-PyTorch/ca524bac51e3b54a0d723e746ee400905567adcb/images/7_image.png -------------------------------------------------------------------------------- /images/7_merged.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Look-Into-Person-PyTorch/ca524bac51e3b54a0d723e746ee400905567adcb/images/7_merged.png -------------------------------------------------------------------------------- /images/7_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Look-Into-Person-PyTorch/ca524bac51e3b54a0d723e746ee400905567adcb/images/7_out.png -------------------------------------------------------------------------------- /images/8_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Look-Into-Person-PyTorch/ca524bac51e3b54a0d723e746ee400905567adcb/images/8_image.png -------------------------------------------------------------------------------- /images/8_merged.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Look-Into-Person-PyTorch/ca524bac51e3b54a0d723e746ee400905567adcb/images/8_merged.png -------------------------------------------------------------------------------- /images/8_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Look-Into-Person-PyTorch/ca524bac51e3b54a0d723e746ee400905567adcb/images/8_out.png -------------------------------------------------------------------------------- /images/9_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Look-Into-Person-PyTorch/ca524bac51e3b54a0d723e746ee400905567adcb/images/9_image.png -------------------------------------------------------------------------------- /images/9_merged.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Look-Into-Person-PyTorch/ca524bac51e3b54a0d723e746ee400905567adcb/images/9_merged.png -------------------------------------------------------------------------------- /images/9_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Look-Into-Person-PyTorch/ca524bac51e3b54a0d723e746ee400905567adcb/images/9_out.png -------------------------------------------------------------------------------- /images/deeplabv3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Look-Into-Person-PyTorch/ca524bac51e3b54a0d723e746ee400905567adcb/images/deeplabv3.png -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | from torchsummary import summary 2 | from torchvision import models 3 | 4 | from config import im_size, device 5 | 6 | if __name__ == '__main__': 7 | model = models.segmentation.deeplabv3_resnet50(pretrained=False, progress=True, num_classes=20) 8 | model = model.to(device) 9 | summary(model, input_size=(3, im_size, im_size)) 10 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.utils.tensorboard import SummaryWriter 5 | from torchvision import models 6 | 7 | from config import device, grad_clip, print_freq, num_classes 8 | from data_gen import LIPDataset 9 | from utils import parse_args, save_checkpoint, AverageMeter, clip_gradient, get_logger, get_learning_rate, \ 10 | adjust_learning_rate, accuracy 11 | 12 | 13 | def train_net(args): 14 | torch.manual_seed(7) 15 | np.random.seed(7) 16 | checkpoint = args.checkpoint 17 | start_epoch = 0 18 | best_loss = float('inf') 19 | writer = SummaryWriter() 20 | epochs_since_improvement = 0 21 | 22 | # Initialize / load checkpoint 23 | if checkpoint is None: 24 | model = models.segmentation.deeplabv3_resnet50(pretrained=False, progress=True, num_classes=num_classes) 25 | model = nn.DataParallel(model) 26 | 27 | if args.optimizer == 'sgd': 28 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.mom, 29 | weight_decay=args.weight_decay) 30 | else: 31 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 32 | 33 | else: 34 | checkpoint = torch.load(checkpoint) 35 | start_epoch = checkpoint['epoch'] + 1 36 | epochs_since_improvement = checkpoint['epochs_since_improvement'] 37 | model = checkpoint['model'] 38 | optimizer = checkpoint['optimizer'] 39 | 40 | logger = get_logger() 41 | 42 | # Move to GPU, if available 43 | model = model.to(device) 44 | 45 | # Loss function 46 | criterion = nn.CrossEntropyLoss().to(device) 47 | 48 | # Custom dataloaders 49 | train_dataset = LIPDataset('train') 50 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=8) 51 | valid_dataset = LIPDataset('valid') 52 | valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=8) 53 | 54 | # Epochs 55 | for epoch in range(start_epoch, args.end_epoch): 56 | if epochs_since_improvement == 10: 57 | break 58 | 59 | if epochs_since_improvement > 0 and epochs_since_improvement % 2 == 0: 60 | adjust_learning_rate(optimizer, 0.6) 61 | 62 | # One epoch's training 63 | train_loss, train_acc = train(train_loader=train_loader, 64 | model=model, 65 | criterion=criterion, 66 | optimizer=optimizer, 67 | epoch=epoch, 68 | logger=logger) 69 | effective_lr = get_learning_rate(optimizer) 70 | print('Current effective learning rate: {}\n'.format(effective_lr)) 71 | 72 | writer.add_scalar('Train_Loss', train_loss, epoch) 73 | writer.add_scalar('Train_Accuracy', train_acc, epoch) 74 | 75 | # One epoch's validation 76 | valid_loss, valid_acc = valid(valid_loader=valid_loader, 77 | model=model, 78 | criterion=criterion, 79 | logger=logger) 80 | 81 | writer.add_scalar('Valid_Loss', valid_loss, epoch) 82 | writer.add_scalar('Valid_Accuracy', valid_acc, epoch) 83 | 84 | # Check if there was an improvement 85 | is_best = valid_loss < best_loss 86 | best_loss = min(valid_loss, best_loss) 87 | if not is_best: 88 | epochs_since_improvement += 1 89 | print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,)) 90 | else: 91 | epochs_since_improvement = 0 92 | 93 | # Save checkpoint 94 | save_checkpoint(epoch, epochs_since_improvement, model, optimizer, best_loss, is_best) 95 | 96 | 97 | def train(train_loader, model, criterion, optimizer, epoch, logger): 98 | model.train() # train mode (dropout and batchnorm is used) 99 | 100 | losses = AverageMeter() 101 | accs = AverageMeter() 102 | 103 | # Batches 104 | for i, (img, label) in enumerate(train_loader): 105 | # Move to GPU, if available 106 | img = img.type(torch.FloatTensor).to(device) # [N, 3, 320, 320] 107 | label = label.type(torch.LongTensor).to(device) # [N, 320, 320] 108 | 109 | # Forward prop. 110 | out = model(img)['out'] # [N, 320, 320] 111 | 112 | # Calculate loss 113 | loss = criterion(out, label) 114 | acc = accuracy(out, label) 115 | 116 | # Back prop. 117 | optimizer.zero_grad() 118 | loss.backward() 119 | 120 | # Clip gradients 121 | clip_gradient(optimizer, grad_clip) 122 | 123 | # Update weights 124 | optimizer.step() 125 | 126 | # Keep track of metrics 127 | losses.update(loss.item()) 128 | accs.update(acc) 129 | 130 | # Print status 131 | 132 | if i % print_freq == 0: 133 | status = 'Epoch: [{0}][{1}/{2}]\t' \ 134 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \ 135 | 'Accuracy {acc.val:.4f} ({acc.avg:.4f})'.format(epoch, i, len(train_loader), loss=losses, acc=accs) 136 | logger.info(status) 137 | 138 | return losses.avg, accs.avg 139 | 140 | 141 | def valid(valid_loader, model, criterion, logger): 142 | model.eval() # eval mode (dropout and batchnorm is NOT used) 143 | 144 | losses = AverageMeter() 145 | accs = AverageMeter() 146 | 147 | # Batches 148 | for img, label in valid_loader: 149 | # Move to GPU, if available 150 | img = img.type(torch.FloatTensor).to(device) # [N, 3, 320, 320] 151 | label = label.type(torch.LongTensor).to(device) # [N, 320, 320] 152 | 153 | # Forward prop. 154 | out = model(img)['out'] # [N, 320, 320] 155 | 156 | # Calculate loss 157 | loss = criterion(out, label) 158 | acc = accuracy(out, label) 159 | 160 | # Keep track of metrics 161 | losses.update(loss.item()) 162 | accs.update(acc) 163 | 164 | # Print status 165 | status = 'Validation: Loss {loss.avg:.4f} Accuracy {acc.avg:.4f}\n'.format(loss=losses, acc=accs) 166 | 167 | logger.info(status) 168 | 169 | return losses.avg, accs.avg 170 | 171 | 172 | def main(): 173 | global args 174 | args = parse_args() 175 | train_net(args) 176 | 177 | 178 | if __name__ == '__main__': 179 | main() 180 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | 5 | import cv2 as cv 6 | import numpy as np 7 | import torch 8 | 9 | from config import im_size 10 | 11 | 12 | def to_categorical(y, num_classes): 13 | """ 1-hot encodes a tensor """ 14 | return np.eye(num_classes, dtype='uint8')[y] 15 | 16 | 17 | def clip_gradient(optimizer, grad_clip): 18 | """ 19 | Clips gradients computed during backpropagation to avoid explosion of gradients. 20 | :param optimizer: optimizer with the gradients to be clipped 21 | :param grad_clip: clip value 22 | """ 23 | for group in optimizer.param_groups: 24 | for param in group['params']: 25 | if param.grad is not None: 26 | param.grad.data.clamp_(-grad_clip, grad_clip) 27 | 28 | 29 | def save_checkpoint(epoch, epochs_since_improvement, model, optimizer, loss, is_best): 30 | state = {'epoch': epoch, 31 | 'epochs_since_improvement': epochs_since_improvement, 32 | 'loss': loss, 33 | 'model': model, 34 | 'optimizer': optimizer} 35 | # filename = 'checkpoint_' + str(epoch) + '_' + str(loss) + '.tar' 36 | filename = 'checkpoint.tar' 37 | torch.save(state, filename) 38 | # If this checkpoint is the best so far, store a copy so it doesn't get overwritten by a worse checkpoint 39 | if is_best: 40 | torch.save(state, 'BEST_checkpoint.tar') 41 | 42 | 43 | class AverageMeter(object): 44 | """ 45 | Keeps track of most recent, average, sum, and count of a metric. 46 | """ 47 | 48 | def __init__(self): 49 | self.reset() 50 | 51 | def reset(self): 52 | self.val = 0 53 | self.avg = 0 54 | self.sum = 0 55 | self.count = 0 56 | 57 | def update(self, val, n=1): 58 | self.val = val 59 | self.sum += val * n 60 | self.count += n 61 | self.avg = self.sum / self.count 62 | 63 | 64 | class LossMeterBag(object): 65 | 66 | def __init__(self, name_list): 67 | self.meter_dict = dict() 68 | self.name_list = name_list 69 | for name in self.name_list: 70 | self.meter_dict[name] = AverageMeter() 71 | 72 | def update(self, val_list): 73 | for i, name in enumerate(self.name_list): 74 | val = val_list[i] 75 | self.meter_dict[name].update(val) 76 | 77 | def __str__(self): 78 | ret = '' 79 | for name in self.name_list: 80 | ret += '{0}:\t {1:.4f}({2:.4f})\t'.format(name, self.meter_dict[name].val, self.meter_dict[name].avg) 81 | 82 | return ret 83 | 84 | 85 | def adjust_learning_rate(optimizer, shrink_factor): 86 | """ 87 | Shrinks learning rate by a specified factor. 88 | :param optimizer: optimizer whose learning rate must be shrunk. 89 | :param shrink_factor: factor in interval (0, 1) to multiply learning rate with. 90 | """ 91 | 92 | print("\nDECAYING learning rate.") 93 | for param_group in optimizer.param_groups: 94 | param_group['lr'] = param_group['lr'] * shrink_factor 95 | print("The new learning rate is %f\n" % (optimizer.param_groups[0]['lr'],)) 96 | 97 | 98 | def get_learning_rate(optimizer): 99 | for param_group in optimizer.param_groups: 100 | return param_group['lr'] 101 | 102 | 103 | def accuracy(scores, targets, k=1): 104 | batch_size = targets.size(0) 105 | _, ind = scores.topk(k, 1, True, True) 106 | ind = torch.squeeze(ind, dim=1) 107 | correct = ind.eq(targets) 108 | correct_total = correct.view(-1).float().sum() # 0D tensor 109 | return correct_total.item() * (100.0 / batch_size / im_size / im_size) 110 | 111 | 112 | def parse_args(): 113 | parser = argparse.ArgumentParser(description='Train face network') 114 | # general 115 | parser.add_argument('--end-epoch', type=int, default=50, help='training epoch size.') 116 | parser.add_argument('--lr', type=float, default=0.0001, help='start learning rate') 117 | parser.add_argument('--lr-step', type=int, default=10, help='period of learning rate decay') 118 | parser.add_argument('--optimizer', default='sgd', help='optimizer') 119 | parser.add_argument('--weight-decay', type=float, default=0.0005, help='weight decay') 120 | parser.add_argument('--mom', type=float, default=0.9, help='momentum') 121 | parser.add_argument('--batch-size', type=int, default=16, help='batch size in each context') 122 | parser.add_argument('--checkpoint', type=str, default=None, help='checkpoint') 123 | parser.add_argument('--pretrained', type=bool, default=False, help='pretrained model') 124 | args = parser.parse_args() 125 | return args 126 | 127 | 128 | def get_logger(): 129 | logger = logging.getLogger() 130 | handler = logging.StreamHandler() 131 | formatter = logging.Formatter("%(asctime)s %(levelname)s \t%(message)s") 132 | handler.setFormatter(formatter) 133 | logger.addHandler(handler) 134 | logger.setLevel(logging.DEBUG) 135 | return logger 136 | 137 | 138 | def safe_crop(mat, x, y, crop_size=(im_size, im_size)): 139 | crop_height, crop_width = crop_size 140 | if len(mat.shape) == 2: 141 | ret = np.zeros((crop_height, crop_width), np.float32) 142 | else: 143 | ret = np.zeros((crop_height, crop_width, 3), np.float32) 144 | crop = mat[y:y + crop_height, x:x + crop_width] 145 | h, w = crop.shape[:2] 146 | ret[0:h, 0:w] = crop 147 | if crop_size != (im_size, im_size): 148 | ret = cv.resize(ret, dsize=(im_size, im_size), interpolation=cv.INTER_NEAREST) 149 | return ret 150 | 151 | 152 | def ensure_folder(folder): 153 | if not os.path.exists(folder): 154 | os.makedirs(folder) 155 | --------------------------------------------------------------------------------