├── src ├── main │ ├── java │ │ └── com │ │ │ └── cloudera │ │ │ ├── .DS_Store │ │ │ └── knittingboar │ │ │ ├── App.java │ │ │ ├── yarn │ │ │ └── avro │ │ │ │ └── generated │ │ │ │ ├── WorkerId.java │ │ │ │ ├── KnittingBoarService.java │ │ │ │ ├── ServiceError.java │ │ │ │ ├── ProgressReport.java │ │ │ │ └── FileSplit.java │ │ │ ├── conf │ │ │ └── cmdline │ │ │ │ ├── ModelTesterCmdLineDriver.java │ │ │ │ └── DataConverterCmdLineDriver.java │ │ │ ├── metrics │ │ │ └── POLRMetrics.java │ │ │ ├── records │ │ │ ├── RecordFactory.java │ │ │ ├── RCV1RecordFactory.java │ │ │ └── TwentyNewsgroupsRecordFactory.java │ │ │ ├── sgd │ │ │ ├── iterativereduce │ │ │ │ └── POLRNodeBase.java │ │ │ └── MultinomialLogisticRegressionParameterVectors_deprecated.java │ │ │ ├── messages │ │ │ └── iterativereduce │ │ │ │ ├── ParameterVectorUpdatable.java │ │ │ │ └── ParameterVector.java │ │ │ ├── io │ │ │ └── InputRecordsSplit.java │ │ │ └── utils │ │ │ └── Utils.java │ └── avro │ │ ├── KnittingBoarService.avdl │ │ └── KnittingBoarService.avpr.old └── test │ ├── resources │ ├── KBoar_Sample.model │ ├── kboar-shard-0.txt.gz │ ├── app_unit_test.properties │ └── donut_no_header.csv │ └── java │ └── com │ └── cloudera │ └── knittingboar │ ├── AppTest.java │ ├── records │ ├── TestCSVBasedDatasetRecordFactory.java │ ├── TestRCV1RecordFactory.java │ ├── TestTwentyNewsgroupsCustomRecordParseOLRRun.java │ └── Test20NewsgroupsBookParsing.java │ ├── utils │ ├── TestUtils.java │ ├── TestingUtils.java │ ├── TestRcv1SubsetConversion.java │ ├── TestConvert20NewsTestDataset.java │ ├── TestDatasetConverter.java │ └── DataUtils.java │ ├── io │ ├── TestCsvRecordParsing.java │ ├── TestSplitReset.java │ └── TestSplitCalcs.java │ ├── conf │ └── cmdline │ │ ├── TestJobDriver.java │ │ └── TestDataConverterDriver.java │ ├── sgd │ ├── iterativereduce │ │ ├── TestKnittingBoar_IRUnitSim.java │ │ └── TestPOLRIterativeReduce.java │ ├── TestPOLRMasterNode.java │ ├── TestParallelOnlineLogisticRegression.java │ ├── TestBaseSGD.java │ └── olr │ │ └── TestBaseOLRTest20Newsgroups.java │ ├── messages │ └── TestParameterVector.java │ └── metrics │ ├── TestRCV1ApplyModel.java │ ├── Test20NewsApplyModel.java │ └── TestSaveLoadModel.java ├── .gitignore ├── scripts └── convert_20newsgroups.sh ├── app.sample.properties └── README.md /src/main/java/com/cloudera/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpatanooga/KnittingBoar/HEAD/src/main/java/com/cloudera/.DS_Store -------------------------------------------------------------------------------- /src/test/resources/KBoar_Sample.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpatanooga/KnittingBoar/HEAD/src/test/resources/KBoar_Sample.model -------------------------------------------------------------------------------- /src/test/resources/kboar-shard-0.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpatanooga/KnittingBoar/HEAD/src/test/resources/kboar-shard-0.txt.gz -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.class 2 | *.jpage 3 | .classpath 4 | .settings 5 | .project 6 | app.properties 7 | log4j.properties 8 | 9 | # Package Files # 10 | *.jar 11 | *.war 12 | *.ear 13 | 14 | # Generated stuff 15 | target/ 16 | testData/ 17 | build/ 18 | -------------------------------------------------------------------------------- /src/main/java/com/cloudera/knittingboar/App.java: -------------------------------------------------------------------------------- 1 | package com.cloudera.knittingboar; 2 | 3 | /** 4 | * Hello world! 5 | * 6 | */ 7 | public class App 8 | { 9 | public static void main( String[] args ) 10 | { 11 | 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /scripts/convert_20newsgroups.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | JAVA_HOME="/usr/lib/jvm/java-1.6.0" 4 | JAVA_OPTS="" 5 | JAR=" com.cloudera.knittingboar.conf.cmdline.DataConverterCmdLineDriver" 6 | 7 | execJAR="../target/KnittingBoar-1.0-SNAPSHOT-jar-with-dependencies.jar" 8 | 9 | JAVA_CMD="/usr/bin/java" 10 | 11 | 12 | 13 | export HADOOP_NICENESS=0 14 | 15 | cmd=$1 16 | 17 | $JAVA_CMD -cp $execJAR:$(echo lib/*.jar | tr ' ' ':') ${JAR} $@ 18 | 19 | #echo ../target/lib/*.jar | tr ' ' ':' 20 | -------------------------------------------------------------------------------- /src/main/java/com/cloudera/knittingboar/yarn/avro/generated/WorkerId.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Autogenerated by Avro 3 | * 4 | * DO NOT EDIT DIRECTLY 5 | */ 6 | package com.cloudera.knittingboar.yarn.avro.generated; 7 | @SuppressWarnings("all") 8 | @org.apache.avro.specific.FixedSize(32) 9 | public class WorkerId extends org.apache.avro.specific.SpecificFixed { 10 | public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"fixed\",\"name\":\"WorkerId\",\"namespace\":\"com.cloudera.knittingboar.yarn.avro.generated\",\"size\":32}"); 11 | 12 | /** Creates a new WorkerId */ 13 | public WorkerId() { 14 | super(); 15 | } 16 | 17 | /** Creates a new WorkerId with the given bytes */ 18 | public WorkerId(byte[] bytes) { 19 | super(bytes); 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /src/test/java/com/cloudera/knittingboar/AppTest.java: -------------------------------------------------------------------------------- 1 | package com.cloudera.knittingboar; 2 | 3 | import junit.framework.Test; 4 | import junit.framework.TestCase; 5 | import junit.framework.TestSuite; 6 | 7 | /** 8 | * Unit test for simple App. 9 | */ 10 | public class AppTest 11 | extends TestCase 12 | { 13 | /** 14 | * Create the test case 15 | * 16 | * @param testName name of the test case 17 | */ 18 | public AppTest( String testName ) 19 | { 20 | super( testName ); 21 | } 22 | 23 | /** 24 | * @return the suite of tests being tested 25 | */ 26 | public static Test suite() 27 | { 28 | return new TestSuite( AppTest.class ); 29 | } 30 | 31 | /** 32 | * Rigourous Test :-) 33 | */ 34 | public void testApp() 35 | { 36 | assertTrue( true ); 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /src/main/java/com/cloudera/knittingboar/conf/cmdline/ModelTesterCmdLineDriver.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.cloudera.knittingboar.conf.cmdline; 19 | 20 | public class ModelTesterCmdLineDriver { 21 | 22 | } 23 | -------------------------------------------------------------------------------- /src/test/java/com/cloudera/knittingboar/records/TestCSVBasedDatasetRecordFactory.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.cloudera.knittingboar.records; 19 | 20 | import junit.framework.TestCase; 21 | 22 | public class TestCSVBasedDatasetRecordFactory extends TestCase { 23 | 24 | } 25 | -------------------------------------------------------------------------------- /app.sample.properties: -------------------------------------------------------------------------------- 1 | # This is the path for the KnittingBoar JAR 2 | iterativereduce.jar.path=iterativereduce-0.1-SNAPSHOT.jar 3 | 4 | # Path to your application (which was compiled against KB!) 5 | app.jar.path=KnittingBoar-1.0-SNAPSHOT-jar-with-dependencies.jar 6 | 7 | # Comma separated list of other JAR's required for depenedencies 8 | app.lib.jar.path=avro-1.7.1.jar,avro-ipc-1.7.1.jar 9 | 10 | # Input file(s) to process 11 | app.input.path=hdfs:///user/yarn/kboar/input/kboar-shard-0.txt 12 | 13 | # Output results to 14 | app.output.path=hdfs:///user/yarn/kboar/output 15 | 16 | # Number of iterations 17 | app.iteration.count=3 18 | 19 | app.name=IR_SGD_Broski 20 | 21 | # Requested memory for YARN clients 22 | yarn.memory=512 23 | # The main() class/entry for the AppMaster 24 | yarn.master.main=com.cloudera.knittingboar.sgd.iterativereduce.POLRMasterNode 25 | # Any extra command-line args 26 | yarn.master.args= 27 | 28 | # The main() class/entry for the AppWorker 29 | yarn.worker.main=com.cloudera.knittingboar.sgd.iterativereduce.POLRWorkerNode 30 | # Any extra command-line args 31 | yarn.worker.args= 32 | 33 | # Any other configuration params, will be pushed down to clients 34 | com.cloudera.knittingboar.setup.FeatureVectorSize=10000 35 | com.cloudera.knittingboar.setup.numCategories=20 36 | com.cloudera.knittingboar.setup.RecordFactoryClassname=com.cloudera.knittingboar.records.TwentyNewsgroupsRecordFactory 37 | -------------------------------------------------------------------------------- /src/test/resources/app_unit_test.properties: -------------------------------------------------------------------------------- 1 | # This is the path for the KnittingBoar JAR 2 | iterativereduce.jar.path=iterativereduce-0.1-SNAPSHOT.jar 3 | 4 | # Path to your application (which was compiled against KB!) 5 | app.jar.path=KnittingBoar-1.0-SNAPSHOT-jar-with-dependencies.jar 6 | 7 | # Comma separated list of other JAR's required for depenedencies 8 | app.lib.jar.path=avro-1.7.1.jar,avro-ipc-1.7.1.jar 9 | 10 | # Input file(s) to process 11 | app.input.path=/Users/jpatterson/Downloads/datasets/20news-kboar/train4/ 12 | 13 | # Output results to 14 | app.output.path=hdfs:///user/yarn/kboar/output 15 | 16 | # Number of iterations 17 | app.iteration.count=3 18 | 19 | app.name=IR_SGD_Broski 20 | 21 | # Requested memory for YARN clients 22 | yarn.memory=512 23 | # The main() class/entry for the AppMaster 24 | yarn.master.main=com.cloudera.knittingboar.sgd.iterativereduce.POLRMasterNode 25 | # Any extra command-line args 26 | yarn.master.args= 27 | 28 | # The main() class/entry for the AppWorker 29 | yarn.worker.main=com.cloudera.knittingboar.sgd.iterativereduce.POLRWorkerNode 30 | # Any extra command-line args 31 | yarn.worker.args= 32 | 33 | # Any other configuration params, will be pushed down to clients 34 | com.cloudera.knittingboar.setup.FeatureVectorSize=10000 35 | com.cloudera.knittingboar.setup.numCategories=20 36 | com.cloudera.knittingboar.setup.RecordFactoryClassname=com.cloudera.knittingboar.records.TwentyNewsgroupsRecordFactory 37 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | KnittingBoar 2 | ============ 3 | 4 | Parallel Iterative Algorithm (SGD) on Hadoop's YARN framework 5 | 6 | * Built on top of BSP-style computation framework "Iterative Reduce" (Hadoop / YARN) 7 | * Uses Mahout's implementation of Stochastic Gradient Descent (SGD) as basis for worker process 8 | 9 | Slides From Hadoop World 2012: 10 | 11 | http://www.cloudera.com/content/cloudera/en/resources/library/hadoopworld/strata-hadoop-world-2012-knitting-boar_slide_deck.html 12 | 13 | Knitting Boar is an experimental machine learning application which parallelizes Mahout's Stochastic Gradient Descent on top of a new YARN based framework for Hadoop called [Iterative Reduce](https://github.com/emsixteeen/IterativeReduce) 14 | 15 | * [Intro to Knitting Boar](https://github.com/jpatanooga/KnittingBoar/wiki/Intro-to-Knitting-Boar) 16 | * [Quick Start] (https://github.com/jpatanooga/KnittingBoar/wiki/Quick-Start) 17 | * [Frequently Asked Questions] (https://github.com/jpatanooga/KnittingBoar/wiki/FAQ) 18 | * [Command Line Usage] (https://github.com/jpatanooga/KnittingBoar/wiki/Command-Line-Usage) 19 | * [Knitting Boar Internals] (https://github.com/jpatanooga/KnittingBoar/wiki/Knitting-Boar-Internals) 20 | * [Iterative Reduce] (https://github.com/jpatanooga/KnittingBoar/wiki/Iterative-Reduce) 21 | * [Outstanding Issues] (https://github.com/jpatanooga/KnittingBoar/wiki/Outstanding-Issues) 22 | * [Parallel SGD Resources] (https://github.com/jpatanooga/KnittingBoar/wiki/Parallel-SGD-Resources) 23 | -------------------------------------------------------------------------------- /src/main/java/com/cloudera/knittingboar/metrics/POLRMetrics.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.cloudera.knittingboar.metrics; 19 | 20 | /** 21 | * this is the class we'll use to report worker node perf to master node 22 | * 23 | * @author jpatterson 24 | * 25 | */ 26 | public class POLRMetrics { 27 | 28 | public String WorkerNodeIPAddress = null; 29 | public String WorkerNodeInputDataSplit = null; 30 | 31 | public long AvgBatchProecssingTimeInMS = 0; 32 | 33 | public long TotalInputProcessingTimeInMS = 0; 34 | public long TotalRecordsProcessed = 0; 35 | 36 | public double AvgLogLikelihood = 0.0; 37 | public double AvgCorrect = 0.0; 38 | 39 | } 40 | -------------------------------------------------------------------------------- /src/test/java/com/cloudera/knittingboar/utils/TestUtils.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.cloudera.knittingboar.utils; 19 | 20 | import java.util.List; 21 | 22 | import com.google.common.collect.Lists; 23 | 24 | import junit.framework.TestCase; 25 | 26 | public class TestUtils extends TestCase { 27 | 28 | public void testStringStuff() { 29 | 30 | String s = "x,y,shape,color,k,k0,xx,xy,yy,a,b,c,bias"; 31 | 32 | String[] ar = s.split(","); 33 | 34 | List values = Lists.newArrayList(s.split(",")); 35 | 36 | System.out.println( ">" + ar.length); 37 | 38 | 39 | System.out.println( ">" + values.size()); 40 | 41 | } 42 | 43 | } 44 | -------------------------------------------------------------------------------- /src/test/java/com/cloudera/knittingboar/utils/TestingUtils.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package com.cloudera.knittingboar.utils; 18 | 19 | import java.io.File; 20 | import java.io.FileOutputStream; 21 | import java.io.IOException; 22 | import java.net.URL; 23 | import java.util.zip.GZIPInputStream; 24 | 25 | import com.google.common.io.ByteStreams; 26 | import com.google.common.io.Resources; 27 | 28 | public class TestingUtils { 29 | 30 | public static void copyDecompressed(String resource, File output) 31 | throws IOException { 32 | URL input = Resources.getResource(resource); 33 | ByteStreams.copy(new GZIPInputStream(input.openStream()), 34 | new FileOutputStream(output)); 35 | } 36 | 37 | } 38 | -------------------------------------------------------------------------------- /src/test/java/com/cloudera/knittingboar/utils/TestRcv1SubsetConversion.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.cloudera.knittingboar.utils; 19 | 20 | import java.io.IOException; 21 | 22 | import junit.framework.TestCase; 23 | 24 | public class TestRcv1SubsetConversion extends TestCase { 25 | 26 | String file = "/Users/jpatterson/Downloads/rcv1/rcv1.train.vw"; 27 | 28 | public void testConversion() throws IOException { 29 | 30 | 31 | int c = DatasetConverter.ExtractSubsetofRCV1V2ForTraining(file, "/Users/jpatterson/Downloads/rcv1/subset/train-unit-test/", 100000, 10000); 32 | 33 | System.out.println( "count: " + c); 34 | 35 | //assertEquals( 11314, count ); 36 | 37 | 38 | } 39 | 40 | } 41 | -------------------------------------------------------------------------------- /src/test/java/com/cloudera/knittingboar/io/TestCsvRecordParsing.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.cloudera.knittingboar.io; 19 | 20 | import java.util.ArrayList; 21 | 22 | import com.google.common.base.CharMatcher; 23 | import com.google.common.base.Splitter; 24 | import com.google.common.collect.Lists; 25 | 26 | import junit.framework.TestCase; 27 | 28 | /** 29 | * This explodes if you also have the google-collections jar in the classpath. 30 | * 31 | * fix: remove google-collections 32 | * 33 | * @author jpatterson 34 | * 35 | */ 36 | public class TestCsvRecordParsing extends TestCase { 37 | 38 | 39 | public void testParse() { 40 | 41 | String line = "\"a\", \"b\", \"c\"\n"; 42 | 43 | Splitter COMMA = Splitter.on(',').trimResults(CharMatcher.is('"')); 44 | ArrayList variableNames = Lists.newArrayList(COMMA.split(line)); 45 | 46 | 47 | System.out.println( variableNames ); 48 | 49 | } 50 | 51 | } 52 | -------------------------------------------------------------------------------- /src/test/java/com/cloudera/knittingboar/utils/TestConvert20NewsTestDataset.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.cloudera.knittingboar.utils; 19 | 20 | import java.io.IOException; 21 | 22 | import junit.framework.TestCase; 23 | 24 | public class TestConvert20NewsTestDataset extends TestCase { 25 | 26 | 27 | public void testNaiveBayesFormatConverter() throws IOException { 28 | 29 | int count = DatasetConverter.ConvertNewsgroupsFromSingleFiles("/Users/jpatterson/Downloads/datasets/20news-bydate/20news-bydate-train/", "/Users/jpatterson/Downloads/datasets/20news-kboar/train-dataset-unit-test/", 21000); 30 | 31 | //int count = DatasetConverter.ConvertNewsgroupsFromSingleFiles("/Users/jpatterson/Downloads/datasets/20news-bydate/20news-bydate-train/", "/Users/jpatterson/Downloads/datasets/20news-kboar/train4/", 2850); 32 | 33 | System.out.println( "Total: " + count ); 34 | 35 | assertEquals( 11314, count ); 36 | 37 | } 38 | 39 | 40 | } 41 | -------------------------------------------------------------------------------- /src/test/java/com/cloudera/knittingboar/utils/TestDatasetConverter.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.cloudera.knittingboar.utils; 19 | 20 | import java.io.FileNotFoundException; 21 | import java.io.IOException; 22 | 23 | import junit.framework.TestCase; 24 | 25 | public class TestDatasetConverter extends TestCase { 26 | 27 | public void test20NewsgroupsFormatConverterForNWorkers() throws IOException { 28 | 29 | //DatasetConverter.ConvertNewsgroupsFromNaiveBayesFormat("/Users/jpatterson/Downloads/datasets/20news-processed/train/", "/Users/jpatterson/Downloads/datasets/20news-kboar/train2/"); 30 | 31 | //int count = DatasetConverter.ConvertNewsgroupsFromSingleFiles("/Users/jpatterson/Downloads/datasets/20news-bydate/20news-bydate-train/", "/Users/jpatterson/Downloads/datasets/20news-kboar/train3/", 5657); 32 | 33 | int count = DatasetConverter.ConvertNewsgroupsFromSingleFiles("/Users/jpatterson/Downloads/datasets/20news-bydate/20news-bydate-train/", "/Users/jpatterson/Downloads/datasets/20news-kboar/train4/", 2850); 34 | 35 | assertEquals( 11314, count ); 36 | 37 | } 38 | 39 | } 40 | -------------------------------------------------------------------------------- /src/test/java/com/cloudera/knittingboar/conf/cmdline/TestJobDriver.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.cloudera.knittingboar.conf.cmdline; 19 | 20 | import java.io.PrintWriter; 21 | import java.io.StringWriter; 22 | 23 | import org.apache.mahout.classifier.sgd.TrainLogistic; 24 | 25 | import junit.framework.TestCase; 26 | 27 | /** 28 | * This is a test for the code where we planned on having a driver similar to MapReduce 29 | * - hasnt materialized yet 30 | * 31 | */ 32 | public class TestJobDriver extends TestCase { 33 | 34 | public void testBasics() throws Exception { 35 | 36 | 37 | StringWriter sw = new StringWriter(); 38 | PrintWriter pw = new PrintWriter(sw, true); 39 | String[] params = new String[]{ 40 | "--input", "donut.csv", 41 | "--output", "foo.model", 42 | "--features", "20", 43 | "--passes", "100", 44 | "--rate", "50" 45 | }; 46 | 47 | ModelTrainerCmdLineDriver.mainToOutput(params, pw); 48 | 49 | String trainOut = sw.toString(); 50 | assertTrue(trainOut.contains("Parse:correct")); 51 | 52 | 53 | } 54 | 55 | 56 | 57 | } 58 | -------------------------------------------------------------------------------- /src/main/java/com/cloudera/knittingboar/records/RecordFactory.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.cloudera.knittingboar.records; 19 | 20 | import java.util.List; 21 | import java.util.Map; 22 | import java.util.Set; 23 | 24 | //import org.apache.mahout.classifier.sgd.RecordFactory; 25 | import org.apache.mahout.math.Vector; 26 | 27 | /** 28 | * Base interface for Knitting Boar's vectorization system 29 | * 30 | * @author jpatterson 31 | * 32 | */ 33 | public interface RecordFactory { 34 | 35 | public static String TWENTYNEWSGROUPS_RECORDFACTORY = "com.cloudera.knittingboar.records.TwentyNewsgroupsRecordFactory"; 36 | public static String RCV1_RECORDFACTORY = "com.cloudera.knittingboar.records.RCV1RecordFactory"; 37 | public static String CSV_RECORDFACTORY = "com.cloudera.knittingboar.records.CSVRecordFactory"; 38 | 39 | public int processLine(String line, Vector featureVector) throws Exception; 40 | 41 | public String GetClassnameByID(int id); 42 | 43 | // Map> getTraceDictionary(); 44 | 45 | // RecordFactory includeBiasTerm(boolean useBias); 46 | 47 | public List getTargetCategories(); 48 | 49 | // void firstLine(String line); 50 | 51 | } 52 | -------------------------------------------------------------------------------- /src/main/avro/KnittingBoarService.avdl: -------------------------------------------------------------------------------- 1 | @namespace ("com.cloudera.knittingboar.yarn.avro.generated") 2 | protocol KnittingBoarService { 3 | 4 | fixed WorkerId (32); 5 | 6 | record FileSplit { 7 | string path; 8 | long offset; 9 | long length; 10 | } 11 | 12 | record StartupConfiguration { 13 | FileSplit split; 14 | int iterations; 15 | int batchSize; 16 | union { map, null } other; 17 | } 18 | 19 | record ProgressReport { 20 | WorkerId workerId; 21 | map report; 22 | } 23 | 24 | error ServiceError { 25 | union { null, string } description; 26 | } 27 | 28 | // Worker starts up, sends a startup() call 29 | // @returns master sends back a startup configuration record 30 | // @throws ServiceError 31 | StartupConfiguration startup(WorkerId workerId) throws ServiceError; 32 | 33 | // Periodically the worker will send a progress update 34 | // @returns true 35 | boolean progress(WorkerId workerId, union { ProgressReport, null } report); 36 | 37 | // Send a vector update (right now just a series of bytes) 38 | // @return true if update has been received, false on error 39 | boolean update(WorkerId workerId, bytes data); 40 | 41 | // Worker is waiting to get an update from the master. 42 | // @param workerId the workerId 43 | // @param lastUpdate the last update ID 44 | // @param waiting length of time worker has been waiting in this loop 45 | // @return the latest available updateId. If master is in update mode (e.g. it has received an update from at least one worker, it returns -1) 46 | int waiting(WorkerId workerId, int lastUpdate, long waiting); 47 | 48 | // Worker will fetch a new update vector from master 49 | // @param workerId the workerId 50 | // @param updateId the updateId worker is looking for 51 | // @return byte array of new vector 52 | bytes fetch(WorkerId workerId, int updateId); 53 | 54 | // Worker notifies master that work is complete 55 | // @returns void, no return -- informative message only 56 | void complete(WorkerId workerId, ProgressReport finalReport) oneway; 57 | 58 | // Worker informs master that an error occurred 59 | // @returns void -- this is an error condition and assumes the worker is no longer active 60 | void `error`(WorkerId workerId, string message) oneway; 61 | 62 | // Used by worker to send metrics to master 63 | void metricsReport(WorkerId workerId, map metrics) oneway; 64 | } 65 | -------------------------------------------------------------------------------- /src/test/java/com/cloudera/knittingboar/conf/cmdline/TestDataConverterDriver.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.cloudera.knittingboar.conf.cmdline; 19 | 20 | import java.io.File; 21 | import java.io.IOException; 22 | import java.io.PrintWriter; 23 | import java.io.StringWriter; 24 | 25 | import org.apache.commons.io.FileUtils; 26 | 27 | import junit.framework.TestCase; 28 | 29 | import com.cloudera.knittingboar.utils.DataUtils; 30 | 31 | public class TestDataConverterDriver extends TestCase { 32 | 33 | public void testBasics() throws Exception { 34 | File twentyNewsGroupDataDir = DataUtils.getTwentyNewsGroupDir(); 35 | File inputDir = new File(twentyNewsGroupDataDir, "20news-bydate-train"); 36 | File outputDir = new File(twentyNewsGroupDataDir, "20news-bydate-converted"); 37 | FileUtils.deleteQuietly(outputDir); 38 | if(!(outputDir.isDirectory() || outputDir.mkdir())) { 39 | throw new IOException("Could not mkdir " + outputDir); 40 | } 41 | StringWriter sw = new StringWriter(); 42 | PrintWriter pw = new PrintWriter(sw, true); 43 | String[] params = new String[]{ 44 | "--input", inputDir.getAbsolutePath() + "/", 45 | "--output", outputDir.getAbsolutePath() + "/", 46 | "--recordsPerBlock", "2000" 47 | }; 48 | DataConverterCmdLineDriver.mainToOutput(params, pw); 49 | String trainOut = sw.toString(); 50 | assertTrue(trainOut.contains("Total Records Converted: 11314")); 51 | File[] files = outputDir.listFiles(); 52 | for (int x = 0; x < files.length; x++ ) { 53 | System.out.println(files[x].toString()); 54 | } 55 | assertEquals( 6, files.length ); 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /src/main/java/com/cloudera/knittingboar/sgd/iterativereduce/POLRNodeBase.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.cloudera.knittingboar.sgd.iterativereduce; 19 | 20 | import org.apache.hadoop.conf.Configuration; 21 | 22 | /** 23 | * Base class for IR-KnittingBoar nodes 24 | * 25 | * @author jpatterson 26 | * 27 | */ 28 | public class POLRNodeBase { 29 | 30 | protected Configuration conf = null; 31 | protected int num_categories = 2; 32 | protected int FeatureVectorSize = -1; 33 | // protected int BatchSize = 200; 34 | protected double Lambda = 1.0e-4; 35 | protected double LearningRate = 10; 36 | 37 | String LocalInputSplitPath = ""; 38 | String PredictorLabelNames = ""; 39 | String PredictorVariableTypes = ""; 40 | protected String TargetVariableName = ""; 41 | protected String ColumnHeaderNames = ""; 42 | protected int NumberIterations = 1; 43 | 44 | protected int LocalBatchCountForIteration = 0; 45 | protected int GlobalBatchCountForIteration = 0; 46 | 47 | protected String RecordFactoryClassname = ""; 48 | 49 | protected String LoadStringConfVarOrException(String ConfVarName, 50 | String ExcepMsg) throws Exception { 51 | 52 | if (null == this.conf.get(ConfVarName)) { 53 | throw new Exception(ExcepMsg); 54 | } else { 55 | return this.conf.get(ConfVarName); 56 | } 57 | 58 | } 59 | 60 | protected int LoadIntConfVarOrException(String ConfVarName, String ExcepMsg) 61 | throws Exception { 62 | 63 | if (null == this.conf.get(ConfVarName)) { 64 | throw new Exception(ExcepMsg); 65 | } else { 66 | return this.conf.getInt(ConfVarName, 0); 67 | } 68 | 69 | } 70 | 71 | public Configuration getConf() { 72 | return this.conf; 73 | } 74 | 75 | } 76 | -------------------------------------------------------------------------------- /src/main/avro/KnittingBoarService.avpr.old: -------------------------------------------------------------------------------- 1 | { "namespace": "com.cloudera.wovenwabbit.yarn.avro.generated", 2 | "protocol": "KnittingBoarService", 3 | 4 | "types": [ 5 | 6 | // Worker specific 7 | 8 | { "name": "WorkerState", 9 | "type": "enum", 10 | "symbols": [ 11 | "READY", 12 | "RUNNING", 13 | "WAITING", 14 | "UPDATE", 15 | "TERMINATING", 16 | "COMPLETED", 17 | "SHUTDOWN" 18 | ] 19 | }, 20 | { "name": "WorkerId", 21 | "type": "fixed", 22 | "size": 32 23 | }, 24 | { "name": "WorkerPayload", 25 | "type": "record", 26 | "fields": [ 27 | { "name": "PayloadType", "type": { 28 | "name": "WorkerPayloadTypes", 29 | "type": "enum", 30 | "symbols": [ 31 | "VECTOR", 32 | "OTHER" 33 | ] 34 | } 35 | }, 36 | { "name": "Payload", "type": "bytes" } 37 | ] 38 | }, 39 | { "name": "WorkerMessage", 40 | "type": "record", 41 | "fields": [ 42 | { "name": "WorkerId", "type": "WorkerId" }, 43 | { "name": "State", "type": "WorkerState" }, 44 | { "name": "Iteration", "type": [ "null", "int" ] }, 45 | { "name": "Batch", "type": [ "null", "int" ] }, 46 | { "name": "Payload", "type": [ "null", "WorkerPayload" ] } 47 | ] 48 | }, 49 | 50 | // Master specific 51 | 52 | { "name": "MasterMessageType", 53 | "type": "enum", 54 | "symbols": [ 55 | "START", 56 | "CONTINUE", 57 | "UPDATE", 58 | "WAIT", 59 | "TERMINATE", 60 | "SHUTDOWN" 61 | ] 62 | }, 63 | { "name": "MasterStartupPayload", 64 | "type": "record", 65 | "fields": [ 66 | { "name": "FileSplit", "type": "string" }, 67 | { "name": "Iterations", "type": "int" }, 68 | { "name": "BatchSize", "type": "int" }, 69 | { "name": "Other", "type": [ "null", { "type": "map", "values": "string" } ] } 70 | ] 71 | }, 72 | { "name": "MasterUpdatePayload", 73 | "type": "record", 74 | "fields": [ 75 | { "name": "PayloadType", "type": { 76 | "name": "MasterPayloadTypes", 77 | "type": "enum", 78 | "symbols": [ 79 | "VECTOR", 80 | "OTHER" 81 | ] 82 | } 83 | }, 84 | { "name": "Payload", "type": "bytes" } 85 | ] 86 | }, 87 | { "name": "MasterMessage", 88 | "type": "record", 89 | "fields": [ 90 | { "name": "WorkerId", "type": "WorkerId" }, 91 | { "name": "MessageType", "type": "MasterMessageType" }, 92 | { "name": "Payload", "type": [ "null", "MasterStartupPayload", "MasterUpdatePayload"] }, 93 | { "name": "Additional", "type": ["null", "string"] } 94 | ] 95 | } 96 | ], 97 | 98 | // RPC Messages 99 | "messages": { 100 | "startup": { 101 | "request": [{ "name": "workerId", "type": "WorkerId" }], 102 | "response": "boolean"; 103 | "one-way": true 104 | }, 105 | 106 | "hearbeat": { 107 | "request": [{ "name": "message", "type": "WorkerMessage" }], 108 | "response": "MasterMessage" 109 | } 110 | } 111 | } -------------------------------------------------------------------------------- /src/test/java/com/cloudera/knittingboar/utils/DataUtils.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package com.cloudera.knittingboar.utils; 18 | 19 | import java.io.BufferedReader; 20 | import java.io.File; 21 | import java.io.IOException; 22 | import java.io.InputStreamReader; 23 | import java.net.URL; 24 | 25 | import org.apache.commons.io.FileUtils; 26 | 27 | public class DataUtils { 28 | 29 | private static File twentyNewsGroups; 30 | private static final String TWENTY_NEWS_GROUP_LOCAL_DIR = "knittingboar-20news"; 31 | private static final String TWENTY_NEWS_GROUP_TAR_URL = "http://people.csail.mit.edu/jrennie/20Newsgroups/20news-bydate.tar.gz"; 32 | private static final String TWENTY_NEWS_GROUP_TAR_FILE_NAME = "20news-bydate.tar.gz"; 33 | 34 | public static String get20NewsgroupsLocalDataLocation() { 35 | 36 | File tmpDir = new File("/tmp"); 37 | if(!tmpDir.isDirectory()) { 38 | tmpDir = new File(System.getProperty("java.io.tmpdir")); 39 | } 40 | File baseDir = new File(tmpDir, TWENTY_NEWS_GROUP_LOCAL_DIR); 41 | 42 | 43 | return baseDir.toString(); 44 | 45 | } 46 | 47 | public static synchronized File getTwentyNewsGroupDir() throws IOException { 48 | if(twentyNewsGroups != null) { 49 | return twentyNewsGroups; 50 | } 51 | // mac gives unique tmp each run and we want to store this persist 52 | // this data across restarts 53 | File tmpDir = new File("/tmp"); 54 | if(!tmpDir.isDirectory()) { 55 | tmpDir = new File(System.getProperty("java.io.tmpdir")); 56 | } 57 | File baseDir = new File(tmpDir, TWENTY_NEWS_GROUP_LOCAL_DIR); 58 | if(!(baseDir.isDirectory() || baseDir.mkdir())) { 59 | throw new IOException("Could not mkdir " + baseDir); 60 | } 61 | File tarFile = new File(baseDir, TWENTY_NEWS_GROUP_TAR_FILE_NAME); 62 | 63 | if(!tarFile.isFile()) { 64 | FileUtils.copyURLToFile(new URL(TWENTY_NEWS_GROUP_TAR_URL), tarFile); 65 | } 66 | 67 | Process p = Runtime.getRuntime().exec(String.format("tar -C %s -xvf %s", 68 | baseDir.getAbsolutePath(), tarFile.getAbsolutePath())); 69 | BufferedReader stdError = new BufferedReader(new 70 | InputStreamReader(p.getErrorStream())); 71 | System.out.println("Here is the standard error of the command (if any):\n"); 72 | String s; 73 | while ((s = stdError.readLine()) != null) { 74 | System.out.println(s); 75 | } 76 | stdError.close(); 77 | twentyNewsGroups = baseDir; 78 | return twentyNewsGroups; 79 | } 80 | 81 | 82 | 83 | } 84 | -------------------------------------------------------------------------------- /src/main/java/com/cloudera/knittingboar/messages/iterativereduce/ParameterVectorUpdatable.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.cloudera.knittingboar.messages.iterativereduce; 19 | 20 | import java.io.IOException; 21 | import java.nio.ByteBuffer; 22 | 23 | import org.apache.mahout.math.Matrix; 24 | 25 | //import com.cloudera.knittingboar.sgd.GradientBuffer; 26 | import com.cloudera.iterativereduce.Updateable; 27 | 28 | public class ParameterVectorUpdatable implements 29 | Updateable { 30 | 31 | 32 | ParameterVector param_msg = null; 33 | 34 | public ParameterVectorUpdatable() {} 35 | 36 | public ParameterVectorUpdatable(ParameterVector g) { 37 | this.param_msg = g; 38 | } 39 | 40 | @Override 41 | public void fromBytes(ByteBuffer b) { 42 | 43 | b.rewind(); 44 | 45 | // System.out.println( " > ParameterVectorGradient::fromBytes > b: " + 46 | // b.array().length + ", remaining: " + b.remaining() ); 47 | 48 | try { 49 | this.param_msg = new ParameterVector(); 50 | this.param_msg.Deserialize(b.array()); 51 | } catch (IOException e) { 52 | // TODO Auto-generated catch block 53 | e.printStackTrace(); 54 | } 55 | } 56 | 57 | @Override 58 | public ParameterVector get() { 59 | // TODO Auto-generated method stub 60 | return this.param_msg; 61 | } 62 | 63 | @Override 64 | public void set(ParameterVector t) { 65 | // TODO Auto-generated method stub 66 | this.param_msg = t; 67 | } 68 | 69 | @Override 70 | public ByteBuffer toBytes() { 71 | // TODO Auto-generated method stub 72 | byte[] bytes = null; 73 | try { 74 | bytes = this.param_msg.Serialize(); 75 | } catch (IOException e) { 76 | // TODO Auto-generated catch block 77 | e.printStackTrace(); 78 | } 79 | 80 | // ByteBuffer buf = ByteBuffer.allocate(bytes.length); 81 | // buf.put(bytes); 82 | ByteBuffer buf = ByteBuffer.wrap(bytes); 83 | 84 | return buf; 85 | } 86 | 87 | @Override 88 | public void fromString(String s) { 89 | // TODO Auto-generated method stub 90 | 91 | } 92 | /* 93 | @Override 94 | public int getGlobalBatchNumber() { 95 | // TODO Auto-generated method stub 96 | return 0; 97 | } 98 | 99 | @Override 100 | public int getGlobalIterationNumber() { 101 | // TODO Auto-generated method stub 102 | return 0; 103 | } 104 | 105 | @Override 106 | public void setIterationState(int arg0, int arg1) { 107 | // TODO Auto-generated method stub 108 | 109 | } 110 | */ 111 | } 112 | -------------------------------------------------------------------------------- /src/test/java/com/cloudera/knittingboar/sgd/iterativereduce/TestKnittingBoar_IRUnitSim.java: -------------------------------------------------------------------------------- 1 | package com.cloudera.knittingboar.sgd.iterativereduce; 2 | 3 | import java.io.File; 4 | import java.io.IOException; 5 | import java.util.ArrayList; 6 | 7 | import junit.framework.TestCase; 8 | 9 | import org.apache.hadoop.conf.Configuration; 10 | import org.apache.hadoop.fs.FSDataOutputStream; 11 | import org.apache.hadoop.fs.FileSystem; 12 | import org.apache.hadoop.fs.Path; 13 | import org.apache.hadoop.mapred.FileInputFormat; 14 | import org.apache.hadoop.mapred.InputSplit; 15 | import org.apache.hadoop.mapred.JobConf; 16 | import org.apache.hadoop.mapred.TextInputFormat; 17 | 18 | import com.cloudera.iterativereduce.io.TextRecordParser; 19 | import com.cloudera.iterativereduce.irunit.IRUnitDriver; 20 | import com.cloudera.knittingboar.sgd.iterativereduce.POLRMasterNode; 21 | import com.cloudera.knittingboar.sgd.iterativereduce.POLRWorkerNode; 22 | import com.cloudera.knittingboar.utils.DataUtils; 23 | import com.cloudera.knittingboar.utils.DatasetConverter; 24 | 25 | 26 | public class TestKnittingBoar_IRUnitSim extends TestCase { 27 | 28 | 29 | private static JobConf defaultConf = new JobConf(); 30 | private static FileSystem localFs = null; 31 | static { 32 | try { 33 | defaultConf.set("fs.defaultFS", "file:///"); 34 | localFs = FileSystem.getLocal(defaultConf); 35 | } catch (IOException e) { 36 | throw new RuntimeException("init failure", e); 37 | } 38 | } 39 | 40 | //private static Path workDir = new Path(System.getProperty("test.build.data", "/Users/jpatterson/Downloads/datasets/20news-kboar/train4/")); 41 | private static Path workDir20NewsLocal = new Path(new Path("/tmp"), "Dataset20Newsgroups"); 42 | private static File unzipDir = new File( workDir20NewsLocal + "/20news-bydate"); 43 | private static String strKBoarTrainDirInput = "" + unzipDir.toString() + "/KBoar-train/"; 44 | 45 | 46 | public void setupResources() throws IOException { 47 | 48 | File file20News = DataUtils.getTwentyNewsGroupDir(); 49 | 50 | DatasetConverter.ConvertNewsgroupsFromSingleFiles( DataUtils.get20NewsgroupsLocalDataLocation() + "/20news-bydate-train/", strKBoarTrainDirInput, 6000); 51 | 52 | 53 | } 54 | 55 | public void testIRUnit_POLR() throws IOException { 56 | 57 | System.out.println( "Starting: testIRUnit_POLR" ); 58 | 59 | setupResources(); 60 | 61 | String[] props = { 62 | "app.iteration.count", 63 | "com.cloudera.knittingboar.setup.FeatureVectorSize", 64 | "com.cloudera.knittingboar.setup.numCategories", 65 | "com.cloudera.knittingboar.setup.RecordFactoryClassname" 66 | }; 67 | 68 | IRUnitDriver polr_driver = new IRUnitDriver("src/test/resources/app_unit_test.properties", props ); 69 | 70 | polr_driver.SetProperty("app.input.path", strKBoarTrainDirInput); 71 | 72 | polr_driver.Setup(); 73 | polr_driver.SimulateRun(); 74 | 75 | System.out.println("\n\nComplete..."); 76 | 77 | POLRMasterNode IR_Master = (POLRMasterNode)polr_driver.getMaster(); 78 | 79 | Path out = new Path("/tmp/IR_Model_0.model"); 80 | FileSystem fs = out.getFileSystem(defaultConf); 81 | FSDataOutputStream fos = fs.create(out); 82 | 83 | //LOG.info("Writing master results to " + out.toString()); 84 | IR_Master.complete(fos); 85 | 86 | fos.flush(); 87 | fos.close(); 88 | 89 | System.out.println("\n\nModel Saved: /tmp/IR_Model_0.model" ); 90 | 91 | } 92 | 93 | 94 | } 95 | -------------------------------------------------------------------------------- /src/main/java/com/cloudera/knittingboar/io/InputRecordsSplit.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.cloudera.knittingboar.io; 19 | 20 | import java.io.IOException; 21 | 22 | import org.apache.hadoop.io.LongWritable; 23 | import org.apache.hadoop.io.Text; 24 | import org.apache.hadoop.mapred.InputSplit; 25 | import org.apache.hadoop.mapred.JobConf; 26 | import org.apache.hadoop.mapred.RecordReader; 27 | import org.apache.hadoop.mapred.Reporter; 28 | import org.apache.hadoop.mapred.TextInputFormat; 29 | 30 | /** 31 | * Encapsulates functionality from: 32 | * 33 | * - FileInputFormat::getSplits(...) - this info should be calculated by the 34 | * main job controlling process [MOVE] 35 | * 36 | * - TextInputFormat::readSplit(...) 37 | * 38 | * 39 | * Notes - currently hard-coded to read CSV "record per line" non-compressed 40 | * records from disk 41 | * 42 | * @author jpatterson 43 | * 44 | */ 45 | public class InputRecordsSplit { 46 | 47 | TextInputFormat input_format = null; 48 | InputSplit split = null; 49 | JobConf jobConf = null; 50 | 51 | RecordReader reader = null; 52 | LongWritable key = null; 53 | 54 | final Reporter voidReporter = Reporter.NULL; 55 | 56 | public InputRecordsSplit(JobConf jobConf, InputSplit split) 57 | throws IOException { 58 | 59 | this.jobConf = jobConf; 60 | this.split = split; 61 | this.input_format = new TextInputFormat(); 62 | 63 | // RecordReader reader = 64 | // format.getRecordReader(splits[x], job, reporter); 65 | this.reader = input_format.getRecordReader(this.split, this.jobConf, 66 | voidReporter); 67 | this.key = reader.createKey(); 68 | // Text value = reader.createValue(); 69 | 70 | } 71 | 72 | /** 73 | * 74 | * just a dead simple way to do this 75 | * 76 | * - functionality from TestTextInputFormat::readSplit() 77 | * 78 | * If returns true, then csv_line contains the next line If returns false, 79 | * then there is no next record 80 | * 81 | * Will terminate when it hits the end of the split based on the information 82 | * provided in the split class to the constructor and the TextInputFormat 83 | * 84 | * @param csv_line 85 | * @throws IOException 86 | */ 87 | public boolean next(Text csv_line) throws IOException { 88 | 89 | return reader.next(key, csv_line); 90 | 91 | } 92 | 93 | public void ResetToStartOfSplit() throws IOException { 94 | 95 | // I'mma cheatin here. sue me. 96 | this.reader = input_format.getRecordReader(this.split, this.jobConf, 97 | voidReporter); 98 | 99 | } 100 | 101 | } 102 | -------------------------------------------------------------------------------- /src/test/java/com/cloudera/knittingboar/messages/TestParameterVector.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.cloudera.knittingboar.messages; 19 | 20 | import java.io.DataInput; 21 | import java.io.DataInputStream; 22 | import java.io.DataOutput; 23 | import java.io.DataOutputStream; 24 | import java.io.FileInputStream; 25 | import java.io.FileOutputStream; 26 | import java.io.IOException; 27 | import java.io.OutputStream; 28 | 29 | import org.apache.mahout.math.DenseMatrix; 30 | import org.apache.mahout.math.Matrix; 31 | 32 | import junit.framework.TestCase; 33 | 34 | import com.cloudera.knittingboar.messages.iterativereduce.ParameterVector; 35 | 36 | 37 | 38 | public class TestParameterVector extends TestCase { 39 | 40 | 41 | 42 | public static String msg_file = "/tmp/TestGradientUpdateMessageSerde.msg"; 43 | 44 | //public static String ip = "255.255.255.1"; 45 | 46 | public static int pass_count = 8; 47 | 48 | public void testSerde() throws IOException { 49 | 50 | int classes = 20; 51 | int features = 10000; 52 | 53 | //GradientBuffer g = new GradientBuffer( classes, features ); 54 | Matrix m = new DenseMatrix(classes, features); 55 | //m.set(0, 0, 0.1); 56 | //m.set(0, 1, 0.3); 57 | 58 | 59 | //g.numFeatures(); 60 | 61 | for (int c = 0; c < classes - 1; c++) { 62 | 63 | for (int f = 0; f < features; f++ ) { 64 | 65 | m.set(c, f, (double)((double)f / 10.0f) ); 66 | 67 | } 68 | 69 | } 70 | 71 | System.out.println( "matrix created..." ); 72 | 73 | 74 | ParameterVector vec_gradient = new ParameterVector(); 75 | vec_gradient.SrcWorkerPassCount = pass_count; 76 | vec_gradient.parameter_vector = m; 77 | vec_gradient.AvgLogLikelihood = -1.368f; 78 | vec_gradient.PercentCorrect = 72.68f; 79 | vec_gradient.TrainedRecords = 2500; 80 | 81 | 82 | assertEquals( 10000, vec_gradient.numFeatures() ); 83 | assertEquals( 10000, vec_gradient.parameter_vector.columnSize() ); 84 | 85 | assertEquals( 20, vec_gradient.numCategories() ); 86 | assertEquals( 20, vec_gradient.parameter_vector.rowSize() ); 87 | 88 | byte[] buf = vec_gradient.Serialize(); 89 | 90 | 91 | ParameterVector vec_gradient_deserialized = new ParameterVector(); 92 | vec_gradient_deserialized.Deserialize(buf); 93 | 94 | assertEquals( pass_count, vec_gradient_deserialized.SrcWorkerPassCount ); 95 | assertEquals( 0.1, vec_gradient_deserialized.parameter_vector.get(0, 1) ); 96 | assertEquals( 0.2, vec_gradient_deserialized.parameter_vector.get(0, 2) ); 97 | assertEquals( 0.3, vec_gradient_deserialized.parameter_vector.get(0, 3) ); 98 | assertEquals( 0.4, vec_gradient_deserialized.parameter_vector.get(0, 4) ); 99 | assertEquals( 0.5, vec_gradient_deserialized.parameter_vector.get(0, 5) ); 100 | 101 | assertEquals( -1.368f, vec_gradient_deserialized.AvgLogLikelihood ); 102 | assertEquals( 72.68f, vec_gradient_deserialized.PercentCorrect ); 103 | assertEquals( 2500, vec_gradient_deserialized.TrainedRecords ); 104 | 105 | } 106 | 107 | 108 | 109 | 110 | } 111 | -------------------------------------------------------------------------------- /src/test/java/com/cloudera/knittingboar/sgd/TestPOLRMasterNode.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.cloudera.knittingboar.sgd; 19 | 20 | import org.apache.hadoop.conf.Configuration; 21 | 22 | import com.cloudera.knittingboar.records.RecordFactory; 23 | import com.cloudera.knittingboar.sgd.iterativereduce.POLRMasterNode; 24 | 25 | import junit.framework.TestCase; 26 | 27 | /** 28 | * - Test messages coming in and merging into the master gradient buffer 29 | * 30 | * - Test trigger message generation after gradient update 31 | * 32 | * @author jpatterson 33 | * 34 | */ 35 | public class TestPOLRMasterNode extends TestCase { 36 | 37 | public Configuration generateDebugConfigurationObject() { 38 | 39 | Configuration c = new Configuration(); 40 | 41 | // feature vector size 42 | c.setInt( "com.cloudera.knittingboar.setup.FeatureVectorSize", 10 ); 43 | 44 | c.set( "com.cloudera.knittingboar.setup.RecordFactoryClassname", RecordFactory.CSV_RECORDFACTORY); 45 | 46 | 47 | // predictor label names 48 | c.set( "com.cloudera.knittingboar.setup.PredictorLabelNames", "x,y" ); 49 | 50 | // predictor var types 51 | c.set( "com.cloudera.knittingboar.setup.PredictorVariableTypes", "numeric,numeric" ); 52 | 53 | // target variables 54 | c.set( "com.cloudera.knittingboar.setup.TargetVariableName", "color" ); 55 | 56 | // column header names 57 | c.set( "com.cloudera.knittingboar.setup.ColumnHeaderNames", "x,y,shape,color,k,k0,xx,xy,yy,a,b,c,bias" ); 58 | 59 | return c; 60 | 61 | } 62 | 63 | public void testMasterConfiguration() { 64 | 65 | POLRMasterNode master = new POLRMasterNode(); 66 | 67 | // ------------------ 68 | // generate the debug conf ---- normally setup by YARN stuff 69 | master.setup(this.generateDebugConfigurationObject()); 70 | // now load the conf stuff into locally used vars 71 | /* try { 72 | master.LoadConfigVarsLocally(); 73 | } catch (Exception e) { 74 | // TODO Auto-generated catch block 75 | e.printStackTrace(); 76 | System.out.println( "Conf load fail: shutting down." ); 77 | assertEquals( 0, 1 ); 78 | } 79 | // now construct any needed machine learning data structures based on config 80 | master.Setup(); 81 | */ 82 | // ------------------ 83 | 84 | // test the base conf stuff ------------ 85 | 86 | assertEquals( master.getConf().getInt("com.cloudera.knittingboar.setup.FeatureVectorSize", 0), 10 ); 87 | // assertEquals( master.getConf().get("com.cloudera.knittingboar.setup.LocalInputSplitPath"), "hdfs://127.0.0.1/input/0" ); 88 | assertEquals( master.getConf().get("com.cloudera.knittingboar.setup.PredictorLabelNames"), "x,y" ); 89 | assertEquals( master.getConf().get("com.cloudera.knittingboar.setup.PredictorVariableTypes"), "numeric,numeric" ); 90 | assertEquals( master.getConf().get("com.cloudera.knittingboar.setup.TargetVariableName"), "color" ); 91 | assertEquals( master.getConf().get("com.cloudera.knittingboar.setup.ColumnHeaderNames"), "x,y,shape,color,k,k0,xx,xy,yy,a,b,c,bias" ); 92 | 93 | // now test the parsed stuff ------------ 94 | 95 | 96 | 97 | 98 | } 99 | 100 | 101 | 102 | } 103 | -------------------------------------------------------------------------------- /src/main/java/com/cloudera/knittingboar/yarn/avro/generated/KnittingBoarService.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Autogenerated by Avro 3 | * 4 | * DO NOT EDIT DIRECTLY 5 | */ 6 | package com.cloudera.knittingboar.yarn.avro.generated; 7 | 8 | @SuppressWarnings("all") 9 | public interface KnittingBoarService { 10 | public static final org.apache.avro.Protocol PROTOCOL = org.apache.avro.Protocol.parse("{\"protocol\":\"KnittingBoarService\",\"namespace\":\"com.cloudera.knittingboar.yarn.avro.generated\",\"types\":[{\"type\":\"fixed\",\"name\":\"WorkerId\",\"size\":32},{\"type\":\"record\",\"name\":\"FileSplit\",\"fields\":[{\"name\":\"path\",\"type\":\"string\"},{\"name\":\"offset\",\"type\":\"long\"},{\"name\":\"length\",\"type\":\"long\"}]},{\"type\":\"record\",\"name\":\"StartupConfiguration\",\"fields\":[{\"name\":\"split\",\"type\":\"FileSplit\"},{\"name\":\"iterations\",\"type\":\"int\"},{\"name\":\"batchSize\",\"type\":\"int\"},{\"name\":\"other\",\"type\":[{\"type\":\"map\",\"values\":\"string\"},\"null\"]}]},{\"type\":\"record\",\"name\":\"ProgressReport\",\"fields\":[{\"name\":\"workerId\",\"type\":\"WorkerId\"},{\"name\":\"report\",\"type\":{\"type\":\"map\",\"values\":\"string\"}}]},{\"type\":\"error\",\"name\":\"ServiceError\",\"fields\":[{\"name\":\"description\",\"type\":[\"null\",\"string\"]}]}],\"messages\":{\"startup\":{\"request\":[{\"name\":\"workerId\",\"type\":\"WorkerId\"}],\"response\":\"StartupConfiguration\",\"errors\":[\"ServiceError\"]},\"progress\":{\"request\":[{\"name\":\"workerId\",\"type\":\"WorkerId\"},{\"name\":\"report\",\"type\":[\"ProgressReport\",\"null\"]}],\"response\":\"boolean\"},\"update\":{\"request\":[{\"name\":\"workerId\",\"type\":\"WorkerId\"},{\"name\":\"data\",\"type\":\"bytes\"}],\"response\":\"boolean\"},\"waiting\":{\"request\":[{\"name\":\"workerId\",\"type\":\"WorkerId\"},{\"name\":\"lastUpdate\",\"type\":\"int\"},{\"name\":\"waiting\",\"type\":\"long\"}],\"response\":\"int\"},\"fetch\":{\"request\":[{\"name\":\"workerId\",\"type\":\"WorkerId\"},{\"name\":\"updateId\",\"type\":\"int\"}],\"response\":\"bytes\"},\"complete\":{\"request\":[{\"name\":\"workerId\",\"type\":\"WorkerId\"},{\"name\":\"finalReport\",\"type\":\"ProgressReport\"}],\"response\":\"null\",\"one-way\":true},\"error\":{\"request\":[{\"name\":\"workerId\",\"type\":\"WorkerId\"},{\"name\":\"message\",\"type\":\"string\"}],\"response\":\"null\",\"one-way\":true},\"metricsReport\":{\"request\":[{\"name\":\"workerId\",\"type\":\"WorkerId\"},{\"name\":\"metrics\",\"type\":{\"type\":\"map\",\"values\":\"long\"}}],\"response\":\"null\",\"one-way\":true}}}"); 11 | com.cloudera.knittingboar.yarn.avro.generated.StartupConfiguration startup(com.cloudera.knittingboar.yarn.avro.generated.WorkerId workerId) throws org.apache.avro.AvroRemoteException, com.cloudera.knittingboar.yarn.avro.generated.ServiceError; 12 | boolean progress(com.cloudera.knittingboar.yarn.avro.generated.WorkerId workerId, com.cloudera.knittingboar.yarn.avro.generated.ProgressReport report) throws org.apache.avro.AvroRemoteException; 13 | boolean update(com.cloudera.knittingboar.yarn.avro.generated.WorkerId workerId, java.nio.ByteBuffer data) throws org.apache.avro.AvroRemoteException; 14 | int waiting(com.cloudera.knittingboar.yarn.avro.generated.WorkerId workerId, int lastUpdate, long waiting) throws org.apache.avro.AvroRemoteException; 15 | java.nio.ByteBuffer fetch(com.cloudera.knittingboar.yarn.avro.generated.WorkerId workerId, int updateId) throws org.apache.avro.AvroRemoteException; 16 | void complete(com.cloudera.knittingboar.yarn.avro.generated.WorkerId workerId, com.cloudera.knittingboar.yarn.avro.generated.ProgressReport finalReport); 17 | void error(com.cloudera.knittingboar.yarn.avro.generated.WorkerId workerId, java.lang.CharSequence message); 18 | void metricsReport(com.cloudera.knittingboar.yarn.avro.generated.WorkerId workerId, java.util.Map metrics); 19 | 20 | @SuppressWarnings("all") 21 | public interface Callback extends KnittingBoarService { 22 | public static final org.apache.avro.Protocol PROTOCOL = com.cloudera.knittingboar.yarn.avro.generated.KnittingBoarService.PROTOCOL; 23 | void startup(com.cloudera.knittingboar.yarn.avro.generated.WorkerId workerId, org.apache.avro.ipc.Callback callback) throws java.io.IOException; 24 | void progress(com.cloudera.knittingboar.yarn.avro.generated.WorkerId workerId, com.cloudera.knittingboar.yarn.avro.generated.ProgressReport report, org.apache.avro.ipc.Callback callback) throws java.io.IOException; 25 | void update(com.cloudera.knittingboar.yarn.avro.generated.WorkerId workerId, java.nio.ByteBuffer data, org.apache.avro.ipc.Callback callback) throws java.io.IOException; 26 | void waiting(com.cloudera.knittingboar.yarn.avro.generated.WorkerId workerId, int lastUpdate, long waiting, org.apache.avro.ipc.Callback callback) throws java.io.IOException; 27 | void fetch(com.cloudera.knittingboar.yarn.avro.generated.WorkerId workerId, int updateId, org.apache.avro.ipc.Callback callback) throws java.io.IOException; 28 | } 29 | } -------------------------------------------------------------------------------- /src/test/java/com/cloudera/knittingboar/io/TestSplitReset.java: -------------------------------------------------------------------------------- 1 | package com.cloudera.knittingboar.io; 2 | 3 | import java.io.IOException; 4 | import java.io.OutputStreamWriter; 5 | import java.io.Writer; 6 | import java.util.ArrayList; 7 | 8 | import org.apache.hadoop.conf.Configuration; 9 | import org.apache.hadoop.fs.FileSystem; 10 | import org.apache.hadoop.fs.Path; 11 | import org.apache.hadoop.io.Text; 12 | import org.apache.hadoop.mapred.FileInputFormat; 13 | import org.apache.hadoop.mapred.InputSplit; 14 | import org.apache.hadoop.mapred.JobConf; 15 | import org.apache.hadoop.mapred.TextInputFormat; 16 | 17 | import com.cloudera.iterativereduce.io.TextRecordParser; 18 | import com.cloudera.knittingboar.messages.iterativereduce.ParameterVectorUpdatable; 19 | import com.cloudera.knittingboar.sgd.iterativereduce.POLRWorkerNode; 20 | 21 | import junit.framework.TestCase; 22 | 23 | 24 | public class TestSplitReset extends TestCase { 25 | 26 | 27 | private static JobConf defaultConf = new JobConf(); 28 | private static FileSystem localFs = null; 29 | static { 30 | try { 31 | defaultConf.set("fs.defaultFS", "file:///"); 32 | localFs = FileSystem.getLocal(defaultConf); 33 | } catch (IOException e) { 34 | throw new RuntimeException("init failure", e); 35 | } 36 | } 37 | 38 | 39 | private static Path workDir = new Path(System.getProperty("test.build.data", "/tmp/TestSplitReset/")); 40 | 41 | public Configuration generateDebugConfigurationObject() { 42 | 43 | Configuration c = new Configuration(); 44 | 45 | // feature vector size 46 | c.setInt( "com.cloudera.knittingboar.setup.FeatureVectorSize", 10000 ); 47 | 48 | c.setInt( "com.cloudera.knittingboar.setup.numCategories", 20); 49 | 50 | c.setInt("com.cloudera.knittingboar.setup.NumberPasses", 2); 51 | 52 | 53 | // setup 20newsgroups 54 | c.set( "com.cloudera.knittingboar.setup.RecordFactoryClassname", "com.cloudera.knittingboar.records.TwentyNewsgroupsRecordFactory"); 55 | 56 | return c; 57 | 58 | } 59 | 60 | public InputSplit[] generateDebugSplits( Path input_path, JobConf job ) { 61 | 62 | long block_size = localFs.getDefaultBlockSize(); 63 | // localFs. 64 | 65 | System.out.println("default block size: " + (block_size / 1024 / 1024) + "MB"); 66 | 67 | 68 | // ---- set where we'll read the input files from ------------- 69 | FileInputFormat.setInputPaths(job, input_path); 70 | 71 | 72 | // try splitting the file in a variety of sizes 73 | TextInputFormat format = new TextInputFormat(); 74 | format.configure(job); 75 | 76 | int numSplits = 1; 77 | 78 | InputSplit[] splits = null; 79 | 80 | try { 81 | splits = format.getSplits(job, numSplits); 82 | } catch (IOException e) { 83 | // TODO Auto-generated catch block 84 | e.printStackTrace(); 85 | } 86 | 87 | 88 | return splits; 89 | 90 | 91 | } 92 | 93 | public void testReset() throws IOException { 94 | 95 | //TextRecordParser lineParser = null; 96 | 97 | // ---- this all needs to be done in 98 | JobConf job = new JobConf(defaultConf); 99 | 100 | 101 | Path file = new Path(workDir, "testGetSplits.txt"); 102 | 103 | int tmp_file_size = 200000; 104 | 105 | long block_size = localFs.getDefaultBlockSize(); 106 | 107 | System.out.println("default block size: " + (block_size / 1024 / 1024) + "MB"); 108 | 109 | Writer writer = new OutputStreamWriter(localFs.create(file)); 110 | try { 111 | for (int i = 0; i < tmp_file_size; i++) { 112 | writer.write("a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, x, y, z, 1, a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, x, y, z, a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, x, y, z, a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, x, y, z, 99"); 113 | writer.write("\n"); 114 | } 115 | } finally { 116 | writer.close(); 117 | } 118 | 119 | System.out.println( "file write complete" ); 120 | 121 | 122 | 123 | // TODO: work on this, splits are generating for everything in dir 124 | InputSplit[] splits = generateDebugSplits(workDir, job); 125 | 126 | System.out.println( "split count: " + splits.length ); 127 | 128 | TextRecordParser txt_reader = new TextRecordParser(); 129 | 130 | long len = Integer.parseInt( splits[0].toString().split(":")[2].split("\\+")[1] ); 131 | 132 | txt_reader.setFile(splits[0].toString().split(":")[1], 0, len); 133 | 134 | Text csv_line = new Text(); 135 | int x = 0; 136 | while (txt_reader.hasMoreRecords()) { 137 | 138 | 139 | txt_reader.next(csv_line); 140 | x++; 141 | } 142 | 143 | System.out.println( "read recs: " + x ); 144 | 145 | txt_reader.reset(); 146 | //txt_reader.setFile(splits[0].toString().split(":")[1], 0, len); 147 | x = 0; 148 | while (txt_reader.hasMoreRecords()) { 149 | 150 | 151 | txt_reader.next(csv_line); 152 | x++; 153 | } 154 | 155 | System.out.println( "[after reset] read recs: " + x ); 156 | 157 | } 158 | 159 | } 160 | -------------------------------------------------------------------------------- /src/test/java/com/cloudera/knittingboar/metrics/TestRCV1ApplyModel.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.cloudera.knittingboar.metrics; 19 | 20 | import java.io.IOException; 21 | import java.util.ArrayList; 22 | 23 | import org.apache.hadoop.conf.Configuration; 24 | import org.apache.hadoop.fs.FileSystem; 25 | import org.apache.hadoop.fs.Path; 26 | import org.apache.hadoop.mapred.FileInputFormat; 27 | import org.apache.hadoop.mapred.InputSplit; 28 | import org.apache.hadoop.mapred.JobConf; 29 | import org.apache.hadoop.mapred.TextInputFormat; 30 | 31 | import com.cloudera.knittingboar.io.InputRecordsSplit; 32 | import com.cloudera.knittingboar.records.RecordFactory; 33 | //import com.cloudera.knittingboar.sgd.POLRWorkerDriver; 34 | 35 | import junit.framework.TestCase; 36 | 37 | /** 38 | * This class will test applying the model against the RCV1 dataset [TODO] 39 | * 40 | * @author jpatterson 41 | * 42 | */ 43 | public class TestRCV1ApplyModel extends TestCase { 44 | /* 45 | private static JobConf defaultConf = new JobConf(); 46 | private static FileSystem localFs = null; 47 | static { 48 | try { 49 | defaultConf.set("fs.defaultFS", "file:///"); 50 | localFs = FileSystem.getLocal(defaultConf); 51 | } catch (IOException e) { 52 | throw new RuntimeException("init failure", e); 53 | } 54 | } 55 | 56 | private static Path inputDir = new Path(System.getProperty("test.build.data", "/Users/jpatterson/Downloads/rcv1/subset/train/")); 57 | 58 | private static Path fullRCV1Dir = new Path("/Users/jpatterson/Downloads/rcv1/rcv1.train.vw"); 59 | 60 | private static Path fullRCV1TestFile = new Path("/Users/jpatterson/Downloads/rcv1/rcv1.test.vw"); 61 | 62 | public Configuration generateDebugConfigurationObject() { 63 | 64 | Configuration c = new Configuration(); 65 | 66 | // feature vector size 67 | c.setInt( "com.cloudera.knittingboar.setup.FeatureVectorSize", 10000 ); 68 | 69 | c.setInt( "com.cloudera.knittingboar.setup.numCategories", 2); 70 | 71 | c.setInt("com.cloudera.knittingboar.setup.BatchSize", 200); 72 | 73 | // local input split path 74 | c.set( "com.cloudera.knittingboar.setup.LocalInputSplitPath", "hdfs://127.0.0.1/input/0" ); 75 | 76 | // setup 20newsgroups 77 | c.set( "com.cloudera.knittingboar.setup.RecordFactoryClassname", RecordFactory.RCV1_RECORDFACTORY); 78 | 79 | return c; 80 | 81 | } 82 | 83 | public InputSplit[] generateDebugSplits( Path input_path, JobConf job ) { 84 | 85 | long block_size = localFs.getDefaultBlockSize(); 86 | 87 | System.out.println("default block size: " + (block_size / 1024 / 1024) + "MB"); 88 | 89 | FileInputFormat.setInputPaths(job, input_path); 90 | 91 | 92 | // try splitting the file in a variety of sizes 93 | TextInputFormat format = new TextInputFormat(); 94 | format.configure(job); 95 | 96 | int numSplits = 1; 97 | 98 | InputSplit[] splits = null; 99 | 100 | try { 101 | splits = format.getSplits(job, numSplits); 102 | } catch (IOException e) { 103 | // TODO Auto-generated catch block 104 | e.printStackTrace(); 105 | } 106 | 107 | 108 | return splits; 109 | 110 | 111 | } 112 | 113 | public void testApplyModel() throws Exception { 114 | 115 | POLRModelTester tester = new POLRModelTester(); 116 | // ------------------ 117 | // generate the debug conf ---- normally setup by YARN stuff 118 | tester.setConf(this.generateDebugConfigurationObject()); 119 | // now load the conf stuff into locally used vars 120 | try { 121 | tester.LoadConfigVarsLocally(); 122 | } catch (Exception e) { 123 | // TODO Auto-generated catch block 124 | e.printStackTrace(); 125 | System.out.println( "Conf load fail: shutting down." ); 126 | assertEquals( 0, 1 ); 127 | } 128 | // now construct any needed machine learning data structures based on config 129 | tester.Setup(); 130 | tester.Load("/tmp/master_sgd.model"); 131 | // ------------------ 132 | 133 | 134 | // ---- this all needs to be done in 135 | JobConf job = new JobConf(defaultConf); 136 | 137 | InputSplit[] splits = generateDebugSplits(fullRCV1TestFile, job); 138 | 139 | System.out.println( "split count: " + splits.length ); 140 | 141 | InputRecordsSplit custom_reader_0 = new InputRecordsSplit(job, splits[0]); 142 | // TODO: set this up to run through the conf pathways 143 | tester.setupInputSplit(custom_reader_0); 144 | tester.RunThroughTestRecords(); 145 | 146 | } 147 | */ 148 | } 149 | -------------------------------------------------------------------------------- /src/test/java/com/cloudera/knittingboar/io/TestSplitCalcs.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.cloudera.knittingboar.io; 19 | 20 | import java.io.IOException; 21 | import java.io.OutputStreamWriter; 22 | import java.io.Writer; 23 | import java.util.BitSet; 24 | import java.util.Random; 25 | 26 | import org.apache.commons.logging.Log; 27 | import org.apache.commons.logging.LogFactory; 28 | import org.apache.hadoop.fs.FileSystem; 29 | import org.apache.hadoop.fs.Path; 30 | import org.apache.hadoop.io.LongWritable; 31 | import org.apache.hadoop.io.Text; 32 | import org.apache.hadoop.mapred.FileInputFormat; 33 | import org.apache.hadoop.mapred.InputSplit; 34 | import org.apache.hadoop.mapred.JobConf; 35 | import org.apache.hadoop.mapred.RecordReader; 36 | import org.apache.hadoop.mapred.Reporter; 37 | import org.apache.hadoop.mapred.TextInputFormat; 38 | 39 | import junit.framework.TestCase; 40 | 41 | public class TestSplitCalcs extends TestCase { 42 | 43 | private static final Log LOG = LogFactory.getLog(TestSplitCalcs.class.getName()); 44 | 45 | private static int MAX_LENGTH = 1000; 46 | 47 | private static JobConf defaultConf = new JobConf(); 48 | private static FileSystem localFs = null; 49 | static { 50 | try { 51 | defaultConf.set("fs.defaultFS", "file:///"); 52 | localFs = FileSystem.getLocal(defaultConf); 53 | } catch (IOException e) { 54 | throw new RuntimeException("init failure", e); 55 | } 56 | } 57 | 58 | private static Path workDir = new Path(new Path(System.getProperty("test.build.data", "/tmp")), "TestSplitCalcs").makeQualified(localFs); 59 | 60 | 61 | 62 | /** 63 | * 64 | * - use the TextInputFormat.getSplits() to test pulling split info 65 | * @throws IOException 66 | * 67 | */ 68 | public void testGetSplits() throws IOException { 69 | 70 | TextInputFormat input = new TextInputFormat(); 71 | 72 | JobConf job = new JobConf(defaultConf); 73 | Path file = new Path(workDir, "testGetSplits.txt"); 74 | 75 | int tmp_file_size = 200000; 76 | 77 | long block_size = localFs.getDefaultBlockSize(); 78 | 79 | System.out.println("default block size: " + (block_size / 1024 / 1024) + "MB"); 80 | 81 | Writer writer = new OutputStreamWriter(localFs.create(file)); 82 | try { 83 | for (int i = 0; i < tmp_file_size; i++) { 84 | writer.write("a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, x, y, z, 1, a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, x, y, z, a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, x, y, z, a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, x, y, z, 99"); 85 | writer.write("\n"); 86 | } 87 | } finally { 88 | writer.close(); 89 | } 90 | 91 | System.out.println( "file write complete" ); 92 | 93 | 94 | // A reporter that does nothing 95 | Reporter reporter = Reporter.NULL; 96 | 97 | // localFs.delete(workDir, true); 98 | FileInputFormat.setInputPaths(job, file); 99 | 100 | 101 | // try splitting the file in a variety of sizes 102 | TextInputFormat format = new TextInputFormat(); 103 | format.configure(job); 104 | LongWritable key = new LongWritable(); 105 | Text value = new Text(); 106 | 107 | int numSplits = 1; 108 | 109 | InputSplit[] splits = format.getSplits(job, numSplits); 110 | 111 | LOG.info("requested " + numSplits + " splits, splitting: got = " + splits.length); 112 | 113 | assertEquals( 2, splits.length ); 114 | 115 | System.out.println( "---- debug splits --------- " ); 116 | 117 | for (int x = 0; x < splits.length; x++) { 118 | 119 | System.out.println( "> Split [" + x + "]: " + splits[x].getLength() + ", " + splits[x].toString() + ", " + splits[x].getLocations()[0] ); 120 | 121 | 122 | RecordReader reader = format.getRecordReader(splits[x], job, reporter); 123 | try { 124 | int count = 0; 125 | while (reader.next(key, value)) { 126 | 127 | if (count == 0) { 128 | System.out.println( "first: " + value.toString() ); 129 | assertTrue( value.toString().contains("a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p") ); 130 | } 131 | 132 | count++; 133 | } 134 | 135 | System.out.println( "last: " + value.toString() ); 136 | 137 | assertTrue( value.toString().contains("a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p") ); 138 | 139 | } finally { 140 | reader.close(); 141 | } 142 | 143 | } // for each split 144 | 145 | 146 | } 147 | 148 | 149 | } 150 | -------------------------------------------------------------------------------- /src/test/java/com/cloudera/knittingboar/metrics/Test20NewsApplyModel.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.cloudera.knittingboar.metrics; 19 | 20 | import java.io.File; 21 | import java.io.FileNotFoundException; 22 | import java.io.IOException; 23 | import java.net.MalformedURLException; 24 | import java.net.URL; 25 | 26 | import org.apache.commons.compress.archivers.ArchiveException; 27 | import org.apache.hadoop.conf.Configuration; 28 | import org.apache.hadoop.fs.FileSystem; 29 | import org.apache.hadoop.fs.Path; 30 | import org.apache.hadoop.mapred.FileInputFormat; 31 | import org.apache.hadoop.mapred.InputSplit; 32 | import org.apache.hadoop.mapred.JobConf; 33 | import org.apache.hadoop.mapred.TextInputFormat; 34 | 35 | import com.cloudera.knittingboar.io.InputRecordsSplit; 36 | import com.cloudera.knittingboar.records.RecordFactory; 37 | import com.cloudera.knittingboar.utils.DataUtils; 38 | import com.cloudera.knittingboar.utils.DatasetConverter; 39 | import com.cloudera.knittingboar.utils.Utils; 40 | 41 | import junit.framework.TestCase; 42 | 43 | /** 44 | * Test applying a model to the test data in the 20newsgroups dataset 45 | * 46 | * @author jpatterson 47 | * 48 | */ 49 | public class Test20NewsApplyModel extends TestCase { 50 | 51 | private static JobConf defaultConf = new JobConf(); 52 | private static FileSystem localFs = null; 53 | static { 54 | try { 55 | defaultConf.set("fs.defaultFS", "file:///"); 56 | localFs = FileSystem.getLocal(defaultConf); 57 | } catch (IOException e) { 58 | throw new RuntimeException("init failure", e); 59 | } 60 | } 61 | 62 | private static Path workDir20NewsLocal = new Path(new Path("/tmp"), "Dataset20Newsgroups"); 63 | private static File unzipDir = new File( workDir20NewsLocal + "/20news-bydate"); 64 | private static String strKBoarTestDirInput = "" + unzipDir.toString() + "/KBoar-test/"; 65 | 66 | 67 | 68 | public Configuration generateDebugConfigurationObject() { 69 | 70 | Configuration c = new Configuration(); 71 | 72 | // feature vector size 73 | c.setInt( "com.cloudera.knittingboar.setup.FeatureVectorSize", 10000 ); 74 | 75 | c.setInt( "com.cloudera.knittingboar.setup.numCategories", 20); 76 | 77 | // c.setInt("com.cloudera.knittingboar.setup.BatchSize", 500); 78 | 79 | // setup 20newsgroups 80 | c.set( "com.cloudera.knittingboar.setup.RecordFactoryClassname", RecordFactory.TWENTYNEWSGROUPS_RECORDFACTORY); 81 | 82 | return c; 83 | 84 | } 85 | 86 | public InputSplit[] generateDebugSplits( Path input_path, JobConf job ) { 87 | 88 | long block_size = localFs.getDefaultBlockSize(); 89 | 90 | System.out.println("default block size: " + (block_size / 1024 / 1024) + "MB"); 91 | 92 | // ---- set where we'll read the input files from ------------- 93 | FileInputFormat.setInputPaths(job, input_path); 94 | 95 | // try splitting the file in a variety of sizes 96 | TextInputFormat format = new TextInputFormat(); 97 | format.configure(job); 98 | 99 | int numSplits = 1; 100 | 101 | InputSplit[] splits = null; 102 | 103 | try { 104 | splits = format.getSplits(job, numSplits); 105 | } catch (IOException e) { 106 | // TODO Auto-generated catch block 107 | e.printStackTrace(); 108 | } 109 | 110 | return splits; 111 | 112 | } 113 | 114 | 115 | public void testLoad20NewsModel() throws Exception { 116 | 117 | 118 | File file20News = DataUtils.getTwentyNewsGroupDir(); 119 | 120 | DatasetConverter.ConvertNewsgroupsFromSingleFiles( DataUtils.get20NewsgroupsLocalDataLocation() + "/20news-bydate-test/", strKBoarTestDirInput, 12000); 121 | 122 | POLRModelTester tester = new POLRModelTester(); 123 | 124 | // ------------------ 125 | // generate the debug conf ---- normally setup by YARN stuff 126 | tester.setConf(this.generateDebugConfigurationObject()); 127 | // now load the conf stuff into locally used vars 128 | try { 129 | tester.LoadConfigVarsLocally(); 130 | } catch (Exception e) { 131 | // TODO Auto-generated catch block 132 | e.printStackTrace(); 133 | System.out.println( "Conf load fail: shutting down." ); 134 | assertEquals( 0, 1 ); 135 | } 136 | // now construct any needed machine learning data structures based on config 137 | tester.Setup(); 138 | tester.Load( "src/test/resources/KBoar_Sample.model" ); 139 | 140 | // ------------------ 141 | 142 | Path testData20News = new Path(strKBoarTestDirInput); 143 | 144 | // ---- this all needs to be done in 145 | JobConf job = new JobConf(defaultConf); 146 | 147 | InputSplit[] splits = generateDebugSplits(testData20News, job); 148 | 149 | System.out.println( "split count: " + splits.length ); 150 | 151 | InputRecordsSplit custom_reader_0 = new InputRecordsSplit(job, splits[0]); 152 | tester.setupInputSplit(custom_reader_0); 153 | 154 | tester.RunThroughTestRecords(); 155 | 156 | } 157 | 158 | 159 | } 160 | -------------------------------------------------------------------------------- /src/main/java/com/cloudera/knittingboar/messages/iterativereduce/ParameterVector.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.cloudera.knittingboar.messages.iterativereduce; 19 | 20 | import java.io.ByteArrayInputStream; 21 | import java.io.ByteArrayOutputStream; 22 | import java.io.DataInput; 23 | import java.io.DataInputStream; 24 | import java.io.DataOutput; 25 | import java.io.DataOutputStream; 26 | import java.io.FileOutputStream; 27 | import java.io.IOException; 28 | import java.io.OutputStream; 29 | 30 | import org.apache.mahout.math.Matrix; 31 | import org.apache.mahout.math.MatrixWritable; 32 | 33 | 34 | 35 | public class ParameterVector { 36 | 37 | // worker stuff to send out 38 | public int SrcWorkerPassCount = 0; 39 | 40 | public Matrix parameter_vector = null; 41 | public int GlobalPassCount = 0; // what pass should the worker dealing with? 42 | 43 | public int IterationComplete = 0; // 0 = no, 1 = yes 44 | public int CurrentIteration = 0; 45 | 46 | public int TrainedRecords = 0; 47 | public float AvgLogLikelihood = 0; 48 | public float PercentCorrect = 0; 49 | 50 | public byte[] Serialize() throws IOException { 51 | 52 | // DataOutput d 53 | 54 | ByteArrayOutputStream out = new ByteArrayOutputStream(); 55 | DataOutput d = new DataOutputStream(out); 56 | 57 | // d.writeUTF(src_host); 58 | d.writeInt(this.SrcWorkerPassCount); 59 | d.writeInt(this.GlobalPassCount); 60 | 61 | d.writeInt(this.IterationComplete); 62 | d.writeInt(this.CurrentIteration); 63 | 64 | d.writeInt(this.TrainedRecords); 65 | d.writeFloat(this.AvgLogLikelihood); 66 | d.writeFloat(this.PercentCorrect); 67 | // buf.write 68 | // MatrixWritable.writeMatrix(d, this.worker_gradient.getMatrix()); 69 | MatrixWritable.writeMatrix(d, this.parameter_vector); 70 | // MatrixWritable. 71 | 72 | return out.toByteArray(); 73 | } 74 | 75 | public void Deserialize(byte[] bytes) throws IOException { 76 | // DataInput in) throws IOException { 77 | 78 | ByteArrayInputStream b = new ByteArrayInputStream(bytes); 79 | DataInput in = new DataInputStream(b); 80 | // this.src_host = in.readUTF(); 81 | this.SrcWorkerPassCount = in.readInt(); 82 | this.GlobalPassCount = in.readInt(); 83 | 84 | this.IterationComplete = in.readInt(); 85 | this.CurrentIteration = in.readInt(); 86 | 87 | this.TrainedRecords = in.readInt(); // d.writeInt(this.TrainedRecords); 88 | this.AvgLogLikelihood = in.readFloat(); // d.writeFloat(this.AvgLogLikelihood); 89 | this.PercentCorrect = in.readFloat(); // d.writeFloat(this.PercentCorrect); 90 | 91 | this.parameter_vector = MatrixWritable.readMatrix(in); 92 | 93 | } 94 | 95 | public int numFeatures() { 96 | return this.parameter_vector.numCols(); 97 | } 98 | 99 | public int numCategories() { 100 | return this.parameter_vector.numRows(); 101 | } 102 | 103 | 104 | 105 | 106 | 107 | 108 | /** 109 | * TODO: fix loop 110 | * 111 | * @param other_gamma 112 | */ 113 | public void AccumulateParameterVector(Matrix other_gamma) { 114 | 115 | // this.gamma.plus(arg0) 116 | 117 | for (int row = 0; row < this.parameter_vector.rowSize(); row++) { 118 | 119 | for (int col = 0; col < this.parameter_vector.columnSize(); col++) { 120 | 121 | double old_this_val = this.parameter_vector.get(row, col); 122 | double other_val = other_gamma.get(row, col); 123 | 124 | // System.out.println( "Accumulate: " + old_this_val + ", " + other_val 125 | // ); 126 | 127 | this.parameter_vector.set(row, col, old_this_val + other_val); 128 | 129 | // System.out.println( "new value: " + this.gamma.get(row, col) ); 130 | 131 | } 132 | 133 | } 134 | 135 | // this.AccumulatedGradientsCount++; 136 | 137 | } 138 | /* 139 | public void Accumulate(GradientBuffer other_gamma) { 140 | 141 | for (int row = 0; row < this.gamma.rowSize(); row++) { 142 | 143 | for (int col = 0; col < this.gamma.columnSize(); col++) { 144 | 145 | double old_this_val = this.gamma.get(row, col); 146 | double other_val = other_gamma.getCell(row, col); 147 | 148 | this.gamma.set(row, col, old_this_val + other_val); 149 | 150 | } 151 | 152 | } 153 | 154 | this.AccumulatedGradientsCount++; 155 | 156 | } 157 | */ 158 | 159 | /** 160 | * TODO: Need to take a look at built in matrix ops here 161 | * 162 | */ 163 | public void AverageParameterVectors(int denominator) { 164 | 165 | for (int row = 0; row < this.parameter_vector.rowSize(); row++) { 166 | 167 | for (int col = 0; col < this.parameter_vector.columnSize(); col++) { 168 | 169 | double old_this_val = this.parameter_vector.get(row, col); 170 | // double other_val = other_gamma.getCell(row, col); 171 | this.parameter_vector.set(row, col, old_this_val / denominator); 172 | 173 | } 174 | 175 | } 176 | 177 | } 178 | 179 | 180 | 181 | } 182 | -------------------------------------------------------------------------------- /src/test/resources/donut_no_header.csv: -------------------------------------------------------------------------------- 1 | 0.923307513352484,0.0135197141207755,21,2,4,8,0.852496764213146,0.0124828536260896,0.000182782669907495,0.923406490600458,0.0778750292332978,0.644866125183976,1 2 | 0.711011884035543,0.909141522599384,22,2,3,9,0.505537899239772,0.64641042683833,0.826538308114327,1.15415605849213,0.953966686673604,0.46035073663368,1 3 | 0.75118898646906,0.836567111080512,23,2,3,9,0.564284893392414,0.62842000028592,0.699844531341594,1.12433510339845,0.872783737128441,0.419968245447719,1 4 | 0.308209649519995,0.418023289414123,24,1,5,1,0.094993188057238,0.128838811521522,0.174743470492603,0.519361780024138,0.808280495564412,0.208575453051705,1 5 | 0.849057961953804,0.500220163026825,25,1,5,2,0.720899422757147,0.424715912147755,0.250220211498583,0.985454024425153,0.52249756970547,0.349058031386046,1 6 | 0.0738831346388906,0.486534863477573,21,2,6,1,0.00545871758406844,0.0359467208248278,0.236716173379140,0.492112681164801,1.04613986717142,0.42632955896436,1 7 | 0.612888508243486,0.0204555552918464,22,2,4,10,0.375632323536926,0.0125369747681119,0.000418429742297785,0.613229772009826,0.387651566219268,0.492652707029903,1 8 | 0.207169560948387,0.932857288978994,23,2,1,4,0.0429192269835473,0.193259634985281,0.870222721601238,0.955584610897845,1.22425602987611,0.522604151014326,1 9 | 0.309267645236105,0.506309477845207,24,1,5,1,0.0956464763898851,0.156585139973909,0.256349287355886,0.593292308854389,0.856423069092351,0.190836685845410,1 10 | 0.78758287569508,0.171928803203627,25,2,4,10,0.620286786088131,0.135408181241926,0.0295595133710317,0.806130448165285,0.273277419610556,0.436273561610666,1 11 | 0.930236018029973,0.0790199618786573,21,2,4,8,0.86533904924026,0.0735072146828825,0.00624415437530446,0.93358620577618,0.105409523078414,0.601936228937031,1 12 | 0.238834470743313,0.623727766098455,22,1,5,1,0.0570419044152386,0.148967690904034,0.389036326202168,0.667890882268509,0.984077887735915,0.288991338582386,1 13 | 0.83537525916472,0.802311758277938,23,2,3,7,0.697851823624524,0.670231393002335,0.643704157471036,1.15825557675997,0.819027144096042,0.451518508649315,1 14 | 0.656760312616825,0.320640653371811,24,1,5,3,0.43133410822855,0.210584055746134,0.102810428594702,0.730851925374252,0.469706197095164,0.238209090579297,1 15 | 0.180789119331166,0.114329558331519,25,2,2,5,0.0326847056685386,0.0206695401642766,0.0130712479082803,0.213906413126907,0.82715035810576,0.500636870310341,1 16 | 0.990028728265315,0.061085847672075,21,2,4,8,0.980156882790638,0.0604767440857932,0.00373148078581595,0.991911469626425,0.06189432159595,0.657855445853466,1 17 | 0.751934139290825,0.972332585137337,22,2,3,9,0.565404949831033,0.731130065509666,0.945430656119858,1.22916052895905,1.00347761677540,0.535321288127727,1 18 | 0.136412925552577,0.552212274167687,23,2,6,1,0.0186084862578129,0.0753288918452558,0.304938395741448,0.5688118159807,1.02504684326820,0.3673168690368,1 19 | 0.5729476721026,0.0981996888294816,24,2,4,10,0.328269034967789,0.0562632831160512,0.0096431788862070,0.581302170866406,0.43819729534628,0.408368525870829,1 20 | 0.446335297077894,0.339370004367083,25,1,5,3,0.199215197417612,0.151472811718508,0.115171999864114,0.560702414192882,0.649397107420365,0.169357302283512,1 21 | 0.922843366628513,0.912627586396411,21,2,3,7,0.851639879330248,0.842212314308118,0.832889111451739,1.29789405992245,0.915883320912091,0.590811338548155,1 22 | 0.166969822719693,0.398156099021435,22,2,6,1,0.0278789216990458,0.0664800532683736,0.158528279187967,0.431749002184154,0.923291695753637,0.348254618269284,1 23 | 0.350683249300346,0.84422400011681,23,2,1,6,0.122978741339848,0.296055215498298,0.712714162373228,0.914162405545687,1.06504760696993,0.375214144584023,1 24 | 0.47748578293249,0.792779305484146,24,1,5,6,0.227992672902653,0.378540847371773,0.628499027203925,0.9254683679665,0.949484141121692,0.29364368150863,1 25 | 0.384564548265189,0.153326370986179,25,2,2,5,0.147889891782409,0.0589638865954405,0.0235089760397912,0.414003463538894,0.634247405427742,0.365387395199715,1 26 | 0.563622857443988,0.467359990812838,21,1,5,3,0.317670725433326,0.263414773476928,0.218425361012576,0.73218582781006,0.639414084578942,0.071506910079209,1 27 | 0.343304847599939,0.854578266385943,22,2,1,6,0.117858218385617,0.293380861503846,0.730304013379203,0.920957236664559,1.07775346743350,0.387658506651072,1 28 | 0.666085948701948,0.710089378990233,23,1,5,2,0.443670491058174,0.472980557667886,0.504226926154735,0.973600234805286,0.784681795257806,0.267809801016930,1 29 | 0.190568120684475,0.0772022884339094,24,2,2,5,0.0363162086212125,0.0147122950193909,0.00596019333943254,0.205612261211838,0.813105258002736,0.523933195018469,1 30 | 0.353534662164748,0.427994541125372,25,1,5,1,0.124986757351942,0.151310905505115,0.183179327233118,0.555127088678854,0.775304301713569,0.163208092002022,1 31 | 0.127048352966085,0.927507144864649,21,2,1,4,0.0161412839913949,0.117838255119330,0.860269503774972,0.936168140755905,1.27370093893119,0.567322915045421,1 32 | 0.960906301159412,0.891004979610443,22,2,3,7,0.923340919607862,0.856172299272088,0.793889873690606,1.31043152942016,0.891862204031343,0.604416671286136,1 33 | 0.306814440060407,0.902291874401271,23,2,1,6,0.094135100629581,0.276836176215481,0.81413062661056,0.953029761990747,1.13782109627099,0.446272800849954,1 34 | 0.087350245565176,0.671402548439801,24,2,6,4,0.00763006540029655,0.0586471774793016,0.450781382051459,0.677060889028273,1.13300968942079,0.446831795474291,1 35 | 0.27015240653418,0.371201378758997,25,1,5,1,0.0729823227562089,0.100280945780549,0.137790463592580,0.459099974241765,0.81882108746687,0.263474858488646,1 36 | 0.871842501685023,0.569787061074749,21,2,3,2,0.7601093477444,0.496764576755166,0.324657294968199,1.04152131169391,0.584021951079369,0.378334613738721,1 37 | 0.686449621338397,0.169308491749689,22,2,4,10,0.471213082635629,0.116221750050949,0.0286653653785545,0.707020825728764,0.356341416814533,0.379631841296403,1 38 | 0.67132937326096,0.571220482233912,23,1,5,2,0.450683127402953,0.383477088331915,0.326292839323543,0.881462402332905,0.659027480614106,0.185542747720368,1 39 | 0.548616112209857,0.405350996181369,24,1,5,3,0.300979638576258,0.222382087605415,0.164309430105228,0.682121007359754,0.606676886210257,0.106404700508298,1 40 | 0.677980388281867,0.993355110753328,25,2,3,9,0.459657406894831,0.673475283690318,0.986754376059756,1.20266860895036,1.04424662144096,0.524477152905055,1 41 | -------------------------------------------------------------------------------- /src/test/java/com/cloudera/knittingboar/sgd/TestParallelOnlineLogisticRegression.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.cloudera.knittingboar.sgd; 19 | 20 | import java.util.ArrayList; 21 | 22 | import junit.framework.TestCase; 23 | 24 | import org.apache.mahout.classifier.sgd.L1; 25 | import org.apache.mahout.math.RandomAccessSparseVector; 26 | import org.apache.mahout.math.Vector; 27 | 28 | import com.cloudera.knittingboar.utils.Utils; 29 | 30 | /** 31 | * Mostly temporary tests used to debug components as we developed the system 32 | * 33 | * @author jpatterson 34 | * 35 | */ 36 | public class TestParallelOnlineLogisticRegression extends TestCase { 37 | 38 | 39 | public void testCreateLR() { 40 | 41 | int categories = 2; 42 | int numFeatures = 5; 43 | double lambda = 1.0e-4; 44 | double learning_rate = 50; 45 | 46 | 47 | ParallelOnlineLogisticRegression plr = new ParallelOnlineLogisticRegression( categories, numFeatures, new L1()) 48 | .lambda(lambda) 49 | .learningRate(learning_rate) 50 | .alpha(1 - 1.0e-3); 51 | 52 | assertEquals( plr.getLambda(), 1.0e-4 ); 53 | } 54 | 55 | public void testTrainMechanics() { 56 | 57 | 58 | int categories = 2; 59 | int numFeatures = 5; 60 | double lambda = 1.0e-4; 61 | double learning_rate = 10; 62 | 63 | ParallelOnlineLogisticRegression plr = new ParallelOnlineLogisticRegression( categories, numFeatures, new L1()) 64 | .lambda(lambda) 65 | .learningRate(learning_rate) 66 | .alpha(1 - 1.0e-3); 67 | 68 | 69 | Vector input = new RandomAccessSparseVector(numFeatures); 70 | 71 | for ( int x = 0; x < numFeatures; x++ ) { 72 | 73 | input.set(x, x); 74 | 75 | } 76 | 77 | plr.train(0, input); 78 | 79 | plr.train(0, input); 80 | 81 | plr.train(0, input); 82 | 83 | 84 | } 85 | 86 | 87 | public void testPOLRInternalBuffers() { 88 | 89 | System.out.println( "testPOLRInternalBuffers --------------" ); 90 | 91 | int categories = 2; 92 | int numFeatures = 5; 93 | double lambda = 1.0e-4; 94 | double learning_rate = 10; 95 | 96 | ArrayList trainingSet_0 = new ArrayList(); 97 | 98 | for ( int s = 0; s < 1; s++ ) { 99 | 100 | 101 | Vector input = new RandomAccessSparseVector(numFeatures); 102 | 103 | for ( int x = 0; x < numFeatures; x++ ) { 104 | 105 | input.set(x, x); 106 | 107 | } 108 | 109 | trainingSet_0.add(input); 110 | 111 | } // for 112 | 113 | ParallelOnlineLogisticRegression plr_agent_0 = new ParallelOnlineLogisticRegression( categories, numFeatures, new L1()) 114 | .lambda(lambda) 115 | .learningRate(learning_rate) 116 | .alpha(1 - 1.0e-3); 117 | 118 | System.out.println( "Beta: " ); 119 | //Utils.PrintVectorNonZero(plr_agent_0.getBeta().getRow(0)); 120 | Utils.PrintVectorNonZero(plr_agent_0.getBeta().viewRow(0)); 121 | 122 | 123 | // System.out.println( "\nGamma: " ); 124 | //Utils.PrintVectorNonZero(plr_agent_0.gamma.getMatrix().getRow(0)); 125 | // Utils.PrintVectorNonZero(plr_agent_0.gamma.getMatrix().viewRow(0)); 126 | 127 | plr_agent_0.train(0, trainingSet_0.get(0) ); 128 | 129 | System.out.println( "Beta: " ); 130 | //Utils.PrintVectorNonZero(plr_agent_0.noReallyGetBeta().getRow(0)); 131 | Utils.PrintVectorNonZero(plr_agent_0.noReallyGetBeta().viewRow(0)); 132 | 133 | 134 | // System.out.println( "\nGamma: " ); 135 | //Utils.PrintVectorNonZero(plr_agent_0.gamma.getMatrix().getRow(0)); 136 | // Utils.PrintVectorNonZero(plr_agent_0.gamma.getMatrix().viewRow(0)); 137 | 138 | } 139 | 140 | public void testLocalGradientFlush() { 141 | 142 | 143 | System.out.println( "\n\n\ntestLocalGradientFlush --------------" ); 144 | 145 | int categories = 2; 146 | int numFeatures = 5; 147 | double lambda = 1.0e-4; 148 | double learning_rate = 10; 149 | 150 | ArrayList trainingSet_0 = new ArrayList(); 151 | 152 | for ( int s = 0; s < 1; s++ ) { 153 | 154 | 155 | Vector input = new RandomAccessSparseVector(numFeatures); 156 | 157 | for ( int x = 0; x < numFeatures; x++ ) { 158 | 159 | input.set(x, x); 160 | 161 | } 162 | 163 | trainingSet_0.add(input); 164 | 165 | } // for 166 | 167 | ParallelOnlineLogisticRegression plr_agent_0 = new ParallelOnlineLogisticRegression( categories, numFeatures, new L1()) 168 | .lambda(lambda) 169 | .learningRate(learning_rate) 170 | .alpha(1 - 1.0e-3); 171 | 172 | 173 | plr_agent_0.train(0, trainingSet_0.get(0) ); 174 | 175 | // System.out.println( "\nGamma: " ); 176 | // Utils.PrintVectorNonZero(plr_agent_0.gamma.getMatrix().viewRow(0)); 177 | 178 | 179 | // plr_agent_0.FlushGamma(); 180 | /* 181 | System.out.println( "Flushing Gamma ...... " ); 182 | 183 | System.out.println( "\nGamma: " ); 184 | Utils.PrintVector(plr_agent_0.gamma.getMatrix().viewRow(0)); 185 | 186 | for ( int x = 0; x < numFeatures; x++ ) { 187 | 188 | assertEquals( plr_agent_0.gamma.getMatrix().get(0, x), 0.0 ); 189 | 190 | } 191 | */ 192 | } 193 | 194 | 195 | 196 | } 197 | -------------------------------------------------------------------------------- /src/test/java/com/cloudera/knittingboar/metrics/TestSaveLoadModel.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.cloudera.knittingboar.metrics; 19 | 20 | import java.io.File; 21 | import java.io.IOException; 22 | import java.util.ArrayList; 23 | 24 | import org.apache.hadoop.conf.Configuration; 25 | import org.apache.hadoop.fs.FSDataOutputStream; 26 | import org.apache.hadoop.fs.FileSystem; 27 | import org.apache.hadoop.fs.Path; 28 | import org.apache.hadoop.mapred.FileInputFormat; 29 | import org.apache.hadoop.mapred.InputSplit; 30 | import org.apache.hadoop.mapred.JobConf; 31 | import org.apache.hadoop.mapred.TextInputFormat; 32 | 33 | import com.cloudera.iterativereduce.irunit.IRUnitDriver; 34 | import com.cloudera.knittingboar.io.InputRecordsSplit; 35 | import com.cloudera.knittingboar.sgd.iterativereduce.POLRMasterNode; 36 | import com.cloudera.knittingboar.utils.DataUtils; 37 | import com.cloudera.knittingboar.utils.DatasetConverter; 38 | 39 | import junit.framework.TestCase; 40 | 41 | /** 42 | * This unit test tests running a 4 worker simulated parallel SGD process, saving the model, 43 | * loading the model, and then checking to see that the parameters of the model deserialized correctly 44 | * 45 | * @author jpatterson 46 | * 47 | */ 48 | public class TestSaveLoadModel extends TestCase { 49 | 50 | private static JobConf defaultConf = new JobConf(); 51 | private static FileSystem localFs = null; 52 | static { 53 | try { 54 | defaultConf.set("fs.defaultFS", "file:///"); 55 | localFs = FileSystem.getLocal(defaultConf); 56 | } catch (IOException e) { 57 | throw new RuntimeException("init failure", e); 58 | } 59 | } 60 | 61 | private static Path workDir20NewsLocal = new Path(new Path("/tmp"), "TestSaveLoadModel"); 62 | private static File unzipDir = new File( workDir20NewsLocal + "/20news-bydate"); 63 | private static String strKBoarTrainDirInput = "" + unzipDir.toString() + "/KBoar-train/"; 64 | //private static String strKBoarTestDirInput = "" + unzipDir.toString() + "/KBoar-test/"; 65 | 66 | 67 | // location of N splits of KBoar converted data --- 68 | private static Path workDir = new Path( strKBoarTrainDirInput ); //DataUtils.get20NewsgroupsLocalDataLocation() + "/20news-bydate-train/" ); 69 | 70 | public Configuration generateDebugConfigurationObject() { 71 | 72 | Configuration c = new Configuration(); 73 | 74 | // feature vector size 75 | c.setInt( "com.cloudera.knittingboar.setup.FeatureVectorSize", 10000 ); 76 | 77 | c.setInt( "com.cloudera.knittingboar.setup.numCategories", 20); 78 | 79 | // setup 20newsgroups 80 | c.set( "com.cloudera.knittingboar.setup.RecordFactoryClassname", "com.cloudera.knittingboar.records.TwentyNewsgroupsRecordFactory"); 81 | 82 | return c; 83 | 84 | } 85 | 86 | 87 | public void testRunMasterAndFourWorkers() throws Exception { 88 | 89 | DataUtils.getTwentyNewsGroupDir(); 90 | 91 | // convert the training data into 4 shards 92 | DatasetConverter.ConvertNewsgroupsFromSingleFiles( DataUtils.get20NewsgroupsLocalDataLocation() + "/20news-bydate-train/", strKBoarTrainDirInput, 3000); 93 | 94 | // convert the test data into 1 shard 95 | // DatasetConverter.ConvertNewsgroupsFromSingleFiles( DataUtils.get20NewsgroupsLocalDataLocation() + "/20news-bydate-test/", strKBoarTestDirInput, 12000); 96 | 97 | 98 | int num_passes = 15; 99 | 100 | 101 | String[] props = { 102 | "app.iteration.count", 103 | "com.cloudera.knittingboar.setup.FeatureVectorSize", 104 | "com.cloudera.knittingboar.setup.numCategories", 105 | "com.cloudera.knittingboar.setup.RecordFactoryClassname" 106 | }; 107 | 108 | IRUnitDriver polr_driver = new IRUnitDriver("src/test/resources/app_unit_test.properties", props ); 109 | 110 | polr_driver.SetProperty("app.input.path", strKBoarTrainDirInput); 111 | 112 | polr_driver.Setup(); 113 | polr_driver.SimulateRun(); 114 | 115 | System.out.println("\n\nComplete..."); 116 | 117 | POLRMasterNode IR_Master = (POLRMasterNode)polr_driver.getMaster(); 118 | 119 | Path out = new Path("/tmp/TestSaveLoadModel.model"); 120 | FileSystem fs = out.getFileSystem(defaultConf); 121 | FSDataOutputStream fos = fs.create(out); 122 | 123 | IR_Master.complete(fos); 124 | 125 | fos.flush(); 126 | fos.close(); 127 | 128 | System.out.println("\n\nModel Saved: /tmp/TestSaveLoadModel.model" ); 129 | 130 | 131 | System.out.println( "\n\n> Loading Model for tests..." ); 132 | 133 | 134 | POLRModelTester tester = new POLRModelTester(); 135 | 136 | 137 | 138 | // ------------------ 139 | // generate the debug conf ---- normally setup by YARN stuff 140 | tester.setConf(this.generateDebugConfigurationObject()); 141 | // now load the conf stuff into locally used vars 142 | try { 143 | tester.LoadConfigVarsLocally(); 144 | } catch (Exception e) { 145 | // TODO Auto-generated catch block 146 | e.printStackTrace(); 147 | System.out.println( "Conf load fail: shutting down." ); 148 | assertEquals( 0, 1 ); 149 | } 150 | // now construct any needed machine learning data structures based on config 151 | tester.Setup(); 152 | tester.Load( "/tmp/TestSaveLoadModel.model" ); 153 | 154 | assertEquals( 1.0e-4, tester.polr.getLambda() ); 155 | 156 | assertEquals( 20, tester.polr.numCategories() ); 157 | 158 | } 159 | 160 | 161 | 162 | } 163 | -------------------------------------------------------------------------------- /src/main/java/com/cloudera/knittingboar/yarn/avro/generated/ServiceError.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Autogenerated by Avro 3 | * 4 | * DO NOT EDIT DIRECTLY 5 | */ 6 | package com.cloudera.knittingboar.yarn.avro.generated; 7 | @SuppressWarnings("all") 8 | public class ServiceError extends org.apache.avro.specific.SpecificExceptionBase implements org.apache.avro.specific.SpecificRecord { 9 | public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"error\",\"name\":\"ServiceError\",\"namespace\":\"com.cloudera.knittingboar.yarn.avro.generated\",\"fields\":[{\"name\":\"description\",\"type\":[\"null\",\"string\"]}]}"); 10 | @Deprecated public java.lang.CharSequence description; 11 | 12 | public ServiceError() { 13 | super(); 14 | } 15 | 16 | public ServiceError(Object value) { 17 | super(value); 18 | } 19 | 20 | public ServiceError(Throwable cause) { 21 | super(cause); 22 | } 23 | 24 | public ServiceError(Object value, Throwable cause) { 25 | super(value, cause); 26 | } 27 | 28 | public org.apache.avro.Schema getSchema() { return SCHEMA$; } 29 | // Used by DatumWriter. Applications should not call. 30 | public java.lang.Object get(int field$) { 31 | switch (field$) { 32 | case 0: return description; 33 | default: throw new org.apache.avro.AvroRuntimeException("Bad index"); 34 | } 35 | } 36 | // Used by DatumReader. Applications should not call. 37 | @SuppressWarnings(value="unchecked") 38 | public void put(int field$, java.lang.Object value$) { 39 | switch (field$) { 40 | case 0: description = (java.lang.CharSequence)value$; break; 41 | default: throw new org.apache.avro.AvroRuntimeException("Bad index"); 42 | } 43 | } 44 | 45 | /** 46 | * Gets the value of the 'description' field. 47 | */ 48 | public java.lang.CharSequence getDescription() { 49 | return description; 50 | } 51 | 52 | /** 53 | * Sets the value of the 'description' field. 54 | * @param value the value to set. 55 | */ 56 | public void setDescription(java.lang.CharSequence value) { 57 | this.description = value; 58 | } 59 | 60 | /** Creates a new ServiceError RecordBuilder */ 61 | public static com.cloudera.knittingboar.yarn.avro.generated.ServiceError.Builder newBuilder() { 62 | return new com.cloudera.knittingboar.yarn.avro.generated.ServiceError.Builder(); 63 | } 64 | 65 | /** Creates a new ServiceError RecordBuilder by copying an existing Builder */ 66 | public static com.cloudera.knittingboar.yarn.avro.generated.ServiceError.Builder newBuilder(com.cloudera.knittingboar.yarn.avro.generated.ServiceError.Builder other) { 67 | return new com.cloudera.knittingboar.yarn.avro.generated.ServiceError.Builder(other); 68 | } 69 | 70 | /** Creates a new ServiceError RecordBuilder by copying an existing ServiceError instance */ 71 | public static com.cloudera.knittingboar.yarn.avro.generated.ServiceError.Builder newBuilder(com.cloudera.knittingboar.yarn.avro.generated.ServiceError other) { 72 | return new com.cloudera.knittingboar.yarn.avro.generated.ServiceError.Builder(other); 73 | } 74 | 75 | /** 76 | * RecordBuilder for ServiceError instances. 77 | */ 78 | public static class Builder extends org.apache.avro.specific.SpecificErrorBuilderBase 79 | implements org.apache.avro.data.ErrorBuilder { 80 | 81 | private java.lang.CharSequence description; 82 | 83 | /** Creates a new Builder */ 84 | private Builder() { 85 | super(com.cloudera.knittingboar.yarn.avro.generated.ServiceError.SCHEMA$); 86 | } 87 | 88 | /** Creates a Builder by copying an existing Builder */ 89 | private Builder(com.cloudera.knittingboar.yarn.avro.generated.ServiceError.Builder other) { 90 | super(other); 91 | } 92 | 93 | /** Creates a Builder by copying an existing ServiceError instance */ 94 | private Builder(com.cloudera.knittingboar.yarn.avro.generated.ServiceError other) { 95 | super(other); 96 | if (isValidValue(fields()[0], other.description)) { 97 | this.description = (java.lang.CharSequence) data().deepCopy(fields()[0].schema(), other.description); 98 | fieldSetFlags()[0] = true; 99 | } 100 | } 101 | 102 | @Override 103 | public com.cloudera.knittingboar.yarn.avro.generated.ServiceError.Builder setValue(Object value) { 104 | super.setValue(value); 105 | return this; 106 | } 107 | 108 | @Override 109 | public com.cloudera.knittingboar.yarn.avro.generated.ServiceError.Builder clearValue() { 110 | super.clearValue(); 111 | return this; 112 | } 113 | 114 | @Override 115 | public com.cloudera.knittingboar.yarn.avro.generated.ServiceError.Builder setCause(Throwable cause) { 116 | super.setCause(cause); 117 | return this; 118 | } 119 | 120 | @Override 121 | public com.cloudera.knittingboar.yarn.avro.generated.ServiceError.Builder clearCause() { 122 | super.clearCause(); 123 | return this; 124 | } 125 | 126 | /** Gets the value of the 'description' field */ 127 | public java.lang.CharSequence getDescription() { 128 | return description; 129 | } 130 | 131 | /** Sets the value of the 'description' field */ 132 | public com.cloudera.knittingboar.yarn.avro.generated.ServiceError.Builder setDescription(java.lang.CharSequence value) { 133 | validate(fields()[0], value); 134 | this.description = value; 135 | fieldSetFlags()[0] = true; 136 | return this; 137 | } 138 | 139 | /** Checks whether the 'description' field has been set */ 140 | public boolean hasDescription() { 141 | return fieldSetFlags()[0]; 142 | } 143 | 144 | /** Clears the value of the 'description' field */ 145 | public com.cloudera.knittingboar.yarn.avro.generated.ServiceError.Builder clearDescription() { 146 | description = null; 147 | fieldSetFlags()[0] = false; 148 | return this; 149 | } 150 | 151 | @Override 152 | public ServiceError build() { 153 | try { 154 | ServiceError record = new ServiceError(getValue(), getCause()); 155 | record.description = fieldSetFlags()[0] ? this.description : (java.lang.CharSequence) defaultValue(fields()[0]); 156 | return record; 157 | } catch (Exception e) { 158 | throw new org.apache.avro.AvroRuntimeException(e); 159 | } 160 | } 161 | } 162 | } 163 | -------------------------------------------------------------------------------- /src/test/java/com/cloudera/knittingboar/records/TestRCV1RecordFactory.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.cloudera.knittingboar.records; 19 | 20 | import java.io.IOException; 21 | 22 | import org.apache.mahout.math.RandomAccessSparseVector; 23 | import org.apache.mahout.math.Vector; 24 | 25 | import junit.framework.TestCase; 26 | 27 | public class TestRCV1RecordFactory extends TestCase { 28 | 29 | String training_rec_0 = "0 |f 7:4.3696374e-02 8:1.0872085e-01 19:2.2659289e-02 20:1.6952585e-02 50:3.3265986e-02 52:2.8914521e-02 99:6.2935837e-02 111:3.6749814e-02 124:4.5141779e-02 147:2.7024418e-02 153:5.3756956e-02 169:2.8062440e-02 182:4.7379807e-02 183:2.7668567e-02 188:1.4508039e-02 269:2.3687121e-02 271:2.0829555e-02 297:2.8352227e-02 311:3.3546336e-02 319:2.8875276e-02 332:4.7258154e-02 337:3.1720489e-02 360:6.8111412e-02 368:4.4445150e-02 411:4.4164777e-02 488:8.4059432e-02 586:2.9122708e-02 591:9.5403686e-02 664:3.6937956e-02 702:2.8176809e-02 737:1.6336726e-01 739:5.1228814e-02 757:3.4760747e-02 764:3.6367100e-02 768:6.1244022e-02 791:2.1772176e-01 817:7.4271448e-02 848:4.0480603e-02 895:5.1346138e-02 933:4.1986264e-02 979:1.1311502e-01 1003:4.5158323e-02 1005:4.0224314e-02 1021:4.6525169e-02 1071:2.9869374e-02 1127:2.1704819e-02 1133:4.5880664e-02 1162:3.8132094e-02 1178:5.2212182e-02 1180:1.0740499e-01 1338:4.9277205e-02 1360:4.6650354e-02 1498:5.9916675e-02 1511:7.6297082e-02 1577:5.0769087e-02 1659:5.0992116e-02 1666:2.4987224e-02 1674:2.9845037e-02 1810:4.6527624e-02 1966:4.3204561e-02 2018:4.3157250e-02 2066:1.3678090e-01 2074:1.0599699e-01 2117:9.8577492e-02 2183:1.4329165e-01 2248:1.2792459e-01 2276:7.9498030e-02 2316:4.9681831e-02 2340:5.8379412e-02 2762:5.1772792e-02 2771:4.9624689e-02 3077:2.1542890e-01 3227:8.3143584e-02 3246:5.2039523e-02 3282:5.2630566e-02 3369:7.0463479e-02 3615:5.6905646e-02 3620:6.6913836e-02 3962:6.1502680e-02 4132:2.1751978e-01 4143:2.6172629e-01 4144:9.1886766e-02 4499:1.1314832e-01 5031:7.9870239e-02 5055:8.6920090e-02 5401:5.4840717e-02 5423:9.5343769e-02 5860:8.9788958e-02 6065:8.6977042e-02 6668:7.6055169e-02 6697:6.8251781e-02 7139:6.4996362e-02 7426:1.2097790e-01 7606:1.9588335e-01 8870:1.4963643e-01 9804:9.4143294e-02 12121:7.4564628e-02 13942:1.6451047e-01 14595:1.0607405e-01 15422:8.9860193e-02 15652:1.0834268e-01 16223:9.6487328e-02 16859:1.0539885e-01 17424:8.1960648e-02 19529:9.3970060e-02 23299:1.8965191e-01 24377:1.0888006e-01 24448:9.8843329e-02 24454:2.8149781e-01 24455:2.1925208e-01 26622:1.0557952e-01 31771:1.3358803e-01 41700:1.2038895e-0"; 30 | String training_rec_1 = "1 |f 9:3.3307336e-02 13:2.6428020e-02 55:4.5726493e-02 69:3.0852484e-02 93:7.4375033e-02 111:4.7884714e-02 140:5.6255672e-02 151:6.5337561e-02 153:7.0044883e-02 161:5.7628881e-02 175:4.5645405e-02 180:4.6431489e-02 187:5.3116236e-02 193:3.8840283e-02 209:6.7031987e-02 217:3.3030130e-02 229:4.9895555e-02 233:2.7318209e-02 236:2.9892704e-02 252:5.6756295e-02 258:4.4865504e-02 260:6.3265145e-02 263:5.3964965e-02 269:5.2257512e-02 271:4.5953277e-02 276:3.2793090e-02 286:3.6571421e-02 288:2.9139040e-02 319:3.7624255e-02 334:7.6396912e-02 338:3.6081653e-02 362:7.6015718e-02 389:5.4903280e-02 417:2.6063753e-02 426:5.4687556e-02 438:7.3853023e-02 453:3.9429404e-02 477:4.9223945e-02 480:5.3083062e-02 488:5.2191041e-02 506:3.5930801e-02 558:8.1321917e-02 561:5.6125600e-02 594:5.5980112e-02 617:8.4778033e-02 623:4.6125464e-02 642:4.1558836e-02 644:1.0274204e-01 755:6.0711723e-02 803:5.0099224e-02 836:5.0887167e-02 837:5.0023027e-02 951:6.7326568e-02 1012:1.0415152e-01 1037:7.9722285e-02 1059:2.8764643e-02 1061:5.8410108e-02 1077:6.1814863e-02 1089:7.1079604e-02 1129:7.9089224e-02 1133:5.9782140e-02 1196:1.0324168e-01 1212:9.1613702e-02 1218:3.9104007e-02 1221:6.0955465e-02 1237:5.1370349e-02 1240:7.8351930e-02 1241:8.0820285e-02 1287:7.0892565e-02 1342:4.7291242e-02 1356:6.4580373e-02 1486:5.8335997e-02 1492:5.5702407e-02 1499:8.0403641e-02 1546:7.6899387e-02 1575:5.2044731e-02 1626:7.6970406e-02 1659:1.1249663e-01 1823:1.1926711e-01 1839:8.0284260e-02 1976:9.9477187e-02 1985:7.4262738e-02 2008:7.3236965e-02 2061:1.1390504e-01 2153:4.4327311e-02 2190:9.7444594e-02 2212:5.6166001e-02 2234:4.9261238e-02 2446:6.5645538e-02 2455:7.3392190e-02 3088:9.1792777e-02 3230:5.9966434e-02 3247:1.3415127e-01 3261:8.2616769e-02 3306:2.2782215e-01 3394:6.7915484e-02 3443:8.2661413e-02 3669:2.0983912e-01 3690:1.0559268e-01 3899:7.4440584e-02 4367:2.0285535e-01 4369:9.3478583e-02 4455:2.1870497e-01 4500:8.8423654e-02 4965:7.5678401e-02 5223:7.7687882e-02 5539:1.2643857e-01 5874:7.3007621e-02 5945:8.4572427e-02 5995:8.5160516e-02 6025:9.9534206e-02 6051:1.3297138e-01 6510:1.3806941e-01 7456:1.9141673e-01 9117:9.7984366e-02 9257:1.1816826e-01 9461:1.6047940e-01 10985:1.1505207e-01 13963:2.0298573e-01 14669:1.3871375e-01 14770:1.0384714e-01 16122:1.2237830e-01 23461:1.8466952e-01 23701:1.1287191e-01 25082:1.4810963e-01 40790:1.3507748e-01"; 31 | String training_rec_2 = "1 |f 5:5.2558247e-02 63:1.2409918e-01 70:6.0630817e-02 188:4.9340606e-02 193:1.0137629e-01 286:9.5454372e-02 571:3.5901710e-01 811:8.9977279e-02 942:1.6419628e-01 1193:9.6185014e-02 1380:1.2333614e-01 1702:3.7394261e-01 1766:4.7051618e-01 2244:3.2123861e-01 2298:4.9728591e-02 2554:3.1551713e-01 2565:3.2508451e-01 3400:1.7101182e-01 13750:2.5799808e-01"; 32 | 33 | 34 | public void testParse() throws Exception { 35 | 36 | RCV1RecordFactory factory = new RCV1RecordFactory(); 37 | 38 | Vector v = new RandomAccessSparseVector(RCV1RecordFactory.FEATURES); 39 | 40 | int actual = factory.processLine(training_rec_0, v); 41 | 42 | assertEquals( 0, actual); 43 | assertEquals( .043696374, v.get(7)); 44 | 45 | 46 | 47 | 48 | Vector v2 = new RandomAccessSparseVector(RCV1RecordFactory.FEATURES); 49 | 50 | int actual2 = factory.processLine(training_rec_1, v2); 51 | 52 | assertEquals( 1, actual2); 53 | assertEquals( .030852484, v2.get(69)); 54 | 55 | 56 | 57 | 58 | 59 | 60 | } 61 | 62 | } 63 | -------------------------------------------------------------------------------- /src/main/java/com/cloudera/knittingboar/sgd/MultinomialLogisticRegressionParameterVectors_deprecated.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.cloudera.knittingboar.sgd; 19 | 20 | import java.util.ArrayList; 21 | 22 | import org.apache.mahout.math.DenseMatrix; 23 | import org.apache.mahout.math.Matrix; 24 | 25 | import com.cloudera.knittingboar.utils.Utils; 26 | 27 | /** 28 | * An array of parameter vectors in the form of a Matrix object 29 | * 30 | * A wrapper around another Matrix object to hold the worker's parameter vectors 31 | * 32 | * @author jpatterson 33 | * 34 | */ 35 | public class MultinomialLogisticRegressionParameterVectors_deprecated { 36 | 37 | protected Matrix gamma; // this is the saved updated gradient we merge at the 38 | // super step 39 | 40 | private int AccumulatedGradientsCount = 0; 41 | private int numCategories = 2; // default 42 | 43 | public MultinomialLogisticRegressionParameterVectors_deprecated(int numCategories, int numFeatures) { 44 | 45 | this.numCategories = numCategories; 46 | this.gamma = new DenseMatrix(numCategories - 1, numFeatures); 47 | 48 | } 49 | 50 | public void setMatrix(Matrix m) { 51 | 52 | this.gamma = m; 53 | 54 | } 55 | 56 | public Matrix getMatrix() { 57 | // close(); 58 | return this.gamma; 59 | } 60 | 61 | public void setCell(int i, int j, double gammaIJ) { 62 | this.gamma.set(i, j, gammaIJ); 63 | } 64 | 65 | public double getCell(int row, int col) { 66 | return this.gamma.get(row, col); 67 | } 68 | 69 | public int numFeatures() { 70 | return this.gamma.numCols(); 71 | } 72 | 73 | public int numCategories() { 74 | return this.numCategories; 75 | } 76 | 77 | private void ClearGradientBuffer() { 78 | 79 | this.gamma = this.gamma.like(); 80 | 81 | } 82 | 83 | public void SetupMerge() { 84 | 85 | this.AccumulatedGradientsCount = 0; 86 | this.ClearGradientBuffer(); 87 | 88 | } 89 | 90 | /** 91 | * TODO: fix loop 92 | * 93 | * @param other_gamma 94 | */ 95 | public void AccumulateParameterVector(Matrix other_gamma) { 96 | 97 | // this.gamma.plus(arg0) 98 | 99 | for (int row = 0; row < this.gamma.rowSize(); row++) { 100 | 101 | for (int col = 0; col < this.gamma.columnSize(); col++) { 102 | 103 | double old_this_val = this.gamma.get(row, col); 104 | double other_val = other_gamma.get(row, col); 105 | 106 | // System.out.println( "Accumulate: " + old_this_val + ", " + other_val 107 | // ); 108 | 109 | this.gamma.set(row, col, old_this_val + other_val); 110 | 111 | // System.out.println( "new value: " + this.gamma.get(row, col) ); 112 | 113 | } 114 | 115 | } 116 | 117 | this.AccumulatedGradientsCount++; 118 | 119 | } 120 | /* 121 | public void Accumulate(GradientBuffer other_gamma) { 122 | 123 | for (int row = 0; row < this.gamma.rowSize(); row++) { 124 | 125 | for (int col = 0; col < this.gamma.columnSize(); col++) { 126 | 127 | double old_this_val = this.gamma.get(row, col); 128 | double other_val = other_gamma.getCell(row, col); 129 | 130 | this.gamma.set(row, col, old_this_val + other_val); 131 | 132 | } 133 | 134 | } 135 | 136 | this.AccumulatedGradientsCount++; 137 | 138 | } 139 | */ 140 | 141 | /** 142 | * TODO: Need to take a look at built in matrix ops here 143 | * 144 | */ 145 | public void AverageParameterVectors(int denominator) { 146 | 147 | for (int row = 0; row < this.gamma.rowSize(); row++) { 148 | 149 | for (int col = 0; col < this.gamma.columnSize(); col++) { 150 | 151 | double old_this_val = this.gamma.get(row, col); 152 | // double other_val = other_gamma.getCell(row, col); 153 | this.gamma.set(row, col, old_this_val / denominator); 154 | 155 | } 156 | 157 | } 158 | 159 | } 160 | 161 | /** 162 | * Copies another GradientBufffer 163 | * 164 | * @param other 165 | */ 166 | /* public void Copy(MultinomialLogisticRegressionParameterVectors other) { 167 | 168 | this.Accumulate(other); 169 | 170 | } 171 | 172 | public static GradientBuffer Merge(ArrayList gammas) { 173 | 174 | int numFeatures = gammas.get(0).numFeatures(); 175 | int numCategories = gammas.get(0).numCategories(); 176 | 177 | GradientBuffer merged = new GradientBuffer(numCategories, numFeatures); 178 | 179 | // accumulate all the gradients into buffers 180 | 181 | for (int x = 0; x < gammas.size(); x++) { 182 | 183 | merged.Accumulate(gammas.get(x)); 184 | 185 | } 186 | 187 | // calc average 188 | 189 | return merged; 190 | 191 | } 192 | */ 193 | /** 194 | * Clears all values in matrix back to 0s 195 | */ 196 | public void Reset() { 197 | 198 | for (int row = 0; row < this.gamma.rowSize(); row++) { 199 | 200 | for (int col = 0; col < this.gamma.columnSize(); col++) { 201 | 202 | this.gamma.set(row, col, 0); 203 | 204 | } 205 | 206 | } 207 | 208 | this.AccumulatedGradientsCount = 0; 209 | 210 | } 211 | 212 | public void Debug() { 213 | 214 | System.out.println("\nGamma: "); 215 | 216 | for (int x = 0; x < this.gamma.rowSize(); x++) { 217 | 218 | Utils.PrintVectorSectionNonZero(this.gamma.viewRow(x), 3); 219 | 220 | } 221 | 222 | } 223 | 224 | public void DebugRowInGamma(int row) { 225 | 226 | System.out.println("\nGamma: "); 227 | Utils.PrintVectorSectionNonZero(this.gamma.viewRow(row), 3); 228 | 229 | } 230 | 231 | } 232 | -------------------------------------------------------------------------------- /src/test/java/com/cloudera/knittingboar/records/TestTwentyNewsgroupsCustomRecordParseOLRRun.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.cloudera.knittingboar.records; 19 | 20 | import java.io.File; 21 | 22 | import org.apache.commons.io.FileUtils; 23 | import org.apache.commons.logging.Log; 24 | import org.apache.commons.logging.LogFactory; 25 | import org.apache.hadoop.fs.FileSystem; 26 | import org.apache.hadoop.fs.Path; 27 | import org.apache.hadoop.io.Text; 28 | import org.apache.hadoop.mapred.FileInputFormat; 29 | import org.apache.hadoop.mapred.InputSplit; 30 | import org.apache.hadoop.mapred.JobConf; 31 | import org.apache.hadoop.mapred.TextInputFormat; 32 | import org.apache.mahout.classifier.sgd.L1; 33 | import org.apache.mahout.classifier.sgd.OnlineLogisticRegression; 34 | import org.apache.mahout.math.DenseVector; 35 | import org.apache.mahout.math.RandomAccessSparseVector; 36 | import org.apache.mahout.math.Vector; 37 | import org.junit.After; 38 | import org.junit.Before; 39 | import org.junit.Test; 40 | 41 | import com.cloudera.knittingboar.io.InputRecordsSplit; 42 | import com.cloudera.knittingboar.utils.TestingUtils; 43 | import com.google.common.io.Files; 44 | 45 | public class TestTwentyNewsgroupsCustomRecordParseOLRRun { 46 | private static final Log LOG = LogFactory 47 | .getLog(TestTwentyNewsgroupsCustomRecordParseOLRRun.class.getName()); 48 | 49 | private static final int FEATURES = 10000; 50 | 51 | private JobConf defaultConf; 52 | private FileSystem localFs; 53 | 54 | private File baseDir; 55 | private Path workDir; 56 | private String inputFileName; 57 | 58 | @Before 59 | public void setup() throws Exception { 60 | defaultConf = new JobConf(); 61 | defaultConf.set("fs.defaultFS", "file:///"); 62 | localFs = FileSystem.getLocal(defaultConf); 63 | inputFileName = "kboar-shard-0.txt"; 64 | baseDir = Files.createTempDir(); 65 | File inputFile = new File(baseDir, inputFileName); 66 | TestingUtils.copyDecompressed(inputFileName + ".gz", inputFile); 67 | workDir = new Path(baseDir.getAbsolutePath()); 68 | } 69 | 70 | @After 71 | public void teardown() throws Exception { 72 | FileUtils.deleteQuietly(baseDir); 73 | } 74 | 75 | @Test 76 | public void testRecordFactoryOnDatasetShard() throws Exception { 77 | // TODO a test with assertions is not a test 78 | // p.270 ----- metrics to track lucene's parsing mechanics, progress, 79 | // performance of OLR ------------ 80 | double averageLL = 0.0; 81 | double averageCorrect = 0.0; 82 | int k = 0; 83 | double step = 0.0; 84 | int[] bumps = new int[] { 1, 2, 5 }; 85 | 86 | TwentyNewsgroupsRecordFactory rec_factory = new TwentyNewsgroupsRecordFactory( 87 | "\t"); 88 | // rec_factory.setClassSplitString("\t"); 89 | 90 | JobConf job = new JobConf(defaultConf); 91 | 92 | long block_size = localFs.getDefaultBlockSize(workDir); 93 | 94 | LOG.info("default block size: " + (block_size / 1024 / 1024) + "MB"); 95 | 96 | // matches the OLR setup on p.269 --------------- 97 | // stepOffset, decay, and alpha --- describe how the learning rate decreases 98 | // lambda: amount of regularization 99 | // learningRate: amount of initial learning rate 100 | @SuppressWarnings("resource") 101 | OnlineLogisticRegression learningAlgorithm = new OnlineLogisticRegression( 102 | 20, FEATURES, new L1()).alpha(1).stepOffset(1000).decayExponent(0.9) 103 | .lambda(3.0e-5).learningRate(20); 104 | 105 | FileInputFormat.setInputPaths(job, workDir); 106 | 107 | // try splitting the file in a variety of sizes 108 | TextInputFormat format = new TextInputFormat(); 109 | format.configure(job); 110 | Text value = new Text(); 111 | 112 | int numSplits = 1; 113 | 114 | InputSplit[] splits = format.getSplits(job, numSplits); 115 | 116 | LOG.info("requested " + numSplits + " splits, splitting: got = " 117 | + splits.length); 118 | LOG.info("---- debug splits --------- "); 119 | rec_factory.Debug(); 120 | int total_read = 0; 121 | 122 | for (int x = 0; x < splits.length; x++) { 123 | 124 | LOG.info("> Split [" + x + "]: " + splits[x].getLength()); 125 | 126 | int count = 0; 127 | InputRecordsSplit custom_reader = new InputRecordsSplit(job, splits[x]); 128 | while (custom_reader.next(value)) { 129 | Vector v = new RandomAccessSparseVector( 130 | TwentyNewsgroupsRecordFactory.FEATURES); 131 | int actual = rec_factory.processLine(value.toString(), v); 132 | 133 | String ng = rec_factory.GetNewsgroupNameByID(actual); 134 | 135 | // calc stats --------- 136 | 137 | double mu = Math.min(k + 1, 200); 138 | double ll = learningAlgorithm.logLikelihood(actual, v); 139 | averageLL = averageLL + (ll - averageLL) / mu; 140 | 141 | Vector p = new DenseVector(20); 142 | learningAlgorithm.classifyFull(p, v); 143 | int estimated = p.maxValueIndex(); 144 | 145 | int correct = (estimated == actual ? 1 : 0); 146 | averageCorrect = averageCorrect + (correct - averageCorrect) / mu; 147 | learningAlgorithm.train(actual, v); 148 | k++; 149 | int bump = bumps[(int) Math.floor(step) % bumps.length]; 150 | int scale = (int) Math.pow(10, Math.floor(step / bumps.length)); 151 | 152 | if (k % (bump * scale) == 0) { 153 | step += 0.25; 154 | LOG.info(String.format("%10d %10.3f %10.3f %10.2f %s %s", k, ll, 155 | averageLL, averageCorrect * 100, ng, 156 | rec_factory.GetNewsgroupNameByID(estimated))); 157 | } 158 | 159 | learningAlgorithm.close(); 160 | count++; 161 | } 162 | 163 | LOG.info("read: " + count + " records for split " + x); 164 | total_read += count; 165 | } // for each split 166 | LOG.info("total read across all splits: " + total_read); 167 | rec_factory.Debug(); 168 | } 169 | } 170 | -------------------------------------------------------------------------------- /src/main/java/com/cloudera/knittingboar/conf/cmdline/DataConverterCmdLineDriver.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.cloudera.knittingboar.conf.cmdline; 19 | 20 | import java.io.IOException; 21 | import java.io.PrintWriter; 22 | 23 | import org.apache.commons.cli2.CommandLine; 24 | import org.apache.commons.cli2.Group; 25 | import org.apache.commons.cli2.Option; 26 | import org.apache.commons.cli2.builder.ArgumentBuilder; 27 | import org.apache.commons.cli2.builder.DefaultOptionBuilder; 28 | import org.apache.commons.cli2.builder.GroupBuilder; 29 | import org.apache.commons.cli2.commandline.Parser; 30 | import org.apache.commons.cli2.util.HelpFormatter; 31 | 32 | import com.cloudera.knittingboar.utils.DatasetConverter; 33 | 34 | public class DataConverterCmdLineDriver { 35 | 36 | private static String strInputFile; 37 | private static String strOutputFile; 38 | private static String strrecordsPerBlock; 39 | 40 | public static void main(String[] args) throws Exception { 41 | mainToOutput(args, new PrintWriter(System.out, true)); 42 | } 43 | 44 | static void mainToOutput(String[] args, PrintWriter output) throws Exception { 45 | if (!parseArgs(args)) { 46 | 47 | output.write("Parse:Incorrect"); 48 | return; 49 | } // if 50 | 51 | output.write("Parse:correct"); 52 | 53 | int shard_rec_count = Integer.parseInt(strrecordsPerBlock); 54 | 55 | System.out.println("Converting "); 56 | System.out.println("From: " + strInputFile); 57 | System.out.println("To: " + strOutputFile); 58 | System.out.println("File shard size (record count/file): " 59 | + shard_rec_count); 60 | 61 | int count = DatasetConverter.ConvertNewsgroupsFromSingleFiles(strInputFile, 62 | strOutputFile, shard_rec_count); 63 | 64 | output.write("Total Records Converted: " + count); 65 | 66 | } // mainToOutput 67 | 68 | private static boolean parseArgs(String[] args) throws IOException { 69 | DefaultOptionBuilder builder = new DefaultOptionBuilder(); 70 | 71 | Option help = builder.withLongName("help").withDescription( 72 | "print this list").create(); 73 | 74 | // Option quiet = 75 | // builder.withLongName("quiet").withDescription("be extra quiet").create(); 76 | // Option scores = 77 | // builder.withLongName("scores").withDescription("output score diagnostics during training").create(); 78 | 79 | ArgumentBuilder argumentBuilder = new ArgumentBuilder(); 80 | Option inputFileOption = builder 81 | .withLongName("input") 82 | .withRequired(true) 83 | .withArgument(argumentBuilder.withName("input").withMaximum(1).create()) 84 | .withDescription("where to get input data").create(); 85 | 86 | Option outputFileOption = builder.withLongName("output").withRequired(true) 87 | .withArgument( 88 | argumentBuilder.withName("output").withMaximum(1).create()) 89 | .withDescription("where to write output data").create(); 90 | 91 | Option recordsPerBlockOption = builder.withLongName("recordsPerBlock") 92 | .withArgument( 93 | argumentBuilder.withName("recordsPerBlock").withDefault("20000") 94 | .withMaximum(1).create()).withDescription( 95 | "the number of records per output file shard to write").create(); 96 | 97 | // optionally can be { 20Newsgroups, rcv1 } 98 | Option RecordFactoryType = builder.withLongName("datasetType") 99 | .withArgument( 100 | argumentBuilder.withName("recordFactoryType").withDefault( 101 | "20Newsgroups").withMaximum(1).create()).withDescription( 102 | "the type of dataset to convert").create(); 103 | 104 | /* 105 | * Option passes = builder.withLongName("passes") .withArgument( 106 | * argumentBuilder.withName("passes") .withDefault("2") 107 | * .withMaximum(1).create()) 108 | * .withDescription("the number of times to pass over the input data") 109 | * .create(); 110 | * 111 | * Option lambda = builder.withLongName("lambda") 112 | * .withArgument(argumentBuilder 113 | * .withName("lambda").withDefault("1e-4").withMaximum(1).create()) 114 | * .withDescription("the amount of coefficient decay to use") .create(); 115 | * 116 | * Option rate = builder.withLongName("rate") 117 | * .withArgument(argumentBuilder.withName 118 | * ("learningRate").withDefault("1e-3").withMaximum(1).create()) 119 | * .withDescription("the learning rate") .create(); 120 | * 121 | * Option noBias = builder.withLongName("noBias") 122 | * .withDescription("don't include a bias term") .create(); 123 | */ 124 | 125 | Group normalArgs = new GroupBuilder().withOption(help).withOption( 126 | inputFileOption).withOption(outputFileOption).withOption( 127 | recordsPerBlockOption).withOption(RecordFactoryType).create(); 128 | 129 | Parser parser = new Parser(); 130 | parser.setHelpOption(help); 131 | parser.setHelpTrigger("--help"); 132 | parser.setGroup(normalArgs); 133 | parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130)); 134 | CommandLine cmdLine = parser.parseAndHelp(args); 135 | 136 | if (cmdLine == null) { 137 | 138 | System.out.println("null!"); 139 | return false; 140 | } 141 | 142 | // "/Users/jpatterson/Downloads/datasets/20news-bydate/20news-bydate-train/" 143 | strInputFile = getStringArgument(cmdLine, inputFileOption); 144 | 145 | // "/Users/jpatterson/Downloads/datasets/20news-kboar/train4/" 146 | strOutputFile = getStringArgument(cmdLine, outputFileOption); 147 | 148 | strrecordsPerBlock = getStringArgument(cmdLine, recordsPerBlockOption); 149 | 150 | return true; 151 | } 152 | 153 | private static boolean getBooleanArgument(CommandLine cmdLine, Option option) { 154 | return cmdLine.hasOption(option); 155 | } 156 | 157 | private static String getStringArgument(CommandLine cmdLine, Option inputFile) { 158 | return (String) cmdLine.getValue(inputFile); 159 | } 160 | 161 | } -------------------------------------------------------------------------------- /src/test/java/com/cloudera/knittingboar/sgd/iterativereduce/TestPOLRIterativeReduce.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.cloudera.knittingboar.sgd.iterativereduce; 19 | 20 | import java.io.File; 21 | import java.net.InetSocketAddress; 22 | import java.util.ArrayList; 23 | import java.util.HashMap; 24 | import java.util.concurrent.Callable; 25 | import java.util.concurrent.ExecutorService; 26 | import java.util.concurrent.Executors; 27 | import java.util.concurrent.Future; 28 | 29 | import org.apache.commons.io.FileUtils; 30 | import org.apache.hadoop.conf.Configuration; 31 | import org.junit.After; 32 | import org.junit.Before; 33 | import org.junit.Test; 34 | 35 | import com.cloudera.iterativereduce.ComputableMaster; 36 | import com.cloudera.iterativereduce.ComputableWorker; 37 | import com.cloudera.iterativereduce.Utils; 38 | import com.cloudera.iterativereduce.io.TextRecordParser; 39 | import com.cloudera.iterativereduce.yarn.appmaster.ApplicationMasterService; 40 | import com.cloudera.iterativereduce.yarn.appworker.ApplicationWorkerService; 41 | import com.cloudera.iterativereduce.yarn.avro.generated.FileSplit; 42 | import com.cloudera.iterativereduce.yarn.avro.generated.StartupConfiguration; 43 | import com.cloudera.iterativereduce.yarn.avro.generated.WorkerId; 44 | 45 | import com.cloudera.knittingboar.utils.TestingUtils; 46 | import com.google.common.io.Files; 47 | //import com.cloudera.knittingboar.yarn.AvroUtils; 48 | 49 | 50 | public class TestPOLRIterativeReduce { 51 | /* 52 | private InetSocketAddress masterAddress; 53 | private ExecutorService pool; 54 | private ApplicationMasterService masterService; 55 | private Future master; 56 | private ComputableMaster computableMaster; 57 | private ArrayList> workerServices; 58 | private ArrayList> workers; 59 | private ArrayList> computableWorkers; 60 | private File baseDir; 61 | private Configuration configuration; 62 | */ 63 | /* 64 | @Before 65 | public void setUp() throws Exception { 66 | masterAddress = new InetSocketAddress(9999); 67 | pool = Executors.newFixedThreadPool(4); 68 | baseDir = Files.createTempDir(); 69 | configuration = new Configuration(); 70 | configuration.setInt( "com.cloudera.knittingboar.setup.FeatureVectorSize", 10000 ); 71 | configuration.setInt( "com.cloudera.knittingboar.setup.numCategories", 20); 72 | configuration.setInt("com.cloudera.knittingboar.setup.BatchSize", 200); 73 | configuration.setInt("com.cloudera.knittingboar.setup.NumberPasses", 1); 74 | // local input split path 75 | configuration.set( "com.cloudera.knittingboar.setup.LocalInputSplitPath", "hdfs://127.0.0.1/input/0" ); 76 | // setup 20newsgroups 77 | configuration.set( "com.cloudera.knittingboar.setup.RecordFactoryClassname", "com.cloudera.knittingboar.records.TwentyNewsgroupsRecordFactory"); 78 | workerServices = new ArrayList>(); 79 | workers = new ArrayList>(); 80 | computableWorkers = new ArrayList>(); 81 | 82 | setUpMaster(); 83 | 84 | setUpWorker("worker1"); 85 | setUpWorker("worker2"); 86 | setUpWorker("worker3"); 87 | } 88 | 89 | @After 90 | public void teardown() throws Exception { 91 | FileUtils.deleteDirectory(baseDir); 92 | } 93 | */ 94 | /** 95 | * TODO: give the system multiple files and create the right number of splits 96 | * 97 | * TODO: StartupConfiguration needs to be fed from the Configuration object somehow 98 | * 99 | * TODO: what event do I have for when all the work is done? 100 | * - maybe a "completion()" method in ComputableMaster ? 101 | * 102 | * @throws Exception 103 | */ 104 | /* public void setUpMaster() throws Exception { 105 | 106 | System.out.println( "start-ms:" + System.currentTimeMillis() ); 107 | 108 | File inputFile = new File(baseDir, "kboar-shard-0.txt"); 109 | TestingUtils.copyDecompressed("kboar-shard-0.txt.gz", inputFile); 110 | FileSplit split = FileSplit.newBuilder() 111 | .setPath(inputFile.getAbsolutePath()).setOffset(0).setLength(8348890) 112 | .build(); 113 | 114 | // hey MK, how do I set multiple splits or splits of multiple files? 115 | StartupConfiguration conf = StartupConfiguration.newBuilder() 116 | .setSplit(split).setBatchSize(200).setIterations(1).setOther(null) 117 | .build(); 118 | 119 | HashMap workers = new HashMap(); 120 | workers.put(Utils.createWorkerId("worker1"), conf); 121 | workers.put(Utils.createWorkerId("worker2"), conf); 122 | workers.put(Utils.createWorkerId("worker3"), conf); 123 | 124 | computableMaster = new POLRMasterNode(); 125 | masterService = new ApplicationMasterService(masterAddress, 126 | workers, computableMaster, ParameterVectorGradientUpdatable.class, null, configuration ); 127 | 128 | master = pool.submit(masterService); 129 | } 130 | 131 | private void setUpWorker(String name) { 132 | //HDFSLineParser parser = new HDFSLineParser(); 133 | 134 | TextRecordParser parser = new TextRecordParser(); 135 | 136 | ComputableWorker computableWorker = new POLRWorkerNode(); 137 | final ApplicationWorkerService workerService = new ApplicationWorkerService( 138 | name, masterAddress, parser, computableWorker, ParameterVectorGradientUpdatable.class, configuration); 139 | 140 | Future worker = pool.submit(new Callable() { 141 | public Integer call() { 142 | return workerService.run(); 143 | } 144 | }); 145 | 146 | computableWorkers.add(computableWorker); 147 | workerServices.add(workerService); 148 | workers.add(worker); 149 | } 150 | */ 151 | @Test 152 | public void testWorkerService() throws Exception { 153 | // TODO tests without assertions are not tests 154 | /* workers.get(0).get(); 155 | workers.get(1).get(); 156 | workers.get(2).get(); 157 | master.get(); 158 | 159 | // Bozo numbers 160 | assertEquals(Integer.valueOf(12100), computableWorkers.get(0).getResults().get()); 161 | assertEquals(Integer.valueOf(12100), computableWorkers.get(1).getResults().get()); 162 | assertEquals(Integer.valueOf(12100), computableWorkers.get(2).getResults().get()); 163 | assertEquals(Integer.valueOf(51570), computableMaster.getResults().get()); 164 | */ 165 | } 166 | } -------------------------------------------------------------------------------- /src/main/java/com/cloudera/knittingboar/records/RCV1RecordFactory.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.cloudera.knittingboar.records; 19 | 20 | import java.io.BufferedReader; 21 | import java.io.FileNotFoundException; 22 | import java.io.FileReader; 23 | import java.io.IOException; 24 | import java.io.StringReader; 25 | import java.util.ArrayList; 26 | import java.util.List; 27 | import java.util.Map; 28 | import java.util.Set; 29 | import java.util.TreeMap; 30 | 31 | import org.apache.lucene.analysis.TokenStream; 32 | import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; 33 | import org.apache.mahout.math.Vector; 34 | import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder; 35 | import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder; 36 | import org.apache.mahout.vectorizer.encoders.StaticWordValueEncoder; 37 | 38 | import org.apache.mahout.math.RandomAccessSparseVector; 39 | 40 | import com.cloudera.knittingboar.utils.Utils; 41 | import com.google.common.collect.ConcurrentHashMultiset; 42 | import com.google.common.collect.Lists; 43 | import com.google.common.collect.Multiset; 44 | 45 | /** 46 | * RecordFactory for 47 | * https://github.com/JohnLangford/vowpal_wabbit/wiki/Rcv1-example 48 | * 49 | * @author jpatterson 50 | * 51 | */ 52 | public class RCV1RecordFactory implements RecordFactory { 53 | 54 | public static final int FEATURES = 10000; 55 | ConstantValueEncoder encoder = null; 56 | 57 | public RCV1RecordFactory() { 58 | 59 | this.encoder = new ConstantValueEncoder("body_values"); 60 | 61 | } 62 | 63 | public static void ScanFile(String file, int debug_break_cnt) 64 | throws IOException { 65 | 66 | ConstantValueEncoder encoder_test = new ConstantValueEncoder("test"); 67 | 68 | BufferedReader reader = null; 69 | // Collection words 70 | int line_count = 0; 71 | 72 | Multiset class_count = ConcurrentHashMultiset.create(); 73 | Multiset namespaces = ConcurrentHashMultiset.create(); 74 | 75 | try { 76 | // System.out.println( newsgroup ); 77 | reader = new BufferedReader(new FileReader(file)); 78 | 79 | String line = reader.readLine(); 80 | 81 | while (line != null && line.length() > 0) { 82 | 83 | // shard_writer.write(line + "\n"); 84 | // out += line; 85 | 86 | String[] parts = line.split(" "); 87 | 88 | // System.out.println( "Class: " + parts[0] ); 89 | 90 | class_count.add(parts[0]); 91 | namespaces.add(parts[1]); 92 | 93 | line = reader.readLine(); 94 | line_count++; 95 | 96 | Vector v = new RandomAccessSparseVector(FEATURES); 97 | 98 | for (int x = 2; x < parts.length; x++) { 99 | // encoder_test.addToVector(parts[x], v); 100 | // System.out.println( parts[x] ); 101 | String[] feature = parts[x].split(":"); 102 | int index = Integer.parseInt(feature[0]) % FEATURES; 103 | double val = Double.parseDouble(feature[1]); 104 | 105 | // System.out.println( feature[1] + " = " + val ); 106 | 107 | if (index < FEATURES) { 108 | v.set(index, val); 109 | } else { 110 | 111 | System.out.println("Could Hash: " + index + " to " 112 | + (index % FEATURES)); 113 | 114 | } 115 | 116 | } 117 | 118 | Utils.PrintVectorSectionNonZero(v, 10); 119 | System.out.println("###"); 120 | 121 | if (line_count > debug_break_cnt) { 122 | break; 123 | } 124 | 125 | } 126 | 127 | System.out.println("Total Rec Count: " + line_count); 128 | 129 | System.out.println("-------------------- "); 130 | 131 | System.out.println("Classes"); 132 | for (String word : class_count.elementSet()) { 133 | System.out.println("Class " + word + ": " + class_count.count(word) 134 | + " "); 135 | } 136 | 137 | System.out.println("-------------------- "); 138 | 139 | System.out.println("NameSpaces:"); 140 | for (String word : namespaces.elementSet()) { 141 | System.out.println("Namespace " + word + ": " + namespaces.count(word) 142 | + " "); 143 | } 144 | 145 | /* 146 | * TokenStream ts = analyzer.tokenStream("text", reader); 147 | * ts.addAttribute(CharTermAttribute.class); 148 | * 149 | * // for each word in the stream, minus non-word stuff, add word to 150 | * collection while (ts.incrementToken()) { String s = 151 | * ts.getAttribute(CharTermAttribute.class).toString(); 152 | * //System.out.print( " " + s ); //words.add(s); out += s + " "; } 153 | */ 154 | 155 | } finally { 156 | reader.close(); 157 | } 158 | 159 | // return out + "\n"; 160 | 161 | } 162 | 163 | // doesnt really do anything in a 2 class dataset 164 | @Override 165 | public String GetClassnameByID(int id) { 166 | return String.valueOf(id); // this.newsGroups.values().get(id); 167 | } 168 | 169 | /** 170 | * Processes single line of input into: - target variable - Feature vector 171 | * 172 | * Right now our hash function is simply "modulo" 173 | * 174 | * @throws Exception 175 | */ 176 | public int processLine(String line, Vector v) throws Exception { 177 | 178 | // p.269 --------------------------------------------------------- 179 | // Map> traceDictionary = new TreeMap>(); 181 | 182 | int actual = 0; 183 | 184 | String[] parts = line.split(" "); 185 | 186 | actual = Integer.parseInt(parts[0]); 187 | 188 | // dont know what to do the the "namespace" "f" 189 | 190 | for (int x = 2; x < parts.length; x++) { 191 | 192 | String[] feature = parts[x].split(":"); 193 | int index = Integer.parseInt(feature[0]) % FEATURES; 194 | double val = Double.parseDouble(feature[1]); 195 | 196 | if (index < FEATURES) { 197 | v.set(index, val); 198 | } else { 199 | 200 | System.out 201 | .println("Could Hash: " + index + " to " + (index % FEATURES)); 202 | 203 | } 204 | 205 | } 206 | 207 | // System.out.println("\nEOL\n"); 208 | 209 | return actual; 210 | } 211 | 212 | @Override 213 | public List getTargetCategories() { 214 | 215 | List out = new ArrayList(); 216 | 217 | // for ( int x = 0; x < this.newsGroups.size(); x++ ) { 218 | 219 | // System.out.println( x + "" + this.newsGroups.values().get(x) ); 220 | out.add("0"); 221 | out.add("1"); 222 | 223 | // } 224 | 225 | return out; 226 | 227 | } 228 | 229 | } 230 | -------------------------------------------------------------------------------- /src/main/java/com/cloudera/knittingboar/yarn/avro/generated/ProgressReport.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Autogenerated by Avro 3 | * 4 | * DO NOT EDIT DIRECTLY 5 | */ 6 | package com.cloudera.knittingboar.yarn.avro.generated; 7 | @SuppressWarnings("all") 8 | public class ProgressReport extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { 9 | public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"ProgressReport\",\"namespace\":\"com.cloudera.knittingboar.yarn.avro.generated\",\"fields\":[{\"name\":\"workerId\",\"type\":{\"type\":\"fixed\",\"name\":\"WorkerId\",\"size\":32}},{\"name\":\"report\",\"type\":{\"type\":\"map\",\"values\":\"string\"}}]}"); 10 | @Deprecated public com.cloudera.knittingboar.yarn.avro.generated.WorkerId workerId; 11 | @Deprecated public java.util.Map report; 12 | public org.apache.avro.Schema getSchema() { return SCHEMA$; } 13 | // Used by DatumWriter. Applications should not call. 14 | public java.lang.Object get(int field$) { 15 | switch (field$) { 16 | case 0: return workerId; 17 | case 1: return report; 18 | default: throw new org.apache.avro.AvroRuntimeException("Bad index"); 19 | } 20 | } 21 | // Used by DatumReader. Applications should not call. 22 | @SuppressWarnings(value="unchecked") 23 | public void put(int field$, java.lang.Object value$) { 24 | switch (field$) { 25 | case 0: workerId = (com.cloudera.knittingboar.yarn.avro.generated.WorkerId)value$; break; 26 | case 1: report = (java.util.Map)value$; break; 27 | default: throw new org.apache.avro.AvroRuntimeException("Bad index"); 28 | } 29 | } 30 | 31 | /** 32 | * Gets the value of the 'workerId' field. 33 | */ 34 | public com.cloudera.knittingboar.yarn.avro.generated.WorkerId getWorkerId() { 35 | return workerId; 36 | } 37 | 38 | /** 39 | * Sets the value of the 'workerId' field. 40 | * @param value the value to set. 41 | */ 42 | public void setWorkerId(com.cloudera.knittingboar.yarn.avro.generated.WorkerId value) { 43 | this.workerId = value; 44 | } 45 | 46 | /** 47 | * Gets the value of the 'report' field. 48 | */ 49 | public java.util.Map getReport() { 50 | return report; 51 | } 52 | 53 | /** 54 | * Sets the value of the 'report' field. 55 | * @param value the value to set. 56 | */ 57 | public void setReport(java.util.Map value) { 58 | this.report = value; 59 | } 60 | 61 | /** Creates a new ProgressReport RecordBuilder */ 62 | public static com.cloudera.knittingboar.yarn.avro.generated.ProgressReport.Builder newBuilder() { 63 | return new com.cloudera.knittingboar.yarn.avro.generated.ProgressReport.Builder(); 64 | } 65 | 66 | /** Creates a new ProgressReport RecordBuilder by copying an existing Builder */ 67 | public static com.cloudera.knittingboar.yarn.avro.generated.ProgressReport.Builder newBuilder(com.cloudera.knittingboar.yarn.avro.generated.ProgressReport.Builder other) { 68 | return new com.cloudera.knittingboar.yarn.avro.generated.ProgressReport.Builder(other); 69 | } 70 | 71 | /** Creates a new ProgressReport RecordBuilder by copying an existing ProgressReport instance */ 72 | public static com.cloudera.knittingboar.yarn.avro.generated.ProgressReport.Builder newBuilder(com.cloudera.knittingboar.yarn.avro.generated.ProgressReport other) { 73 | return new com.cloudera.knittingboar.yarn.avro.generated.ProgressReport.Builder(other); 74 | } 75 | 76 | /** 77 | * RecordBuilder for ProgressReport instances. 78 | */ 79 | public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase 80 | implements org.apache.avro.data.RecordBuilder { 81 | 82 | private com.cloudera.knittingboar.yarn.avro.generated.WorkerId workerId; 83 | private java.util.Map report; 84 | 85 | /** Creates a new Builder */ 86 | private Builder() { 87 | super(com.cloudera.knittingboar.yarn.avro.generated.ProgressReport.SCHEMA$); 88 | } 89 | 90 | /** Creates a Builder by copying an existing Builder */ 91 | private Builder(com.cloudera.knittingboar.yarn.avro.generated.ProgressReport.Builder other) { 92 | super(other); 93 | } 94 | 95 | /** Creates a Builder by copying an existing ProgressReport instance */ 96 | private Builder(com.cloudera.knittingboar.yarn.avro.generated.ProgressReport other) { 97 | super(com.cloudera.knittingboar.yarn.avro.generated.ProgressReport.SCHEMA$); 98 | if (isValidValue(fields()[0], other.workerId)) { 99 | this.workerId = (com.cloudera.knittingboar.yarn.avro.generated.WorkerId) data().deepCopy(fields()[0].schema(), other.workerId); 100 | fieldSetFlags()[0] = true; 101 | } 102 | if (isValidValue(fields()[1], other.report)) { 103 | this.report = (java.util.Map) data().deepCopy(fields()[1].schema(), other.report); 104 | fieldSetFlags()[1] = true; 105 | } 106 | } 107 | 108 | /** Gets the value of the 'workerId' field */ 109 | public com.cloudera.knittingboar.yarn.avro.generated.WorkerId getWorkerId() { 110 | return workerId; 111 | } 112 | 113 | /** Sets the value of the 'workerId' field */ 114 | public com.cloudera.knittingboar.yarn.avro.generated.ProgressReport.Builder setWorkerId(com.cloudera.knittingboar.yarn.avro.generated.WorkerId value) { 115 | validate(fields()[0], value); 116 | this.workerId = value; 117 | fieldSetFlags()[0] = true; 118 | return this; 119 | } 120 | 121 | /** Checks whether the 'workerId' field has been set */ 122 | public boolean hasWorkerId() { 123 | return fieldSetFlags()[0]; 124 | } 125 | 126 | /** Clears the value of the 'workerId' field */ 127 | public com.cloudera.knittingboar.yarn.avro.generated.ProgressReport.Builder clearWorkerId() { 128 | workerId = null; 129 | fieldSetFlags()[0] = false; 130 | return this; 131 | } 132 | 133 | /** Gets the value of the 'report' field */ 134 | public java.util.Map getReport() { 135 | return report; 136 | } 137 | 138 | /** Sets the value of the 'report' field */ 139 | public com.cloudera.knittingboar.yarn.avro.generated.ProgressReport.Builder setReport(java.util.Map value) { 140 | validate(fields()[1], value); 141 | this.report = value; 142 | fieldSetFlags()[1] = true; 143 | return this; 144 | } 145 | 146 | /** Checks whether the 'report' field has been set */ 147 | public boolean hasReport() { 148 | return fieldSetFlags()[1]; 149 | } 150 | 151 | /** Clears the value of the 'report' field */ 152 | public com.cloudera.knittingboar.yarn.avro.generated.ProgressReport.Builder clearReport() { 153 | report = null; 154 | fieldSetFlags()[1] = false; 155 | return this; 156 | } 157 | 158 | @Override 159 | public ProgressReport build() { 160 | try { 161 | ProgressReport record = new ProgressReport(); 162 | record.workerId = fieldSetFlags()[0] ? this.workerId : (com.cloudera.knittingboar.yarn.avro.generated.WorkerId) defaultValue(fields()[0]); 163 | record.report = fieldSetFlags()[1] ? this.report : (java.util.Map) defaultValue(fields()[1]); 164 | return record; 165 | } catch (Exception e) { 166 | throw new org.apache.avro.AvroRuntimeException(e); 167 | } 168 | } 169 | } 170 | } 171 | -------------------------------------------------------------------------------- /src/test/java/com/cloudera/knittingboar/sgd/TestBaseSGD.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.cloudera.knittingboar.sgd; 19 | 20 | import java.io.BufferedReader; 21 | import java.io.File; 22 | import java.io.FileInputStream; 23 | import java.io.IOException; 24 | import java.io.InputStream; 25 | import java.io.InputStreamReader; 26 | 27 | import org.apache.hadoop.conf.Configuration; 28 | import org.apache.hadoop.fs.FileSystem; 29 | import org.apache.hadoop.fs.Path; 30 | import org.apache.hadoop.mapred.FileInputFormat; 31 | import org.apache.hadoop.mapred.InputSplit; 32 | import org.apache.hadoop.mapred.JobConf; 33 | import org.apache.hadoop.mapred.TextInputFormat; 34 | 35 | import junit.framework.TestCase; 36 | 37 | import com.cloudera.iterativereduce.io.TextRecordParser; 38 | import com.cloudera.knittingboar.io.InputRecordsSplit; 39 | import com.cloudera.knittingboar.records.RecordFactory; 40 | import com.cloudera.knittingboar.sgd.iterativereduce.POLRWorkerNode; 41 | import com.google.common.base.Charsets; 42 | import com.google.common.collect.Sets; 43 | import com.google.common.io.Resources; 44 | 45 | public class TestBaseSGD extends TestCase { 46 | 47 | 48 | 49 | private static JobConf defaultConf = new JobConf(); 50 | private static FileSystem localFs = null; 51 | static { 52 | try { 53 | defaultConf.set("fs.defaultFS", "file:///"); 54 | localFs = FileSystem.getLocal(defaultConf); 55 | } catch (IOException e) { 56 | throw new RuntimeException("init failure", e); 57 | } 58 | } 59 | 60 | //private static Path workDir = new Path(System.getProperty("/Users/jpatterson/Documents/workspace/WovenWabbit/data/donut_no_header.csv")); 61 | private static Path workDir = new Path(System.getProperty("test.build.data", "src/test/resources/donut_no_header.csv")); 62 | 63 | 64 | public Configuration generateDebugConfigurationObject() { 65 | 66 | Configuration c = new Configuration(); 67 | 68 | // feature vector size 69 | c.setInt( "com.cloudera.knittingboar.setup.FeatureVectorSize", 20 ); 70 | 71 | c.setInt( "com.cloudera.knittingboar.setup.numCategories", 2); 72 | 73 | //c.setInt("com.cloudera.knittingboar.setup.BatchSize", 10); 74 | 75 | 76 | 77 | c.set( "com.cloudera.knittingboar.setup.RecordFactoryClassname", RecordFactory.CSV_RECORDFACTORY); 78 | 79 | 80 | // local input split path 81 | //c.set( "com.cloudera.knittingboar.setup.LocalInputSplitPath", "hdfs://127.0.0.1/input/0" ); 82 | 83 | // predictor label names 84 | c.set( "com.cloudera.knittingboar.setup.PredictorLabelNames", "x,y" ); 85 | 86 | // predictor var types 87 | c.set( "com.cloudera.knittingboar.setup.PredictorVariableTypes", "numeric,numeric" ); 88 | 89 | // target variables 90 | c.set( "com.cloudera.knittingboar.setup.TargetVariableName", "color" ); 91 | 92 | // column header names 93 | c.set( "com.cloudera.knittingboar.setup.ColumnHeaderNames", "x,y,shape,color,k,k0,xx,xy,yy,a,b,c,bias" ); 94 | //c.set( "com.cloudera.knittingboar.setup.ColumnHeaderNames", "\"x\",\"y\",\"shape\",\"color\",\"k\",\"k0\",\"xx\",\"xy\",\"yy\",\"a\",\"b\",\"c\",\"bias\"\n" ); 95 | 96 | return c; 97 | 98 | } 99 | 100 | 101 | 102 | 103 | public InputSplit[] generateDebugSplits( Path input_path, JobConf job ) { 104 | 105 | 106 | long block_size = localFs.getDefaultBlockSize(); 107 | 108 | System.out.println("default block size: " + (block_size / 1024 / 1024) + "MB"); 109 | 110 | // ---- set where we'll read the input files from ------------- 111 | FileInputFormat.setInputPaths(job, input_path); 112 | 113 | 114 | // try splitting the file in a variety of sizes 115 | TextInputFormat format = new TextInputFormat(); 116 | format.configure(job); 117 | 118 | int numSplits = 1; 119 | 120 | InputSplit[] splits = null; 121 | 122 | try { 123 | splits = format.getSplits(job, numSplits); 124 | } catch (IOException e) { 125 | // TODO Auto-generated catch block 126 | e.printStackTrace(); 127 | } 128 | 129 | 130 | return splits; 131 | 132 | 133 | } 134 | 135 | public static BufferedReader open(String inputFile) throws IOException { 136 | InputStream in; 137 | try { 138 | in = Resources.getResource(inputFile).openStream(); 139 | } catch (IllegalArgumentException e) { 140 | in = new FileInputStream(new File(inputFile)); 141 | } 142 | return new BufferedReader(new InputStreamReader(in, Charsets.UTF_8)); 143 | } 144 | 145 | public void testConfigStuff() { 146 | 147 | String lambda = "1.0e-4"; 148 | double parsed_lambda = Double.parseDouble(lambda); 149 | 150 | System.out.println( "Parsed Double: " + parsed_lambda ); 151 | 152 | 153 | 154 | 155 | } 156 | 157 | public void testTrainer() throws Exception { 158 | 159 | 160 | //POLRWorkerDriver olr_run = new POLRWorkerDriver(); 161 | 162 | POLRWorkerNode olr_run = new POLRWorkerNode(); 163 | olr_run.setup(this.generateDebugConfigurationObject()); 164 | 165 | // generate the debug conf ---- normally setup by YARN stuff 166 | //olr_run.setConf(this.generateDebugConfigurationObject()); 167 | 168 | // ---- this all needs to be done in 169 | JobConf job = new JobConf(defaultConf); 170 | 171 | InputSplit[] splits = generateDebugSplits(workDir, job); 172 | 173 | InputRecordsSplit custom_reader = new InputRecordsSplit(job, splits[0]); 174 | 175 | // TODO: set this up to run through the conf pathways 176 | //olr_run.setupInputSplit(custom_reader); 177 | TextRecordParser txt_reader = new TextRecordParser(); 178 | 179 | long len = Integer.parseInt(splits[0].toString().split(":")[2] 180 | .split("\\+")[1]); 181 | 182 | txt_reader.setFile(splits[0].toString().split(":")[1], 0, len); 183 | 184 | olr_run.setRecordParser(txt_reader); 185 | 186 | //olr_run.s 187 | 188 | //olr_run.LoadConfigVarsLocally(); 189 | 190 | //olr_run.Setup(); 191 | 192 | 193 | for ( int x = 0; x < 5; x++) { 194 | 195 | olr_run.compute(); 196 | olr_run.IncrementIteration(); 197 | System.out.println( "---------- cycle " + x + " done ------------- " ); 198 | 199 | } // for 200 | 201 | //olr_run.PrintModelStats(); 202 | 203 | 204 | 205 | //LogisticModelParameters lmp = model_builder.lmp;//TrainLogistic.getParameters(); 206 | assertEquals(1.0e-4, olr_run.polr_modelparams.getLambda(), 1.0e-9); 207 | assertEquals(20, olr_run.polr_modelparams.getNumFeatures()); 208 | assertTrue(olr_run.polr_modelparams.useBias()); 209 | assertEquals("color", olr_run.polr_modelparams.getTargetVariable()); 210 | //CsvRecordFactory csv = model_builder.lmp.getCsvRecordFactory(); 211 | // assertEquals("[1, 2]", Sets.newTreeSet(olr_run.csvVectorFactory.getTargetCategories()).toString()); 212 | // assertEquals("[Intercept Term, x, y]", Sets.newTreeSet(olr_run.csvVectorFactory.getPredictors()).toString()); 213 | 214 | 215 | 216 | System.out.println("done!"); 217 | 218 | 219 | assertNotNull(0); 220 | 221 | } 222 | 223 | } 224 | -------------------------------------------------------------------------------- /src/main/java/com/cloudera/knittingboar/yarn/avro/generated/FileSplit.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Autogenerated by Avro 3 | * 4 | * DO NOT EDIT DIRECTLY 5 | */ 6 | package com.cloudera.knittingboar.yarn.avro.generated; 7 | @SuppressWarnings("all") 8 | public class FileSplit extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { 9 | public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"FileSplit\",\"namespace\":\"com.cloudera.knittingboar.yarn.avro.generated\",\"fields\":[{\"name\":\"path\",\"type\":\"string\"},{\"name\":\"offset\",\"type\":\"long\"},{\"name\":\"length\",\"type\":\"long\"}]}"); 10 | @Deprecated public java.lang.CharSequence path; 11 | @Deprecated public long offset; 12 | @Deprecated public long length; 13 | public org.apache.avro.Schema getSchema() { return SCHEMA$; } 14 | // Used by DatumWriter. Applications should not call. 15 | public java.lang.Object get(int field$) { 16 | switch (field$) { 17 | case 0: return path; 18 | case 1: return offset; 19 | case 2: return length; 20 | default: throw new org.apache.avro.AvroRuntimeException("Bad index"); 21 | } 22 | } 23 | // Used by DatumReader. Applications should not call. 24 | @SuppressWarnings(value="unchecked") 25 | public void put(int field$, java.lang.Object value$) { 26 | switch (field$) { 27 | case 0: path = (java.lang.CharSequence)value$; break; 28 | case 1: offset = (java.lang.Long)value$; break; 29 | case 2: length = (java.lang.Long)value$; break; 30 | default: throw new org.apache.avro.AvroRuntimeException("Bad index"); 31 | } 32 | } 33 | 34 | /** 35 | * Gets the value of the 'path' field. 36 | */ 37 | public java.lang.CharSequence getPath() { 38 | return path; 39 | } 40 | 41 | /** 42 | * Sets the value of the 'path' field. 43 | * @param value the value to set. 44 | */ 45 | public void setPath(java.lang.CharSequence value) { 46 | this.path = value; 47 | } 48 | 49 | /** 50 | * Gets the value of the 'offset' field. 51 | */ 52 | public java.lang.Long getOffset() { 53 | return offset; 54 | } 55 | 56 | /** 57 | * Sets the value of the 'offset' field. 58 | * @param value the value to set. 59 | */ 60 | public void setOffset(java.lang.Long value) { 61 | this.offset = value; 62 | } 63 | 64 | /** 65 | * Gets the value of the 'length' field. 66 | */ 67 | public java.lang.Long getLength() { 68 | return length; 69 | } 70 | 71 | /** 72 | * Sets the value of the 'length' field. 73 | * @param value the value to set. 74 | */ 75 | public void setLength(java.lang.Long value) { 76 | this.length = value; 77 | } 78 | 79 | /** Creates a new FileSplit RecordBuilder */ 80 | public static com.cloudera.knittingboar.yarn.avro.generated.FileSplit.Builder newBuilder() { 81 | return new com.cloudera.knittingboar.yarn.avro.generated.FileSplit.Builder(); 82 | } 83 | 84 | /** Creates a new FileSplit RecordBuilder by copying an existing Builder */ 85 | public static com.cloudera.knittingboar.yarn.avro.generated.FileSplit.Builder newBuilder(com.cloudera.knittingboar.yarn.avro.generated.FileSplit.Builder other) { 86 | return new com.cloudera.knittingboar.yarn.avro.generated.FileSplit.Builder(other); 87 | } 88 | 89 | /** Creates a new FileSplit RecordBuilder by copying an existing FileSplit instance */ 90 | public static com.cloudera.knittingboar.yarn.avro.generated.FileSplit.Builder newBuilder(com.cloudera.knittingboar.yarn.avro.generated.FileSplit other) { 91 | return new com.cloudera.knittingboar.yarn.avro.generated.FileSplit.Builder(other); 92 | } 93 | 94 | /** 95 | * RecordBuilder for FileSplit instances. 96 | */ 97 | public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase 98 | implements org.apache.avro.data.RecordBuilder { 99 | 100 | private java.lang.CharSequence path; 101 | private long offset; 102 | private long length; 103 | 104 | /** Creates a new Builder */ 105 | private Builder() { 106 | super(com.cloudera.knittingboar.yarn.avro.generated.FileSplit.SCHEMA$); 107 | } 108 | 109 | /** Creates a Builder by copying an existing Builder */ 110 | private Builder(com.cloudera.knittingboar.yarn.avro.generated.FileSplit.Builder other) { 111 | super(other); 112 | } 113 | 114 | /** Creates a Builder by copying an existing FileSplit instance */ 115 | private Builder(com.cloudera.knittingboar.yarn.avro.generated.FileSplit other) { 116 | super(com.cloudera.knittingboar.yarn.avro.generated.FileSplit.SCHEMA$); 117 | if (isValidValue(fields()[0], other.path)) { 118 | this.path = (java.lang.CharSequence) data().deepCopy(fields()[0].schema(), other.path); 119 | fieldSetFlags()[0] = true; 120 | } 121 | if (isValidValue(fields()[1], other.offset)) { 122 | this.offset = (java.lang.Long) data().deepCopy(fields()[1].schema(), other.offset); 123 | fieldSetFlags()[1] = true; 124 | } 125 | if (isValidValue(fields()[2], other.length)) { 126 | this.length = (java.lang.Long) data().deepCopy(fields()[2].schema(), other.length); 127 | fieldSetFlags()[2] = true; 128 | } 129 | } 130 | 131 | /** Gets the value of the 'path' field */ 132 | public java.lang.CharSequence getPath() { 133 | return path; 134 | } 135 | 136 | /** Sets the value of the 'path' field */ 137 | public com.cloudera.knittingboar.yarn.avro.generated.FileSplit.Builder setPath(java.lang.CharSequence value) { 138 | validate(fields()[0], value); 139 | this.path = value; 140 | fieldSetFlags()[0] = true; 141 | return this; 142 | } 143 | 144 | /** Checks whether the 'path' field has been set */ 145 | public boolean hasPath() { 146 | return fieldSetFlags()[0]; 147 | } 148 | 149 | /** Clears the value of the 'path' field */ 150 | public com.cloudera.knittingboar.yarn.avro.generated.FileSplit.Builder clearPath() { 151 | path = null; 152 | fieldSetFlags()[0] = false; 153 | return this; 154 | } 155 | 156 | /** Gets the value of the 'offset' field */ 157 | public java.lang.Long getOffset() { 158 | return offset; 159 | } 160 | 161 | /** Sets the value of the 'offset' field */ 162 | public com.cloudera.knittingboar.yarn.avro.generated.FileSplit.Builder setOffset(long value) { 163 | validate(fields()[1], value); 164 | this.offset = value; 165 | fieldSetFlags()[1] = true; 166 | return this; 167 | } 168 | 169 | /** Checks whether the 'offset' field has been set */ 170 | public boolean hasOffset() { 171 | return fieldSetFlags()[1]; 172 | } 173 | 174 | /** Clears the value of the 'offset' field */ 175 | public com.cloudera.knittingboar.yarn.avro.generated.FileSplit.Builder clearOffset() { 176 | fieldSetFlags()[1] = false; 177 | return this; 178 | } 179 | 180 | /** Gets the value of the 'length' field */ 181 | public java.lang.Long getLength() { 182 | return length; 183 | } 184 | 185 | /** Sets the value of the 'length' field */ 186 | public com.cloudera.knittingboar.yarn.avro.generated.FileSplit.Builder setLength(long value) { 187 | validate(fields()[2], value); 188 | this.length = value; 189 | fieldSetFlags()[2] = true; 190 | return this; 191 | } 192 | 193 | /** Checks whether the 'length' field has been set */ 194 | public boolean hasLength() { 195 | return fieldSetFlags()[2]; 196 | } 197 | 198 | /** Clears the value of the 'length' field */ 199 | public com.cloudera.knittingboar.yarn.avro.generated.FileSplit.Builder clearLength() { 200 | fieldSetFlags()[2] = false; 201 | return this; 202 | } 203 | 204 | @Override 205 | public FileSplit build() { 206 | try { 207 | FileSplit record = new FileSplit(); 208 | record.path = fieldSetFlags()[0] ? this.path : (java.lang.CharSequence) defaultValue(fields()[0]); 209 | record.offset = fieldSetFlags()[1] ? this.offset : (java.lang.Long) defaultValue(fields()[1]); 210 | record.length = fieldSetFlags()[2] ? this.length : (java.lang.Long) defaultValue(fields()[2]); 211 | return record; 212 | } catch (Exception e) { 213 | throw new org.apache.avro.AvroRuntimeException(e); 214 | } 215 | } 216 | } 217 | } 218 | -------------------------------------------------------------------------------- /src/main/java/com/cloudera/knittingboar/utils/Utils.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.cloudera.knittingboar.utils; 19 | 20 | import java.io.File; 21 | import java.io.FileInputStream; 22 | import java.io.FileNotFoundException; 23 | import java.io.FileOutputStream; 24 | import java.io.IOException; 25 | import java.io.InputStream; 26 | import java.io.OutputStream; 27 | import java.util.Iterator; 28 | import java.util.LinkedList; 29 | import java.util.List; 30 | import java.util.zip.GZIPInputStream; 31 | 32 | import org.apache.commons.compress.archivers.ArchiveException; 33 | import org.apache.commons.compress.archivers.ArchiveStreamFactory; 34 | import org.apache.commons.compress.archivers.tar.TarArchiveEntry; 35 | import org.apache.commons.compress.archivers.tar.TarArchiveInputStream; 36 | import org.apache.commons.compress.utils.IOUtils; 37 | import org.apache.mahout.math.Vector; 38 | 39 | public class Utils { 40 | 41 | public static void UnTarAndZipGZFile(final File inputFile, 42 | final File outputDir) throws FileNotFoundException, IOException, 43 | ArchiveException { 44 | 45 | System.out.println( "Path: " + inputFile.getParent() ); 46 | 47 | // unGzip(inputFile, new File(inputFile.getParent())); 48 | 49 | String new_path = inputFile.getPath().replaceFirst(".gz", ""); 50 | 51 | System.out.println("Tar File: " + new_path); 52 | 53 | unTar(new File(new_path), outputDir); 54 | 55 | // now cleanup tmp .tar file 56 | 57 | } 58 | 59 | /** 60 | * Untar an input file into an output file. 61 | * 62 | * The output file is created in the output folder, having the same name as 63 | * the input file, minus the '.tar' extension. 64 | * 65 | * @param inputFile 66 | * the input .tar file 67 | * @param outputDir 68 | * the output directory file. 69 | * @throws IOException 70 | * @throws FileNotFoundException 71 | * 72 | * @return The {@link List} of {@link File}s with the untared content. 73 | * @throws ArchiveException 74 | */ 75 | private static List unTar(final File inputFile, final File outputDir) 76 | throws FileNotFoundException, IOException, ArchiveException { 77 | 78 | System.out.println(String.format("Untaring %s to dir %s.", inputFile 79 | .getAbsolutePath(), outputDir.getAbsolutePath())); 80 | 81 | final List untaredFiles = new LinkedList(); 82 | final InputStream is = new FileInputStream(inputFile); 83 | final TarArchiveInputStream debInputStream = (TarArchiveInputStream) new ArchiveStreamFactory() 84 | .createArchiveInputStream("tar", is); 85 | TarArchiveEntry entry = null; 86 | while ((entry = (TarArchiveEntry) debInputStream.getNextEntry()) != null) { 87 | final File outputFile = new File(outputDir, entry.getName()); 88 | if (entry.isDirectory()) { 89 | System.out.println(String.format( 90 | "Attempting to write output directory %s.", outputFile 91 | .getAbsolutePath())); 92 | if (!outputFile.exists()) { 93 | System.out.println(String.format( 94 | "Attempting to create output directory %s.", outputFile 95 | .getAbsolutePath())); 96 | if (!outputFile.mkdirs()) { 97 | throw new IllegalStateException(String.format( 98 | "Couldn't create directory %s.", outputFile.getAbsolutePath())); 99 | } 100 | } 101 | } else { 102 | System.out.println(String.format("Creating output file %s.", outputFile 103 | .getAbsolutePath())); 104 | final OutputStream outputFileStream = new FileOutputStream(outputFile); 105 | IOUtils.copy(debInputStream, outputFileStream); 106 | outputFileStream.close(); 107 | } 108 | untaredFiles.add(outputFile); 109 | } 110 | debInputStream.close(); 111 | 112 | return untaredFiles; 113 | } 114 | 115 | /** 116 | * Ungzip an input file into an output file. 117 | *

