├── data └── .gitignore ├── work └── .gitignore ├── display-ad-java ├── src │ ├── main │ │ └── java │ │ │ └── com │ │ │ └── sigaphi │ │ │ └── kaggle │ │ │ └── displayad │ │ │ ├── Options.java │ │ │ ├── ToRedis.java │ │ │ ├── RawFeature.java │ │ │ └── FeaturesToVw.java │ └── test │ │ └── java │ │ └── com │ │ └── sigaphi │ │ └── kaggle │ │ └── displayad │ │ └── RawFeatureTest.java ├── .gitignore └── pom.xml ├── scripts ├── shuffle.py ├── submit.py └── vw_run.py ├── .gitignore ├── LICENSE ├── run.sh └── README.md /data/.gitignore: -------------------------------------------------------------------------------- 1 | *.csv 2 | -------------------------------------------------------------------------------- /work/.gitignore: -------------------------------------------------------------------------------- 1 | *.model 2 | *.txt 3 | *.gz 4 | -------------------------------------------------------------------------------- /display-ad-java/src/main/java/com/sigaphi/kaggle/displayad/Options.java: -------------------------------------------------------------------------------- 1 | package com.sigaphi.kaggle.displayad; 2 | 3 | /** 4 | * Feature options 5 | * @author Guocong Song 6 | */ 7 | public enum Options { 8 | CAT_BASIC, 9 | CAT_POP_1, 10 | CAT_EXCD, 11 | CAT_TAIL; 12 | } 13 | -------------------------------------------------------------------------------- /display-ad-java/.gitignore: -------------------------------------------------------------------------------- 1 | /target/ 2 | *.class 3 | 4 | # Mobile Tools for Java (J2ME) 5 | .mtj.tmp/ 6 | 7 | # Package Files # 8 | *.jar 9 | *.war 10 | *.ear 11 | 12 | # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml 13 | hs_err_pid* 14 | .project 15 | .settings 16 | .classpath 17 | -------------------------------------------------------------------------------- /scripts/shuffle.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @author: Guocong Song 3 | ''' 4 | import heapq 5 | import sys 6 | 7 | heapCap = int(sys.argv[1]) 8 | heap = [] 9 | for line in sys.stdin: 10 | key = hash(line) 11 | if len(heap) < heapCap: 12 | heapq.heappush(heap, (key, line)) 13 | else: 14 | _, out = heapq.heappushpop(heap, (key, line)) 15 | sys.stdout.write(out) 16 | 17 | while len(heap) > 0: 18 | _, out = heapq.heappop(heap) 19 | sys.stdout.write(out) 20 | -------------------------------------------------------------------------------- /scripts/submit.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @author: Guocong Song 3 | ''' 4 | import pandas as pd 5 | import sys 6 | import numpy as np 7 | import gzip 8 | 9 | 10 | df = pd.read_csv(sys.stdin) 11 | p = 0.55 * df.p1 + 0.15 * df.p2 + 0.15 * df.p3 + 0.15 * df.p4 12 | df['Predicted'] = prob = 1.0 / (1.0 + np.exp(-p)) 13 | 14 | submission = 'submission.cvs.gz' 15 | print('saving to', submission, '...') 16 | with gzip.open(submission, 'wt') as f: 17 | df[['Id', 'Predicted']].to_csv(f, index=False) 18 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | # C extensions 6 | *.so 7 | 8 | # Distribution / packaging 9 | .Python 10 | env/ 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | eggs/ 15 | lib/ 16 | lib64/ 17 | parts/ 18 | sdist/ 19 | var/ 20 | *.egg-info/ 21 | .installed.cfg 22 | *.egg 23 | 24 | # PyInstaller 25 | # Usually these files are written by a python script from a template 26 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 27 | *.manifest 28 | *.spec 29 | 30 | # Installer logs 31 | pip-log.txt 32 | pip-delete-this-directory.txt 33 | 34 | # Unit test / coverage reports 35 | htmlcov/ 36 | .tox/ 37 | .coverage 38 | .cache 39 | nosetests.xml 40 | coverage.xml 41 | 42 | # Translations 43 | *.mo 44 | *.pot 45 | 46 | # Django stuff: 47 | *.log 48 | 49 | # Sphinx documentation 50 | docs/_build/ 51 | 52 | # PyBuilder 53 | target/ 54 | 55 | .project 56 | .pydevproject 57 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2014 Guocong Song 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /display-ad-java/src/main/java/com/sigaphi/kaggle/displayad/ToRedis.java: -------------------------------------------------------------------------------- 1 | package com.sigaphi.kaggle.displayad; 2 | 3 | import java.io.BufferedReader; 4 | import java.io.InputStreamReader; 5 | import java.util.Map; 6 | import java.util.function.Consumer; 7 | 8 | import com.google.common.base.Joiner; 9 | 10 | import redis.clients.jedis.Jedis; 11 | import redis.clients.jedis.Pipeline; 12 | 13 | /** 14 | * Input feature value pairs to Redis 15 | * @author Guocong Song 16 | */ 17 | public class ToRedis { 18 | public static final Joiner JOIN_KV = Joiner.on(":"); 19 | private static final Jedis jedis = new Jedis("localhost", 6379); 20 | 21 | private static Consumer toRedis = (RawFeature raw) -> { 22 | Pipeline pipe = jedis.pipelined(); 23 | Map catMap = raw.getCatFields(null); 24 | catMap.entrySet().stream() 25 | .forEach(e -> pipe.zincrby(JOIN_KV.join("imp", e.getKey()), 1, e.getValue())); 26 | pipe.sync(); 27 | }; 28 | 29 | public static void main(String[] args) { 30 | BufferedReader reader = new BufferedReader(new InputStreamReader(System.in)); 31 | reader.lines() 32 | .skip(1) 33 | .map(line -> RawFeature.Builder.of(line)) 34 | .forEach(toRedis); 35 | jedis.close(); 36 | } 37 | 38 | } 39 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export VW_BIN=/opt/vw/vowpalwabbit 4 | 5 | TRAIN="cat ../data/train.csv" 6 | TEST="cat ../data/test.csv" 7 | JAVA_BIN="java -Xmx4g -cp ../display-ad-java/target/*:. com.sigaphi.kaggle.displayad" 8 | 9 | echo "import data into redis ..." 10 | $TRAIN | $JAVA_BIN.ToRedis 11 | $TEST | $JAVA_BIN.ToRedis 12 | 13 | echo "making vw input files ..." 14 | $TRAIN | $JAVA_BIN.FeaturesToVw | gzip > train.vw.gz 15 | $TEST | $JAVA_BIN.FeaturesToVw | gzip > test.vw.gz 16 | 17 | echo "training model 1 ..." 18 | python ../scripts/vw_run.py quad_11 3 6000000 19 | python ../scripts/vw_run.py quad_13 3 100000 20 | python ../scripts/vw_run.py quad_12 1 10000 21 | mv prediction_test.txt prediction_test_1.txt 22 | 23 | echo "training model 2 ..." 24 | python ../scripts/vw_run.py poly_1 6 1 25 | mv prediction_test.txt prediction_test_2.txt 26 | 27 | echo "training model 3 ..." 28 | python ../scripts/vw_run.py poly_2 6 10000 29 | mv prediction_test.txt prediction_test_3.txt 30 | 31 | echo "training model 4 ..." 32 | python ../scripts/vw_run.py poly_3 6 100000 33 | mv prediction_test.txt prediction_test_4.txt 34 | 35 | echo "making a submission file" 36 | cat <(echo "Id,p1,p2,p3,p4") <(paste -d"," <(zcat test.vw.gz | cut -f1 | cut -d"," -f1) prediction_test_1.txt prediction_test_2.txt prediction_test_3.txt prediction_test_4.txt) | python ../scripts/submit.py -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Display Advertising Challenge 2 | ============================= 3 | 4 | Description 5 | ----------- 6 | This is the code was written for the [Kaggle Criteo Competition of CTR prediction](https://www.kaggle.com/c/criteo-display-ad-challenge). 7 | 8 | Since the data are highly sparse, the basic methodology is to use logistic regression with appropriate quadratic/polynomial feature generation and regularization to make sophisticated and over-fitting-tractable models. [Vowpal Wabbit](https://github.com/JohnLangford/vowpal_wabbit) is the major machine learning software used for this project. Since the data size is challenging in terms of my personal workstation (a single quad-core CPU), the techniques of feature selection and model training are selected based on the trade off between performance and CPU/RAM resource limit. 9 | 10 | Dependencies and requirements 11 | ----------------------------- 12 | Please note that the code was written for my personal learning and practice in new features of Java 8 and Python 3.4 in Ubuntu 14.04. The code cannot be run in early versions of these two languages or other OSs. Compatibility is not considered here. 13 | 14 | * Java 8 15 | * Python 3.4 16 | * Maven 3 17 | * Redis 2.8 18 | * Pandas 0.14 19 | * Vowpal Wabbit 7.7 20 | * Java-based open source projects: (Maven will install them automatically) 21 | - guava 17.0 22 | - jedis 2.5.1 23 | - commons-lang3 3.3.2 24 | 25 | 26 | How to run 27 | ---------- 28 | - Copy train and test data file (train.csv, test.csv) to data folder 29 | - Compile the Java code by 30 | ``` 31 | $ cd display-ad-java 32 | $ mvn package # or mvn install 33 | ``` 34 | - Make sure a redis instance running at localhost:6379 35 | - Set the path of binary vw (VW_BIN) in run.sh, such as 36 | ``` 37 | export VW_BIN=/path/to/vw/binary 38 | ``` 39 | - Finally, 40 | ``` 41 | $ cd work 42 | $ ../run.sh 43 | ``` 44 | 45 | 46 | -------------------------------------------------------------------------------- /display-ad-java/src/test/java/com/sigaphi/kaggle/displayad/RawFeatureTest.java: -------------------------------------------------------------------------------- 1 | package com.sigaphi.kaggle.displayad; 2 | 3 | import static org.junit.Assert.*; 4 | 5 | import java.util.Arrays; 6 | import java.util.List; 7 | import java.util.Map; 8 | 9 | 10 | import org.junit.Test; 11 | 12 | import com.google.common.base.Splitter; 13 | 14 | public class RawFeatureTest { 15 | 16 | @Test 17 | public void test() { 18 | String str = ",,2,"; 19 | for (String s : Splitter.on(",").split(str)) { 20 | System.out.print("x:"); 21 | System.out.println(s); 22 | } 23 | for (String s : str.split(",")) { 24 | System.out.print("y:"); 25 | System.out.println(s); 26 | } 27 | List a = Arrays.asList(0,1,2,3,4); 28 | System.out.println(a.subList(1, 4)); 29 | 30 | String line = "55840616,0,0,4,13,13,2815,20,4,21,41,0,4,,14,05db9164,a0e12995,622d2ce8,51c64c6d,25c83c98,,3086a9e9,0b153874,a73ee510,3b08e48b,ebd30041,e9521d94,c7a109eb,1adce6ef,78c64a1d,ab8b968d,d4bb7bd8,1616f155,21ddcdc9,5840adea,ee4fa92e,,32c7478e,d61a7d0a,9b3e8820,b29c74dc"; 31 | String actual = RawFeature.Builder.of(line).toString(); 32 | System.out.println(RawFeature.Builder.of(line).getHeader()); 33 | System.out.println(line); 34 | System.out.println(actual); 35 | System.out.println(RawFeature.Builder.of(line).getCatFields(null)); 36 | System.out.println(RawFeature.Builder.of(line).getNumFields(null)); 37 | assertEquals(line, actual); 38 | line = "66042132,,1,6,5,283,26,81,5,42,,6,,5,05db9164,,43a795a8,be13fbd1,4cf72387,6f6d9be8,f00bddf8,6c97ac79,a73ee510,ca1bb880,55795b33,277cb5a2,39795005,b28479f6,93625cba,b06f79e3,e5ba7672,3987fb8a,21ddcdc9,5840adea,45fdf300,,32c7478e,a6e7d8d3,001f3601,2fede552"; 39 | actual = RawFeature.Builder.of(line).toString(); 40 | System.out.println(RawFeature.Builder.of(line).getHeader()); 41 | System.out.println(line); 42 | System.out.println(actual); 43 | assertEquals(line, actual); 44 | Map numMap = RawFeature.Builder.of(line).getNumFields(null); 45 | System.out.println(numMap); 46 | // System.out.println(FeaturesToVw.sos2("numN", numMap.size())); 47 | Map catMap = RawFeature.Builder.of(line).getCatFields(null); 48 | System.out.println(catMap); 49 | catMap = RawFeature.Builder.of(line).getCatFields(RawFeature.catColsExSet); 50 | System.out.println(catMap); 51 | // System.out.println(FeaturesToVw.numMapToLogString(numMap)); 52 | 53 | 54 | // Map catMap = RawFeature.Builder.of(line).getCatFields(FeaturesToVw.placeFeatures); 55 | // System.out.println(FeaturesToVw.catMapToString(catMap, Options.CAT_POP_1)); 56 | // catMap = RawFeature.Builder.of(line).getCatFields(FeaturesToVw.otherFeatures); 57 | // System.out.println(FeaturesToVw.catMapToString(catMap, Options.CAT_POP_1)); 58 | // System.out.println(FeaturesToVw.catMapToString(catMap, Options.CAT_BASIC)); 59 | } 60 | 61 | } 62 | -------------------------------------------------------------------------------- /display-ad-java/pom.xml: -------------------------------------------------------------------------------- 1 | 3 | 4.0.0 4 | com.sigaphi.kaggle.displayad 5 | display-ad-java 6 | jar 7 | 1.0-SNAPSHOT 8 | display-ad-java 9 | http://maven.apache.org 10 | 11 | 12 | junit 13 | junit 14 | 4.11 15 | test 16 | 17 | 18 | com.google.guava 19 | guava 20 | 17.0 21 | 22 | 23 | redis.clients 24 | jedis 25 | 2.5.1 26 | 27 | 28 | org.apache.commons 29 | commons-lang3 30 | 3.3.2 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | org.eclipse.m2e 39 | lifecycle-mapping 40 | 1.0.0 41 | 42 | 43 | 44 | 45 | 46 | org.apache.maven.plugins 47 | maven-dependency-plugin 48 | [2.0,) 49 | 50 | copy-dependencies 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | maven-compiler-plugin 66 | 3.0 67 | 68 | 1.8 69 | 1.8 70 | 71 | 72 | 73 | org.apache.maven.plugins 74 | maven-dependency-plugin 75 | 2.8 76 | 77 | 78 | copy-dependencies 79 | package 80 | 81 | copy-dependencies 82 | 83 | 84 | ${project.build.directory} 85 | false 86 | false 87 | true 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | -------------------------------------------------------------------------------- /scripts/vw_run.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @author: Guocong Song 3 | ''' 4 | import subprocess 5 | import sys 6 | from enum import Enum 7 | import os 8 | 9 | VW = os.path.join(os.environ['VW_BIN'], 'vw') 10 | 11 | def quadratic(a, excd=[]): 12 | b = [] 13 | def combine(a, i): 14 | if i == len(a): 15 | return 16 | for x in a[i:]: 17 | pair = ''.join([a[i], x]) 18 | if pair in excd: 19 | continue 20 | b.append(' '.join(['-q', pair])) 21 | combine(a, i + 1) 22 | combine(a, 0) 23 | return b 24 | 25 | 26 | class Option(Enum): 27 | poly_1 = 2 28 | poly_2 = 3 29 | quad_11 = 11 30 | quad_12 = 12 31 | quad_13 = 13 32 | poly_3 = 100 33 | 34 | @classmethod 35 | def fromstring(cls, str): 36 | return getattr(cls, str, None) 37 | 38 | @classmethod 39 | def features(cls, val): 40 | if val == Option.quad_11: 41 | return ' '.join(quadratic(list('pnabcdg'), excd=['aa', 'bb', 'cc', 'gg'])) \ 42 | + ' --save_resume --l2 10e-8 --l1 1.3e-8 -l 0.007 -q m:' 43 | elif val == Option.quad_12: 44 | return ' '.join(quadratic(list('pnabcdg'), excd=['aa', 'bb', 'cc', 'gg'])) \ 45 | + ' --save_resume --feature_mask log_bin.model --l2 10e-8 --l1 0e-8 -l 0.005 -q m:' 46 | elif val == Option.quad_13: 47 | return ' '.join(quadratic(list('pnabcdg'), excd=['aa', 'bb', 'cc', 'gg'])) \ 48 | + ' --save_resume --feature_mask log_bin.model --l2 10e-8 --l1 1.3e-8 -l 0.007 -q m:' 49 | elif val == Option.poly_1: 50 | return '--stage_poly --batch_sz 800000 --batch_sz_no_doubling --sched_exponent 1.96' \ 51 | + ' --save_resume --l2 6e-8 --l1 1.2e-8 -l 0.007' 52 | elif val == Option.poly_2: 53 | return '--stage_poly --batch_sz 4000000 --batch_sz_no_doubling --sched_exponent 2.3' \ 54 | + ' --save_resume --l2 6e-8 --l1 1.2e-8 -l 0.007' 55 | elif val == Option.poly_3: 56 | return '--stage_poly --batch_sz 7500000 --batch_sz_no_doubling --sched_exponent 2.5' \ 57 | + ' --save_resume --l2 6e-8 --l1 1.2e-8 -l 0.007' 58 | else: 59 | print('wrong mode:', val) 60 | sys.exit(1) 61 | 62 | 63 | def train(fname, option, passes, shfl_win): 64 | feeder = 'zcat %s | cut -f2 | python ../scripts/shuffle.py %s | ' % (fname, shfl_win) 65 | loss = '--loss_function logistic' 66 | bits = '-b 27' 67 | pass_args = '--passes %s --holdout_off -C -0.4878' % passes 68 | features = Option.features(option) 69 | update = '--adaptive --invariant --power_t 0.5' 70 | io = '-c --compressed -f log_bin.model -k' 71 | command = ' '.join([feeder, VW, loss, bits, pass_args, features, update,io]) 72 | subprocess.call(command, stdout=sys.stdout, shell=True, executable='/bin/bash') 73 | 74 | 75 | def test(fname): 76 | predictions = 'prediction_' + fname.split('.')[0] + '.txt' 77 | predictor = 'zcat %s | cut -f2 | %s -t -i log_bin.model -p %s --quiet' % (fname, VW, predictions) 78 | subprocess.call(predictor, stdout=sys.stdout, shell=True, executable='/bin/bash') 79 | 80 | 81 | if __name__ == '__main__': 82 | option = Option.fromstring(sys.argv[1]) 83 | passes = sys.argv[2] 84 | shfl_win = sys.argv[3] 85 | 86 | train('train.vw.gz', option, passes, shfl_win) 87 | test('test.vw.gz') 88 | -------------------------------------------------------------------------------- /display-ad-java/src/main/java/com/sigaphi/kaggle/displayad/RawFeature.java: -------------------------------------------------------------------------------- 1 | package com.sigaphi.kaggle.displayad; 2 | 3 | import java.util.ArrayList; 4 | import java.util.Arrays; 5 | import java.util.HashMap; 6 | import java.util.HashSet; 7 | import java.util.LinkedHashMap; 8 | import java.util.List; 9 | import java.util.Map; 10 | import java.util.Set; 11 | import java.util.function.Function; 12 | import java.util.stream.Collectors; 13 | import java.util.stream.IntStream; 14 | 15 | import org.apache.commons.lang3.StringUtils; 16 | 17 | import com.google.common.base.Joiner; 18 | import com.google.common.base.Splitter; 19 | 20 | /** 21 | * Read csv file to RawFeature objects 22 | * @author Guocong Song 23 | */ 24 | public class RawFeature { 25 | private int id; 26 | private int label; 27 | 28 | public static final int INT_NULL = -100; 29 | public static final String NA = "NA"; 30 | 31 | public static final List numCols = IntStream.range(1, 14) 32 | .mapToObj(i -> "I" + Integer.toString(i)).collect(Collectors.toList()); 33 | public static final List catCols = IntStream.range(1, 27) 34 | .mapToObj(i -> "C" + Integer.toString(i)).collect(Collectors.toList()); 35 | public static final List catColsEx = new ArrayList<>(); 36 | static { 37 | catColsEx.addAll(catCols); 38 | catColsEx.add("reqstA"); 39 | catColsEx.add("reqstB"); 40 | } 41 | public static final Set catColsSet = new HashSet(RawFeature.catCols); 42 | public static final Set catColsExSet = new HashSet(RawFeature.catColsEx); 43 | 44 | private static final List reqstA = Arrays.asList("C3", "C4", "C12", "C16", "C21", "C24"); 45 | private static final List reqstB = Arrays.asList("C19", "C20", "C25", "C26"); 46 | 47 | private final Map map = new HashMap(); 48 | 49 | public RawFeature(int id, int label, List nums, List cats) { 50 | this.id = id; 51 | this.label = label; 52 | for (int i = 0; i < numCols.size(); i++) { 53 | map.put(numCols.get(i), nums.get(i)); 54 | } 55 | for (int i = 0; i < catCols.size(); i++) { 56 | map.put(catCols.get(i), cats.get(i)); 57 | } 58 | String a = reqstA.stream().map(col -> this.getField(col, "")).collect(Collectors.joining("-")); 59 | String b = reqstB.stream().map(col -> this.getField(col, "")).collect(Collectors.joining("-")); 60 | map.put("reqstA", a); 61 | map.put("reqstB", b); 62 | } 63 | 64 | public int getId() { 65 | return id; 66 | } 67 | 68 | public int getLabel() { 69 | return label; 70 | } 71 | 72 | public String getField(String name, String empty) { 73 | String val = map.get(name); 74 | if (StringUtils.isEmpty(val)) { 75 | return empty; 76 | } 77 | return val; 78 | } 79 | 80 | public Double getNumField(String name) { 81 | String val = map.get(name); 82 | if (StringUtils.isEmpty(val)) { 83 | return INT_NULL * 1.0; 84 | } 85 | return Double.parseDouble(val); 86 | } 87 | 88 | public Map getCatFields(Set colSet) { 89 | Set set; 90 | if (colSet == null) { 91 | set = catColsSet; 92 | } else { 93 | set = colSet; 94 | } 95 | return catColsEx.stream() 96 | .filter(col -> set.contains(col)) 97 | .collect(Collectors.toMap(Function.identity(), 98 | col -> this.getField(col, NA), 99 | (x, y) -> y, 100 | LinkedHashMap::new)); 101 | } 102 | 103 | public Map getNumFields(Set colSet) { 104 | Set set; 105 | if (colSet == null) { 106 | set = new HashSet(numCols); 107 | } else { 108 | set = colSet; 109 | } 110 | return numCols.stream() 111 | .filter(col -> set.contains(col)) 112 | .collect(Collectors.toMap(Function.identity(), 113 | col -> this.getNumField(col), 114 | (x, y) -> y, 115 | LinkedHashMap::new)); 116 | } 117 | 118 | public String getSync() { 119 | return Joiner.on(",").join(id, label); 120 | } 121 | 122 | public String getHeader() { 123 | List vals = new ArrayList(); 124 | vals.add("Id"); 125 | if (label != INT_NULL) { 126 | vals.add("Label"); 127 | } 128 | numCols.stream().forEach(col -> vals.add(col)); 129 | catCols.stream().forEach(col -> vals.add(col)); 130 | return Joiner.on(",").join(vals); 131 | } 132 | 133 | public boolean hasLabel() { 134 | return label != INT_NULL; 135 | } 136 | 137 | @Override 138 | public String toString() { 139 | List vals = new ArrayList(); 140 | vals.add(Integer.toString(id)); 141 | if (label != INT_NULL) { 142 | vals.add(Integer.toString(label)); 143 | } 144 | numCols.stream().forEach(col -> vals.add(this.getField(col, ""))); 145 | catCols.stream().forEach(col -> vals.add(this.getField(col, ""))); 146 | return Joiner.on(",").join(vals); 147 | } 148 | 149 | public static class Builder { 150 | public static RawFeature of(String line) { 151 | List els = Splitter.on(",").splitToList(line); 152 | List nums, cats; 153 | int label; 154 | int id = Integer.parseInt(els.get(0)); 155 | switch (els.size()) { 156 | case 41: 157 | label = Integer.parseInt(els.get(1)); 158 | nums = els.subList(2, 15); 159 | cats = els.subList(15, 41); 160 | break; 161 | case 40: 162 | label = RawFeature.INT_NULL; 163 | nums = els.subList(1, 14); 164 | cats = els.subList(14, 40); 165 | break; 166 | default: 167 | throw new ArrayStoreException(); 168 | } 169 | return new RawFeature(id, label, nums, cats); 170 | } 171 | } 172 | } 173 | -------------------------------------------------------------------------------- /display-ad-java/src/main/java/com/sigaphi/kaggle/displayad/FeaturesToVw.java: -------------------------------------------------------------------------------- 1 | package com.sigaphi.kaggle.displayad; 2 | 3 | import java.io.BufferedReader; 4 | import java.io.InputStreamReader; 5 | import java.text.DecimalFormat; 6 | import java.util.Arrays; 7 | import java.util.HashMap; 8 | import java.util.List; 9 | import java.util.Map; 10 | import java.util.Set; 11 | import java.util.HashSet; 12 | import java.util.function.Function; 13 | import java.util.stream.Collectors; 14 | 15 | import redis.clients.jedis.Jedis; 16 | 17 | import com.google.common.base.Joiner; 18 | 19 | /** 20 | * Convert RawFeature to vw format 21 | * @author Guocong Song 22 | */ 23 | public class FeaturesToVw { 24 | public static final Joiner JOIN_KV = Joiner.on(":"); 25 | public static final Joiner JOIN_SPACE = Joiner.on(" "); 26 | public static final Joiner JOIN_CAT = Joiner.on("="); 27 | public static final Joiner JOIN = Joiner.on(""); 28 | public static final DecimalFormat DF = new DecimalFormat("#.####"); 29 | 30 | private static final Jedis jedis = new Jedis("localhost", 6379, 100000); 31 | static final Map> incld = new HashMap>(); 32 | static final Map> pop1 = new HashMap>(); 33 | static final Set featureGroupA; 34 | static final Set featureGroupB; 35 | static final Set featureGroupC; 36 | static final Set featureGroupD; 37 | static final Set featureGroupG; 38 | static final Set featureGroupP; 39 | static final Set featureGroupN; 40 | static final Map numCaps = new HashMap<>(); 41 | static final Set duplicates = new HashSet<>(Arrays.asList("C4=NA", "C12=NA", "C16=NA", "C21=NA", "C24=NA", 42 | "C20=NA", "C25=NA", "C26=NA")); 43 | static { 44 | RawFeature.catCols.stream() 45 | .forEach(col -> incld.put(col, jedis.zrangeByScore(JOIN_KV.join("imp", col), 7d, 1e9))); 46 | RawFeature.catCols.stream() 47 | .forEach(col -> pop1.put(col, jedis.zrangeByScore(JOIN_KV.join("imp", col), 60d, 1e9))); 48 | 49 | featureGroupA = new HashSet(Arrays.asList("C3", "C4", "C12", 50 | "C16", "C21", "C24")); 51 | featureGroupB = new HashSet(Arrays.asList("C2", "C15", "C18")); 52 | featureGroupC = new HashSet(Arrays.asList("C7", "C13", "C11")); 53 | featureGroupD = new HashSet(Arrays.asList("C6", "C14", "C17", "C20", "C22", "C25", "C9", "C23")); 54 | Set catFeatures = new HashSet(RawFeature.catCols); 55 | featureGroupG = catFeatures.stream() 56 | .filter(e -> !featureGroupA.contains(e)) 57 | .filter(e -> !featureGroupB.contains(e)) 58 | .filter(e -> !featureGroupC.contains(e)) 59 | .filter(e -> !featureGroupD.contains(e)) 60 | .collect(Collectors.toSet()); 61 | 62 | featureGroupP = new HashSet(Arrays.asList("I4", "I8", "I13")); 63 | Set numFeatures = new HashSet(RawFeature.numCols); 64 | featureGroupN = numFeatures.stream() 65 | .filter(e -> !featureGroupP.contains(e)) 66 | .collect(Collectors.toSet()); 67 | 68 | numCaps.put("I1", 1090d); 69 | numCaps.put("I2", 22000d); 70 | numCaps.put("I5", 3260000d); 71 | numCaps.put("I6", 162000d); 72 | numCaps.put("I7", 22000d); 73 | numCaps.put("I9", 22000d); 74 | numCaps.put("I12", 1090d); 75 | } 76 | 77 | 78 | public static String sos2(String key, double x) { 79 | String newKey = JOIN.join(key, "_"); 80 | if (x < 0) { 81 | return JOIN.join(newKey, "NA"); 82 | } 83 | if (x == 0.0) { 84 | return JOIN.join(newKey, "0"); 85 | } 86 | double xx = x; 87 | if (numCaps.containsKey(key)) { 88 | double cap = numCaps.get(key); 89 | xx = xx > cap ? cap : xx; 90 | } 91 | double y = Math.log1p(xx) * 1.4; 92 | double low = Math.floor(y); 93 | double high = Math.ceil(y); 94 | String weightLow = DF.format(high - y); 95 | String weightHigh = DF.format(y - low); 96 | return JOIN_SPACE.join(JOIN.join(newKey, (int) low, ":", weightLow), 97 | JOIN.join(newKey, (int) high, ":", weightHigh)); 98 | } 99 | 100 | public static String numMapToLogString(Map map) { 101 | return map.entrySet().stream().map(e -> { 102 | String key = e.getKey(); 103 | double val = e.getValue(); 104 | val = key.equals("I2") ? val + 3.0 : val; 105 | return sos2(key, val); 106 | }).collect(Collectors.joining(" ")); 107 | } 108 | 109 | public static String catMapToString(Map map, Options o) { 110 | List list = null; 111 | if (o == Options.CAT_BASIC) { 112 | list = map.entrySet().stream() 113 | .filter(e -> incld.get(e.getKey()).contains(e.getValue())) 114 | .filter(e -> !pop1.get(e.getKey()).contains(e.getValue())) 115 | .map(e -> JOIN_CAT.join(e.getKey(), e.getValue())) 116 | .collect(Collectors.toList()); 117 | } else if (o == Options.CAT_POP_1) { 118 | list = map.entrySet().stream() 119 | .filter(e -> pop1.get(e.getKey()).contains(e.getValue())) 120 | .map(e -> JOIN_CAT.join(e.getKey(), e.getValue())) 121 | .filter(e -> !duplicates.contains(e)) 122 | .collect(Collectors.toList()); 123 | } else { 124 | throw new NullPointerException(); 125 | } 126 | if (list == null) { 127 | return ""; 128 | } 129 | return JOIN_SPACE.join(list); 130 | } 131 | 132 | public static Function basicTran = (RawFeature raw) -> { 133 | int y = raw.getLabel() == 0 ? -1 : raw.getLabel(); 134 | String numN = numMapToLogString(raw.getNumFields(featureGroupN)); 135 | String numP = numMapToLogString(raw.getNumFields(featureGroupP)); 136 | 137 | String catPopA = catMapToString(raw.getCatFields(featureGroupA), Options.CAT_POP_1); 138 | String catPopB = catMapToString(raw.getCatFields(featureGroupB), Options.CAT_POP_1); 139 | String catPopC = catMapToString(raw.getCatFields(featureGroupC), Options.CAT_POP_1); 140 | String catPopD = catMapToString(raw.getCatFields(featureGroupD), Options.CAT_POP_1); 141 | String catPopG = catMapToString(raw.getCatFields(featureGroupG), Options.CAT_POP_1); 142 | Map catMap = raw.getCatFields(null); 143 | String catBasic = catMapToString(catMap, Options.CAT_BASIC); 144 | long cnt = catMap.entrySet().stream() 145 | .filter(e -> !incld.get(e.getKey()).contains(e.getValue()) 146 | || duplicates.contains(JOIN_CAT.join(e.getKey(), e.getValue()))) 147 | .count(); 148 | double miss = Math.log1p(cnt); 149 | String missStr = miss < 1e-6 ? "" : JOIN_CAT.join("miss", DF.format(miss)); 150 | 151 | String vw = JOIN_SPACE.join(y, "|p", numP, "|n", numN, 152 | "|a", catPopA, "|b", catPopB, "|c", catPopC, 153 | "|d", catPopD, "|g", catPopG, "|z", catBasic, "|m", missStr); 154 | return Joiner.on("\t").join(raw.getSync(), vw); 155 | }; 156 | 157 | public static void main(String[] args) { 158 | BufferedReader reader = new BufferedReader(new InputStreamReader(System.in)); 159 | reader.lines() 160 | .skip(1) 161 | .map(line -> RawFeature.Builder.of(line)) 162 | .map(basicTran) 163 | .forEach(System.out::println); 164 | jedis.close(); 165 | } 166 | 167 | } 168 | --------------------------------------------------------------------------------