├── LICENSE ├── README.md ├── csv-to-tfrecords.py ├── egs-to-csv.py ├── faster-mapping.sh ├── format-mapping.py ├── map-label-no-vote.py ├── phone-to-cluster.py ├── run.sh ├── train_and_eval.py └── vote.py /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2018 Joshua Meyer 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # `kaldi-tf` 2 | 3 | A set of scripts for getting data out of Kaldi and into TensorFlow. 4 | 5 | 6 | ## Pipeline 7 | 8 | | Step | Code Location | 9 | |---|---| 10 | | 1) Generate Kaldi phoneme-level alignments (`*.ali`) via GMMs | Kaldi source | 11 | | 2) Generate Kaldi nnet3 neural net example files (`egs.*.ark`) from alignments | Kaldi source | 12 | | 3) Convert binary Nnet3Egs ark files to text ark files via `nnet3-copy-egs.cc` | Kaldi source | 13 | | 4) Convert text ark file to csv via `egs-to-csv.py` | this repo | 14 | | 5) Convert csv to tfrecords via `csv-to-tfrecords.py` | this repo | 15 | | 6) Read tfrecords, train, and evaluate with `train_and_eval.py` | this repo | 16 | 17 | ## Modifying Kaldi egs 18 | 19 | Unrelated to TensorFlow, but if you want to open Kaldi egs, make changes, and use those modified egs in training, follow this guide: 20 | 21 | 1) convert egs.ark to text: `$ nnet3-copy-egs ark:egs.1.ark ark,t:egs.1.ark.txt` 22 | 2) make your changes to new ark text file 23 | 3) convert ark text file back to binary with new scp file: `$ nnet3-copy-egs ark,t:egs.1.ark.txt ark,scp:egs.1.ark,egs.scp` 24 | 4) make changes to scp file paths, because they change depending on where you run the `nnet3-copy-egs` script! 25 | -------------------------------------------------------------------------------- /csv-to-tfrecords.py: -------------------------------------------------------------------------------- 1 | # Josh Meyer // jrmeyer.github.io 2 | 3 | import sys 4 | import pandas 5 | import tensorflow as tf 6 | import numpy as np 7 | 8 | # 9 | # USAGE: $ python3 csv-to-tfrecords.py data.csv data.tfrecords 10 | # 11 | # 12 | 13 | infile=sys.argv[1] 14 | outfile=sys.argv[2] 15 | 16 | csv = pandas.read_csv(infile, header=None).values 17 | 18 | 19 | with tf.python_io.TFRecordWriter(outfile) as writer: 20 | for row in csv: 21 | # row is read as a single char string, so remove trailing whitespace and split 22 | row = row[0].rstrip().split(' ') 23 | # the first col is label, all rest are feats 24 | label = int(row[0]) 25 | 26 | mfccs = np.array([ float(feat) for feat in row[1:] ]).tostring() 27 | 28 | example = tf.train.Example() 29 | example.features.feature["mfccs"].bytes_list.value.append(mfccs) 30 | example.features.feature["label"].int64_list.value.append(label) 31 | writer.write(example.SerializeToString()) 32 | -------------------------------------------------------------------------------- /egs-to-csv.py: -------------------------------------------------------------------------------- 1 | # Josh Meyer // jrmeyer.github.io 2 | # 3 | # $ python3 egs-to-csv.py in-egs.txt out-egs.csv 4 | # 5 | 6 | 7 | 8 | import sys 9 | import re 10 | 11 | 12 | def extract_windows(eg, line, regex, outfile, win_size=62): 13 | ''' 14 | given a line of labels, and the saved block of feature vectors, 15 | this function will extract windows of a given size and assign them 16 | to their label in a label -- flattened_data file 17 | win_size comes from the left and right context you provided to kaldi 18 | to splice the frames 19 | ''' 20 | 21 | # for each label in the nnet3Egs object 22 | for i, label in enumerate(regex.findall(line)): 23 | 24 | # cat all feats for that eg/label into a single vector 25 | catFeats='' 26 | for row in eg[i:i+win_size]: 27 | 28 | # remove the trailing \] if it exists (this is just a cleaning step) 29 | row = row[0].replace("]", "") 30 | # cat all the rows into one flat vector 31 | catFeats += (row + ' ') 32 | 33 | print(label, catFeats, file=outfile) 34 | 35 | 36 | def get_eg_dim(arkfile): 37 | ''' 38 | given kaldi ark file in txt format, find the dimension of the target labels 39 | ''' 40 | with open(arkfile, "r") as arkf: 41 | for line in arkf: 42 | if "dim=" in line: 43 | egDim = re.search('dim=([0-9]*)', line).group(1) 44 | break 45 | else: 46 | pass 47 | 48 | return egDim 49 | 50 | 51 | 52 | 53 | def main(arkfile, outfile): 54 | ''' 55 | arkfile: is the input ark file from kaldi (egs.ark) 56 | regex: matches the labels of each eg based on number of dims in output layer 57 | outfile: where to save the output 58 | ''' 59 | 60 | regex = re.compile("dim=[0-9]+ \[ ([0-9]+) ") 61 | 62 | eg=[] 63 | with open(arkfile,"r") as arkf: 64 | with open(outfile,"a") as outf: 65 | for line in arkf: 66 | # the first line of the eg 67 | if 'input' in line: 68 | eg=[] 69 | pass 70 | # if we've hit the labels then we're at the end of the data 71 | elif 'output' in line: 72 | extract_windows(eg, line, regex, outf) 73 | # this should be one frame of data 74 | else: 75 | eg.append([line.strip()]) 76 | 77 | 78 | 79 | if __name__ == "__main__": 80 | 81 | # this is a kaldi nnet3Egs ark file which has been converted 82 | # to txt via nnet3-copy-egs 83 | arkfile=sys.argv[1] 84 | outfile=sys.argv[2] 85 | 86 | # print the dimension of the arkfile to disk for downstream use 87 | 88 | # this regex matches the label for 89 | # each eg in a frame of egs (ie the 90 | # ark file contains groups of egs 91 | # which make sense for TDNN) 92 | 93 | 94 | main(arkfile, outfile) 95 | 96 | print("Extracted egs from " + arkfile + " and printed to " + outfile ) 97 | 98 | -------------------------------------------------------------------------------- /faster-mapping.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Josh Meyer // jrmeyer.github.io 3 | 4 | echo "### SPLIT ARK FOR MULTIPLE JOBS ###" 5 | 6 | 7 | ARKFILE=$1 8 | MAPPINGS=$2 9 | TMP_DIR=$3 10 | 11 | 12 | 13 | #### MANUAL SHARD #### 14 | 15 | num_lines=(`wc -l $ARKFILE`) 16 | num_processors=(`nproc`) 17 | segs_per_job=$(( num_lines / num_processors )) 18 | 19 | echo "$0: processing $num_lines segments from $ARKFILE" 20 | echo "$0: splitting segments over $num_processors CPUs" 21 | echo "$0: with $segs_per_job segments per job." 22 | 23 | # will split into segments00 segments01 ... etc 24 | split -l $segs_per_job --numeric-suffixes --additional-suffix=.tmp $ARKFILE $TMP_DIR/ARK_split 25 | 26 | proc_ids=() # make an array for proc ids 27 | 28 | for i in $TMP_DIR/ARK_split*.tmp; do 29 | while read mapping; do 30 | mapArr=($mapping) 31 | old=${mapArr[0]} 32 | new=${mapArr[1]} 33 | sed_command="s/ \[ ${old} / \[ ${new}@ /g" 34 | parallel --pipepart --block 5000M -a $i -k sed -e \" $sed_command \" > ${i}.mod 35 | mv ${i}.mod $i 36 | done <$MAPPINGS & 37 | proc_ids+=($!) 38 | done 39 | # # wait for subprocesses to stop 40 | for proc_id in ${proc_ids[*]}; do wait $proc_id; done; 41 | 42 | ########### 43 | 44 | proc_ids=() 45 | for i in $TMP_DIR/ARK_split*.tmp; do 46 | parallel --pipepart --block 500M -a $i -k 'sed "s/@//g"' > ${i}.final 47 | mv ${i}.final $i & 48 | proc_ids+=($!) 49 | done 50 | 51 | # wait for subprocesses to stop 52 | for proc_id in ${proc_ids[*]}; do wait $proc_id; done; 53 | 54 | -------------------------------------------------------------------------------- /format-mapping.py: -------------------------------------------------------------------------------- 1 | # Josh Meyer // jrmeyer.github.io 2 | 3 | # This script fixes the problem that occurs if the TF clusters (their IDs) are 4 | # larger than the dimensionality of the original Kaldi eg. this problem only shows up 5 | # when you train the modified egs in Kaldi, and the new target IDs happen to be 6 | # larger than the dims of the eg (dim=X [ ID 1 ]) 7 | 8 | import sys 9 | 10 | oldMappings=sys.argv[1] 11 | newMappings=sys.argv[2] 12 | 13 | i=0 14 | newMap={} 15 | with open(oldMappings) as orgMap: 16 | for line in orgMap.readlines(): 17 | clusterID = line.split()[1] 18 | if clusterID in newMap: 19 | pass 20 | else: 21 | newMap[clusterID]=i 22 | i+=1 23 | 24 | with open(newMappings, 'w') as newFile: 25 | with open(oldMappings) as orgMap: 26 | for line in orgMap.readlines(): 27 | kaldiID, clusterID = line.split() 28 | print(str(kaldiID) + ' ' + str(newMap[clusterID]), file=newFile ) 29 | 30 | -------------------------------------------------------------------------------- /map-label-no-vote.py: -------------------------------------------------------------------------------- 1 | # Josh Meyer // jrmeyer.github.io 2 | # 3 | # $ python3 egs-to-csv.py path/to/txt/egs.txt 4 | # 5 | 6 | 7 | 8 | import sys 9 | import re 10 | 11 | 12 | 13 | def replace_labels(curLine, newLabels, outfile): 14 | ''' 15 | ''' 16 | 17 | newLabels=[1, 2, 3, 4, 5, 6, 7, 8] 18 | 19 | # for each label in the nnet3Egs object 20 | line= re.split('(rows=. )', curLine) 21 | 22 | misc = line[0] + line[1] 23 | labels = line[2] 24 | 25 | i=0 26 | newOut='' 27 | for label in re.split('(dim=\d+ \[ \d+ \d+ \])', labels): 28 | if 'dim' in label: 29 | newLabel = re.sub('\[ \d+', '[ ' + str(newLabels[i]) , label) 30 | newOut+=newLabel 31 | i+=1 32 | else: 33 | newOut+=label # this should only be spaces 34 | 35 | 36 | return misc+newOut 37 | 38 | 39 | 40 | 41 | 42 | def main(arkfile, new_label_file, outfile): 43 | ''' 44 | arkfile: is the input ark file from kaldi (egs.ark) 45 | regex: matches the labels of each eg based on number of dims in output layer 46 | outfile: where to save the output 47 | ''' 48 | 49 | 50 | with open(arkfile,"r") as arkf: 51 | for line in arkf: 52 | if 'output' in line: 53 | newLine = replace_labels(line, 'foo', 'bar') 54 | print(newLine) 55 | else: 56 | print(line) 57 | 58 | 59 | 60 | if __name__ == "__main__": 61 | 62 | # this is a kaldi nnet3Egs ark file which has been converted 63 | # to txt via nnet3-copy-egs 64 | arkfile=sys.argv[1] 65 | newLabels=sys.argv[2] 66 | 67 | # this regex matches the label for 68 | # each eg in a frame of egs (ie the 69 | # ark file contains groups of egs 70 | # which make sense for TDNN) 71 | 72 | 73 | 74 | 75 | main(arkfile, 'newlabs', 'out') 76 | 77 | 78 | 79 | 80 | # line=' output 8 0 0 0 0 1 0 0 2 0 0 3 0 0 4 0 0 5 0 0 6 0 0 7 0 rows=8 dim=736 [ 75 1 ] dim=736 [ 625 1 ] dim=736 [ 692 1 ] dim=736 [ 692 1 ] dim=736 [ 692 1 ] dim=736 [ 692 1 ] dim=736 [ 261 1 ] dim=736 [ 151 1 ]' 81 | 82 | # replace_labels(line, 'foo', 'bar') 83 | 84 | 85 | 86 | 87 | # split string between rows and first dim 88 | # save first half 89 | # iterate over each dim, and replce with new label, catting to string 90 | # cat first half of org string (up to row) with new labels 91 | -------------------------------------------------------------------------------- /phone-to-cluster.py: -------------------------------------------------------------------------------- 1 | # josh meyer jrmeyer.github.io 2 | # 2018 3 | 4 | # 5 | # this script takes in a two-column file (i.e. mapping.txt) 6 | # where the left column is the triphone id from kaldi and 7 | # the right col is the new cluster id from tensorflow 8 | 9 | import sys 10 | 11 | mapping=sys.argv[1] 12 | 13 | 14 | 15 | phoneMap={} 16 | with open(mapping, 'r') as f: 17 | lines = f.readlines() 18 | for line in lines: 19 | 20 | phone = line.split()[0] 21 | cluster = line.split()[1] 22 | 23 | if cluster in phoneMap: 24 | if phone in phoneMap[cluster]: 25 | phoneMap[cluster][phone] = phoneMap[cluster][phone] + 1 26 | else: 27 | phoneMap[cluster][phone] = 1 28 | 29 | else: 30 | phoneMap[cluster] = {phone:1} 31 | 32 | for item in phoneMap.items(): 33 | print(item[0], item[1]) 34 | 35 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Josh Meyer // jrmeyer.github.io 4 | 5 | # 6 | # this script should take in an ark file in txt form and return an arkfile 7 | # in txt form 8 | 9 | 10 | 11 | # given the name of the experiment, this script assumes file structure 12 | # that I've got set up 13 | exp_name=${1} 14 | tmp_dir="/data/TMP" 15 | path_to_exp="/home/ubuntu/kaldi/egs/multi-task-kaldi/mtk/MTL/exp/${exp_name}/nnet3/egs" 16 | 17 | 18 | 19 | if [ 1 ]; then 20 | echo "### CONVERTING BINARY EGS TO TEXT ###" 21 | echo "LOOKING FOR EGS IN ${path_to_exp}" 22 | # binary --> text 23 | KALDI=/home/ubuntu/kaldi/src/nnet3bin/nnet3-copy-egs 24 | $KALDI ark:${path_to_exp}/egs.1.ark ark,t:${tmp_dir}/org-txt-ark 25 | $KALDI ark:${path_to_exp}/valid_diagnostic.egs ark,t:${tmp_dir}/valid-txt-ark 26 | $KALDI ark:${path_to_exp}/train_diagnostic.egs ark,t:${tmp_dir}/train-txt-ark 27 | $KALDI ark:${path_to_exp}/combine.egs ark,t:${tmp_dir}/combine-txt-ark 28 | fi 29 | 30 | 31 | 32 | if [ 1 ]; then 33 | echo "### CONVERT EGS TO TFRECORDS ###" 34 | # EGS --> CSV 35 | 36 | python3 egs-to-csv.py ${tmp_dir}/org-txt-ark ${tmp_dir}/ark.csv 37 | # Split data into train / eval / all 38 | # I know this isn't kosher, but right 39 | # now idc about eval, it's just a step 40 | # in the scripts 41 | 42 | mv ${tmp_dir}/ark.csv ${tmp_dir}/all.csv 43 | #mv ${tmp_dir}/ark.csv ${tmp_dir}/train.csv 44 | tail -n100 ${tmp_dir}/all.csv > ${tmp_dir}/eval.csv 45 | # CSV --> TFRECORDS 46 | python3 csv-to-tfrecords.py ${tmp_dir}/all.csv ${tmp_dir}/all.tfrecords 47 | python3 csv-to-tfrecords.py ${tmp_dir}/eval.csv ${tmp_dir}/eval.tfrecords 48 | #python3 csv-to-tfrecords.py ${tmp_dir}/train.csv ${tmp_dir}/train.tfrecords 49 | # TRAIN K-MEANS 50 | echo "### TRAIN AND EVAL MODEL ###" 51 | echo "# remove old model in /tmp/tf" 52 | rm -rf /tmp/tf 53 | time python3 train_and_eval.py $tmp_dir ## returns tf-labels.txt 54 | fi 55 | 56 | 57 | # VOTE FOR MAPPINGS 58 | if [ 1 ]; then 59 | cut -d' ' -f1 ${tmp_dir}/all.csv > ${tmp_dir}/kaldi-labels.txt 60 | paste -d' ' ${tmp_dir}/kaldi-labels.txt ${tmp_dir}/tf-labels.txt > ${tmp_dir}/combined-labels.txt 61 | python3 vote.py ${tmp_dir}/combined-labels.txt > ${tmp_dir}/mapping.txt 62 | python3 format-mapping.py ${tmp_dir}/mapping.txt ${tmp_dir}/formatted-mapping.txt 63 | fi 64 | 65 | 66 | 67 | if [ 1 ]; then 68 | # PERFORM MAPPING 69 | for egs in org-txt-ark valid-txt-ark train-txt-ark combine-txt-ark; do 70 | ./faster-mapping.sh $tmp_dir/$egs $tmp_dir/formatted-mapping.txt ${tmp_dir} 71 | cat ${tmp_dir}/ARK_split* > ${tmp_dir}/${egs}.mod 72 | rm ${tmp_dir}/ARK_split* 73 | done 74 | fi 75 | 76 | 77 | if [ 1 ]; then 78 | echo "TXT.egs --> BIN.egs ;; RENAME AND MOVE BIN.egs" 79 | # text --> binary 80 | 81 | $KALDI ark,t:${tmp_dir}/org-txt-ark.mod ark,scp:${tmp_dir}/egs.1.ark,${tmp_dir}/egs.scp 82 | $KALDI ark,t:${tmp_dir}/valid-txt-ark.mod ark,scp:${tmp_dir}/valid_diagnostic.egs,${tmp_dir}/valid_diagnostic.scp 83 | $KALDI ark,t:${tmp_dir}/train-txt-ark.mod ark,scp:${tmp_dir}/train_diagnostic.egs,${tmp_dir}/train_diagnostic.scp 84 | $KALDI ark,t:${tmp_dir}/combine-txt-ark.mod ark,scp:${tmp_dir}/combine.egs,${tmp_dir}/combine.scp 85 | 86 | # fix paths 87 | fix_path="s/\/data\/TMP/MTL\/exp\/${exp_name}\/nnet3\/egs/g" 88 | sed -Ei $fix_path ${tmp_dir}/*.scp 89 | 90 | # move old egs to tmp dir 91 | mkdir ${path_to_exp}/org-scp-ark 92 | mv ${path_to_exp}/*.scp ${path_to_exp}/*.ark ${path_to_exp}/*.egs ${path_to_exp}/org-scp-ark/. 93 | 94 | # move new to org dir 95 | mv ${tmp_dir}/*.scp ${tmp_dir}/*.ark ${tmp_dir}/*.egs ${path_to_exp}/. 96 | 97 | echo "### OLD ARKS and SCPs moved to ${path_to_exp}/org-scp-ark/" 98 | echo "### NEW ARKS just renamed to standard names" 99 | fi 100 | 101 | 102 | rm $tmp_dir/* 103 | 104 | 105 | -------------------------------------------------------------------------------- /train_and_eval.py: -------------------------------------------------------------------------------- 1 | # Josh Meyer // jrmeyer.github.io 2 | # 3 | # $ python3 train_and_eval.py 4 | # 5 | # No CLI args (tfrecords paths are hardcoded) 6 | 7 | import tensorflow as tf 8 | import multiprocessing # this gets us the number of CPU cores on the machine 9 | import sys 10 | 11 | 12 | data_dir=sys.argv[1] # where are the tfrecords files? 13 | 14 | 15 | def parse_fn(record): 16 | ''' 17 | this is a parser function. It defines the template for 18 | interpreting the examples you're feeding in. Basically, 19 | this function defines what the labels and data look like 20 | for your labeled data. 21 | ''' 22 | features={ 23 | 'mfccs': tf.FixedLenFeature([], tf.string) 24 | # 'label': tf.FixedLenFeature([], tf.int64) 25 | } 26 | parsed = tf.parse_single_example(record, features) 27 | mfccs= tf.convert_to_tensor(tf.decode_raw(parsed['mfccs'], tf.float64)) 28 | # label= tf.cast(parsed['label'], tf.int32) 29 | 30 | return {'mfccs': mfccs} # previously was this: {'mfccs': mfccs}, label 31 | 32 | 33 | 34 | def my_input_fn(tfrecords_path, model): 35 | ''' 36 | this is an Estimator input function. it defines things 37 | like datasets and batches, and can perform operations 38 | such as shuffling 39 | The dataset and iterator are both defined here. 40 | ''' 41 | 42 | dataset = ( 43 | tf.data.TFRecordDataset(tfrecords_path, 44 | num_parallel_reads=multiprocessing.cpu_count()) 45 | .apply( 46 | tf.contrib.data.map_and_batch( 47 | map_func=parse_fn, 48 | batch_size=4096, 49 | num_parallel_batches=multiprocessing.cpu_count() 50 | ) 51 | ) 52 | .prefetch(4096) 53 | ) 54 | 55 | 56 | iterator = dataset.make_one_shot_iterator() 57 | batch_mfccs = iterator.get_next() 58 | # batch_mfccs, batch_labels = iterator.get_next() 59 | 60 | if model == "dnn": 61 | output = (batch_mfccs, batch_labels) 62 | elif model == "kmeans": 63 | output = batch_mfccs 64 | 65 | return batch_mfccs 66 | 67 | 68 | 69 | def zscore(in_tensor): 70 | ''' 71 | Some normalization for audio feats 72 | (value − min_value) / (max_value − min_value) 73 | ''' 74 | out_tensor = tf.div( 75 | tf.subtract( 76 | in_tensor, 77 | tf.reduce_min(in_tensor) 78 | ), 79 | tf.subtract( 80 | tf.reduce_max(in_tensor), 81 | tf.reduce_min(in_tensor) 82 | ) 83 | ) 84 | return out_tensor 85 | 86 | 87 | 88 | ### Multi-GPU training config ### 89 | 90 | # distribution = tf.contrib.distribute.MirroredStrategy() 91 | #run_config = tf.estimator.RunConfig(train_distribute=distribution) 92 | run_config = tf.estimator.RunConfig() 93 | 94 | 95 | ### K-Means ### 96 | 97 | train_spec_kmeans = tf.estimator.TrainSpec(input_fn = lambda: my_input_fn( str(data_dir) + '/' + 'all.tfrecords', 'kmeans') , max_steps=2000) 98 | eval_spec_kmeans = tf.estimator.EvalSpec(input_fn = lambda: my_input_fn( str(data_dir) + '/' + 'eval.tfrecords', 'kmeans') ) 99 | 100 | KMeansEstimator = tf.contrib.factorization.KMeansClustering( 101 | num_clusters=4096, 102 | feature_columns = [tf.feature_column.numeric_column( 103 | key='mfccs', 104 | dtype=tf.float64, 105 | shape=(806,), 106 | normalizer_fn = lambda x: zscore(x) 107 | )], # The input features to our model 108 | model_dir = '/tmp/tf', 109 | use_mini_batch = True, 110 | config = run_config) 111 | 112 | 113 | print("Train and Evaluate K-Means") 114 | tf.estimator.train_and_evaluate(KMeansEstimator, train_spec_kmeans, eval_spec_kmeans) 115 | 116 | 117 | # map the input points to their clusters 118 | cluster_centers = KMeansEstimator.cluster_centers() 119 | cluster_indices = list(KMeansEstimator.predict_cluster_index(input_fn = lambda: my_input_fn( str(data_dir) + '/' + 'all.tfrecords', 'kmeans'))) 120 | 121 | with open(str(data_dir) + '/' + "tf-labels.txt", "a") as outfile: 122 | for i in cluster_indices: 123 | index = i-1 # kaldi uses zero-based indexes 124 | print(index, file=outfile) 125 | 126 | 127 | 128 | 129 | # # DNN 130 | 131 | # # Define train and eval specs 132 | # train_spec_dnn = tf.estimator.TrainSpec(input_fn = lambda: my_input_fn('train.tfrecords', 'dnn') , max_steps=100000) 133 | # eval_spec_dnn = tf.estimator.EvalSpec(input_fn = lambda: my_input_fn('eval.tfrecords', 'dnn') ) 134 | 135 | # DNNClassifier = tf.estimator.DNNClassifier( 136 | # feature_columns = [tf.feature_column.numeric_column(key='mfccs', dtype=tf.float64, shape=(806,))], # The input features to our model 137 | # hidden_units = [256, 256, 256, 256], # Two layers, each with 10 neurons 138 | # n_classes = 200, 139 | # optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.0000001), 140 | # model_dir = '/tmp/tf', 141 | # config = run_config) # Path to where checkpoints etc are stored 142 | 143 | 144 | 145 | # print("Train and Evaluate DNN") 146 | # tf.estimator.train_and_evaluate(DNNClassifier, train_spec_dnn, eval_spec_dnn) 147 | 148 | # predictions = list(DNNClassifier.predict(input_fn = lambda: my_input_fn('/home/ubuntu/eval.tfrecords', 'dnn'))) 149 | 150 | # for logits in predictions: 151 | # print(logits['probabilities']) 152 | 153 | 154 | -------------------------------------------------------------------------------- /vote.py: -------------------------------------------------------------------------------- 1 | # Josh Meyer // jrmeyer.github.io 2 | 3 | import sys 4 | import collections 5 | 6 | # this script expects as input a single text file with two columns 7 | # where the first column is a list of phoneme labels from kaldi 8 | # and the second column is a list of cluster labels from tensorflow 9 | # one row represents one window of audio from a kaldi nnet3Eg 10 | 11 | 12 | infile=sys.argv[1] 13 | 14 | with open(infile) as f: 15 | content = f.readlines() 16 | 17 | vote_dict={} 18 | for line in content: 19 | kaldi_label, tf_label = line.split() 20 | 21 | if kaldi_label in vote_dict: 22 | vote_dict[kaldi_label].append(tf_label) 23 | else: 24 | vote_dict[kaldi_label] = [tf_label] 25 | 26 | for kaldi_label, tf_labels in vote_dict.items(): 27 | most_common_tf_label = max(set(tf_labels), key=tf_labels.count) 28 | # just keep most common 29 | vote_dict[kaldi_label] = most_common_tf_label 30 | 31 | 32 | for kaldi, tf in collections.OrderedDict(sorted(vote_dict.items())).items(): 33 | print(kaldi, tf) 34 | --------------------------------------------------------------------------------