├── README.md ├── facial_landmarks ├── 4.jpg ├── check_weights.py ├── facial_lm_model.py ├── inference.py ├── inference_folder.py ├── model_weights │ ├── face_landmark.tflite │ └── facial_landmarks.pth ├── test_images │ └── m@35189.2e16d0ba.fill-514x514.jpg ├── tflite2pt.py └── utils.py └── iris ├── inference.py ├── iris2.jpg ├── irismodel.py ├── model_weights ├── iris_landmark.tflite └── irislandmarks.pth ├── tfite2pt.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # mediapipe_pytorch 2 | PyTorch implementation of Google's Mediapipe model. Iris Landmark model | Face Mesh Model 3 | 4 | 5 | ## Face Mesh Model 6 | facial_landmark folder contains the PyTorch implementation of paper Real-time Facial Surface Geometry from Monocular Video on Mobile GPUs (https://arxiv.org/pdf/1907.06724.pdf) 7 | * For inference 8 | ```bash 9 | cd facial_landmark 10 | !python inference.py 11 | ``` 12 | 13 | ## Iris Landmark Model 14 | iris folder contains the PyTorch implementation of paper Real-time Pupil Tracking from Monocular Video for Digital Puppetry (https://arxiv.org/pdf/2006.11341) 15 | * For inference 16 | ```bash 17 | cd iris 18 | !python inference.py 19 | ``` 20 | 21 | ## Conversion Issues 22 | * TFLite uses slightly different padding compared to PyTorch. 23 | * Instead of using the padding parameter in the conv layer applying padding manually. 24 | * Change the padding value. 25 | * Misleading results 26 | * nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=2, padding=1, bias=True) 27 | * Correction 28 | * nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=2, padding=0, bias=True) 29 | * x = nn.ReflectionPad2d((1, 0, 1, 0))(x) # Apply padding before convolution. 30 | 31 | -------------------------------------------------------------------------------- /facial_landmarks/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiqq111/mediapipe_pytorch/4eb026c8fdcdceda28cb7d96bec64c150577f289/facial_landmarks/4.jpg -------------------------------------------------------------------------------- /facial_landmarks/check_weights.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import tensorflow as tf 3 | import torch 4 | import cv2 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | from facial_lm_model import FacialLM_Model 8 | from tflite import Model 9 | 10 | data = open("model_weights/face_landmark.tflite", "rb").read() 11 | model = Model.GetRootAsModel(data, 0) 12 | 13 | tflite_graph = model.Subgraphs(0) 14 | tflite_graph.Name() 15 | 16 | # Tensor name to index mapping 17 | tflite_tensor_dict = {} 18 | for i in range(tflite_graph.TensorsLength()): 19 | tflite_tensor_dict[tflite_graph.Tensors(i).Name().decode("utf8")] = i 20 | 21 | 22 | parameters = {} 23 | for i in range(tflite_graph.TensorsLength()): 24 | tensor = tflite_graph.Tensors(i) 25 | if tensor.Buffer() > 0: 26 | name = tensor.Name().decode("utf8") 27 | parameters[name] = tensor.Buffer() 28 | else: 29 | # Buffer value less than zero are not weights 30 | print(tensor.Name().decode("utf8")) 31 | 32 | print("Total parameters: ", len(parameters)) 33 | 34 | 35 | def get_weights(tensor_name): 36 | index = tflite_tensor_dict[tensor_name] 37 | tensor = tflite_graph.Tensors(index) 38 | 39 | buffer = tensor.Buffer() 40 | shape = [tensor.Shape(i) for i in range(tensor.ShapeLength())] 41 | 42 | weights = model.Buffers(buffer).DataAsNumpy() 43 | weights = weights.view(dtype=np.float32) 44 | weights = weights.reshape(shape) 45 | return weights 46 | 47 | net = FacialLM_Model() 48 | weights = torch.load('facial_landmarks.pth') 49 | net.load_state_dict(weights) 50 | net = net.eval() 51 | 52 | # net(torch.randn(2,3,64,64))[0].shape 53 | 54 | probable_names = [] 55 | for i in range(0, tflite_graph.TensorsLength()): 56 | tensor = tflite_graph.Tensors(i) 57 | if tensor.Buffer() > 0 and tensor.Type() == 0: 58 | probable_names.append(tensor.Name().decode("utf-8")) 59 | 60 | pt2tflite_keys = {} 61 | i = 0 62 | for name, params in net.state_dict().items(): 63 | print(name) 64 | if i < 83: 65 | pt2tflite_keys[name] = probable_names[i] 66 | i += 1 67 | 68 | matched_keys = { 69 | 'confidence.2.depthwiseconv_conv.0.weight': 'depthwise_conv2d_16/Kernel', 70 | 'confidence.2.depthwiseconv_conv.0.bias' : 'depthwise_conv2d_16/Bias', 71 | 'confidence.2.depthwiseconv_conv.1.weight': 'conv2d_17/Kernel', 72 | 'confidence.2.depthwiseconv_conv.1.bias': 'conv2d_17/Bias', 73 | 'confidence.2.prelu.weight': 'p_re_lu_17/Alpha', 74 | 75 | 'confidence.3.weight': 'conv2d_18/Kernel', 76 | 'confidence.3.bias': 'conv2d_18/Bias', 77 | 'confidence.4.weight': 'p_re_lu_18/Alpha', 78 | 79 | 'confidence.5.depthwiseconv_conv.0.weight': 'depthwise_conv2d_17/Kernel', 80 | 'confidence.5.depthwiseconv_conv.0.bias' : 'depthwise_conv2d_17/Bias', 81 | 'confidence.5.depthwiseconv_conv.1.weight': 'conv2d_19/Kernel', 82 | 'confidence.5.depthwiseconv_conv.1.bias': 'conv2d_19/Bias', 83 | 'confidence.5.prelu.weight': 'p_re_lu_19/Alpha', 84 | 85 | 86 | 87 | 'confidence.6.weight': 'conv2d_20/Kernel', 88 | 'confidence.6.bias': 'conv2d_20/Bias', 89 | 90 | 'facial_landmarks.0.depthwiseconv_conv.0.weight': 'depthwise_conv2d_22/Kernel', 91 | 'facial_landmarks.0.depthwiseconv_conv.0.bias': 'depthwise_conv2d_22/Bias', 92 | 'facial_landmarks.0.depthwiseconv_conv.1.weight': 'conv2d_27/Kernel', 93 | 'facial_landmarks.0.depthwiseconv_conv.1.bias': 'conv2d_27/Bias', 94 | 'facial_landmarks.0.prelu.weight': 'p_re_lu_25/Alpha', 95 | 96 | 97 | 'facial_landmarks.1.weight': 'conv2d_28/Kernel', 98 | 'facial_landmarks.1.bias': 'conv2d_28/Bias', 99 | 'facial_landmarks.2.weight': 'p_re_lu_26/Alpha', 100 | 101 | 'facial_landmarks.3.depthwiseconv_conv.0.weight': 'depthwise_conv2d_23/Kernel', 102 | 'facial_landmarks.3.depthwiseconv_conv.0.bias': 'depthwise_conv2d_23/Bias', 103 | 'facial_landmarks.3.depthwiseconv_conv.1.weight': 'conv2d_29/Kernel', 104 | 'facial_landmarks.3.depthwiseconv_conv.1.bias': 'conv2d_29/Bias', 105 | 'facial_landmarks.3.prelu.weight': 'p_re_lu_27/Alpha', 106 | 107 | 'facial_landmarks.4.weight': 'conv2d_30/Kernel', 108 | 'facial_landmarks.4.bias': 'conv2d_30/Bias', 109 | 110 | } 111 | 112 | pt2tflite_keys.update(matched_keys) 113 | 114 | for key, value in pt2tflite_keys.items(): 115 | # print(key, value) 116 | tflite_ = parameters[value] 117 | W = get_weights(value) 118 | if W.ndim == 4: 119 | if 'depthwise' in value: 120 | # (1, 3, 3, 32) --> (32, 1, 3, 3) 121 | # for depthwise conv 122 | W = W.transpose((3, 0, 1, 2)) 123 | else: 124 | W = W.transpose((0, 3, 1, 2)) 125 | elif W.ndim == 3: 126 | # prelu 127 | W = W.reshape(-1) 128 | tflite_ = W 129 | 130 | torch_ = net.state_dict()[key] 131 | # print(key, value, tflite_.shape, torch_.shape) 132 | np.testing.assert_array_almost_equal(torch_.cpu().detach().numpy(), tflite_, decimal=3) 133 | print("matching ::", np.allclose(torch_.cpu().detach().numpy(), tflite_, atol=1e-03)) -------------------------------------------------------------------------------- /facial_landmarks/facial_lm_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from utils import FacialLMBasicBlock 6 | 7 | 8 | class FacialLM_Model(nn.Module): 9 | """[MediaPipe facial_landmark model in Pytorch] 10 | 11 | Args: 12 | nn ([type]): [description] 13 | 14 | Returns: 15 | [type]: [description] 16 | """ 17 | # 1x1x1x1404, 1x1x1x1 18 | # 1x1404x1x1 19 | 20 | def __init__(self): 21 | """[summary]""" 22 | super(FacialLM_Model, self).__init__() 23 | 24 | self.backbone = nn.Sequential( 25 | 26 | nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=2, padding=0, bias=True), 27 | nn.PReLU(16), 28 | 29 | FacialLMBasicBlock(16, 16), 30 | FacialLMBasicBlock(16, 16), 31 | FacialLMBasicBlock(16, 32, stride=2), # pad 32 | 33 | FacialLMBasicBlock(32, 32), 34 | FacialLMBasicBlock(32, 32), 35 | FacialLMBasicBlock(32, 64, stride=2), 36 | 37 | FacialLMBasicBlock(64, 64), 38 | FacialLMBasicBlock(64, 64), 39 | FacialLMBasicBlock(64, 128, stride=2), 40 | 41 | FacialLMBasicBlock(128, 128), 42 | FacialLMBasicBlock(128, 128), 43 | FacialLMBasicBlock(128, 128, stride=2), 44 | 45 | FacialLMBasicBlock(128, 128), 46 | FacialLMBasicBlock(128, 128) 47 | ) 48 | 49 | # facial_landmark head 50 | # @TODO change name from self.confidence to self.facial_landmarks 51 | self.confidence = nn.Sequential( 52 | FacialLMBasicBlock(128, 128, stride=2), 53 | FacialLMBasicBlock(128, 128), 54 | FacialLMBasicBlock(128, 128), 55 | # ---- 56 | nn.Conv2d(in_channels=128, out_channels=32, kernel_size=1, stride=1, padding=0, bias=True), 57 | nn.PReLU(32), 58 | 59 | FacialLMBasicBlock(32, 32), 60 | nn.Conv2d(in_channels=32, out_channels=1404, kernel_size=3, stride=3, padding=0, bias=True) 61 | ) 62 | 63 | # confidence score head 64 | # @TODO change name from self.facial_landmarks to self.confidence 65 | self.facial_landmarks = nn.Sequential( 66 | FacialLMBasicBlock(128, 128, stride=2), 67 | nn.Conv2d(in_channels=128, out_channels=32, kernel_size=1, stride=1, padding=0, bias=True), 68 | nn.PReLU(32), 69 | 70 | FacialLMBasicBlock(32, 32), 71 | nn.Conv2d(in_channels=32, out_channels=1, kernel_size=3, stride=3, padding=0, bias=True) 72 | ) 73 | 74 | 75 | @torch.no_grad() 76 | def forward(self, x): 77 | """ forward prop 78 | 79 | Args: 80 | x ([torch.Tensor]): [input Tensor] 81 | 82 | Returns: 83 | [list]: [facial_landmarks, confidence] 84 | facial_landmarks: 1 x 1404 x 1 x 1 85 | (368 x 3) 86 | (x, y, z) 87 | (x, y) corresponds to image pixel locations 88 | confidence: 1 x 1 x 1 x 1 89 | 368 face landmarks 90 | """ 91 | 92 | # @TODO remove 93 | with torch.no_grad(): 94 | x = nn.ReflectionPad2d((1, 0, 1, 0))(x) 95 | features = self.backbone(x) 96 | 97 | # @TODO change the names 98 | confidence = self.facial_landmarks(features) 99 | 100 | facial_landmarks = self.confidence(features) 101 | 102 | return [facial_landmarks, confidence] 103 | # return [facial_landmarks.view(x.shape[0], -1), confidence.reshape(x.shape[0], -1)] 104 | 105 | 106 | def predict(self, img): 107 | """ single image inference 108 | 109 | Args: 110 | img ([type]): [description] 111 | 112 | Returns: 113 | [type]: [description] 114 | """ 115 | if isinstance(img, np.ndarray): 116 | img = torch.from_numpy(img).permute((2, 0, 1)) 117 | 118 | return self.batch_predict(img.unsqueeze(0)) 119 | 120 | 121 | def batch_predict(self, x): 122 | """ batch inference 123 | currently only single image inference is supported 124 | 125 | Args: 126 | x ([type]): [description] 127 | 128 | Returns: 129 | [type]: [description] 130 | """ 131 | if isinstance(x, np.ndarray): 132 | x = torch.from_numpy(x).permute((0, 3, 1, 2)) 133 | 134 | facial_landmarks, confidence = self.forward(x) 135 | return facial_landmarks, confidence 136 | # return facial_landmarks.view(x.shape[0], -1), confidence.view(x.shape[0], -1) 137 | 138 | 139 | def test(self): 140 | """ Sample Inference""" 141 | inp = torch.randn(1, 3, 192, 192) 142 | output = self(inp) 143 | print(output[0].shape, output[1].shape) 144 | 145 | 146 | # m = FacialLM_Model() 147 | # m.test() 148 | """ 149 | m = FacialLM_Model() 150 | inp = torch.randn(1, 3, 192, 192) 151 | output = m(inp) 152 | print(output[0].shape, output[1].shape) 153 | """ -------------------------------------------------------------------------------- /facial_landmarks/inference.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import tensorflow as tf 3 | import torch 4 | import cv2 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | from facial_lm_model import FacialLM_Model 8 | from utils import pad_image 9 | 10 | 11 | class FaceMesh: 12 | """ 13 | mediapipe face mesh model inefernce in pytorch and tflite 14 | """ 15 | def __init__(self, model_path=None): 16 | """[summary] 17 | 18 | Args: 19 | model_path ([type], optional): [description]. Defaults to None. 20 | """ 21 | # @TODO change model_path 22 | # tflite model 23 | self.interpreter = tf.lite.Interpreter("model_weights/face_landmark.tflite") 24 | 25 | # pytorch model 26 | self.torch_model = FacialLM_Model() 27 | weights = torch.load('model_weights/facial_landmarks.pth') 28 | self.torch_model.load_state_dict(weights) 29 | self.torch_model = self.torch_model.eval() 30 | 31 | 32 | def __call__(self, img_path): 33 | """[summary] 34 | 35 | Args: 36 | img_path ([str]): [image path] 37 | 38 | Returns: 39 | [list]: [face landmarks and confidence] 40 | """ 41 | self.input_details = self.interpreter.get_input_details() 42 | self.output_details = self.interpreter.get_output_details() 43 | 44 | self.interpreter.allocate_tensors() 45 | blob = cv2.imread(img_path).astype(np.float32) 46 | blob = cv2.cvtColor(blob, cv2.COLOR_BGR2RGB) 47 | 48 | blob = pad_image(blob, desired_size=192) 49 | 50 | # -1 to 1 norm 51 | # blob /= 255 # x.float() / 127.5 - 1.0 52 | # @TODO /255 works better for few images 53 | blob = (blob / 127.5) - 1.0 54 | # blob = (blob - 0.5) / 0.5 55 | # blob = blob / 127.5 56 | # blob = (blob - 128) / 255.0 57 | 58 | facial_landmarks_torch, confidence_torch = self.torch_model.predict(blob) 59 | 60 | blob = np.expand_dims(blob, axis=0) 61 | 62 | self.interpreter.set_tensor(self.input_details[0]['index'], blob) 63 | 64 | self.interpreter.invoke() 65 | 66 | facial_landmarks = self.interpreter.get_tensor(self.output_details[0]['index']) 67 | # confidence = self.interpreter.get_tensor(self.output_details[1]['index']) 68 | 69 | # np.testing.assert_array_almost_equal(facial_landmarks_torch.cpu().detach().numpy(), facial_landmarks, decimal=3) 70 | print("Tensorrt and torch values are matching ::", np.allclose(facial_landmarks_torch.cpu().detach().numpy(), facial_landmarks, atol=1e-02)) 71 | return facial_landmarks_torch, confidence_torch 72 | 73 | 74 | m = FaceMesh() 75 | img_path = '4.jpg' 76 | facial_landmarks_torch, confidence_torch = m(img_path) 77 | 78 | im = cv2.imread(img_path) 79 | im = pad_image(im, desired_size=192) 80 | 81 | facial_landmarks_ = facial_landmarks_torch.reshape(-1) 82 | np.save('output', facial_landmarks_) 83 | for idx in range(468): 84 | cv2.circle(im, (int(facial_landmarks_[idx*3]), int(facial_landmarks_[idx*3 + 1])), 1, (200, 160, 75), -1) 85 | 86 | # cv2.imwrite('output.jpg', im) 87 | im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) 88 | plt.figure(figsize=(10,10)) 89 | plt.imshow(im) 90 | plt.show() -------------------------------------------------------------------------------- /facial_landmarks/inference_folder.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import numpy as np 4 | from facial_lm_model import FacialLM_Model 5 | from utils import pad_image 6 | import glob, os 7 | from tqdm import tqdm 8 | 9 | 10 | class FaceMesh: 11 | """ 12 | mediapipe face mesh inefernce in pytorch 13 | """ 14 | def __init__(self, model_path='model_weights/facial_landmarks.pth'): 15 | """[summary] 16 | 17 | Args: 18 | model_path (str, optional): [description]. 19 | Defaults to 'model_weights/facial_landmarks.pth'. 20 | """ 21 | self.torch_model = FacialLM_Model() 22 | weights = torch.load(model_path) 23 | self.torch_model.load_state_dict(weights) 24 | self.torch_model = self.torch_model.eval() 25 | 26 | 27 | def __call__(self, img_path): 28 | """[summary] 29 | 30 | Args: 31 | img_path ([str]): [image path] 32 | 33 | Returns: 34 | [list]: [face landmarks and confidence] 35 | """ 36 | blob = cv2.imread(img_path).astype(np.float32) 37 | blob = cv2.cvtColor(blob, cv2.COLOR_BGR2RGB) 38 | 39 | blob = pad_image(blob, desired_size=192) 40 | 41 | # -1 to 1 norm 42 | blob = (blob/127.5) - 1.0 43 | 44 | facial_landmarks_torch, confidence_torch = self.torch_model.predict(blob) 45 | 46 | return facial_landmarks_torch, confidence_torch 47 | 48 | 49 | model = FaceMesh() 50 | img_paths = glob.glob('test_images/*') 51 | output_dir = 'outputs' 52 | os.makedirs(output_dir, exist_ok=True) 53 | 54 | for img_path in tqdm(img_paths): 55 | facial_landmarks_torch, confidence_torch = model(img_path) 56 | 57 | im = cv2.imread(img_path) 58 | im = pad_image(im) 59 | 60 | facial_landmarks_ = facial_landmarks_torch.reshape(-1) 61 | 62 | for idx in range(468): 63 | cv2.circle(im, (int(facial_landmarks_[idx*3]), int(facial_landmarks_[idx*3 + 1])), 1, (0, 0, 255), -1) 64 | 65 | # im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) 66 | filename = os.path.join(output_dir, os.path.basename(img_path)) 67 | cv2.imwrite(filename, im) 68 | # import matplotlib.pyplot as plt 69 | # plt.imshow(im) 70 | # plt.show() -------------------------------------------------------------------------------- /facial_landmarks/model_weights/face_landmark.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiqq111/mediapipe_pytorch/4eb026c8fdcdceda28cb7d96bec64c150577f289/facial_landmarks/model_weights/face_landmark.tflite -------------------------------------------------------------------------------- /facial_landmarks/model_weights/facial_landmarks.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiqq111/mediapipe_pytorch/4eb026c8fdcdceda28cb7d96bec64c150577f289/facial_landmarks/model_weights/facial_landmarks.pth -------------------------------------------------------------------------------- /facial_landmarks/test_images/m@35189.2e16d0ba.fill-514x514.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiqq111/mediapipe_pytorch/4eb026c8fdcdceda28cb7d96bec64c150577f289/facial_landmarks/test_images/m@35189.2e16d0ba.fill-514x514.jpg -------------------------------------------------------------------------------- /facial_landmarks/tflite2pt.py: -------------------------------------------------------------------------------- 1 | from tflite import Model 2 | import numpy as np 3 | from collections import OrderedDict 4 | from facial_lm_model import FacialLM_Model 5 | from utils import GetKeysDict 6 | import torch 7 | 8 | data = open("model_weights/face_landmark.tflite", "rb").read() 9 | model = Model.GetRootAsModel(data, 0) 10 | 11 | tflite_graph = model.Subgraphs(0) 12 | tflite_graph.Name() 13 | 14 | # Tensor name to index mapping 15 | tflite_tensor_dict = {} 16 | for i in range(tflite_graph.TensorsLength()): 17 | tflite_tensor_dict[tflite_graph.Tensors(i).Name().decode("utf8")] = i 18 | 19 | def get_weights(tensor_name): 20 | index = tflite_tensor_dict[tensor_name] 21 | tensor = tflite_graph.Tensors(index) 22 | 23 | buffer = tensor.Buffer() 24 | shape = [tensor.Shape(i) for i in range(tensor.ShapeLength())] 25 | 26 | weights = model.Buffers(buffer).DataAsNumpy() 27 | weights = weights.view(dtype=np.float32) 28 | weights = weights.reshape(shape) 29 | return weights 30 | 31 | 32 | # Store the weights in dict 33 | parameters = {} 34 | for i in range(tflite_graph.TensorsLength()): 35 | tensor = tflite_graph.Tensors(i) 36 | if tensor.Buffer() > 0: 37 | name = tensor.Name().decode("utf8") 38 | parameters[name] = tensor.Buffer() 39 | else: 40 | # Buffer value less than zero are not weights 41 | print(tensor.Name().decode("utf8")) 42 | 43 | print("Total parameters: ", len(parameters)) 44 | 45 | pt_model = FacialLM_Model() 46 | # pt_model(torch.randn(2,3,64,64))[0].shape 47 | 48 | probable_names = [] 49 | for i in range(0, tflite_graph.TensorsLength()): 50 | tensor = tflite_graph.Tensors(i) 51 | if tensor.Buffer() > 0 and tensor.Type() == 0: 52 | probable_names.append(tensor.Name().decode("utf-8")) 53 | 54 | pt2tflite_keys = {} 55 | i = 0 56 | for name, params in pt_model.state_dict().items(): 57 | # first 83 nodes names are perfectly matched 58 | if i < 83: 59 | pt2tflite_keys[name] = probable_names[i] 60 | i += 1 61 | 62 | # Remaining nodes 63 | matched_keys = GetKeysDict().facial_landmark_dict 64 | 65 | # update the remaining keys 66 | pt2tflite_keys.update(matched_keys) 67 | 68 | new_state_dict = OrderedDict() 69 | 70 | for pt_key, tflite_key in pt2tflite_keys.items(): 71 | weight = get_weights(tflite_key) 72 | print(pt_key, tflite_key, weight.shape, pt_model.state_dict()[pt_key].shape) 73 | 74 | # if pt_key == 'facial_landmarks.4.weight': 75 | # weight = weight.transpose((0, 3, 1, 2)) 76 | # weight = weight.transpose((0, 3, 2, 1)) 77 | # print(weight.shape) 78 | # print(weight) 79 | # else: 80 | if weight.ndim == 4: 81 | if 'depthwise' in tflite_key: 82 | # (1, 3, 3, 32) --> (32, 1, 3, 3) 83 | # for depthwise conv 84 | weight = weight.transpose((3, 0, 1, 2)) 85 | else: 86 | weight = weight.transpose((0, 3, 1, 2)) 87 | 88 | if 'p_re_lu' in tflite_key: 89 | weight = weight.reshape(-1) 90 | 91 | new_state_dict[pt_key] = torch.from_numpy(weight) 92 | 93 | pt_model.load_state_dict(new_state_dict, strict=True) 94 | 95 | torch.save(pt_model.state_dict(), "facial_landmarks.pth") -------------------------------------------------------------------------------- /facial_landmarks/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import cv2 4 | import torch.nn.functional as F 5 | 6 | 7 | class FacialLMBasicBlock(nn.Module): 8 | """ Building block for mediapipe facial landmark model 9 | 10 | DepthwiseConv + Conv + PRelu 11 | downsampling + channel padding for few blocks(when stride=2) 12 | channel padding values - 16, 32, 64 13 | 14 | Args: 15 | nn ([type]): [description] 16 | """ 17 | 18 | def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1): 19 | super(FacialLMBasicBlock, self).__init__() 20 | 21 | self.stride = stride 22 | self.channel_pad = out_channels - in_channels 23 | 24 | # TFLite uses slightly different padding than PyTorch 25 | # on the depthwise conv layer when the stride is 2. 26 | if stride == 2: 27 | self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride) 28 | padding = 0 29 | else: 30 | padding = (kernel_size - 1) // 2 31 | 32 | self.depthwiseconv_conv = nn.Sequential( 33 | nn.Conv2d(in_channels=in_channels, out_channels=in_channels, 34 | kernel_size=kernel_size, stride=stride, padding=padding, 35 | groups=in_channels, bias=True), 36 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, 37 | kernel_size=1, stride=1, padding=0, bias=True), 38 | ) 39 | 40 | self.prelu = nn.PReLU(out_channels) 41 | 42 | 43 | def forward(self, x): 44 | """[summary] 45 | 46 | Args: 47 | x ([torch.Tensor]): [input tensor] 48 | 49 | Returns: 50 | [torch.Tensor]: [featues] 51 | """ 52 | 53 | if self.stride == 2: 54 | h = F.pad(x, (0, 2, 0, 2), "constant", 0) 55 | x = self.max_pool(x) 56 | else: 57 | h = x 58 | 59 | if self.channel_pad > 0: 60 | x = F.pad(x, (0, 0, 0, 0, 0, self.channel_pad), "constant", 0) 61 | 62 | return self.prelu(self.depthwiseconv_conv(h) + x) 63 | 64 | 65 | def pad_image(im, desired_size=192): 66 | """[summary] 67 | 68 | Args: 69 | im ([cv2 image]): [input image] 70 | desired_size (int, optional): [description]. Defaults to 64. 71 | 72 | Returns: 73 | [cv2 image]: [resized image] 74 | """ 75 | old_size = im.shape[:2] # old_size is in (height, width) format 76 | 77 | ratio = float(desired_size)/max(old_size) 78 | new_size = tuple([int(x*ratio) for x in old_size]) 79 | 80 | # new_size should be in (width, height) format 81 | 82 | im = cv2.resize(im, (new_size[1], new_size[0])) 83 | 84 | delta_w = desired_size - new_size[1] 85 | delta_h = desired_size - new_size[0] 86 | top, bottom = delta_h//2, delta_h-(delta_h//2) 87 | left, right = delta_w//2, delta_w-(delta_w//2) 88 | 89 | color = [0, 0, 0] 90 | new_im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, 91 | value=color) 92 | 93 | return new_im 94 | 95 | 96 | class GetKeysDict: 97 | """ 98 | maps pytorch keys to tflite keys 99 | """ 100 | def __init__(self): 101 | self.facial_landmark_dict = { 102 | 'confidence.2.depthwiseconv_conv.0.weight': 'depthwise_conv2d_16/Kernel', 103 | 'confidence.2.depthwiseconv_conv.0.bias' : 'depthwise_conv2d_16/Bias', 104 | 'confidence.2.depthwiseconv_conv.1.weight': 'conv2d_17/Kernel', 105 | 'confidence.2.depthwiseconv_conv.1.bias': 'conv2d_17/Bias', 106 | 'confidence.2.prelu.weight': 'p_re_lu_17/Alpha', 107 | 108 | 'confidence.3.weight': 'conv2d_18/Kernel', 109 | 'confidence.3.bias': 'conv2d_18/Bias', 110 | 'confidence.4.weight': 'p_re_lu_18/Alpha', 111 | 112 | 'confidence.5.depthwiseconv_conv.0.weight': 'depthwise_conv2d_17/Kernel', 113 | 'confidence.5.depthwiseconv_conv.0.bias' : 'depthwise_conv2d_17/Bias', 114 | 'confidence.5.depthwiseconv_conv.1.weight': 'conv2d_19/Kernel', 115 | 'confidence.5.depthwiseconv_conv.1.bias': 'conv2d_19/Bias', 116 | 'confidence.5.prelu.weight': 'p_re_lu_19/Alpha', 117 | 118 | 119 | 120 | 'confidence.6.weight': 'conv2d_20/Kernel', 121 | 'confidence.6.bias': 'conv2d_20/Bias', 122 | 123 | 'facial_landmarks.0.depthwiseconv_conv.0.weight': 'depthwise_conv2d_22/Kernel', 124 | 'facial_landmarks.0.depthwiseconv_conv.0.bias': 'depthwise_conv2d_22/Bias', 125 | 'facial_landmarks.0.depthwiseconv_conv.1.weight': 'conv2d_27/Kernel', 126 | 'facial_landmarks.0.depthwiseconv_conv.1.bias': 'conv2d_27/Bias', 127 | 'facial_landmarks.0.prelu.weight': 'p_re_lu_25/Alpha', 128 | 129 | 130 | 'facial_landmarks.1.weight': 'conv2d_28/Kernel', 131 | 'facial_landmarks.1.bias': 'conv2d_28/Bias', 132 | 'facial_landmarks.2.weight': 'p_re_lu_26/Alpha', 133 | 134 | 'facial_landmarks.3.depthwiseconv_conv.0.weight': 'depthwise_conv2d_23/Kernel', 135 | 'facial_landmarks.3.depthwiseconv_conv.0.bias': 'depthwise_conv2d_23/Bias', 136 | 'facial_landmarks.3.depthwiseconv_conv.1.weight': 'conv2d_29/Kernel', 137 | 'facial_landmarks.3.depthwiseconv_conv.1.bias': 'conv2d_29/Bias', 138 | 'facial_landmarks.3.prelu.weight': 'p_re_lu_27/Alpha', 139 | 140 | 'facial_landmarks.4.weight': 'conv2d_30/Kernel', 141 | 'facial_landmarks.4.bias': 'conv2d_30/Bias', 142 | 143 | } -------------------------------------------------------------------------------- /iris/inference.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import tensorflow as tf 3 | import torch 4 | import cv2 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | from irismodel import IrisLM 8 | 9 | 10 | def plot(img_path, iris, eye_contour): 11 | im = cv2.imread(img_path) 12 | im = pad_image(im, desired_size=64) 13 | 14 | lm = iris[0] 15 | h, w, _ = im.shape 16 | 17 | cv2.circle(im, (int(lm[0]), int(lm[1])), 2, (0, 255, 0), -1) 18 | cv2.circle(im, (int(lm[3]), int(lm[4])), 1, (255, 0, 255), -1) 19 | cv2.circle(im, (int(lm[6]), int(lm[7])), 2, (255, 0, 255), -1) 20 | cv2.circle(im, (int(lm[9] ), int(lm[10])), 1, (255, 0, 255), -1) 21 | cv2.circle(im, (int(lm[12] ), int(lm[13])), 1, (255, 0, 255), -1) 22 | 23 | eye_contour 24 | for idx in range(71): 25 | cv2.circle(im, (int(eye_contour[0][idx*3]), int(eye_contour[0][idx*3 + 1])), 1, (0, 0, 255), -1) 26 | 27 | im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) 28 | plt.figure(figsize=(10,10)) 29 | plt.imshow(im) 30 | plt.show() 31 | 32 | 33 | def pad_image(im, desired_size=64): 34 | 35 | old_size = im.shape[:2] # old_size is in (height, width) format 36 | 37 | ratio = float(desired_size)/max(old_size) 38 | new_size = tuple([int(x*ratio) for x in old_size]) 39 | 40 | # new_size should be in (width, height) format 41 | 42 | im = cv2.resize(im, (new_size[1], new_size[0])) 43 | 44 | delta_w = desired_size - new_size[1] 45 | delta_h = desired_size - new_size[0] 46 | top, bottom = delta_h//2, delta_h-(delta_h//2) 47 | left, right = delta_w//2, delta_w-(delta_w//2) 48 | 49 | color = [0, 0, 0] 50 | new_im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, 51 | value=color) 52 | 53 | new_im.shape 54 | return new_im 55 | 56 | 57 | class Model: 58 | def __init__(self): 59 | """[summary] 60 | 61 | Args: 62 | model_path ([str]): [path] 63 | """ 64 | # @TODO change path 65 | # add model paths for both pt and tflite models 66 | self.interpreter = tf.lite.Interpreter(model_path="model_weights/iris_landmark.tflite") # Model Loading 67 | self.net = IrisLM() 68 | weights = torch.load('model_weights/irislandmarks.pth') 69 | self.net.load_state_dict(weights) 70 | self.net = self.net.eval() 71 | 72 | def __call__(self, img_path): 73 | self.input_details = self.interpreter.get_input_details() 74 | self.output_details = self.interpreter.get_output_details() 75 | print(self.input_details) 76 | print(self.output_details) 77 | 78 | self.interpreter.allocate_tensors() 79 | blob = cv2.imread(img_path).astype(np.float32) 80 | blob = pad_image(blob, desired_size=64) 81 | 82 | blob /= 255 # x.float() / 127.5 - 1.0 83 | 84 | eye_contour_torch, iris_torch = self.net.predict(blob) 85 | 86 | blob = np.expand_dims(blob, axis=0) 87 | 88 | self.interpreter.set_tensor(self.input_details[0]['index'], blob) 89 | 90 | self.interpreter.invoke() 91 | 92 | eye_contour = self.interpreter.get_tensor(self.output_details[0]['index']) 93 | iris = self.interpreter.get_tensor(self.output_details[1]['index']) 94 | np.testing.assert_array_almost_equal(eye_contour_torch.cpu().detach().numpy(), eye_contour, decimal=3) 95 | print("Are tflite and torch values matching? ::", np.allclose(eye_contour_torch.cpu().detach().numpy(), eye_contour, atol=1e-03)) 96 | return eye_contour, iris 97 | 98 | m = Model() 99 | img_path = 'iris2.jpg' 100 | eye_contour, iris = m(img_path) 101 | plot(img_path, iris, eye_contour) -------------------------------------------------------------------------------- /iris/iris2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiqq111/mediapipe_pytorch/4eb026c8fdcdceda28cb7d96bec64c150577f289/iris/iris2.jpg -------------------------------------------------------------------------------- /iris/irismodel.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from utils import IrisBlock 6 | 7 | 8 | class IrisLM(nn.Module): 9 | """[summary] 10 | 11 | Args: 12 | nn ([type]): [description] 13 | """ 14 | def __init__(self): 15 | """[summary] 16 | """ 17 | super(IrisLM, self).__init__() 18 | 19 | self.backbone = nn.Sequential( 20 | 21 | nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=2, padding=0, bias=True), 22 | nn.PReLU(64), 23 | 24 | IrisBlock(64, 64), IrisBlock(64, 64), 25 | IrisBlock(64, 64), IrisBlock(64, 64), 26 | IrisBlock(64, 128, stride=2), 27 | 28 | IrisBlock(128, 128), IrisBlock(128, 128), 29 | IrisBlock(128, 128), IrisBlock(128, 128), 30 | IrisBlock(128, 128, stride=2) 31 | ) 32 | 33 | # iris_contour head 34 | self.iris_contour = nn.Sequential( 35 | IrisBlock(128, 128), IrisBlock(128, 128), 36 | IrisBlock(128, 128, stride=2), 37 | IrisBlock(128, 128), IrisBlock(128, 128), 38 | IrisBlock(128, 128, stride=2), 39 | IrisBlock(128, 128), IrisBlock(128, 128), 40 | nn.Conv2d(in_channels=128, out_channels=15, kernel_size=2, stride=1, padding=0, bias=True) 41 | ) 42 | 43 | # eye_contour head 44 | self.eye_contour = nn.Sequential( 45 | IrisBlock(128, 128), IrisBlock(128, 128), 46 | IrisBlock(128, 128, stride=2), 47 | IrisBlock(128, 128), IrisBlock(128, 128), 48 | IrisBlock(128, 128, stride=2), 49 | IrisBlock(128, 128), IrisBlock(128, 128), 50 | nn.Conv2d(in_channels=128, out_channels=213, kernel_size=2, stride=1, padding=0, bias=True) 51 | ) 52 | 53 | 54 | @torch.no_grad() 55 | def forward(self, x): 56 | """ forward prop 57 | 58 | Args: 59 | x ([torch.Tensor]): [input Tensor] 60 | 61 | Returns: 62 | [list]: [eye_contour, iris_contour] 63 | eye_contour (batch_size, 213) 64 | (71 points) 65 | (x, y, z) 66 | (x, y) corresponds to image pixel locations 67 | iris_contour (batch_size, 15) 68 | (5, 3) 5 points 69 | """ 70 | with torch.no_grad(): 71 | x = F.pad(x, [0, 1, 0, 1], "constant", 0) 72 | 73 | # (_, 128, 8, 8) 74 | features = self.backbone(x) 75 | 76 | # (_, 213, 1, 1) 77 | eye_contour = self.eye_contour(features) 78 | 79 | # (_, 15, 1, 1) 80 | iris_contour = self.iris_contour(features) 81 | # (batch_size, 213) (batch_size, 15) 82 | return [eye_contour.view(x.shape[0], -1), iris_contour.reshape(x.shape[0], -1)] 83 | 84 | 85 | def predict(self, img): 86 | """ single image inference 87 | 88 | Args: 89 | img ([type]): [description] 90 | 91 | Returns: 92 | [type]: [description] 93 | """ 94 | if isinstance(img, np.ndarray): 95 | img = torch.from_numpy(img).permute((2, 0, 1)) 96 | 97 | return self.batch_predict(img.unsqueeze(0)) 98 | 99 | 100 | def batch_predict(self, x): 101 | """ batch inference 102 | 103 | Args: 104 | x ([type]): [description] 105 | 106 | Returns: 107 | [type]: [description] 108 | """ 109 | if isinstance(x, np.ndarray): 110 | x = torch.from_numpy(x).permute((0, 3, 1, 2)) 111 | 112 | eye_contour, iris_contour = self.forward(x) 113 | 114 | return eye_contour.view(x.shape[0], -1), iris_contour.view(x.shape[0], -1) 115 | 116 | 117 | def test(self): 118 | """ Sample Inference""" 119 | inp = torch.randn(1, 3, 64, 64) 120 | output = self(inp) 121 | print(output[0].shape, output[1].shape) -------------------------------------------------------------------------------- /iris/model_weights/iris_landmark.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiqq111/mediapipe_pytorch/4eb026c8fdcdceda28cb7d96bec64c150577f289/iris/model_weights/iris_landmark.tflite -------------------------------------------------------------------------------- /iris/model_weights/irislandmarks.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiqq111/mediapipe_pytorch/4eb026c8fdcdceda28cb7d96bec64c150577f289/iris/model_weights/irislandmarks.pth -------------------------------------------------------------------------------- /iris/tfite2pt.py: -------------------------------------------------------------------------------- 1 | from tflite import Model 2 | import numpy as np 3 | from collections import OrderedDict 4 | from irismodel import IrisLM 5 | import torch 6 | from utils import GetKeysDict 7 | 8 | 9 | data = open("model_weights/iris_landmark.tflite", "rb").read() 10 | model = Model.GetRootAsModel(data, 0) 11 | 12 | tflite_graph = model.Subgraphs(0) 13 | tflite_graph.Name() 14 | 15 | tflite_tensor_dict = {} 16 | for i in range(tflite_graph.TensorsLength()): 17 | tflite_tensor_dict[tflite_graph.Tensors(i).Name().decode("utf8")] = i 18 | 19 | parameters = {} 20 | for i in range(tflite_graph.TensorsLength()): 21 | tensor = tflite_graph.Tensors(i) 22 | if tensor.Buffer() > 0: 23 | name = tensor.Name().decode("utf8") 24 | parameters[name] = tensor.Buffer() 25 | else: 26 | # Buffer value less than zero are not weights 27 | print(tensor.Name().decode("utf8")) 28 | 29 | print("Total parameters: ", len(parameters)) 30 | 31 | 32 | def get_weights(tensor_name): 33 | index = tflite_tensor_dict[tensor_name] 34 | tensor = tflite_graph.Tensors(index) 35 | 36 | buffer = tensor.Buffer() 37 | shape = [tensor.Shape(i) for i in range(tensor.ShapeLength())] 38 | 39 | weights = model.Buffers(buffer).DataAsNumpy() 40 | weights = weights.view(dtype=np.float32) 41 | weights = weights.reshape(shape) 42 | return weights 43 | 44 | 45 | net = IrisLM() 46 | # net(torch.randn(2,3,64,64))[0].shape 47 | 48 | probable_names = [] 49 | for i in range(0, tflite_graph.TensorsLength()): 50 | tensor = tflite_graph.Tensors(i) 51 | if tensor.Buffer() > 0 and tensor.Type() == 0: 52 | probable_names.append(tensor.Name().decode("utf-8")) 53 | 54 | pt2tflite_keys = {} 55 | i = 0 56 | for name, params in net.state_dict().items(): 57 | print(name) 58 | if i < 85: 59 | pt2tflite_keys[name] = probable_names[i] 60 | i += 1 61 | 62 | matched_keys = GetKeysDict().iris_landmark_dict 63 | pt2tflite_keys.update(matched_keys) 64 | 65 | new_state_dict = OrderedDict() 66 | 67 | for pt_key, tflite_key in pt2tflite_keys.items(): 68 | weight = get_weights(tflite_key) 69 | print(pt_key, tflite_key) 70 | 71 | if weight.ndim == 4: 72 | if weight.shape[0] == 1: 73 | weight = weight.transpose((3, 0, 1, 2)) 74 | else: 75 | weight = weight.transpose((0, 3, 1, 2)) 76 | elif weight.ndim == 3: 77 | weight = weight.reshape(-1) 78 | 79 | new_state_dict[pt_key] = torch.from_numpy(weight) 80 | 81 | net.load_state_dict(new_state_dict, strict=True) 82 | 83 | torch.save(net.state_dict(), "model_weights/irislandmarks.pth") -------------------------------------------------------------------------------- /iris/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class IrisBlock(nn.Module): 7 | """ Building block for mediapipe iris landmark model 8 | 9 | COnv + PRelu + DepthwiseConv + Conv + PRelu 10 | downsampling + channel padding for few blocks(when stride=2) 11 | channel padding values - 12 | 13 | Args: 14 | nn ([type]): [description] 15 | """ 16 | def __init__(self, in_channels, out_channels, kernel_size = 3, stride = 1): 17 | super(IrisBlock, self).__init__() 18 | 19 | self.stride = stride 20 | self.channel_pad = out_channels - in_channels 21 | 22 | padding = (kernel_size - 1) // 2 23 | 24 | self.conv_prelu = nn.Sequential( 25 | nn.Conv2d(in_channels=in_channels, out_channels=int(out_channels/2), kernel_size=stride, stride=stride, padding=0, bias=True), 26 | nn.PReLU(int(out_channels/2)) 27 | ) 28 | self.depthwiseconv_conv = nn.Sequential( 29 | nn.Conv2d(in_channels=int(out_channels/2), out_channels=int(out_channels/2), 30 | kernel_size=kernel_size, stride=1, padding=padding, 31 | groups=int(out_channels/2), bias=True), 32 | nn.Conv2d(in_channels=int(out_channels/2), out_channels=out_channels, kernel_size=1, stride=1, padding=0, bias=True), 33 | ) 34 | 35 | # Downsample 36 | if stride == 2: 37 | self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride) 38 | 39 | self.prelu = nn.PReLU(out_channels) 40 | 41 | 42 | @torch.no_grad() 43 | def forward(self, x): 44 | """[summary] 45 | 46 | Args: 47 | x ([torch.Tensor]): [input tensor] 48 | 49 | Returns: 50 | [torch.Tensor]: [featues] 51 | """ 52 | out = self.conv_prelu(x) 53 | out = self.depthwiseconv_conv(out) 54 | 55 | if self.stride == 2: 56 | x = self.max_pool(x) 57 | 58 | if self.channel_pad > 0: 59 | x = F.pad(x, (0, 0, 0, 0, 0, self.channel_pad), "constant", 0) 60 | 61 | return self.prelu(out + x) 62 | 63 | 64 | class GetKeysDict: 65 | def __init__(self): 66 | self.iris_landmark_dict = { 67 | 'eye_contour.0.conv_prelu.0.weight': 'conv2d_21/Kernel', 68 | 'eye_contour.0.conv_prelu.0.bias': 'conv2d_21/Bias', 69 | 70 | 'eye_contour.0.conv_prelu.1.weight': 'p_re_lu_21/Alpha', 71 | 'eye_contour.0.depthwiseconv_conv.0.weight': 'depthwise_conv2d_10/Kernel', 72 | 'eye_contour.0.depthwiseconv_conv.0.bias': 'depthwise_conv2d_10/Bias', 73 | 'eye_contour.0.depthwiseconv_conv.1.weight': 'conv2d_22/Kernel', 74 | 'eye_contour.0.depthwiseconv_conv.1.bias': 'conv2d_22/Bias', 75 | 'eye_contour.0.prelu.weight': 'p_re_lu_22/Alpha', 76 | 'eye_contour.1.conv_prelu.0.weight': 'conv2d_23/Kernel', 77 | 'eye_contour.1.conv_prelu.0.bias': 'conv2d_23/Bias', 78 | 'eye_contour.1.conv_prelu.1.weight': 'p_re_lu_23/Alpha', 79 | 'eye_contour.1.depthwiseconv_conv.0.weight': 'depthwise_conv2d_11/Kernel', 80 | 'eye_contour.1.depthwiseconv_conv.0.bias': 'depthwise_conv2d_11/Bias', 81 | 'eye_contour.1.depthwiseconv_conv.1.weight': 'conv2d_24/Kernel', 82 | 'eye_contour.1.depthwiseconv_conv.1.bias': 'conv2d_24/Bias', 83 | 'eye_contour.1.prelu.weight': 'p_re_lu_24/Alpha', 84 | 'eye_contour.2.conv_prelu.0.weight': 'conv2d_25/Kernel', 85 | 'eye_contour.2.conv_prelu.0.bias': 'conv2d_25/Bias', 86 | 'eye_contour.2.conv_prelu.1.weight': 'p_re_lu_25/Alpha', 87 | 'eye_contour.2.depthwiseconv_conv.0.weight': 'depthwise_conv2d_12/Kernel', 88 | 'eye_contour.2.depthwiseconv_conv.0.bias': 'depthwise_conv2d_12/Bias', 89 | 'eye_contour.2.depthwiseconv_conv.1.weight': 'conv2d_26/Kernel', 90 | 'eye_contour.2.depthwiseconv_conv.1.bias': 'conv2d_26/Bias', 91 | 'eye_contour.2.prelu.weight': 'p_re_lu_26/Alpha', 92 | 'eye_contour.3.conv_prelu.0.weight': 'conv2d_27/Kernel', 93 | 'eye_contour.3.conv_prelu.0.bias': 'conv2d_27/Bias', 94 | 'eye_contour.3.conv_prelu.1.weight': 'p_re_lu_27/Alpha', 95 | 'eye_contour.3.depthwiseconv_conv.0.weight': 'depthwise_conv2d_13/Kernel', 96 | 'eye_contour.3.depthwiseconv_conv.0.bias': 'depthwise_conv2d_13/Bias', 97 | 'eye_contour.3.depthwiseconv_conv.1.weight': 'conv2d_28/Kernel', 98 | 'eye_contour.3.depthwiseconv_conv.1.bias': 'conv2d_28/Bias', 99 | 'eye_contour.3.prelu.weight': 'p_re_lu_28/Alpha', 100 | 'eye_contour.4.conv_prelu.0.weight': 'conv2d_29/Kernel', 101 | 'eye_contour.4.conv_prelu.0.bias': 'conv2d_29/Bias', 102 | 'eye_contour.4.conv_prelu.1.weight': 'p_re_lu_29/Alpha', 103 | 'eye_contour.4.depthwiseconv_conv.0.weight': 'depthwise_conv2d_14/Kernel', 104 | 'eye_contour.4.depthwiseconv_conv.0.bias': 'depthwise_conv2d_14/Bias', 105 | 'eye_contour.4.depthwiseconv_conv.1.weight': 'conv2d_30/Kernel', 106 | 'eye_contour.4.depthwiseconv_conv.1.bias': 'conv2d_30/Bias', 107 | 'eye_contour.4.prelu.weight': 'p_re_lu_30/Alpha', 108 | 'eye_contour.5.conv_prelu.0.weight': 'conv2d_31/Kernel', 109 | 'eye_contour.5.conv_prelu.0.bias': 'conv2d_31/Bias', 110 | 'eye_contour.5.conv_prelu.1.weight': 'p_re_lu_31/Alpha', 111 | 'eye_contour.5.depthwiseconv_conv.0.weight': 'depthwise_conv2d_15/Kernel', 112 | 'eye_contour.5.depthwiseconv_conv.0.bias': 'depthwise_conv2d_15/Bias', 113 | 'eye_contour.5.depthwiseconv_conv.1.weight': 'conv2d_32/Kernel', 114 | 'eye_contour.5.depthwiseconv_conv.1.bias': 'conv2d_32/Bias', 115 | 'eye_contour.5.prelu.weight': 'p_re_lu_32/Alpha', 116 | 'eye_contour.6.conv_prelu.0.weight': 'conv2d_33/Kernel', 117 | 'eye_contour.6.conv_prelu.0.bias': 'conv2d_33/Bias', 118 | 'eye_contour.6.conv_prelu.1.weight': 'p_re_lu_33/Alpha', 119 | 'eye_contour.6.depthwiseconv_conv.0.weight': 'depthwise_conv2d_16/Kernel', 120 | 'eye_contour.6.depthwiseconv_conv.0.bias': 'depthwise_conv2d_16/Bias', 121 | 'eye_contour.6.depthwiseconv_conv.1.weight': 'conv2d_34/Kernel', 122 | 'eye_contour.6.depthwiseconv_conv.1.bias': 'conv2d_34/Bias', 123 | 'eye_contour.6.prelu.weight': 'p_re_lu_34/Alpha', 124 | 'eye_contour.7.conv_prelu.0.weight': 'conv2d_35/Kernel', 125 | 'eye_contour.7.conv_prelu.0.bias': 'conv2d_35/Bias', 126 | 'eye_contour.7.conv_prelu.1.weight': 'p_re_lu_35/Alpha', 127 | 'eye_contour.7.depthwiseconv_conv.0.weight': 'depthwise_conv2d_17/Kernel', 128 | 'eye_contour.7.depthwiseconv_conv.0.bias': 'depthwise_conv2d_17/Bias', 129 | 'eye_contour.7.depthwiseconv_conv.1.weight': 'conv2d_36/Kernel', 130 | 'eye_contour.7.depthwiseconv_conv.1.bias': 'conv2d_36/Bias', 131 | 'eye_contour.7.prelu.weight': 'p_re_lu_36/Alpha', 132 | 'eye_contour.8.weight': 'conv_eyes_contours_and_brows/Kernel', 133 | 'eye_contour.8.bias': 'conv_eyes_contours_and_brows/Bias', 134 | 135 | 'iris_contour.0.conv_prelu.0.weight': 'conv2d_37/Kernel', 136 | 'iris_contour.0.conv_prelu.0.bias': 'conv2d_37/Bias', 137 | 'iris_contour.0.conv_prelu.1.weight': 'p_re_lu_37/Alpha', 138 | 'iris_contour.0.depthwiseconv_conv.0.weight': 'depthwise_conv2d_18/Kernel', 139 | 'iris_contour.0.depthwiseconv_conv.0.bias': 'depthwise_conv2d_18/Bias', 140 | 'iris_contour.0.depthwiseconv_conv.1.weight': 'conv2d_38/Kernel', 141 | 'iris_contour.0.depthwiseconv_conv.1.bias': 'conv2d_38/Bias', 142 | 'iris_contour.0.prelu.weight': 'p_re_lu_38/Alpha', 143 | 'iris_contour.1.conv_prelu.0.weight': 'conv2d_39/Kernel', 144 | 'iris_contour.1.conv_prelu.0.bias': 'conv2d_39/Bias', 145 | 'iris_contour.1.conv_prelu.1.weight': 'p_re_lu_39/Alpha', 146 | 'iris_contour.1.depthwiseconv_conv.0.weight': 'depthwise_conv2d_19/Kernel', 147 | 'iris_contour.1.depthwiseconv_conv.0.bias': 'depthwise_conv2d_19/Bias', 148 | 'iris_contour.1.depthwiseconv_conv.1.weight': 'conv2d_40/Kernel', 149 | 'iris_contour.1.depthwiseconv_conv.1.bias': 'conv2d_40/Bias', 150 | 'iris_contour.1.prelu.weight': 'p_re_lu_40/Alpha', 151 | 'iris_contour.2.conv_prelu.0.weight': 'conv2d_41/Kernel', 152 | 'iris_contour.2.conv_prelu.0.bias': 'conv2d_41/Bias', 153 | 'iris_contour.2.conv_prelu.1.weight': 'p_re_lu_41/Alpha', 154 | 'iris_contour.2.depthwiseconv_conv.0.weight': 'depthwise_conv2d_20/Kernel', 155 | 'iris_contour.2.depthwiseconv_conv.0.bias': 'depthwise_conv2d_20/Bias', 156 | 'iris_contour.2.depthwiseconv_conv.1.weight': 'conv2d_42/Kernel', 157 | 'iris_contour.2.depthwiseconv_conv.1.bias': 'conv2d_42/Bias', 158 | 'iris_contour.2.prelu.weight': 'p_re_lu_42/Alpha', 159 | 'iris_contour.3.conv_prelu.0.weight': 'conv2d_43/Kernel', 160 | 'iris_contour.3.conv_prelu.0.bias': 'conv2d_43/Bias', 161 | 'iris_contour.3.conv_prelu.1.weight': 'p_re_lu_43/Alpha', 162 | 'iris_contour.3.depthwiseconv_conv.0.weight': 'depthwise_conv2d_21/Kernel', 163 | 'iris_contour.3.depthwiseconv_conv.0.bias': 'depthwise_conv2d_21/Bias', 164 | 'iris_contour.3.depthwiseconv_conv.1.weight': 'conv2d_44/Kernel', 165 | 'iris_contour.3.depthwiseconv_conv.1.bias': 'conv2d_44/Bias', 166 | 'iris_contour.3.prelu.weight': 'p_re_lu_44/Alpha', 167 | 'iris_contour.4.conv_prelu.0.weight': 'conv2d_45/Kernel', 168 | 'iris_contour.4.conv_prelu.0.bias': 'conv2d_45/Bias', 169 | 'iris_contour.4.conv_prelu.1.weight': 'p_re_lu_45/Alpha', 170 | 'iris_contour.4.depthwiseconv_conv.0.weight': 'depthwise_conv2d_22/Kernel', 171 | 'iris_contour.4.depthwiseconv_conv.0.bias': 'depthwise_conv2d_22/Bias', 172 | 'iris_contour.4.depthwiseconv_conv.1.weight': 'conv2d_46/Kernel', 173 | 'iris_contour.4.depthwiseconv_conv.1.bias': 'conv2d_46/Bias', 174 | 'iris_contour.4.prelu.weight': 'p_re_lu_46/Alpha', 175 | 'iris_contour.5.conv_prelu.0.weight': 'conv2d_47/Kernel', 176 | 'iris_contour.5.conv_prelu.0.bias': 'conv2d_47/Bias', 177 | 'iris_contour.5.conv_prelu.1.weight': 'p_re_lu_47/Alpha', 178 | 'iris_contour.5.depthwiseconv_conv.0.weight': 'depthwise_conv2d_23/Kernel', 179 | 'iris_contour.5.depthwiseconv_conv.0.bias': 'depthwise_conv2d_23/Bias', 180 | 'iris_contour.5.depthwiseconv_conv.1.weight': 'conv2d_48/Kernel', 181 | 'iris_contour.5.depthwiseconv_conv.1.bias': 'conv2d_48/Bias', 182 | 'iris_contour.5.prelu.weight': 'p_re_lu_48/Alpha', 183 | 'iris_contour.6.conv_prelu.0.weight': 'conv2d_49/Kernel', 184 | 'iris_contour.6.conv_prelu.0.bias': 'conv2d_49/Bias', 185 | 'iris_contour.6.conv_prelu.1.weight': 'p_re_lu_49/Alpha', 186 | 'iris_contour.6.depthwiseconv_conv.0.weight': 'depthwise_conv2d_24/Kernel', 187 | 'iris_contour.6.depthwiseconv_conv.0.bias': 'depthwise_conv2d_24/Bias', 188 | 'iris_contour.6.depthwiseconv_conv.1.weight': 'conv2d_50/Kernel', 189 | 'iris_contour.6.depthwiseconv_conv.1.bias': 'conv2d_50/Bias', 190 | 'iris_contour.6.prelu.weight': 'p_re_lu_50/Alpha', 191 | 'iris_contour.7.conv_prelu.0.weight': 'conv2d_51/Kernel', 192 | 'iris_contour.7.conv_prelu.0.bias': 'conv2d_51/Bias', 193 | 'iris_contour.7.conv_prelu.1.weight': 'p_re_lu_51/Alpha', 194 | 'iris_contour.7.depthwiseconv_conv.0.weight': 'depthwise_conv2d_25/Kernel', 195 | 'iris_contour.7.depthwiseconv_conv.0.bias': 'depthwise_conv2d_25/Bias', 196 | 'iris_contour.7.depthwiseconv_conv.1.weight': 'conv2d_52/Kernel', 197 | 'iris_contour.7.depthwiseconv_conv.1.bias': 'conv2d_52/Bias', 198 | 'iris_contour.7.prelu.weight': 'p_re_lu_52/Alpha', 199 | 'iris_contour.8.weight': 'conv_iris/Kernel', 200 | 'iris_contour.8.bias': 'conv_iris/Bias' 201 | } --------------------------------------------------------------------------------