├── LICENSE ├── README.md ├── ctc.py ├── edit_distance.py ├── ext_param_info.py ├── main_timit.py ├── main_toy_dataset.py ├── timit.py └── toy_dataset.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | CTC-LSTM 2 | ========================================= 3 | 4 | This repositry contains an implementation of the CTC cost function (Graves et al., 2006). 5 | 6 | - CTC cost is implemented in pure [Theano](https://github.com/Theano/Theano). 7 | 8 | - Supports mini-batch. 9 | 10 | To avoid numerical underflow, two solutions are implemented: 11 | 12 | - Normalization of the alphas at each timestep 13 | - Calculations in the logarithmic domain 14 | 15 | This repository also contains sample code for applying CTC to two datasets, a simple dummy dataset constituted of artificial data, and code to use the TIMIT dataset. The models are implemented using [Blocks](https://github.com/mila-udem/blocks). Both datasets are implemented using [Fuel](https://github.com/mila-udem/fuel). 16 | 17 | The model on the TIMIT dataset is able to learn up to 50% phoneme accuracy using no handcrafted processing of the signal, but instead uses an end-to-end model composed of convolutions, LSTMs, and the CTC cost function. 18 | 19 | 20 | Reference 21 | ========= 22 | Graves, Alex, et al. *Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks.* Proceedings of the 23rd international conference on Machine learning. ACM, 2006. 23 | 24 | 25 | Credits 26 | ======= 27 | [Alex Auvolat](https://github.com/Alexis211) 28 | 29 | [Thomas Mesnard](https://github.com/thomasmesnard) 30 | 31 | 32 | Special thanks to 33 | ================= 34 | [Mohammad Pezeshki](https://github.com/mohammadpz/) 35 | -------------------------------------------------------------------------------- /ctc.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | 3 | import theano 4 | from theano import tensor, scan 5 | 6 | from blocks.bricks import Brick 7 | 8 | # T: INPUT_SEQUENCE_LENGTH 9 | # B: BATCH_SIZE 10 | # L: OUTPUT_SEQUENCE_LENGTH 11 | # C: NUM_CLASSES 12 | class CTC(Brick): 13 | def apply(self, l, probs, l_len=None, probs_mask=None): 14 | """ 15 | Numeration: 16 | Characters 0 to C-1 are true characters 17 | Character C is the blank character 18 | Inputs: 19 | l : L x B : the sequence labelling 20 | probs : T x B x C+1 : the probabilities output by the RNN 21 | l_len : B : the length of each labelling sequence 22 | probs_mask : T x B 23 | Output: the B probabilities of the labelling sequences 24 | Steps: 25 | - Calculate y' the labelling sequence with blanks 26 | - Calculate the recurrence relationship for the alphas 27 | - Calculate the sequence of the alphas 28 | - Return the probability found at the end of that sequence 29 | """ 30 | T = probs.shape[0] 31 | B = probs.shape[1] 32 | C = probs.shape[2]-1 33 | L = l.shape[0] 34 | S = 2*L+1 35 | 36 | # l_blk = l with interleaved blanks 37 | l_blk = C * tensor.ones((S, B), dtype='int32') 38 | l_blk = tensor.set_subtensor(l_blk[1::2,:], l) 39 | l_blk = l_blk.T # now l_blk is B x S 40 | 41 | # dimension of alpha (corresponds to alpha hat in the paper) : 42 | # T x B x S 43 | # dimension of c : 44 | # T x B 45 | # first value of alpha (size B x S) 46 | alpha0 = tensor.concatenate([ tensor.ones((B, 1)), 47 | tensor.zeros((B, S-1)) 48 | ], axis=1) 49 | c0 = tensor.ones((B,)) 50 | 51 | # recursion 52 | l_blk_2 = tensor.concatenate([-tensor.ones((B,2)), l_blk[:,:-2]], axis=1) 53 | l_case2 = tensor.neq(l_blk, C) * tensor.neq(l_blk, l_blk_2) 54 | # l_case2 is B x S 55 | 56 | def recursion(p, p_mask, prev_alpha, prev_c): 57 | # p is B x C+1 58 | # prev_alpha is B x S 59 | prev_alpha_1 = tensor.concatenate([tensor.zeros((B,1)),prev_alpha[:,:-1]], axis=1) 60 | prev_alpha_2 = tensor.concatenate([tensor.zeros((B,2)),prev_alpha[:,:-2]], axis=1) 61 | 62 | alpha_bar = prev_alpha + prev_alpha_1 63 | alpha_bar = tensor.switch(l_case2, alpha_bar + prev_alpha_2, alpha_bar) 64 | next_alpha = alpha_bar * p[tensor.arange(B)[:,None].repeat(S,axis=1).flatten(), l_blk.flatten()].reshape((B,S)) 65 | next_alpha = tensor.switch(p_mask[:,None], next_alpha, prev_alpha) 66 | next_alpha = next_alpha * tensor.lt(tensor.arange(S)[None,:], (2*l_len+1)[:, None]) 67 | next_c = next_alpha.sum(axis=1) 68 | 69 | return next_alpha / next_c[:, None], next_c 70 | 71 | # apply the recursion with scan 72 | [alpha, c], _ = scan(fn=recursion, 73 | sequences=[probs, probs_mask], 74 | outputs_info=[alpha0, c0]) 75 | 76 | # c = theano.printing.Print('c')(c) 77 | last_alpha = alpha[-1] 78 | # last_alpha = theano.printing.Print('a-1')(last_alpha) 79 | 80 | prob = tensor.log(c).sum(axis=0) + tensor.log(last_alpha[tensor.arange(B), 2*l_len.astype('int32')-1] 81 | + last_alpha[tensor.arange(B), 2*l_len.astype('int32')] 82 | + 1e-30) 83 | 84 | # return the log probability of the labellings 85 | return -prob 86 | 87 | def apply_log_domain(self, l, probs, l_len=None, probs_mask=None): 88 | # Does the same computation as apply, but alpha is in the log domain 89 | # This avoids numerical underflow issues that were not corrected in the previous version. 90 | 91 | def _log(a): 92 | return tensor.log(tensor.clip(a, 1e-12, 1e12)) 93 | 94 | def _log_add(a, b): 95 | maximum = tensor.maximum(a, b) 96 | return (maximum + tensor.log1p(tensor.exp(a + b - 2 * maximum))) 97 | 98 | def _log_mul(a, b): 99 | return a + b 100 | 101 | # See comments above 102 | B = probs.shape[1] 103 | C = probs.shape[2]-1 104 | L = l.shape[0] 105 | S = 2*L+1 106 | 107 | l_blk = C * tensor.ones((S, B), dtype='int32') 108 | l_blk = tensor.set_subtensor(l_blk[1::2,:], l) 109 | l_blk = l_blk.T # now l_blk is B x S 110 | 111 | alpha0 = tensor.concatenate([ tensor.ones((B, 1)), 112 | tensor.zeros((B, S-1)) 113 | ], axis=1) 114 | alpha0 = _log(alpha0) 115 | 116 | l_blk_2 = tensor.concatenate([-tensor.ones((B,2)), l_blk[:,:-2]], axis=1) 117 | l_case2 = tensor.neq(l_blk, C) * tensor.neq(l_blk, l_blk_2) 118 | 119 | def recursion(p, p_mask, prev_alpha): 120 | prev_alpha_1 = tensor.concatenate([tensor.zeros((B,1)),prev_alpha[:,:-1]], axis=1) 121 | prev_alpha_2 = tensor.concatenate([tensor.zeros((B,2)),prev_alpha[:,:-2]], axis=1) 122 | 123 | alpha_bar1 = tensor.set_subtensor(prev_alpha[:,1:], _log_add(prev_alpha[:,1:],prev_alpha[:,:-1])) 124 | alpha_bar2 = tensor.set_subtensor(alpha_bar1[:,2:], _log_add(alpha_bar1[:,2:],prev_alpha[:,:-2])) 125 | 126 | alpha_bar = tensor.switch(l_case2, alpha_bar2, alpha_bar1) 127 | 128 | probs = _log(p[tensor.arange(B)[:,None].repeat(S,axis=1).flatten(), l_blk.flatten()].reshape((B,S))) 129 | next_alpha = _log_mul(alpha_bar, probs) 130 | next_alpha = tensor.switch(p_mask[:,None], next_alpha, prev_alpha) 131 | 132 | return next_alpha 133 | 134 | alpha, _ = scan(fn=recursion, 135 | sequences=[probs, probs_mask], 136 | outputs_info=[alpha0]) 137 | 138 | last_alpha = alpha[-1] 139 | # last_alpha = theano.printing.Print('a-1')(last_alpha) 140 | 141 | prob = _log_add(last_alpha[tensor.arange(B), 2*l_len.astype('int32')-1], 142 | last_alpha[tensor.arange(B), 2*l_len.astype('int32')]) 143 | 144 | # return the negative log probability of the labellings 145 | return -prob 146 | 147 | 148 | def best_path_decoding(self, probs, probs_mask=None): 149 | # probs is T x B x C+1 150 | T = probs.shape[0] 151 | B = probs.shape[1] 152 | C = probs.shape[2]-1 153 | 154 | maxprob = probs.argmax(axis=2) 155 | is_double = tensor.eq(maxprob[:-1], maxprob[1:]) 156 | maxprob = tensor.switch(tensor.concatenate([tensor.zeros((1,B)), is_double]), 157 | C*tensor.ones_like(maxprob), maxprob) 158 | # maxprob = theano.printing.Print('maxprob')(maxprob.T).T 159 | 160 | # returns two values : 161 | # label : (T x) T x B 162 | # label_length : (T x) B 163 | def recursion(maxp, p_mask, label_length, label): 164 | nonzero = p_mask * tensor.neq(maxp, C) 165 | nonzero_id = nonzero.nonzero()[0] 166 | 167 | new_label = tensor.set_subtensor(label[label_length[nonzero_id], nonzero_id], maxp[nonzero_id]) 168 | new_label_length = tensor.switch(nonzero, label_length + numpy.int32(1), label_length) 169 | 170 | return new_label_length, new_label 171 | 172 | [label_length, label], _ = scan(fn=recursion, 173 | sequences=[maxprob, probs_mask], 174 | outputs_info=[tensor.zeros((B,),dtype='int32'),-tensor.ones((T,B))]) 175 | 176 | return label[-1], label_length[-1] 177 | 178 | def prefix_search(self, probs, probs_mask=None): 179 | # Hard one... 180 | pass 181 | 182 | 183 | 184 | # vim: set sts=4 ts=4 sw=4 sw=4 tw=0 et: 185 | -------------------------------------------------------------------------------- /edit_distance.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import theano 3 | from theano import tensor 4 | 5 | @theano.compile.ops.as_op(itypes=[tensor.imatrix, tensor.ivector, tensor.imatrix, tensor.ivector], 6 | otypes=[tensor.ivector]) 7 | def batch_edit_distance(a, a_len, b, b_len): 8 | B = a.shape[0] 9 | assert b.shape[0] == B 10 | 11 | for i in range(B): 12 | print "A:", a[i, :a_len[i]] 13 | print "B:", b[i, :b_len[i]] 14 | 15 | q = max(a.shape[1], b.shape[1]) * numpy.ones((B, a.shape[1]+1, b.shape[1]+1), dtype='int32') 16 | q[:, 0, 0] = 0 17 | 18 | for i in range(a.shape[1]+1): 19 | for j in range(b.shape[1]+1): 20 | if i > 0: 21 | q[:, i, j] = numpy.minimum(q[:, i, j], q[:, i-1, j]+1) 22 | if j > 0: 23 | q[:, i, j] = numpy.minimum(q[:, i, j], q[:, i, j-1]+1) 24 | if i > 0 and j > 0: 25 | q[:, i, j] = numpy.minimum(q[:, i, j], q[:, i-1, j-1]+numpy.not_equal(a[:, i-1], b[:, j-1])) 26 | return q[numpy.arange(B), a_len, b_len] 27 | 28 | # vim: set sts=4 ts=4 sw=4 tw=0 et : 29 | -------------------------------------------------------------------------------- /ext_param_info.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy 4 | 5 | import cPickle 6 | 7 | from blocks.extensions import SimpleExtension 8 | 9 | logging.basicConfig(level='INFO') 10 | logger = logging.getLogger('extensions.ParamInfo') 11 | 12 | class ParamInfo(SimpleExtension): 13 | def __init__(self, model, **kwargs): 14 | super(ParamInfo, self).__init__(**kwargs) 15 | 16 | self.model = model 17 | 18 | def do(self, which_callback, *args): 19 | print("---- PARAMETER INFO ----") 20 | print("\tmin\tmax\tmean\tvar\tdim\t\tname") 21 | for k, v in self.model.get_parameter_values().iteritems(): 22 | print("\t%.4f\t%.4f\t%.4f\t%.4f\t%13s\t%s"% 23 | (v.min(), v.max(), v.mean(), ((v-v.mean())**2).mean(), 'x'.join([repr(x) for x in v.shape]), k)) 24 | 25 | -------------------------------------------------------------------------------- /main_timit.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import theano 4 | import numpy 5 | from theano import tensor 6 | 7 | from blocks.bricks import Linear, Tanh, Rectifier 8 | from blocks.bricks.conv import Convolutional, MaxPooling 9 | from blocks.bricks.lookup import LookupTable 10 | from blocks.bricks.recurrent import SimpleRecurrent, LSTM 11 | from blocks.initialization import IsotropicGaussian, Constant 12 | 13 | from blocks.algorithms import (GradientDescent, Scale, AdaDelta, RemoveNotFinite, RMSProp, BasicMomentum, 14 | StepClipping, CompositeRule, Momentum) 15 | from blocks.graph import ComputationGraph 16 | from blocks.model import Model 17 | from blocks.main_loop import MainLoop 18 | 19 | from blocks.filter import VariableFilter 20 | from blocks.roles import WEIGHT, BIAS 21 | from blocks.graph import ComputationGraph, apply_dropout, apply_noise 22 | 23 | from blocks.extensions import ProgressBar 24 | from blocks.extensions.monitoring import TrainingDataMonitoring, DataStreamMonitoring 25 | from blocks.extensions import FinishAfter, Printing 26 | 27 | from blocks.extras.extensions.plot import Plot 28 | 29 | 30 | from ctc import CTC 31 | from timit import setup_datastream 32 | 33 | from edit_distance import batch_edit_distance 34 | from ext_param_info import ParamInfo 35 | 36 | 37 | # ========================================================================================== 38 | # THE HYPERPARAMETERS 39 | # ========================================================================================== 40 | 41 | # Stop after this many epochs 42 | n_epochs = 10000 43 | # How often (number of batches) to print / plot 44 | monitor_freq = 50 45 | 46 | sort_batch_count = 50 47 | batch_size = 100 48 | 49 | # The convolutionnal layers. Parameters: 50 | # nfilter the number of filters 51 | # filter_size the size of the filters (number of timesteps) 52 | # stride the stride on which to apply the filter (non-1 stride are not optimized, runs very slowly with current Theano) 53 | # pool_stride the block size for max pooling 54 | # normalize do we normalize the values before applying the activation function? 55 | # activation a brick for the activation function 56 | # dropout dropout applied after activation function 57 | # skip do we introduce skip connections from previous layer(s) to next layer(s) ? 58 | convs = [ 59 | {'nfilter': 20, 60 | 'filter_size': 200, 61 | 'stride': 1, 62 | 'pool_stride': 10, 63 | 'normalize': True, 64 | 'activation': Rectifier(name='a0'), 65 | 'dropout': 0.0, 66 | 'skip': ['min', 'max', 'subsample']}, 67 | {'nfilter': 20, 68 | 'filter_size': 200, 69 | 'stride': 1, 70 | 'pool_stride': 10, 71 | 'normalize': True, 72 | 'activation': Rectifier(name='a1'), 73 | 'dropout': 0.0, 74 | 'skip': ['max']}, 75 | {'nfilter': 20, 76 | 'filter_size': 30, 77 | 'stride': 1, 78 | 'pool_stride': 2, 79 | 'normalize': True, 80 | 'activation': Rectifier(name='a2'), 81 | 'dropout': 0.0, 82 | 'skip': ['max']}, 83 | {'nfilter': 100, 84 | 'filter_size': 20, 85 | 'stride': 1, 86 | 'pool_stride': 2, 87 | 'normalize': True, 88 | 'activation': Rectifier(name='a3'), 89 | 'dropout': 0.0, 90 | 'skip': []}, 91 | ] 92 | 93 | # recurrent layers. Parameters: 94 | # type type of the layer (simple, lstm, blstm) 95 | # dim size of the state 96 | # normalize do we normalize the values after the RNN ? 97 | # dropout dropout after the RNN 98 | # skip do we introduce skip connections from previous layer(s) to next layer(s) ? 99 | recs = [ 100 | {'type': 'blstm', 101 | 'dim': 50, 102 | 'normalize': False, 103 | 'dropout': 0.0, 104 | 'skip': True}, 105 | {'type': 'blstm', 106 | 'dim': 50, 107 | 'normalize': False, 108 | 'dropout': 0.0, 109 | 'skip': True}, 110 | ] 111 | 112 | # do we normalize the activations just before the softmax layer ? 113 | normalize_out = True 114 | 115 | # regularization : noise on the weights 116 | weight_noise = 0.01 117 | 118 | # regularization : L2 penalization 119 | l2_output_bias = 0. 120 | l2_output_weight = 0. 121 | l2_all_bias = 0.0 122 | l2_all_weight = 0. 123 | 124 | # number of phonemes in timit, a constant 125 | num_output_classes = 61 126 | 127 | 128 | # the step rule (uncomment your favorite choice) 129 | step_rule = CompositeRule([AdaDelta(), RemoveNotFinite()]) 130 | #step_rule = CompositeRule([Momentum(learning_rate=0.00001, momentum=0.99), RemoveNotFinite()]) 131 | #step_rule = CompositeRule([Momentum(learning_rate=0.001, momentum=0.9), RemoveNotFinite()]) 132 | #step_rule = CompositeRule([AdaDelta(), Scale(0.01), RemoveNotFinite()]) 133 | #step_rule = CompositeRule([RMSProp(learning_rate=0.1, decay_rate=0.95), 134 | # RemoveNotFinite()]) 135 | #step_rule = CompositeRule([RMSProp(learning_rate=0.0001, decay_rate=0.95), 136 | # BasicMomentum(momentum=0.9), 137 | # RemoveNotFinite()]) 138 | 139 | # How the weights are initialized 140 | weights_init = IsotropicGaussian(0.01) 141 | biases_init = Constant(0.001) 142 | 143 | 144 | # ========================================================================================== 145 | # THE MODEL 146 | # ========================================================================================== 147 | 148 | print('Building model ...') 149 | 150 | 151 | # THEANO INPUT VARIABLES 152 | inputt = tensor.matrix('input') 153 | input_mask = tensor.matrix('input_mask') 154 | y = tensor.lmatrix('output').T 155 | y_mask = tensor.matrix('output_mask').T 156 | y_len = y_mask.sum(axis=0) 157 | L = y.shape[0] 158 | B = y.shape[1] 159 | # inputt : B x T 160 | # input_mask : B x T 161 | # y : L x B 162 | # y_mask : L x B 163 | 164 | # NORMALIZE THE INPUTS 165 | inputt = inputt / (inputt**2).mean() 166 | 167 | dropout_locs = [] 168 | 169 | # CONVOLUTION LAYERS 170 | conv_in = inputt[:, None, :, None] 171 | conv_in_channels = 1 172 | conv_in_mask = input_mask 173 | 174 | cb = [] 175 | for i, p in enumerate(convs): 176 | # Convolution bricks 177 | conv = Convolutional(filter_size=(p['filter_size'],1), 178 | # step=(p['stride'],1), 179 | num_filters=p['nfilter'], 180 | num_channels=conv_in_channels, 181 | batch_size=batch_size, 182 | border_mode='valid', 183 | tied_biases=True, 184 | name='conv%d'%i) 185 | cb.append(conv) 186 | maxpool = MaxPooling(pooling_size=(p['pool_stride'], 1), name='mp%d'%i) 187 | 188 | conv_out = conv.apply(conv_in)[:, :, ::p['stride'], :] 189 | conv_out = maxpool.apply(conv_out) 190 | if p['normalize']: 191 | conv_out_mean = conv_out.mean(axis=2).mean(axis=0) 192 | conv_out_var = ((conv_out - conv_out_mean[None, :, None, :])**2).mean(axis=2).mean(axis=0).sqrt() 193 | conv_out = (conv_out - conv_out_mean[None, :, None, :]) / conv_out_var[None, :, None, :] 194 | if p['activation'] is not None: 195 | conv_out = p['activation'].apply(conv_out) 196 | if p['dropout'] > 0: 197 | b = [p['activation'] if p['activation'] is not None else conv] 198 | dropout_locs.append((VariableFilter(bricks=b, name='output'), p['dropout'])) 199 | if p['skip'] is not None and len(p['skip'])>0: 200 | maxpooladd = MaxPooling(pooling_size=(p['stride']*p['pool_stride'], 1), name='Mp%d'%i) 201 | skip = [] 202 | if 'max' in p['skip']: 203 | skip.append(maxpooladd.apply(conv_in)[:, :, :conv_out.shape[2], :]) 204 | if 'min' in p['skip']: 205 | skip.append(maxpooladd.apply(-conv_in)[:, :, :conv_out.shape[2], :]) 206 | if 'subsample' in p['skip']: 207 | skip.append(conv_in[:, :, ::(p['stride']*p['pool_stride']), :][:, :, :conv_out.shape[2], :]) 208 | conv_out = tensor.concatenate([conv_out] + skip, axis=1) 209 | conv_out_channels = p['nfilter'] + len(p['skip']) * conv_in_channels 210 | else: 211 | conv_out_channels = p['nfilter'] 212 | conv_out_mask = conv_in_mask[:, ::(p['stride']*p['pool_stride'])][:, :conv_out.shape[2]] 213 | 214 | conv_in = conv_out 215 | conv_in_channels = conv_out_channels 216 | conv_in_mask = conv_out_mask 217 | 218 | # RECURRENT LAYERS 219 | rec_mask = conv_out_mask.dimshuffle(1, 0) 220 | rec_in = conv_out[:, :, :, 0].dimshuffle(2, 0, 1) 221 | rec_in_dim = conv_out_channels 222 | 223 | rb = [] 224 | for i, p in enumerate(recs): 225 | # RNN bricks 226 | if p['type'] == 'lstm': 227 | pre_rec = Linear(input_dim=rec_in_dim, output_dim=4*p['dim'], name='rnn_linear%d'%i) 228 | rec = LSTM(activation=Tanh(), dim=p['dim'], name="rnn%d"%i) 229 | rb = rb + [pre_rec, rec] 230 | 231 | rnn_in = pre_rec.apply(rec_in) 232 | 233 | rec_out, _ = rec.apply(inputs=rnn_in, mask=rec_mask) 234 | dropout_b = [rec] 235 | rec_out_dim = p['dim'] 236 | elif p['type'] == 'simple': 237 | pre_rec = Linear(input_dim=rec_in_dim, output_dim=p['dim'], name='rnn_linear%d'%i) 238 | rec = SimpleRecurrent(activation=Tanh(), dim=p['dim'], name="rnn%d"%i) 239 | rb = rb + [pre_rec, rec] 240 | 241 | rnn_in = pre_rec.apply(rec_in) 242 | 243 | rec_out = rec.apply(inputs=rnn_in, mask=rec_mask) 244 | dropout_b = [rec] 245 | rec_out_dim = p['dim'] 246 | elif p['type'] == 'blstm': 247 | pre_frec = Linear(input_dim=rec_in_dim, output_dim=4*p['dim'], name='frnn_linear%d'%i) 248 | pre_brec = Linear(input_dim=rec_in_dim, output_dim=4*p['dim'], name='brnn_linear%d'%i) 249 | frec = LSTM(activation=Tanh(), dim=p['dim'], name="frnn%d"%i) 250 | brec = LSTM(activation=Tanh(), dim=p['dim'], name="brnn%d"%i) 251 | rb = rb + [pre_frec, pre_brec, frec, brec] 252 | 253 | frnn_in = pre_frec.apply(rec_in) 254 | frnn_out, _ = frec.apply(inputs=frnn_in, mask=rec_mask) 255 | brnn_in = pre_brec.apply(rec_in) 256 | brnn_out, _ = brec.apply(inputs=brnn_in, mask=rec_mask) 257 | 258 | rec_out = tensor.concatenate([frnn_out, brnn_out], axis=2) 259 | dropout_b = [frec, brec] 260 | rec_out_dim = 2*p['dim'] 261 | else: 262 | assert False 263 | 264 | if p['normalize']: 265 | rec_out_mean = rec_out.mean(axis=1).mean(axis=0) 266 | rec_out_var = ((rec_out - rec_out_mean[None, None, :])**2).mean(axis=1).mean(axis=0).sqrt() 267 | rec_out = (rec_out - rec_out_mean[None, None, :]) / rec_out_var[None, None, :] 268 | if p['dropout'] > 0: 269 | dropout_locs.append((VariableFilter(bricks=dropout_b, name='output'), p['dropout'])) 270 | 271 | if p['skip']: 272 | rec_out = tensor.concatenate([rec_in, rec_out], axis=2) 273 | rec_out_dim = rec_in_dim + rec_out_dim 274 | 275 | rec_in = rec_out 276 | rec_in_dim = rec_out_dim 277 | 278 | # LINEAR FOR THE OUTPUT 279 | rec_to_o = Linear(name='rec_to_o', 280 | input_dim=rec_out_dim, 281 | output_dim=num_output_classes + 1) 282 | y_hat_pre = rec_to_o.apply(rec_out) 283 | # y_hat_pre : T x B x C+1 284 | 285 | if normalize_out: 286 | y_hat_pre_mean = y_hat_pre.mean(axis=1).mean(axis=0) 287 | y_hat_pre_var = ((y_hat_pre - y_hat_pre_mean[None, None, :])**2).mean(axis=1).mean(axis=0).sqrt() 288 | y_hat_pre = (y_hat_pre - y_hat_pre_mean[None, None, :]) / y_hat_pre_var[None, None, :] 289 | 290 | # y_hat : T x B x C+1 291 | y_hat = tensor.nnet.softmax( 292 | y_hat_pre.reshape((-1, num_output_classes + 1)) 293 | ).reshape((y_hat_pre.shape[0], y_hat_pre.shape[1], -1)) 294 | y_hat.name = 'y_hat' 295 | 296 | y_hat_mask = rec_mask 297 | 298 | # CTC COST AND ERROR MEASURE 299 | cost = CTC().apply_log_domain(y, y_hat, y_len, y_hat_mask).mean() 300 | cost.name = 'CTC' 301 | 302 | dl, dl_length = CTC().best_path_decoding(y_hat, y_hat_mask) 303 | dl = dl[:L, :] 304 | dl_length = tensor.minimum(dl_length, L) 305 | 306 | edit_distances = batch_edit_distance(dl.T.astype('int32'), dl_length.astype('int32'), 307 | y.T.astype('int32'), y_len.astype('int32')) 308 | edit_distance = edit_distances.mean() 309 | edit_distance.name = 'edit_distance' 310 | errors_per_char = (edit_distances / y_len).mean() 311 | errors_per_char.name = 'errors_per_char' 312 | 313 | is_error = tensor.neq(dl, y) * tensor.lt(tensor.arange(L)[:,None], y_len[None,:]) 314 | is_error = tensor.switch(is_error.sum(axis=0), tensor.ones((B,)), tensor.neq(y_len, dl_length)) 315 | 316 | error_rate = is_error.mean() 317 | error_rate.name = 'error_rate' 318 | 319 | # REGULARIZATION 320 | cg = ComputationGraph([cost, error_rate]) 321 | if weight_noise > 0: 322 | noise_vars = VariableFilter(roles=[WEIGHT])(cg) 323 | cg = apply_noise(cg, noise_vars, weight_noise) 324 | for vfilter, p in dropout_locs: 325 | cg = apply_dropout(cg, vfilter(cg), p) 326 | [cost_reg, error_rate_reg] = cg.outputs 327 | 328 | ctc_reg = cost_reg + 1e-24 329 | ctc_reg.name = 'CTC' 330 | 331 | if l2_output_bias > 0: 332 | cost_reg += l2_output_bias * sum(x.norm(2) for x in VariableFilter(roles=[BIAS], bricks=[rec_to_o])(cg)) 333 | if l2_output_weight > 0: 334 | cost_reg += l2_output_weight * sum(x.norm(2) for x in VariableFilter(roles=[WEIGHT], bricks=[rec_to_o])(cg)) 335 | if l2_all_bias > 0: 336 | cost_reg += l2_all_bias * sum(x.norm(2) for x in VariableFilter(roles=[BIAS])(cg)) 337 | if l2_all_weight > 0: 338 | cost_reg += l2_all_weight * sum(x.norm(2) for x in VariableFilter(roles=[WEIGHT])(cg)) 339 | cost_reg.name = 'cost' 340 | 341 | 342 | # INITIALIZATION 343 | for brick in [rec_to_o] + cb + rb: 344 | brick.weights_init = weights_init 345 | brick.biases_init = biases_init 346 | brick.initialize() 347 | 348 | 349 | # ========================================================================================== 350 | # THE INFRASTRUCTURE 351 | # ========================================================================================== 352 | 353 | # SET UP THE DATASTREAM 354 | 355 | print('Bulding DataStream ...') 356 | ds, stream = setup_datastream('/home/lx.nobackup/datasets/timit/readable', 357 | batch_size=batch_size, 358 | sort_batch_count=sort_batch_count) 359 | valid_ds, valid_stream = setup_datastream('/home/lx.nobackup/datasets/timit/readable', 360 | batch_size=batch_size, 361 | sort_batch_count=sort_batch_count, 362 | valid=True) 363 | 364 | 365 | # SET UP THE BLOCKS ALGORITHM WITH EXTENSIONS 366 | 367 | print('Bulding training process...') 368 | algorithm = GradientDescent(cost=cost_reg, 369 | parameters=ComputationGraph(cost).parameters, 370 | step_rule=step_rule) 371 | 372 | monitor_cost = TrainingDataMonitoring([ctc_reg, cost_reg, error_rate_reg], 373 | prefix="train", 374 | every_n_batches=monitor_freq, 375 | after_epoch=False) 376 | 377 | monitor_valid = DataStreamMonitoring([cost, error_rate, edit_distance, errors_per_char], 378 | data_stream=valid_stream, 379 | prefix="valid", 380 | after_epoch=True) 381 | 382 | plot = Plot(document='CTC_timit_%s%s%s%s_%s'% 383 | (repr([p['nfilter'] for p in convs]), 384 | repr([p['filter_size'] for p in convs]), 385 | repr([p['stride'] for p in convs]), 386 | repr([p['pool_stride'] for p in convs]), 387 | repr([p['dim'] for p in recs])), 388 | channels=[['train_cost', 'train_CTC', 'valid_CTC'], 389 | ['train_error_rate', 'valid_error_rate'], 390 | ['valid_edit_distance'], 391 | ['valid_errors_per_char']], 392 | every_n_batches=monitor_freq, 393 | after_epoch=True) 394 | 395 | model = Model(cost) 396 | main_loop = MainLoop(data_stream=stream, algorithm=algorithm, 397 | extensions=[ 398 | ProgressBar(), 399 | 400 | monitor_cost, monitor_valid, 401 | 402 | plot, 403 | Printing(every_n_batches=monitor_freq, after_epoch=True), 404 | ParamInfo(Model([cost]), every_n_batches=monitor_freq), 405 | 406 | FinishAfter(after_n_epochs=n_epochs), 407 | ], 408 | model=model) 409 | 410 | 411 | # NOW WE FINALLY CAN TRAIN OUR MODEL 412 | 413 | print('Starting training ...') 414 | main_loop.run() 415 | 416 | 417 | # vim: set sts=4 ts=4 sw=4 tw=0 et: 418 | -------------------------------------------------------------------------------- /main_toy_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import theano 4 | import numpy 5 | from theano import tensor 6 | from blocks.model import Model 7 | from blocks.bricks import Linear, Tanh 8 | from blocks.bricks.lookup import LookupTable 9 | from ctc import CTC 10 | from blocks.initialization import IsotropicGaussian, Constant 11 | from fuel.datasets import IterableDataset 12 | from fuel.streams import DataStream 13 | from blocks.algorithms import (GradientDescent, Scale, AdaDelta, RemoveNotFinite, 14 | StepClipping, CompositeRule) 15 | from blocks.extensions.monitoring import TrainingDataMonitoring, DataStreamMonitoring 16 | from blocks.main_loop import MainLoop 17 | from blocks.extensions import FinishAfter, Printing 18 | from blocks.bricks.recurrent import SimpleRecurrent, LSTM 19 | from blocks.graph import ComputationGraph 20 | 21 | from toy_dataset import setup_datastream 22 | 23 | from edit_distance import batch_edit_distance 24 | from blocks.extras.extensions.plot import Plot 25 | 26 | floatX = theano.config.floatX 27 | 28 | 29 | n_epochs = 10000 30 | num_input_classes = 5 31 | h_dim = 40 32 | rec_dim = 40 33 | num_output_classes = 4 34 | 35 | 36 | print('Building model ...') # ----------- THE MODEL -------------------------- 37 | 38 | inputt = tensor.lmatrix('input').T 39 | input_mask = tensor.matrix('input_mask').T 40 | y = tensor.lmatrix('output').T 41 | y_mask = tensor.matrix('output_mask').T 42 | y_len = y_mask.sum(axis=0) 43 | # inputt : T x B 44 | # input_mask : T x B 45 | # y : L x B 46 | # y_mask : L x B 47 | 48 | # Linear bricks in 49 | input_to_h = LookupTable(num_input_classes, h_dim, name='lookup') 50 | h = input_to_h.apply(inputt) 51 | # h : T x B x h_dim 52 | 53 | # RNN bricks 54 | pre_lstm = Linear(input_dim=h_dim, output_dim=4*rec_dim, name='LSTM_linear') 55 | lstm = LSTM(activation=Tanh(), 56 | dim=rec_dim, name="rnn") 57 | rnn_out, _ = lstm.apply(pre_lstm.apply(h), mask=input_mask) 58 | 59 | # Linear bricks out 60 | rec_to_o = Linear(name='rec_to_o', 61 | input_dim=rec_dim, 62 | output_dim=num_output_classes + 1) 63 | y_hat_pre = rec_to_o.apply(rnn_out) 64 | # y_hat_pre : T x B x C+1 65 | 66 | # y_hat : T x B x C+1 67 | y_hat = tensor.nnet.softmax( 68 | y_hat_pre.reshape((-1, num_output_classes + 1)) 69 | ).reshape((y_hat_pre.shape[0], y_hat_pre.shape[1], -1)) 70 | y_hat.name = 'y_hat' 71 | 72 | y_hat_mask = input_mask 73 | 74 | # Cost 75 | cost = CTC().apply_log_domain(y, y_hat, y_len, y_hat_mask).mean() 76 | cost.name = 'CTC' 77 | 78 | dl, dl_length = CTC().best_path_decoding(y_hat, y_hat_mask) 79 | 80 | edit_distances = batch_edit_distance(dl.T.astype('int32'), dl_length, y.T.astype('int32'), 81 | y_len.astype('int32')) 82 | edit_distance = edit_distances.mean() 83 | edit_distance.name = 'edit_distance' 84 | errors_per_char = (edit_distances / y_len).mean() 85 | errors_per_char.name = 'errors_per_char' 86 | 87 | L = y.shape[0] 88 | B = y.shape[1] 89 | dl = dl[:L, :] 90 | is_error = tensor.neq(dl, y) * tensor.lt(tensor.arange(L)[:,None], y_len[None,:]) 91 | is_error = tensor.switch(is_error.sum(axis=0), tensor.ones((B,)), tensor.neq(y_len, dl_length)) 92 | 93 | error_rate = is_error.mean() 94 | error_rate.name = 'error_rate' 95 | 96 | 97 | # Initialization 98 | for brick in [input_to_h, pre_lstm, lstm, rec_to_o]: 99 | brick.weights_init = IsotropicGaussian(0.01) 100 | brick.biases_init = Constant(0) 101 | brick.initialize() 102 | 103 | print('Bulding DataStream ...') # --------------------------------------------------- 104 | ds, stream = setup_datastream(batch_size=100, 105 | nb_examples=10000, rng_seed=123, 106 | min_out_len=5, max_out_len=20) 107 | valid_ds, valid_stream = setup_datastream(batch_size=100, 108 | nb_examples=1000, rng_seed=456, 109 | min_out_len=5, max_out_len=20) 110 | 111 | print('Bulding training process...') # ---------------------------------------------- 112 | algorithm = GradientDescent(cost=cost, 113 | parameters=ComputationGraph(cost).parameters, 114 | step_rule=CompositeRule([RemoveNotFinite(), AdaDelta()])) 115 | # CompositeRule([StepClipping(10.0), Scale(0.02)])) 116 | monitor_cost = TrainingDataMonitoring([cost, error_rate], 117 | prefix="train", 118 | after_epoch=True) 119 | 120 | monitor_valid = DataStreamMonitoring([cost, error_rate, edit_distance, errors_per_char], 121 | data_stream=valid_stream, 122 | prefix="valid", 123 | after_epoch=True) 124 | 125 | plot = Plot(document='CTC_toy_dataset_%d_%d'%(h_dim, rec_dim), 126 | channels=[['train_CTC', 'valid_CTC'], 127 | ['train_error_rate', 'valid_error_rate'], 128 | ['valid_edit_distance'], 129 | ['valid_errors_per_char']], 130 | after_epoch=True) 131 | 132 | model = Model(cost) 133 | main_loop = MainLoop(data_stream=stream, algorithm=algorithm, 134 | extensions=[monitor_cost, monitor_valid, plot, 135 | FinishAfter(after_n_epochs=n_epochs), 136 | Printing()], 137 | model=model) 138 | 139 | print('Starting training ...') # --------------------------------------------------- 140 | main_loop.run() 141 | 142 | 143 | # vim: set sts=4 ts=4 sw=4 tw=0 et: 144 | -------------------------------------------------------------------------------- /timit.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | import numpy 4 | 5 | import cPickle 6 | 7 | from fuel.datasets import Dataset, IndexableDataset 8 | from fuel.streams import DataStream 9 | from fuel.schemes import IterationScheme, ConstantScheme, SequentialExampleScheme, ShuffledExampleScheme 10 | from fuel.transformers import Batch, Mapping, SortMapping, Unpack, Padding, Transformer 11 | 12 | import sys 13 | import os 14 | 15 | logging.basicConfig(level='INFO') 16 | logger = logging.getLogger(__name__) 17 | 18 | class _balanced_batch_helper(object): 19 | def __init__(self, key): 20 | self.key = key 21 | def __call__(self, data): 22 | return data[self.key].shape[0] 23 | 24 | def setup_datastream(path, batch_size, sort_batch_count, valid=False): 25 | A = numpy.load(os.path.join(path, ('valid_x_raw.npy' if valid else 'train_x_raw.npy'))) 26 | B = numpy.load(os.path.join(path, ('valid_phn.npy' if valid else 'train_phn.npy'))) 27 | C = numpy.load(os.path.join(path, ('valid_seq_to_phn.npy' if valid else 'train_seq_to_phn.npy'))) 28 | 29 | D = [B[x[0]:x[1], 2] for x in C] 30 | 31 | ds = IndexableDataset({'input': A, 'output': D}) 32 | stream = DataStream(ds, iteration_scheme=ShuffledExampleScheme(len(A))) 33 | 34 | stream = Batch(stream, iteration_scheme=ConstantScheme(batch_size * sort_batch_count)) 35 | comparison = _balanced_batch_helper(stream.sources.index('input')) 36 | stream = Mapping(stream, SortMapping(comparison)) 37 | stream = Unpack(stream) 38 | 39 | stream = Batch(stream, iteration_scheme=ConstantScheme(batch_size, num_examples=len(A))) 40 | stream = Padding(stream, mask_sources=['input', 'output']) 41 | 42 | return ds, stream 43 | 44 | if __name__ == "__main__": 45 | ds, stream = setup_datastream(batch_size=2, 46 | path='/home/lx.nobackup/datasets/timit/readable') 47 | 48 | for i, d in enumerate(stream.get_epoch_iterator()): 49 | print '--' 50 | print d 51 | 52 | 53 | if i > 2: break 54 | 55 | # vim: set sts=4 ts=4 sw=4 tw=0 et : 56 | -------------------------------------------------------------------------------- /toy_dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | import numpy 4 | 5 | import cPickle 6 | 7 | from picklable_itertools import iter_ 8 | 9 | from fuel.datasets import Dataset 10 | from fuel.streams import DataStream 11 | from fuel.schemes import IterationScheme, ConstantScheme, SequentialExampleScheme 12 | from fuel.transformers import Batch, Mapping, SortMapping, Unpack, Padding, Transformer 13 | 14 | import sys 15 | import os 16 | 17 | logging.basicConfig(level='INFO') 18 | logger = logging.getLogger(__name__) 19 | 20 | class ToyDataset(Dataset): 21 | def __init__(self, nb_examples, rng_seed, min_out_len, max_out_len, **kwargs): 22 | self.provides_sources = ('input', 'output') 23 | 24 | random.seed(rng_seed) 25 | 26 | table = [ 27 | [0, 1, 2, 3, 4], 28 | [0, 1, 2, 1, 0], 29 | [4, 3, 2, 3, 4], 30 | [4, 3, 2, 1, 0] 31 | ] 32 | prob0 = 0.7 33 | prob = 0.2 34 | 35 | self.data = [] 36 | for n in range(nb_examples): 37 | o = [] 38 | i = [] 39 | l = random.randrange(min_out_len, max_out_len) 40 | for p in range(l): 41 | o.append(random.randrange(len(table))) 42 | for x in table[o[-1]]: 43 | q = 0 44 | if random.uniform(0, 1) < prob0: 45 | i.append(x) 46 | while random.uniform(0, 1) < prob: 47 | i.append(x) 48 | self.data.append((i, o)) 49 | 50 | super(ToyDataset, self).__init__(**kwargs) 51 | 52 | 53 | def get_data(self, state=None, request=None): 54 | if request is None: 55 | raise ValueError("Request required") 56 | 57 | return self.data[request] 58 | 59 | # -------------- DATASTREAM SETUP -------------------- 60 | 61 | def setup_datastream(batch_size, **kwargs): 62 | ds = ToyDataset(**kwargs) 63 | stream = DataStream(ds, iteration_scheme=SequentialExampleScheme(kwargs['nb_examples'])) 64 | 65 | stream = Batch(stream, iteration_scheme=ConstantScheme(batch_size)) 66 | stream = Padding(stream, mask_sources=['input', 'output']) 67 | 68 | return ds, stream 69 | 70 | if __name__ == "__main__": 71 | 72 | ds, stream = setup_datastream(nb_examples=5, 73 | rng_seed=123, 74 | min_out_len=3, 75 | max_out_len=6) 76 | 77 | for i, d in enumerate(stream.get_epoch_iterator()): 78 | print '--' 79 | print d 80 | 81 | 82 | if i > 2: break 83 | 84 | # vim: set sts=4 ts=4 sw=4 tw=0 et : 85 | --------------------------------------------------------------------------------