├── lib ├── __init__.py ├── core │ ├── __init__.py │ ├── api │ │ ├── __init__.py │ │ └── keypoint.py │ ├── base_trainer │ │ ├── __init__.py │ │ ├── metric.py │ │ ├── model.py │ │ └── net_work.py │ └── model │ │ ├── loss │ │ ├── __init__.py │ │ └── simpleface_loss.py │ │ ├── __init__.py │ │ └── face_model.py ├── dataset │ ├── __init__.py │ ├── headpose.py │ ├── augmentor │ │ ├── visual_augmentation.py │ │ └── augmentation.py │ └── dataietr.py └── utils │ ├── __init__.py │ ├── logger.py │ └── init.py ├── tools ├── __init__.py ├── convert_to_onnx.py └── convert_to_coreml.py ├── .idea ├── .gitignore ├── vcs.xml ├── misc.xml ├── inspectionProfiles │ └── profiles_settings.xml ├── modules.xml └── face_landmark_pytorch.iml ├── .gitignore ├── run.sh ├── figures └── tmp_screenshot_18.08.20192.png ├── train.py ├── README.md ├── vis.py ├── train_config.py ├── make_json.py └── LICENSE /lib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lib/core/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lib/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lib/core/api/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lib/core/base_trainer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lib/core/model/loss/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | #-*-coding:utf-8-*- -------------------------------------------------------------------------------- /lib/core/model/__init__.py: -------------------------------------------------------------------------------- 1 | #-*-coding:utf-8-*- -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.onnx 3 | *.pth 4 | *.csv 5 | *.log 6 | *.__pycache__ 7 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES="0" 2 | torchrun --nproc_per_node=1 train.py 3 | -------------------------------------------------------------------------------- /figures/tmp_screenshot_18.08.20192.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/610265158/face_landmark_pytorch/HEAD/figures/tmp_screenshot_18.08.20192.png -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/face_landmark_pytorch.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 14 | -------------------------------------------------------------------------------- /lib/utils/logger.py: -------------------------------------------------------------------------------- 1 | #-*-coding:utf-8-*- 2 | 3 | 4 | 5 | 6 | #-*-coding:utf-8-*- 7 | 8 | import logging 9 | 10 | 11 | def get_logger(LEVEL,log_file=None): 12 | head = '[%(asctime)-15s] [%(levelname)s] %(message)s ' 13 | if LEVEL=='info': 14 | logging.basicConfig(level=logging.INFO, format=head) 15 | elif LEVEL=='debug': 16 | logging.basicConfig(level=logging.DEBUG, format=head) 17 | logger = logging.getLogger() 18 | 19 | if log_file !=None: 20 | 21 | fh = logging.FileHandler(log_file) 22 | logger.addHandler(fh) 23 | return logger 24 | 25 | logger=get_logger('info') 26 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from lib.core.base_trainer.net_work import Train 4 | from lib.dataset.dataietr import AlaskaDataIter 5 | from torch.utils.data import Dataset, DataLoader 6 | import cv2 7 | import numpy as np 8 | import pandas as pd 9 | 10 | 11 | 12 | from train_config import config as cfg 13 | 14 | import setproctitle 15 | 16 | 17 | setproctitle.setproctitle("pks") 18 | 19 | 20 | 21 | 22 | 23 | def get_fold(df,n_folds): 24 | 25 | skf = KFold(n_splits=n_folds, shuffle=True, random_state=cfg.SEED) 26 | for fold, (train_idx, val_idx) in enumerate(skf.split(df)): 27 | df.loc[val_idx, 'fold'] = fold 28 | 29 | return df 30 | 31 | 32 | def setppm100asval(df): 33 | def func(fn): 34 | if 'PPM' in fn: 35 | return 0 36 | else: 37 | return 1 38 | df['fold']=df['image'].apply(func) 39 | 40 | 41 | return df 42 | 43 | 44 | def main(): 45 | 46 | 47 | 48 | 49 | 50 | train_df = pd.read_csv(cfg.DATA.train_f_path) 51 | 52 | val_df =pd.read_csv(cfg.DATA.val_f_path) 53 | 54 | 55 | ###build trainer 56 | trainer = Train(train_df=train_df,val_df=val_df,fold=0) 57 | 58 | ### train 59 | trainer.custom_loop() 60 | 61 | if __name__=='__main__': 62 | main() 63 | 64 | -------------------------------------------------------------------------------- /lib/utils/init.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import tensorflow as tf 4 | def init(*args): 5 | 6 | if len(args)==1: 7 | use_pb=True 8 | pb_path=args[0] 9 | else: 10 | use_pb=False 11 | meta_path=args[0] 12 | restore_model_path=args[1] 13 | 14 | def ini_ckpt(): 15 | graph = tf.Graph() 16 | graph.as_default() 17 | configProto = tf.ConfigProto() 18 | configProto.gpu_options.allow_growth = True 19 | sess = tf.Session(config=configProto) 20 | #load_model(model_path, sess) 21 | saver = tf.train.import_meta_graph(meta_path) 22 | saver.restore(sess, restore_model_path) 23 | 24 | print("Model restred!") 25 | return (graph, sess) 26 | def init_pb(model_path): 27 | config = tf.ConfigProto() 28 | config.gpu_options.per_process_gpu_memory_fraction = 0.2 29 | compute_graph = tf.Graph() 30 | compute_graph.as_default() 31 | sess = tf.Session(config=config) 32 | with tf.gfile.GFile(model_path, 'rb') as fid: 33 | graph_def = tf.GraphDef() 34 | graph_def.ParseFromString(fid.read()) 35 | tf.import_graph_def(graph_def, name='') 36 | 37 | # saver = tf.train.Saver(tf.global_variables()) 38 | # saver.save(sess, save_path='./tmp.ckpt') 39 | return (compute_graph, sess) 40 | 41 | if use_pb: 42 | model = init_pb(pb_path) 43 | else: 44 | model = ini_ckpt() 45 | 46 | graph = model[0] 47 | sess = model[1] 48 | 49 | return graph,sess -------------------------------------------------------------------------------- /lib/core/base_trainer/metric.py: -------------------------------------------------------------------------------- 1 | import sklearn 2 | from sklearn import metrics 3 | from sklearn.metrics import confusion_matrix 4 | 5 | import numpy as np 6 | import torch.nn as nn 7 | from train_config import config as cfg 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | from sklearn.metrics import roc_auc_score 13 | from lib.utils.logger import logger 14 | 15 | import warnings 16 | 17 | warnings.filterwarnings('ignore') 18 | 19 | 20 | class AverageMeter(object): 21 | """Computes and stores the average and current value""" 22 | 23 | def __init__(self): 24 | self.reset() 25 | 26 | def reset(self): 27 | self.val = 0 28 | self.avg = 0 29 | self.sum = 0 30 | self.count = 0 31 | 32 | def update(self, val, n=1): 33 | self.val = val 34 | self.sum += val * n 35 | self.count += n 36 | self.avg = self.sum / self.count 37 | 38 | 39 | class ROCAUCMeter(object): 40 | def __init__(self): 41 | self.reset() 42 | 43 | def reset(self): 44 | 45 | self.y_true_11 = None 46 | self.y_pred_11 = None 47 | 48 | def update(self, y_true, y_pred): 49 | y_true = y_true.cpu().numpy() 50 | 51 | y_pred = torch.sigmoid(y_pred).data.cpu().numpy() 52 | 53 | if self.y_true_11 is None: 54 | self.y_true_11 = y_true 55 | self.y_pred_11 = y_pred 56 | else: 57 | self.y_true_11 = np.concatenate((self.y_true_11, y_true), axis=0) 58 | self.y_pred_11 = np.concatenate((self.y_pred_11, y_pred), axis=0) 59 | 60 | @property 61 | def avg(self): 62 | 63 | aucs = [] 64 | for i in range(11): 65 | aucs.append(roc_auc_score(self.y_true_11[:, i], self.y_pred_11[:, i])) 66 | print(np.round(aucs, 4)) 67 | 68 | return np.mean(aucs) 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /lib/core/model/face_model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('.') 3 | import torch 4 | import torch.nn.functional as F 5 | import torch.nn as nn 6 | from torch.nn import Parameter 7 | 8 | 9 | from train_config import config as cfg 10 | 11 | 12 | import timm 13 | 14 | 15 | class Net(nn.Module): 16 | def __init__(self, num_classes=1000): 17 | super().__init__() 18 | 19 | 20 | self.model = timm.create_model('tf_mobilenetv3_large_minimal_100', pretrained=True,exportable=True) 21 | 22 | 23 | self._avg_pooling = nn.AdaptiveAvgPool2d(1) 24 | 25 | self._fc = nn.Linear(1280 , num_classes, bias=True) 26 | 27 | def forward(self, inputs): 28 | 29 | #do preprocess 30 | 31 | 32 | bs = inputs.size(0) 33 | # Convolution layers 34 | x = self.model.forward_features(inputs) 35 | 36 | fm = x.view(bs, -1) 37 | x = self._fc(fm) 38 | 39 | return x 40 | 41 | 42 | 43 | if __name__=='__main__': 44 | import torch 45 | import torchvision 46 | 47 | dummy_input = torch.randn(1, 3, 224, 224, device='cpu') 48 | model = Net() 49 | 50 | ### load your weights 51 | model.eval() 52 | # Providing input and output names sets the display names for values 53 | # within the model's graph. Setting these does not change the semantics 54 | # of the graph; it is only for readability. 55 | # 56 | # The inputs to the network consist of the flat list of inputs (i.e. 57 | # the values you would pass to the forward() method) followed by the 58 | # flat list of parameters. You can partially specify names, i.e. provide 59 | # a list here shorter than the number of inputs to the model, and we will 60 | # only set that subset of names, starting from the beginning. 61 | 62 | 63 | torch.onnx.export(model, dummy_input, "classifier.onnx",opset_version=11 ) 64 | -------------------------------------------------------------------------------- /tools/convert_to_onnx.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('.') 3 | 4 | 5 | from lib.core.base_trainer.model import COTRAIN 6 | 7 | 8 | import torch 9 | import torchvision 10 | import argparse 11 | import re 12 | 13 | 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--model', type=str,default=None, help='the thres for detect') 17 | args = parser.parse_args() 18 | 19 | model_path=args.model 20 | 21 | device=torch.device('cpu') 22 | dummy_input = torch.randn(1, 3,128, 128 , device='cpu') 23 | 24 | style_model = COTRAIN(inference=True).to(device) 25 | style_model.eval() 26 | if model_path is not None: 27 | 28 | state_dict = torch.load(model_path,map_location=device) 29 | # remove saved deprecated running_* keys in InstanceNorm from the checkpoint 30 | 31 | style_model.load_state_dict(state_dict,strict=False) 32 | style_model.to(device) 33 | 34 | 35 | 36 | 37 | ### load your weights 38 | style_model.eval() 39 | # Providing input and output names sets the display names for values 40 | # within the model's graph. Setting these does not change the semantics 41 | # of the graph; it is only for readability. 42 | # 43 | # The inputs to the network consist of the flat list of inputs (i.e. 44 | # the values you would pass to the forward() method) followed by the 45 | # flat list of parameters. You can partially specify names, i.e. provide 46 | # a list here shorter than the number of inputs to the model, and we will 47 | # only set that subset of names, starting from the beginning. 48 | 49 | 50 | torch.onnx.export(style_model, 51 | dummy_input, 52 | "face_kps.onnx" , 53 | # dynamic_axes={'input': {2: 'height', 54 | # 3: 'width', 55 | # }}, 56 | input_names=["input"], 57 | output_names=["output"],opset_version=12) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Note: This code is not maintained. Please refer to Peppa_Pig_Face_Landmark, The TRAIN subdir.There is better model. 2 | # face_landmark 3 | A simple face aligment method, based on pytorch 4 | 5 | 6 | ## introduction 7 | This is the pytorch branch. 8 | 9 | It is simple and flexible, trained with wingloss , multi task learning, also with data augmentation based on headpose and face attributes(eyes state and mouth state). 10 | 11 | [CN blog](https://blog.csdn.net/qq_35606924/article/details/99711208) 12 | 13 | And i suggest that you could try with another project,including face detect and keypoints, and some optimizations were made, u can check it there **[[pappa_pig_face_engine]](https://github.com/610265158/Peppa_Pig_Face_Engine).** 14 | 15 | Contact me if u have problem about it. 2120140200@mail.nankai.edu.cn :) 16 | 17 | demo pictures: 18 | 19 | ![samples](https://github.com/610265158/face_landmark/blob/master/figures/tmp_screenshot_18.08.20192.png) 20 | 21 | ![gifs](https://github.com/610265158/Peppa_Pig_Face_Engine/blob/master/figure/sample.gif) 22 | 23 | this gif is from github.com/610265158/Peppa_Pig_Face_Engine ) 24 | 25 | pretrained model: 26 | 27 | ###### shufflenetv2_1.0 28 | + [baidu disk](https://pan.baidu.com/s/1MK3wI0nrZUOA8yU0ChWvBw) (code 9x2m) 29 | 30 | 31 | 32 | ## requirment 33 | 34 | + pytorch 35 | 36 | + tensorpack (for data provider) 37 | 38 | + opencv 39 | 40 | + python 3.6 41 | 42 | 43 | ## useage 44 | 45 | ### train 46 | 47 | 1. download all the [300W](https://ibug.doc.ic.ac.uk/resources/facial-point-annotations/) data set including the [300VW](https://ibug.doc.ic.ac.uk/resources/300-VW/)(parse as images, and make the label the same formate as 300W) 48 | ``` 49 | ├── 300VW 50 | │   ├── 001_annot 51 | │   ├── 002_annot 52 | │ .... 53 | ├── 300W 54 | │   ├── 01_Indoor 55 | │   └── 02_Outdoor 56 | ├── AFW 57 | │   └── afw 58 | ├── HELEN 59 | │   ├── testset 60 | │   └── trainset 61 | ├── IBUG 62 | │   └── ibug 63 | ├── LFPW 64 | │   ├── testset 65 | │   └── trainset 66 | ``` 67 | 68 | 2. run ` python make_json.py` produce train.json and val.json 69 | (if u like train u own data, please read the json produced , it is quite simple) 70 | 71 | 3. then, run: `python train.py` 72 | 73 | 4. by default it trained with shufflenetv2_1.0 74 | 75 | 76 | ### visualization 77 | 78 | ``` 79 | python vis.py --model ./model/keypoints.pth 80 | ``` 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | -------------------------------------------------------------------------------- /vis.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | from lib.dataset.dataietr import AlaskaDataIter 4 | from train_config import config 5 | from lib.core.base_trainer.model import COTRAIN 6 | 7 | import torch 8 | import time 9 | import argparse 10 | 11 | from torch.utils.data import DataLoader 12 | import numpy as np 13 | import os 14 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" 15 | 16 | import cv2 17 | from train_config import config as cfg 18 | cfg.TRAIN.batch_size=1 19 | 20 | 21 | df=pd.read_csv(cfg.DATA.val_f_path) 22 | val_genererator = AlaskaDataIter(df, 23 | img_root=cfg.DATA.root_path, 24 | training_flag=False, shuffle=False) 25 | val_ds=DataLoader(val_genererator, 26 | cfg.TRAIN.batch_size, 27 | num_workers=1,shuffle=False) 28 | 29 | 30 | def vis(weight): 31 | device = torch.device("cuda" if torch.cuda.is_available() else 'cpu') 32 | model=COTRAIN(inference=True) 33 | model.eval() 34 | state_dict = torch.load(weight, map_location=device) 35 | model.load_state_dict(state_dict, strict=False) 36 | 37 | 38 | 39 | for step, (ids, images, kps) in enumerate(val_ds): 40 | 41 | 42 | # kps = kps.to(device).float() 43 | 44 | img_show = np.array(images)*255 45 | print(img_show.shape) 46 | 47 | img_show=np.transpose(img_show[0],axes=[1,2,0]).astype(np.uint8) 48 | img_show=np.ascontiguousarray(img_show) 49 | images=images.to(device) 50 | print(images.size()) 51 | 52 | start=time.time() 53 | with torch.no_grad(): 54 | res=model(images) 55 | res=res.detach().numpy() 56 | print(res) 57 | print('xxxx',time.time()-start) 58 | #print(res) 59 | 60 | 61 | 62 | 63 | landmark = np.array(res[0][0:136]).reshape([-1, 2]) 64 | 65 | for _index in range(landmark.shape[0]): 66 | x_y = landmark[_index] 67 | #print(x_y) 68 | cv2.circle(img_show, center=(int(x_y[0] * 128), 69 | int(x_y[1] * 128)), 70 | color=(255, 122, 122), radius=1, thickness=2) 71 | 72 | cv2.imshow('tmp',img_show) 73 | cv2.waitKey(0) 74 | 75 | 76 | def load_checkpoint(net, checkpoint): 77 | # from collections import OrderedDict 78 | # 79 | # temp = OrderedDict() 80 | # if 'state_dict' in checkpoint: 81 | # checkpoint = dict(checkpoint['state_dict']) 82 | # for k in checkpoint: 83 | # k2 = 'module.'+k if not k.startswith('module.') else k 84 | # temp[k2] = checkpoint[k] 85 | 86 | net.load_state_dict(torch.load(checkpoint,map_location=torch.device('cpu')), strict=True) 87 | if __name__=='__main__': 88 | parser = argparse.ArgumentParser(description='Start train.') 89 | 90 | parser.add_argument('--model', dest='model', type=str, default=None, \ 91 | help='the model to use') 92 | 93 | args = parser.parse_args() 94 | 95 | 96 | vis(args.model) 97 | 98 | 99 | 100 | 101 | -------------------------------------------------------------------------------- /train_config.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import os 4 | import numpy as np 5 | from easydict import EasyDict as edict 6 | 7 | config = edict() 8 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 9 | 10 | config.TRAIN = edict() 11 | #### below are params for dataiter 12 | config.TRAIN.process_num = 4 13 | config.TRAIN.prefetch_size = 30 14 | ############ 15 | 16 | config.TRAIN.num_gpu = 1 17 | config.TRAIN.batch_size = 128 18 | config.TRAIN.validatiojn_batch_size = 128 19 | config.TRAIN.accumulation_batch_size=128 20 | config.TRAIN.log_interval = 10 ##10 iters for a log msg 21 | config.TRAIN.epoch = 50 22 | config.TRAIN.early_stop=20 23 | config.TRAIN.test_interval=1 24 | 25 | config.TRAIN.init_lr = 1.e-3 26 | config.TRAIN.warmup_step=1500 27 | config.TRAIN.weight_decay_factor = 5.e-4 ####l2 28 | config.TRAIN.vis=False #### if to check the training data 29 | config.TRAIN.mix_precision=False ##use mix precision to speedup, tf1.14 at least 30 | config.TRAIN.opt='AdamW' ##Adam or SGD 31 | config.TRAIN.gradient_clip=5 32 | 33 | config.MODEL = edict() 34 | config.MODEL.model_path = './models/' ## save directory 35 | config.MODEL.hin = 128 # input size during training , 128,160, depends on 36 | config.MODEL.win = 128 37 | 38 | config.MODEL.out_channel=136+3+4 # output vector 68 points , 3 headpose ,4 cls params,(left eye, right eye, mouth, big mouth open) 39 | 40 | config.MODEL.pretrained_model='./models/fold0_epoch_8_val_loss_12.163321_val_mse_0.249408.pth' 41 | config.DATA = edict() 42 | 43 | config.DATA.root_path='' 44 | config.DATA.train_f_path='train.csv' 45 | config.DATA.val_f_path='val.csv' 46 | 47 | ############the model is trained with RGB mode 48 | config.DATA.root_path='../tmp_crop_data_face_landmark_pytorch' 49 | 50 | config.DATA.base_extend_range=[0.1,0.2] ###extand 51 | config.DATA.scale_factor=[0.7,1.35] ###scales 52 | 53 | config.DATA.symmetry = [(0, 16), (1, 15), (2, 14), (3, 13), (4, 12), (5, 11), (6, 10), (7, 9), (8, 8), 54 | (17, 26), (18, 25), (19, 24), (20, 23), (21, 22), 55 | (31, 35), (32, 34), 56 | (36, 45), (37, 44), (38, 43), (39, 42), (40, 47), (41, 46), 57 | (48, 54), (49, 53), (50, 52), (55, 59), (56, 58), (60, 64), (61, 63), (65, 67)] 58 | 59 | 60 | 61 | 62 | 63 | weights=[1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1., #####bouding 64 | 1.5,1.5,1.5,1.5,1.5, 1.5,1.5,1.5,1.5,1.5, #####eyebows 65 | 1.,1.,1.,1.,1.,1.,1.,1.,1., #####nose 66 | 1.5,1.5,1.5,1.5,1.5,1.5, 1.5,1.5,1.5,1.5,1.5,1.5, ####eyes 67 | 1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1. #####mouth 68 | ] 69 | weights_xy=[[x,x] for x in weights] 70 | 71 | config.DATA.weights = np.array(weights_xy,dtype=np.float32).reshape([-1]) 72 | 73 | 74 | 75 | 76 | 77 | 78 | -------------------------------------------------------------------------------- /lib/core/api/keypoint.py: -------------------------------------------------------------------------------- 1 | #-*-coding:utf-8-*- 2 | import tensorflow as tf 3 | import cv2 4 | import numpy as np 5 | import time 6 | 7 | from lib.utils.init import init 8 | from train_config import config 9 | 10 | 11 | class Keypoints: 12 | 13 | def __init__(self,pb): 14 | self.PADDING_FLAG=True 15 | self.ANCHOR_SIZE=0 16 | self._pb_path=pb 17 | 18 | self._graph = tf.Graph() 19 | self._sess = tf.Session(graph=self._graph) 20 | 21 | with self._graph.as_default(): 22 | 23 | self._graph, self._sess = init(self._pb_path) 24 | self.img_input = tf.get_default_graph().get_tensor_by_name('tower_0/images:0') 25 | self.embeddings_keypoints = tf.get_default_graph().get_tensor_by_name('tower_0/prediction:0') 26 | self.training = tf.get_default_graph().get_tensor_by_name('training_flag:0') 27 | 28 | 29 | def simple_run(self,img): 30 | 31 | with self._graph.as_default(): 32 | img=np.expand_dims(img,axis=0) 33 | 34 | star=time.time() 35 | for i in range(10): 36 | [landmarkers] = self._sess.run([self.embeddings_keypoints], \ 37 | feed_dict={self.img_input: img, \ 38 | self.training: False}) 39 | print((time.time()-star)/10) 40 | return landmarkers 41 | def run(self,_img,_bboxs): 42 | img_batch = [] 43 | coordinate_correct_batch = [] 44 | pad_batch = [] 45 | 46 | res_landmarks_batch=[] 47 | res_state_batch=[] 48 | with self._sess.as_default(): 49 | with self._graph.as_default(): 50 | for _box in _bboxs: 51 | crop_img=self._one_shot_preprocess(_img,_box) 52 | 53 | landmarkers = self._sess.run([self.embeddings_keypoints], \ 54 | feed_dict={self.img_input: crop_img, \ 55 | self.training: False}) 56 | 57 | 58 | 59 | 60 | def _one_shot_preprocess(self,image,bbox): 61 | 62 | 63 | add = int(np.max(bbox)) 64 | bimg = cv2.copyMakeBorder(image, add, add, add, add, borderType=cv2.BORDER_CONSTANT, 65 | value=config.DATA.pixel_means.tolist()) 66 | 67 | bbox += add 68 | 69 | 70 | bbox_width = bbox[2] - bbox[0] 71 | bbox_height = bbox[3] - bbox[1] 72 | 73 | ###extend 74 | bbox[0] -= config.DATA.base_extend_range[0] * bbox_width 75 | bbox[1] -= config.DATA.base_extend_range[1] * bbox_height 76 | bbox[2] += config.DATA.base_extend_range[0] * bbox_width 77 | bbox[3] += config.DATA.base_extend_range[1] * bbox_height 78 | 79 | 80 | ## 81 | bbox = bbox.astype(np.int) 82 | crop_image = bimg[bbox[1]:bbox[3], bbox[0]:bbox[2], :] 83 | crop_image = cv2.resize(crop_image, (config.MODEL.hin, config.MODEL.win), 84 | interpolation=cv2.INTER_AREA) 85 | 86 | 87 | if config.MODEL.channel==1: 88 | crop_image=cv2.cvtColor(crop_image,cv2.COLOR_RGB2GRAY) 89 | crop_image=np.expand_dims(crop_image,axis=-1) 90 | 91 | 92 | return crop_image -------------------------------------------------------------------------------- /lib/dataset/headpose.py: -------------------------------------------------------------------------------- 1 | #-*-coding:utf-8-*- 2 | 3 | 4 | import cv2 5 | import numpy as np 6 | 7 | 8 | 9 | 10 | # object_pts = np.float32([[6.825897, 6.760612, 4.402142], 11 | # [1.330353, 7.122144, 6.903745], 12 | # [-1.330353, 7.122144, 6.903745], 13 | # [-6.825897, 6.760612, 4.402142], 14 | # [5.311432, 5.485328, 3.987654], 15 | # [1.789930, 5.393625, 4.413414], 16 | # [-1.789930, 5.393625, 4.413414], 17 | # [-5.311432, 5.485328, 3.987654], 18 | # [2.005628, 1.409845, 6.165652], 19 | # [-2.005628, 1.409845, 6.165652], 20 | # [2.774015, -2.080775, 5.048531], 21 | # [-2.774015, -2.080775, 5.048531], 22 | # [0.000000, -3.116408, 6.097667], 23 | # [0.000000, -7.415691, 4.070434]]) 24 | object_pts = np.float32([[6.825897, 6.760612, 4.402142], 25 | [1.330353, 7.122144, 6.903745], 26 | [-1.330353, 7.122144, 6.903745], 27 | [-6.825897, 6.760612, 4.402142], 28 | [5.311432, 5.485328, 3.987654], 29 | [1.789930, 5.393625, 4.413414], 30 | [-1.789930, 5.393625, 4.413414], 31 | [-5.311432, 5.485328, 3.987654], 32 | [2.005628, 1.409845, 6.165652], 33 | [-2.005628, 1.409845, 6.165652]]) 34 | reprojectsrc = np.float32([[10.0, 10.0, 10.0], 35 | [10.0, 10.0, -10.0], 36 | [10.0, -10.0, -10.0], 37 | [10.0, -10.0, 10.0], 38 | [-10.0, 10.0, 10.0], 39 | [-10.0, 10.0, -10.0], 40 | [-10.0, -10.0, -10.0], 41 | [-10.0, -10.0, 10.0]]) 42 | 43 | line_pairs = [[0, 1], [1, 2], [2, 3], [3, 0], 44 | [4, 5], [5, 6], [6, 7], [7, 4], 45 | [0, 4], [1, 5], [2, 6], [3, 7]] 46 | 47 | 48 | def get_head_pose(shape,img): 49 | h,w,_=img.shape 50 | K = [w, 0.0, w//2, 51 | 0.0, w, h//2, 52 | 0.0, 0.0, 1.0] 53 | # Assuming no lens distortion 54 | D = [0, 0, 0.0, 0.0, 0] 55 | 56 | cam_matrix = np.array(K).reshape(3, 3).astype(np.float32) 57 | dist_coeffs = np.array(D).reshape(5, 1).astype(np.float32) 58 | 59 | 60 | 61 | # image_pts = np.float32([shape[17], shape[21], shape[22], shape[26], shape[36], 62 | # shape[39], shape[42], shape[45], shape[31], shape[35], 63 | # shape[48], shape[54], shape[57], shape[8]]) 64 | image_pts = np.float32([shape[17], shape[21], shape[22], shape[26], shape[36], 65 | shape[39], shape[42], shape[45], shape[31], shape[35]]) 66 | _, rotation_vec, translation_vec = cv2.solvePnP(object_pts, image_pts, cam_matrix, dist_coeffs) 67 | 68 | reprojectdst, _ = cv2.projectPoints(reprojectsrc, rotation_vec, translation_vec, cam_matrix, 69 | dist_coeffs) 70 | 71 | reprojectdst = tuple(map(tuple, reprojectdst.reshape(8, 2))) 72 | 73 | # calc euler angle 74 | rotation_mat, _ = cv2.Rodrigues(rotation_vec) 75 | pose_mat = cv2.hconcat((rotation_mat, translation_vec)) 76 | _, _, _, _, _, _, euler_angle = cv2.decomposeProjectionMatrix(pose_mat) 77 | 78 | return reprojectdst, euler_angle 79 | -------------------------------------------------------------------------------- /tools/convert_to_coreml.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('.') 3 | 4 | from lib.core.base_trainer.model import Net,COTRAIN 5 | 6 | import argparse 7 | 8 | import re 9 | 10 | parser = argparse.ArgumentParser() 11 | 12 | parser.add_argument('--model', type=str,default=None, help='the thres for detect') 13 | args = parser.parse_args() 14 | 15 | model_path=args.model 16 | 17 | import warnings 18 | warnings.simplefilter(action='ignore', category=FutureWarning) 19 | 20 | import torch 21 | 22 | import coremltools as ct 23 | from coremltools.proto import FeatureTypes_pb2 as ft 24 | device=torch.device("cuda" if torch.cuda.is_available() else 'cpu') 25 | dummy_input = torch.randn(1, 3,288, 160, device='cpu') 26 | style_model = COTRAIN(inference=True).to(device) 27 | ### load your weights 28 | style_model.eval() 29 | 30 | 31 | 32 | if model_path is not None: 33 | state_dict = torch.load(model_path,map_location=device) 34 | # remove saved deprecated running_* keys in InstanceNorm from the checkpoint 35 | 36 | style_model.load_state_dict(state_dict) 37 | style_model.to(device) 38 | style_model.eval() 39 | trace = torch.jit.trace(style_model, dummy_input) 40 | 41 | # shapes = [(1,3, 128*i, 128*j) for i in range(1, 10) for j in range(1,10)] 42 | # input_shape = ct.EnumeratedShapes(shapes=shapes) 43 | 44 | # input_shape = ct.Shape(shape=(1, 3,ct.RangeDim(256, 256*10), ct.RangeDim(256, 256*10))) 45 | # Convert the model 46 | mlmodel = ct.convert( 47 | trace, 48 | inputs=[ct.ImageType(name="__input", scale=1/255., 49 | shape=dummy_input.shape)], 50 | minimum_deployment_target=ct.target.iOS14 51 | 52 | ) 53 | 54 | 55 | 56 | 57 | spec = mlmodel.get_spec() 58 | 59 | # Edit the spec 60 | ct.utils.rename_feature(spec, '__input', 'image') 61 | ct.utils.rename_feature(spec, 'var_917', 'output') 62 | 63 | 64 | # to update the shapes of the first input 65 | # input_name = spec.description.input[0].name 66 | # img_size_ranges = ct.models.neural_network.flexible_shape_utils.NeuralNetworkImageSizeRange() 67 | # img_size_ranges.add_height_range((256, -1)) 68 | # img_size_ranges.add_width_range((256, -1)) 69 | 70 | # ct.models.neural_network.flexible_shape_utils.update_image_size_range(spec, 71 | # feature_name=input_name,size_range=img_size_ranges) 72 | 73 | # save out the updated model 74 | mlmodel = ct.models.MLModel(spec) 75 | print(mlmodel) 76 | print(spec.description.output) 77 | # for output in spec.description.output: 78 | # # Change output type 79 | # output.type.imageType.colorSpace = ft.ImageFeatureType.ColorSpace.Value('RGB') 80 | # # channels, height, width = tuple(output.type.multiArrayType.shape) 81 | # # 82 | # # # Set image shape 83 | # output.type.imageType.width = 256 84 | # output.type.imageType.height = 256 85 | 86 | # output_name = spec.description.output[0].name 87 | # ct.models.neural_network.flexible_shape_utils.update_image_size_range(spec, 88 | # feature_name=output_name,size_range=img_size_ranges) 89 | 90 | 91 | mlmodel = ct.models.MLModel(spec) 92 | print(mlmodel) 93 | 94 | 95 | from coremltools.models.neural_network import quantization_utils 96 | from coremltools.models.neural_network.quantization_utils import AdvancedQuantizedLayerSelector 97 | 98 | selector = AdvancedQuantizedLayerSelector( 99 | skip_layer_types=['batchnorm','bias'], 100 | minimum_conv_kernel_channels=4, 101 | minimum_conv_weight_count=4096 102 | ) 103 | 104 | model_fp8 = quantization_utils.quantize_weights(mlmodel, nbits=8,quantization_mode='linear',selector=selector) 105 | 106 | 107 | 108 | fp_8_file='./neural_style.mlmodel' 109 | model_fp8.save(fp_8_file) -------------------------------------------------------------------------------- /lib/core/model/loss/simpleface_loss.py: -------------------------------------------------------------------------------- 1 | # -*-coding:utf-8-*- 2 | import sys 3 | 4 | 5 | import torch 6 | import math 7 | import numpy as np 8 | 9 | from train_config import config as cfg 10 | 11 | bce_losser=torch.nn.BCEWithLogitsLoss() 12 | def _wing_loss(landmarks, labels, w=10.0, epsilon=2.0, weights=1.): 13 | """ 14 | Arguments: 15 | landmarks, labels: float tensors with shape [batch_size, landmarks]. landmarks means x1,x2,x3,x4...y1,y2,y3,y4 1-D 16 | w, epsilon: a float numbers. 17 | Returns: 18 | a float tensor with shape []. 19 | """ 20 | 21 | x = landmarks - labels 22 | c = w * (1.0 - math.log(1.0 + w / epsilon)) 23 | absolute_x = torch.abs(x) 24 | losses = torch.where( 25 | torch.gt(absolute_x, w), absolute_x - c, 26 | w * torch.log(1.0 + absolute_x / epsilon) 27 | 28 | ) 29 | losses = losses * torch.tensor(cfg.DATA.weights,device='cuda') 30 | loss = torch.sum(torch.mean(losses * weights, dim=[0])) 31 | 32 | return loss 33 | 34 | 35 | def _mse(landmarks, labels, weights=1.): 36 | return torch.mean(0.5 * (landmarks - labels) *(landmarks - labels)* weights) 37 | 38 | 39 | def l1(landmarks, labels): 40 | return torch.mean(landmarks - labels) 41 | 42 | 43 | def calculate_loss(predict_keypoints, label_keypoints): 44 | 45 | landmark_label = label_keypoints[:, 0:136] 46 | pose_label = label_keypoints[:, 136:139] 47 | leye_cls_label = label_keypoints[:, 139] 48 | reye_cls_label = label_keypoints[:, 140] 49 | mouth_cls_label = label_keypoints[:, 141] 50 | big_mouth_cls_label = label_keypoints[:, 142] 51 | 52 | landmark_predict = predict_keypoints[:, 0:136] 53 | pose_predict = predict_keypoints[:, 136:139] 54 | leye_cls_predict = predict_keypoints[:, 139] 55 | reye_cls_predict = predict_keypoints[:, 140] 56 | mouth_cls_predict = predict_keypoints[:, 141] 57 | big_mouth_cls_predict = predict_keypoints[:, 142] 58 | 59 | loss = _wing_loss(landmark_predict, landmark_label) 60 | 61 | loss_pose = _mse(pose_predict, pose_label) 62 | 63 | leye_loss = bce_losser(leye_cls_predict,leye_cls_label) 64 | reye_loss = bce_losser(reye_cls_predict,reye_cls_label) 65 | 66 | mouth_loss = bce_losser(mouth_cls_predict,mouth_cls_label) 67 | mouth_loss_big = bce_losser(big_mouth_cls_predict,big_mouth_cls_label) 68 | mouth_loss = mouth_loss + mouth_loss_big 69 | 70 | # ##make crosssentropy 71 | # leye_cla_correct_prediction = tf.equal( 72 | # tf.cast(tf.greater_equal(tf.nn.sigmoid(leye_cls_predict), 0.5), tf.int32), 73 | # tf.cast(leye_cla_label, tf.int32)) 74 | # leye_cla_accuracy = tf.reduce_mean(tf.cast(leye_cla_correct_prediction, tf.float32)) 75 | # 76 | # reye_cla_correct_prediction = tf.equal( 77 | # tf.cast(tf.greater_equal(tf.nn.sigmoid(reye_cla_predict), 0.5), tf.int32), 78 | # tf.cast(reye_cla_label, tf.int32)) 79 | # reye_cla_accuracy = tf.reduce_mean(tf.cast(reye_cla_correct_prediction, tf.float32)) 80 | # mouth_cla_correct_prediction = tf.equal( 81 | # tf.cast(tf.greater_equal(tf.nn.sigmoid(mouth_cla_predict), 0.5), tf.int32), 82 | # tf.cast(mouth_cla_label, tf.int32)) 83 | # mouth_cla_accuracy = tf.reduce_mean(tf.cast(mouth_cla_correct_prediction, tf.float32)) 84 | 85 | #### l2 regularization_losses 86 | # l2_loss = [] 87 | # l2_reg = tf.keras.regularizers.l2(cfg.TRAIN.weight_decay_factor) 88 | # variables_restore = tf.get_collection(tf.GraphKeys.MODEL_VARIABLES) 89 | # for var in variables_restore: 90 | # if 'weight' in var.name: 91 | # l2_loss.append(l2_reg(var)) 92 | # regularization_losses = tf.add_n(l2_loss, name='l1_loss') 93 | 94 | return loss + loss_pose + leye_loss + reye_loss + mouth_loss 95 | 96 | 97 | -------------------------------------------------------------------------------- /lib/dataset/augmentor/visual_augmentation.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import random 4 | 5 | def pixel_jitter(src,p=0.5,max_=5.): 6 | 7 | src=src.astype(np.float32) 8 | 9 | pattern=(np.random.rand(src.shape[0], src.shape[1],src.shape[2])-0.5)*2*max_ 10 | img = src + pattern 11 | 12 | img[img<0]=0 13 | img[img >255] = 255 14 | 15 | img = img.astype(np.uint8) 16 | 17 | return img 18 | 19 | def gray(src): 20 | g_img=cv2.cvtColor(src,cv2.COLOR_RGB2GRAY) 21 | src[:,:,0]=g_img 22 | src[:,:,1]=g_img 23 | src[:,:,2]=g_img 24 | return src 25 | 26 | def swap_change(src): 27 | a = [0,1,2] 28 | 29 | k = random.sample(a, 3) 30 | 31 | res=src.copy() 32 | res[:,:,0]=src[:,:,k[0]] 33 | res[:, :, 1] = src[:, :, k[1]] 34 | res[:, :, 2] = src[:, :, k[2]] 35 | return res 36 | 37 | 38 | def Img_dropout(src,max_pattern_ratio=0.05): 39 | pattern=np.ones_like(src) 40 | width_ratio = random.uniform(0, max_pattern_ratio) 41 | height_ratio = random.uniform(0, max_pattern_ratio) 42 | width=src.shape[1] 43 | height=src.shape[0] 44 | block_width=width*width_ratio 45 | block_height=height*height_ratio 46 | width_start=int(random.uniform(0,width-block_width)) 47 | width_end=int(width_start+block_width) 48 | height_start=int(random.uniform(0,height-block_height)) 49 | height_end=int(height_start+block_height) 50 | pattern[height_start:height_end,width_start:width_end,:]=0 51 | img=src*pattern 52 | return img 53 | 54 | 55 | 56 | def blur_heatmap(src, ksize=(3, 3)): 57 | for i in range(src.shape[2]): 58 | src[:, :, i] = cv2.GaussianBlur(src[:, :, i], ksize, 0) 59 | amin, amax = src[:, :, i].min(), src[:, :, i].max() # 求最大最小值 60 | if amax>0: 61 | src[:, :, i] = (src[:, :, i] - amin) / (amax - amin) # (矩阵元素-最小值)/(最大值-最小值) 62 | return src 63 | def blur(src,ksize=(3,3)): 64 | for i in range(src.shape[2]): 65 | src[:, :, i]=cv2.GaussianBlur(src[:, :, i],ksize,1.5) 66 | return src 67 | 68 | 69 | 70 | 71 | def adjust_contrast(image, factor): 72 | """ Adjust contrast of an image. 73 | Args 74 | image: Image to adjust. 75 | factor: A factor for adjusting contrast. 76 | """ 77 | mean = image.mean(axis=0).mean(axis=0) 78 | return _clip((image - mean) * factor + mean) 79 | 80 | 81 | def adjust_brightness(image, delta): 82 | """ Adjust brightness of an image 83 | Args 84 | image: Image to adjust. 85 | delta: Brightness offset between -1 and 1 added to the pixel values. 86 | """ 87 | return _clip(image + delta * 255) 88 | 89 | 90 | def adjust_hue(image, delta): 91 | """ Adjust hue of an image. 92 | Args 93 | image: Image to adjust. 94 | delta: An interval between -1 and 1 for the amount added to the hue channel. 95 | The values are rotated if they exceed 180. 96 | """ 97 | image[..., 0] = np.mod(image[..., 0] + delta * 180, 180) 98 | return image 99 | 100 | 101 | def adjust_saturation(image, factor): 102 | """ Adjust saturation of an image. 103 | Args 104 | image: Image to adjust. 105 | factor: An interval for the factor multiplying the saturation values of each pixel. 106 | """ 107 | image[..., 1] = np.clip(image[..., 1] * factor, 0, 255) 108 | return image 109 | 110 | 111 | def _clip(image): 112 | """ 113 | Clip and convert an image to np.uint8. 114 | Args 115 | image: Image to clip. 116 | """ 117 | return np.clip(image, 0, 255).astype(np.uint8) 118 | def _uniform(val_range): 119 | """ Uniformly sample from the given range. 120 | Args 121 | val_range: A pair of lower and upper bound. 122 | """ 123 | return np.random.uniform(val_range[0], val_range[1]) 124 | 125 | 126 | class ColorDistort(): 127 | 128 | def __init__( 129 | self, 130 | contrast_range=(0.8, 1.2), 131 | brightness_range=(-.2, .2), 132 | hue_range=(-0.1, 0.1), 133 | saturation_range=(0.8, 1.2) 134 | ): 135 | self.contrast_range = contrast_range 136 | self.brightness_range = brightness_range 137 | self.hue_range = hue_range 138 | self.saturation_range = saturation_range 139 | 140 | def __call__(self, image): 141 | 142 | 143 | if self.contrast_range is not None: 144 | contrast_factor = _uniform(self.contrast_range) 145 | image = adjust_contrast(image,contrast_factor) 146 | if self.brightness_range is not None: 147 | brightness_delta = _uniform(self.brightness_range) 148 | image = adjust_brightness(image, brightness_delta) 149 | 150 | if self.hue_range is not None or self.saturation_range is not None: 151 | 152 | image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) 153 | 154 | if self.hue_range is not None: 155 | hue_delta = _uniform(self.hue_range) 156 | image = adjust_hue(image, hue_delta) 157 | 158 | if self.saturation_range is not None: 159 | saturation_factor = _uniform(self.saturation_range) 160 | image = adjust_saturation(image, saturation_factor) 161 | 162 | image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) 163 | 164 | return image 165 | 166 | 167 | 168 | 169 | class DsfdVisualAug(): 170 | pass -------------------------------------------------------------------------------- /make_json.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import json 5 | import traceback 6 | import cv2 7 | import pandas as pd 8 | 9 | from tqdm import tqdm 10 | ''' 11 | i decide to merge more data from CelebA, the data anns will be complex, so json maybe a better way. 12 | ''' 13 | 14 | 15 | 16 | 17 | 18 | data_dir='/media/lz/ssd_2/coco_data/facelandmark/PUB' ########points to your director,300w 19 | #celeba_data_dir='CELEBA' ########points to your director,CELEBA 20 | 21 | 22 | train_json='./train.csv' 23 | val_json='./val.csv' 24 | save_dir='../tmp_crop_data_face_landmark_pytorch' 25 | 26 | if not os.access(save_dir,os.F_OK): 27 | os.mkdir(save_dir) 28 | 29 | def GetFileList(dir, fileList): 30 | newDir = dir 31 | if os.path.isfile(dir): 32 | fileList.append(dir) 33 | elif os.path.isdir(dir): 34 | for s in os.listdir(dir): 35 | 36 | # if s == "pts": 37 | # continue 38 | newDir=os.path.join(dir,s) 39 | GetFileList(newDir, fileList) 40 | return fileList 41 | 42 | 43 | 44 | 45 | pic_list=[] 46 | GetFileList(data_dir,pic_list) 47 | 48 | pic_list=[x for x in pic_list if '.jpg' in x or 'png' in x or 'jpeg' in x ] 49 | 50 | 51 | ratio=0.95 52 | train_list=[x for x in pic_list if 'AFW' not in x] 53 | val_list=[x for x in pic_list if 'AFW' in x] 54 | 55 | # train_list=[x for x in pic_list if '300W/' not in x] 56 | # val_list=[x for x in pic_list if '300W/' in x] 57 | 58 | 59 | def process_data(data_list,csv_nm): 60 | 61 | 62 | global cnt 63 | image_list=[] 64 | keypoint_list=[] 65 | 66 | for pic in tqdm(data_list): 67 | one_image_ann={} 68 | 69 | ### image_path 70 | one_image_ann['image_path_raw']=pic 71 | 72 | #### keypoints 73 | pts=pic.rsplit('.',1)[0]+'.pts' 74 | if os.access(pic,os.F_OK) and os.access(pts,os.F_OK): 75 | try: 76 | tmp=[] 77 | with open(pts) as p_f: 78 | labels=p_f.readlines()[3:-1] 79 | for _one_p in labels: 80 | xy = _one_p.rstrip().split(' ') 81 | tmp.append([float(xy[0]),float(xy[1])]) 82 | 83 | one_image_ann['keypoints'] = tmp 84 | 85 | label = np.array(tmp).reshape((-1, 2)) 86 | bbox = [float(np.min(label[:, 0])), float(np.min(label[:, 1])), float(np.max(label[:, 0])), float(np.max(label[:, 1]))] 87 | one_image_ann['bbox'] = bbox 88 | 89 | ### placeholder 90 | one_image_ann['attr'] = None 91 | 92 | ###### crop it 93 | 94 | image=cv2.imread(one_image_ann['image_path_raw'],cv2.IMREAD_COLOR) 95 | 96 | h,w,c=image.shape 97 | 98 | ##expanded for 99 | bbox_int = [int(x) for x in bbox] 100 | bbox_width = bbox_int[2] - bbox_int[0] 101 | bbox_height = bbox_int[3] - bbox_int[1] 102 | 103 | center_x=(bbox_int[2] + bbox_int[0])//2 104 | center_y=(bbox_int[3] + bbox_int[1])//2 105 | 106 | x1=int(center_x-bbox_width*2) 107 | x1=x1 if x1>=0 else 0 108 | 109 | y1 = int(center_y - bbox_height*2) 110 | y1 = y1 if y1 >= 0 else 0 111 | 112 | x2 = int(center_x + bbox_width*2) 113 | x2 = x2 if x2 512: 124 | scale=512/max(hh,ww) 125 | else: 126 | scale=1 127 | crop_face=cv2.resize(crop_face,None,fx=scale,fy=scale) 128 | 129 | one_image_ann['bbox'][0] *= scale 130 | one_image_ann['bbox'][1] *= scale 131 | one_image_ann['bbox'][2] *= scale 132 | one_image_ann['bbox'][3] *= scale 133 | 134 | 135 | x1*=scale 136 | y1 *= scale 137 | x2 *= scale 138 | y2 *= scale 139 | for i in range(len(one_image_ann['keypoints'])): 140 | one_image_ann['keypoints'][i][0]*= scale 141 | one_image_ann['keypoints'][i][1]*= scale 142 | 143 | fname= one_image_ann['image_path_raw'].split('PUB/')[-1] 144 | 145 | fname=fname.replace('/','_').replace('/','_') 146 | 147 | 148 | # cv2.imwrite(one_image_ann['image_name'],crop_face) 149 | 150 | 151 | one_image_ann['bbox'][0] -= x1 152 | one_image_ann['bbox'][1] -= y1 153 | one_image_ann['bbox'][2] -= x1 154 | one_image_ann['bbox'][3] -= y1 155 | 156 | for i in range(len(one_image_ann['keypoints'])): 157 | one_image_ann['keypoints'][i][0]-=x1 158 | one_image_ann['keypoints'][i][1]-=y1 159 | 160 | 161 | keypoint=list(np.array(one_image_ann['keypoints']).reshape(-1).astype(np.float32)) 162 | # [x1,y1,x2,y2]=[int(x) for x in one_image_ann['bbox']] 163 | # 164 | # cv2.rectangle(crop_face,(x1,y1),(x2,y2),thickness=2,color=(255,0,0)) 165 | # 166 | # landmark=np.array(one_image_ann['keypoints']) 167 | # 168 | # for _index in range(landmark.shape[0]): 169 | # x_y = landmark[_index] 170 | # # print(x_y) 171 | # cv2.circle(crop_face, center=(int(x_y[0] ), 172 | # int(x_y[1] )), 173 | # color=(255, 0, 0), radius=2, thickness=4) 174 | # 175 | # 176 | # cv2.imshow('ss', crop_face) 177 | # cv2.waitKey(0) 178 | 179 | image_list.append(fname) 180 | keypoint_list.append(keypoint) 181 | # json_list.append(one_image_ann) 182 | except: 183 | print(pic) 184 | 185 | print(traceback.print_exc()) 186 | 187 | 188 | # with open(json_nm, 'w') as f: 189 | # json.dump(json_list, f, indent=2) 190 | 191 | data_dict={'image':image_list, 192 | 'keypoint':keypoint_list} 193 | df=pd.DataFrame(data_dict) 194 | 195 | df.to_csv(csv_nm,index=False) 196 | 197 | 198 | process_data(train_list,train_json) 199 | 200 | 201 | process_data(val_list,val_json) 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /lib/dataset/dataietr.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import random 4 | import cv2 5 | import json 6 | import numpy as np 7 | import pandas as pd 8 | from tqdm import tqdm 9 | 10 | from lib.utils.logger import logger 11 | 12 | 13 | import traceback 14 | 15 | from train_config import config as cfg 16 | import albumentations as A 17 | import os 18 | import copy 19 | from lib.dataset.augmentor.augmentation import Rotate_aug,\ 20 | Affine_aug,\ 21 | Mirror,\ 22 | Padding_aug 23 | 24 | from lib.dataset.headpose import get_head_pose 25 | 26 | 27 | 28 | cv2.setNumThreads(0) 29 | cv2.ocl.setUseOpenCL(False) 30 | 31 | class AlaskaDataIter(): 32 | def __init__(self, df,img_root, 33 | training_flag=True,shuffle=True): 34 | self.eye_close_thres = 0.02 35 | self.mouth_close_thres = 0.02 36 | self.big_mouth_open_thres = 0.08 37 | 38 | 39 | 40 | self.training_flag = training_flag 41 | self.shuffle = shuffle 42 | self.img_root_path=img_root 43 | 44 | 45 | self.df=df 46 | if self.training_flag: 47 | self.balance() 48 | 49 | self.train_trans = A.Compose([ 50 | 51 | A.RandomBrightnessContrast(p=0.3), 52 | A.HueSaturationValue(p=0.3), 53 | A.OneOf([A.GaussianBlur(), 54 | A.MotionBlur()], 55 | p=0.1), 56 | A.ToGray(p=0.1), 57 | A.OneOf([A.GaussNoise(), 58 | A.ISONoise()], 59 | p=0.1), 60 | ]) 61 | 62 | 63 | self.val_trans=A.Compose([ 64 | 65 | A.Resize(height=cfg.MODEL.hin, 66 | width=cfg.MODEL.win) 67 | 68 | ]) 69 | 70 | 71 | 72 | def __getitem__(self, item): 73 | 74 | return self.single_map_func(self.df.iloc[item], self.training_flag) 75 | 76 | def __len__(self): 77 | 78 | return len(self.df) 79 | 80 | def balance(self, ): 81 | df = copy.deepcopy(self.df) 82 | 83 | expanded=[] 84 | lar_count = 0 85 | for i in tqdm(range(len(df))): 86 | 87 | ann=df.iloc[i] 88 | 89 | ### 300w balance, according to keypoints 90 | 91 | label = ann['keypoint'][1:-1].split(',') 92 | label = [float(x) for x in label] 93 | label = np.array(label).reshape([-1, 2]) 94 | 95 | bbox = [float(np.min(label[:, 0])), float(np.min(label[:, 1])), float(np.max(label[:, 0])), 96 | float(np.max(label[:, 1]))] 97 | 98 | bbox_width = bbox[2] - bbox[0] 99 | bbox_height = bbox[3] - bbox[1] 100 | 101 | # if bbox_width < 50 or bbox_height < 50: 102 | # res_anns.remove(ann) 103 | 104 | left_eye_close = np.sqrt( 105 | np.square(label[37, 0] - label[41, 0]) + 106 | np.square(label[37, 1] - label[41, 1])) / bbox_height < self.eye_close_thres \ 107 | or np.sqrt(np.square(label[38, 0] - label[40, 0]) + 108 | np.square(label[38, 1] - label[40, 1])) / bbox_height < self.eye_close_thres 109 | right_eye_close = np.sqrt( 110 | np.square(label[43, 0] - label[47, 0]) + 111 | np.square(label[43, 1] - label[47, 1])) / bbox_height < self.eye_close_thres \ 112 | or np.sqrt(np.square(label[44, 0] - label[46, 0]) + 113 | np.square( 114 | label[44, 1] - label[46, 1])) / bbox_height < self.eye_close_thres 115 | if left_eye_close or right_eye_close: 116 | for i in range(5): 117 | expanded.append(ann) 118 | ###half face 119 | if np.sqrt(np.square(label[36, 0] - label[45, 0]) + 120 | np.square(label[36, 1] - label[45, 1])) / bbox_width < 0.5: 121 | for i in range(10): 122 | expanded.append(ann) 123 | 124 | if np.sqrt(np.square(label[62, 0] - label[66, 0]) + 125 | np.square(label[62, 1] - label[66, 1])) / bbox_height > 0.15: 126 | for i in range(10): 127 | expanded.append(ann) 128 | 129 | if np.sqrt(np.square(label[62, 0] - label[66, 0]) + 130 | np.square(label[62, 1] - label[66, 1])) / cfg.MODEL.hin > self.big_mouth_open_thres: 131 | for i in range(25): 132 | expanded.append(ann) 133 | ##########eyes diff aug 134 | if left_eye_close and not right_eye_close: 135 | for i in range(20): 136 | expanded.append(ann) 137 | lar_count += 1 138 | if not left_eye_close and right_eye_close: 139 | for i in range(20): 140 | expanded.append(ann) 141 | lar_count += 15 142 | 143 | self.df=self.df.append(expanded) 144 | logger.info('befor balance the dataset contains %d images' % (len(df))) 145 | logger.info('after balanced the datasets contains %d samples' % (len(self.df))) 146 | 147 | def augmentationCropImage(self, img, bbox, joints=None, is_training=True): 148 | 149 | bbox = np.array(bbox).reshape(4, ).astype(np.float32) 150 | add = max(img.shape[0], img.shape[1]) 151 | 152 | bimg = cv2.copyMakeBorder(img, add, add, add, add, borderType=cv2.BORDER_CONSTANT) 153 | 154 | objcenter = np.array([(bbox[0] + bbox[2]) / 2., (bbox[1] + bbox[3]) / 2.]) 155 | bbox += add 156 | objcenter += add 157 | 158 | joints[:, :2] += add 159 | 160 | gt_width = (bbox[2] - bbox[0]) 161 | gt_height = (bbox[3] - bbox[1]) 162 | 163 | crop_width_half = gt_width * (1 + cfg.DATA.base_extend_range[0] * 2) // 2 164 | crop_height_half = gt_height * (1 + cfg.DATA.base_extend_range[1] * 2) // 2 165 | 166 | if is_training: 167 | min_x = int(objcenter[0] - crop_width_half + \ 168 | random.uniform(-cfg.DATA.base_extend_range[0], cfg.DATA.base_extend_range[0]) * gt_width) 169 | max_x = int(objcenter[0] + crop_width_half + \ 170 | random.uniform(-cfg.DATA.base_extend_range[0], cfg.DATA.base_extend_range[0]) * gt_width) 171 | min_y = int(objcenter[1] - crop_height_half + \ 172 | random.uniform(-cfg.DATA.base_extend_range[1], cfg.DATA.base_extend_range[1]) * gt_height) 173 | max_y = int(objcenter[1] + crop_height_half + \ 174 | random.uniform(-cfg.DATA.base_extend_range[1], cfg.DATA.base_extend_range[1]) * gt_height) 175 | else: 176 | min_x = int(objcenter[0] - crop_width_half) 177 | max_x = int(objcenter[0] + crop_width_half) 178 | min_y = int(objcenter[1] - crop_height_half) 179 | max_y = int(objcenter[1] + crop_height_half) 180 | 181 | joints[:, 0] = joints[:, 0] - min_x 182 | joints[:, 1] = joints[:, 1] - min_y 183 | 184 | img = bimg[min_y:max_y, min_x:max_x, :] 185 | 186 | crop_image_height, crop_image_width, _ = img.shape 187 | joints[:, 0] = joints[:, 0] / crop_image_width 188 | joints[:, 1] = joints[:, 1] / crop_image_height 189 | 190 | interp_methods = [cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_NEAREST, 191 | cv2.INTER_LANCZOS4] 192 | interp_method = random.choice(interp_methods) 193 | 194 | img = cv2.resize(img, (cfg.MODEL.win, cfg.MODEL.hin), interpolation=interp_method) 195 | 196 | joints[:, 0] = joints[:, 0] * cfg.MODEL.win 197 | joints[:, 1] = joints[:, 1] * cfg.MODEL.hin 198 | return img, joints 199 | def single_map_func(self, dp, is_training): 200 | """Data augmentation function.""" 201 | ####customed here 202 | 203 | 204 | fn=dp['image'] 205 | image=cv2.imread(os.path.join(self.img_root_path,fn)) 206 | kps=dp['keypoint'][1:-1].split(',') 207 | kps=[float(x) for x in kps] 208 | kps=np.array(kps).reshape([-1,2]) 209 | bbox = [float(np.min(kps[:, 0])), float(np.min(kps[:, 1])), float(np.max(kps[:, 0])), 210 | float(np.max(kps[:, 1]))] 211 | 212 | bbox = np.array(bbox) 213 | 214 | ### random crop and resize 215 | crop_image, label = self.augmentationCropImage(image, bbox, kps, is_training) 216 | 217 | if is_training: 218 | 219 | if random.uniform(0, 1) > 0.5: 220 | crop_image, label = Mirror(crop_image, label=label, symmetry=cfg.DATA.symmetry) 221 | if random.uniform(0, 1) > 0.0: 222 | angle = random.uniform(-45, 45) 223 | crop_image, label = Rotate_aug(crop_image, label=label, angle=angle) 224 | 225 | if random.uniform(0, 1) > 0.5: 226 | strength = random.uniform(0, 50) 227 | crop_image, label = Affine_aug(crop_image, strength=strength, label=label) 228 | 229 | if random.uniform(0, 1) > 0.5: 230 | crop_image = Padding_aug(crop_image, 0.3) 231 | 232 | transformed = self.train_trans(image=crop_image) 233 | crop_image = transformed['image'] 234 | 235 | 236 | 237 | #######head pose 238 | reprojectdst, euler_angle = get_head_pose(label, crop_image) 239 | PRY = euler_angle.reshape([-1]).astype(np.float32) / 90. 240 | 241 | ######cla_label 242 | cla_label = np.zeros([4]) 243 | if np.sqrt(np.square(label[37, 0] - label[41, 0]) + 244 | np.square(label[37, 1] - label[41, 1])) / cfg.MODEL.hin < self.eye_close_thres \ 245 | or np.sqrt(np.square(label[38, 0] - label[40, 0]) + 246 | np.square(label[38, 1] - label[40, 1])) / cfg.MODEL.hin < self.eye_close_thres: 247 | cla_label[0] = 1 248 | if np.sqrt(np.square(label[43, 0] - label[47, 0]) + 249 | np.square(label[43, 1] - label[47, 1])) / cfg.MODEL.hin < self.eye_close_thres \ 250 | or np.sqrt(np.square(label[44, 0] - label[46, 0]) + 251 | np.square(label[44, 1] - label[46, 1])) / cfg.MODEL.hin < self.eye_close_thres: 252 | cla_label[1] = 1 253 | if np.sqrt(np.square(label[61, 0] - label[67, 0]) + 254 | np.square(label[61, 1] - label[67, 1])) / cfg.MODEL.hin < self.mouth_close_thres \ 255 | or np.sqrt(np.square(label[62, 0] - label[66, 0]) + 256 | np.square(label[62, 1] - label[66, 1])) / cfg.MODEL.hin < self.mouth_close_thres \ 257 | or np.sqrt(np.square(label[63, 0] - label[65, 0]) + 258 | np.square(label[63, 1] - label[65, 1])) / cfg.MODEL.hin < self.mouth_close_thres: 259 | cla_label[2] = 1 260 | 261 | ### mouth open big 1 mean true 262 | if np.sqrt(np.square(label[62, 0] - label[66, 0]) + 263 | np.square(label[62, 1] - label[66, 1])) / cfg.MODEL.hin > self.big_mouth_open_thres: 264 | cla_label[3] = 1 265 | 266 | crop_image_height, crop_image_width, _ = crop_image.shape 267 | 268 | label = label.astype(np.float32) 269 | 270 | label[:, 0] = label[:, 0] / crop_image_width 271 | label[:, 1] = label[:, 1] / crop_image_height 272 | crop_image = crop_image.astype(np.uint8) 273 | 274 | 275 | crop_image = crop_image.astype(np.float32) 276 | 277 | crop_image = np.transpose(crop_image, axes=[2, 0, 1]) 278 | crop_image/=255. 279 | label = label.reshape([-1]).astype(np.float32) 280 | cla_label = cla_label.astype(np.float32) 281 | 282 | label = np.concatenate([label, PRY, cla_label], axis=0) 283 | 284 | 285 | return fn,crop_image,label 286 | -------------------------------------------------------------------------------- /lib/core/base_trainer/model.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import partial 3 | 4 | from train_config import config as cfg 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | import timm 11 | 12 | from torchvision.models.mobilenetv3 import InvertedResidual,InvertedResidualConfig 13 | 14 | 15 | 16 | bn_momentum=0.1 17 | 18 | 19 | 20 | def upsample_x_like_y(x,y): 21 | size = y.shape[-2:] 22 | x=F.interpolate(x, size=size, mode='bilinear') 23 | 24 | return x 25 | 26 | # from lib.core.base_trainer.mobileone import MobileOneBlock 27 | class SeparableConv2d(nn.Module): 28 | """ Separable Conv 29 | """ 30 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding=0, bias=False, 31 | channel_multiplier=1., pw_kernel_size=1): 32 | super(SeparableConv2d, self).__init__() 33 | 34 | 35 | self.conv_dw = nn.Sequential(nn.Conv2d( 36 | int(in_channels*channel_multiplier), int(in_channels*channel_multiplier), kernel_size, 37 | stride=stride, dilation=dilation, padding=padding, groups=int(in_channels*channel_multiplier)), 38 | nn.BatchNorm2d(in_channels,momentum=bn_momentum), 39 | nn.ReLU(inplace=True) 40 | ) 41 | self.conv_pw = nn.Conv2d( 42 | int(in_channels*channel_multiplier), out_channels, pw_kernel_size, padding=0, bias=bias) 43 | 44 | @property 45 | def in_channels(self): 46 | return self.conv_dw.in_channels 47 | 48 | @property 49 | def out_channels(self): 50 | return self.conv_pw.out_channels 51 | 52 | def forward(self, x): 53 | 54 | x = self.conv_dw(x) 55 | x = self.conv_pw(x) 56 | return x 57 | 58 | 59 | 60 | class ASPPPooling(nn.Module): 61 | def __init__(self, in_channels, out_channels): 62 | super(ASPPPooling, self).__init__() 63 | self.pool=nn.Sequential( 64 | nn.AdaptiveAvgPool2d(1), 65 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 66 | nn.BatchNorm2d(out_channels,momentum=bn_momentum), 67 | nn.ReLU()) 68 | def forward(self, x): 69 | 70 | y=x 71 | 72 | 73 | x = self.pool(x) 74 | 75 | x= upsample_x_like_y(x,y) 76 | return x 77 | 78 | 79 | class ASPP(nn.Module): 80 | def __init__(self, in_channels, atrous_rates): 81 | super(ASPP, self).__init__() 82 | out_channels = 512 83 | 84 | rate1, rate2, rate3 = tuple(atrous_rates) 85 | 86 | self.fm_conx1=nn.Sequential( 87 | nn.Conv2d(in_channels, out_channels//4, 1, bias=False), 88 | nn.BatchNorm2d(out_channels//4,momentum=bn_momentum), 89 | nn.ReLU()) 90 | 91 | self.fm_convx3_rate2=nn.Sequential( 92 | nn.Conv2d(in_channels, out_channels//4, kernel_size=3, padding=2, bias=False,dilation=rate1), 93 | nn.BatchNorm2d(out_channels//4,momentum=bn_momentum), 94 | nn.ReLU(inplace=True) 95 | ) 96 | 97 | self.fm_convx3_rate4=nn.Sequential( 98 | nn.Conv2d(in_channels, out_channels//4, kernel_size=3, padding=4, bias=False,dilation=rate2), 99 | nn.BatchNorm2d(out_channels//4,momentum=bn_momentum), 100 | nn.ReLU(inplace=True) 101 | ) 102 | 103 | self.fm_convx3_rate8=nn.Sequential( 104 | nn.Conv2d(in_channels, out_channels//4, kernel_size=3, padding=8, bias=False,dilation=rate3), 105 | nn.BatchNorm2d(out_channels//4,momentum=bn_momentum), 106 | nn.ReLU(inplace=True) 107 | ) 108 | 109 | self.fm_pool=ASPPPooling(in_channels=in_channels,out_channels=out_channels//4) 110 | 111 | self.project = nn.Sequential( 112 | nn.Conv2d(out_channels//4*5, out_channels, 1, bias=False), 113 | nn.BatchNorm2d(out_channels,momentum=bn_momentum), 114 | nn.ReLU(inplace=True)) 115 | 116 | def forward(self, x): 117 | 118 | fm1=self.fm_conx1(x) 119 | fm2=self.fm_convx3_rate2(x) 120 | fm4=self.fm_convx3_rate4(x) 121 | fm8=self.fm_convx3_rate8(x) 122 | fm_pool=self.fm_pool(x) 123 | 124 | res = torch.cat([fm1,fm2,fm4,fm8,fm_pool], dim=1) 125 | 126 | return self.project(res) 127 | 128 | 129 | 130 | class FeatureFuc(nn.Module): 131 | def __init__(self, inchannels=128): 132 | super(FeatureFuc, self).__init__() 133 | 134 | 135 | self.block1=InvertedResidual(cnf=InvertedResidualConfig( inchannels, 5, 256, inchannels, False, "RE", 1, 1, 1), 136 | norm_layer = partial(nn.BatchNorm2d, momentum=bn_momentum)) 137 | 138 | self.block2=InvertedResidual(cnf=InvertedResidualConfig( inchannels, 5, 256, inchannels, False, "RE", 1, 1, 1), 139 | norm_layer = partial(nn.BatchNorm2d, momentum=bn_momentum)) 140 | 141 | 142 | def forward(self, x): 143 | 144 | y1=self.block1(x) 145 | 146 | y2=self.block2(y1) 147 | 148 | return y2 149 | 150 | class SCSEModule(nn.Module): 151 | def __init__(self, in_channels, reduction=4): 152 | super().__init__() 153 | self.cSE = nn.Sequential( 154 | nn.AdaptiveAvgPool2d(1), 155 | nn.Conv2d(in_channels, in_channels // reduction, 1), 156 | nn.ReLU(inplace=True), 157 | nn.Conv2d(in_channels // reduction, in_channels, 1), 158 | nn.Sigmoid(), 159 | ) 160 | self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid()) 161 | 162 | def forward(self, x): 163 | return x *self.cSE(x) + x*self.sSE(x) 164 | 165 | 166 | 167 | 168 | 169 | 170 | class Heading(nn.Module): 171 | 172 | def __init__(self,encoder_channels): 173 | super(Heading, self).__init__() 174 | self.extra_feature = FeatureFuc(encoder_channels[-1]) 175 | self.aspp = ASPP(encoder_channels[-1], [2, 4, 8]) 176 | 177 | 178 | 179 | 180 | 181 | 182 | def forward(self,features): 183 | 184 | img,encx2,encx4,encx8,encx16=features 185 | 186 | ## add extra feature 187 | encx16=self.extra_feature(encx16) 188 | encx16=self.aspp(encx16) 189 | 190 | return [encx2,encx4,encx8,encx16] 191 | 192 | class Net(nn.Module): 193 | def __init__(self,): 194 | super(Net, self).__init__() 195 | 196 | self.encoder = timm.create_model(model_name='mobilenetv3_large_100', 197 | pretrained=True, 198 | features_only=True, 199 | out_indices=[0,1,2,4], 200 | bn_momentum=bn_momentum, 201 | in_chans=3, 202 | output_stride=16, 203 | ) 204 | 205 | # self.encoder.blocks[4][1]=nn.Identity() 206 | self.encoder.blocks[5]=nn.Identity() 207 | self.encoder.blocks[6]=nn.Identity() 208 | 209 | 210 | self.encoder_out_channels = [3, 16, 24, 40, 112] #mobilenetv3 211 | 212 | 213 | self.head=Heading(self.encoder_out_channels) 214 | self._avg_pooling = nn.AdaptiveAvgPool2d(1) 215 | 216 | self._fc = nn.Linear(512, 136+3+4, bias=True) 217 | 218 | def forward(self, x): 219 | """Sequentially pass `x` trough model`s encoder, decoder and heads""" 220 | 221 | bs = x.size(0) 222 | features=self.encoder(x) 223 | 224 | features=[x]+features 225 | 226 | [encx2,encx4,encx8,encx16]=self.head(features) 227 | fm=self._avg_pooling(encx16) 228 | 229 | fm = fm.view(bs, -1) 230 | x = self._fc(fm) 231 | 232 | 233 | return x,[encx2,encx4,encx8,encx16] 234 | 235 | 236 | 237 | class TeacherNet(nn.Module): 238 | def __init__(self,): 239 | super(TeacherNet, self).__init__() 240 | 241 | self.encoder = timm.create_model(model_name='tf_efficientnet_b0_ns', 242 | pretrained=True, 243 | features_only=True, 244 | out_indices=[0,1,2,3], 245 | bn_momentum=bn_momentum, 246 | in_chans=3, 247 | ) 248 | 249 | # self.encoder.out_channels=[3, 24 , 40, 64,176] 250 | self.encoder.out_channels=[3,16, 24, 40, 112] 251 | self.head = Heading(self.encoder.out_channels) 252 | self._avg_pooling = nn.AdaptiveAvgPool2d(1) 253 | 254 | self._fc = nn.Linear(512, 136 + 3 + 4, bias=True) 255 | 256 | def forward(self, x): 257 | """Sequentially pass `x` trough model`s encoder, decoder and heads""" 258 | bs = x.size(0) 259 | features=self.encoder(x) 260 | 261 | features=[x]+features 262 | 263 | [encx2,encx4,encx8,encx16] = self.head(features) 264 | 265 | fm = self._avg_pooling(encx16) 266 | 267 | fm = fm.view(bs, -1) 268 | 269 | x = self._fc(fm) 270 | 271 | return x,[encx2,encx4,encx8,encx16] 272 | 273 | 274 | 275 | class COTRAIN(nn.Module): 276 | def __init__(self,inference=False): 277 | super(COTRAIN, self).__init__() 278 | 279 | self.inference=inference 280 | self.student=Net() 281 | self.teacher=TeacherNet() 282 | 283 | self.MSELoss=nn.MSELoss() 284 | 285 | # self.DiceLoss = segmentation_models_pytorch.losses.DiceLoss(mode='multilabel',eps=1e-5) 286 | self.BCELoss = nn.BCEWithLogitsLoss() 287 | 288 | 289 | self.act=nn.Sigmoid() 290 | 291 | 292 | def distill_loss(self,student_pres,teacher_pres): 293 | 294 | 295 | num_level=len(student_pres) 296 | loss=0 297 | for i in range(num_level): 298 | loss+=self.MSELoss(student_pres[i],teacher_pres[i]) 299 | 300 | return loss/num_level 301 | 302 | 303 | def criterion(self,y_pred, y_true): 304 | 305 | return 0.5*self.BCELoss(y_pred, y_true) + 0.5*self.DiceLoss(y_pred, y_true) 306 | 307 | def _wing_loss(self,landmarks, labels, w=10.0, epsilon=2.0, weights=1.): 308 | """ 309 | Arguments: 310 | landmarks, labels: float tensors with shape [batch_size, landmarks]. landmarks means x1,x2,x3,x4...y1,y2,y3,y4 1-D 311 | w, epsilon: a float numbers. 312 | Returns: 313 | a float tensor with shape []. 314 | """ 315 | 316 | x = landmarks - labels 317 | c = w * (1.0 - math.log(1.0 + w / epsilon)) 318 | absolute_x = torch.abs(x) 319 | losses = torch.where( 320 | torch.gt(absolute_x, w), absolute_x - c, 321 | w * torch.log(1.0 + absolute_x / epsilon) 322 | 323 | ) 324 | losses = losses * torch.tensor(cfg.DATA.weights, device='cuda') 325 | loss = torch.sum(torch.mean(losses * weights, dim=[0])) 326 | 327 | return loss 328 | def loss(self,predict_keypoints, label_keypoints): 329 | 330 | landmark_label = label_keypoints[:, 0:136] 331 | pose_label = label_keypoints[:, 136:139] 332 | leye_cls_label = label_keypoints[:, 139] 333 | reye_cls_label = label_keypoints[:, 140] 334 | mouth_cls_label = label_keypoints[:, 141] 335 | big_mouth_cls_label = label_keypoints[:, 142] 336 | 337 | landmark_predict = predict_keypoints[:, 0:136] 338 | pose_predict = predict_keypoints[:, 136:139] 339 | leye_cls_predict = predict_keypoints[:, 139] 340 | reye_cls_predict = predict_keypoints[:, 140] 341 | mouth_cls_predict = predict_keypoints[:, 141] 342 | big_mouth_cls_predict = predict_keypoints[:, 142] 343 | 344 | loss = self._wing_loss(landmark_predict, landmark_label) 345 | 346 | loss_pose = self.MSELoss(pose_predict, pose_label) 347 | 348 | leye_loss = self.BCELoss (leye_cls_predict, leye_cls_label) 349 | reye_loss = self.BCELoss (reye_cls_predict, reye_cls_label) 350 | 351 | mouth_loss = self.BCELoss (mouth_cls_predict, mouth_cls_label) 352 | mouth_loss_big = self.BCELoss (big_mouth_cls_predict, big_mouth_cls_label) 353 | mouth_loss = mouth_loss + mouth_loss_big 354 | 355 | 356 | 357 | return loss + loss_pose + leye_loss + reye_loss + mouth_loss 358 | 359 | return current_loss 360 | 361 | 362 | def forward(self, x,gt=None): 363 | 364 | student_pre,student_fms=self.student(x) 365 | 366 | if self.inference: 367 | 368 | return student_pre 369 | 370 | teacher_pre,teacher_fms=self.teacher(x) 371 | 372 | distill_loss=self.distill_loss(student_fms,teacher_fms) 373 | 374 | student_loss=self.loss(student_pre,gt) 375 | 376 | teacher_loss=self.loss(teacher_pre,gt) 377 | 378 | return student_loss,teacher_loss,distill_loss,student_pre 379 | 380 | 381 | 382 | 383 | 384 | 385 | 386 | 387 | if __name__=='__main__': 388 | import torch 389 | import torchvision 390 | 391 | from thop import profile 392 | 393 | dummy_x = torch.randn(1, 3, 288, 160, device='cpu') 394 | 395 | model = COTRAIN(inference=True) 396 | 397 | input = torch.randn(1, 3, 288, 160) 398 | flops, params = profile(model, inputs=(input,)) 399 | print(flops/1024/1024) 400 | 401 | -------------------------------------------------------------------------------- /lib/dataset/augmentor/augmentation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import random 5 | import math 6 | from train_config import config as cfg 7 | 8 | ######May wrong, when use it check it 9 | def Rotate_aug(src,angle,label=None,center=None,scale=1.0): 10 | ''' 11 | :param src: src image 12 | :param label: label should be numpy array with [[x1,y1], 13 | [x2,y2], 14 | [x3,y3]...] 15 | :param angle: 16 | :param center: 17 | :param scale: 18 | :return: the rotated image and the points 19 | ''' 20 | image=src 21 | (h, w) = image.shape[:2] 22 | # 若未指定旋转中心,则将图像中心设为旋转中心 23 | if center is None: 24 | center = (w / 2, h / 2) 25 | # 执行旋转 26 | M = cv2.getRotationMatrix2D(center, angle, scale) 27 | if label is None: 28 | for i in range(image.shape[2]): 29 | image[:,:,i] = cv2.warpAffine(image[:,:,i], M, (w, h), 30 | flags=cv2.INTER_CUBIC, 31 | borderMode=cv2.BORDER_CONSTANT) 32 | return image,None 33 | else: 34 | label=label.T 35 | ####make it as a 3x3 RT matrix 36 | full_M=np.row_stack((M,np.asarray([0,0,1]))) 37 | img_rotated = cv2.warpAffine(image, M, (w, h), flags=cv2.INTER_CUBIC, 38 | borderMode=cv2.BORDER_CONSTANT) 39 | ###make the label as 3xN matrix 40 | full_label = np.row_stack((label, np.ones(shape=(1,label.shape[1])))) 41 | label_rotated=np.dot(full_M,full_label) 42 | label_rotated=label_rotated[0:2,:] 43 | #label_rotated = label_rotated.astype(np.int32) 44 | label_rotated=label_rotated.T 45 | return img_rotated,label_rotated 46 | def Rotate_coordinate(label,rt_matrix): 47 | if rt_matrix.shape[0]==2: 48 | rt_matrix=np.row_stack((rt_matrix, np.asarray([0, 0, 1]))) 49 | full_label = np.row_stack((label, np.ones(shape=(1, label.shape[1])))) 50 | label_rotated = np.dot(rt_matrix, full_label) 51 | label_rotated = label_rotated[0:2, :] 52 | return label_rotated 53 | 54 | 55 | def box_to_point(boxes): 56 | ''' 57 | 58 | :param boxes: [n,x,y,x,y] 59 | :return: [4n,x,y] 60 | ''' 61 | ##caution the boxes are ymin xmin ymax xmax 62 | points_set=np.zeros(shape=[4*boxes.shape[0],2]) 63 | 64 | for i in range(boxes.shape[0]): 65 | points_set[4 * i]=np.array([boxes[i][0],boxes[i][1]]) 66 | points_set[4 * i+1] =np.array([boxes[i][0],boxes[i][3]]) 67 | points_set[4 * i+2] =np.array([boxes[i][2],boxes[i][3]]) 68 | points_set[4 * i+3] =np.array([boxes[i][2],boxes[i][1]]) 69 | 70 | 71 | return points_set 72 | 73 | 74 | def point_to_box(points): 75 | boxes=[] 76 | points=points.reshape([-1,4,2]) 77 | 78 | for i in range(points.shape[0]): 79 | box=[np.min(points[i][:,0]),np.min(points[i][:,1]),np.max(points[i][:,0]),np.max(points[i][:,1])] 80 | 81 | boxes.append(box) 82 | 83 | return np.array(boxes) 84 | 85 | 86 | def Rotate_with_box(src,angle,boxes=None,center=None,scale=1.0): 87 | ''' 88 | :param src: src image 89 | :param label: label should be numpy array with [[x1,y1], 90 | [x2,y2], 91 | [x3,y3]...] 92 | :param angle:angel 93 | :param center: 94 | :param scale: 95 | :return: the rotated image and the points 96 | ''' 97 | 98 | label=box_to_point(boxes) 99 | image=src 100 | (h, w) = image.shape[:2] 101 | # 若未指定旋转中心,则将图像中心设为旋转中心 102 | 103 | 104 | if center is None: 105 | center = (w / 2, h / 2) 106 | # 执行旋转 107 | M = cv2.getRotationMatrix2D(center, angle, scale) 108 | 109 | new_size=Rotate_coordinate(np.array([[0,w,w,0], 110 | [0,0,h,h]]), M) 111 | 112 | new_h,new_w=np.max(new_size[1])-np.min(new_size[1]),np.max(new_size[0])-np.min(new_size[0]) 113 | 114 | scale=min(h/new_h,w/new_w) 115 | 116 | M = cv2.getRotationMatrix2D(center, angle, scale) 117 | 118 | if boxes is None: 119 | for i in range(image.shape[2]): 120 | image[:,:,i] = cv2.warpAffine(image[:,:,i], M, (w, h), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_CONSTANT) 121 | return image,None 122 | else: 123 | label=label.T 124 | ####make it as a 3x3 RT matrix 125 | full_M=np.row_stack((M,np.asarray([0,0,1]))) 126 | img_rotated = cv2.warpAffine(image, M, (w, h), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_CONSTANT) 127 | ###make the label as 3xN matrix 128 | full_label = np.row_stack((label, np.ones(shape=(1,label.shape[1])))) 129 | label_rotated=np.dot(full_M,full_label) 130 | label_rotated=label_rotated[0:2,:] 131 | #label_rotated = label_rotated.astype(np.int32) 132 | label_rotated=label_rotated.T 133 | 134 | boxes_rotated = point_to_box(label_rotated) 135 | return img_rotated,boxes_rotated 136 | 137 | ###CAUTION:its not ok for transform with label for perspective _aug 138 | def Perspective_aug(src,strength,label=None): 139 | image = src 140 | pts_base = np.float32([[0, 0], [300, 0], [0, 300], [300, 300]]) 141 | pts1=np.random.rand(4, 2)*random.uniform(-strength,strength)+pts_base 142 | pts1=pts1.astype(np.float32) 143 | #pts1 =np.float32([[56, 65], [368, 52], [28, 387], [389, 398]]) 144 | M = cv2.getPerspectiveTransform(pts1, pts_base) 145 | trans_img = cv2.warpPerspective(image, M, (src.shape[1], src.shape[0])) 146 | 147 | label_rotated=None 148 | if label is not None: 149 | label=label.T 150 | full_label = np.row_stack((label, np.ones(shape=(1, label.shape[1])))) 151 | label_rotated = np.dot(M, full_label) 152 | label_rotated=label_rotated.astype(np.int32) 153 | label_rotated=label_rotated.T 154 | return trans_img,label_rotated 155 | 156 | def Affine_aug(src,strength,label=None): 157 | image = src 158 | pts_base = np.float32([[10,100],[200,50],[100,250]]) 159 | pts1 = np.random.rand(3, 2) * random.uniform(-strength, strength) + pts_base 160 | pts1 = pts1.astype(np.float32) 161 | M = cv2.getAffineTransform(pts1, pts_base) 162 | trans_img = cv2.warpAffine(image, M, (image.shape[1], image.shape[0]) , 163 | borderMode=cv2.BORDER_CONSTANT) 164 | label_rotated=None 165 | if label is not None: 166 | label=label.T 167 | full_label = np.row_stack((label, np.ones(shape=(1, label.shape[1])))) 168 | label_rotated = np.dot(M, full_label) 169 | #label_rotated = label_rotated.astype(np.int32) 170 | label_rotated=label_rotated.T 171 | return trans_img,label_rotated 172 | def Padding_aug(src,max_pattern_ratio=0.05): 173 | src=src.astype(np.float32) 174 | pattern=np.ones_like(src) 175 | ratio = random.uniform(0, max_pattern_ratio) 176 | 177 | height,width,_=src.shape 178 | 179 | if random.uniform(0,1)>0.5: 180 | if random.uniform(0, 1) > 0.5: 181 | pattern[0:int(ratio*height),:,:]=0 182 | else: 183 | pattern[height-int(ratio * height):, :, :] = 0 184 | else: 185 | if random.uniform(0, 1) > 0.5: 186 | pattern[:,0:int(ratio * width), :] = 0 187 | else: 188 | pattern[:,width-int(ratio * width):, :] = 0 189 | 190 | 191 | bias_pattern=(1-pattern) 192 | 193 | img=src*pattern+bias_pattern 194 | 195 | img=img.astype(np.uint8) 196 | return img 197 | 198 | def Blur_heatmaps(src, ksize=(3, 3)): 199 | for i in range(src.shape[2]): 200 | src[:, :, i] = cv2.GaussianBlur(src[:, :, i], ksize, 0) 201 | amin, amax = src[:, :, i].min(), src[:, :, i].max() # 求最大最小值 202 | if amax>0: 203 | src[:, :, i] = (src[:, :, i] - amin) / (amax - amin) # (矩阵元素-最小值)/(最大值-最小值) 204 | return src 205 | def Blur_aug(src,ksize=(3,3)): 206 | for i in range(src.shape[2]): 207 | src[:, :, i]=cv2.GaussianBlur(src[:, :, i],ksize,1.5) 208 | return src 209 | 210 | def Img_dropout(src,max_pattern_ratio=0.05): 211 | width_ratio = random.uniform(0, max_pattern_ratio) 212 | height_ratio = random.uniform(0, max_pattern_ratio) 213 | width=src.shape[1] 214 | height=src.shape[0] 215 | block_width=width*width_ratio 216 | block_height=height*height_ratio 217 | width_start=int(random.uniform(0,width-block_width)) 218 | width_end=int(width_start+block_width) 219 | height_start=int(random.uniform(0,height-block_height)) 220 | height_end=int(height_start+block_height) 221 | src[height_start:height_end,width_start:width_end,:]=0 222 | 223 | return src 224 | 225 | def Fill_img(img_raw,target_height,target_width,label=None): 226 | 227 | ###sometimes use in objs detects 228 | channel=img_raw.shape[2] 229 | raw_height = img_raw.shape[0] 230 | raw_width = img_raw.shape[1] 231 | if raw_width / raw_height >= target_width / target_height: 232 | shape_need = [int(target_height / target_width * raw_width), raw_width, channel] 233 | img_fill = np.zeros(shape_need, dtype=img_raw.dtype) 234 | shift_x=(img_fill.shape[1]-raw_width)//2 235 | shift_y=(img_fill.shape[0]-raw_height)//2 236 | for i in range(channel): 237 | img_fill[shift_y:raw_height+shift_y, shift_x:raw_width+shift_x, i] = img_raw[:,:,i] 238 | else: 239 | shape_need = [raw_height, int(target_width / target_height * raw_height), channel] 240 | img_fill = np.zeros(shape_need, dtype=img_raw.dtype) 241 | shift_x = (img_fill.shape[1] - raw_width) // 2 242 | shift_y = (img_fill.shape[0] - raw_height) // 2 243 | for i in range(channel): 244 | img_fill[shift_y:raw_height + shift_y, shift_x:raw_width + shift_x, i] = img_raw[:, :, i] 245 | if label is None: 246 | return img_fill,shift_x,shift_y 247 | else: 248 | label[:,0]+=shift_x 249 | label[:, 1]+=shift_y 250 | return img_fill,label 251 | def Random_crop(src,shrink): 252 | h,w,_=src.shape 253 | 254 | h_shrink=int(h*shrink) 255 | w_shrink = int(w * shrink) 256 | bimg = cv2.copyMakeBorder(src, h_shrink, h_shrink, w_shrink, w_shrink, borderType=cv2.BORDER_CONSTANT, 257 | value=(0,0,0)) 258 | 259 | start_h=random.randint(0,2*h_shrink) 260 | start_w=random.randint(0,2*w_shrink) 261 | 262 | target_img=bimg[start_h:start_h+h,start_w:start_w+w,:] 263 | 264 | return target_img 265 | 266 | def box_in_img(img,boxes,min_overlap=0.5): 267 | 268 | raw_bboxes = np.array(boxes) 269 | 270 | face_area=(boxes[:,3]-boxes[:,1])*(boxes[:,2]-boxes[:,0]) 271 | 272 | h,w,_=img.shape 273 | boxes[:, 0][boxes[:, 0] <=0] =0 274 | boxes[:, 0][boxes[:, 0] >=w] = w 275 | boxes[:, 2][boxes[:, 2] <= 0] = 0 276 | boxes[:, 2][boxes[:, 2] >= w] = w 277 | 278 | boxes[:, 1][boxes[:, 1] <= 0] = 0 279 | boxes[:, 1][boxes[:, 1] >= h] = h 280 | 281 | boxes[:, 3][boxes[:, 3] <= 0] = 0 282 | boxes[:, 3][boxes[:, 3] >= h] = h 283 | 284 | boxes_in = [] 285 | for i in range(boxes.shape[0]): 286 | box=boxes[i] 287 | if ((box[3]-box[1])*(box[2]-box[0]))/face_area[i]>min_overlap : 288 | boxes_in.append(boxes[i]) 289 | 290 | boxes_in = np.array(boxes_in) 291 | return boxes_in 292 | 293 | def Random_scale_withbbox(image,bboxes,target_shape,jitter=0.5): 294 | 295 | ###the boxes is in ymin,xmin,ymax,xmax mode 296 | hi, wi, _ = image.shape 297 | 298 | while 1: 299 | if len(bboxes)==0: 300 | print('errrrrrr') 301 | bboxes_=np.array(bboxes) 302 | crop_h = int(hi * random.uniform(0.2, 1)) 303 | crop_w = int(wi * random.uniform(0.2, 1)) 304 | 305 | start_h = random.randint(0, hi - crop_h) 306 | start_w = random.randint(0, wi - crop_w) 307 | 308 | croped = image[start_h:start_h + crop_h, start_w:start_w + crop_w, :] 309 | 310 | bboxes_[:, 0] = bboxes_[:, 0] - start_w 311 | bboxes_[:, 1] = bboxes_[:, 1] - start_h 312 | bboxes_[:, 2] = bboxes_[:, 2] - start_w 313 | bboxes_[:, 3] = bboxes_[:, 3] - start_h 314 | 315 | 316 | bboxes_fix=box_in_img(croped,bboxes_) 317 | if len(bboxes_fix)>0: 318 | break 319 | 320 | 321 | ###use box 322 | h,w=target_shape 323 | croped_h,croped_w,_=croped.shape 324 | 325 | croped_h_w_ratio=croped_h/croped_w 326 | 327 | rescale_h=int(h * random.uniform(0.5, 1)) 328 | 329 | rescale_w = int(rescale_h/(random.uniform(0.7, 1.3)*croped_h_w_ratio)) 330 | rescale_w=np.clip(rescale_w,0,w) 331 | 332 | image=cv2.resize(croped,(rescale_w,rescale_h)) 333 | 334 | new_image=np.zeros(shape=[h,w,3],dtype=np.uint8) 335 | 336 | dx = int(random.randint(0, w - rescale_w)) 337 | dy = int(random.randint(0, h - rescale_h)) 338 | 339 | new_image[dy:dy+rescale_h,dx:dx+rescale_w,:]=image 340 | 341 | bboxes_fix[:, 0] = bboxes_fix[:, 0] * rescale_w/ croped_w+dx 342 | bboxes_fix[:, 1] = bboxes_fix[:, 1] * rescale_h / croped_h+dy 343 | bboxes_fix[:, 2] = bboxes_fix[:, 2] * rescale_w / croped_w+dx 344 | bboxes_fix[:, 3] = bboxes_fix[:, 3] * rescale_h / croped_h+dy 345 | 346 | 347 | 348 | return new_image,bboxes_fix 349 | 350 | 351 | def Random_flip(im, boxes): 352 | 353 | im_lr = np.fliplr(im).copy() 354 | h,w,_ = im.shape 355 | xmin = w - boxes[:,2] 356 | xmax = w - boxes[:,0] 357 | boxes[:,0] = xmin 358 | boxes[:,2] = xmax 359 | return im_lr, boxes 360 | 361 | 362 | def Mirror(src,label=None,symmetry=None): 363 | 364 | img = cv2.flip(src, 1) 365 | if label is None: 366 | return img,label 367 | 368 | width=img.shape[1] 369 | cod = [] 370 | allc = [] 371 | for i in range(label.shape[0]): 372 | x, y = label[i][0], label[i][1] 373 | if x >= 0: 374 | x = width - 1 - x 375 | cod.append((x, y)) 376 | # **** the joint index depends on the dataset **** 377 | for (q, w) in symmetry: 378 | cod[q], cod[w] = cod[w], cod[q] 379 | for i in range(label.shape[0]): 380 | allc.append(cod[i][0]) 381 | allc.append(cod[i][1]) 382 | label = np.array(allc).reshape(label.shape[0], 2) 383 | return img,label 384 | 385 | def produce_heat_maps(label,map_size,stride,sigma): 386 | def produce_heat_map(center,map_size,stride,sigma): 387 | grid_y = map_size[0] // stride 388 | grid_x = map_size[1] // stride 389 | start = stride / 2.0 - 0.5 390 | y_range = [i for i in range(grid_y)] 391 | x_range = [i for i in range(grid_x)] 392 | xx, yy = np.meshgrid(x_range, y_range) 393 | xx = xx * stride + start 394 | yy = yy * stride + start 395 | d2 = (xx - center[0]) ** 2 + (yy - center[1]) ** 2 396 | exponent = d2 / 2.0 / sigma / sigma 397 | heatmap = np.exp(-exponent) 398 | 399 | am = np.amax(heatmap) 400 | if am > 0: 401 | heatmap /= am / 255. 402 | 403 | return heatmap 404 | all_keypoints = label 405 | point_num = all_keypoints.shape[0] 406 | heatmaps_this_img=np.zeros([map_size[0]//stride,map_size[1]//stride,point_num]) 407 | for k in range(point_num): 408 | heatmap = produce_heat_map([all_keypoints[k][0],all_keypoints[k][1]], map_size, stride, sigma) 409 | heatmaps_this_img[:,:,k]=heatmap 410 | return heatmaps_this_img 411 | 412 | def visualize_heatmap_target(heatmap): 413 | map_size=heatmap.shape[0:2] 414 | frame_num = heatmap.shape[2] 415 | heat_ = np.zeros([map_size[0], map_size[1]]) 416 | for i in range(frame_num): 417 | heat_ = heat_ + heatmap[:, :, i] 418 | cv2.namedWindow('heat_map', 0) 419 | cv2.imshow('heat_map', heat_) 420 | cv2.waitKey(0) 421 | 422 | 423 | def produce_heatmaps_with_bbox(image,label,h_out,w_out,num_klass,ksize=9,sigma=0): 424 | heatmap=np.zeros(shape=[h_out,w_out,num_klass]) 425 | 426 | h,w,_=image.shape 427 | 428 | for single_box in label: 429 | if single_box[4]>=0: 430 | ####box center (x,y) 431 | center=[(single_box[0]+single_box[2])/2/w,(single_box[1]+single_box[3])/2/h] ###0-1 432 | 433 | heatmap[round(center[1]*h_out),round(center[0]*w_out),int(single_box[4]) ]=1. 434 | 435 | heatmap = cv2.GaussianBlur(heatmap, (ksize,ksize), sigma) 436 | am = np.amax(heatmap) 437 | if am>0: 438 | heatmap /= am / 255. 439 | heatmap=np.expand_dims(heatmap,-1) 440 | return heatmap 441 | 442 | 443 | def produce_heatmaps_with_keypoint(image,label,h_out,w_out,num_klass,ksize=7,sigma=0): 444 | heatmap=np.zeros(shape=[h_out,w_out,num_klass]) 445 | 446 | h,w,_=image.shape 447 | 448 | for i in range(label.shape[0]): 449 | single_point=label[i] 450 | 451 | if single_point[0]>0 and single_point[1]>0: 452 | 453 | heatmap[int(single_point[1]*(h_out-1)),int(single_point[0]*(w_out-1)),i ]=1. 454 | 455 | heatmap = cv2.GaussianBlur(heatmap, (ksize,ksize), sigma) 456 | am = np.amax(heatmap) 457 | if am>0: 458 | heatmap /= am / 255. 459 | return heatmap 460 | 461 | 462 | 463 | 464 | 465 | if __name__=='__main__': 466 | pass 467 | -------------------------------------------------------------------------------- /lib/core/base_trainer/net_work.py: -------------------------------------------------------------------------------- 1 | #-*-coding:utf-8-*- 2 | import numpy as np 3 | import sklearn.metrics 4 | import cv2 5 | import time 6 | import os 7 | 8 | from torch.utils.data import DataLoader, DistributedSampler 9 | 10 | 11 | from lib.dataset.dataietr import AlaskaDataIter 12 | from train_config import config as cfg 13 | #from lib.dataset.dataietr import DataIter 14 | 15 | import sklearn.metrics 16 | from lib.utils.logger import logger 17 | 18 | from lib.core.base_trainer.model import Net,COTRAIN 19 | 20 | import random 21 | 22 | from lib.core.base_trainer.metric import * 23 | import torch 24 | import torch.nn.functional as F 25 | 26 | 27 | 28 | 29 | if not cfg.TRAIN.vis: 30 | torch.distributed.init_process_group(backend="nccl",) 31 | 32 | class Train(object): 33 | """Train class. 34 | """ 35 | 36 | def __init__(self,train_df,val_df,fold): 37 | 38 | self.ddp= not cfg.TRAIN.vis 39 | 40 | # self.ddp=False 41 | 42 | if self.ddp: 43 | 44 | self.train_generator = AlaskaDataIter(train_df,cfg.DATA.root_path ,training_flag=True, shuffle=False) 45 | 46 | 47 | self.train_sampler=DistributedSampler(self.train_generator, 48 | shuffle=True) 49 | self.train_ds = DataLoader(self.train_generator, 50 | 51 | cfg.TRAIN.batch_size, 52 | num_workers=cfg.TRAIN.process_num, 53 | sampler=self.train_sampler) 54 | 55 | self.val_generator = AlaskaDataIter(val_df,cfg.DATA.root_path , training_flag=False, shuffle=False) 56 | 57 | 58 | self.val_sampler=DistributedSampler(self.val_generator, 59 | shuffle=False) 60 | 61 | self.val_ds = DataLoader(self.val_generator, 62 | cfg.TRAIN.validatiojn_batch_size, 63 | num_workers=cfg.TRAIN.process_num, 64 | sampler=self.val_sampler) 65 | 66 | local_rank = torch.distributed.get_rank() 67 | torch.cuda.set_device(local_rank) 68 | self.device = torch.device("cuda", local_rank) 69 | 70 | else: 71 | self.train_generator = AlaskaDataIter(train_df, 72 | img_root=cfg.DATA.root_path, 73 | training_flag=True, shuffle=False) 74 | 75 | self.train_ds = DataLoader(self.train_generator, 76 | cfg.TRAIN.batch_size, 77 | num_workers=cfg.TRAIN.process_num,shuffle=True) 78 | 79 | self.val_generator = AlaskaDataIter(val_df, 80 | img_root=cfg.DATA.root_path, 81 | training_flag=False, shuffle=False) 82 | 83 | self.val_ds = DataLoader(self.val_generator, 84 | cfg.TRAIN.validatiojn_batch_size, 85 | num_workers=cfg.TRAIN.process_num,shuffle=False) 86 | 87 | self.device = torch.device("cuda" if torch.cuda.is_available() else 'cpu') 88 | 89 | self.fold = fold 90 | self.fp16=cfg.TRAIN.mix_precision 91 | 92 | self.init_lr = cfg.TRAIN.init_lr 93 | self.warup_step = cfg.TRAIN.warmup_step 94 | self.epochs = cfg.TRAIN.epoch 95 | self.batch_size = cfg.TRAIN.batch_size 96 | self.l2_regularization = cfg.TRAIN.weight_decay_factor 97 | 98 | self.early_stop = cfg.TRAIN.early_stop 99 | 100 | self.accumulation_step = cfg.TRAIN.accumulation_batch_size // cfg.TRAIN.batch_size 101 | 102 | 103 | self.gradient_clip = cfg.TRAIN.gradient_clip 104 | 105 | self.save_dir=cfg.MODEL.model_path 106 | #### make the device 107 | 108 | 109 | self.model = COTRAIN().to(self.device) 110 | 111 | self.load_weight() 112 | # self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model) 113 | 114 | if 'Adamw' in cfg.TRAIN.opt: 115 | 116 | self.optimizer=torch.optim.AdamW(self.model.parameters(), 117 | lr=self.init_lr, 118 | weight_decay=self.l2_regularization) 119 | 120 | else: 121 | self.optimizer = torch.optim.SGD(self.model.parameters(), 122 | lr=self.init_lr, 123 | momentum=0.9, 124 | weight_decay=self.l2_regularization) 125 | 126 | 127 | 128 | 129 | if self.ddp: 130 | self.model = torch.nn.parallel.DistributedDataParallel(self.model, 131 | device_ids=[local_rank], 132 | output_device=local_rank, 133 | static_graph=True, 134 | find_unused_parameters=True 135 | 136 | ) 137 | else: 138 | self.model=torch.nn.DataParallel(self.model) 139 | 140 | if cfg.TRAIN.vis: 141 | self.vis() 142 | 143 | 144 | 145 | ###control vars 146 | self.iter_num = 0 147 | 148 | # self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer,mode='max', patience=5, 149 | # min_lr=1e-6,factor=0.5,verbose=True) 150 | 151 | # if cfg.TRAIN.lr_scheduler=='cos': 152 | if 1: 153 | logger.info('lr_scheduler.CosineAnnealingLR') 154 | self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, 155 | self.epochs, 156 | eta_min=1.e-7) 157 | else: 158 | logger.info('lr_scheduler.ReduceLROnPlateau') 159 | self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, 160 | mode='max', 161 | patience=5, 162 | min_lr=1e-7, 163 | factor=cfg.TRAIN.lr_scheduler_factor, 164 | verbose=True) 165 | 166 | self.scaler = torch.cuda.amp.GradScaler() 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | def metric(self,targets,preds ): 175 | 176 | 177 | mse=torch.mean((preds-targets)**2) 178 | 179 | mad=torch.mean(torch.abs(preds-targets)) 180 | 181 | return mad,mse 182 | 183 | def custom_loop(self): 184 | """Custom training and testing loop. 185 | Args: 186 | train_dist_dataset: Training dataset created using strategy. 187 | test_dist_dataset: Testing dataset created using strategy. 188 | strategy: Distribution strategy. 189 | Returns: 190 | train_loss, train_accuracy, test_loss, test_accuracy 191 | """ 192 | 193 | def distributed_train_epoch(epoch_num): 194 | 195 | self.train_sampler.set_epoch(epoch_num) 196 | 197 | summary_loss = AverageMeter() 198 | 199 | summary_studen_loss = AverageMeter() 200 | summary_teacher_loss = AverageMeter() 201 | summary_distill_loss = AverageMeter() 202 | 203 | self.model.train() 204 | 205 | 206 | for k,(ids,images, kps) in enumerate(self.train_ds): 207 | 208 | if epoch_num<10: 209 | ###excute warm up in the first epoch 210 | if self.warup_step>0: 211 | if self.iter_num < self.warup_step: 212 | for param_group in self.optimizer.param_groups: 213 | param_group['lr'] = self.iter_num / float(self.warup_step) * self.init_lr 214 | lr = param_group['lr'] 215 | 216 | logger.info('warm up with learning rate: [%f]' % (lr)) 217 | 218 | start=time.time() 219 | 220 | data = images.to(self.device).float() 221 | kps = kps.to(self.device).float() 222 | 223 | batch_size = data.shape[0] 224 | ### 225 | 226 | with torch.cuda.amp.autocast(enabled=self.fp16): 227 | 228 | student_loss, teacher_loss, distill_loss,mate = self.model(data,kps) 229 | 230 | # calculate the final loss, backward the loss, and update the model 231 | current_loss = student_loss+ teacher_loss+ distill_loss 232 | 233 | if torch.isnan(current_loss): 234 | print('there is a nan loss ') 235 | 236 | summary_loss.update(current_loss.detach().item(), batch_size) 237 | summary_studen_loss.update(student_loss.detach().item(), batch_size) 238 | summary_teacher_loss.update(teacher_loss.detach().item(), batch_size) 239 | summary_distill_loss.update(distill_loss.detach().item(), batch_size) 240 | 241 | self.scaler.scale(current_loss).backward() 242 | 243 | if ((self.iter_num + 1) % self.accumulation_step) == 0: 244 | if self.gradient_clip>0: 245 | self.scaler.unscale_(self.optimizer) 246 | nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.gradient_clip, norm_type=2) 247 | self.scaler.step(self.optimizer) 248 | self.scaler.update() 249 | self.optimizer.zero_grad() 250 | 251 | 252 | self.iter_num+=1 253 | time_cost_per_batch=time.time()-start 254 | 255 | images_per_sec=cfg.TRAIN.batch_size/time_cost_per_batch 256 | 257 | 258 | if self.iter_num%cfg.TRAIN.log_interval==0: 259 | 260 | 261 | log_message = '[fold %d], ' \ 262 | 'Train Step %d, ' \ 263 | '[%d/%d ] ' \ 264 | 'summary_loss: %.6f, ' \ 265 | 'summary_student_loss: %.6f, ' \ 266 | 'summary_teacher_loss: %.6f, ' \ 267 | 'summary_distill_loss: %.6f, ' \ 268 | 'time: %.6f, ' \ 269 | 'speed %d images/persec'% ( 270 | self.fold, 271 | self.iter_num, 272 | k, 273 | len(self.train_ds), 274 | summary_loss.avg, 275 | summary_studen_loss.avg, 276 | summary_teacher_loss.avg, 277 | summary_distill_loss.avg, 278 | time.time() - start, 279 | images_per_sec) 280 | logger.info(log_message) 281 | 282 | 283 | 284 | 285 | return summary_loss 286 | 287 | def distributed_test_epoch(epoch_num): 288 | summary_loss = AverageMeter() 289 | summary_dice= AverageMeter() 290 | summary_mad= AverageMeter() 291 | summary_mse= AverageMeter() 292 | 293 | 294 | self.model.eval() 295 | t = time.time() 296 | 297 | with torch.no_grad(): 298 | for step,(ids,images,kps) in enumerate(self.val_ds): 299 | 300 | data = images.to(self.device).float() 301 | kps = kps.to(self.device).float() 302 | 303 | batch_size = data.shape[0] 304 | 305 | loss,_,_,output = self.model(data,kps) 306 | 307 | 308 | if self.ddp: 309 | torch.distributed.all_reduce(loss.div_(torch.distributed.get_world_size())) 310 | 311 | summary_loss.update(loss.detach().item(), batch_size) 312 | 313 | 314 | 315 | val_mad,val_mse =self.metric(kps[:,:136],output[:,:136]) 316 | 317 | if self.ddp: 318 | 319 | torch.distributed.all_reduce(val_mad.div_(torch.distributed.get_world_size())) 320 | torch.distributed.all_reduce(val_mse.div_(torch.distributed.get_world_size())) 321 | 322 | 323 | summary_mad.update(val_mad.detach().item(), batch_size) 324 | summary_mse.update(val_mse.detach().item(), batch_size) 325 | 326 | if step % cfg.TRAIN.log_interval == 0: 327 | 328 | log_message = '[fold %d], ' \ 329 | 'Val Step %d, ' \ 330 | 'summary_loss: %.6f, ' \ 331 | 'summary_mad: %.6f, ' \ 332 | 'summary_mse: %.6f, ' \ 333 | 'time: %.6f' % ( 334 | self.fold,step, 335 | summary_loss.avg, 336 | summary_mad.avg, 337 | summary_mse.avg, 338 | time.time() - t) 339 | 340 | logger.info(log_message) 341 | 342 | 343 | 344 | return summary_loss,summary_mad,summary_mse 345 | 346 | 347 | 348 | 349 | best_roc_auc=0. 350 | not_improvement=0 351 | for epoch in range(self.epochs): 352 | 353 | for param_group in self.optimizer.param_groups: 354 | lr=param_group['lr'] 355 | logger.info('learning rate: [%f]' %(lr)) 356 | t=time.time() 357 | 358 | 359 | summary_loss = distributed_train_epoch(epoch) 360 | train_epoch_log_message = '[fold %d], ' \ 361 | '[RESULT]: TRAIN. Epoch: %d,' \ 362 | ' summary_loss: %.5f,' \ 363 | ' time:%.5f' % ( 364 | self.fold, 365 | epoch, 366 | summary_loss.avg, 367 | (time.time() - t)) 368 | logger.info(train_epoch_log_message) 369 | 370 | 371 | 372 | 373 | 374 | if epoch%cfg.TRAIN.test_interval==0 and epoch>0 or epoch%10==0: 375 | 376 | summary_loss ,summary_mad,summary_mse= distributed_test_epoch(epoch) 377 | 378 | val_epoch_log_message = '[fold %d], ' \ 379 | '[RESULT]: VAL. Epoch: %d,' \ 380 | ' summary_loss: %.5f,' \ 381 | ' mad_score: %.5f,' \ 382 | ' mse_score: %.5f,' \ 383 | ' time:%.5f' % ( 384 | self.fold, 385 | epoch, 386 | summary_loss.avg, 387 | summary_mad.avg, 388 | summary_mse.avg, 389 | (time.time() - t)) 390 | 391 | logger.info(val_epoch_log_message) 392 | 393 | self.scheduler.step() 394 | # self.scheduler.step(acc_score.avg) 395 | 396 | #### save model 397 | if not os.access(cfg.MODEL.model_path, os.F_OK): 398 | os.mkdir(cfg.MODEL.model_path) 399 | ###save the best auc model 400 | 401 | #### save the model every end of epoch 402 | #### save the model every end of epoch 403 | current_model_saved_name='./models/fold%d_epoch_%d_val_loss_%.6f_val_mse_%.6f.pth'%(self.fold, 404 | epoch, 405 | summary_loss.avg, 406 | summary_mse.avg,) 407 | logger.info('A model saved to %s' % current_model_saved_name) 408 | #### save the model every end of epoch 409 | if self.ddp and torch.distributed.get_rank() == 0 : 410 | torch.save(self.model.module.state_dict(),current_model_saved_name) 411 | else: 412 | torch.save(self.model.module.state_dict(),current_model_saved_name) 413 | 414 | else: 415 | self.scheduler.step() 416 | ####switch back 417 | 418 | 419 | # save_checkpoint({ 420 | # 'state_dict': self.model.state_dict(), 421 | # },iters=epoch,tag=current_model_saved_name) 422 | 423 | 424 | 425 | # if cur_roc_auc_score>best_roc_auc: 426 | # best_roc_auc=cur_roc_auc_score 427 | # logger.info(' best metric score update as %.6f' % (best_roc_auc)) 428 | # else: 429 | # not_improvement+=1 430 | 431 | if not_improvement>=self.early_stop: 432 | logger.info(' best metric score not improvement for %d, break'%(self.early_stop)) 433 | break 434 | 435 | 436 | 437 | def load_weight(self): 438 | if cfg.MODEL.pretrained_model is not None: 439 | state_dict=torch.load(cfg.MODEL.pretrained_model, map_location=self.device) 440 | self.model.load_state_dict(state_dict,strict=False) 441 | 442 | 443 | def vis(self): 444 | print('it is here') 445 | 446 | print('show it, here') 447 | 448 | 449 | # state_dict=torch.load('/Users/liangzi/Downloads/fold1_epoch_29_val_loss_0.049467_val_dice_0.964158.pth', map_location=self.device) 450 | # self.model.load_state_dict(state_dict,strict=False) 451 | self.model.eval() 452 | for id,images, kps in self.train_ds: 453 | 454 | 455 | 456 | for i in range(images.shape[0]): 457 | 458 | example_image=np.array(images[i]*255,dtype=np.uint8) 459 | example_image=np.transpose(example_image,[1,2,0]) 460 | example_image=np.ascontiguousarray(example_image) 461 | example_kps=np.array(kps[i][:136].reshape(-1,2))*128 462 | print(example_kps) 463 | 464 | for _index in range(example_kps.shape[0]): 465 | x_y = example_kps[_index] 466 | 467 | cv2.circle(example_image, (int(x_y[0] ),int(x_y[1] )), 468 | color=(255, 0, 0), radius=2, thickness=4) 469 | 470 | 471 | cv2.imshow('ss', example_image) 472 | cv2.waitKey(0) 473 | 474 | 475 | 476 | --------------------------------------------------------------------------------