└── gene_expression.py /gene_expression.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Models for gene expression from DNA.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | from six.moves import range # pylint: disable=redefined-builtin 20 | 21 | from tensor2tensor.layers import common_hparams 22 | from tensor2tensor.layers import common_layers 23 | from tensor2tensor.utils import contrib 24 | from tensor2tensor.utils import registry 25 | from tensor2tensor.utils import t2t_model 26 | 27 | import tensorflow.compat.v1 as tf 28 | 29 | 30 | @registry.register_model 31 | class GeneExpressionConv(t2t_model.T2TModel): 32 | """Gene expression conv net. 33 | 34 | Based on "Basenji" model from 35 | http://www.biorxiv.org/content/early/2017/07/10/161851 36 | 37 | Uses layer_norm instead of batch_norm. 38 | 39 | Model expects that if targets are of length m, inputs are of length 32*m. The 40 | original data expected that inputs would be of length 128*m, but the data has 41 | been preprocessed to chunk every 4 bases into 1 ID (see 42 | data_generators/gene_expression.py). 43 | 44 | The magnitude of the length reduction is controlled by the pooling sizes 45 | (hparams.pooling_windows) at each conv layer (hparams.num_conv_layers). 46 | """ 47 | 48 | def body(self, features): 49 | inputs = features["inputs"] 50 | inputs.get_shape().assert_has_rank(4) 51 | 52 | hp = self._hparams 53 | 54 | out = inputs 55 | out = common_layers.flatten4d3d(out) 56 | 57 | # Conv layers 58 | assert hp.num_conv_layers == len(hp.pooling_windows) 59 | for i in range(hp.num_conv_layers): 60 | out = conv_layer( 61 | out, 62 | hp.hidden_size, 63 | hp.kernel_width, 64 | hp.stride, 65 | hp.pooling_windows[i], 66 | hp.dropout, 67 | dilation_rate=1, 68 | name="conv_%d" % (i + 1)) 69 | 70 | # Dense dilated conv layers 71 | for i in range(hp.num_dconv_layers): 72 | dilation_rate = 2**(i + 1) 73 | dconv_out = conv_layer( 74 | out, 75 | hp.hidden_size, 76 | hp.kernel_width, 77 | stride=1, 78 | pooling_window=0, 79 | dropout_rate=hp.dropout, 80 | dilation_rate=dilation_rate, 81 | name="dconv_%d" % (i + 1)) 82 | out = tf.concat([out, dconv_out], axis=2) 83 | 84 | # Fully connected layer 85 | out = fc_layer(out, hp.hidden_size, hp.dropout, name="fc") 86 | 87 | out.get_shape().assert_has_rank(3) 88 | out = tf.expand_dims(out, 2) 89 | return out 90 | 91 | 92 | def conv_layer(x, 93 | hidden_size, 94 | kernel_size, 95 | stride, 96 | pooling_window, 97 | dropout_rate, 98 | dilation_rate, 99 | name="conv"): 100 | """Single conv layer with relu, optional pooling, and dropout.""" 101 | with tf.variable_scope(name): 102 | out = x 103 | out = common_layers.conv1d_block( 104 | out, 105 | hidden_size, [(dilation_rate, kernel_size)], 106 | strides=stride, 107 | first_relu=False, 108 | padding="same") 109 | out = tf.nn.relu(out) 110 | if pooling_window: 111 | out = tf.layers.max_pooling1d( 112 | out, pooling_window, pooling_window, padding="same") 113 | out = tf.layers.dropout(out, dropout_rate) 114 | return out 115 | 116 | 117 | def fc_layer(x, num_out, dropout_rate, name="fc"): 118 | with tf.variable_scope(name): 119 | out = x 120 | out = tf.layers.dense(out, num_out) 121 | out = contrib.layers().layer_norm(out) 122 | out = tf.nn.relu(out) 123 | out = tf.layers.dropout(out, dropout_rate) 124 | return out 125 | 126 | 127 | @registry.register_hparams 128 | def gene_expression_conv_base(): 129 | """Hparams for GeneExpressionConv model.""" 130 | hparams = common_hparams.basic_params1() 131 | 132 | batch_size = 10 133 | output_length = 2048 134 | inputs_per_output = 128 135 | chunk_size = 4 136 | input_length = output_length * inputs_per_output // chunk_size 137 | hparams.batch_size = input_length * batch_size 138 | 139 | hparams.dropout = 0.1 140 | hparams.add_hparam("num_conv_layers", 4) 141 | hparams.add_hparam("num_dconv_layers", 7) 142 | # The product of these pooling windows should match 143 | # input_length/target_length. 144 | hparams.add_hparam("pooling_windows", [2, 2, 2, 4]) 145 | 146 | hparams.hidden_size = 256 147 | hparams.kernel_width = 20 148 | hparams.add_hparam("stride", 1) 149 | return hparams 150 | --------------------------------------------------------------------------------