├── TFSCL_model.png ├── README.md └── TFSCL.py /TFSCL_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HymHust/TFSCL/HEAD/TFSCL_model.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TFSCL 2 | TFSCL 3 | Model structure source code for article published in IEEE Transactions on Industrial Informatics: Cross-Domain Compound Fault Diagnosis of Machine-Level Motors via Time–Frequency Self-Contrastive Learning 4 | -------------------------------------------------------------------------------- /TFSCL.py: -------------------------------------------------------------------------------- 1 | import keras_metrics as km 2 | import matplotlib.pyplot as plt 3 | from keras.optimizers import adam_v2 4 | import tensorflow as tf 5 | from keras import backend as K 6 | from keras import layers as KL 7 | import numpy as np 8 | import pandas as pd 9 | from tensorflow.keras.callbacks import ModelCheckpoint 10 | from keras.layers import Conv1D, Dense, Dropout, add,Lambda,GlobalAveragePooling1D,\ 11 | BatchNormalization, AveragePooling1D, Activation, Flatten,Input,concatenate,LayerNormalization 12 | from tensorflow.keras.utils import plot_model 13 | from keras.models import Model 14 | from keras.regularizers import l2 15 | 16 | config = tf.compat.v1.ConfigProto() 17 | config.gpu_options.per_process_gpu_memory_fraction = 0.95 18 | sess=tf.compat.v1.Session(config=config) 19 | 20 | num_classes = 10 21 | 22 | def FFTLayer(x): 23 | x_comp = tf.cast(x, tf.complex64) 24 | x_fft = tf.signal.fft(x_comp) 25 | x_abs = tf.abs(x_fft) 26 | return x_abs 27 | 28 | 29 | def Scos(f1, f2): 30 | f1 = tf.math.l2_normalize(f1, axis=1) 31 | f2 = tf.math.l2_normalize(f2, axis=1) 32 | cos=tf.reduce_mean(tf.reduce_sum((f1 * f2), axis=1)) 33 | return (1-cos) 34 | 35 | class CL_Loss(KL.Layer): 36 | def __init__(self, **kwargs): 37 | super(CL_Loss, self).__init__(**kwargs) 38 | def call(self, inputs, **kwargs): 39 | 40 | f1,f2 = inputs 41 | loss = K.mean(Scos(f1, f2)) 42 | 43 | self.add_loss(loss, inputs=True ) 44 | self.add_metric(loss, aggregation="mean", name="CL_loss") 45 | return loss 46 | 47 | def cnn_model(filters, kernerl_size, strides, conv_padding, dil_rate, inputs): 48 | x = Conv1D(filters=16, kernel_size=3, strides=1, 49 | padding='same', kernel_regularizer=l2(1e-4),activation=tf.nn.gelu)(inputs) 50 | x = Conv1D(filters=filters, kernel_size=kernerl_size, strides=strides, 51 | padding=conv_padding, dilation_rate=dil_rate, kernel_regularizer=l2(1e-4),activation=tf.nn.gelu)(x) 52 | return x 53 | 54 | 55 | def time_brach(inputs,BatchNormal=True): 56 | modelA = cnn_model(filters=16, kernerl_size=3, strides=1, conv_padding='same', dil_rate=1, inputs=inputs) 57 | modelB = cnn_model(filters=16, kernerl_size=3, strides=1, conv_padding='same', dil_rate=2, inputs=inputs) 58 | modelC = cnn_model(filters=16, kernerl_size=3, strides=1, conv_padding='same', dil_rate=3, inputs=inputs) 59 | combined = concatenate([modelA, modelB, modelC]) 60 | x = Conv1D(filters=32, kernel_size=3, strides=1, 61 | padding='same', kernel_regularizer=l2(1e-4),activation='relu')(combined) 62 | if BatchNormal: 63 | x = BatchNormalization()(x) 64 | x = Conv1D(filters=64, kernel_size=3, strides=1, 65 | padding='same', kernel_regularizer=l2(1e-4),activation='relu')(x) 66 | if BatchNormal: 67 | x = BatchNormalization()(x) 68 | return x 69 | 70 | def freq_brach(inputs): 71 | x = Lambda(FFTLayer)(inputs) 72 | x = Conv1D(filters=16, kernel_size=3, strides=1, 73 | padding='same', kernel_regularizer=l2(1e-4),activation=tf.nn.gelu)(x) 74 | x = Conv1D(filters=32, kernel_size=3, strides=1, 75 | padding='same', kernel_regularizer=l2(1e-4),activation=tf.nn.gelu)(x) 76 | x = Conv1D(filters=32, kernel_size=3, strides=1, 77 | padding='same', kernel_regularizer=l2(1e-4),activation=tf.nn.gelu)(x) 78 | x = Conv1D(filters=64, kernel_size=3, strides=1, 79 | padding='same', kernel_regularizer=l2(1e-4),activation=tf.nn.gelu)(x) 80 | 81 | return x 82 | 83 | 84 | def TFSCL(x_shape=(5120,10)): 85 | inputs_x = Input(x_shape, name='x_train') 86 | 87 | fea1 = freq_brach(inputs_x) 88 | fea2 = time_brach(inputs_x) 89 | 90 | res_t=Conv1D(filters=64, kernel_size=1, strides=1, 91 | padding='same', kernel_regularizer=l2(1e-4),activation=tf.nn.gelu)(inputs_x) 92 | 93 | 94 | res_f= Lambda(FFTLayer)(inputs_x) 95 | res_f=Conv1D(filters=64, kernel_size=1, strides=1, 96 | padding='same', kernel_regularizer=l2(1e-4),activation=tf.nn.gelu)(res_f) 97 | 98 | 99 | fea1 = add([fea1,res_t]) 100 | fea2 = add([fea2,res_f]) 101 | loss = CL_Loss()([fea1, fea2]) 102 | con=concatenate([fea1,fea2])+loss 103 | con=GlobalAveragePooling1D()(con) 104 | 105 | 106 | y_pred = Dense(units=num_classes, activation='sigmoid', kernel_regularizer=l2(1e-4),name='output')(con) 107 | 108 | 109 | 110 | model = Model(inputs=inputs_x, outputs=y_pred) 111 | adam = adam_v2.Adam(learning_rate=0.001) 112 | model.compile(optimizer=adam,loss='binary_crossentropy', metrics='acc') 113 | 114 | model.summary() 115 | plot_model(model=model, to_file='TFSCL_model.png', show_shapes=True) 116 | return model 117 | 118 | 119 | 120 | def Train_Eval(): 121 | #path = r"" #your data path 122 | #x_train, y_train, x_valid, y_valid, x_test, y_test = # Read the dataset through your own functions 123 | 124 | x_shape = x_train.shape[1:] 125 | 126 | model = TFSCL(x_shape=x_shape) 127 | model.summary() 128 | callback_list = [ModelCheckpoint(filepath='propose.hdf5', verbose=1, save_best_only=True, monitor="loss")] 129 | model.fit([x_train,y_train],batch_size=64, epochs=20,shuffle=True,verbose=1, callbacks=callback_list) 130 | model.load_weights('propose.hdf5') 131 | 132 | 133 | pred_model = Model(inputs=model.get_layer('x_train').input,outputs=model.get_layer('output').output) 134 | 135 | adam = adam_v2.Adam(learning_rate=0.001) 136 | pred_model.compile(optimizer=adam, loss='binary_crossentropy', 137 | metrics='acc') 138 | 139 | loss, accuracy = model.evaluate(x_test, y_test) 140 | print("test loss:", loss) 141 | print("test accuracy", accuracy) 142 | 143 | 144 | 145 | if __name__ == "__main__": 146 | Model=TFSCL() 147 | 148 | 149 | 150 | --------------------------------------------------------------------------------