├── .gitignore ├── README ├── converters ├── change_format.py ├── common.py ├── parallelizer-b.py ├── pre-b.py └── subsample.py ├── ffm ├── .classpath ├── .fatjar ├── .project ├── .settings │ ├── org.eclipse.core.resources.prefs │ ├── org.eclipse.core.runtime.prefs │ └── org.eclipse.jdt.core.prefs └── src │ └── ffm │ ├── FFMModel.java │ ├── FFMNode.java │ ├── FFMParameter.java │ ├── FFMProblem.java │ └── LogLossEvalutor.java ├── main_script.sh └── utils └── count.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.class 2 | 3 | # Mobile Tools for Java (J2ME) 4 | .mtj.tmp/ 5 | 6 | # Package Files # 7 | *.jar 8 | *.war 9 | *.ear 10 | 11 | # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml 12 | hs_err_pid* 13 | 14 | *.pyc 15 | -------------------------------------------------------------------------------- /README: -------------------------------------------------------------------------------- 1 | Field-Aware Factorization Machine Implemented by Java, with An Experiment Using Criteo Dataset. 2 | ================================================================================================ 3 | 4 | 5 | Get The DataSet 6 | ================ 7 | refer to: https://github.com/guestwalk/kaggle-2014-criteo/blob/master/README, 'Get The Dataset' part. 8 | 9 | 10 | Run main_script.sh 11 | =================== 12 | feature_engineer --> split train/validation set --> subsample --> change to libffm format. 13 | then run libffm: 14 | ----------------------------------------------------------------------------------- 15 | java -Xmx65g -jar ffm.jar 16 | ----------------------------------------------------------------------------------- 17 | eta: used for learning rate 18 | lambda: used for L2 regularization 19 | iter: max iterations 20 | factor: latent factor num 21 | norm: instance wise normalization 22 | rand: use random instance order when training 23 | trset: train set 24 | vaset: validation set 25 | 26 | 27 | Experiment Results: 28 | ==================== 29 | norm and rand only affect training speed. 30 | best eta is about 0.1, bigger eta hurt validation logloss, smaller eta get slow convergence. 31 | ------------------------------------------------------------------------------------------ 32 | when eta=0.1, factor=4, iter=10: 33 | lambda 0.00000 0.00001 0.00010 0.00100 0.01000 0.10000 34 | best_logloss 0.45061 0.44930 0.44951 0.46919 0.58700 0.69321 35 | convergence iter 3 5 10 10 10 10(very slow) 36 | convergence iter 10 means not oberserve convergence. 37 | ------------------------------------------------------------------------------------------- 38 | when lambda=0.0001, eta=0.1, iter=30: 39 | k=4,8,12 doesn't affect best_logloss and convergence iter. 40 | -------------------------------------------------------------------------------- /converters/change_format.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | import sys 3 | 4 | input_file = sys.argv[1] 5 | output_file = sys.argv[2] 6 | 7 | output_handle = open(output_file, 'w') 8 | for line in open(input_file): 9 | line = line.strip() 10 | fields = line.split() 11 | output_line = fields[0] 12 | for i, field in enumerate(fields[1:], start=1): 13 | output_line += " " + str(i) + ":" + field + ":1" 14 | output_line += "\n" 15 | output_handle.write(output_line) 16 | output_handle.close() 17 | 18 | -------------------------------------------------------------------------------- /converters/common.py: -------------------------------------------------------------------------------- 1 | import hashlib, csv, math, os, pickle, subprocess 2 | 3 | HEADER="Id,Label,I1,I2,I3,I4,I5,I6,I7,I8,I9,I10,I11,I12,I13,C1,C2,C3,C4,C5,C6,C7,C8,C9,C10,C11,C12,C13,C14,C15,C16,C17,C18,C19,C20,C21,C22,C23,C24,C25,C26" 4 | 5 | def open_with_first_line_skipped(path, skip=True): 6 | f = open(path) 7 | if not skip: 8 | return f 9 | next(f) 10 | return f 11 | 12 | def hashstr(str, nr_bins): 13 | return int(hashlib.md5(str.encode('utf8')).hexdigest(), 16)%(nr_bins-1)+1 14 | 15 | def gen_feats(row): 16 | feats = [] 17 | for j in range(1, 14): 18 | field = 'I{0}'.format(j) 19 | value = row[field] 20 | if value != '': 21 | value = int(value) 22 | if value > 2: 23 | value = int(math.log(float(value))**2) 24 | else: 25 | value = 'SP'+str(value) 26 | key = field + '-' + str(value) 27 | feats.append(key) 28 | for j in range(1, 27): 29 | field = 'C{0}'.format(j) 30 | value = row[field] 31 | key = field + '-' + value 32 | feats.append(key) 33 | return feats 34 | 35 | def read_freqent_feats(threshold=10): 36 | frequent_feats = set() 37 | for row in csv.DictReader(open('fc.trva.t10.txt')): 38 | if int(row['Total']) < threshold: 39 | continue 40 | frequent_feats.add(row['Field']+'-'+row['Value']) 41 | return frequent_feats 42 | 43 | def split(path, nr_thread, has_header): 44 | 45 | def open_with_header_witten(path, idx, header): 46 | f = open(path+'.__tmp__.{0}'.format(idx), 'w') 47 | if not has_header: 48 | return f 49 | f.write(header) 50 | return f 51 | 52 | def calc_nr_lines_per_thread(): 53 | nr_lines = int(list(subprocess.Popen('wc -l {0}'.format(path), shell=True, 54 | stdout=subprocess.PIPE).stdout)[0].split()[0]) 55 | if not has_header: 56 | nr_lines += 1 57 | return math.ceil(float(nr_lines)/nr_thread) 58 | 59 | header = open(path).readline() 60 | 61 | nr_lines_per_thread = calc_nr_lines_per_thread() 62 | 63 | idx = 0 64 | f = open_with_header_witten(path, idx, header) 65 | for i, line in enumerate(open_with_first_line_skipped(path, has_header), start=1): 66 | if i%nr_lines_per_thread == 0: 67 | f.close() 68 | idx += 1 69 | f = open_with_header_witten(path, idx, header) 70 | f.write(line) 71 | f.close() 72 | 73 | def parallel_convert(cvt_path, arg_paths, nr_thread): 74 | 75 | workers = [] 76 | for i in range(nr_thread): 77 | cmd = '{0}'.format(os.path.join('.', cvt_path)) 78 | for path in arg_paths: 79 | cmd += ' {0}'.format(path+'.__tmp__.{0}'.format(i)) 80 | worker = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE) 81 | workers.append(worker) 82 | for worker in workers: 83 | worker.communicate() 84 | 85 | def cat(path, nr_thread): 86 | 87 | if os.path.exists(path): 88 | os.remove(path) 89 | for i in range(nr_thread): 90 | cmd = 'cat {svm}.__tmp__.{idx} >> {svm}'.format(svm=path, idx=i) 91 | p = subprocess.Popen(cmd, shell=True) 92 | p.communicate() 93 | 94 | def delete(path, nr_thread): 95 | 96 | for i in range(nr_thread): 97 | os.remove('{0}.__tmp__.{1}'.format(path, i)) 98 | 99 | -------------------------------------------------------------------------------- /converters/parallelizer-b.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse, sys 4 | 5 | from common import * 6 | 7 | def parse_args(): 8 | 9 | if len(sys.argv) == 1: 10 | sys.argv.append('-h') 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('-s', dest='nr_thread', default=12, type=int) 14 | parser.add_argument('cvt_path') 15 | parser.add_argument('src1_path') 16 | # parser.add_argument('src2_path') 17 | parser.add_argument('dst_path') 18 | args = vars(parser.parse_args()) 19 | 20 | return args 21 | 22 | def main(): 23 | 24 | args = parse_args() 25 | 26 | nr_thread = args['nr_thread'] 27 | 28 | split(args['src1_path'], nr_thread, True) 29 | 30 | # split(args['src2_path'], nr_thread, False) 31 | 32 | parallel_convert(args['cvt_path'], [args['src1_path'], args['dst_path']], nr_thread) 33 | 34 | cat(args['dst_path'], nr_thread) 35 | 36 | delete(args['src1_path'], nr_thread) 37 | 38 | # delete(args['src2_path'], nr_thread) 39 | 40 | delete(args['dst_path'], nr_thread) 41 | 42 | main() 43 | -------------------------------------------------------------------------------- /converters/pre-b.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse, csv, sys 4 | 5 | from common import * 6 | 7 | if len(sys.argv) == 1: 8 | sys.argv.append('-h') 9 | 10 | from common import * 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('-n', '--nr_bins', type=int, default=int(1e+6)) 14 | parser.add_argument('-t', '--threshold', type=int, default=int(10)) 15 | parser.add_argument('csv_path', type=str) 16 | # parser.add_argument('gbdt_path', type=str) 17 | parser.add_argument('out_path', type=str) 18 | args = vars(parser.parse_args()) 19 | 20 | def gen_hashed_fm_feats(feats, nr_bins): 21 | feats = [(field, hashstr(feat, nr_bins)) for (field, feat) in feats] 22 | feats.sort() 23 | feats = ['{0}'.format(idx) for (field, idx) in feats] 24 | return feats 25 | 26 | frequent_feats = read_freqent_feats(args['threshold']) 27 | 28 | with open(args['out_path'], 'w') as f: 29 | for row in csv.DictReader(open(args['csv_path'])): 30 | feats = [] 31 | 32 | for feat in gen_feats(row): 33 | field = feat.split('-')[0] 34 | type, field = field[0], int(field[1:]) 35 | if type == 'C' and feat not in frequent_feats: 36 | feat = feat.split('-')[0]+'less' 37 | if type == 'C': 38 | field += 13 39 | feats.append((field, feat)) 40 | 41 | feats = gen_hashed_fm_feats(feats, args['nr_bins']) 42 | f.write(row['Label'] + ' ' + ' '.join(feats) + '\n') 43 | -------------------------------------------------------------------------------- /converters/subsample.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | import sys 3 | import random 4 | 5 | input_file = sys.argv[1] 6 | output_file = sys.argv[2] 7 | 8 | output_handle = open(output_file, 'w') 9 | for line in open(input_file): 10 | if random.random() < 0.3: 11 | output_handle.write(line) 12 | output_handle.close() 13 | -------------------------------------------------------------------------------- /ffm/.classpath: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /ffm/.fatjar: -------------------------------------------------------------------------------- 1 | #Fat Jar Configuration File 2 | #Thu May 28 16:53:05 CST 2015 3 | onejar.license.required=true 4 | manifest.classpath= 5 | manifest.removesigners=true 6 | onejar.checkbox=false 7 | jarname=ffm.jar 8 | manifest.mergeall=true 9 | manifest.mainclass=ffm.FFMModel 10 | manifest.file= 11 | jarname.isextern=false 12 | onejar.expand= 13 | excludes= 14 | includes= 15 | -------------------------------------------------------------------------------- /ffm/.project: -------------------------------------------------------------------------------- 1 | 2 | 3 | ffm 4 | 5 | 6 | 7 | 8 | 9 | org.eclipse.jdt.core.javabuilder 10 | 11 | 12 | 13 | 14 | 15 | org.eclipse.jdt.core.javanature 16 | 17 | 18 | -------------------------------------------------------------------------------- /ffm/.settings/org.eclipse.core.resources.prefs: -------------------------------------------------------------------------------- 1 | eclipse.preferences.version=1 2 | encoding/=UTF-8 3 | -------------------------------------------------------------------------------- /ffm/.settings/org.eclipse.core.runtime.prefs: -------------------------------------------------------------------------------- 1 | eclipse.preferences.version=1 2 | line.separator=\n 3 | -------------------------------------------------------------------------------- /ffm/.settings/org.eclipse.jdt.core.prefs: -------------------------------------------------------------------------------- 1 | eclipse.preferences.version=1 2 | org.eclipse.jdt.core.compiler.codegen.inlineJsrBytecode=enabled 3 | org.eclipse.jdt.core.compiler.codegen.targetPlatform=1.7 4 | org.eclipse.jdt.core.compiler.codegen.unusedLocal=preserve 5 | org.eclipse.jdt.core.compiler.compliance=1.7 6 | org.eclipse.jdt.core.compiler.debug.lineNumber=generate 7 | org.eclipse.jdt.core.compiler.debug.localVariable=generate 8 | org.eclipse.jdt.core.compiler.debug.sourceFile=generate 9 | org.eclipse.jdt.core.compiler.problem.assertIdentifier=error 10 | org.eclipse.jdt.core.compiler.problem.enumIdentifier=error 11 | org.eclipse.jdt.core.compiler.source=1.7 12 | -------------------------------------------------------------------------------- /ffm/src/ffm/FFMModel.java: -------------------------------------------------------------------------------- 1 | package ffm; 2 | 3 | import java.io.BufferedReader; 4 | import java.io.BufferedWriter; 5 | import java.io.File; 6 | import java.io.FileInputStream; 7 | import java.io.FileOutputStream; 8 | import java.io.IOException; 9 | import java.io.InputStreamReader; 10 | import java.io.OutputStreamWriter; 11 | import java.util.Random; 12 | 13 | public class FFMModel { 14 | // max(feature_num) + 1 15 | public int n; 16 | // max(field_num) + 1 17 | public int m; 18 | // latent factor dim 19 | public int k; 20 | // length = n * m * k * 2 21 | public float[] W; 22 | public boolean normalization; 23 | 24 | public FFMModel initModel(int n_, int m_, FFMParameter param) { 25 | n = n_; 26 | m = m_; 27 | k = param.k; 28 | normalization = param.normalization; 29 | W = new float[n * m * k * 2]; 30 | 31 | float coef = (float) (0.5 / Math.sqrt(k)); 32 | Random random = new Random(); 33 | 34 | int position = 0; 35 | for (int j = 0; j < n; j++) { 36 | for(int f = 0; f < m; f++) { 37 | for(int d = 0; d < k; d++) { 38 | W[position] = coef * random.nextFloat(); 39 | position += 1; 40 | } 41 | for(int d = this.k; d < 2*this.k; d++) { 42 | W[position] = 1.f; 43 | position += 1; 44 | } 45 | } 46 | } 47 | 48 | return this; 49 | } 50 | 51 | public void saveModel(String path) throws IOException { 52 | BufferedWriter bw = new BufferedWriter(new OutputStreamWriter( 53 | new FileOutputStream(new File(path)), "UTF-8")); 54 | bw.write("n " + n + "\n"); 55 | bw.write("m " + m + "\n"); 56 | bw.write("k " + k + "\n"); 57 | bw.write("normalization " + normalization + "\n"); 58 | int align0 = k * 2; 59 | int align1 = m * k * 2; 60 | for(int j=0; j 1; i--) { 122 | int tmp = order[i-1]; 123 | int index = random.nextInt(i); 124 | order[i-1] = order[index]; 125 | order[index] = tmp; 126 | } 127 | } 128 | return order; 129 | } 130 | 131 | public static float wTx(FFMProblem prob, int i, float r, FFMModel model, 132 | float kappa, float eta, float lambda, boolean do_update) { 133 | // kappa = -y * exp(-y*t) / (1+exp(-y*t)) 134 | int start = prob.P[i]; 135 | int end = prob.P[i+1]; 136 | float t = 0.f; 137 | int align0 = model.k * 2; 138 | int align1 = model.m * model.k * 2; 139 | 140 | for(int N1 = start; N1 < end; N1++) { 141 | int j1 = prob.X[N1].j; 142 | int f1 = prob.X[N1].f; 143 | float v1 = prob.X[N1].v; 144 | if(j1 >= model.n || f1 >= model.m) continue; 145 | 146 | for(int N2 = N1+1; N2 < end; N2++) { 147 | int j2 = prob.X[N2].j; 148 | int f2 = prob.X[N2].f; 149 | float v2 = prob.X[N2].v; 150 | if(j2 >= model.n || f2 >= model.m) continue; 151 | 152 | int w1_index = j1 * align1 + f2 * align0; 153 | int w2_index = j2 * align1 + f1 * align0; 154 | float v = 2.f * v1 * v2 * r; 155 | 156 | if(do_update) { 157 | int wg1_index = w1_index + model.k; 158 | int wg2_index = w2_index + model.k; 159 | float kappav = kappa * v; 160 | for(int d = 0; d < model.k; d++) { 161 | float g1 = lambda * model.W[w1_index+d] + kappav * model.W[w2_index+d]; 162 | float g2 = lambda * model.W[w2_index+d] + kappav * model.W[w1_index+d]; 163 | 164 | float wg1 = model.W[wg1_index+d] + g1 * g1; 165 | float wg2 = model.W[wg2_index+d] + g2 * g2; 166 | 167 | model.W[w1_index+d] = model.W[w1_index+d] - eta / (float)(Math.sqrt(wg1)) * g1; 168 | model.W[w2_index+d] = model.W[w2_index+d] - eta / (float)(Math.sqrt(wg2)) * g2; 169 | 170 | model.W[wg1_index+d] = wg1; 171 | model.W[wg2_index+d] = wg2; 172 | } 173 | } else { 174 | for(int d = 0; d < model.k; d++) { 175 | t += model.W[w1_index + d] * model.W[w2_index + d] * v; 176 | } 177 | } 178 | } 179 | } 180 | return t; 181 | } 182 | 183 | public static FFMModel train(FFMProblem tr, FFMProblem va, FFMParameter param) { 184 | FFMModel model = new FFMModel(); 185 | model.initModel(tr.n, tr.m, param); 186 | 187 | float[] R_tr = normalize(tr, param.normalization); 188 | float[] R_va = null; 189 | if(va != null) { 190 | R_va = normalize(va, param.normalization); 191 | } 192 | 193 | for(int iter = 0; iter < param.n_iters; iter++) { 194 | double tr_loss = 0.; 195 | int[] order = randomization(tr.l, param.random); 196 | for(int ii=0; ii " 254 | + " "); 255 | System.out.println("for example:\n" 256 | + "java -jar ffm.jar 0.1 0.01 15 4 true false tr_ va_"); 257 | } 258 | 259 | FFMProblem tr = FFMProblem.readFFMProblem(args[6]); 260 | FFMProblem va = FFMProblem.readFFMProblem(args[7]); 261 | 262 | FFMParameter param = FFMParameter.defaultParameter(); 263 | param.eta = Float.parseFloat(args[0]); 264 | param.lambda = Float.parseFloat(args[1]); 265 | param.n_iters = Integer.parseInt(args[2]); 266 | param.k = Integer.parseInt(args[3]); 267 | param.normalization = Boolean.parseBoolean(args[4]); 268 | param.random = Boolean.parseBoolean(args[5]); 269 | 270 | FFMModel.train(tr, va, param); 271 | } 272 | } 273 | -------------------------------------------------------------------------------- /ffm/src/ffm/FFMNode.java: -------------------------------------------------------------------------------- 1 | package ffm; 2 | 3 | /** 4 | * @author chenhuang 5 | * 6 | */ 7 | public class FFMNode { 8 | // field_num 9 | public int f; 10 | // feature_num 11 | public int j; 12 | // value 13 | public float v; 14 | @Override 15 | public String toString() { 16 | return "FFMNode [f=" + f + ", j=" + j + ", v=" + v + "]"; 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /ffm/src/ffm/FFMParameter.java: -------------------------------------------------------------------------------- 1 | package ffm; 2 | 3 | /** 4 | * @author chenhuang 5 | * 6 | */ 7 | public class FFMParameter { 8 | // eta used for per-coordinate learning rate 9 | public float eta; 10 | // used for l2-regularization 11 | public float lambda; 12 | // max iterations 13 | public int n_iters; 14 | // latent factor dim 15 | public int k; 16 | // instance-wise normalization 17 | public boolean normalization; 18 | // randomization training order of samples 19 | public boolean random; 20 | 21 | public static FFMParameter defaultParameter() { 22 | FFMParameter parameter = new FFMParameter(); 23 | parameter.eta = 0.1f; 24 | parameter.lambda = 0; 25 | parameter.n_iters = 15; 26 | parameter.k = 4; 27 | parameter.normalization = true; 28 | parameter.random = true; 29 | return parameter; 30 | } 31 | 32 | @Override 33 | public String toString() { 34 | return "FFMParameter [eta=" + eta + ", lambda=" + lambda + ", n_iters=" 35 | + n_iters + ", k=" + k + ", normalization=" + normalization 36 | + ", random=" + random + "]"; 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /ffm/src/ffm/FFMProblem.java: -------------------------------------------------------------------------------- 1 | package ffm; 2 | 3 | import java.io.BufferedReader; 4 | import java.io.File; 5 | import java.io.FileInputStream; 6 | import java.io.IOException; 7 | import java.io.InputStreamReader; 8 | import java.util.Arrays; 9 | 10 | /** 11 | * @author chenhuang 12 | * 13 | */ 14 | public class FFMProblem { 15 | //data : field_num:feature_num:value 16 | // max(feature_num) + 1 17 | public int n; 18 | // max(field_num) + 1 19 | public int m; 20 | public int l; 21 | // X[ [P[0], P[1]) ], length=nnz 22 | public FFMNode[] X; 23 | // length=l+1 24 | public int[] P; 25 | // Y[0], length=l 26 | public float[] Y; 27 | 28 | public static FFMProblem readFFMProblem(String path) throws IOException { 29 | FFMProblem problem = new FFMProblem(); 30 | int l = 0, nnz = 0; 31 | BufferedReader br = new BufferedReader(new InputStreamReader( 32 | new FileInputStream(new File(path)), "UTF-8")); 33 | String line = null; 34 | while((line = br.readLine()) != null) { 35 | l += 1; 36 | String[] fields = line.split(" |\t"); 37 | for(int i=1; i 0)?1.f:-1.f; 57 | for(int j=1; j= testDataSize) { 23 | position = 0; 24 | enoughData = true; 25 | } 26 | } 27 | 28 | public double getAverageLogLoss() { 29 | if(enoughData) { 30 | return totalLoss / testDataSize; 31 | } else { 32 | return totalLoss / position; 33 | } 34 | } 35 | 36 | /** prob: p(y=1|x;w), y: 1 or 0(-1) */ 37 | public static double calLogLoss(double prob, double y) { 38 | double p = Math.max(Math.min(prob, 1-1e-15), 1e-15); 39 | return y == 1.? -Math.log(p) : -Math.log(1. - p); 40 | } 41 | 42 | public static void main(String[] args) { 43 | LogLossEvalutor evalutor = new LogLossEvalutor(4); 44 | double[] losses = {3, 2, 1, 0.7, 0.5, 0.2}; 45 | for(int i=0; i fc.trva.t10.txt 4 | 5 | thread_num=10 6 | ./converters/parallelizer-b.py -s ${thread_num} ./converters/pre-b.py train.csv train.sp 7 | 8 | split -l 6548660 train.sp -d -a 4 tr_ 9 | mv tr_0006 va_ 10 | cat tr_000* > tr_ 11 | rm -rf tr_000* 12 | 13 | ./converters/subsample.py tr_ tr_sample 14 | ./converters/subsample.py va_ va_sample 15 | 16 | ./converters/change_format.py tr_sample tr_std 17 | ./converters/change_format.py va_sample va_std 18 | 19 | echo `date` 20 | 21 | java -Xmx65g -jar ffm.jar 0.1 0.00002 10 4 true false tr_std va_std > java_result 22 | 23 | echo `date` 24 | 25 | -------------------------------------------------------------------------------- /utils/count.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse, csv, sys, collections 4 | 5 | if len(sys.argv) == 1: 6 | sys.argv.append('-h') 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('csv_path', type=str) 10 | args = vars(parser.parse_args()) 11 | 12 | counts = collections.defaultdict(lambda : [0, 0, 0]) 13 | 14 | for i, row in enumerate(csv.DictReader(open(args['csv_path'])), start=1): 15 | label = row['Label'] 16 | for j in range(1, 27): 17 | field = 'C{0}'.format(j) 18 | value = row[field] 19 | if label == '0': 20 | counts[field+','+value][0] += 1 21 | else: 22 | counts[field+','+value][1] += 1 23 | counts[field+','+value][2] += 1 24 | if i % 1000000 == 0: 25 | sys.stderr.write('{0}m\n'.format(int(i/1000000))) 26 | 27 | print('Field,Value,Neg,Pos,Total,Ratio') 28 | for key, (neg, pos, total) in sorted(counts.items(), key=lambda x: x[1][2]): 29 | if total < 10: 30 | continue 31 | ratio = round(float(pos)/total, 5) 32 | print(key+','+str(neg)+','+str(pos)+','+str(total)+','+str(ratio)) 33 | --------------------------------------------------------------------------------