├── __init__.py ├── README.md ├── tf_cnn_basic.py ├── octConv_resnet.py ├── oct_Resnet_unit.py └── tf_octConv.py /__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tensorflow_octConv 2 | 3 | ### Paper:《Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks with Octave Convolution》. 4 | 5 | Implementation of [OctaveConv](https://arxiv.org/abs/1904.05049) in Tensorflow 6 | 7 | 8 | 9 | 10 | NOTE:The results are coming. 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | Code modification based on [terrychenism](https://github.com/terrychenism/OctaveConv)! 19 | Thanks! 20 | -------------------------------------------------------------------------------- /tf_cnn_basic.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def BN(data, bn_momentum=0.9, name=None): 5 | return tf.layers.batch_normalization(data, momentum=bn_momentum, name=('%s__bn' % name)) 6 | 7 | 8 | def AC(data, name=None): 9 | return tf.nn.relu(data, name=('%s__relu' % name)) 10 | 11 | 12 | def BN_AC(data, momentum=0.9, name=None): 13 | bn = BN(data=data, name=name) 14 | bn_ac = AC(data=bn, name=name) 15 | return bn_ac 16 | 17 | 18 | def Conv(data, num_filter, kernel, stride=(1, 1), pad='valid', name=None, no_bias=False, w=None, b=None, attr=None, 19 | num_group=1): 20 | if w is None: 21 | conv = tf.layers.conv2d(inputs=data, filters=num_filter, kernel_size=kernel, 22 | strides=stride, padding=pad, name=('%s__conv' % name), use_bias=no_bias) 23 | else: 24 | if b is None: 25 | conv = tf.layers.conv2d(data=data, num_filter=num_filter, kernel_size=kernel, 26 | stride=stride, padding=pad, name=('%s__conv' % name), use_bias=no_bias, 27 | kernel_initializer=w) 28 | else: 29 | conv = tf.layers.conv2d(data=data, num_filter=num_filter, kernel_size=kernel, 30 | stride=stride, padding=pad, name=('%s__conv' % name), use_bias=True, 31 | kernel_initializer=w, bias_initializer=b) 32 | return conv 33 | 34 | 35 | # - - - - - - - - - - - - - - - - - - - - - - - 36 | # Standard Common functions < CVPR > 37 | def Conv_BN(data, num_filter, kernel, pad, stride=(1, 1), name=None, w=None, b=None, no_bias=False, attr=None, 38 | num_group=1): 39 | cov = Conv(data=data, num_filter=num_filter, num_group=num_group, kernel=kernel, pad=pad, stride=stride, name=name, 40 | w=w, b=b, no_bias=no_bias, attr=attr) 41 | cov_bn = BN(data=cov, name=('%s__bn' % name)) 42 | return cov_bn 43 | 44 | 45 | def Conv_BN_AC(data, num_filter, kernel, pad, stride=(1, 1), name=None, w=None, b=None, no_bias=False, attr=None, 46 | num_group=1): 47 | cov_bn = Conv_BN(data=data, num_filter=num_filter, num_group=num_group, kernel=kernel, pad=pad, stride=stride, 48 | name=name, w=w, b=b, no_bias=no_bias, attr=attr) 49 | cov_ba = AC(data=cov_bn, name=('%s__ac' % name)) 50 | return cov_ba 51 | 52 | 53 | # - - - - - - - - - - - - - - - - - - - - - - - 54 | # Standard Common functions < ECCV > 55 | def BN_Conv(data, num_filter, kernel, pad, stride=(1, 1), name=None, w=None, b=None, no_bias=False, attr=None, 56 | num_group=1): 57 | bn = BN(data=data, name=('%s__bn' % name)) 58 | bn_cov = Conv(data=bn, num_filter=num_filter, num_group=num_group, kernel=kernel, pad=pad, stride=stride, name=name, 59 | w=w, b=b, no_bias=no_bias, attr=attr) 60 | return bn_cov 61 | 62 | def AC_Conv(data, num_filter, kernel, pad, stride=(1, 1), name=None, w=None, b=None, no_bias=False, attr=None, 63 | num_group=1): 64 | ac = AC(data=data, name=('%s__ac' % name)) 65 | ac_cov = Conv(data=ac, num_filter=num_filter, num_group=num_group, kernel=kernel, pad=pad, stride=stride, name=name, 66 | w=w, b=b, no_bias=no_bias, attr=attr) 67 | return ac_cov 68 | 69 | 70 | def BN_AC_Conv(data, num_filter, kernel, pad, stride=(1, 1), name=None, w=None, b=None, no_bias=False, attr=None, 71 | num_group=1): 72 | bn = BN(data=data, name=('%s__bn' % name)) 73 | ba_cov = AC_Conv(data=bn, num_filter=num_filter, num_group=num_group, kernel=kernel, pad=pad, stride=stride, 74 | name=name, w=w, b=b, no_bias=no_bias, attr=attr) 75 | return ba_cov 76 | 77 | def Pooling(data, pool_type='avg', kernel=(2, 2),pad='valid', stride=(2, 2), name=None): 78 | if pool_type == 'avg': 79 | return tf.layers.average_pooling2d(inputs=data, pool_size=kernel, strides=stride, padding=pad, name=name) 80 | elif pool_type == 'max': 81 | return tf.layers.max_pooling2d(inputs=data, pool_size=kernel, strides=stride, padding=pad, name=name) 82 | 83 | def ElementWiseSum(x, y, name=None): 84 | return tf.add(x=x, y=y, name=name) 85 | 86 | def UpSampling(lf_conv, scale=2, sample_type='nearest',num_args=1, name=None): 87 | return tf.keras.layers.UpSampling2D(size=(scale, scale), name=name)(lf_conv) 88 | -------------------------------------------------------------------------------- /octConv_resnet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from tf_octConv import * 4 | from tf_cnn_basic import * 5 | from oct_Resnet_unit import * 6 | 7 | 8 | G = 1 9 | alpha = 0.25 10 | use_fp16 = True 11 | k_sec = {2: 3, 3: 4, 4: 6, 5: 3} 12 | 13 | 14 | def get_before_pool(): 15 | 16 | data = tf.Variable(name="data") 17 | data = tf.cast(x=data, dtype=np.float16) if use_fp16 else data 18 | 19 | # conv1 20 | conv1 = Conv_BN_AC(data=data, num_filter=64, kernel=(7, 7), name='conv1', pad='same', stride=(2, 2)) 21 | pool1 = Pooling(data=conv1, pool_type="max", kernel=(3, 3), pad='same', stride=(2, 2), name="pool1") 22 | 23 | # conv2 24 | num_in = 32 25 | num_mid = 64 26 | num_out = 256 27 | i = 1 28 | hf_conv1_x, lf_conv1_x = Residual_Unit_first( 29 | data=pool1, 30 | alpha=alpha, 31 | num_in=(num_in if i == 1 else num_out), 32 | num_mid=num_mid, 33 | num_out=num_out, 34 | name=('conv2_B%02d' % i), 35 | first_block=(i == 1), 36 | stride=((1, 1) if (i == 1) else (1, 1))) 37 | 38 | for i in range(2, k_sec[2] + 1): 39 | hf_conv2_x, lf_conv2_x = Residual_Unit( 40 | hf_data=(hf_conv1_x if i == 2 else hf_conv2_x), 41 | lf_data=(lf_conv1_x if i == 2 else lf_conv2_x), 42 | alpha=alpha, 43 | num_in=(num_in if i == 1 else num_out), 44 | num_mid=num_mid, 45 | num_out=num_out, 46 | name=('conv2_B%02d' % i), 47 | first_block=(i == 1), 48 | stride=((1, 1) if (i == 1) else (1, 1))) 49 | 50 | # conv3 51 | num_in = num_out 52 | num_mid = int(num_mid * 2) 53 | num_out = int(num_out * 2) 54 | for i in range(1, k_sec[3] + 1): 55 | hf_conv3_x, lf_conv3_x = Residual_Unit( 56 | hf_data=(hf_conv2_x if i == 1 else hf_conv3_x), 57 | lf_data=(lf_conv2_x if i == 1 else lf_conv3_x), 58 | alpha=alpha, 59 | num_in=(num_in if i == 1 else num_out), 60 | num_mid=num_mid, 61 | num_out=num_out, 62 | name=('conv3_B%02d' % i), 63 | first_block=(i == 1), 64 | stride=((2, 2) if (i == 1) else (1, 1))) 65 | 66 | 67 | # conv4 68 | num_in = num_out 69 | num_mid = int(num_mid * 2) 70 | num_out = int(num_out * 2) 71 | for i in range(1, k_sec[4] + 1): 72 | hf_conv4_x, lf_conv4_x = Residual_Unit( 73 | hf_data=(hf_conv3_x if i == 1 else hf_conv4_x), 74 | lf_data=(lf_conv3_x if i == 1 else lf_conv4_x), 75 | alpha=alpha, 76 | num_in=(num_in if i == 1 else num_out), 77 | num_mid=num_mid, 78 | num_out=num_out, 79 | name=('conv4_B%02d' % i), 80 | first_block=(i == 1), 81 | stride=((2, 2) if (i == 1) else (1, 1))) 82 | 83 | 84 | # conv5 85 | num_in = num_out 86 | num_mid = int(num_mid * 2) 87 | num_out = int(num_out * 2) 88 | i = 1 89 | conv5_x = Residual_Unit_last( 90 | hf_data=hf_conv4_x, 91 | lf_data=lf_conv4_x, 92 | alpha=alpha, 93 | num_in=(num_in if i == 1 else num_out), 94 | num_mid=num_mid, 95 | num_out=num_out, 96 | name=('conv5_B%02d' % i), 97 | first_block=(i == 1), 98 | stride=((2, 2) if (i == 1) else (1, 1))) 99 | 100 | for i in range(2, k_sec[5] + 1): 101 | conv5_x = Residual_Unit_norm(data=conv5_x, 102 | num_in=num_out, 103 | num_mid=num_mid, 104 | num_out=num_out, 105 | name=('conv5_B%02d' % i), 106 | first_block=(i == 1), 107 | stride=((2, 2) if (i == 1) else (1, 1))) 108 | 109 | output = tf.cast(x=conv5_x, dtype=np.float32) if use_fp16 else conv5_x 110 | # output 111 | return output 112 | 113 | 114 | def get_linear(num_classes=10): 115 | before_pool = get_before_pool() 116 | pool5 = Pooling(data=before_pool, pool_type="avg", kernel=(7, 7), stride=(1, 1), name="global-pool") 117 | flat5 = tf.layers.flatten(input=pool5, name='flatten') 118 | fc6 = tf.layers.dense(inputs=flat5, units=num_classes, name='classifier') 119 | return fc6 120 | 121 | 122 | def get_symbol(num_classes=10): 123 | fc6 = get_linear(num_classes) 124 | softmax = tf.nn.softmax(logits=fc6, name='softmax') 125 | sys_out = softmax 126 | return sys_out 127 | -------------------------------------------------------------------------------- /oct_Resnet_unit.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tf_cnn_basic import * 3 | from tf_octConv import * 4 | 5 | 6 | def Residual_Unit_norm(data, num_in, num_mid, num_out, name, first_block=False, stride=(1, 1), g=1): 7 | conv_m1 = Conv_BN_AC(data=data, num_filter=num_mid, kernel=(1, 1), pad='valid', name=('%s_conv-m1' % name)) 8 | conv_m2 = Conv_BN_AC(data=conv_m1, num_filter=num_mid, kernel=(3, 3), pad='same', name=('%s_conv-m2' % name), 9 | stride=stride, num_group=g) 10 | conv_m3 = Conv_BN(data=conv_m2, num_filter=num_out, kernel=( 1, 1), pad='valid', name=('%s_conv-m3' % name)) 11 | 12 | if first_block: 13 | data = Conv_BN(data=data, num_filter=num_out, kernel=( 1, 1), pad='valid', name=('%s_conv-w1' % name), 14 | stride=stride) 15 | 16 | outputs = ElementWiseSum(data, conv_m3, name=('%s_sum' % name)) 17 | return AC(outputs) 18 | 19 | 20 | def Residual_Unit_last(hf_data, lf_data, alpha, num_in, num_mid, num_out, name, first_block=False, stride=(1, 1), g=1): 21 | hf_data_m, lf_data_m = octConv_BN_AC(hf_data=hf_data, lf_data=lf_data, alpha=alpha, num_filter_in=num_in, 22 | num_filter_out=num_mid, kernel=( 1, 1), pad='valid', 23 | name=('%s_conv-m1' % name)) 24 | conv_m2 = lastOctConv_BN_AC(hf_data=hf_data_m, lf_data=lf_data_m, alpha=alpha, num_filter_in=num_mid, 25 | num_filter_out=num_mid, name=('%s_conv-m2' % name), kernel=(3, 3), pad='same', 26 | stride=stride) 27 | conv_m3 = Conv_BN(data=conv_m2, num_filter=num_out, kernel=( 1, 1), pad='valid', name=('%s_conv-m3' % name)) 28 | 29 | if first_block: 30 | data = lastOctConv_BN(hf_data=hf_data, lf_data=lf_data, alpha=alpha, num_filter_in=num_in, 31 | num_filter_out=num_out, name=('%s_conv-w1' % name), kernel=(1, 1), pad='valid', 32 | stride=stride) 33 | 34 | outputs = ElementWiseSum(data, conv_m3, name=('%s_sum' % name)) 35 | outputs = AC(outputs, name=('%s_act' % name)) 36 | return outputs 37 | 38 | 39 | def Residual_Unit_first(data, alpha, num_in, num_mid, num_out, name, first_block=False, stride=(1, 1), g=1): 40 | hf_data_m, lf_data_m = firstOctConv_BN_AC(data=data, alpha=alpha, num_filter_in=num_in, num_filter_out=num_mid, 41 | kernel=( 1, 1), pad='valid', name=('%s_conv-m1' % name)) 42 | hf_data_m, lf_data_m = octConv_BN_AC(hf_data=hf_data_m, lf_data=lf_data_m, alpha=alpha, num_filter_in=num_mid, 43 | num_filter_out=num_mid, kernel=( 3, 3), pad='same', 44 | name=('%s_conv-m2' % name), stride=stride, num_group=g) 45 | hf_data_m, lf_data_m = octConv_BN(hf_data=hf_data_m, lf_data=lf_data_m, alpha=alpha, num_filter_in=num_mid, 46 | num_filter_out=num_out, kernel=( 1, 1), pad='valid', name=('%s_conv-m3' % name)) 47 | 48 | if first_block: 49 | hf_data, lf_data = firstOctConv_BN(data=data, alpha=alpha, num_filter_in=num_in, num_filter_out=num_out, 50 | kernel=( 1, 1), pad='valid', name=('%s_conv-w1' % name), stride=stride) 51 | 52 | hf_outputs = ElementWiseSum(hf_data, hf_data_m, name=('%s_hf_sum' % name)) 53 | lf_outputs = ElementWiseSum(lf_data, lf_data_m, name=('%s_lf_sum' % name)) 54 | 55 | hf_outputs = AC(hf_outputs, name=('%s_hf_act' % name)) 56 | lf_outputs = AC(lf_outputs, name=('%s_lf_act' % name)) 57 | return hf_outputs, lf_outputs 58 | 59 | 60 | def Residual_Unit(hf_data, lf_data, alpha, num_in, num_mid, num_out, name, first_block=False, stride=(1, 1), g=1): 61 | hf_data_m, lf_data_m = octConv_BN_AC(hf_data=hf_data, lf_data=lf_data, alpha=alpha, num_filter_in=num_in, 62 | num_filter_out=num_mid, kernel=( 1, 1), pad='valid', 63 | name=('%s_conv-m1' % name)) 64 | hf_data_m, lf_data_m = octConv_BN_AC(hf_data=hf_data_m, lf_data=lf_data_m, alpha=alpha, num_filter_in=num_mid, 65 | num_filter_out=num_mid, kernel=( 3, 3), pad='same', 66 | name=('%s_conv-m2' % name), stride=stride, num_group=g) 67 | hf_data_m, lf_data_m = octConv_BN(hf_data=hf_data_m, lf_data=lf_data_m, alpha=alpha, num_filter_in=num_mid, 68 | num_filter_out=num_out, kernel=( 1, 1), pad='valid', name=('%s_conv-m3' % name)) 69 | 70 | if first_block: 71 | hf_data, lf_data = octConv_BN(hf_data=hf_data, lf_data=lf_data, alpha=alpha, num_filter_in=num_in, 72 | num_filter_out=num_out, kernel=( 1, 1), pad='valid', name=('%s_conv-w1' % name), 73 | stride=stride) 74 | 75 | hf_outputs = ElementWiseSum(hf_data, hf_data_m, name=('%s_hf_sum' % name)) 76 | lf_outputs = ElementWiseSum(lf_data, lf_data_m, name=('%s_lf_sum' % name)) 77 | 78 | hf_outputs = AC(hf_outputs, name=('%s_hf_act' % name)) 79 | lf_outputs = AC(lf_outputs, name=('%s_lf_act' % name)) 80 | return hf_outputs, lf_outputs 81 | -------------------------------------------------------------------------------- /tf_octConv.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tf_cnn_basic import * 3 | 4 | def firstOctConv(data, settings, ch_in, ch_out, name, kernel=(1,1), pad='valid', stride=(1,1)): 5 | alpha_in, alpha_out = settings 6 | hf_ch_in = int(ch_in * (1 - alpha_in)) 7 | hf_ch_out = int(ch_out * (1 - alpha_out)) 8 | 9 | lf_ch_in = ch_in - hf_ch_in 10 | lf_ch_out = ch_out - hf_ch_out 11 | 12 | hf_data = data 13 | 14 | if stride == (2, 2): 15 | hf_data = Pooling(data=hf_data, pool_type='avg', kernel=(2,2), stride=(2,2), name=('%s_hf_down' % name)) 16 | hf_conv = Conv(data=hf_data, num_filter=hf_ch_out, kernel=kernel, pad=pad, stride=(1,1), name=('%s_hf_conv' % name)) 17 | hf_pool = Pooling(data=hf_data, pool_type='avg', kernel=(2,2), stride=(2,2), name=('%s_hf_pool' % name)) 18 | hf_pool_conv = Conv(data=hf_pool, num_filter=lf_ch_out, kernel=kernel, pad=pad, stride=(1,1), name=('%s_hf_pool_conv' % name)) 19 | 20 | out_h = hf_conv 21 | out_l = hf_pool_conv 22 | return out_h, out_l 23 | 24 | def lastOctConv(hf_data, lf_data, settings, ch_in, ch_out, name, kernel=(1,1), pad='valid', stride=(1,1)): 25 | alpha_in, alpha_out = settings 26 | hf_ch_in = int(ch_in * (1 - alpha_in)) 27 | hf_ch_out = int(ch_out * (1 - alpha_out)) 28 | 29 | if stride == (2, 2): 30 | hf_data = Pooling(data=hf_data, pool_type='avg', kernel=(2,2), stride=(2,2), name=('%s_hf_down' % name)) 31 | hf_conv = Conv(data=hf_data, num_filter=hf_ch_out, kernel=kernel, pad=pad, stride=(1,1), name=('%s_hf_conv' % name)) 32 | 33 | lf_conv = Conv(data=lf_data, num_filter=hf_ch_out, kernel=kernel, pad=pad, stride=(1,1), name=('%s_lf_conv' % name)) 34 | out_h = hf_conv + lf_conv 35 | 36 | return out_h 37 | 38 | def OctConv(hf_data, lf_data, settings, ch_in, ch_out, name, kernel=(1,1), pad='valid', stride=(1,1)): 39 | alpha_in, alpha_out = settings 40 | hf_ch_in = int(ch_in * (1 - alpha_in)) 41 | hf_ch_out = int(ch_out * (1 - alpha_out)) 42 | 43 | lf_ch_in = ch_in - hf_ch_in 44 | lf_ch_out = ch_out - hf_ch_out 45 | 46 | if stride == (2, 2): 47 | hf_data = Pooling(data=hf_data, pool_type='avg', kernel=(2,2), stride=(2,2), name=('%s_hf_down' % name)) 48 | hf_conv = Conv(data=hf_data, num_filter=hf_ch_out, kernel=kernel, pad=pad, stride=(1,1), name=('%s_hf_conv' % name)) 49 | hf_pool = Pooling(data=hf_data, pool_type='avg', kernel=(2,2), stride=(2,2), name=('%s_hf_pool' % name)) 50 | hf_pool_conv = Conv(data=hf_pool, num_filter=lf_ch_out, kernel=kernel, pad=pad, stride=(1,1), name=('%s_hf_pool_conv' % name)) 51 | 52 | lf_conv = Conv(data=lf_data, num_filter=hf_ch_out, kernel=kernel, pad=pad, stride=(1,1), name=('%s_lf_conv' % name)) 53 | if stride == (2, 2): 54 | lf_upsample = lf_conv 55 | lf_down = Pooling(data=lf_data, pool_type='avg', kernel=(2,2), stride=(2,2), name=('%s_lf_down' % name)) 56 | else: 57 | lf_upsample = UpSampling(lf_conv, scale=2, sample_type='nearest',num_args=1, name='%s_lf_upsample' % name) 58 | lf_down = lf_data 59 | lf_down_conv = Conv(data=lf_down, num_filter=lf_ch_out, kernel=kernel, pad=pad, stride=(1,1), name=('%s_lf_down_conv' % name)) 60 | 61 | out_h = hf_conv + lf_upsample 62 | out_l = hf_pool_conv + lf_down_conv 63 | 64 | return out_h, out_l 65 | 66 | 67 | def firstOctConv_BN_AC(data, alpha, num_filter_in, num_filter_out, kernel, pad, stride=(1,1), name=None, w=None, b=None, no_bias=True, attr=None, num_group=1): 68 | hf_data, lf_data = firstOctConv(data=data, settings=(0, alpha), ch_in=num_filter_in, ch_out=num_filter_out, name=name, kernel=kernel, pad=pad, stride=stride) 69 | out_hf = BN_AC(data=hf_data, name=('%s_hf') % name) 70 | out_lf = BN_AC(data=lf_data, name=('%s_lf') % name) 71 | return out_hf, out_lf 72 | 73 | def lastOctConv_BN_AC(hf_data, lf_data, alpha, num_filter_in, num_filter_out, kernel, pad, stride=(1,1), name=None, w=None, b=None, no_bias=True, attr=None, num_group=1): 74 | conv = lastOctConv(hf_data=hf_data, lf_data=lf_data, settings=(alpha, 0), ch_in=num_filter_in, ch_out=num_filter_out, name=name, kernel=kernel, pad=pad, stride=stride) 75 | out = BN_AC(data=conv, name=name) 76 | return out 77 | 78 | def octConv_BN_AC(hf_data, lf_data, alpha, num_filter_in, num_filter_out, kernel, pad, stride=(1,1), name=None, w=None, b=None, no_bias=True, attr=None, num_group=1): 79 | hf_data, lf_data = OctConv(hf_data=hf_data, lf_data=lf_data, settings=(alpha, alpha), ch_in=num_filter_in, ch_out=num_filter_out, name=name, kernel=kernel, pad=pad, stride=stride) 80 | out_hf = BN_AC(data=hf_data, name=('%s_hf') % name) 81 | out_lf = BN_AC(data=lf_data, name=('%s_lf') % name) 82 | return out_hf, out_lf 83 | 84 | 85 | def firstOctConv_BN(data, alpha, num_filter_in, num_filter_out, kernel, pad, stride=(1,1), name=None, w=None, b=None, no_bias=True, attr=None, num_group=1): 86 | hf_data, lf_data = firstOctConv(data=data, settings=(0, alpha), ch_in=num_filter_in, ch_out=num_filter_out, name=name, kernel=kernel, pad=pad, stride=stride) 87 | out_hf = BN(data=hf_data, name=('%s_hf') % name) 88 | out_lf = BN(data=lf_data, name=('%s_lf') % name) 89 | return out_hf, out_lf 90 | 91 | def lastOctConv_BN(hf_data, lf_data, alpha, num_filter_in, num_filter_out, kernel, pad, stride=(1,1), name=None, w=None, b=None, no_bias=True, attr=None, num_group=1): 92 | conv = lastOctConv(hf_data=hf_data, lf_data=lf_data, settings=(alpha, 0), ch_in=num_filter_in, ch_out=num_filter_out, name=name, kernel=kernel, pad=pad, stride=stride) 93 | out = BN(data=conv, name=name) 94 | return out 95 | 96 | def octConv_BN(hf_data, lf_data, alpha, num_filter_in, num_filter_out, kernel, pad, stride=(1,1), name=None, w=None, b=None, no_bias=True, attr=None, num_group=1): 97 | hf_data, lf_data = OctConv(hf_data=hf_data, lf_data=lf_data, settings=(alpha, alpha), ch_in=num_filter_in, ch_out=num_filter_out, name=name, kernel=kernel, pad=pad, stride=stride) 98 | out_hf = BN(data=hf_data, name=('%s_hf') % name) 99 | out_lf = BN(data=lf_data, name=('%s_lf') % name) 100 | return out_hf, out_lf 101 | 102 | 103 | 104 | 105 | 106 | --------------------------------------------------------------------------------