118 | * The output file is created in the output folder, having the same name as 119 | * the input file, minus the '.gz' extension. 120 | * 121 | * @param inputFile 122 | * the input .gz file 123 | * @param outputDir 124 | * the output directory file. 125 | * @throws IOException 126 | * @throws FileNotFoundException 127 | * 128 | * @return The {@File} with the ungzipped content. 129 | */ 130 | private static File unGzip(final File inputFile, final File outputDir) 131 | throws FileNotFoundException, IOException { 132 | 133 | System.out.println(String.format("Ungzipping %s to dir %s.", inputFile 134 | .getAbsolutePath(), outputDir.getAbsolutePath())); 135 | 136 | final File outputFile = new File(outputDir, inputFile.getName().substring( 137 | 0, inputFile.getName().length() - 3)); 138 | 139 | final GZIPInputStream in = new GZIPInputStream(new FileInputStream( 140 | inputFile)); 141 | final FileOutputStream out = new FileOutputStream(outputFile); 142 | 143 | for (int c = in.read(); c != -1; c = in.read()) { 144 | out.write(c); 145 | } 146 | 147 | in.close(); 148 | out.close(); 149 | 150 | return outputFile; 151 | } 152 | 153 | public static void PrintVector(Vector v) { 154 | 155 | boolean first = true; 156 | Iterator nonZeros = v.iterator(); 157 | while (nonZeros.hasNext()) { 158 | Vector.Element vec_loc = nonZeros.next(); 159 | 160 | if (!first) { 161 | System.out.print(","); 162 | } else { 163 | first = false; 164 | } 165 | 166 | System.out.print(" " + vec_loc.get()); 167 | 168 | } 169 | 170 | System.out.println(""); 171 | 172 | } 173 | 174 | public static void PrintVectorSection(Vector v, int num) { 175 | 176 | boolean first = true; 177 | Iterator nonZeros = v.iterator(); 178 | int cnt = 0; 179 | 180 | while (nonZeros.hasNext()) { 181 | Vector.Element vec_loc = nonZeros.next(); 182 | 183 | if (!first) { 184 | System.out.print(","); 185 | } else { 186 | first = false; 187 | } 188 | 189 | System.out.print(" " + vec_loc.get()); 190 | if (cnt > num) { 191 | break; 192 | } 193 | cnt++; 194 | } 195 | 196 | System.out.println(" ######## "); 197 | 198 | } 199 | 200 | public static void PrintVectorNonZero(Vector v) { 201 | 202 | boolean first = true; 203 | Iterator nonZeros = v.iterateNonZero(); 204 | while (nonZeros.hasNext()) { 205 | Vector.Element vec_loc = nonZeros.next(); 206 | 207 | if (!first) { 208 | System.out.print(","); 209 | } else { 210 | first = false; 211 | } 212 | System.out.print(" " + vec_loc.get()); 213 | 214 | } 215 | 216 | System.out.println(""); 217 | 218 | } 219 | 220 | public static void PrintVectorSectionNonZero(Vector v, int size) { 221 | 222 | boolean first = true; 223 | Iterator nonZeros = v.iterateNonZero(); 224 | 225 | int cnt = 0; 226 | 227 | while (nonZeros.hasNext()) { 228 | Vector.Element vec_loc = nonZeros.next(); 229 | 230 | if (!first) { 231 | System.out.print(","); 232 | } else { 233 | first = false; 234 | } 235 | System.out.print(" " + vec_loc.get()); 236 | 237 | if (cnt > size) { 238 | break; 239 | } 240 | cnt++; 241 | } 242 | 243 | System.out.println(""); 244 | 245 | } 246 | 247 | } 248 | -------------------------------------------------------------------------------- /src/main/java/com/cloudera/knittingboar/records/TwentyNewsgroupsRecordFactory.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.cloudera.knittingboar.records; 19 | 20 | import java.io.BufferedReader; 21 | import java.io.File; 22 | import java.io.FileReader; 23 | import java.io.IOException; 24 | import java.io.Reader; 25 | import java.io.StringReader; 26 | import java.util.ArrayList; 27 | import java.util.Arrays; 28 | import java.util.Collection; 29 | import java.util.Collections; 30 | import java.util.List; 31 | import java.util.Map; 32 | import java.util.Set; 33 | import java.util.TreeMap; 34 | 35 | import org.apache.lucene.analysis.Analyzer; 36 | import org.apache.lucene.analysis.TokenStream; 37 | import org.apache.lucene.analysis.standard.StandardAnalyzer; 38 | import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; 39 | import org.apache.lucene.util.Version; 40 | 41 | import org.apache.mahout.math.RandomAccessSparseVector; 42 | import org.apache.mahout.math.Vector; 43 | import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder; 44 | import org.apache.mahout.vectorizer.encoders.Dictionary; 45 | import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder; 46 | import org.apache.mahout.vectorizer.encoders.StaticWordValueEncoder; 47 | 48 | import com.google.common.base.Splitter; 49 | import com.google.common.collect.ConcurrentHashMultiset; 50 | import com.google.common.collect.HashMultiset; 51 | import com.google.common.collect.Iterables; 52 | import com.google.common.collect.Multiset; 53 | 54 | /** 55 | * Adapted from: 56 | * 57 | * https://github.com/tdunning/MiA/blob/master/src/main/java/mia/classifier/ch14 58 | * /TrainNewsGroups.java 59 | * 60 | * 61 | * I've hardcoded the class id's in the dataset record factory, cause, uh, they 62 | * don't really change in this dataset 63 | * 64 | * @author jpatterson 65 | * 66 | */ 67 | public class TwentyNewsgroupsRecordFactory implements RecordFactory { // implements 68 | // RecordFactory 69 | // { 70 | 71 | public static final int FEATURES = 10000; 72 | 73 | Dictionary newsGroups = null; // new Dictionary(); 74 | 75 | Analyzer analyzer = new StandardAnalyzer(Version.LUCENE_31); 76 | 77 | String class_id_split_string = " "; 78 | 79 | public TwentyNewsgroupsRecordFactory(String strClassSeperator) { 80 | 81 | this.newsGroups = new Dictionary(); 82 | 83 | newsGroups.intern("alt.atheism"); 84 | newsGroups.intern("comp.graphics"); 85 | newsGroups.intern("comp.os.ms-windows.misc"); 86 | newsGroups.intern("comp.sys.ibm.pc.hardware"); 87 | newsGroups.intern("comp.sys.mac.hardware"); 88 | newsGroups.intern("comp.windows.x"); 89 | newsGroups.intern("misc.forsale"); 90 | newsGroups.intern("rec.autos"); 91 | newsGroups.intern("rec.motorcycles"); 92 | newsGroups.intern("rec.sport.baseball"); 93 | 94 | newsGroups.intern("rec.sport.hockey"); 95 | newsGroups.intern("sci.crypt"); 96 | newsGroups.intern("sci.electronics"); 97 | newsGroups.intern("sci.med"); 98 | newsGroups.intern("sci.space"); 99 | newsGroups.intern("soc.religion.christian"); 100 | newsGroups.intern("talk.politics.guns"); 101 | newsGroups.intern("talk.politics.mideast"); 102 | newsGroups.intern("talk.politics.misc"); 103 | newsGroups.intern("talk.religion.misc"); 104 | 105 | this.class_id_split_string = strClassSeperator; 106 | 107 | } 108 | 109 | @Override 110 | public List getTargetCategories() { 111 | 112 | List out = new ArrayList(); 113 | 114 | for (int x = 0; x < this.newsGroups.size(); x++) { 115 | 116 | // System.out.println( x + "" + this.newsGroups.values().get(x) ); 117 | out.add(this.newsGroups.values().get(x)); 118 | 119 | } 120 | 121 | return out; 122 | 123 | } 124 | 125 | public int LookupIDForNewsgroupName(String name) { 126 | 127 | return this.newsGroups.values().indexOf(name); 128 | 129 | } 130 | 131 | public boolean ContainsIDForNewsgroupName(String name) { 132 | 133 | return this.newsGroups.values().contains(name); 134 | 135 | } 136 | 137 | public String GetNewsgroupNameByID(int id) { 138 | 139 | return this.newsGroups.values().get(id); 140 | 141 | } 142 | 143 | @Override 144 | public String GetClassnameByID(int id) { 145 | return this.newsGroups.values().get(id); 146 | } 147 | 148 | private static void countWords(Analyzer analyzer, Collection words, 149 | Reader in) throws IOException { 150 | 151 | // use the provided analyzer to tokenize the input stream 152 | TokenStream ts = analyzer.tokenStream("text", in); 153 | ts.addAttribute(CharTermAttribute.class); 154 | 155 | // for each word in the stream, minus non-word stuff, add word to collection 156 | while (ts.incrementToken()) { 157 | String s = ts.getAttribute(CharTermAttribute.class).toString(); 158 | words.add(s); 159 | } 160 | 161 | } 162 | 163 | /** 164 | * Processes single line of input into: - target variable - Feature vector 165 | * 166 | * @throws Exception 167 | */ 168 | public int processLine(String line, Vector v) throws Exception { 169 | 170 | String[] parts = line.split(this.class_id_split_string); 171 | if (parts.length < 2) { 172 | throw new Exception("wtf: line not formed well."); 173 | } 174 | 175 | String newsgroup_name = parts[0]; 176 | String msg = parts[1]; 177 | 178 | // p.269 --------------------------------------------------------- 179 | Map> traceDictionary = new TreeMap>(); 180 | 181 | // encodes the text content in both the subject and the body of the email 182 | FeatureVectorEncoder encoder = new StaticWordValueEncoder("body"); 183 | encoder.setProbes(2); 184 | encoder.setTraceDictionary(traceDictionary); 185 | 186 | // provides a constant offset that the model can use to encode the average 187 | // frequency 188 | // of each class 189 | FeatureVectorEncoder bias = new ConstantValueEncoder("Intercept"); 190 | bias.setTraceDictionary(traceDictionary); 191 | 192 | int actual = newsGroups.intern(newsgroup_name); 193 | // newsGroups.values().contains(arg0) 194 | 195 | // System.out.println( "> newsgroup name: " + newsgroup_name ); 196 | // System.out.println( "> newsgroup id: " + actual ); 197 | 198 | Multiset words = ConcurrentHashMultiset.create(); 199 | /* 200 | * // System.out.println("record: "); for ( int x = 1; x < parts.length; x++ 201 | * ) { //String s = ts.getAttribute(CharTermAttribute.class).toString(); // 202 | * System.out.print( " " + parts[x] ); String foo = parts[x].trim(); 203 | * System.out.print( " " + foo ); words.add( foo ); 204 | * 205 | * } // System.out.println("\nEOR"); System.out.println( "\nwords found: " + 206 | * (parts.length - 1) ); System.out.println( "words in set: " + words.size() 207 | * + ", " + words.toString() ); 208 | */ 209 | 210 | StringReader in = new StringReader(msg); 211 | 212 | countWords(analyzer, words, in); 213 | 214 | // ----- p.271 ----------- 215 | // Vector v = new RandomAccessSparseVector(FEATURES); 216 | 217 | // original value does nothing in a ContantValueEncoder 218 | bias.addToVector("", 1, v); 219 | 220 | // original value does nothing in a ContantValueEncoder 221 | // lines.addToVector("", lineCount / 30, v); 222 | 223 | // original value does nothing in a ContantValueEncoder 224 | // logLines.addToVector("", Math.log(lineCount + 1), v); 225 | 226 | // now scan through all the words and add them 227 | // System.out.println( "############### " + words.toArray().length); 228 | for (String word : words.elementSet()) { 229 | encoder.addToVector(word, Math.log(1 + words.count(word)), v); 230 | // System.out.print( words.count(word) + " " ); 231 | } 232 | 233 | // System.out.println("\nEOL\n"); 234 | 235 | return actual; 236 | } 237 | 238 | public void Debug() { 239 | 240 | System.out.println("DictionarySize: " + this.newsGroups.values().size()); 241 | 242 | } 243 | 244 | public void DebugDictionary() { 245 | 246 | for (int x = 0; x < this.newsGroups.size(); x++) { 247 | 248 | System.out.println(x + "" + this.newsGroups.values().get(x)); 249 | 250 | } 251 | 252 | } 253 | 254 | } 255 | -------------------------------------------------------------------------------- /src/test/java/com/cloudera/knittingboar/records/Test20NewsgroupsBookParsing.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.cloudera.knittingboar.records; 19 | 20 | import java.io.BufferedReader; 21 | import java.io.File; 22 | import java.io.FileReader; 23 | import java.io.IOException; 24 | import java.io.Reader; 25 | import java.io.StringReader; 26 | import java.util.ArrayList; 27 | import java.util.Arrays; 28 | import java.util.Collection; 29 | import java.util.Collections; 30 | import java.util.List; 31 | import java.util.Map; 32 | import java.util.Set; 33 | import java.util.TreeMap; 34 | 35 | import org.apache.hadoop.fs.Path; 36 | import org.apache.lucene.analysis.Analyzer; 37 | import org.apache.lucene.analysis.TokenStream; 38 | import org.apache.lucene.analysis.standard.StandardAnalyzer; 39 | import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; 40 | import org.apache.lucene.util.Version; 41 | import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder; 42 | import org.apache.mahout.vectorizer.encoders.Dictionary; 43 | import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder; 44 | import org.apache.mahout.vectorizer.encoders.StaticWordValueEncoder; 45 | 46 | import com.cloudera.knittingboar.utils.DataUtils; 47 | import com.google.common.base.Splitter; 48 | import com.google.common.collect.ConcurrentHashMultiset; 49 | import com.google.common.collect.HashMultiset; 50 | import com.google.common.collect.Iterables; 51 | import com.google.common.collect.Multiset; 52 | 53 | import junit.framework.TestCase; 54 | 55 | public class Test20NewsgroupsBookParsing extends TestCase { 56 | 57 | 58 | 59 | private static Path workDir20NewsLocal = new Path(new Path("/tmp"), "Dataset20Newsgroups"); 60 | private static File unzipDir = new File( workDir20NewsLocal + "/20news-bydate"); 61 | private static String strKBoarTestDirInput = "" + unzipDir.toString() + "/KBoar-test/"; 62 | 63 | 64 | 65 | private static final int FEATURES = 10000; 66 | private static Multiset overallCounts; 67 | 68 | 69 | /** 70 | * 71 | * Counts words 72 | * 73 | * @param analyzer 74 | * @param words 75 | * @param in 76 | * @throws IOException 77 | */ 78 | private static void countWords(Analyzer analyzer, Collection words, Reader in) throws IOException { 79 | 80 | //System.out.println( "> ----- countWords ------" ); 81 | 82 | // use the provided analyzer to tokenize the input stream 83 | TokenStream ts = analyzer.tokenStream("text", in); 84 | ts.addAttribute(CharTermAttribute.class); 85 | 86 | // for each word in the stream, minus non-word stuff, add word to collection 87 | while (ts.incrementToken()) { 88 | String s = ts.getAttribute(CharTermAttribute.class).toString(); 89 | //System.out.print( " " + s ); 90 | words.add(s); 91 | } 92 | 93 | //System.out.println( "\n<" ); 94 | 95 | /*overallCounts.addAll(words);*/ 96 | } 97 | 98 | 99 | public void test20NewsgroupsFileScan() throws IOException { 100 | 101 | // p.270 ----- metrics to track lucene's parsing mechanics, progress, performance of OLR ------------ 102 | double averageLL = 0.0; 103 | double averageCorrect = 0.0; 104 | double averageLineCount = 0.0; 105 | int k = 0; 106 | double step = 0.0; 107 | int[] bumps = new int[]{1, 2, 5}; 108 | double lineCount = 0; 109 | 110 | 111 | Splitter onColon = Splitter.on(":").trimResults(); 112 | // last line on p.269 113 | Analyzer analyzer = new StandardAnalyzer(Version.LUCENE_31); 114 | 115 | File file20News = DataUtils.getTwentyNewsGroupDir(); 116 | 117 | System.out.println( "written to: " + DataUtils.get20NewsgroupsLocalDataLocation() ); 118 | 119 | File base = new File(DataUtils.get20NewsgroupsLocalDataLocation() + "/20news-bydate-train"); 120 | overallCounts = HashMultiset.create(); 121 | 122 | 123 | 124 | // p.269 --------------------------------------------------------- 125 | Map> traceDictionary = new TreeMap>(); 126 | 127 | // encodes the text content in both the subject and the body of the email 128 | FeatureVectorEncoder encoder = new StaticWordValueEncoder("body"); 129 | encoder.setProbes(2); 130 | encoder.setTraceDictionary(traceDictionary); 131 | 132 | // provides a constant offset that the model can use to encode the average frequency 133 | // of each class 134 | FeatureVectorEncoder bias = new ConstantValueEncoder("Intercept"); 135 | bias.setTraceDictionary(traceDictionary); 136 | 137 | // used to encode the number of lines in a message 138 | FeatureVectorEncoder lines = new ConstantValueEncoder("Lines"); 139 | lines.setTraceDictionary(traceDictionary); 140 | Dictionary newsGroups = new Dictionary(); 141 | 142 | 143 | // bottom of p.269 ------------------------------ 144 | // because OLR expects to get integer class IDs for the target variable during training 145 | // we need a dictionary to convert the target variable (the newsgroup name) 146 | // to an integer, which is the newsGroup object 147 | List files = new ArrayList(); 148 | for (File newsgroup : base.listFiles()) { 149 | newsGroups.intern(newsgroup.getName()); 150 | System.out.println( ">> " + newsgroup.getName() ); 151 | files.addAll(Arrays.asList(newsgroup.listFiles())); 152 | } 153 | 154 | // mix up the files, helps training in OLR 155 | Collections.shuffle(files); 156 | System.out.printf("%d training files\n", files.size()); 157 | 158 | 159 | 160 | // ----- p.270 ------------ "reading and tokenzing the data" --------- 161 | for (File file : files) { 162 | BufferedReader reader = new BufferedReader(new FileReader(file)); 163 | 164 | // identify newsgroup ---------------- 165 | // convert newsgroup name to unique id 166 | // ----------------------------------- 167 | String ng = file.getParentFile().getName(); 168 | int actual = newsGroups.intern(ng); 169 | Multiset words = ConcurrentHashMultiset.create(); 170 | 171 | // check for line count header ------- 172 | String line = reader.readLine(); 173 | while (line != null && line.length() > 0) { 174 | 175 | // if this is a line that has a line count, let's pull that value out ------ 176 | if (line.startsWith("Lines:")) { 177 | String count = Iterables.get(onColon.split(line), 1); 178 | try { 179 | lineCount = Integer.parseInt(count); 180 | averageLineCount += (lineCount - averageLineCount) / Math.min(k + 1, 1000); 181 | } catch (NumberFormatException e) { 182 | // if anything goes wrong in parse: just use the avg count 183 | lineCount = averageLineCount; 184 | } 185 | } 186 | 187 | 188 | // which header words to actually count 189 | boolean countHeader = ( 190 | line.startsWith("From:") || line.startsWith("Subject:")|| 191 | line.startsWith("Keywords:")|| line.startsWith("Summary:")); 192 | 193 | 194 | // we're still looking at the header at this point 195 | // loop through the lines in the file, while the line starts with: " " 196 | do { 197 | 198 | // get a reader for this specific string ------ 199 | StringReader in = new StringReader(line); 200 | 201 | 202 | 203 | // ---- count words in header --------- 204 | if (countHeader) { 205 | //System.out.println( "#### countHeader ################*************" ); 206 | countWords(analyzer, words, in); 207 | } 208 | 209 | // iterate to the next string ---- 210 | line = reader.readLine(); 211 | 212 | 213 | 214 | } while (line.startsWith(" ")); 215 | 216 | //System.out.println("[break]"); 217 | 218 | } 219 | 220 | // now we're done with the header 221 | 222 | //System.out.println("[break-header]"); 223 | 224 | // -------- count words in body ---------- 225 | countWords(analyzer, words, reader); 226 | reader.close(); 227 | 228 | /* 229 | for (String word : words.elementSet()) { 230 | //encoder.addToVector(word, Math.log(1 + words.count(word)), v); 231 | System.out.println( "> " + word + ", " + words.count(word) ); 232 | } 233 | */ 234 | } 235 | 236 | 237 | } 238 | 239 | } 240 | -------------------------------------------------------------------------------- /src/test/java/com/cloudera/knittingboar/sgd/olr/TestBaseOLRTest20Newsgroups.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.cloudera.knittingboar.sgd.olr; 19 | 20 | import java.io.File; 21 | import java.io.FileInputStream; 22 | import java.io.FileNotFoundException; 23 | import java.io.IOException; 24 | 25 | import org.apache.hadoop.conf.Configuration; 26 | import org.apache.hadoop.fs.FileSystem; 27 | import org.apache.hadoop.fs.Path; 28 | import org.apache.hadoop.io.Text; 29 | import org.apache.hadoop.mapred.FileInputFormat; 30 | import org.apache.hadoop.mapred.InputSplit; 31 | import org.apache.hadoop.mapred.JobConf; 32 | import org.apache.hadoop.mapred.TextInputFormat; 33 | import org.apache.mahout.classifier.sgd.ModelSerializer; 34 | import org.apache.mahout.classifier.sgd.OnlineLogisticRegression; 35 | import org.apache.mahout.math.DenseVector; 36 | import org.apache.mahout.math.RandomAccessSparseVector; 37 | import org.apache.mahout.math.Vector; 38 | 39 | import com.cloudera.knittingboar.io.InputRecordsSplit; 40 | import com.cloudera.knittingboar.metrics.POLRMetrics; 41 | import com.cloudera.knittingboar.metrics.POLRModelTester; 42 | import com.cloudera.knittingboar.records.RecordFactory; 43 | import com.cloudera.knittingboar.records.TwentyNewsgroupsRecordFactory; 44 | import com.cloudera.knittingboar.utils.DataUtils; 45 | import com.cloudera.knittingboar.utils.DatasetConverter; 46 | 47 | import junit.framework.TestCase; 48 | 49 | /** 50 | * Mainly just a demo to show how we'd test the 20Newsgroups model generated 51 | * with OLR 52 | * 53 | * @author jpatterson 54 | * 55 | */ 56 | public class TestBaseOLRTest20Newsgroups extends TestCase { 57 | 58 | //private static Path testData20News = new Path(System.getProperty("test.build.data", "/Users/jpatterson/Downloads/datasets/20news-kboar/test/kboar-shard-0.txt")); 59 | 60 | //private static Path model20News = new Path( "/Users/jpatterson/Downloads/datasets/20news-kboar/models/model_10_31pm.model" ); 61 | private static Path model20News = new Path( "/tmp/olr-news-group.model" ); 62 | 63 | //private static Path testData20News = new Path(System.getProperty("test.build.data", "/Users/jpatterson/Downloads/datasets/20news-kboar/test/")); 64 | 65 | private static final int FEATURES = 10000; 66 | 67 | private static JobConf defaultConf = new JobConf(); 68 | private static FileSystem localFs = null; 69 | static { 70 | try { 71 | defaultConf.set("fs.defaultFS", "file:///"); 72 | localFs = FileSystem.getLocal(defaultConf); 73 | } catch (IOException e) { 74 | throw new RuntimeException("init failure", e); 75 | } 76 | } 77 | 78 | POLRMetrics metrics = new POLRMetrics(); 79 | 80 | //double averageLL = 0.0; 81 | //double averageCorrect = 0.0; 82 | double averageLineCount = 0.0; 83 | int k = 0; 84 | double step = 0.0; 85 | int[] bumps = new int[]{1, 2, 5}; 86 | double lineCount = 0; 87 | 88 | private static Path workDir20NewsLocal = new Path(new Path("/tmp"), "Dataset20Newsgroups"); 89 | private static File unzipDir = new File( workDir20NewsLocal + "/20news-bydate"); 90 | private static String strKBoarTestDirInput = "" + unzipDir.toString() + "/KBoar-test/"; 91 | 92 | public Configuration generateDebugConfigurationObject() { 93 | 94 | Configuration c = new Configuration(); 95 | 96 | // feature vector size 97 | c.setInt( "com.cloudera.knittingboar.setup.FeatureVectorSize", 10000 ); 98 | 99 | c.setInt( "com.cloudera.knittingboar.setup.numCategories", 20); 100 | 101 | // setup 20newsgroups 102 | c.set( "com.cloudera.knittingboar.setup.RecordFactoryClassname", RecordFactory.TWENTYNEWSGROUPS_RECORDFACTORY); 103 | 104 | return c; 105 | 106 | } 107 | 108 | public InputSplit[] generateDebugSplits( Path input_path, JobConf job ) { 109 | 110 | 111 | long block_size = localFs.getDefaultBlockSize(); 112 | 113 | System.out.println("default block size: " + (block_size / 1024 / 1024) + "MB"); 114 | 115 | // ---- set where we'll read the input files from ------------- 116 | //FileInputFormat.setInputPaths(job, workDir); 117 | FileInputFormat.setInputPaths(job, input_path); 118 | 119 | 120 | // try splitting the file in a variety of sizes 121 | TextInputFormat format = new TextInputFormat(); 122 | format.configure(job); 123 | //LongWritable key = new LongWritable(); 124 | //Text value = new Text(); 125 | 126 | int numSplits = 1; 127 | 128 | InputSplit[] splits = null; 129 | 130 | try { 131 | splits = format.getSplits(job, numSplits); 132 | } catch (IOException e) { 133 | // TODO Auto-generated catch block 134 | e.printStackTrace(); 135 | } 136 | 137 | 138 | return splits; 139 | 140 | 141 | } 142 | 143 | 144 | 145 | 146 | 147 | 148 | public void testResults() throws Exception { 149 | 150 | File file20News = DataUtils.getTwentyNewsGroupDir(); 151 | DatasetConverter.ConvertNewsgroupsFromSingleFiles( DataUtils.get20NewsgroupsLocalDataLocation() + "/20news-bydate-test/", strKBoarTestDirInput, 6000); 152 | 153 | // File base = new File( file20News + "/20news-bydate-train/" ); 154 | 155 | System.out.println( "Testing on: " + strKBoarTestDirInput ); 156 | 157 | 158 | OnlineLogisticRegression classifier = ModelSerializer.readBinary(new FileInputStream(model20News.toString()), OnlineLogisticRegression.class); 159 | 160 | 161 | Text value = new Text(); 162 | long batch_vec_factory_time = 0; 163 | int k = 0; 164 | int num_correct = 0; 165 | 166 | 167 | 168 | // ---- this all needs to be done in 169 | JobConf job = new JobConf(defaultConf); 170 | 171 | // TODO: work on this, splits are generating for everything in dir 172 | // InputSplit[] splits = generateDebugSplits(inputDir, job); 173 | 174 | Path strKBoarTestDirInputPath = new Path( strKBoarTestDirInput ); 175 | 176 | //fullRCV1Dir 177 | InputSplit[] splits = generateDebugSplits(strKBoarTestDirInputPath, job); 178 | 179 | System.out.println( "split count: " + splits.length ); 180 | 181 | 182 | 183 | InputRecordsSplit custom_reader_0 = new InputRecordsSplit(job, splits[0]); 184 | 185 | TwentyNewsgroupsRecordFactory VectorFactory = new TwentyNewsgroupsRecordFactory("\t"); 186 | 187 | 188 | for (int x = 0; x < 8000; x++ ) { 189 | 190 | if ( custom_reader_0.next(value)) { 191 | 192 | long startTime = System.currentTimeMillis(); 193 | 194 | Vector v = new RandomAccessSparseVector(FEATURES); 195 | int actual = VectorFactory.processLine(value.toString(), v); 196 | 197 | long endTime = System.currentTimeMillis(); 198 | 199 | //System.out.println("That took " + (endTime - startTime) + " milliseconds"); 200 | batch_vec_factory_time += (endTime - startTime); 201 | 202 | 203 | String ng = VectorFactory.GetClassnameByID(actual); //.GetNewsgroupNameByID( actual ); 204 | 205 | // calc stats --------- 206 | 207 | double mu = Math.min(k + 1, 200); 208 | double ll = classifier.logLikelihood(actual, v); 209 | //averageLL = averageLL + (ll - averageLL) / mu; 210 | metrics.AvgLogLikelihood = metrics.AvgLogLikelihood + (ll - metrics.AvgLogLikelihood) / mu; 211 | 212 | Vector p = new DenseVector(20); 213 | classifier.classifyFull(p, v); 214 | int estimated = p.maxValueIndex(); 215 | 216 | int correct = (estimated == actual? 1 : 0); 217 | if (estimated == actual) { 218 | num_correct++; 219 | } 220 | //averageCorrect = averageCorrect + (correct - averageCorrect) / mu; 221 | metrics.AvgCorrect = metrics.AvgCorrect + (correct - metrics.AvgCorrect) / mu; 222 | 223 | //this.polr.train(actual, v); 224 | 225 | 226 | k++; 227 | // if (x == this.BatchSize - 1) { 228 | int bump = bumps[(int) Math.floor(step) % bumps.length]; 229 | int scale = (int) Math.pow(10, Math.floor(step / bumps.length)); 230 | 231 | if (k % (bump * scale) == 0) { 232 | step += 0.25; 233 | 234 | System.out.printf("Worker %s:\t Tested Recs: %10d, numCorrect: %d, AvgLL: %10.3f, Percent Correct: %10.2f, VF: %d\n", 235 | "OLR-standard-test", k, num_correct, metrics.AvgLogLikelihood, metrics.AvgCorrect * 100, batch_vec_factory_time); 236 | 237 | } 238 | 239 | classifier.close(); 240 | 241 | } else { 242 | 243 | // nothing else to process in split! 244 | break; 245 | 246 | } // if 247 | 248 | 249 | } // for the number of passes in the run 250 | 251 | 252 | 253 | } 254 | 255 | 256 | } 257 | --------------------------------------------------------------------------------