├── lib
├── __init__.py
├── core
│ ├── __init__.py
│ ├── api
│ │ ├── __init__.py
│ │ └── keypoint.py
│ ├── base_trainer
│ │ ├── __init__.py
│ │ ├── metric.py
│ │ ├── model.py
│ │ └── net_work.py
│ └── model
│ │ ├── loss
│ │ ├── __init__.py
│ │ └── simpleface_loss.py
│ │ ├── __init__.py
│ │ └── face_model.py
├── dataset
│ ├── __init__.py
│ ├── headpose.py
│ ├── augmentor
│ │ ├── visual_augmentation.py
│ │ └── augmentation.py
│ └── dataietr.py
└── utils
│ ├── __init__.py
│ ├── logger.py
│ └── init.py
├── tools
├── __init__.py
├── convert_to_onnx.py
└── convert_to_coreml.py
├── .idea
├── .gitignore
├── vcs.xml
├── misc.xml
├── inspectionProfiles
│ └── profiles_settings.xml
├── modules.xml
└── face_landmark_pytorch.iml
├── .gitignore
├── run.sh
├── figures
└── tmp_screenshot_18.08.20192.png
├── train.py
├── README.md
├── vis.py
├── train_config.py
├── make_json.py
└── LICENSE
/lib/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/lib/core/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/lib/dataset/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/lib/core/api/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/lib/core/base_trainer/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/lib/core/model/loss/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/lib/utils/__init__.py:
--------------------------------------------------------------------------------
1 | #-*-coding:utf-8-*-
--------------------------------------------------------------------------------
/lib/core/model/__init__.py:
--------------------------------------------------------------------------------
1 | #-*-coding:utf-8-*-
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.onnx
3 | *.pth
4 | *.csv
5 | *.log
6 | *.__pycache__
7 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES="0"
2 | torchrun --nproc_per_node=1 train.py
3 |
--------------------------------------------------------------------------------
/figures/tmp_screenshot_18.08.20192.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/610265158/face_landmark_pytorch/HEAD/figures/tmp_screenshot_18.08.20192.png
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/face_landmark_pytorch.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/lib/utils/logger.py:
--------------------------------------------------------------------------------
1 | #-*-coding:utf-8-*-
2 |
3 |
4 |
5 |
6 | #-*-coding:utf-8-*-
7 |
8 | import logging
9 |
10 |
11 | def get_logger(LEVEL,log_file=None):
12 | head = '[%(asctime)-15s] [%(levelname)s] %(message)s '
13 | if LEVEL=='info':
14 | logging.basicConfig(level=logging.INFO, format=head)
15 | elif LEVEL=='debug':
16 | logging.basicConfig(level=logging.DEBUG, format=head)
17 | logger = logging.getLogger()
18 |
19 | if log_file !=None:
20 |
21 | fh = logging.FileHandler(log_file)
22 | logger.addHandler(fh)
23 | return logger
24 |
25 | logger=get_logger('info')
26 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | from lib.core.base_trainer.net_work import Train
4 | from lib.dataset.dataietr import AlaskaDataIter
5 | from torch.utils.data import Dataset, DataLoader
6 | import cv2
7 | import numpy as np
8 | import pandas as pd
9 |
10 |
11 |
12 | from train_config import config as cfg
13 |
14 | import setproctitle
15 |
16 |
17 | setproctitle.setproctitle("pks")
18 |
19 |
20 |
21 |
22 |
23 | def get_fold(df,n_folds):
24 |
25 | skf = KFold(n_splits=n_folds, shuffle=True, random_state=cfg.SEED)
26 | for fold, (train_idx, val_idx) in enumerate(skf.split(df)):
27 | df.loc[val_idx, 'fold'] = fold
28 |
29 | return df
30 |
31 |
32 | def setppm100asval(df):
33 | def func(fn):
34 | if 'PPM' in fn:
35 | return 0
36 | else:
37 | return 1
38 | df['fold']=df['image'].apply(func)
39 |
40 |
41 | return df
42 |
43 |
44 | def main():
45 |
46 |
47 |
48 |
49 |
50 | train_df = pd.read_csv(cfg.DATA.train_f_path)
51 |
52 | val_df =pd.read_csv(cfg.DATA.val_f_path)
53 |
54 |
55 | ###build trainer
56 | trainer = Train(train_df=train_df,val_df=val_df,fold=0)
57 |
58 | ### train
59 | trainer.custom_loop()
60 |
61 | if __name__=='__main__':
62 | main()
63 |
64 |
--------------------------------------------------------------------------------
/lib/utils/init.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import tensorflow as tf
4 | def init(*args):
5 |
6 | if len(args)==1:
7 | use_pb=True
8 | pb_path=args[0]
9 | else:
10 | use_pb=False
11 | meta_path=args[0]
12 | restore_model_path=args[1]
13 |
14 | def ini_ckpt():
15 | graph = tf.Graph()
16 | graph.as_default()
17 | configProto = tf.ConfigProto()
18 | configProto.gpu_options.allow_growth = True
19 | sess = tf.Session(config=configProto)
20 | #load_model(model_path, sess)
21 | saver = tf.train.import_meta_graph(meta_path)
22 | saver.restore(sess, restore_model_path)
23 |
24 | print("Model restred!")
25 | return (graph, sess)
26 | def init_pb(model_path):
27 | config = tf.ConfigProto()
28 | config.gpu_options.per_process_gpu_memory_fraction = 0.2
29 | compute_graph = tf.Graph()
30 | compute_graph.as_default()
31 | sess = tf.Session(config=config)
32 | with tf.gfile.GFile(model_path, 'rb') as fid:
33 | graph_def = tf.GraphDef()
34 | graph_def.ParseFromString(fid.read())
35 | tf.import_graph_def(graph_def, name='')
36 |
37 | # saver = tf.train.Saver(tf.global_variables())
38 | # saver.save(sess, save_path='./tmp.ckpt')
39 | return (compute_graph, sess)
40 |
41 | if use_pb:
42 | model = init_pb(pb_path)
43 | else:
44 | model = ini_ckpt()
45 |
46 | graph = model[0]
47 | sess = model[1]
48 |
49 | return graph,sess
--------------------------------------------------------------------------------
/lib/core/base_trainer/metric.py:
--------------------------------------------------------------------------------
1 | import sklearn
2 | from sklearn import metrics
3 | from sklearn.metrics import confusion_matrix
4 |
5 | import numpy as np
6 | import torch.nn as nn
7 | from train_config import config as cfg
8 |
9 | import torch
10 | import torch.nn as nn
11 |
12 | from sklearn.metrics import roc_auc_score
13 | from lib.utils.logger import logger
14 |
15 | import warnings
16 |
17 | warnings.filterwarnings('ignore')
18 |
19 |
20 | class AverageMeter(object):
21 | """Computes and stores the average and current value"""
22 |
23 | def __init__(self):
24 | self.reset()
25 |
26 | def reset(self):
27 | self.val = 0
28 | self.avg = 0
29 | self.sum = 0
30 | self.count = 0
31 |
32 | def update(self, val, n=1):
33 | self.val = val
34 | self.sum += val * n
35 | self.count += n
36 | self.avg = self.sum / self.count
37 |
38 |
39 | class ROCAUCMeter(object):
40 | def __init__(self):
41 | self.reset()
42 |
43 | def reset(self):
44 |
45 | self.y_true_11 = None
46 | self.y_pred_11 = None
47 |
48 | def update(self, y_true, y_pred):
49 | y_true = y_true.cpu().numpy()
50 |
51 | y_pred = torch.sigmoid(y_pred).data.cpu().numpy()
52 |
53 | if self.y_true_11 is None:
54 | self.y_true_11 = y_true
55 | self.y_pred_11 = y_pred
56 | else:
57 | self.y_true_11 = np.concatenate((self.y_true_11, y_true), axis=0)
58 | self.y_pred_11 = np.concatenate((self.y_pred_11, y_pred), axis=0)
59 |
60 | @property
61 | def avg(self):
62 |
63 | aucs = []
64 | for i in range(11):
65 | aucs.append(roc_auc_score(self.y_true_11[:, i], self.y_pred_11[:, i]))
66 | print(np.round(aucs, 4))
67 |
68 | return np.mean(aucs)
69 |
70 |
71 |
72 |
--------------------------------------------------------------------------------
/lib/core/model/face_model.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('.')
3 | import torch
4 | import torch.nn.functional as F
5 | import torch.nn as nn
6 | from torch.nn import Parameter
7 |
8 |
9 | from train_config import config as cfg
10 |
11 |
12 | import timm
13 |
14 |
15 | class Net(nn.Module):
16 | def __init__(self, num_classes=1000):
17 | super().__init__()
18 |
19 |
20 | self.model = timm.create_model('tf_mobilenetv3_large_minimal_100', pretrained=True,exportable=True)
21 |
22 |
23 | self._avg_pooling = nn.AdaptiveAvgPool2d(1)
24 |
25 | self._fc = nn.Linear(1280 , num_classes, bias=True)
26 |
27 | def forward(self, inputs):
28 |
29 | #do preprocess
30 |
31 |
32 | bs = inputs.size(0)
33 | # Convolution layers
34 | x = self.model.forward_features(inputs)
35 |
36 | fm = x.view(bs, -1)
37 | x = self._fc(fm)
38 |
39 | return x
40 |
41 |
42 |
43 | if __name__=='__main__':
44 | import torch
45 | import torchvision
46 |
47 | dummy_input = torch.randn(1, 3, 224, 224, device='cpu')
48 | model = Net()
49 |
50 | ### load your weights
51 | model.eval()
52 | # Providing input and output names sets the display names for values
53 | # within the model's graph. Setting these does not change the semantics
54 | # of the graph; it is only for readability.
55 | #
56 | # The inputs to the network consist of the flat list of inputs (i.e.
57 | # the values you would pass to the forward() method) followed by the
58 | # flat list of parameters. You can partially specify names, i.e. provide
59 | # a list here shorter than the number of inputs to the model, and we will
60 | # only set that subset of names, starting from the beginning.
61 |
62 |
63 | torch.onnx.export(model, dummy_input, "classifier.onnx",opset_version=11 )
64 |
--------------------------------------------------------------------------------
/tools/convert_to_onnx.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('.')
3 |
4 |
5 | from lib.core.base_trainer.model import COTRAIN
6 |
7 |
8 | import torch
9 | import torchvision
10 | import argparse
11 | import re
12 |
13 |
14 |
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('--model', type=str,default=None, help='the thres for detect')
17 | args = parser.parse_args()
18 |
19 | model_path=args.model
20 |
21 | device=torch.device('cpu')
22 | dummy_input = torch.randn(1, 3,128, 128 , device='cpu')
23 |
24 | style_model = COTRAIN(inference=True).to(device)
25 | style_model.eval()
26 | if model_path is not None:
27 |
28 | state_dict = torch.load(model_path,map_location=device)
29 | # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
30 |
31 | style_model.load_state_dict(state_dict,strict=False)
32 | style_model.to(device)
33 |
34 |
35 |
36 |
37 | ### load your weights
38 | style_model.eval()
39 | # Providing input and output names sets the display names for values
40 | # within the model's graph. Setting these does not change the semantics
41 | # of the graph; it is only for readability.
42 | #
43 | # The inputs to the network consist of the flat list of inputs (i.e.
44 | # the values you would pass to the forward() method) followed by the
45 | # flat list of parameters. You can partially specify names, i.e. provide
46 | # a list here shorter than the number of inputs to the model, and we will
47 | # only set that subset of names, starting from the beginning.
48 |
49 |
50 | torch.onnx.export(style_model,
51 | dummy_input,
52 | "face_kps.onnx" ,
53 | # dynamic_axes={'input': {2: 'height',
54 | # 3: 'width',
55 | # }},
56 | input_names=["input"],
57 | output_names=["output"],opset_version=12)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Note: This code is not maintained. Please refer to Peppa_Pig_Face_Landmark, The TRAIN subdir.There is better model.
2 | # face_landmark
3 | A simple face aligment method, based on pytorch
4 |
5 |
6 | ## introduction
7 | This is the pytorch branch.
8 |
9 | It is simple and flexible, trained with wingloss , multi task learning, also with data augmentation based on headpose and face attributes(eyes state and mouth state).
10 |
11 | [CN blog](https://blog.csdn.net/qq_35606924/article/details/99711208)
12 |
13 | And i suggest that you could try with another project,including face detect and keypoints, and some optimizations were made, u can check it there **[[pappa_pig_face_engine]](https://github.com/610265158/Peppa_Pig_Face_Engine).**
14 |
15 | Contact me if u have problem about it. 2120140200@mail.nankai.edu.cn :)
16 |
17 | demo pictures:
18 |
19 | 
20 |
21 | 
22 |
23 | this gif is from github.com/610265158/Peppa_Pig_Face_Engine )
24 |
25 | pretrained model:
26 |
27 | ###### shufflenetv2_1.0
28 | + [baidu disk](https://pan.baidu.com/s/1MK3wI0nrZUOA8yU0ChWvBw) (code 9x2m)
29 |
30 |
31 |
32 | ## requirment
33 |
34 | + pytorch
35 |
36 | + tensorpack (for data provider)
37 |
38 | + opencv
39 |
40 | + python 3.6
41 |
42 |
43 | ## useage
44 |
45 | ### train
46 |
47 | 1. download all the [300W](https://ibug.doc.ic.ac.uk/resources/facial-point-annotations/) data set including the [300VW](https://ibug.doc.ic.ac.uk/resources/300-VW/)(parse as images, and make the label the same formate as 300W)
48 | ```
49 | ├── 300VW
50 | │ ├── 001_annot
51 | │ ├── 002_annot
52 | │ ....
53 | ├── 300W
54 | │ ├── 01_Indoor
55 | │ └── 02_Outdoor
56 | ├── AFW
57 | │ └── afw
58 | ├── HELEN
59 | │ ├── testset
60 | │ └── trainset
61 | ├── IBUG
62 | │ └── ibug
63 | ├── LFPW
64 | │ ├── testset
65 | │ └── trainset
66 | ```
67 |
68 | 2. run ` python make_json.py` produce train.json and val.json
69 | (if u like train u own data, please read the json produced , it is quite simple)
70 |
71 | 3. then, run: `python train.py`
72 |
73 | 4. by default it trained with shufflenetv2_1.0
74 |
75 |
76 | ### visualization
77 |
78 | ```
79 | python vis.py --model ./model/keypoints.pth
80 | ```
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
--------------------------------------------------------------------------------
/vis.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 |
3 | from lib.dataset.dataietr import AlaskaDataIter
4 | from train_config import config
5 | from lib.core.base_trainer.model import COTRAIN
6 |
7 | import torch
8 | import time
9 | import argparse
10 |
11 | from torch.utils.data import DataLoader
12 | import numpy as np
13 | import os
14 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
15 |
16 | import cv2
17 | from train_config import config as cfg
18 | cfg.TRAIN.batch_size=1
19 |
20 |
21 | df=pd.read_csv(cfg.DATA.val_f_path)
22 | val_genererator = AlaskaDataIter(df,
23 | img_root=cfg.DATA.root_path,
24 | training_flag=False, shuffle=False)
25 | val_ds=DataLoader(val_genererator,
26 | cfg.TRAIN.batch_size,
27 | num_workers=1,shuffle=False)
28 |
29 |
30 | def vis(weight):
31 | device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
32 | model=COTRAIN(inference=True)
33 | model.eval()
34 | state_dict = torch.load(weight, map_location=device)
35 | model.load_state_dict(state_dict, strict=False)
36 |
37 |
38 |
39 | for step, (ids, images, kps) in enumerate(val_ds):
40 |
41 |
42 | # kps = kps.to(device).float()
43 |
44 | img_show = np.array(images)*255
45 | print(img_show.shape)
46 |
47 | img_show=np.transpose(img_show[0],axes=[1,2,0]).astype(np.uint8)
48 | img_show=np.ascontiguousarray(img_show)
49 | images=images.to(device)
50 | print(images.size())
51 |
52 | start=time.time()
53 | with torch.no_grad():
54 | res=model(images)
55 | res=res.detach().numpy()
56 | print(res)
57 | print('xxxx',time.time()-start)
58 | #print(res)
59 |
60 |
61 |
62 |
63 | landmark = np.array(res[0][0:136]).reshape([-1, 2])
64 |
65 | for _index in range(landmark.shape[0]):
66 | x_y = landmark[_index]
67 | #print(x_y)
68 | cv2.circle(img_show, center=(int(x_y[0] * 128),
69 | int(x_y[1] * 128)),
70 | color=(255, 122, 122), radius=1, thickness=2)
71 |
72 | cv2.imshow('tmp',img_show)
73 | cv2.waitKey(0)
74 |
75 |
76 | def load_checkpoint(net, checkpoint):
77 | # from collections import OrderedDict
78 | #
79 | # temp = OrderedDict()
80 | # if 'state_dict' in checkpoint:
81 | # checkpoint = dict(checkpoint['state_dict'])
82 | # for k in checkpoint:
83 | # k2 = 'module.'+k if not k.startswith('module.') else k
84 | # temp[k2] = checkpoint[k]
85 |
86 | net.load_state_dict(torch.load(checkpoint,map_location=torch.device('cpu')), strict=True)
87 | if __name__=='__main__':
88 | parser = argparse.ArgumentParser(description='Start train.')
89 |
90 | parser.add_argument('--model', dest='model', type=str, default=None, \
91 | help='the model to use')
92 |
93 | args = parser.parse_args()
94 |
95 |
96 | vis(args.model)
97 |
98 |
99 |
100 |
101 |
--------------------------------------------------------------------------------
/train_config.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import os
4 | import numpy as np
5 | from easydict import EasyDict as edict
6 |
7 | config = edict()
8 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
9 |
10 | config.TRAIN = edict()
11 | #### below are params for dataiter
12 | config.TRAIN.process_num = 4
13 | config.TRAIN.prefetch_size = 30
14 | ############
15 |
16 | config.TRAIN.num_gpu = 1
17 | config.TRAIN.batch_size = 128
18 | config.TRAIN.validatiojn_batch_size = 128
19 | config.TRAIN.accumulation_batch_size=128
20 | config.TRAIN.log_interval = 10 ##10 iters for a log msg
21 | config.TRAIN.epoch = 50
22 | config.TRAIN.early_stop=20
23 | config.TRAIN.test_interval=1
24 |
25 | config.TRAIN.init_lr = 1.e-3
26 | config.TRAIN.warmup_step=1500
27 | config.TRAIN.weight_decay_factor = 5.e-4 ####l2
28 | config.TRAIN.vis=False #### if to check the training data
29 | config.TRAIN.mix_precision=False ##use mix precision to speedup, tf1.14 at least
30 | config.TRAIN.opt='AdamW' ##Adam or SGD
31 | config.TRAIN.gradient_clip=5
32 |
33 | config.MODEL = edict()
34 | config.MODEL.model_path = './models/' ## save directory
35 | config.MODEL.hin = 128 # input size during training , 128,160, depends on
36 | config.MODEL.win = 128
37 |
38 | config.MODEL.out_channel=136+3+4 # output vector 68 points , 3 headpose ,4 cls params,(left eye, right eye, mouth, big mouth open)
39 |
40 | config.MODEL.pretrained_model='./models/fold0_epoch_8_val_loss_12.163321_val_mse_0.249408.pth'
41 | config.DATA = edict()
42 |
43 | config.DATA.root_path=''
44 | config.DATA.train_f_path='train.csv'
45 | config.DATA.val_f_path='val.csv'
46 |
47 | ############the model is trained with RGB mode
48 | config.DATA.root_path='../tmp_crop_data_face_landmark_pytorch'
49 |
50 | config.DATA.base_extend_range=[0.1,0.2] ###extand
51 | config.DATA.scale_factor=[0.7,1.35] ###scales
52 |
53 | config.DATA.symmetry = [(0, 16), (1, 15), (2, 14), (3, 13), (4, 12), (5, 11), (6, 10), (7, 9), (8, 8),
54 | (17, 26), (18, 25), (19, 24), (20, 23), (21, 22),
55 | (31, 35), (32, 34),
56 | (36, 45), (37, 44), (38, 43), (39, 42), (40, 47), (41, 46),
57 | (48, 54), (49, 53), (50, 52), (55, 59), (56, 58), (60, 64), (61, 63), (65, 67)]
58 |
59 |
60 |
61 |
62 |
63 | weights=[1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1., #####bouding
64 | 1.5,1.5,1.5,1.5,1.5, 1.5,1.5,1.5,1.5,1.5, #####eyebows
65 | 1.,1.,1.,1.,1.,1.,1.,1.,1., #####nose
66 | 1.5,1.5,1.5,1.5,1.5,1.5, 1.5,1.5,1.5,1.5,1.5,1.5, ####eyes
67 | 1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1. #####mouth
68 | ]
69 | weights_xy=[[x,x] for x in weights]
70 |
71 | config.DATA.weights = np.array(weights_xy,dtype=np.float32).reshape([-1])
72 |
73 |
74 |
75 |
76 |
77 |
78 |
--------------------------------------------------------------------------------
/lib/core/api/keypoint.py:
--------------------------------------------------------------------------------
1 | #-*-coding:utf-8-*-
2 | import tensorflow as tf
3 | import cv2
4 | import numpy as np
5 | import time
6 |
7 | from lib.utils.init import init
8 | from train_config import config
9 |
10 |
11 | class Keypoints:
12 |
13 | def __init__(self,pb):
14 | self.PADDING_FLAG=True
15 | self.ANCHOR_SIZE=0
16 | self._pb_path=pb
17 |
18 | self._graph = tf.Graph()
19 | self._sess = tf.Session(graph=self._graph)
20 |
21 | with self._graph.as_default():
22 |
23 | self._graph, self._sess = init(self._pb_path)
24 | self.img_input = tf.get_default_graph().get_tensor_by_name('tower_0/images:0')
25 | self.embeddings_keypoints = tf.get_default_graph().get_tensor_by_name('tower_0/prediction:0')
26 | self.training = tf.get_default_graph().get_tensor_by_name('training_flag:0')
27 |
28 |
29 | def simple_run(self,img):
30 |
31 | with self._graph.as_default():
32 | img=np.expand_dims(img,axis=0)
33 |
34 | star=time.time()
35 | for i in range(10):
36 | [landmarkers] = self._sess.run([self.embeddings_keypoints], \
37 | feed_dict={self.img_input: img, \
38 | self.training: False})
39 | print((time.time()-star)/10)
40 | return landmarkers
41 | def run(self,_img,_bboxs):
42 | img_batch = []
43 | coordinate_correct_batch = []
44 | pad_batch = []
45 |
46 | res_landmarks_batch=[]
47 | res_state_batch=[]
48 | with self._sess.as_default():
49 | with self._graph.as_default():
50 | for _box in _bboxs:
51 | crop_img=self._one_shot_preprocess(_img,_box)
52 |
53 | landmarkers = self._sess.run([self.embeddings_keypoints], \
54 | feed_dict={self.img_input: crop_img, \
55 | self.training: False})
56 |
57 |
58 |
59 |
60 | def _one_shot_preprocess(self,image,bbox):
61 |
62 |
63 | add = int(np.max(bbox))
64 | bimg = cv2.copyMakeBorder(image, add, add, add, add, borderType=cv2.BORDER_CONSTANT,
65 | value=config.DATA.pixel_means.tolist())
66 |
67 | bbox += add
68 |
69 |
70 | bbox_width = bbox[2] - bbox[0]
71 | bbox_height = bbox[3] - bbox[1]
72 |
73 | ###extend
74 | bbox[0] -= config.DATA.base_extend_range[0] * bbox_width
75 | bbox[1] -= config.DATA.base_extend_range[1] * bbox_height
76 | bbox[2] += config.DATA.base_extend_range[0] * bbox_width
77 | bbox[3] += config.DATA.base_extend_range[1] * bbox_height
78 |
79 |
80 | ##
81 | bbox = bbox.astype(np.int)
82 | crop_image = bimg[bbox[1]:bbox[3], bbox[0]:bbox[2], :]
83 | crop_image = cv2.resize(crop_image, (config.MODEL.hin, config.MODEL.win),
84 | interpolation=cv2.INTER_AREA)
85 |
86 |
87 | if config.MODEL.channel==1:
88 | crop_image=cv2.cvtColor(crop_image,cv2.COLOR_RGB2GRAY)
89 | crop_image=np.expand_dims(crop_image,axis=-1)
90 |
91 |
92 | return crop_image
--------------------------------------------------------------------------------
/lib/dataset/headpose.py:
--------------------------------------------------------------------------------
1 | #-*-coding:utf-8-*-
2 |
3 |
4 | import cv2
5 | import numpy as np
6 |
7 |
8 |
9 |
10 | # object_pts = np.float32([[6.825897, 6.760612, 4.402142],
11 | # [1.330353, 7.122144, 6.903745],
12 | # [-1.330353, 7.122144, 6.903745],
13 | # [-6.825897, 6.760612, 4.402142],
14 | # [5.311432, 5.485328, 3.987654],
15 | # [1.789930, 5.393625, 4.413414],
16 | # [-1.789930, 5.393625, 4.413414],
17 | # [-5.311432, 5.485328, 3.987654],
18 | # [2.005628, 1.409845, 6.165652],
19 | # [-2.005628, 1.409845, 6.165652],
20 | # [2.774015, -2.080775, 5.048531],
21 | # [-2.774015, -2.080775, 5.048531],
22 | # [0.000000, -3.116408, 6.097667],
23 | # [0.000000, -7.415691, 4.070434]])
24 | object_pts = np.float32([[6.825897, 6.760612, 4.402142],
25 | [1.330353, 7.122144, 6.903745],
26 | [-1.330353, 7.122144, 6.903745],
27 | [-6.825897, 6.760612, 4.402142],
28 | [5.311432, 5.485328, 3.987654],
29 | [1.789930, 5.393625, 4.413414],
30 | [-1.789930, 5.393625, 4.413414],
31 | [-5.311432, 5.485328, 3.987654],
32 | [2.005628, 1.409845, 6.165652],
33 | [-2.005628, 1.409845, 6.165652]])
34 | reprojectsrc = np.float32([[10.0, 10.0, 10.0],
35 | [10.0, 10.0, -10.0],
36 | [10.0, -10.0, -10.0],
37 | [10.0, -10.0, 10.0],
38 | [-10.0, 10.0, 10.0],
39 | [-10.0, 10.0, -10.0],
40 | [-10.0, -10.0, -10.0],
41 | [-10.0, -10.0, 10.0]])
42 |
43 | line_pairs = [[0, 1], [1, 2], [2, 3], [3, 0],
44 | [4, 5], [5, 6], [6, 7], [7, 4],
45 | [0, 4], [1, 5], [2, 6], [3, 7]]
46 |
47 |
48 | def get_head_pose(shape,img):
49 | h,w,_=img.shape
50 | K = [w, 0.0, w//2,
51 | 0.0, w, h//2,
52 | 0.0, 0.0, 1.0]
53 | # Assuming no lens distortion
54 | D = [0, 0, 0.0, 0.0, 0]
55 |
56 | cam_matrix = np.array(K).reshape(3, 3).astype(np.float32)
57 | dist_coeffs = np.array(D).reshape(5, 1).astype(np.float32)
58 |
59 |
60 |
61 | # image_pts = np.float32([shape[17], shape[21], shape[22], shape[26], shape[36],
62 | # shape[39], shape[42], shape[45], shape[31], shape[35],
63 | # shape[48], shape[54], shape[57], shape[8]])
64 | image_pts = np.float32([shape[17], shape[21], shape[22], shape[26], shape[36],
65 | shape[39], shape[42], shape[45], shape[31], shape[35]])
66 | _, rotation_vec, translation_vec = cv2.solvePnP(object_pts, image_pts, cam_matrix, dist_coeffs)
67 |
68 | reprojectdst, _ = cv2.projectPoints(reprojectsrc, rotation_vec, translation_vec, cam_matrix,
69 | dist_coeffs)
70 |
71 | reprojectdst = tuple(map(tuple, reprojectdst.reshape(8, 2)))
72 |
73 | # calc euler angle
74 | rotation_mat, _ = cv2.Rodrigues(rotation_vec)
75 | pose_mat = cv2.hconcat((rotation_mat, translation_vec))
76 | _, _, _, _, _, _, euler_angle = cv2.decomposeProjectionMatrix(pose_mat)
77 |
78 | return reprojectdst, euler_angle
79 |
--------------------------------------------------------------------------------
/tools/convert_to_coreml.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('.')
3 |
4 | from lib.core.base_trainer.model import Net,COTRAIN
5 |
6 | import argparse
7 |
8 | import re
9 |
10 | parser = argparse.ArgumentParser()
11 |
12 | parser.add_argument('--model', type=str,default=None, help='the thres for detect')
13 | args = parser.parse_args()
14 |
15 | model_path=args.model
16 |
17 | import warnings
18 | warnings.simplefilter(action='ignore', category=FutureWarning)
19 |
20 | import torch
21 |
22 | import coremltools as ct
23 | from coremltools.proto import FeatureTypes_pb2 as ft
24 | device=torch.device("cuda" if torch.cuda.is_available() else 'cpu')
25 | dummy_input = torch.randn(1, 3,288, 160, device='cpu')
26 | style_model = COTRAIN(inference=True).to(device)
27 | ### load your weights
28 | style_model.eval()
29 |
30 |
31 |
32 | if model_path is not None:
33 | state_dict = torch.load(model_path,map_location=device)
34 | # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
35 |
36 | style_model.load_state_dict(state_dict)
37 | style_model.to(device)
38 | style_model.eval()
39 | trace = torch.jit.trace(style_model, dummy_input)
40 |
41 | # shapes = [(1,3, 128*i, 128*j) for i in range(1, 10) for j in range(1,10)]
42 | # input_shape = ct.EnumeratedShapes(shapes=shapes)
43 |
44 | # input_shape = ct.Shape(shape=(1, 3,ct.RangeDim(256, 256*10), ct.RangeDim(256, 256*10)))
45 | # Convert the model
46 | mlmodel = ct.convert(
47 | trace,
48 | inputs=[ct.ImageType(name="__input", scale=1/255.,
49 | shape=dummy_input.shape)],
50 | minimum_deployment_target=ct.target.iOS14
51 |
52 | )
53 |
54 |
55 |
56 |
57 | spec = mlmodel.get_spec()
58 |
59 | # Edit the spec
60 | ct.utils.rename_feature(spec, '__input', 'image')
61 | ct.utils.rename_feature(spec, 'var_917', 'output')
62 |
63 |
64 | # to update the shapes of the first input
65 | # input_name = spec.description.input[0].name
66 | # img_size_ranges = ct.models.neural_network.flexible_shape_utils.NeuralNetworkImageSizeRange()
67 | # img_size_ranges.add_height_range((256, -1))
68 | # img_size_ranges.add_width_range((256, -1))
69 |
70 | # ct.models.neural_network.flexible_shape_utils.update_image_size_range(spec,
71 | # feature_name=input_name,size_range=img_size_ranges)
72 |
73 | # save out the updated model
74 | mlmodel = ct.models.MLModel(spec)
75 | print(mlmodel)
76 | print(spec.description.output)
77 | # for output in spec.description.output:
78 | # # Change output type
79 | # output.type.imageType.colorSpace = ft.ImageFeatureType.ColorSpace.Value('RGB')
80 | # # channels, height, width = tuple(output.type.multiArrayType.shape)
81 | # #
82 | # # # Set image shape
83 | # output.type.imageType.width = 256
84 | # output.type.imageType.height = 256
85 |
86 | # output_name = spec.description.output[0].name
87 | # ct.models.neural_network.flexible_shape_utils.update_image_size_range(spec,
88 | # feature_name=output_name,size_range=img_size_ranges)
89 |
90 |
91 | mlmodel = ct.models.MLModel(spec)
92 | print(mlmodel)
93 |
94 |
95 | from coremltools.models.neural_network import quantization_utils
96 | from coremltools.models.neural_network.quantization_utils import AdvancedQuantizedLayerSelector
97 |
98 | selector = AdvancedQuantizedLayerSelector(
99 | skip_layer_types=['batchnorm','bias'],
100 | minimum_conv_kernel_channels=4,
101 | minimum_conv_weight_count=4096
102 | )
103 |
104 | model_fp8 = quantization_utils.quantize_weights(mlmodel, nbits=8,quantization_mode='linear',selector=selector)
105 |
106 |
107 |
108 | fp_8_file='./neural_style.mlmodel'
109 | model_fp8.save(fp_8_file)
--------------------------------------------------------------------------------
/lib/core/model/loss/simpleface_loss.py:
--------------------------------------------------------------------------------
1 | # -*-coding:utf-8-*-
2 | import sys
3 |
4 |
5 | import torch
6 | import math
7 | import numpy as np
8 |
9 | from train_config import config as cfg
10 |
11 | bce_losser=torch.nn.BCEWithLogitsLoss()
12 | def _wing_loss(landmarks, labels, w=10.0, epsilon=2.0, weights=1.):
13 | """
14 | Arguments:
15 | landmarks, labels: float tensors with shape [batch_size, landmarks]. landmarks means x1,x2,x3,x4...y1,y2,y3,y4 1-D
16 | w, epsilon: a float numbers.
17 | Returns:
18 | a float tensor with shape [].
19 | """
20 |
21 | x = landmarks - labels
22 | c = w * (1.0 - math.log(1.0 + w / epsilon))
23 | absolute_x = torch.abs(x)
24 | losses = torch.where(
25 | torch.gt(absolute_x, w), absolute_x - c,
26 | w * torch.log(1.0 + absolute_x / epsilon)
27 |
28 | )
29 | losses = losses * torch.tensor(cfg.DATA.weights,device='cuda')
30 | loss = torch.sum(torch.mean(losses * weights, dim=[0]))
31 |
32 | return loss
33 |
34 |
35 | def _mse(landmarks, labels, weights=1.):
36 | return torch.mean(0.5 * (landmarks - labels) *(landmarks - labels)* weights)
37 |
38 |
39 | def l1(landmarks, labels):
40 | return torch.mean(landmarks - labels)
41 |
42 |
43 | def calculate_loss(predict_keypoints, label_keypoints):
44 |
45 | landmark_label = label_keypoints[:, 0:136]
46 | pose_label = label_keypoints[:, 136:139]
47 | leye_cls_label = label_keypoints[:, 139]
48 | reye_cls_label = label_keypoints[:, 140]
49 | mouth_cls_label = label_keypoints[:, 141]
50 | big_mouth_cls_label = label_keypoints[:, 142]
51 |
52 | landmark_predict = predict_keypoints[:, 0:136]
53 | pose_predict = predict_keypoints[:, 136:139]
54 | leye_cls_predict = predict_keypoints[:, 139]
55 | reye_cls_predict = predict_keypoints[:, 140]
56 | mouth_cls_predict = predict_keypoints[:, 141]
57 | big_mouth_cls_predict = predict_keypoints[:, 142]
58 |
59 | loss = _wing_loss(landmark_predict, landmark_label)
60 |
61 | loss_pose = _mse(pose_predict, pose_label)
62 |
63 | leye_loss = bce_losser(leye_cls_predict,leye_cls_label)
64 | reye_loss = bce_losser(reye_cls_predict,reye_cls_label)
65 |
66 | mouth_loss = bce_losser(mouth_cls_predict,mouth_cls_label)
67 | mouth_loss_big = bce_losser(big_mouth_cls_predict,big_mouth_cls_label)
68 | mouth_loss = mouth_loss + mouth_loss_big
69 |
70 | # ##make crosssentropy
71 | # leye_cla_correct_prediction = tf.equal(
72 | # tf.cast(tf.greater_equal(tf.nn.sigmoid(leye_cls_predict), 0.5), tf.int32),
73 | # tf.cast(leye_cla_label, tf.int32))
74 | # leye_cla_accuracy = tf.reduce_mean(tf.cast(leye_cla_correct_prediction, tf.float32))
75 | #
76 | # reye_cla_correct_prediction = tf.equal(
77 | # tf.cast(tf.greater_equal(tf.nn.sigmoid(reye_cla_predict), 0.5), tf.int32),
78 | # tf.cast(reye_cla_label, tf.int32))
79 | # reye_cla_accuracy = tf.reduce_mean(tf.cast(reye_cla_correct_prediction, tf.float32))
80 | # mouth_cla_correct_prediction = tf.equal(
81 | # tf.cast(tf.greater_equal(tf.nn.sigmoid(mouth_cla_predict), 0.5), tf.int32),
82 | # tf.cast(mouth_cla_label, tf.int32))
83 | # mouth_cla_accuracy = tf.reduce_mean(tf.cast(mouth_cla_correct_prediction, tf.float32))
84 |
85 | #### l2 regularization_losses
86 | # l2_loss = []
87 | # l2_reg = tf.keras.regularizers.l2(cfg.TRAIN.weight_decay_factor)
88 | # variables_restore = tf.get_collection(tf.GraphKeys.MODEL_VARIABLES)
89 | # for var in variables_restore:
90 | # if 'weight' in var.name:
91 | # l2_loss.append(l2_reg(var))
92 | # regularization_losses = tf.add_n(l2_loss, name='l1_loss')
93 |
94 | return loss + loss_pose + leye_loss + reye_loss + mouth_loss
95 |
96 |
97 |
--------------------------------------------------------------------------------
/lib/dataset/augmentor/visual_augmentation.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import random
4 |
5 | def pixel_jitter(src,p=0.5,max_=5.):
6 |
7 | src=src.astype(np.float32)
8 |
9 | pattern=(np.random.rand(src.shape[0], src.shape[1],src.shape[2])-0.5)*2*max_
10 | img = src + pattern
11 |
12 | img[img<0]=0
13 | img[img >255] = 255
14 |
15 | img = img.astype(np.uint8)
16 |
17 | return img
18 |
19 | def gray(src):
20 | g_img=cv2.cvtColor(src,cv2.COLOR_RGB2GRAY)
21 | src[:,:,0]=g_img
22 | src[:,:,1]=g_img
23 | src[:,:,2]=g_img
24 | return src
25 |
26 | def swap_change(src):
27 | a = [0,1,2]
28 |
29 | k = random.sample(a, 3)
30 |
31 | res=src.copy()
32 | res[:,:,0]=src[:,:,k[0]]
33 | res[:, :, 1] = src[:, :, k[1]]
34 | res[:, :, 2] = src[:, :, k[2]]
35 | return res
36 |
37 |
38 | def Img_dropout(src,max_pattern_ratio=0.05):
39 | pattern=np.ones_like(src)
40 | width_ratio = random.uniform(0, max_pattern_ratio)
41 | height_ratio = random.uniform(0, max_pattern_ratio)
42 | width=src.shape[1]
43 | height=src.shape[0]
44 | block_width=width*width_ratio
45 | block_height=height*height_ratio
46 | width_start=int(random.uniform(0,width-block_width))
47 | width_end=int(width_start+block_width)
48 | height_start=int(random.uniform(0,height-block_height))
49 | height_end=int(height_start+block_height)
50 | pattern[height_start:height_end,width_start:width_end,:]=0
51 | img=src*pattern
52 | return img
53 |
54 |
55 |
56 | def blur_heatmap(src, ksize=(3, 3)):
57 | for i in range(src.shape[2]):
58 | src[:, :, i] = cv2.GaussianBlur(src[:, :, i], ksize, 0)
59 | amin, amax = src[:, :, i].min(), src[:, :, i].max() # 求最大最小值
60 | if amax>0:
61 | src[:, :, i] = (src[:, :, i] - amin) / (amax - amin) # (矩阵元素-最小值)/(最大值-最小值)
62 | return src
63 | def blur(src,ksize=(3,3)):
64 | for i in range(src.shape[2]):
65 | src[:, :, i]=cv2.GaussianBlur(src[:, :, i],ksize,1.5)
66 | return src
67 |
68 |
69 |
70 |
71 | def adjust_contrast(image, factor):
72 | """ Adjust contrast of an image.
73 | Args
74 | image: Image to adjust.
75 | factor: A factor for adjusting contrast.
76 | """
77 | mean = image.mean(axis=0).mean(axis=0)
78 | return _clip((image - mean) * factor + mean)
79 |
80 |
81 | def adjust_brightness(image, delta):
82 | """ Adjust brightness of an image
83 | Args
84 | image: Image to adjust.
85 | delta: Brightness offset between -1 and 1 added to the pixel values.
86 | """
87 | return _clip(image + delta * 255)
88 |
89 |
90 | def adjust_hue(image, delta):
91 | """ Adjust hue of an image.
92 | Args
93 | image: Image to adjust.
94 | delta: An interval between -1 and 1 for the amount added to the hue channel.
95 | The values are rotated if they exceed 180.
96 | """
97 | image[..., 0] = np.mod(image[..., 0] + delta * 180, 180)
98 | return image
99 |
100 |
101 | def adjust_saturation(image, factor):
102 | """ Adjust saturation of an image.
103 | Args
104 | image: Image to adjust.
105 | factor: An interval for the factor multiplying the saturation values of each pixel.
106 | """
107 | image[..., 1] = np.clip(image[..., 1] * factor, 0, 255)
108 | return image
109 |
110 |
111 | def _clip(image):
112 | """
113 | Clip and convert an image to np.uint8.
114 | Args
115 | image: Image to clip.
116 | """
117 | return np.clip(image, 0, 255).astype(np.uint8)
118 | def _uniform(val_range):
119 | """ Uniformly sample from the given range.
120 | Args
121 | val_range: A pair of lower and upper bound.
122 | """
123 | return np.random.uniform(val_range[0], val_range[1])
124 |
125 |
126 | class ColorDistort():
127 |
128 | def __init__(
129 | self,
130 | contrast_range=(0.8, 1.2),
131 | brightness_range=(-.2, .2),
132 | hue_range=(-0.1, 0.1),
133 | saturation_range=(0.8, 1.2)
134 | ):
135 | self.contrast_range = contrast_range
136 | self.brightness_range = brightness_range
137 | self.hue_range = hue_range
138 | self.saturation_range = saturation_range
139 |
140 | def __call__(self, image):
141 |
142 |
143 | if self.contrast_range is not None:
144 | contrast_factor = _uniform(self.contrast_range)
145 | image = adjust_contrast(image,contrast_factor)
146 | if self.brightness_range is not None:
147 | brightness_delta = _uniform(self.brightness_range)
148 | image = adjust_brightness(image, brightness_delta)
149 |
150 | if self.hue_range is not None or self.saturation_range is not None:
151 |
152 | image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
153 |
154 | if self.hue_range is not None:
155 | hue_delta = _uniform(self.hue_range)
156 | image = adjust_hue(image, hue_delta)
157 |
158 | if self.saturation_range is not None:
159 | saturation_factor = _uniform(self.saturation_range)
160 | image = adjust_saturation(image, saturation_factor)
161 |
162 | image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
163 |
164 | return image
165 |
166 |
167 |
168 |
169 | class DsfdVisualAug():
170 | pass
--------------------------------------------------------------------------------
/make_json.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import numpy as np
4 | import json
5 | import traceback
6 | import cv2
7 | import pandas as pd
8 |
9 | from tqdm import tqdm
10 | '''
11 | i decide to merge more data from CelebA, the data anns will be complex, so json maybe a better way.
12 | '''
13 |
14 |
15 |
16 |
17 |
18 | data_dir='/media/lz/ssd_2/coco_data/facelandmark/PUB' ########points to your director,300w
19 | #celeba_data_dir='CELEBA' ########points to your director,CELEBA
20 |
21 |
22 | train_json='./train.csv'
23 | val_json='./val.csv'
24 | save_dir='../tmp_crop_data_face_landmark_pytorch'
25 |
26 | if not os.access(save_dir,os.F_OK):
27 | os.mkdir(save_dir)
28 |
29 | def GetFileList(dir, fileList):
30 | newDir = dir
31 | if os.path.isfile(dir):
32 | fileList.append(dir)
33 | elif os.path.isdir(dir):
34 | for s in os.listdir(dir):
35 |
36 | # if s == "pts":
37 | # continue
38 | newDir=os.path.join(dir,s)
39 | GetFileList(newDir, fileList)
40 | return fileList
41 |
42 |
43 |
44 |
45 | pic_list=[]
46 | GetFileList(data_dir,pic_list)
47 |
48 | pic_list=[x for x in pic_list if '.jpg' in x or 'png' in x or 'jpeg' in x ]
49 |
50 |
51 | ratio=0.95
52 | train_list=[x for x in pic_list if 'AFW' not in x]
53 | val_list=[x for x in pic_list if 'AFW' in x]
54 |
55 | # train_list=[x for x in pic_list if '300W/' not in x]
56 | # val_list=[x for x in pic_list if '300W/' in x]
57 |
58 |
59 | def process_data(data_list,csv_nm):
60 |
61 |
62 | global cnt
63 | image_list=[]
64 | keypoint_list=[]
65 |
66 | for pic in tqdm(data_list):
67 | one_image_ann={}
68 |
69 | ### image_path
70 | one_image_ann['image_path_raw']=pic
71 |
72 | #### keypoints
73 | pts=pic.rsplit('.',1)[0]+'.pts'
74 | if os.access(pic,os.F_OK) and os.access(pts,os.F_OK):
75 | try:
76 | tmp=[]
77 | with open(pts) as p_f:
78 | labels=p_f.readlines()[3:-1]
79 | for _one_p in labels:
80 | xy = _one_p.rstrip().split(' ')
81 | tmp.append([float(xy[0]),float(xy[1])])
82 |
83 | one_image_ann['keypoints'] = tmp
84 |
85 | label = np.array(tmp).reshape((-1, 2))
86 | bbox = [float(np.min(label[:, 0])), float(np.min(label[:, 1])), float(np.max(label[:, 0])), float(np.max(label[:, 1]))]
87 | one_image_ann['bbox'] = bbox
88 |
89 | ### placeholder
90 | one_image_ann['attr'] = None
91 |
92 | ###### crop it
93 |
94 | image=cv2.imread(one_image_ann['image_path_raw'],cv2.IMREAD_COLOR)
95 |
96 | h,w,c=image.shape
97 |
98 | ##expanded for
99 | bbox_int = [int(x) for x in bbox]
100 | bbox_width = bbox_int[2] - bbox_int[0]
101 | bbox_height = bbox_int[3] - bbox_int[1]
102 |
103 | center_x=(bbox_int[2] + bbox_int[0])//2
104 | center_y=(bbox_int[3] + bbox_int[1])//2
105 |
106 | x1=int(center_x-bbox_width*2)
107 | x1=x1 if x1>=0 else 0
108 |
109 | y1 = int(center_y - bbox_height*2)
110 | y1 = y1 if y1 >= 0 else 0
111 |
112 | x2 = int(center_x + bbox_width*2)
113 | x2 = x2 if x2 512:
124 | scale=512/max(hh,ww)
125 | else:
126 | scale=1
127 | crop_face=cv2.resize(crop_face,None,fx=scale,fy=scale)
128 |
129 | one_image_ann['bbox'][0] *= scale
130 | one_image_ann['bbox'][1] *= scale
131 | one_image_ann['bbox'][2] *= scale
132 | one_image_ann['bbox'][3] *= scale
133 |
134 |
135 | x1*=scale
136 | y1 *= scale
137 | x2 *= scale
138 | y2 *= scale
139 | for i in range(len(one_image_ann['keypoints'])):
140 | one_image_ann['keypoints'][i][0]*= scale
141 | one_image_ann['keypoints'][i][1]*= scale
142 |
143 | fname= one_image_ann['image_path_raw'].split('PUB/')[-1]
144 |
145 | fname=fname.replace('/','_').replace('/','_')
146 |
147 |
148 | # cv2.imwrite(one_image_ann['image_name'],crop_face)
149 |
150 |
151 | one_image_ann['bbox'][0] -= x1
152 | one_image_ann['bbox'][1] -= y1
153 | one_image_ann['bbox'][2] -= x1
154 | one_image_ann['bbox'][3] -= y1
155 |
156 | for i in range(len(one_image_ann['keypoints'])):
157 | one_image_ann['keypoints'][i][0]-=x1
158 | one_image_ann['keypoints'][i][1]-=y1
159 |
160 |
161 | keypoint=list(np.array(one_image_ann['keypoints']).reshape(-1).astype(np.float32))
162 | # [x1,y1,x2,y2]=[int(x) for x in one_image_ann['bbox']]
163 | #
164 | # cv2.rectangle(crop_face,(x1,y1),(x2,y2),thickness=2,color=(255,0,0))
165 | #
166 | # landmark=np.array(one_image_ann['keypoints'])
167 | #
168 | # for _index in range(landmark.shape[0]):
169 | # x_y = landmark[_index]
170 | # # print(x_y)
171 | # cv2.circle(crop_face, center=(int(x_y[0] ),
172 | # int(x_y[1] )),
173 | # color=(255, 0, 0), radius=2, thickness=4)
174 | #
175 | #
176 | # cv2.imshow('ss', crop_face)
177 | # cv2.waitKey(0)
178 |
179 | image_list.append(fname)
180 | keypoint_list.append(keypoint)
181 | # json_list.append(one_image_ann)
182 | except:
183 | print(pic)
184 |
185 | print(traceback.print_exc())
186 |
187 |
188 | # with open(json_nm, 'w') as f:
189 | # json.dump(json_list, f, indent=2)
190 |
191 | data_dict={'image':image_list,
192 | 'keypoint':keypoint_list}
193 | df=pd.DataFrame(data_dict)
194 |
195 | df.to_csv(csv_nm,index=False)
196 |
197 |
198 | process_data(train_list,train_json)
199 |
200 |
201 | process_data(val_list,val_json)
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/lib/dataset/dataietr.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import random
4 | import cv2
5 | import json
6 | import numpy as np
7 | import pandas as pd
8 | from tqdm import tqdm
9 |
10 | from lib.utils.logger import logger
11 |
12 |
13 | import traceback
14 |
15 | from train_config import config as cfg
16 | import albumentations as A
17 | import os
18 | import copy
19 | from lib.dataset.augmentor.augmentation import Rotate_aug,\
20 | Affine_aug,\
21 | Mirror,\
22 | Padding_aug
23 |
24 | from lib.dataset.headpose import get_head_pose
25 |
26 |
27 |
28 | cv2.setNumThreads(0)
29 | cv2.ocl.setUseOpenCL(False)
30 |
31 | class AlaskaDataIter():
32 | def __init__(self, df,img_root,
33 | training_flag=True,shuffle=True):
34 | self.eye_close_thres = 0.02
35 | self.mouth_close_thres = 0.02
36 | self.big_mouth_open_thres = 0.08
37 |
38 |
39 |
40 | self.training_flag = training_flag
41 | self.shuffle = shuffle
42 | self.img_root_path=img_root
43 |
44 |
45 | self.df=df
46 | if self.training_flag:
47 | self.balance()
48 |
49 | self.train_trans = A.Compose([
50 |
51 | A.RandomBrightnessContrast(p=0.3),
52 | A.HueSaturationValue(p=0.3),
53 | A.OneOf([A.GaussianBlur(),
54 | A.MotionBlur()],
55 | p=0.1),
56 | A.ToGray(p=0.1),
57 | A.OneOf([A.GaussNoise(),
58 | A.ISONoise()],
59 | p=0.1),
60 | ])
61 |
62 |
63 | self.val_trans=A.Compose([
64 |
65 | A.Resize(height=cfg.MODEL.hin,
66 | width=cfg.MODEL.win)
67 |
68 | ])
69 |
70 |
71 |
72 | def __getitem__(self, item):
73 |
74 | return self.single_map_func(self.df.iloc[item], self.training_flag)
75 |
76 | def __len__(self):
77 |
78 | return len(self.df)
79 |
80 | def balance(self, ):
81 | df = copy.deepcopy(self.df)
82 |
83 | expanded=[]
84 | lar_count = 0
85 | for i in tqdm(range(len(df))):
86 |
87 | ann=df.iloc[i]
88 |
89 | ### 300w balance, according to keypoints
90 |
91 | label = ann['keypoint'][1:-1].split(',')
92 | label = [float(x) for x in label]
93 | label = np.array(label).reshape([-1, 2])
94 |
95 | bbox = [float(np.min(label[:, 0])), float(np.min(label[:, 1])), float(np.max(label[:, 0])),
96 | float(np.max(label[:, 1]))]
97 |
98 | bbox_width = bbox[2] - bbox[0]
99 | bbox_height = bbox[3] - bbox[1]
100 |
101 | # if bbox_width < 50 or bbox_height < 50:
102 | # res_anns.remove(ann)
103 |
104 | left_eye_close = np.sqrt(
105 | np.square(label[37, 0] - label[41, 0]) +
106 | np.square(label[37, 1] - label[41, 1])) / bbox_height < self.eye_close_thres \
107 | or np.sqrt(np.square(label[38, 0] - label[40, 0]) +
108 | np.square(label[38, 1] - label[40, 1])) / bbox_height < self.eye_close_thres
109 | right_eye_close = np.sqrt(
110 | np.square(label[43, 0] - label[47, 0]) +
111 | np.square(label[43, 1] - label[47, 1])) / bbox_height < self.eye_close_thres \
112 | or np.sqrt(np.square(label[44, 0] - label[46, 0]) +
113 | np.square(
114 | label[44, 1] - label[46, 1])) / bbox_height < self.eye_close_thres
115 | if left_eye_close or right_eye_close:
116 | for i in range(5):
117 | expanded.append(ann)
118 | ###half face
119 | if np.sqrt(np.square(label[36, 0] - label[45, 0]) +
120 | np.square(label[36, 1] - label[45, 1])) / bbox_width < 0.5:
121 | for i in range(10):
122 | expanded.append(ann)
123 |
124 | if np.sqrt(np.square(label[62, 0] - label[66, 0]) +
125 | np.square(label[62, 1] - label[66, 1])) / bbox_height > 0.15:
126 | for i in range(10):
127 | expanded.append(ann)
128 |
129 | if np.sqrt(np.square(label[62, 0] - label[66, 0]) +
130 | np.square(label[62, 1] - label[66, 1])) / cfg.MODEL.hin > self.big_mouth_open_thres:
131 | for i in range(25):
132 | expanded.append(ann)
133 | ##########eyes diff aug
134 | if left_eye_close and not right_eye_close:
135 | for i in range(20):
136 | expanded.append(ann)
137 | lar_count += 1
138 | if not left_eye_close and right_eye_close:
139 | for i in range(20):
140 | expanded.append(ann)
141 | lar_count += 15
142 |
143 | self.df=self.df.append(expanded)
144 | logger.info('befor balance the dataset contains %d images' % (len(df)))
145 | logger.info('after balanced the datasets contains %d samples' % (len(self.df)))
146 |
147 | def augmentationCropImage(self, img, bbox, joints=None, is_training=True):
148 |
149 | bbox = np.array(bbox).reshape(4, ).astype(np.float32)
150 | add = max(img.shape[0], img.shape[1])
151 |
152 | bimg = cv2.copyMakeBorder(img, add, add, add, add, borderType=cv2.BORDER_CONSTANT)
153 |
154 | objcenter = np.array([(bbox[0] + bbox[2]) / 2., (bbox[1] + bbox[3]) / 2.])
155 | bbox += add
156 | objcenter += add
157 |
158 | joints[:, :2] += add
159 |
160 | gt_width = (bbox[2] - bbox[0])
161 | gt_height = (bbox[3] - bbox[1])
162 |
163 | crop_width_half = gt_width * (1 + cfg.DATA.base_extend_range[0] * 2) // 2
164 | crop_height_half = gt_height * (1 + cfg.DATA.base_extend_range[1] * 2) // 2
165 |
166 | if is_training:
167 | min_x = int(objcenter[0] - crop_width_half + \
168 | random.uniform(-cfg.DATA.base_extend_range[0], cfg.DATA.base_extend_range[0]) * gt_width)
169 | max_x = int(objcenter[0] + crop_width_half + \
170 | random.uniform(-cfg.DATA.base_extend_range[0], cfg.DATA.base_extend_range[0]) * gt_width)
171 | min_y = int(objcenter[1] - crop_height_half + \
172 | random.uniform(-cfg.DATA.base_extend_range[1], cfg.DATA.base_extend_range[1]) * gt_height)
173 | max_y = int(objcenter[1] + crop_height_half + \
174 | random.uniform(-cfg.DATA.base_extend_range[1], cfg.DATA.base_extend_range[1]) * gt_height)
175 | else:
176 | min_x = int(objcenter[0] - crop_width_half)
177 | max_x = int(objcenter[0] + crop_width_half)
178 | min_y = int(objcenter[1] - crop_height_half)
179 | max_y = int(objcenter[1] + crop_height_half)
180 |
181 | joints[:, 0] = joints[:, 0] - min_x
182 | joints[:, 1] = joints[:, 1] - min_y
183 |
184 | img = bimg[min_y:max_y, min_x:max_x, :]
185 |
186 | crop_image_height, crop_image_width, _ = img.shape
187 | joints[:, 0] = joints[:, 0] / crop_image_width
188 | joints[:, 1] = joints[:, 1] / crop_image_height
189 |
190 | interp_methods = [cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_NEAREST,
191 | cv2.INTER_LANCZOS4]
192 | interp_method = random.choice(interp_methods)
193 |
194 | img = cv2.resize(img, (cfg.MODEL.win, cfg.MODEL.hin), interpolation=interp_method)
195 |
196 | joints[:, 0] = joints[:, 0] * cfg.MODEL.win
197 | joints[:, 1] = joints[:, 1] * cfg.MODEL.hin
198 | return img, joints
199 | def single_map_func(self, dp, is_training):
200 | """Data augmentation function."""
201 | ####customed here
202 |
203 |
204 | fn=dp['image']
205 | image=cv2.imread(os.path.join(self.img_root_path,fn))
206 | kps=dp['keypoint'][1:-1].split(',')
207 | kps=[float(x) for x in kps]
208 | kps=np.array(kps).reshape([-1,2])
209 | bbox = [float(np.min(kps[:, 0])), float(np.min(kps[:, 1])), float(np.max(kps[:, 0])),
210 | float(np.max(kps[:, 1]))]
211 |
212 | bbox = np.array(bbox)
213 |
214 | ### random crop and resize
215 | crop_image, label = self.augmentationCropImage(image, bbox, kps, is_training)
216 |
217 | if is_training:
218 |
219 | if random.uniform(0, 1) > 0.5:
220 | crop_image, label = Mirror(crop_image, label=label, symmetry=cfg.DATA.symmetry)
221 | if random.uniform(0, 1) > 0.0:
222 | angle = random.uniform(-45, 45)
223 | crop_image, label = Rotate_aug(crop_image, label=label, angle=angle)
224 |
225 | if random.uniform(0, 1) > 0.5:
226 | strength = random.uniform(0, 50)
227 | crop_image, label = Affine_aug(crop_image, strength=strength, label=label)
228 |
229 | if random.uniform(0, 1) > 0.5:
230 | crop_image = Padding_aug(crop_image, 0.3)
231 |
232 | transformed = self.train_trans(image=crop_image)
233 | crop_image = transformed['image']
234 |
235 |
236 |
237 | #######head pose
238 | reprojectdst, euler_angle = get_head_pose(label, crop_image)
239 | PRY = euler_angle.reshape([-1]).astype(np.float32) / 90.
240 |
241 | ######cla_label
242 | cla_label = np.zeros([4])
243 | if np.sqrt(np.square(label[37, 0] - label[41, 0]) +
244 | np.square(label[37, 1] - label[41, 1])) / cfg.MODEL.hin < self.eye_close_thres \
245 | or np.sqrt(np.square(label[38, 0] - label[40, 0]) +
246 | np.square(label[38, 1] - label[40, 1])) / cfg.MODEL.hin < self.eye_close_thres:
247 | cla_label[0] = 1
248 | if np.sqrt(np.square(label[43, 0] - label[47, 0]) +
249 | np.square(label[43, 1] - label[47, 1])) / cfg.MODEL.hin < self.eye_close_thres \
250 | or np.sqrt(np.square(label[44, 0] - label[46, 0]) +
251 | np.square(label[44, 1] - label[46, 1])) / cfg.MODEL.hin < self.eye_close_thres:
252 | cla_label[1] = 1
253 | if np.sqrt(np.square(label[61, 0] - label[67, 0]) +
254 | np.square(label[61, 1] - label[67, 1])) / cfg.MODEL.hin < self.mouth_close_thres \
255 | or np.sqrt(np.square(label[62, 0] - label[66, 0]) +
256 | np.square(label[62, 1] - label[66, 1])) / cfg.MODEL.hin < self.mouth_close_thres \
257 | or np.sqrt(np.square(label[63, 0] - label[65, 0]) +
258 | np.square(label[63, 1] - label[65, 1])) / cfg.MODEL.hin < self.mouth_close_thres:
259 | cla_label[2] = 1
260 |
261 | ### mouth open big 1 mean true
262 | if np.sqrt(np.square(label[62, 0] - label[66, 0]) +
263 | np.square(label[62, 1] - label[66, 1])) / cfg.MODEL.hin > self.big_mouth_open_thres:
264 | cla_label[3] = 1
265 |
266 | crop_image_height, crop_image_width, _ = crop_image.shape
267 |
268 | label = label.astype(np.float32)
269 |
270 | label[:, 0] = label[:, 0] / crop_image_width
271 | label[:, 1] = label[:, 1] / crop_image_height
272 | crop_image = crop_image.astype(np.uint8)
273 |
274 |
275 | crop_image = crop_image.astype(np.float32)
276 |
277 | crop_image = np.transpose(crop_image, axes=[2, 0, 1])
278 | crop_image/=255.
279 | label = label.reshape([-1]).astype(np.float32)
280 | cla_label = cla_label.astype(np.float32)
281 |
282 | label = np.concatenate([label, PRY, cla_label], axis=0)
283 |
284 |
285 | return fn,crop_image,label
286 |
--------------------------------------------------------------------------------
/lib/core/base_trainer/model.py:
--------------------------------------------------------------------------------
1 | import math
2 | from functools import partial
3 |
4 | from train_config import config as cfg
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | import timm
11 |
12 | from torchvision.models.mobilenetv3 import InvertedResidual,InvertedResidualConfig
13 |
14 |
15 |
16 | bn_momentum=0.1
17 |
18 |
19 |
20 | def upsample_x_like_y(x,y):
21 | size = y.shape[-2:]
22 | x=F.interpolate(x, size=size, mode='bilinear')
23 |
24 | return x
25 |
26 | # from lib.core.base_trainer.mobileone import MobileOneBlock
27 | class SeparableConv2d(nn.Module):
28 | """ Separable Conv
29 | """
30 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding=0, bias=False,
31 | channel_multiplier=1., pw_kernel_size=1):
32 | super(SeparableConv2d, self).__init__()
33 |
34 |
35 | self.conv_dw = nn.Sequential(nn.Conv2d(
36 | int(in_channels*channel_multiplier), int(in_channels*channel_multiplier), kernel_size,
37 | stride=stride, dilation=dilation, padding=padding, groups=int(in_channels*channel_multiplier)),
38 | nn.BatchNorm2d(in_channels,momentum=bn_momentum),
39 | nn.ReLU(inplace=True)
40 | )
41 | self.conv_pw = nn.Conv2d(
42 | int(in_channels*channel_multiplier), out_channels, pw_kernel_size, padding=0, bias=bias)
43 |
44 | @property
45 | def in_channels(self):
46 | return self.conv_dw.in_channels
47 |
48 | @property
49 | def out_channels(self):
50 | return self.conv_pw.out_channels
51 |
52 | def forward(self, x):
53 |
54 | x = self.conv_dw(x)
55 | x = self.conv_pw(x)
56 | return x
57 |
58 |
59 |
60 | class ASPPPooling(nn.Module):
61 | def __init__(self, in_channels, out_channels):
62 | super(ASPPPooling, self).__init__()
63 | self.pool=nn.Sequential(
64 | nn.AdaptiveAvgPool2d(1),
65 | nn.Conv2d(in_channels, out_channels, 1, bias=False),
66 | nn.BatchNorm2d(out_channels,momentum=bn_momentum),
67 | nn.ReLU())
68 | def forward(self, x):
69 |
70 | y=x
71 |
72 |
73 | x = self.pool(x)
74 |
75 | x= upsample_x_like_y(x,y)
76 | return x
77 |
78 |
79 | class ASPP(nn.Module):
80 | def __init__(self, in_channels, atrous_rates):
81 | super(ASPP, self).__init__()
82 | out_channels = 512
83 |
84 | rate1, rate2, rate3 = tuple(atrous_rates)
85 |
86 | self.fm_conx1=nn.Sequential(
87 | nn.Conv2d(in_channels, out_channels//4, 1, bias=False),
88 | nn.BatchNorm2d(out_channels//4,momentum=bn_momentum),
89 | nn.ReLU())
90 |
91 | self.fm_convx3_rate2=nn.Sequential(
92 | nn.Conv2d(in_channels, out_channels//4, kernel_size=3, padding=2, bias=False,dilation=rate1),
93 | nn.BatchNorm2d(out_channels//4,momentum=bn_momentum),
94 | nn.ReLU(inplace=True)
95 | )
96 |
97 | self.fm_convx3_rate4=nn.Sequential(
98 | nn.Conv2d(in_channels, out_channels//4, kernel_size=3, padding=4, bias=False,dilation=rate2),
99 | nn.BatchNorm2d(out_channels//4,momentum=bn_momentum),
100 | nn.ReLU(inplace=True)
101 | )
102 |
103 | self.fm_convx3_rate8=nn.Sequential(
104 | nn.Conv2d(in_channels, out_channels//4, kernel_size=3, padding=8, bias=False,dilation=rate3),
105 | nn.BatchNorm2d(out_channels//4,momentum=bn_momentum),
106 | nn.ReLU(inplace=True)
107 | )
108 |
109 | self.fm_pool=ASPPPooling(in_channels=in_channels,out_channels=out_channels//4)
110 |
111 | self.project = nn.Sequential(
112 | nn.Conv2d(out_channels//4*5, out_channels, 1, bias=False),
113 | nn.BatchNorm2d(out_channels,momentum=bn_momentum),
114 | nn.ReLU(inplace=True))
115 |
116 | def forward(self, x):
117 |
118 | fm1=self.fm_conx1(x)
119 | fm2=self.fm_convx3_rate2(x)
120 | fm4=self.fm_convx3_rate4(x)
121 | fm8=self.fm_convx3_rate8(x)
122 | fm_pool=self.fm_pool(x)
123 |
124 | res = torch.cat([fm1,fm2,fm4,fm8,fm_pool], dim=1)
125 |
126 | return self.project(res)
127 |
128 |
129 |
130 | class FeatureFuc(nn.Module):
131 | def __init__(self, inchannels=128):
132 | super(FeatureFuc, self).__init__()
133 |
134 |
135 | self.block1=InvertedResidual(cnf=InvertedResidualConfig( inchannels, 5, 256, inchannels, False, "RE", 1, 1, 1),
136 | norm_layer = partial(nn.BatchNorm2d, momentum=bn_momentum))
137 |
138 | self.block2=InvertedResidual(cnf=InvertedResidualConfig( inchannels, 5, 256, inchannels, False, "RE", 1, 1, 1),
139 | norm_layer = partial(nn.BatchNorm2d, momentum=bn_momentum))
140 |
141 |
142 | def forward(self, x):
143 |
144 | y1=self.block1(x)
145 |
146 | y2=self.block2(y1)
147 |
148 | return y2
149 |
150 | class SCSEModule(nn.Module):
151 | def __init__(self, in_channels, reduction=4):
152 | super().__init__()
153 | self.cSE = nn.Sequential(
154 | nn.AdaptiveAvgPool2d(1),
155 | nn.Conv2d(in_channels, in_channels // reduction, 1),
156 | nn.ReLU(inplace=True),
157 | nn.Conv2d(in_channels // reduction, in_channels, 1),
158 | nn.Sigmoid(),
159 | )
160 | self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid())
161 |
162 | def forward(self, x):
163 | return x *self.cSE(x) + x*self.sSE(x)
164 |
165 |
166 |
167 |
168 |
169 |
170 | class Heading(nn.Module):
171 |
172 | def __init__(self,encoder_channels):
173 | super(Heading, self).__init__()
174 | self.extra_feature = FeatureFuc(encoder_channels[-1])
175 | self.aspp = ASPP(encoder_channels[-1], [2, 4, 8])
176 |
177 |
178 |
179 |
180 |
181 |
182 | def forward(self,features):
183 |
184 | img,encx2,encx4,encx8,encx16=features
185 |
186 | ## add extra feature
187 | encx16=self.extra_feature(encx16)
188 | encx16=self.aspp(encx16)
189 |
190 | return [encx2,encx4,encx8,encx16]
191 |
192 | class Net(nn.Module):
193 | def __init__(self,):
194 | super(Net, self).__init__()
195 |
196 | self.encoder = timm.create_model(model_name='mobilenetv3_large_100',
197 | pretrained=True,
198 | features_only=True,
199 | out_indices=[0,1,2,4],
200 | bn_momentum=bn_momentum,
201 | in_chans=3,
202 | output_stride=16,
203 | )
204 |
205 | # self.encoder.blocks[4][1]=nn.Identity()
206 | self.encoder.blocks[5]=nn.Identity()
207 | self.encoder.blocks[6]=nn.Identity()
208 |
209 |
210 | self.encoder_out_channels = [3, 16, 24, 40, 112] #mobilenetv3
211 |
212 |
213 | self.head=Heading(self.encoder_out_channels)
214 | self._avg_pooling = nn.AdaptiveAvgPool2d(1)
215 |
216 | self._fc = nn.Linear(512, 136+3+4, bias=True)
217 |
218 | def forward(self, x):
219 | """Sequentially pass `x` trough model`s encoder, decoder and heads"""
220 |
221 | bs = x.size(0)
222 | features=self.encoder(x)
223 |
224 | features=[x]+features
225 |
226 | [encx2,encx4,encx8,encx16]=self.head(features)
227 | fm=self._avg_pooling(encx16)
228 |
229 | fm = fm.view(bs, -1)
230 | x = self._fc(fm)
231 |
232 |
233 | return x,[encx2,encx4,encx8,encx16]
234 |
235 |
236 |
237 | class TeacherNet(nn.Module):
238 | def __init__(self,):
239 | super(TeacherNet, self).__init__()
240 |
241 | self.encoder = timm.create_model(model_name='tf_efficientnet_b0_ns',
242 | pretrained=True,
243 | features_only=True,
244 | out_indices=[0,1,2,3],
245 | bn_momentum=bn_momentum,
246 | in_chans=3,
247 | )
248 |
249 | # self.encoder.out_channels=[3, 24 , 40, 64,176]
250 | self.encoder.out_channels=[3,16, 24, 40, 112]
251 | self.head = Heading(self.encoder.out_channels)
252 | self._avg_pooling = nn.AdaptiveAvgPool2d(1)
253 |
254 | self._fc = nn.Linear(512, 136 + 3 + 4, bias=True)
255 |
256 | def forward(self, x):
257 | """Sequentially pass `x` trough model`s encoder, decoder and heads"""
258 | bs = x.size(0)
259 | features=self.encoder(x)
260 |
261 | features=[x]+features
262 |
263 | [encx2,encx4,encx8,encx16] = self.head(features)
264 |
265 | fm = self._avg_pooling(encx16)
266 |
267 | fm = fm.view(bs, -1)
268 |
269 | x = self._fc(fm)
270 |
271 | return x,[encx2,encx4,encx8,encx16]
272 |
273 |
274 |
275 | class COTRAIN(nn.Module):
276 | def __init__(self,inference=False):
277 | super(COTRAIN, self).__init__()
278 |
279 | self.inference=inference
280 | self.student=Net()
281 | self.teacher=TeacherNet()
282 |
283 | self.MSELoss=nn.MSELoss()
284 |
285 | # self.DiceLoss = segmentation_models_pytorch.losses.DiceLoss(mode='multilabel',eps=1e-5)
286 | self.BCELoss = nn.BCEWithLogitsLoss()
287 |
288 |
289 | self.act=nn.Sigmoid()
290 |
291 |
292 | def distill_loss(self,student_pres,teacher_pres):
293 |
294 |
295 | num_level=len(student_pres)
296 | loss=0
297 | for i in range(num_level):
298 | loss+=self.MSELoss(student_pres[i],teacher_pres[i])
299 |
300 | return loss/num_level
301 |
302 |
303 | def criterion(self,y_pred, y_true):
304 |
305 | return 0.5*self.BCELoss(y_pred, y_true) + 0.5*self.DiceLoss(y_pred, y_true)
306 |
307 | def _wing_loss(self,landmarks, labels, w=10.0, epsilon=2.0, weights=1.):
308 | """
309 | Arguments:
310 | landmarks, labels: float tensors with shape [batch_size, landmarks]. landmarks means x1,x2,x3,x4...y1,y2,y3,y4 1-D
311 | w, epsilon: a float numbers.
312 | Returns:
313 | a float tensor with shape [].
314 | """
315 |
316 | x = landmarks - labels
317 | c = w * (1.0 - math.log(1.0 + w / epsilon))
318 | absolute_x = torch.abs(x)
319 | losses = torch.where(
320 | torch.gt(absolute_x, w), absolute_x - c,
321 | w * torch.log(1.0 + absolute_x / epsilon)
322 |
323 | )
324 | losses = losses * torch.tensor(cfg.DATA.weights, device='cuda')
325 | loss = torch.sum(torch.mean(losses * weights, dim=[0]))
326 |
327 | return loss
328 | def loss(self,predict_keypoints, label_keypoints):
329 |
330 | landmark_label = label_keypoints[:, 0:136]
331 | pose_label = label_keypoints[:, 136:139]
332 | leye_cls_label = label_keypoints[:, 139]
333 | reye_cls_label = label_keypoints[:, 140]
334 | mouth_cls_label = label_keypoints[:, 141]
335 | big_mouth_cls_label = label_keypoints[:, 142]
336 |
337 | landmark_predict = predict_keypoints[:, 0:136]
338 | pose_predict = predict_keypoints[:, 136:139]
339 | leye_cls_predict = predict_keypoints[:, 139]
340 | reye_cls_predict = predict_keypoints[:, 140]
341 | mouth_cls_predict = predict_keypoints[:, 141]
342 | big_mouth_cls_predict = predict_keypoints[:, 142]
343 |
344 | loss = self._wing_loss(landmark_predict, landmark_label)
345 |
346 | loss_pose = self.MSELoss(pose_predict, pose_label)
347 |
348 | leye_loss = self.BCELoss (leye_cls_predict, leye_cls_label)
349 | reye_loss = self.BCELoss (reye_cls_predict, reye_cls_label)
350 |
351 | mouth_loss = self.BCELoss (mouth_cls_predict, mouth_cls_label)
352 | mouth_loss_big = self.BCELoss (big_mouth_cls_predict, big_mouth_cls_label)
353 | mouth_loss = mouth_loss + mouth_loss_big
354 |
355 |
356 |
357 | return loss + loss_pose + leye_loss + reye_loss + mouth_loss
358 |
359 | return current_loss
360 |
361 |
362 | def forward(self, x,gt=None):
363 |
364 | student_pre,student_fms=self.student(x)
365 |
366 | if self.inference:
367 |
368 | return student_pre
369 |
370 | teacher_pre,teacher_fms=self.teacher(x)
371 |
372 | distill_loss=self.distill_loss(student_fms,teacher_fms)
373 |
374 | student_loss=self.loss(student_pre,gt)
375 |
376 | teacher_loss=self.loss(teacher_pre,gt)
377 |
378 | return student_loss,teacher_loss,distill_loss,student_pre
379 |
380 |
381 |
382 |
383 |
384 |
385 |
386 |
387 | if __name__=='__main__':
388 | import torch
389 | import torchvision
390 |
391 | from thop import profile
392 |
393 | dummy_x = torch.randn(1, 3, 288, 160, device='cpu')
394 |
395 | model = COTRAIN(inference=True)
396 |
397 | input = torch.randn(1, 3, 288, 160)
398 | flops, params = profile(model, inputs=(input,))
399 | print(flops/1024/1024)
400 |
401 |
--------------------------------------------------------------------------------
/lib/dataset/augmentor/augmentation.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import cv2
4 | import random
5 | import math
6 | from train_config import config as cfg
7 |
8 | ######May wrong, when use it check it
9 | def Rotate_aug(src,angle,label=None,center=None,scale=1.0):
10 | '''
11 | :param src: src image
12 | :param label: label should be numpy array with [[x1,y1],
13 | [x2,y2],
14 | [x3,y3]...]
15 | :param angle:
16 | :param center:
17 | :param scale:
18 | :return: the rotated image and the points
19 | '''
20 | image=src
21 | (h, w) = image.shape[:2]
22 | # 若未指定旋转中心,则将图像中心设为旋转中心
23 | if center is None:
24 | center = (w / 2, h / 2)
25 | # 执行旋转
26 | M = cv2.getRotationMatrix2D(center, angle, scale)
27 | if label is None:
28 | for i in range(image.shape[2]):
29 | image[:,:,i] = cv2.warpAffine(image[:,:,i], M, (w, h),
30 | flags=cv2.INTER_CUBIC,
31 | borderMode=cv2.BORDER_CONSTANT)
32 | return image,None
33 | else:
34 | label=label.T
35 | ####make it as a 3x3 RT matrix
36 | full_M=np.row_stack((M,np.asarray([0,0,1])))
37 | img_rotated = cv2.warpAffine(image, M, (w, h), flags=cv2.INTER_CUBIC,
38 | borderMode=cv2.BORDER_CONSTANT)
39 | ###make the label as 3xN matrix
40 | full_label = np.row_stack((label, np.ones(shape=(1,label.shape[1]))))
41 | label_rotated=np.dot(full_M,full_label)
42 | label_rotated=label_rotated[0:2,:]
43 | #label_rotated = label_rotated.astype(np.int32)
44 | label_rotated=label_rotated.T
45 | return img_rotated,label_rotated
46 | def Rotate_coordinate(label,rt_matrix):
47 | if rt_matrix.shape[0]==2:
48 | rt_matrix=np.row_stack((rt_matrix, np.asarray([0, 0, 1])))
49 | full_label = np.row_stack((label, np.ones(shape=(1, label.shape[1]))))
50 | label_rotated = np.dot(rt_matrix, full_label)
51 | label_rotated = label_rotated[0:2, :]
52 | return label_rotated
53 |
54 |
55 | def box_to_point(boxes):
56 | '''
57 |
58 | :param boxes: [n,x,y,x,y]
59 | :return: [4n,x,y]
60 | '''
61 | ##caution the boxes are ymin xmin ymax xmax
62 | points_set=np.zeros(shape=[4*boxes.shape[0],2])
63 |
64 | for i in range(boxes.shape[0]):
65 | points_set[4 * i]=np.array([boxes[i][0],boxes[i][1]])
66 | points_set[4 * i+1] =np.array([boxes[i][0],boxes[i][3]])
67 | points_set[4 * i+2] =np.array([boxes[i][2],boxes[i][3]])
68 | points_set[4 * i+3] =np.array([boxes[i][2],boxes[i][1]])
69 |
70 |
71 | return points_set
72 |
73 |
74 | def point_to_box(points):
75 | boxes=[]
76 | points=points.reshape([-1,4,2])
77 |
78 | for i in range(points.shape[0]):
79 | box=[np.min(points[i][:,0]),np.min(points[i][:,1]),np.max(points[i][:,0]),np.max(points[i][:,1])]
80 |
81 | boxes.append(box)
82 |
83 | return np.array(boxes)
84 |
85 |
86 | def Rotate_with_box(src,angle,boxes=None,center=None,scale=1.0):
87 | '''
88 | :param src: src image
89 | :param label: label should be numpy array with [[x1,y1],
90 | [x2,y2],
91 | [x3,y3]...]
92 | :param angle:angel
93 | :param center:
94 | :param scale:
95 | :return: the rotated image and the points
96 | '''
97 |
98 | label=box_to_point(boxes)
99 | image=src
100 | (h, w) = image.shape[:2]
101 | # 若未指定旋转中心,则将图像中心设为旋转中心
102 |
103 |
104 | if center is None:
105 | center = (w / 2, h / 2)
106 | # 执行旋转
107 | M = cv2.getRotationMatrix2D(center, angle, scale)
108 |
109 | new_size=Rotate_coordinate(np.array([[0,w,w,0],
110 | [0,0,h,h]]), M)
111 |
112 | new_h,new_w=np.max(new_size[1])-np.min(new_size[1]),np.max(new_size[0])-np.min(new_size[0])
113 |
114 | scale=min(h/new_h,w/new_w)
115 |
116 | M = cv2.getRotationMatrix2D(center, angle, scale)
117 |
118 | if boxes is None:
119 | for i in range(image.shape[2]):
120 | image[:,:,i] = cv2.warpAffine(image[:,:,i], M, (w, h), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_CONSTANT)
121 | return image,None
122 | else:
123 | label=label.T
124 | ####make it as a 3x3 RT matrix
125 | full_M=np.row_stack((M,np.asarray([0,0,1])))
126 | img_rotated = cv2.warpAffine(image, M, (w, h), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_CONSTANT)
127 | ###make the label as 3xN matrix
128 | full_label = np.row_stack((label, np.ones(shape=(1,label.shape[1]))))
129 | label_rotated=np.dot(full_M,full_label)
130 | label_rotated=label_rotated[0:2,:]
131 | #label_rotated = label_rotated.astype(np.int32)
132 | label_rotated=label_rotated.T
133 |
134 | boxes_rotated = point_to_box(label_rotated)
135 | return img_rotated,boxes_rotated
136 |
137 | ###CAUTION:its not ok for transform with label for perspective _aug
138 | def Perspective_aug(src,strength,label=None):
139 | image = src
140 | pts_base = np.float32([[0, 0], [300, 0], [0, 300], [300, 300]])
141 | pts1=np.random.rand(4, 2)*random.uniform(-strength,strength)+pts_base
142 | pts1=pts1.astype(np.float32)
143 | #pts1 =np.float32([[56, 65], [368, 52], [28, 387], [389, 398]])
144 | M = cv2.getPerspectiveTransform(pts1, pts_base)
145 | trans_img = cv2.warpPerspective(image, M, (src.shape[1], src.shape[0]))
146 |
147 | label_rotated=None
148 | if label is not None:
149 | label=label.T
150 | full_label = np.row_stack((label, np.ones(shape=(1, label.shape[1]))))
151 | label_rotated = np.dot(M, full_label)
152 | label_rotated=label_rotated.astype(np.int32)
153 | label_rotated=label_rotated.T
154 | return trans_img,label_rotated
155 |
156 | def Affine_aug(src,strength,label=None):
157 | image = src
158 | pts_base = np.float32([[10,100],[200,50],[100,250]])
159 | pts1 = np.random.rand(3, 2) * random.uniform(-strength, strength) + pts_base
160 | pts1 = pts1.astype(np.float32)
161 | M = cv2.getAffineTransform(pts1, pts_base)
162 | trans_img = cv2.warpAffine(image, M, (image.shape[1], image.shape[0]) ,
163 | borderMode=cv2.BORDER_CONSTANT)
164 | label_rotated=None
165 | if label is not None:
166 | label=label.T
167 | full_label = np.row_stack((label, np.ones(shape=(1, label.shape[1]))))
168 | label_rotated = np.dot(M, full_label)
169 | #label_rotated = label_rotated.astype(np.int32)
170 | label_rotated=label_rotated.T
171 | return trans_img,label_rotated
172 | def Padding_aug(src,max_pattern_ratio=0.05):
173 | src=src.astype(np.float32)
174 | pattern=np.ones_like(src)
175 | ratio = random.uniform(0, max_pattern_ratio)
176 |
177 | height,width,_=src.shape
178 |
179 | if random.uniform(0,1)>0.5:
180 | if random.uniform(0, 1) > 0.5:
181 | pattern[0:int(ratio*height),:,:]=0
182 | else:
183 | pattern[height-int(ratio * height):, :, :] = 0
184 | else:
185 | if random.uniform(0, 1) > 0.5:
186 | pattern[:,0:int(ratio * width), :] = 0
187 | else:
188 | pattern[:,width-int(ratio * width):, :] = 0
189 |
190 |
191 | bias_pattern=(1-pattern)
192 |
193 | img=src*pattern+bias_pattern
194 |
195 | img=img.astype(np.uint8)
196 | return img
197 |
198 | def Blur_heatmaps(src, ksize=(3, 3)):
199 | for i in range(src.shape[2]):
200 | src[:, :, i] = cv2.GaussianBlur(src[:, :, i], ksize, 0)
201 | amin, amax = src[:, :, i].min(), src[:, :, i].max() # 求最大最小值
202 | if amax>0:
203 | src[:, :, i] = (src[:, :, i] - amin) / (amax - amin) # (矩阵元素-最小值)/(最大值-最小值)
204 | return src
205 | def Blur_aug(src,ksize=(3,3)):
206 | for i in range(src.shape[2]):
207 | src[:, :, i]=cv2.GaussianBlur(src[:, :, i],ksize,1.5)
208 | return src
209 |
210 | def Img_dropout(src,max_pattern_ratio=0.05):
211 | width_ratio = random.uniform(0, max_pattern_ratio)
212 | height_ratio = random.uniform(0, max_pattern_ratio)
213 | width=src.shape[1]
214 | height=src.shape[0]
215 | block_width=width*width_ratio
216 | block_height=height*height_ratio
217 | width_start=int(random.uniform(0,width-block_width))
218 | width_end=int(width_start+block_width)
219 | height_start=int(random.uniform(0,height-block_height))
220 | height_end=int(height_start+block_height)
221 | src[height_start:height_end,width_start:width_end,:]=0
222 |
223 | return src
224 |
225 | def Fill_img(img_raw,target_height,target_width,label=None):
226 |
227 | ###sometimes use in objs detects
228 | channel=img_raw.shape[2]
229 | raw_height = img_raw.shape[0]
230 | raw_width = img_raw.shape[1]
231 | if raw_width / raw_height >= target_width / target_height:
232 | shape_need = [int(target_height / target_width * raw_width), raw_width, channel]
233 | img_fill = np.zeros(shape_need, dtype=img_raw.dtype)
234 | shift_x=(img_fill.shape[1]-raw_width)//2
235 | shift_y=(img_fill.shape[0]-raw_height)//2
236 | for i in range(channel):
237 | img_fill[shift_y:raw_height+shift_y, shift_x:raw_width+shift_x, i] = img_raw[:,:,i]
238 | else:
239 | shape_need = [raw_height, int(target_width / target_height * raw_height), channel]
240 | img_fill = np.zeros(shape_need, dtype=img_raw.dtype)
241 | shift_x = (img_fill.shape[1] - raw_width) // 2
242 | shift_y = (img_fill.shape[0] - raw_height) // 2
243 | for i in range(channel):
244 | img_fill[shift_y:raw_height + shift_y, shift_x:raw_width + shift_x, i] = img_raw[:, :, i]
245 | if label is None:
246 | return img_fill,shift_x,shift_y
247 | else:
248 | label[:,0]+=shift_x
249 | label[:, 1]+=shift_y
250 | return img_fill,label
251 | def Random_crop(src,shrink):
252 | h,w,_=src.shape
253 |
254 | h_shrink=int(h*shrink)
255 | w_shrink = int(w * shrink)
256 | bimg = cv2.copyMakeBorder(src, h_shrink, h_shrink, w_shrink, w_shrink, borderType=cv2.BORDER_CONSTANT,
257 | value=(0,0,0))
258 |
259 | start_h=random.randint(0,2*h_shrink)
260 | start_w=random.randint(0,2*w_shrink)
261 |
262 | target_img=bimg[start_h:start_h+h,start_w:start_w+w,:]
263 |
264 | return target_img
265 |
266 | def box_in_img(img,boxes,min_overlap=0.5):
267 |
268 | raw_bboxes = np.array(boxes)
269 |
270 | face_area=(boxes[:,3]-boxes[:,1])*(boxes[:,2]-boxes[:,0])
271 |
272 | h,w,_=img.shape
273 | boxes[:, 0][boxes[:, 0] <=0] =0
274 | boxes[:, 0][boxes[:, 0] >=w] = w
275 | boxes[:, 2][boxes[:, 2] <= 0] = 0
276 | boxes[:, 2][boxes[:, 2] >= w] = w
277 |
278 | boxes[:, 1][boxes[:, 1] <= 0] = 0
279 | boxes[:, 1][boxes[:, 1] >= h] = h
280 |
281 | boxes[:, 3][boxes[:, 3] <= 0] = 0
282 | boxes[:, 3][boxes[:, 3] >= h] = h
283 |
284 | boxes_in = []
285 | for i in range(boxes.shape[0]):
286 | box=boxes[i]
287 | if ((box[3]-box[1])*(box[2]-box[0]))/face_area[i]>min_overlap :
288 | boxes_in.append(boxes[i])
289 |
290 | boxes_in = np.array(boxes_in)
291 | return boxes_in
292 |
293 | def Random_scale_withbbox(image,bboxes,target_shape,jitter=0.5):
294 |
295 | ###the boxes is in ymin,xmin,ymax,xmax mode
296 | hi, wi, _ = image.shape
297 |
298 | while 1:
299 | if len(bboxes)==0:
300 | print('errrrrrr')
301 | bboxes_=np.array(bboxes)
302 | crop_h = int(hi * random.uniform(0.2, 1))
303 | crop_w = int(wi * random.uniform(0.2, 1))
304 |
305 | start_h = random.randint(0, hi - crop_h)
306 | start_w = random.randint(0, wi - crop_w)
307 |
308 | croped = image[start_h:start_h + crop_h, start_w:start_w + crop_w, :]
309 |
310 | bboxes_[:, 0] = bboxes_[:, 0] - start_w
311 | bboxes_[:, 1] = bboxes_[:, 1] - start_h
312 | bboxes_[:, 2] = bboxes_[:, 2] - start_w
313 | bboxes_[:, 3] = bboxes_[:, 3] - start_h
314 |
315 |
316 | bboxes_fix=box_in_img(croped,bboxes_)
317 | if len(bboxes_fix)>0:
318 | break
319 |
320 |
321 | ###use box
322 | h,w=target_shape
323 | croped_h,croped_w,_=croped.shape
324 |
325 | croped_h_w_ratio=croped_h/croped_w
326 |
327 | rescale_h=int(h * random.uniform(0.5, 1))
328 |
329 | rescale_w = int(rescale_h/(random.uniform(0.7, 1.3)*croped_h_w_ratio))
330 | rescale_w=np.clip(rescale_w,0,w)
331 |
332 | image=cv2.resize(croped,(rescale_w,rescale_h))
333 |
334 | new_image=np.zeros(shape=[h,w,3],dtype=np.uint8)
335 |
336 | dx = int(random.randint(0, w - rescale_w))
337 | dy = int(random.randint(0, h - rescale_h))
338 |
339 | new_image[dy:dy+rescale_h,dx:dx+rescale_w,:]=image
340 |
341 | bboxes_fix[:, 0] = bboxes_fix[:, 0] * rescale_w/ croped_w+dx
342 | bboxes_fix[:, 1] = bboxes_fix[:, 1] * rescale_h / croped_h+dy
343 | bboxes_fix[:, 2] = bboxes_fix[:, 2] * rescale_w / croped_w+dx
344 | bboxes_fix[:, 3] = bboxes_fix[:, 3] * rescale_h / croped_h+dy
345 |
346 |
347 |
348 | return new_image,bboxes_fix
349 |
350 |
351 | def Random_flip(im, boxes):
352 |
353 | im_lr = np.fliplr(im).copy()
354 | h,w,_ = im.shape
355 | xmin = w - boxes[:,2]
356 | xmax = w - boxes[:,0]
357 | boxes[:,0] = xmin
358 | boxes[:,2] = xmax
359 | return im_lr, boxes
360 |
361 |
362 | def Mirror(src,label=None,symmetry=None):
363 |
364 | img = cv2.flip(src, 1)
365 | if label is None:
366 | return img,label
367 |
368 | width=img.shape[1]
369 | cod = []
370 | allc = []
371 | for i in range(label.shape[0]):
372 | x, y = label[i][0], label[i][1]
373 | if x >= 0:
374 | x = width - 1 - x
375 | cod.append((x, y))
376 | # **** the joint index depends on the dataset ****
377 | for (q, w) in symmetry:
378 | cod[q], cod[w] = cod[w], cod[q]
379 | for i in range(label.shape[0]):
380 | allc.append(cod[i][0])
381 | allc.append(cod[i][1])
382 | label = np.array(allc).reshape(label.shape[0], 2)
383 | return img,label
384 |
385 | def produce_heat_maps(label,map_size,stride,sigma):
386 | def produce_heat_map(center,map_size,stride,sigma):
387 | grid_y = map_size[0] // stride
388 | grid_x = map_size[1] // stride
389 | start = stride / 2.0 - 0.5
390 | y_range = [i for i in range(grid_y)]
391 | x_range = [i for i in range(grid_x)]
392 | xx, yy = np.meshgrid(x_range, y_range)
393 | xx = xx * stride + start
394 | yy = yy * stride + start
395 | d2 = (xx - center[0]) ** 2 + (yy - center[1]) ** 2
396 | exponent = d2 / 2.0 / sigma / sigma
397 | heatmap = np.exp(-exponent)
398 |
399 | am = np.amax(heatmap)
400 | if am > 0:
401 | heatmap /= am / 255.
402 |
403 | return heatmap
404 | all_keypoints = label
405 | point_num = all_keypoints.shape[0]
406 | heatmaps_this_img=np.zeros([map_size[0]//stride,map_size[1]//stride,point_num])
407 | for k in range(point_num):
408 | heatmap = produce_heat_map([all_keypoints[k][0],all_keypoints[k][1]], map_size, stride, sigma)
409 | heatmaps_this_img[:,:,k]=heatmap
410 | return heatmaps_this_img
411 |
412 | def visualize_heatmap_target(heatmap):
413 | map_size=heatmap.shape[0:2]
414 | frame_num = heatmap.shape[2]
415 | heat_ = np.zeros([map_size[0], map_size[1]])
416 | for i in range(frame_num):
417 | heat_ = heat_ + heatmap[:, :, i]
418 | cv2.namedWindow('heat_map', 0)
419 | cv2.imshow('heat_map', heat_)
420 | cv2.waitKey(0)
421 |
422 |
423 | def produce_heatmaps_with_bbox(image,label,h_out,w_out,num_klass,ksize=9,sigma=0):
424 | heatmap=np.zeros(shape=[h_out,w_out,num_klass])
425 |
426 | h,w,_=image.shape
427 |
428 | for single_box in label:
429 | if single_box[4]>=0:
430 | ####box center (x,y)
431 | center=[(single_box[0]+single_box[2])/2/w,(single_box[1]+single_box[3])/2/h] ###0-1
432 |
433 | heatmap[round(center[1]*h_out),round(center[0]*w_out),int(single_box[4]) ]=1.
434 |
435 | heatmap = cv2.GaussianBlur(heatmap, (ksize,ksize), sigma)
436 | am = np.amax(heatmap)
437 | if am>0:
438 | heatmap /= am / 255.
439 | heatmap=np.expand_dims(heatmap,-1)
440 | return heatmap
441 |
442 |
443 | def produce_heatmaps_with_keypoint(image,label,h_out,w_out,num_klass,ksize=7,sigma=0):
444 | heatmap=np.zeros(shape=[h_out,w_out,num_klass])
445 |
446 | h,w,_=image.shape
447 |
448 | for i in range(label.shape[0]):
449 | single_point=label[i]
450 |
451 | if single_point[0]>0 and single_point[1]>0:
452 |
453 | heatmap[int(single_point[1]*(h_out-1)),int(single_point[0]*(w_out-1)),i ]=1.
454 |
455 | heatmap = cv2.GaussianBlur(heatmap, (ksize,ksize), sigma)
456 | am = np.amax(heatmap)
457 | if am>0:
458 | heatmap /= am / 255.
459 | return heatmap
460 |
461 |
462 |
463 |
464 |
465 | if __name__=='__main__':
466 | pass
467 |
--------------------------------------------------------------------------------
/lib/core/base_trainer/net_work.py:
--------------------------------------------------------------------------------
1 | #-*-coding:utf-8-*-
2 | import numpy as np
3 | import sklearn.metrics
4 | import cv2
5 | import time
6 | import os
7 |
8 | from torch.utils.data import DataLoader, DistributedSampler
9 |
10 |
11 | from lib.dataset.dataietr import AlaskaDataIter
12 | from train_config import config as cfg
13 | #from lib.dataset.dataietr import DataIter
14 |
15 | import sklearn.metrics
16 | from lib.utils.logger import logger
17 |
18 | from lib.core.base_trainer.model import Net,COTRAIN
19 |
20 | import random
21 |
22 | from lib.core.base_trainer.metric import *
23 | import torch
24 | import torch.nn.functional as F
25 |
26 |
27 |
28 |
29 | if not cfg.TRAIN.vis:
30 | torch.distributed.init_process_group(backend="nccl",)
31 |
32 | class Train(object):
33 | """Train class.
34 | """
35 |
36 | def __init__(self,train_df,val_df,fold):
37 |
38 | self.ddp= not cfg.TRAIN.vis
39 |
40 | # self.ddp=False
41 |
42 | if self.ddp:
43 |
44 | self.train_generator = AlaskaDataIter(train_df,cfg.DATA.root_path ,training_flag=True, shuffle=False)
45 |
46 |
47 | self.train_sampler=DistributedSampler(self.train_generator,
48 | shuffle=True)
49 | self.train_ds = DataLoader(self.train_generator,
50 |
51 | cfg.TRAIN.batch_size,
52 | num_workers=cfg.TRAIN.process_num,
53 | sampler=self.train_sampler)
54 |
55 | self.val_generator = AlaskaDataIter(val_df,cfg.DATA.root_path , training_flag=False, shuffle=False)
56 |
57 |
58 | self.val_sampler=DistributedSampler(self.val_generator,
59 | shuffle=False)
60 |
61 | self.val_ds = DataLoader(self.val_generator,
62 | cfg.TRAIN.validatiojn_batch_size,
63 | num_workers=cfg.TRAIN.process_num,
64 | sampler=self.val_sampler)
65 |
66 | local_rank = torch.distributed.get_rank()
67 | torch.cuda.set_device(local_rank)
68 | self.device = torch.device("cuda", local_rank)
69 |
70 | else:
71 | self.train_generator = AlaskaDataIter(train_df,
72 | img_root=cfg.DATA.root_path,
73 | training_flag=True, shuffle=False)
74 |
75 | self.train_ds = DataLoader(self.train_generator,
76 | cfg.TRAIN.batch_size,
77 | num_workers=cfg.TRAIN.process_num,shuffle=True)
78 |
79 | self.val_generator = AlaskaDataIter(val_df,
80 | img_root=cfg.DATA.root_path,
81 | training_flag=False, shuffle=False)
82 |
83 | self.val_ds = DataLoader(self.val_generator,
84 | cfg.TRAIN.validatiojn_batch_size,
85 | num_workers=cfg.TRAIN.process_num,shuffle=False)
86 |
87 | self.device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
88 |
89 | self.fold = fold
90 | self.fp16=cfg.TRAIN.mix_precision
91 |
92 | self.init_lr = cfg.TRAIN.init_lr
93 | self.warup_step = cfg.TRAIN.warmup_step
94 | self.epochs = cfg.TRAIN.epoch
95 | self.batch_size = cfg.TRAIN.batch_size
96 | self.l2_regularization = cfg.TRAIN.weight_decay_factor
97 |
98 | self.early_stop = cfg.TRAIN.early_stop
99 |
100 | self.accumulation_step = cfg.TRAIN.accumulation_batch_size // cfg.TRAIN.batch_size
101 |
102 |
103 | self.gradient_clip = cfg.TRAIN.gradient_clip
104 |
105 | self.save_dir=cfg.MODEL.model_path
106 | #### make the device
107 |
108 |
109 | self.model = COTRAIN().to(self.device)
110 |
111 | self.load_weight()
112 | # self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model)
113 |
114 | if 'Adamw' in cfg.TRAIN.opt:
115 |
116 | self.optimizer=torch.optim.AdamW(self.model.parameters(),
117 | lr=self.init_lr,
118 | weight_decay=self.l2_regularization)
119 |
120 | else:
121 | self.optimizer = torch.optim.SGD(self.model.parameters(),
122 | lr=self.init_lr,
123 | momentum=0.9,
124 | weight_decay=self.l2_regularization)
125 |
126 |
127 |
128 |
129 | if self.ddp:
130 | self.model = torch.nn.parallel.DistributedDataParallel(self.model,
131 | device_ids=[local_rank],
132 | output_device=local_rank,
133 | static_graph=True,
134 | find_unused_parameters=True
135 |
136 | )
137 | else:
138 | self.model=torch.nn.DataParallel(self.model)
139 |
140 | if cfg.TRAIN.vis:
141 | self.vis()
142 |
143 |
144 |
145 | ###control vars
146 | self.iter_num = 0
147 |
148 | # self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer,mode='max', patience=5,
149 | # min_lr=1e-6,factor=0.5,verbose=True)
150 |
151 | # if cfg.TRAIN.lr_scheduler=='cos':
152 | if 1:
153 | logger.info('lr_scheduler.CosineAnnealingLR')
154 | self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer,
155 | self.epochs,
156 | eta_min=1.e-7)
157 | else:
158 | logger.info('lr_scheduler.ReduceLROnPlateau')
159 | self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer,
160 | mode='max',
161 | patience=5,
162 | min_lr=1e-7,
163 | factor=cfg.TRAIN.lr_scheduler_factor,
164 | verbose=True)
165 |
166 | self.scaler = torch.cuda.amp.GradScaler()
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 | def metric(self,targets,preds ):
175 |
176 |
177 | mse=torch.mean((preds-targets)**2)
178 |
179 | mad=torch.mean(torch.abs(preds-targets))
180 |
181 | return mad,mse
182 |
183 | def custom_loop(self):
184 | """Custom training and testing loop.
185 | Args:
186 | train_dist_dataset: Training dataset created using strategy.
187 | test_dist_dataset: Testing dataset created using strategy.
188 | strategy: Distribution strategy.
189 | Returns:
190 | train_loss, train_accuracy, test_loss, test_accuracy
191 | """
192 |
193 | def distributed_train_epoch(epoch_num):
194 |
195 | self.train_sampler.set_epoch(epoch_num)
196 |
197 | summary_loss = AverageMeter()
198 |
199 | summary_studen_loss = AverageMeter()
200 | summary_teacher_loss = AverageMeter()
201 | summary_distill_loss = AverageMeter()
202 |
203 | self.model.train()
204 |
205 |
206 | for k,(ids,images, kps) in enumerate(self.train_ds):
207 |
208 | if epoch_num<10:
209 | ###excute warm up in the first epoch
210 | if self.warup_step>0:
211 | if self.iter_num < self.warup_step:
212 | for param_group in self.optimizer.param_groups:
213 | param_group['lr'] = self.iter_num / float(self.warup_step) * self.init_lr
214 | lr = param_group['lr']
215 |
216 | logger.info('warm up with learning rate: [%f]' % (lr))
217 |
218 | start=time.time()
219 |
220 | data = images.to(self.device).float()
221 | kps = kps.to(self.device).float()
222 |
223 | batch_size = data.shape[0]
224 | ###
225 |
226 | with torch.cuda.amp.autocast(enabled=self.fp16):
227 |
228 | student_loss, teacher_loss, distill_loss,mate = self.model(data,kps)
229 |
230 | # calculate the final loss, backward the loss, and update the model
231 | current_loss = student_loss+ teacher_loss+ distill_loss
232 |
233 | if torch.isnan(current_loss):
234 | print('there is a nan loss ')
235 |
236 | summary_loss.update(current_loss.detach().item(), batch_size)
237 | summary_studen_loss.update(student_loss.detach().item(), batch_size)
238 | summary_teacher_loss.update(teacher_loss.detach().item(), batch_size)
239 | summary_distill_loss.update(distill_loss.detach().item(), batch_size)
240 |
241 | self.scaler.scale(current_loss).backward()
242 |
243 | if ((self.iter_num + 1) % self.accumulation_step) == 0:
244 | if self.gradient_clip>0:
245 | self.scaler.unscale_(self.optimizer)
246 | nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.gradient_clip, norm_type=2)
247 | self.scaler.step(self.optimizer)
248 | self.scaler.update()
249 | self.optimizer.zero_grad()
250 |
251 |
252 | self.iter_num+=1
253 | time_cost_per_batch=time.time()-start
254 |
255 | images_per_sec=cfg.TRAIN.batch_size/time_cost_per_batch
256 |
257 |
258 | if self.iter_num%cfg.TRAIN.log_interval==0:
259 |
260 |
261 | log_message = '[fold %d], ' \
262 | 'Train Step %d, ' \
263 | '[%d/%d ] ' \
264 | 'summary_loss: %.6f, ' \
265 | 'summary_student_loss: %.6f, ' \
266 | 'summary_teacher_loss: %.6f, ' \
267 | 'summary_distill_loss: %.6f, ' \
268 | 'time: %.6f, ' \
269 | 'speed %d images/persec'% (
270 | self.fold,
271 | self.iter_num,
272 | k,
273 | len(self.train_ds),
274 | summary_loss.avg,
275 | summary_studen_loss.avg,
276 | summary_teacher_loss.avg,
277 | summary_distill_loss.avg,
278 | time.time() - start,
279 | images_per_sec)
280 | logger.info(log_message)
281 |
282 |
283 |
284 |
285 | return summary_loss
286 |
287 | def distributed_test_epoch(epoch_num):
288 | summary_loss = AverageMeter()
289 | summary_dice= AverageMeter()
290 | summary_mad= AverageMeter()
291 | summary_mse= AverageMeter()
292 |
293 |
294 | self.model.eval()
295 | t = time.time()
296 |
297 | with torch.no_grad():
298 | for step,(ids,images,kps) in enumerate(self.val_ds):
299 |
300 | data = images.to(self.device).float()
301 | kps = kps.to(self.device).float()
302 |
303 | batch_size = data.shape[0]
304 |
305 | loss,_,_,output = self.model(data,kps)
306 |
307 |
308 | if self.ddp:
309 | torch.distributed.all_reduce(loss.div_(torch.distributed.get_world_size()))
310 |
311 | summary_loss.update(loss.detach().item(), batch_size)
312 |
313 |
314 |
315 | val_mad,val_mse =self.metric(kps[:,:136],output[:,:136])
316 |
317 | if self.ddp:
318 |
319 | torch.distributed.all_reduce(val_mad.div_(torch.distributed.get_world_size()))
320 | torch.distributed.all_reduce(val_mse.div_(torch.distributed.get_world_size()))
321 |
322 |
323 | summary_mad.update(val_mad.detach().item(), batch_size)
324 | summary_mse.update(val_mse.detach().item(), batch_size)
325 |
326 | if step % cfg.TRAIN.log_interval == 0:
327 |
328 | log_message = '[fold %d], ' \
329 | 'Val Step %d, ' \
330 | 'summary_loss: %.6f, ' \
331 | 'summary_mad: %.6f, ' \
332 | 'summary_mse: %.6f, ' \
333 | 'time: %.6f' % (
334 | self.fold,step,
335 | summary_loss.avg,
336 | summary_mad.avg,
337 | summary_mse.avg,
338 | time.time() - t)
339 |
340 | logger.info(log_message)
341 |
342 |
343 |
344 | return summary_loss,summary_mad,summary_mse
345 |
346 |
347 |
348 |
349 | best_roc_auc=0.
350 | not_improvement=0
351 | for epoch in range(self.epochs):
352 |
353 | for param_group in self.optimizer.param_groups:
354 | lr=param_group['lr']
355 | logger.info('learning rate: [%f]' %(lr))
356 | t=time.time()
357 |
358 |
359 | summary_loss = distributed_train_epoch(epoch)
360 | train_epoch_log_message = '[fold %d], ' \
361 | '[RESULT]: TRAIN. Epoch: %d,' \
362 | ' summary_loss: %.5f,' \
363 | ' time:%.5f' % (
364 | self.fold,
365 | epoch,
366 | summary_loss.avg,
367 | (time.time() - t))
368 | logger.info(train_epoch_log_message)
369 |
370 |
371 |
372 |
373 |
374 | if epoch%cfg.TRAIN.test_interval==0 and epoch>0 or epoch%10==0:
375 |
376 | summary_loss ,summary_mad,summary_mse= distributed_test_epoch(epoch)
377 |
378 | val_epoch_log_message = '[fold %d], ' \
379 | '[RESULT]: VAL. Epoch: %d,' \
380 | ' summary_loss: %.5f,' \
381 | ' mad_score: %.5f,' \
382 | ' mse_score: %.5f,' \
383 | ' time:%.5f' % (
384 | self.fold,
385 | epoch,
386 | summary_loss.avg,
387 | summary_mad.avg,
388 | summary_mse.avg,
389 | (time.time() - t))
390 |
391 | logger.info(val_epoch_log_message)
392 |
393 | self.scheduler.step()
394 | # self.scheduler.step(acc_score.avg)
395 |
396 | #### save model
397 | if not os.access(cfg.MODEL.model_path, os.F_OK):
398 | os.mkdir(cfg.MODEL.model_path)
399 | ###save the best auc model
400 |
401 | #### save the model every end of epoch
402 | #### save the model every end of epoch
403 | current_model_saved_name='./models/fold%d_epoch_%d_val_loss_%.6f_val_mse_%.6f.pth'%(self.fold,
404 | epoch,
405 | summary_loss.avg,
406 | summary_mse.avg,)
407 | logger.info('A model saved to %s' % current_model_saved_name)
408 | #### save the model every end of epoch
409 | if self.ddp and torch.distributed.get_rank() == 0 :
410 | torch.save(self.model.module.state_dict(),current_model_saved_name)
411 | else:
412 | torch.save(self.model.module.state_dict(),current_model_saved_name)
413 |
414 | else:
415 | self.scheduler.step()
416 | ####switch back
417 |
418 |
419 | # save_checkpoint({
420 | # 'state_dict': self.model.state_dict(),
421 | # },iters=epoch,tag=current_model_saved_name)
422 |
423 |
424 |
425 | # if cur_roc_auc_score>best_roc_auc:
426 | # best_roc_auc=cur_roc_auc_score
427 | # logger.info(' best metric score update as %.6f' % (best_roc_auc))
428 | # else:
429 | # not_improvement+=1
430 |
431 | if not_improvement>=self.early_stop:
432 | logger.info(' best metric score not improvement for %d, break'%(self.early_stop))
433 | break
434 |
435 |
436 |
437 | def load_weight(self):
438 | if cfg.MODEL.pretrained_model is not None:
439 | state_dict=torch.load(cfg.MODEL.pretrained_model, map_location=self.device)
440 | self.model.load_state_dict(state_dict,strict=False)
441 |
442 |
443 | def vis(self):
444 | print('it is here')
445 |
446 | print('show it, here')
447 |
448 |
449 | # state_dict=torch.load('/Users/liangzi/Downloads/fold1_epoch_29_val_loss_0.049467_val_dice_0.964158.pth', map_location=self.device)
450 | # self.model.load_state_dict(state_dict,strict=False)
451 | self.model.eval()
452 | for id,images, kps in self.train_ds:
453 |
454 |
455 |
456 | for i in range(images.shape[0]):
457 |
458 | example_image=np.array(images[i]*255,dtype=np.uint8)
459 | example_image=np.transpose(example_image,[1,2,0])
460 | example_image=np.ascontiguousarray(example_image)
461 | example_kps=np.array(kps[i][:136].reshape(-1,2))*128
462 | print(example_kps)
463 |
464 | for _index in range(example_kps.shape[0]):
465 | x_y = example_kps[_index]
466 |
467 | cv2.circle(example_image, (int(x_y[0] ),int(x_y[1] )),
468 | color=(255, 0, 0), radius=2, thickness=4)
469 |
470 |
471 | cv2.imshow('ss', example_image)
472 | cv2.waitKey(0)
473 |
474 |
475 |
476 |
--------------------------------------------------------------------------------