├── README.md ├── SSRNET_model.py ├── SSR_module.swift ├── keras2coreml.py ├── ssrnet.h5 ├── ssrnet.mlmodel └── ssrnet_3_3_3_64_1.0_1.0.h5 /README.md: -------------------------------------------------------------------------------- 1 | # Keras-to-coreml-multiple-inputs-example 2 | 3 | ## Introduction 4 | Considering Keras is a convenient framework for building deep learning structure, we usually use it to develop our own network. 5 | 6 | However, network with complex custom layer is not directly supported by the "coremltools", and cannot be easily convert to coreml model (iOS friendly framework). 7 | 8 | + In this project, I share the way to rewrite the ***multiple inputs custom Lambda layer in Keras*** into swift function to support coreml model. 9 | + I use SSR-Net as our example. For more information, please go to https://github.com/shamangary/SSR-Net 10 | 11 | ## How to run? 12 | ``` 13 | python keras2coreml.py 14 | ``` 15 | This command convert the keras model and weight file into coreml model. 16 | 17 | ### Multiple inputs Lambda layer bug? 18 | Before I write this github repository, coremltools does not support multiple inputs Lambda layer. 19 | 20 | https://github.com/apple/coremltools/issues/188 21 | 22 | In the above issue, they fixed it by adding some lines into their code. Plz check it out. 23 | 24 | ## Custom swift class 25 | After the conversion between keras and coreml, the model is not directly usable for iOS app. 26 | You need to define your own class for the custom layer for the model. 27 | 28 | The guide of http://machinethink.net/blog/coreml-custom-layers/ is very useful. 29 | However, we need a more complex custom Lambda layer other than 1 input 1 output. 30 | 31 | Considering our Soft Stagewise Regression (SSR_module) contains 9 inputs and 1 output, 32 | I rewrite the swift class "SSR_module.swift". 33 | 34 | ``` 35 | # Target custom layer in Keras SSR-Net model 36 | 37 | def merge_age(x,s1,s2,s3,lambda_local,lambda_d): 38 | a = x[0][:,0]*0 39 | b = x[0][:,0]*0 40 | c = x[0][:,0]*0 41 | V = 101 42 | 43 | for i in range(0,s1): 44 | a = a+(i+lambda_local*x[6][:,i])*x[0][:,i] 45 | a = K.expand_dims(a,-1) 46 | a = a/(s1*(1+lambda_d*x[3])) 47 | 48 | for j in range(0,s2): 49 | b = b+(j+lambda_local*x[7][:,j])*x[1][:,j] 50 | b = K.expand_dims(b,-1) 51 | b = b/(s1*(1+lambda_d*x[3]))/(s2*(1+lambda_d*x[4])) 52 | 53 | for k in range(0,s3): 54 | c = c+(k+lambda_local*x[8][:,k])*x[2][:,k] 55 | c = K.expand_dims(c,-1) 56 | c = c/(s1*(1+lambda_d*x[3]))/(s2*(1+lambda_d*x[4]))/(s3*(1+lambda_d*x[5])) 57 | 58 | 59 | age = (a+b+c)*V 60 | return age 61 | 62 | pred_a = Lambda(merge_age,arguments={'s1':self.stage_num[0],'s2':self.stage_num[1],'s3':self.stage_num[2],'lambda_local':self.lambda_local,'lambda_d':self.lambda_d},output_shape=(1,),name='pred_a')([pred_a_s1,pred_a_s2,pred_a_s3,delta_s1,delta_s2,delta_s3, local_s1, local_s2, local_s3]) 63 | 64 | 65 | ## Swift class version evaluation function 66 | 67 | func evaluate(inputs: [MLMultiArray], outputs: [MLMultiArray]) throws { 68 | print(#function, inputs.count, outputs.count) 69 | var a: Double = 0 70 | for i in 0...2 { 71 | let index_a: [NSNumber] = [0,0,i as NSNumber,0,0] 72 | a = a + (Double(i)+inputs[6][index_a].doubleValue)*inputs[0][index_a].doubleValue 73 | } 74 | let index_0: [NSNumber] = [0,0,0,0,0] 75 | a = a/(3.0*(1.0+inputs[3][index_0].doubleValue)) 76 | 77 | var b: Double = 0 78 | for j in 0...2 { 79 | let index_b: [NSNumber] = [0,0,j as NSNumber,0,0] 80 | b = b + (Double(j)+inputs[7][index_b].doubleValue)*inputs[1][index_b].doubleValue 81 | } 82 | b = b/(3.0*(1.0+inputs[3][index_0].doubleValue))/(3.0*(1.0+inputs[4][index_0].doubleValue)) 83 | 84 | var c: Double = 0 85 | for k in 0...2 { 86 | let index_c: [NSNumber] = [0,0,k as NSNumber,0,0] 87 | c = c + (Double(k)+inputs[8][index_c].doubleValue)*inputs[2][index_c].doubleValue 88 | } 89 | c = c/(3.0*(1.0+inputs[3][index_0].doubleValue))/(3.0*(1.0+inputs[4][index_0].doubleValue))/(3.0*(1.0+inputs[5][index_0].doubleValue)) 90 | 91 | let age: Double = (a+b+c)*101.0 92 | print(age) // for xcode console debug 93 | outputs[0][index_0] = NSNumber(value: age) 94 | 95 | } 96 | ``` 97 | 98 | 99 | ## Reference 100 | + http://machinethink.net/blog/coreml-custom-layers/ 101 | + https://github.com/hollance/CoreML-Custom-Layers 102 | + https://github.com/apple/coremltools 103 | -------------------------------------------------------------------------------- /SSRNET_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | import numpy as np 4 | from keras.models import Model 5 | from keras.layers import Input, Activation, add, Dense, Flatten, Dropout, Multiply, Embedding, Lambda, Add, Concatenate, Activation 6 | from keras.layers.convolutional import Conv2D, AveragePooling2D, MaxPooling2D 7 | from keras.layers.normalization import BatchNormalization 8 | from keras.regularizers import l2 9 | from keras import backend as K 10 | from keras.optimizers import SGD,Adam 11 | from keras.utils import plot_model 12 | from keras.engine.topology import Layer 13 | from keras import activations, initializers, regularizers, constraints 14 | 15 | sys.setrecursionlimit(2 ** 20) 16 | np.random.seed(2 ** 10) 17 | 18 | 19 | class SSR_net: 20 | def __init__(self, image_size,stage_num,lambda_local,lambda_d): 21 | 22 | if K.image_dim_ordering() == "th": 23 | logging.debug("image_dim_ordering = 'th'") 24 | self._channel_axis = 1 25 | self._input_shape = (3, image_size, image_size) 26 | else: 27 | logging.debug("image_dim_ordering = 'tf'") 28 | self._channel_axis = -1 29 | self._input_shape = (image_size, image_size, 3) 30 | 31 | 32 | self.stage_num = stage_num 33 | self.lambda_local = lambda_local 34 | self.lambda_d = lambda_d 35 | 36 | # def create_model(self): 37 | def __call__(self): 38 | logging.debug("Creating model...") 39 | 40 | 41 | inputs = Input(shape=self._input_shape) 42 | 43 | #------------------------------------------------------------------------------------------------------------------------- 44 | x = Conv2D(32,(3,3))(inputs) 45 | x = BatchNormalization(axis=self._channel_axis)(x) 46 | x = Activation('relu')(x) 47 | x_layer1 = AveragePooling2D(2,2)(x) 48 | x = Conv2D(32,(3,3))(x_layer1) 49 | x = BatchNormalization(axis=self._channel_axis)(x) 50 | x = Activation('relu')(x) 51 | x_layer2 = AveragePooling2D(2,2)(x) 52 | x = Conv2D(32,(3,3))(x_layer2) 53 | x = BatchNormalization(axis=self._channel_axis)(x) 54 | x = Activation('relu')(x) 55 | x_layer3 = AveragePooling2D(2,2)(x) 56 | x = Conv2D(32,(3,3))(x_layer3) 57 | x = BatchNormalization(axis=self._channel_axis)(x) 58 | x = Activation('relu')(x) 59 | #------------------------------------------------------------------------------------------------------------------------- 60 | s = Conv2D(16,(3,3))(inputs) 61 | s = BatchNormalization(axis=self._channel_axis)(s) 62 | s = Activation('tanh')(s) 63 | s_layer1 = MaxPooling2D(2,2)(s) 64 | s = Conv2D(16,(3,3))(s_layer1) 65 | s = BatchNormalization(axis=self._channel_axis)(s) 66 | s = Activation('tanh')(s) 67 | s_layer2 = MaxPooling2D(2,2)(s) 68 | s = Conv2D(16,(3,3))(s_layer2) 69 | s = BatchNormalization(axis=self._channel_axis)(s) 70 | s = Activation('tanh')(s) 71 | s_layer3 = MaxPooling2D(2,2)(s) 72 | s = Conv2D(16,(3,3))(s_layer3) 73 | s = BatchNormalization(axis=self._channel_axis)(s) 74 | s = Activation('tanh')(s) 75 | 76 | 77 | #------------------------------------------------------------------------------------------------------------------------- 78 | # Classifier block 79 | s_layer4 = Conv2D(10,(1,1),activation='relu')(s) 80 | s_layer4 = Flatten()(s_layer4) 81 | s_layer4_mix = Dropout(0.2)(s_layer4) 82 | s_layer4_mix = Dense(units=self.stage_num[0], activation="relu")(s_layer4_mix) 83 | 84 | x_layer4 = Conv2D(10,(1,1),activation='relu')(x) 85 | x_layer4 = Flatten()(x_layer4) 86 | x_layer4_mix = Dropout(0.2)(x_layer4) 87 | x_layer4_mix = Dense(units=self.stage_num[0], activation="relu")(x_layer4_mix) 88 | 89 | feat_a_s1_pre = Multiply()([s_layer4,x_layer4]) 90 | delta_s1 = Dense(1,activation='tanh',name='delta_s1')(feat_a_s1_pre) 91 | 92 | feat_a_s1 = Multiply()([s_layer4_mix,x_layer4_mix]) 93 | feat_a_s1 = Dense(2*self.stage_num[0],activation='relu')(feat_a_s1) 94 | pred_a_s1 = Dense(units=self.stage_num[0], activation="relu",name='pred_age_stage1')(feat_a_s1) 95 | #feat_local_s1 = Lambda(lambda x: x/10)(feat_a_s1) 96 | #feat_a_s1_local = Dropout(0.2)(pred_a_s1) 97 | local_s1 = Dense(units=self.stage_num[0], activation='tanh', name='local_delta_stage1')(feat_a_s1) 98 | #------------------------------------------------------------------------------------------------------------------------- 99 | s_layer2 = Conv2D(10,(1,1),activation='relu')(s_layer2) 100 | s_layer2 = MaxPooling2D(4,4)(s_layer2) 101 | s_layer2 = Flatten()(s_layer2) 102 | s_layer2_mix = Dropout(0.2)(s_layer2) 103 | s_layer2_mix = Dense(self.stage_num[1],activation='relu')(s_layer2_mix) 104 | 105 | x_layer2 = Conv2D(10,(1,1),activation='relu')(x_layer2) 106 | x_layer2 = AveragePooling2D(4,4)(x_layer2) 107 | x_layer2 = Flatten()(x_layer2) 108 | x_layer2_mix = Dropout(0.2)(x_layer2) 109 | x_layer2_mix = Dense(self.stage_num[1],activation='relu')(x_layer2_mix) 110 | 111 | feat_a_s2_pre = Multiply()([s_layer2,x_layer2]) 112 | delta_s2 = Dense(1,activation='tanh',name='delta_s2')(feat_a_s2_pre) 113 | 114 | feat_a_s2 = Multiply()([s_layer2_mix,x_layer2_mix]) 115 | feat_a_s2 = Dense(2*self.stage_num[1],activation='relu')(feat_a_s2) 116 | pred_a_s2 = Dense(units=self.stage_num[1], activation="relu",name='pred_age_stage2')(feat_a_s2) 117 | #feat_local_s2 = Lambda(lambda x: x/10)(feat_a_s2) 118 | #feat_a_s2_local = Dropout(0.2)(pred_a_s2) 119 | local_s2 = Dense(units=self.stage_num[1], activation='tanh', name='local_delta_stage2')(feat_a_s2) 120 | #------------------------------------------------------------------------------------------------------------------------- 121 | s_layer1 = Conv2D(10,(1,1),activation='relu')(s_layer1) 122 | s_layer1 = MaxPooling2D(8,8)(s_layer1) 123 | s_layer1 = Flatten()(s_layer1) 124 | s_layer1_mix = Dropout(0.2)(s_layer1) 125 | s_layer1_mix = Dense(self.stage_num[2],activation='relu')(s_layer1_mix) 126 | 127 | x_layer1 = Conv2D(10,(1,1),activation='relu')(x_layer1) 128 | x_layer1 = AveragePooling2D(8,8)(x_layer1) 129 | x_layer1 = Flatten()(x_layer1) 130 | x_layer1_mix = Dropout(0.2)(x_layer1) 131 | x_layer1_mix = Dense(self.stage_num[2],activation='relu')(x_layer1_mix) 132 | 133 | feat_a_s3_pre = Multiply()([s_layer1,x_layer1]) 134 | delta_s3 = Dense(1,activation='tanh',name='delta_s3')(feat_a_s3_pre) 135 | 136 | feat_a_s3 = Multiply()([s_layer1_mix,x_layer1_mix]) 137 | feat_a_s3 = Dense(2*self.stage_num[2],activation='relu')(feat_a_s3) 138 | pred_a_s3 = Dense(units=self.stage_num[2], activation="relu",name='pred_age_stage3')(feat_a_s3) 139 | #feat_local_s3 = Lambda(lambda x: x/10)(feat_a_s3) 140 | #feat_a_s3_local = Dropout(0.2)(pred_a_s3) 141 | local_s3 = Dense(units=self.stage_num[2], activation='tanh', name='local_delta_stage3')(feat_a_s3) 142 | #------------------------------------------------------------------------------------------------------------------------- 143 | 144 | def merge_age(x,s1,s2,s3,lambda_local,lambda_d): 145 | a = x[0][:,0]*0 146 | b = x[0][:,0]*0 147 | c = x[0][:,0]*0 148 | A = s1*s2*s3 149 | V = 101 150 | 151 | for i in range(0,s1): 152 | a = a+(i+lambda_local*x[6][:,i])*x[0][:,i] 153 | a = K.expand_dims(a,-1) 154 | a = a/(s1*(1+lambda_d*x[3])) 155 | 156 | for j in range(0,s2): 157 | b = b+(j+lambda_local*x[7][:,j])*x[1][:,j] 158 | b = K.expand_dims(b,-1) 159 | b = b/(s1*(1+lambda_d*x[3]))/(s2*(1+lambda_d*x[4])) 160 | 161 | for k in range(0,s3): 162 | c = c+(k+lambda_local*x[8][:,k])*x[2][:,k] 163 | c = K.expand_dims(c,-1) 164 | c = c/(s1*(1+lambda_d*x[3]))/(s2*(1+lambda_d*x[4]))/(s3*(1+lambda_d*x[5])) 165 | 166 | 167 | age = (a+b+c)*V 168 | return age 169 | 170 | pred_a = Lambda(merge_age,arguments={'s1':self.stage_num[0],'s2':self.stage_num[1],'s3':self.stage_num[2],'lambda_local':self.lambda_local,'lambda_d':self.lambda_d},output_shape=(1,),name='pred_a')([pred_a_s1,pred_a_s2,pred_a_s3,delta_s1,delta_s2,delta_s3, local_s1, local_s2, local_s3]) 171 | model = Model(inputs=inputs, outputs=pred_a) 172 | 173 | return model 174 | 175 | 176 | -------------------------------------------------------------------------------- /SSR_module.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | import CoreML 3 | import Accelerate 4 | 5 | @objc(SSR_module) class SSR_module: NSObject, MLCustomLayer { 6 | 7 | required init(parameters: [String : Any]) throws { 8 | print(#function, parameters) 9 | super.init() 10 | } 11 | 12 | func outputShapes(forInputShapes inputShapes: [[NSNumber]]) throws -> [[NSNumber]] { 13 | let out_shape: [NSNumber] = [1,1,1,1,1] 14 | return [out_shape] 15 | // Somehow the output shape has to be 5-dim. 16 | // Even the real prediction is only 1-dim. 17 | // Just check the input shape after you use keras coreml converter and follow the pattern of it. 18 | } 19 | 20 | func setWeightData(_ weights: [Data]) throws { 21 | print(#function, weights) 22 | } 23 | 24 | func evaluate(inputs: [MLMultiArray], outputs: [MLMultiArray]) throws { 25 | print(#function, inputs.count, outputs.count) 26 | var a: Double = 0 27 | for i in 0...2 { 28 | let index_a: [NSNumber] = [0,0,i as NSNumber,0,0] 29 | a = a + (Double(i)+inputs[6][index_a].doubleValue)*inputs[0][index_a].doubleValue 30 | } 31 | let index_0: [NSNumber] = [0,0,0,0,0] 32 | a = a/(3.0*(1.0+inputs[3][index_0].doubleValue)) 33 | 34 | var b: Double = 0 35 | for j in 0...2 { 36 | let index_b: [NSNumber] = [0,0,j as NSNumber,0,0] 37 | b = b + (Double(j)+inputs[7][index_b].doubleValue)*inputs[1][index_b].doubleValue 38 | } 39 | b = b/(3.0*(1.0+inputs[3][index_0].doubleValue))/(3.0*(1.0+inputs[4][index_0].doubleValue)) 40 | 41 | var c: Double = 0 42 | for k in 0...2 { 43 | let index_c: [NSNumber] = [0,0,k as NSNumber,0,0] 44 | c = c + (Double(k)+inputs[8][index_c].doubleValue)*inputs[2][index_c].doubleValue 45 | } 46 | c = c/(3.0*(1.0+inputs[3][index_0].doubleValue))/(3.0*(1.0+inputs[4][index_0].doubleValue))/(3.0*(1.0+inputs[5][index_0].doubleValue)) 47 | 48 | let age: Double = (a+b+c)*101.0 49 | print(age) // for xcode console debug 50 | outputs[0][index_0] = NSNumber(value: age) 51 | 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /keras2coreml.py: -------------------------------------------------------------------------------- 1 | from SSRNET_model import SSR_net 2 | import coremltools 3 | from coremltools.proto import NeuralNetwork_pb2 4 | 5 | import sys, os 6 | import numpy as np 7 | 8 | import keras 9 | from keras.models import * 10 | from keras.layers import * 11 | from keras.preprocessing.image import load_img, img_to_array 12 | 13 | 14 | # The conversion function for Lambda layers. 15 | def convert_lambda(layer): 16 | # Only convert this Lambda layer if it is for our swish function. 17 | 18 | if layer.name == 'pred_a': 19 | params = NeuralNetwork_pb2.CustomLayerParams() 20 | 21 | # The name of the Swift or Obj-C class that implements this layer. 22 | params.className = "SSR_module" 23 | 24 | # The desciption is shown in Xcode's mlmodel viewer. 25 | params.description = "Soft Stagewise Regression" 26 | params.parameters["s1"].doubleValue = layer.arguments['s1'] 27 | params.parameters["s2"].doubleValue = layer.arguments['s2'] 28 | params.parameters["s3"].doubleValue = layer.arguments['s3'] 29 | params.parameters["lambda_d"].doubleValue = layer.arguments['lambda_d'] 30 | params.parameters["lambda_local"].doubleValue = layer.arguments['lambda_local'] 31 | 32 | return params 33 | else: 34 | return None 35 | 36 | def main(): 37 | 38 | weight_file = "./ssrnet_3_3_3_64_1.0_1.0.h5" 39 | 40 | # load model and weights 41 | img_size = 64 42 | stage_num = [3,3,3] 43 | lambda_local = 1 44 | lambda_d = 1 45 | model = SSR_net(img_size,stage_num, lambda_local, lambda_d)() 46 | model.load_weights(weight_file) 47 | model.save('ssrnet.h5') 48 | model.summary() 49 | 50 | 51 | coreml_model = coremltools.converters.keras.convert( 52 | model, 53 | input_names="image", 54 | image_input_names="image", 55 | output_names="output", 56 | add_custom_layers=True, 57 | custom_conversion_functions={ "Lambda": convert_lambda }) 58 | 59 | 60 | # Look at the layers in the converted Core ML model. 61 | print("\nLayers in the converted model:") 62 | for i, layer in enumerate(coreml_model._spec.neuralNetwork.layers): 63 | if layer.HasField("custom"): 64 | print("Layer %d = %s --> custom layer = %s" % (i, layer.name, layer.custom.className)) 65 | else: 66 | print("Layer %d = %s" % (i, layer.name)) 67 | coreml_model.save('ssrnet.mlmodel') 68 | 69 | 70 | if __name__ == '__main__': 71 | main() -------------------------------------------------------------------------------- /ssrnet.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shamangary/Keras-to-coreml-multiple-inputs-example/a1b1dd4aa6a408a11ab71f014cc31639907337e0/ssrnet.h5 -------------------------------------------------------------------------------- /ssrnet.mlmodel: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shamangary/Keras-to-coreml-multiple-inputs-example/a1b1dd4aa6a408a11ab71f014cc31639907337e0/ssrnet.mlmodel -------------------------------------------------------------------------------- /ssrnet_3_3_3_64_1.0_1.0.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shamangary/Keras-to-coreml-multiple-inputs-example/a1b1dd4aa6a408a11ab71f014cc31639907337e0/ssrnet_3_3_3_64_1.0_1.0.h5 --------------------------------------------------------------------------------