├── .gitignore ├── README ├── commands.txt ├── final-draft.pdf ├── pom.xml └── src ├── main ├── java │ ├── collabstream │ │ ├── App.java │ │ ├── lc │ │ │ ├── CounterBolt.java │ │ │ ├── FileReaderSpout.java │ │ │ ├── LineCountTopology.java │ │ │ ├── MsgForwarder.java │ │ │ └── MsgTracker.java │ │ └── streaming │ │ │ ├── BlockPair.java │ │ │ ├── Configuration.java │ │ │ ├── Master.java │ │ │ ├── MatrixSerialization.java │ │ │ ├── MatrixStore.java │ │ │ ├── MatrixUtils.java │ │ │ ├── MsgType.java │ │ │ ├── PermutationUtils.java │ │ │ ├── RatingsSource.java │ │ │ ├── StreamingDSGD.java │ │ │ ├── TestPredictions.java │ │ │ ├── TrainingExample.java │ │ │ ├── Worker.java │ │ │ └── WorkingBlock.java │ └── comparison │ │ └── dsgd │ │ ├── DSGDMain.java │ │ ├── IntMatrixItemPair.java │ │ ├── LongPair.java │ │ ├── MatrixItem.java │ │ ├── MatrixUtils.java │ │ ├── mapper │ │ ├── DSGDFinalMapper.java │ │ ├── DSGDIntermediateMapper.java │ │ ├── DSGDOutputFactorsMapper.java │ │ ├── DSGDPreprocFactorMapper.java │ │ ├── DSGDPreprocMapper.java │ │ ├── DSGDPreprocRatingsMapper.java │ │ └── DSGDRmseMapper.java │ │ └── reducer │ │ ├── DSGDFinalReducer.java │ │ ├── DSGDIntermediateReducer.java │ │ ├── DSGDOutputFactorsReducer.java │ │ ├── DSGDPreprocFactorReducer.java │ │ ├── DSGDPreprocRatingsReducer.java │ │ ├── DSGDPreprocReducer.java │ │ └── DSGDRmseReducer.java └── resources │ ├── log4j.properties │ └── lotr.txt └── test └── java └── collabstream └── AppTest.java /.gitignore: -------------------------------------------------------------------------------- 1 | #Compiled source # 2 | ################### 3 | *.com 4 | *.class 5 | *.dll 6 | *.exe 7 | *.o 8 | *.so 9 | 10 | # Packages # 11 | ############ 12 | # it's better to unpack these files and commit the raw source 13 | # git has its own built in compression methods 14 | *.7z 15 | *.dmg 16 | *.gz 17 | *.iso 18 | *.jar 19 | *.rar 20 | *.tar 21 | *.zip 22 | 23 | # Logs and databases # 24 | ###################### 25 | *.log 26 | *.sql 27 | *.sqlite 28 | 29 | # OS generated files # 30 | ###################### 31 | .DS_Store* 32 | ehthumbs.db 33 | Icon? 34 | Thumbs.db 35 | 36 | #directories 37 | classes/ 38 | lib/ 39 | data/ 40 | 41 | # Eclipse settings 42 | .settings/ 43 | .project 44 | .classpath 45 | -------------------------------------------------------------------------------- /README: -------------------------------------------------------------------------------- 1 | Authors: Chris Johnson, Alex Tang, Muqeet Ali 2 | 3 | CollabStream implements parallelized online matrix factorization using stochastic gradient descent in the case of online streaming data. CollabStream utilizes the Storm parallelization framework developed by Natahn Marz and available at https://github.com/nathanmarz/storm. For more information on the algorithm used by CollabStream please see the full report found at https://github.com/MrChrisJohnson/CollabStream. 4 | 5 | To run: 6 | 1. Go to the directory where you cloned the project. 7 | 2. mvn compile 8 | 3. mvn package 9 | 4. Take the jar with dependencies from the target folder (e.g., name.jar) 10 | 5. run on Hadoop cluster using: 11 | hadoop jar name.jar comparison.dsgd.DSGDMain inputPath outputPath numReducers 12 | where inputPath refers to the path of the train and test data, outputPath refers to where the RMSE and the factor matrices are outputted. 13 | 14 | -------------------------------------------------------------------------------- /commands.txt: -------------------------------------------------------------------------------- 1 | # Some useful commands 2 | 3 | mvn exec:java -Dexec.mainClass=collabstream.lc.LineCountTopology -Dexec.args='local src/main/resources/lotr.txt' -Dexec.classpathScope=compile 4 | mvn exec:java -Dexec.mainClass=collabstream.lc.LineCountTopology -Dexec.args='local src/main/resources/lotr.txt' -Dexec.classpathScope=compile | grep '########' 5 | mvn compile exec:java -Dexec.mainClass=collabstream.lc.LineCountTopology -Dexec.args='local src/main/resources/lotr.txt' -Dexec.classpathScope=compile 6 | mvn compile exec:java -Dexec.mainClass=collabstream.lc.LineCountTopology -Dexec.args='local src/main/resources/lotr.txt' -Dexec.classpathScope=compile | grep '########' 7 | 8 | mvn exec:java -Dexec.mainClass=collabstream.streaming.StreamingDSGD -Dexec.args='local 17 28 data/input/test.dat data/output/test.user data/output/test.item' -Dexec.classpathScope=compile 9 | mvn exec:java -Dexec.mainClass=collabstream.streaming.StreamingDSGD -Dexec.args='local 17 28 data/input/test.dat data/output/test.user data/output/test.item' -Dexec.classpathScope=compile | grep '########' 10 | mvn compile exec:java -Dexec.mainClass=collabstream.streaming.StreamingDSGD -Dexec.args='local 17 28 data/input/test.dat data/output/test.user data/output/test.item' -Dexec.classpathScope=compile 11 | mvn compile exec:java -Dexec.mainClass=collabstream.streaming.StreamingDSGD -Dexec.args='local 17 28 data/input/test.dat data/output/test.user data/output/test.item' -Dexec.classpathScope=compile | grep '########' 12 | 13 | mvn exec:java -Dexec.mainClass=collabstream.streaming.TestPredictions -Dexec.args='4 5 3 data/input/predtest data/output/predtest.user data/output/predtest.item' -Dexec.classpathScope=compile 14 | 15 | mvn exec:java -Dexec.mainClass=collabstream.streaming.StreamingDSGD -Dexec.args='local 6041 3953 data/input/MovieLens/ml_tr_rand.txt data/output/MovieLens/ml.user data/output/MovieLens/ml.item' -Dexec.classpathScope=compile 16 | 17 | mvn exec:java -Dexec.mainClass=collabstream.streaming.StreamingDSGD -Dexec.args='local 6041 3953 data/input/MovieLens/ml_100k data/output/MovieLens/ml_100k.user data/output/MovieLens/ml_100k.item' -Dexec.classpathScope=compile 18 | mvn compile exec:java -Dexec.mainClass=collabstream.streaming.TestPredictions -Dexec.args='6041 3953 10 data/input/MovieLens/ml1m_te_rb.dat data/output/MovieLens/ml_100k.user data/output/MovieLens/ml_100k.item' -Dexec.classpathScope=compile -------------------------------------------------------------------------------- /final-draft.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrChrisJohnson/CollabStream/b54dd4d5f0a098bca63a8de9275417a2ec892b6d/final-draft.pdf -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 3 | 4.0.0 4 | 5 | collabstream 6 | CollabStream 7 | 1.0 8 | jar 9 | 10 | CollabStream 11 | https://github.com/MrChrisJohnson/CollabStream 12 | 13 | 14 | UTF-8 15 | 16 | 17 | 18 | 19 | clojars.org 20 | http://clojars.org/repo 21 | 22 | 23 | 24 | 25 | 26 | junit 27 | junit 28 | 3.8.1 29 | compile 30 | 31 | 32 | storm 33 | storm 34 | 0.5.4 35 | provided 36 | 37 | 38 | log4j 39 | log4j 40 | 1.2.14 41 | 42 | 43 | org.apache.hadoop 44 | hadoop-core 45 | 0.20.2 46 | 47 | 48 | 49 | 50 | 51 | 52 | 54 | maven-assembly-plugin 55 | 2.2.1 56 | 57 | jar-with-dependencies 58 | 59 | 60 | 61 | make-assembly 62 | package 63 | single 64 | 65 | 66 | 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /src/main/java/collabstream/App.java: -------------------------------------------------------------------------------- 1 | package collabstream; 2 | 3 | /** 4 | * Hello world! 5 | * 6 | */ 7 | public class App 8 | { 9 | public static void main( String[] args ) 10 | { 11 | System.out.println( "Hello World!" ); 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /src/main/java/collabstream/lc/CounterBolt.java: -------------------------------------------------------------------------------- 1 | package collabstream.lc; 2 | 3 | import java.util.HashSet; 4 | import java.util.Map; 5 | 6 | import backtype.storm.task.OutputCollector; 7 | import backtype.storm.task.TopologyContext; 8 | import backtype.storm.topology.IRichBolt; 9 | import backtype.storm.topology.OutputFieldsDeclarer; 10 | import backtype.storm.tuple.Fields; 11 | import backtype.storm.tuple.Tuple; 12 | import backtype.storm.tuple.Values; 13 | 14 | public class CounterBolt implements IRichBolt { 15 | private OutputCollector collector; 16 | private HashSet lineNumSet = new HashSet(); 17 | 18 | 19 | public void prepare(Map config, TopologyContext context, OutputCollector collector) { 20 | this.collector = collector; 21 | } 22 | 23 | public void cleanup() { 24 | System.out.println("######## CounterBolt.cleanup: counted " + lineNumSet.size() + " lines"); 25 | } 26 | 27 | public void execute(Tuple tuple) { 28 | String line = tuple.getStringByField("line"); 29 | Integer lineNum = tuple.getIntegerByField("lineNum"); 30 | System.out.println("######## CounterBolt.execute: " + lineNum + ". " + line); 31 | lineNumSet.add(lineNum); 32 | collector.emit(tuple, new Values(lineNum)); 33 | collector.ack(tuple); 34 | } 35 | 36 | public void declareOutputFields(OutputFieldsDeclarer declarer) { 37 | declarer.declare(new Fields("lineNum")); 38 | } 39 | } -------------------------------------------------------------------------------- /src/main/java/collabstream/lc/FileReaderSpout.java: -------------------------------------------------------------------------------- 1 | package collabstream.lc; 2 | 3 | import java.io.FileNotFoundException; 4 | import java.io.FileReader; 5 | import java.io.IOException; 6 | import java.io.LineNumberReader; 7 | import java.util.Map; 8 | 9 | import backtype.storm.spout.SpoutOutputCollector; 10 | import backtype.storm.task.TopologyContext; 11 | import backtype.storm.topology.IRichSpout; 12 | import backtype.storm.topology.OutputFieldsDeclarer; 13 | import backtype.storm.tuple.Fields; 14 | import backtype.storm.tuple.Values; 15 | 16 | public class FileReaderSpout implements IRichSpout { 17 | private SpoutOutputCollector collector; 18 | private String fileName; 19 | private LineNumberReader in; 20 | 21 | public FileReaderSpout(String fileName) { 22 | this.fileName = fileName; 23 | } 24 | 25 | public boolean isDistributed() { 26 | return false; 27 | } 28 | 29 | public void open(Map config, TopologyContext context, SpoutOutputCollector collector) { 30 | this.collector = collector; 31 | try { 32 | in = new LineNumberReader(new FileReader(fileName)); 33 | } catch (FileNotFoundException e){ 34 | System.err.print("######## FileReaderSpout.nextTuple: " + e); 35 | } 36 | } 37 | 38 | public void close() { 39 | if (in == null) return; 40 | try { 41 | in.close(); 42 | } catch (IOException e) { 43 | System.err.print("######## FileReaderSpout.nextTuple: " + e); 44 | } 45 | } 46 | 47 | public void ack(Object msgId) { 48 | System.out.println("######## FileReaderSpout.ack: msgId=" + msgId); 49 | } 50 | 51 | public void fail(Object msgId) { 52 | System.out.println("######## FileReaderSpout.fail: msgId=" + msgId); 53 | } 54 | 55 | public void nextTuple() { 56 | if (in == null) return; 57 | String line = null; 58 | 59 | try { 60 | line = in.readLine(); 61 | } catch (IOException e) { 62 | System.err.print("######## FileReaderSpout.nextTuple: " + e); 63 | } 64 | 65 | if (line != null) { 66 | int lineNum = in.getLineNumber(); 67 | System.out.println("######## FileReaderSpout.nextTuple: " + lineNum + ". " + line); 68 | collector.emit(new Values(lineNum, line), lineNum); 69 | } 70 | } 71 | 72 | public void declareOutputFields(OutputFieldsDeclarer declarer) { 73 | declarer.declare(new Fields("lineNum", "line")); 74 | } 75 | } -------------------------------------------------------------------------------- /src/main/java/collabstream/lc/LineCountTopology.java: -------------------------------------------------------------------------------- 1 | package collabstream.lc; 2 | 3 | import backtype.storm.Config; 4 | import backtype.storm.LocalCluster; 5 | import backtype.storm.StormSubmitter; 6 | import backtype.storm.topology.TopologyBuilder; 7 | import backtype.storm.utils.Utils; 8 | 9 | /** 10 | * Simple line-counting program to test the functionality of Storm. Instantiates four classes: FileReaderSpout, 11 | * CounterBolt, MsgForwarder, and MsgTracker. FileReaderSpout reads lines from a file and emits them. CounterBolt 12 | * counts the lines received from FileReaderSpout and emits their corresponding line numbers. MsgForwarder just forwards 13 | * the lines received from FileReaderSpout and the line numbers received from CounterBolt. MsgTracker receives the 14 | * messages forwarded by MsgForwarder and keeps track of which lines were read and which lines were counted. 15 | */ 16 | public class LineCountTopology { 17 | public static void main(String[] args) throws Exception { 18 | if (args.length < 2) { 19 | System.err.println("######## Wrong number of arguments"); 20 | System.err.println("######## required args: local|production fileName"); 21 | return; 22 | } 23 | 24 | Config config = new Config(); 25 | TopologyBuilder builder = new TopologyBuilder(); 26 | builder.setSpout(1, new FileReaderSpout(args[1])); 27 | builder.setBolt(2, new CounterBolt()).shuffleGrouping(1); 28 | builder.setBolt(3, new MsgForwarder(1,2)).shuffleGrouping(1).shuffleGrouping(2); 29 | builder.setBolt(4, new MsgTracker(1,2)).shuffleGrouping(3,1).shuffleGrouping(3,2); 30 | 31 | System.out.println("######## LineCountTopology.main: submitting topology"); 32 | 33 | if ("local".equals(args[0])) { 34 | LocalCluster cluster = new LocalCluster(); 35 | cluster.submitTopology("line-count", config, builder.createTopology()); 36 | System.out.println("######## LineCountTopology.main: sleeping for 10 secs"); 37 | Utils.sleep(10000); 38 | cluster.shutdown(); 39 | } else { 40 | StormSubmitter.submitTopology("line-count", config, builder.createTopology()); 41 | } 42 | } 43 | } -------------------------------------------------------------------------------- /src/main/java/collabstream/lc/MsgForwarder.java: -------------------------------------------------------------------------------- 1 | package collabstream.lc; 2 | 3 | import java.util.Map; 4 | 5 | import backtype.storm.task.OutputCollector; 6 | import backtype.storm.task.TopologyContext; 7 | import backtype.storm.topology.IRichBolt; 8 | import backtype.storm.topology.OutputFieldsDeclarer; 9 | import backtype.storm.tuple.Fields; 10 | import backtype.storm.tuple.Tuple; 11 | 12 | public class MsgForwarder implements IRichBolt { 13 | private OutputCollector collector; 14 | public int fileReaderId, counterId; 15 | 16 | public MsgForwarder(int fileReaderId, int counterId) { 17 | this.fileReaderId = fileReaderId; 18 | this.counterId = counterId; 19 | } 20 | 21 | public void prepare(Map config, TopologyContext context, OutputCollector collector) { 22 | this.collector = collector; 23 | } 24 | 25 | public void cleanup() { 26 | } 27 | 28 | public void execute(Tuple tuple) { 29 | // The ID of the component that emitted the tuple is mapped directly to the ID of the forwarding stream. 30 | collector.emit(tuple.getSourceComponent(), tuple.getValues()); 31 | collector.ack(tuple); 32 | } 33 | 34 | public void declareOutputFields(OutputFieldsDeclarer declarer) { 35 | declarer.declareStream(fileReaderId, new Fields("lineNum", "line")); 36 | declarer.declareStream(counterId, new Fields("lineNum")); 37 | } 38 | } -------------------------------------------------------------------------------- /src/main/java/collabstream/lc/MsgTracker.java: -------------------------------------------------------------------------------- 1 | package collabstream.lc; 2 | 3 | import java.util.ArrayList; 4 | import java.util.Collections; 5 | import java.util.HashMap; 6 | import java.util.Map; 7 | 8 | import backtype.storm.task.OutputCollector; 9 | import backtype.storm.task.TopologyContext; 10 | import backtype.storm.topology.IRichBolt; 11 | import backtype.storm.topology.OutputFieldsDeclarer; 12 | import backtype.storm.tuple.Tuple; 13 | 14 | public class MsgTracker implements IRichBolt { 15 | private OutputCollector collector; 16 | private int fileReaderId, counterId; 17 | private HashMap lnToRec = new HashMap(); 18 | 19 | private static class Record { 20 | boolean read = false; 21 | boolean counted = false; 22 | 23 | Record() {} 24 | 25 | public String toString() { 26 | return "read=" + (read ? 1 : 0) + ", counted=" + (counted ? 1 : 0); 27 | } 28 | } 29 | 30 | public MsgTracker(int fileReaderId, int counterId) { 31 | this.fileReaderId = fileReaderId; 32 | this.counterId = counterId; 33 | } 34 | 35 | public void prepare(Map config, TopologyContext context, OutputCollector collector) { 36 | this.collector = collector; 37 | } 38 | 39 | public void cleanup() { 40 | ArrayList keys = new ArrayList(lnToRec.keySet()); 41 | Collections.sort(keys); 42 | for (Integer lineNum : keys) { 43 | System.out.println("######## MsgTracker.cleanup: " + lineNum + ". " + lnToRec.get(lineNum)); 44 | } 45 | } 46 | 47 | public void execute(Tuple tuple) { 48 | Integer lineNum = tuple.getIntegerByField("lineNum"); 49 | Record rec = lnToRec.get(lineNum); 50 | if (rec == null) { 51 | rec = new Record(); 52 | lnToRec.put(lineNum, rec); 53 | } 54 | 55 | int streamId = tuple.getSourceStreamId(); 56 | if (streamId == fileReaderId) { 57 | rec.read = true; 58 | } else if (streamId == counterId) { 59 | rec.counted = true; 60 | } 61 | collector.ack(tuple); 62 | } 63 | 64 | public void declareOutputFields(OutputFieldsDeclarer declarer) { 65 | } 66 | } -------------------------------------------------------------------------------- /src/main/java/collabstream/streaming/BlockPair.java: -------------------------------------------------------------------------------- 1 | package collabstream.streaming; 2 | 3 | import java.io.DataInputStream; 4 | import java.io.DataOutputStream; 5 | import java.io.IOException; 6 | import java.io.Serializable; 7 | 8 | import backtype.storm.serialization.ISerialization; 9 | 10 | public class BlockPair implements Serializable { 11 | public final int userBlockIdx, itemBlockIdx; 12 | 13 | public BlockPair(int userBlockIdx, int itemBlockIdx) { 14 | this.userBlockIdx = userBlockIdx; 15 | this.itemBlockIdx = itemBlockIdx; 16 | } 17 | 18 | public String toString() { 19 | return "(" + userBlockIdx + "," + itemBlockIdx + ")"; 20 | } 21 | 22 | public boolean equals(Object obj) { 23 | if (obj instanceof BlockPair) { 24 | BlockPair bp = (BlockPair)obj; 25 | return this.userBlockIdx == bp.userBlockIdx && this.itemBlockIdx == bp.itemBlockIdx; 26 | } else { 27 | return false; 28 | } 29 | } 30 | 31 | public int hashCode() { 32 | int userHi = userBlockIdx >>> 16; 33 | int userLo = userBlockIdx & 0xFFFF; 34 | int itemHi = itemBlockIdx >>> 16; 35 | int itemLo = itemBlockIdx & 0xFFFF; 36 | 37 | return ((userLo ^ itemHi) << 16) | (userHi ^ itemLo); 38 | } 39 | 40 | public static class Serialization implements ISerialization { 41 | public boolean accept(Class c) { 42 | return BlockPair.class.equals(c); 43 | } 44 | 45 | public void serialize(BlockPair bp, DataOutputStream out) throws IOException { 46 | out.writeInt(bp.userBlockIdx); 47 | out.writeInt(bp.itemBlockIdx); 48 | } 49 | 50 | public BlockPair deserialize(DataInputStream in) throws IOException { 51 | return new BlockPair(in.readInt(), in.readInt()); 52 | } 53 | } 54 | } -------------------------------------------------------------------------------- /src/main/java/collabstream/streaming/Configuration.java: -------------------------------------------------------------------------------- 1 | package collabstream.streaming; 2 | 3 | import java.io.Serializable; 4 | 5 | public class Configuration implements Serializable { 6 | public final int numUsers, numItems, numLatent; 7 | public final int numUserBlocks, numItemBlocks; 8 | public final float userPenalty, itemPenalty, initialStepSize; 9 | public final int maxTrainingIters; 10 | public final String inputFilename, userOutputFilename, itemOutputFilename; 11 | public final long inputDelay; // delay between sending two training examples in milliseconds 12 | public final boolean debug; 13 | 14 | private final int smallUserBlockSize, smallItemBlockSize; 15 | private final int bigUserBlockSize, bigItemBlockSize; 16 | private final int numBigUserBlocks, numBigItemBlocks; 17 | private final int userBlockThreshold, itemBlockThreshold; 18 | 19 | public Configuration(int numUsers, int numItems, int numLatent, int numUserBlocks, int numItemBlocks, 20 | float userPenalty, float itemPenalty, float initialStepSize, int maxTrainingIters, 21 | String inputFilename, String userOutputFilename, String itemOutputFilename, 22 | long inputDelay, boolean debug) { 23 | this.numUsers = numUsers; 24 | this.numItems = numItems; 25 | this.numLatent = numLatent; 26 | this.numUserBlocks = numUserBlocks; 27 | this.numItemBlocks = numItemBlocks; 28 | this.userPenalty = userPenalty; 29 | this.itemPenalty = itemPenalty; 30 | this.initialStepSize = initialStepSize; 31 | this.maxTrainingIters = maxTrainingIters; 32 | this.inputFilename = inputFilename; 33 | this.userOutputFilename = userOutputFilename; 34 | this.itemOutputFilename = itemOutputFilename; 35 | this.inputDelay = inputDelay; 36 | this.debug = debug; 37 | 38 | smallUserBlockSize = numUsers / numUserBlocks; 39 | bigUserBlockSize = smallUserBlockSize + 1; 40 | numBigUserBlocks = numUsers % numUserBlocks; 41 | userBlockThreshold = bigUserBlockSize * numBigUserBlocks; 42 | 43 | smallItemBlockSize = numItems / numItemBlocks; 44 | bigItemBlockSize = smallItemBlockSize + 1; 45 | numBigItemBlocks = numItems % numItemBlocks; 46 | itemBlockThreshold = bigItemBlockSize * numBigItemBlocks; 47 | } 48 | 49 | public int getNumWorkers() { 50 | return numUserBlocks; 51 | } 52 | 53 | public int getNumProcesses() { 54 | // numUserBlocks (Worker) + numUserBlocks (MatrixStore) + numItemBlocks (MatrixStore) + 1 (Master) + 1 (RatingsSource) 55 | return 2*numUserBlocks + numItemBlocks + 2; 56 | } 57 | 58 | public int getUserBlockIdx(int userId) { 59 | if (userId < userBlockThreshold) { 60 | return userId / bigUserBlockSize; 61 | } else { 62 | return numBigUserBlocks + (userId - userBlockThreshold) / smallUserBlockSize; 63 | } 64 | } 65 | 66 | public int getItemBlockIdx(int itemId) { 67 | if (itemId < itemBlockThreshold) { 68 | return itemId / bigItemBlockSize; 69 | } else { 70 | return numBigItemBlocks + (itemId - itemBlockThreshold) / smallItemBlockSize; 71 | } 72 | } 73 | 74 | public int getUserBlockLength(int userBlockIdx) { 75 | return (userBlockIdx < numBigUserBlocks) ? bigUserBlockSize : smallUserBlockSize; 76 | } 77 | 78 | public int getItemBlockLength(int itemBlockIdx) { 79 | return (itemBlockIdx < numBigItemBlocks) ? bigItemBlockSize : smallItemBlockSize; 80 | } 81 | 82 | public int getUserBlockStart(int userBlockIdx) { 83 | if (userBlockIdx < numBigUserBlocks) { 84 | return bigUserBlockSize * userBlockIdx; 85 | } else { 86 | return userBlockThreshold + (userBlockIdx - numBigUserBlocks) * smallUserBlockSize; 87 | } 88 | } 89 | 90 | public int getItemBlockStart(int itemBlockIdx) { 91 | if (itemBlockIdx < numBigItemBlocks) { 92 | return bigItemBlockSize * itemBlockIdx; 93 | } else { 94 | return itemBlockThreshold + (itemBlockIdx - numBigItemBlocks) * smallItemBlockSize; 95 | } 96 | } 97 | } -------------------------------------------------------------------------------- /src/main/java/collabstream/streaming/Master.java: -------------------------------------------------------------------------------- 1 | package collabstream.streaming; 2 | 3 | import java.io.BufferedWriter; 4 | import java.io.FileWriter; 5 | import java.io.IOException; 6 | import java.io.PrintWriter; 7 | import java.util.ArrayList; 8 | import java.util.HashSet; 9 | import java.util.LinkedList; 10 | import java.util.Map; 11 | import java.util.Queue; 12 | import java.util.Random; 13 | import java.util.Set; 14 | 15 | import org.apache.commons.lang.time.DurationFormatUtils; 16 | 17 | import backtype.storm.task.OutputCollector; 18 | import backtype.storm.task.TopologyContext; 19 | import backtype.storm.topology.IRichBolt; 20 | import backtype.storm.topology.OutputFieldsDeclarer; 21 | import backtype.storm.tuple.Fields; 22 | import backtype.storm.tuple.Tuple; 23 | import backtype.storm.tuple.Values; 24 | 25 | import static collabstream.streaming.MsgType.*; 26 | 27 | public class Master implements IRichBolt { 28 | private OutputCollector collector; 29 | private final Configuration config; 30 | private PrintWriter userOutput, itemOutput; 31 | private BlockPair[][] blockPair; 32 | private TrainingExample[][] latestExample; 33 | private Set unfinished = new HashSet(); 34 | private Set freeSet = new HashSet(); 35 | private Queue userBlockQueue = new LinkedList(); 36 | private Queue itemBlockQueue = new LinkedList(); 37 | private boolean endOfData = false; 38 | private long startTime, outputStartTime = 0; 39 | private final Random random = new Random(); 40 | 41 | public Master(Configuration config) { 42 | this.config = config; 43 | blockPair = new BlockPair[config.numUserBlocks][config.numItemBlocks]; 44 | latestExample = new TrainingExample[config.numUserBlocks][config.numItemBlocks]; 45 | } 46 | 47 | public void prepare(Map stormConfig, TopologyContext context, OutputCollector collector) { 48 | this.collector = collector; 49 | startTime = System.currentTimeMillis(); 50 | System.out.printf("######## Training started: %1$tY-%1$tb-%1$td %1$tT %tZ\n", startTime); 51 | } 52 | 53 | public void cleanup() { 54 | } 55 | 56 | public void execute(Tuple tuple) { 57 | MsgType msgType = (MsgType)tuple.getValue(0); 58 | if (config.debug && msgType != END_OF_DATA && msgType != PROCESS_BLOCK_FIN) { 59 | System.out.println("######## Master.execute: " + msgType + " " + tuple.getValue(1)); 60 | } 61 | TrainingExample ex, latest; 62 | BlockPair bp, head; 63 | 64 | switch (msgType) { 65 | case END_OF_DATA: 66 | if (config.debug) { 67 | System.out.println("######## Master.execute: " + msgType); 68 | } 69 | endOfData = true; 70 | collector.ack(tuple); 71 | distributeWork(); 72 | break; 73 | case TRAINING_EXAMPLE: 74 | ex = (TrainingExample)tuple.getValue(1); 75 | int userBlockIdx = config.getUserBlockIdx(ex.userId); 76 | int itemBlockIdx = config.getItemBlockIdx(ex.itemId); 77 | 78 | latest = latestExample[userBlockIdx][itemBlockIdx]; 79 | if (latest == null || latest.timestamp < ex.timestamp) { 80 | latestExample[userBlockIdx][itemBlockIdx] = ex; 81 | } 82 | 83 | bp = blockPair[userBlockIdx][itemBlockIdx]; 84 | if (bp == null) { 85 | bp = blockPair[userBlockIdx][itemBlockIdx] = new BlockPair(userBlockIdx, itemBlockIdx); 86 | unfinished.add(bp); 87 | freeSet.add(bp); 88 | } 89 | 90 | collector.emit(tuple, new Values(TRAINING_EXAMPLE, null, ex, userBlockIdx)); 91 | collector.ack(tuple); 92 | distributeWork(); 93 | break; 94 | case PROCESS_BLOCK_FIN: 95 | bp = (BlockPair)tuple.getValue(1); 96 | ex = (TrainingExample)tuple.getValue(2); 97 | if (config.debug) { 98 | System.out.println("######## Master.execute: " + msgType + " " + bp + " " + ex); 99 | } 100 | 101 | latest = latestExample[bp.userBlockIdx][bp.itemBlockIdx]; 102 | if (latest.timestamp == ex.timestamp) { 103 | latest.numTrainingIters = ex.numTrainingIters; 104 | if (endOfData && latest.numTrainingIters >= config.maxTrainingIters) { 105 | unfinished.remove(bp); 106 | } 107 | } 108 | 109 | // numTrainingIters must be updated before free() is called to prevent finished blocks 110 | // from being added to the freeSet 111 | free(bp.userBlockIdx, bp.itemBlockIdx); 112 | 113 | if (unfinished.isEmpty()) { 114 | startOutput(); 115 | } else { 116 | distributeWork(); 117 | } 118 | break; 119 | case USER_BLOCK: 120 | bp = (BlockPair)tuple.getValue(1); 121 | float[][] userBlock = (float[][])tuple.getValue(2); 122 | head = userBlockQueue.remove(); 123 | if (!head.equals(bp)) { 124 | throw new RuntimeException("Expected " + head + ", but received " + bp + " for " + USER_BLOCK); 125 | } 126 | writeUserBlock(bp.userBlockIdx, userBlock); 127 | requestNextUserBlock(); 128 | break; 129 | case ITEM_BLOCK: 130 | bp = (BlockPair)tuple.getValue(1); 131 | float[][] itemBlock = (float[][])tuple.getValue(2); 132 | head = itemBlockQueue.remove(); 133 | if (!head.equals(bp)) { 134 | throw new RuntimeException("Expected " + head + ", but received " + bp + " for " + ITEM_BLOCK); 135 | } 136 | writeItemBlock(bp.itemBlockIdx, itemBlock); 137 | requestNextItemBlock(); 138 | break; 139 | } 140 | } 141 | 142 | public void declareOutputFields(OutputFieldsDeclarer declarer) { 143 | // Field "userBlockIdx" used solely for grouping 144 | declarer.declare(new Fields("msgType", "blockPair", "example", "userBlockIdx")); 145 | } 146 | 147 | private void lock(int userBlockIdx, int itemBlockIdx) { 148 | for (int j = 0; j < config.numItemBlocks; ++j) { 149 | BlockPair bp = blockPair[userBlockIdx][j]; 150 | if (bp != null) { 151 | freeSet.remove(bp); 152 | } 153 | } 154 | for (int i = 0; i < config.numUserBlocks; ++i) { 155 | BlockPair bp = blockPair[i][itemBlockIdx]; 156 | if (bp != null) { 157 | freeSet.remove(bp); 158 | } 159 | } 160 | } 161 | 162 | private void free(int userBlockIdx, int itemBlockIdx) { 163 | for (int j = 0; j < config.numItemBlocks; ++j) { 164 | BlockPair bp = blockPair[userBlockIdx][j]; 165 | if (bp != null) { 166 | if (latestExample[userBlockIdx][j].numTrainingIters < config.maxTrainingIters) { 167 | freeSet.add(bp); 168 | } 169 | } 170 | } 171 | for (int i = 0; i < config.numUserBlocks; ++i) { 172 | BlockPair bp = blockPair[i][itemBlockIdx]; 173 | if (bp != null) { 174 | if (latestExample[i][itemBlockIdx].numTrainingIters < config.maxTrainingIters) { 175 | freeSet.add(bp); 176 | } 177 | } 178 | } 179 | } 180 | 181 | private void distributeWork() { 182 | ArrayList freeList = new ArrayList(freeSet.size()); 183 | while (!freeSet.isEmpty()) { 184 | freeList.addAll(freeSet); 185 | int i = random.nextInt(freeList.size()); 186 | BlockPair bp = freeList.get(i); 187 | lock(bp.userBlockIdx, bp.itemBlockIdx); 188 | collector.emit(new Values(PROCESS_BLOCK_REQ, bp, null, bp.userBlockIdx)); 189 | freeList.clear(); 190 | } 191 | } 192 | 193 | private void startOutput() { 194 | try { 195 | if (outputStartTime > 0) return; 196 | outputStartTime = System.currentTimeMillis(); 197 | System.out.printf("######## Training finished: %1$tY-%1$tb-%1$td %1$tT %tZ\n", outputStartTime); 198 | System.out.println("######## Elapsed training time: " 199 | + DurationFormatUtils.formatPeriod(startTime, outputStartTime, "H:m:s") + " (h:m:s)"); 200 | System.out.printf("######## Output started: %1$tY-%1$tb-%1$td %1$tT %tZ\n", outputStartTime); 201 | 202 | userOutput = new PrintWriter(new BufferedWriter(new FileWriter(config.userOutputFilename))); 203 | itemOutput = new PrintWriter(new BufferedWriter(new FileWriter(config.itemOutputFilename))); 204 | 205 | for (int i = 0; i < config.numUserBlocks; ++i) { 206 | // Add the first block in row i to userBlockQueue 207 | for (int j = 0; j < config.numItemBlocks; ++j) { 208 | BlockPair bp = blockPair[i][j]; 209 | if (bp != null) { 210 | userBlockQueue.add(bp); 211 | break; 212 | } 213 | } 214 | } 215 | for (int j = 0; j < config.numItemBlocks; ++j) { 216 | // Add the first block in column j to itemBlockQueue 217 | for (int i = 0; i < config.numUserBlocks; ++i) { 218 | BlockPair bp = blockPair[i][j]; 219 | if (bp != null) { 220 | itemBlockQueue.add(bp); 221 | break; 222 | } 223 | } 224 | } 225 | 226 | requestNextUserBlock(); 227 | requestNextItemBlock(); 228 | } catch (IOException e) { 229 | System.err.println("######## Master.startOutput: " + e); 230 | } 231 | } 232 | 233 | private void writeUserBlock(int userBlockIdx, float[][] userBlock) { 234 | int userBlockStart = config.getUserBlockStart(userBlockIdx); 235 | for (int i = 0; i < userBlock.length; ++i) { 236 | userOutput.print(userBlockStart + i); 237 | for (int k = 0; k < config.numLatent; ++k) { 238 | userOutput.print(' '); 239 | userOutput.print(userBlock[i][k]); 240 | } 241 | userOutput.println(); 242 | } 243 | } 244 | 245 | private void writeItemBlock(int itemBlockIdx, float[][] itemBlock) { 246 | int itemBlockStart = config.getItemBlockStart(itemBlockIdx); 247 | for (int j = 0; j < itemBlock.length; ++j) { 248 | itemOutput.print(itemBlockStart + j); 249 | for (int k = 0; k < config.numLatent; ++k) { 250 | itemOutput.print(' '); 251 | itemOutput.print(itemBlock[j][k]); 252 | } 253 | itemOutput.println(); 254 | } 255 | } 256 | 257 | private void endOutput() { 258 | long endTime = System.currentTimeMillis(); 259 | System.out.printf("######## Output finished: %1$tY-%1$tb-%1$td %1$tT %tZ\n", endTime); 260 | System.out.println("######## Elapsed output time: " 261 | + DurationFormatUtils.formatPeriod(outputStartTime, endTime, "H:m:s") + " (h:m:s)"); 262 | System.out.println("######## Total elapsed time: " 263 | + DurationFormatUtils.formatPeriod(startTime, endTime, "H:m:s") + " (h:m:s)"); 264 | } 265 | 266 | private void requestNextUserBlock() { 267 | if (userBlockQueue.isEmpty()) { 268 | if (userOutput != null) { 269 | userOutput.close(); 270 | } 271 | if (itemBlockQueue.isEmpty()) { 272 | endOutput(); 273 | } 274 | } else { 275 | BlockPair bp = userBlockQueue.peek(); 276 | collector.emit(new Values(USER_BLOCK_REQ, bp, null, bp.userBlockIdx)); 277 | } 278 | } 279 | 280 | private void requestNextItemBlock() { 281 | if (itemBlockQueue.isEmpty()) { 282 | if (itemOutput != null) { 283 | itemOutput.close(); 284 | } 285 | if (userBlockQueue.isEmpty()) { 286 | endOutput(); 287 | } 288 | } else { 289 | BlockPair bp = itemBlockQueue.peek(); 290 | collector.emit(new Values(ITEM_BLOCK_REQ, bp, null, bp.userBlockIdx)); 291 | } 292 | } 293 | } -------------------------------------------------------------------------------- /src/main/java/collabstream/streaming/MatrixSerialization.java: -------------------------------------------------------------------------------- 1 | package collabstream.streaming; 2 | 3 | import java.io.DataInputStream; 4 | import java.io.DataOutputStream; 5 | import java.io.IOException; 6 | import java.nio.ByteBuffer; 7 | import java.nio.FloatBuffer; 8 | 9 | import backtype.storm.serialization.ISerialization; 10 | 11 | public class MatrixSerialization implements ISerialization{ 12 | public boolean accept(Class c) { 13 | return float[][].class.equals(c); 14 | } 15 | 16 | public void serialize(float[][] matrix, DataOutputStream out) throws IOException { 17 | if (matrix == null) return; 18 | int numRows = matrix.length; 19 | int numCols = (numRows == 0) ? 0 : matrix[0].length; 20 | 21 | out.writeInt(numRows); 22 | out.writeInt(numCols); 23 | 24 | byte[] arr = new byte[4*numCols]; 25 | ByteBuffer bbuf = ByteBuffer.wrap(arr); 26 | FloatBuffer fbuf = bbuf.asFloatBuffer(); 27 | 28 | for (int i = 0; i < numRows; ++i) { 29 | if (matrix[i].length != numCols) { 30 | throw new Error("Rows of matrix have different sizes"); 31 | } 32 | fbuf.put(matrix[i]); 33 | out.write(arr); 34 | fbuf.clear(); 35 | } 36 | } 37 | 38 | public float[][] deserialize(DataInputStream in) throws IOException { 39 | int numRows = in.readInt(); 40 | int numCols = in.readInt(); 41 | float[][] matrix = new float[numRows][numCols]; 42 | 43 | byte[] arr = new byte[4*numCols]; 44 | ByteBuffer bbuf = ByteBuffer.wrap(arr); 45 | FloatBuffer fbuf = bbuf.asFloatBuffer(); 46 | 47 | for (int i = 0; i < numRows; ++i) { 48 | in.readFully(arr); 49 | fbuf.get(matrix[i]); 50 | fbuf.clear(); 51 | } 52 | 53 | return matrix; 54 | } 55 | } -------------------------------------------------------------------------------- /src/main/java/collabstream/streaming/MatrixStore.java: -------------------------------------------------------------------------------- 1 | package collabstream.streaming; 2 | 3 | import java.util.Map; 4 | import java.util.concurrent.ConcurrentHashMap; 5 | 6 | import backtype.storm.task.OutputCollector; 7 | import backtype.storm.task.TopologyContext; 8 | import backtype.storm.topology.IRichBolt; 9 | import backtype.storm.topology.OutputFieldsDeclarer; 10 | import backtype.storm.tuple.Fields; 11 | import backtype.storm.tuple.Tuple; 12 | import backtype.storm.tuple.Values; 13 | 14 | import static collabstream.streaming.MsgType.*; 15 | 16 | public class MatrixStore implements IRichBolt { 17 | private OutputCollector collector; 18 | private final Configuration config; 19 | private final Map userBlockMap = new ConcurrentHashMap(); 20 | private final Map itemBlockMap = new ConcurrentHashMap(); 21 | 22 | public MatrixStore(Configuration config) { 23 | this.config = config; 24 | } 25 | 26 | public void prepare(Map stormConfig, TopologyContext context, OutputCollector collector) { 27 | this.collector = collector; 28 | } 29 | 30 | public void cleanup() { 31 | } 32 | 33 | public void execute(Tuple tuple) { 34 | MsgType msgType = (MsgType)tuple.getValue(0); 35 | BlockPair bp = (BlockPair)tuple.getValue(1); 36 | Integer taskIdObj = (Integer)tuple.getValue(3); 37 | int taskId = (taskIdObj != null) ? taskIdObj.intValue() : tuple.getSourceTask(); 38 | if (config.debug) { 39 | System.out.println("######## MatrixStore.execute: " + msgType + " " + bp + " [" + taskId + "]"); 40 | } 41 | 42 | switch (msgType) { 43 | case USER_BLOCK_REQ: 44 | // In general, this pattern of access is not thread-safe. But since requests with the same userBlockIdx 45 | // are sent to the same thread, it should be safe in our case. 46 | float[][] userBlock = userBlockMap.get(bp.userBlockIdx); 47 | if (userBlock == null) { 48 | userBlock = MatrixUtils.generateRandomMatrix(config.getUserBlockLength(bp.userBlockIdx), config.numLatent); 49 | userBlockMap.put(bp.userBlockIdx, userBlock); 50 | } 51 | collector.emitDirect(taskId, new Values(USER_BLOCK, bp, (Object)userBlock)); 52 | break; 53 | case ITEM_BLOCK_REQ: 54 | // In general, this pattern of access is not thread-safe. But since requests with the same itemBlockIdx 55 | // are sent to the same thread, it should be safe in our case. 56 | float[][] itemBlock = itemBlockMap.get(bp.itemBlockIdx); 57 | if (itemBlock == null) { 58 | itemBlock = MatrixUtils.generateRandomMatrix(config.getItemBlockLength(bp.itemBlockIdx), config.numLatent); 59 | itemBlockMap.put(bp.itemBlockIdx, itemBlock); 60 | } 61 | collector.emitDirect(taskId, new Values(ITEM_BLOCK, bp, (Object)itemBlock)); 62 | break; 63 | case USER_BLOCK: 64 | userBlockMap.put(bp.userBlockIdx, (float[][])tuple.getValue(2)); 65 | collector.emitDirect(tuple.getSourceTask(), new Values(USER_BLOCK_SAVED, bp, null)); 66 | break; 67 | case ITEM_BLOCK: 68 | itemBlockMap.put(bp.itemBlockIdx, (float[][])tuple.getValue(2)); 69 | collector.emitDirect(taskId, new Values(ITEM_BLOCK_SAVED, bp, null)); 70 | break; 71 | } 72 | } 73 | 74 | public void declareOutputFields(OutputFieldsDeclarer declarer) { 75 | declarer.declare(true, new Fields("msgType", "blockPair", "block")); 76 | } 77 | } -------------------------------------------------------------------------------- /src/main/java/collabstream/streaming/MatrixUtils.java: -------------------------------------------------------------------------------- 1 | package collabstream.streaming; 2 | 3 | import java.util.Random; 4 | 5 | public class MatrixUtils { 6 | private static final ThreadLocal localRandom = new ThreadLocal() { 7 | protected Random initialValue() { return new Random(); } 8 | }; 9 | 10 | public static String toString(float[][] matrix) { 11 | if (matrix == null) return ""; 12 | int numRows = matrix.length; 13 | if (numRows == 0) return "[]"; 14 | int numCols = matrix[0].length; 15 | 16 | StringBuilder b = new StringBuilder(numRows*(9*numCols + 2) + 1); 17 | 18 | b.append('['); 19 | for (int i = 0; i < numRows-1; ++i) { 20 | numCols = matrix[i].length; 21 | for (int j = 0; j < numCols; ++j) { 22 | b.append(String.format(" %8.3f", matrix[i][j])); 23 | } 24 | b.append("\n "); 25 | } 26 | numCols = matrix[numRows-1].length; 27 | for (int j = 0; j < numCols; ++j) { 28 | b.append(String.format(" %8.3f", matrix[numRows-1][j])); 29 | } 30 | b.append(" ]"); 31 | 32 | return b.toString(); 33 | } 34 | 35 | public static float[][] generateRandomMatrix(int numRows, int numCols) { 36 | Random random = localRandom.get(); 37 | float[][] matrix = new float[numRows][numCols]; 38 | 39 | for (int i = 0; i < numRows; ++i) { 40 | for (int j = 0; j < numCols; ++j) { 41 | matrix[i][j] = random.nextFloat(); 42 | } 43 | } 44 | return matrix; 45 | } 46 | } -------------------------------------------------------------------------------- /src/main/java/collabstream/streaming/MsgType.java: -------------------------------------------------------------------------------- 1 | package collabstream.streaming; 2 | 3 | public enum MsgType { 4 | END_OF_DATA, 5 | TRAINING_EXAMPLE, 6 | PROCESS_BLOCK_REQ, 7 | PROCESS_BLOCK_FIN, 8 | USER_BLOCK_REQ, 9 | ITEM_BLOCK_REQ, 10 | USER_BLOCK, 11 | ITEM_BLOCK, 12 | USER_BLOCK_SAVED, 13 | ITEM_BLOCK_SAVED 14 | } -------------------------------------------------------------------------------- /src/main/java/collabstream/streaming/PermutationUtils.java: -------------------------------------------------------------------------------- 1 | package collabstream.streaming; 2 | 3 | import java.util.ArrayList; 4 | import java.util.Random; 5 | 6 | import org.apache.commons.lang.time.DurationFormatUtils; 7 | 8 | public class PermutationUtils { 9 | private static final ThreadLocal localRandom = new ThreadLocal() { 10 | protected Random initialValue() { return new Random(); } 11 | }; 12 | 13 | public static T[] permute(T[] arr) { 14 | if (arr == null) return null; 15 | Random random = localRandom.get(); 16 | for (int n = arr.length; n > 1; --n) { 17 | int i = random.nextInt(n); 18 | T temp = arr[i]; 19 | arr[i] = arr[n-1]; 20 | arr[n-1] = temp; 21 | } 22 | return arr; 23 | } 24 | 25 | public static ArrayList permute(ArrayList list) { 26 | if (list == null) return null; 27 | Random random = localRandom.get(); 28 | for (int n = list.size(); n > 1; --n) { 29 | int i = random.nextInt(n); 30 | T temp = list.get(i); 31 | list.set(i, list.get(n-1)); 32 | list.set(n-1, temp); 33 | } 34 | return list; 35 | } 36 | } -------------------------------------------------------------------------------- /src/main/java/collabstream/streaming/RatingsSource.java: -------------------------------------------------------------------------------- 1 | package collabstream.streaming; 2 | 3 | import java.io.BufferedReader; 4 | import java.io.FileReader; 5 | import java.io.IOException; 6 | import java.util.Map; 7 | 8 | import org.apache.commons.lang.StringUtils; 9 | import org.apache.commons.lang.time.DurationFormatUtils; 10 | 11 | import backtype.storm.spout.SpoutOutputCollector; 12 | import backtype.storm.task.TopologyContext; 13 | import backtype.storm.topology.IRichSpout; 14 | import backtype.storm.topology.OutputFieldsDeclarer; 15 | import backtype.storm.tuple.Fields; 16 | import backtype.storm.tuple.Values; 17 | import backtype.storm.utils.Utils; 18 | 19 | import static collabstream.streaming.MsgType.*; 20 | 21 | public class RatingsSource implements IRichSpout { 22 | protected SpoutOutputCollector collector; 23 | private final Configuration config; 24 | private BufferedReader input; 25 | private int sequenceNum = 0; 26 | private long inputStartTime; 27 | 28 | public RatingsSource(Configuration config) { 29 | this.config = config; 30 | } 31 | 32 | public boolean isDistributed() { 33 | return false; 34 | } 35 | 36 | public void open(Map stormConfig, TopologyContext context, SpoutOutputCollector collector) { 37 | this.collector = collector; 38 | inputStartTime = System.currentTimeMillis(); 39 | System.out.printf("######## Input started: %1$tY-%1$tb-%1$td %1$tT %tZ\n", inputStartTime); 40 | try { 41 | input = new BufferedReader(new FileReader(config.inputFilename)); 42 | } catch (IOException e) { 43 | throw new RuntimeException(e); 44 | } 45 | } 46 | 47 | public void close() { 48 | } 49 | 50 | public void ack(Object msgId) { 51 | } 52 | 53 | public void fail(Object msgId) { 54 | if (config.debug) { 55 | System.err.println("######## RatingsSource.fail: Resending " + msgId); 56 | } 57 | if (msgId == END_OF_DATA) { 58 | collector.emit(new Values(END_OF_DATA, null), END_OF_DATA); 59 | } else { 60 | TrainingExample ex = (TrainingExample)msgId; 61 | collector.emit(new Values(TRAINING_EXAMPLE, ex), ex); 62 | } 63 | } 64 | 65 | public void nextTuple() { 66 | if (input == null) return; 67 | try { 68 | String line = input.readLine(); 69 | if (line == null) { 70 | long inputEndTime = System.currentTimeMillis(); 71 | System.out.printf("######## Input finished: %1$tY-%1$tb-%1$td %1$tT %tZ\n", inputEndTime); 72 | System.out.println("######## Elapsed input time: " 73 | + DurationFormatUtils.formatPeriod(inputStartTime, inputEndTime, "H:m:s") + " (h:m:s)"); 74 | 75 | input.close(); 76 | input = null; 77 | collector.emit(new Values(END_OF_DATA, null), END_OF_DATA); 78 | } else { 79 | try { 80 | String[] token = StringUtils.split(line, ' '); 81 | int userId = Integer.parseInt(token[0]); 82 | int itemId = Integer.parseInt(token[1]); 83 | float rating = Float.parseFloat(token[2]); 84 | 85 | TrainingExample ex = new TrainingExample(sequenceNum++, userId, itemId, rating); 86 | collector.emit(new Values(TRAINING_EXAMPLE, ex), ex); 87 | 88 | if (config.inputDelay > 0) { 89 | Utils.sleep(config.inputDelay); 90 | } 91 | } catch (Exception e) { 92 | System.err.println("######## RatingsSource.nextTuple: Could not parse line: " + line + "\n" + e); 93 | } 94 | } 95 | } catch (IOException e) { 96 | throw new RuntimeException(e); 97 | } 98 | } 99 | 100 | public void declareOutputFields(OutputFieldsDeclarer declarer) { 101 | declarer.declare(new Fields("msgType", "example")); 102 | } 103 | } -------------------------------------------------------------------------------- /src/main/java/collabstream/streaming/StreamingDSGD.java: -------------------------------------------------------------------------------- 1 | package collabstream.streaming; 2 | 3 | import java.io.File; 4 | import java.io.FileReader; 5 | import java.util.Properties; 6 | 7 | import backtype.storm.Config; 8 | import backtype.storm.LocalCluster; 9 | import backtype.storm.StormSubmitter; 10 | import backtype.storm.topology.TopologyBuilder; 11 | import backtype.storm.tuple.Fields; 12 | 13 | public class StreamingDSGD { 14 | public static void main(String[] args) throws Exception { 15 | if (args.length < 6) { 16 | System.err.println("######## Wrong number of arguments"); 17 | System.err.println("######## required args: local|production numUsers numItems" 18 | + " inputFilename userOutputFilename itemOutputFilename"); 19 | return; 20 | } 21 | 22 | Properties props = new Properties(); 23 | File propFile = new File("data/collabstream.properties"); 24 | if (propFile.exists()) { 25 | FileReader in = new FileReader(propFile); 26 | props.load(in); 27 | in.close(); 28 | } 29 | 30 | int numUsers = Integer.parseInt(args[1]); 31 | int numItems = Integer.parseInt(args[2]); 32 | int numLatent = Integer.parseInt(props.getProperty("numLatent", "10")); 33 | int numUserBlocks = Integer.parseInt(props.getProperty("numUserBlocks", "10")); 34 | int numItemBlocks = Integer.parseInt(props.getProperty("numItemBlocks", "10")); 35 | float userPenalty = Float.parseFloat(props.getProperty("userPenalty", "0.1")); 36 | float itemPenalty = Float.parseFloat(props.getProperty("itemPenalty", "0.1")); 37 | float initialStepSize = Float.parseFloat(props.getProperty("initialStepSize", "0.1")); 38 | int maxTrainingIters = Integer.parseInt(props.getProperty("maxTrainingIters", "30")); 39 | String inputFilename = args[3]; 40 | String userOutputFilename = args[4]; 41 | String itemOutputFilename = args[5]; 42 | long inputDelay = Long.parseLong(props.getProperty("inputDelay", "0")); 43 | boolean debug = Boolean.parseBoolean(props.getProperty("debug", "false")); 44 | 45 | Configuration config = new Configuration(numUsers, numItems, numLatent, numUserBlocks, numItemBlocks, 46 | userPenalty, itemPenalty, initialStepSize, maxTrainingIters, 47 | inputFilename, userOutputFilename, itemOutputFilename, 48 | inputDelay, debug); 49 | 50 | Config stormConfig = new Config(); 51 | stormConfig.addSerialization(TrainingExample.Serialization.class); 52 | stormConfig.addSerialization(BlockPair.Serialization.class); 53 | stormConfig.addSerialization(MatrixSerialization.class); 54 | stormConfig.setNumWorkers(config.getNumProcesses()); 55 | stormConfig.setNumAckers(config.getNumWorkers()); // our notion of a worker is different from Storm's 56 | 57 | TopologyBuilder builder = new TopologyBuilder(); 58 | builder.setSpout(1, new RatingsSource(config)); 59 | builder.setBolt(2, new Master(config)) 60 | .globalGrouping(1) 61 | .globalGrouping(3, Worker.TO_MASTER_STREAM_ID) 62 | .directGrouping(4) 63 | .directGrouping(5); 64 | builder.setBolt(3, new Worker(config), config.getNumWorkers()) 65 | .fieldsGrouping(2, new Fields("userBlockIdx")) 66 | .directGrouping(4) 67 | .directGrouping(5); 68 | builder.setBolt(4, new MatrixStore(config), config.numUserBlocks) 69 | .fieldsGrouping(3, Worker.USER_BLOCK_STREAM_ID, new Fields("userBlockIdx")); 70 | builder.setBolt(5, new MatrixStore(config), config.numItemBlocks) 71 | .fieldsGrouping(3, Worker.ITEM_BLOCK_STREAM_ID, new Fields("itemBlockIdx")); 72 | 73 | System.out.println("######## StreamingDSGD.main: submitting topology"); 74 | 75 | if ("local".equals(args[0])) { 76 | LocalCluster cluster = new LocalCluster(); 77 | cluster.submitTopology("StreamingDSGD", stormConfig, builder.createTopology()); 78 | } else { 79 | StormSubmitter.submitTopology("StreamingDSGD", stormConfig, builder.createTopology()); 80 | } 81 | } 82 | } -------------------------------------------------------------------------------- /src/main/java/collabstream/streaming/TestPredictions.java: -------------------------------------------------------------------------------- 1 | package collabstream.streaming; 2 | 3 | import java.io.FileReader; 4 | import java.io.LineNumberReader; 5 | import java.util.HashMap; 6 | import java.util.Map; 7 | 8 | import org.apache.commons.lang.StringUtils; 9 | import org.apache.commons.lang.time.DurationFormatUtils; 10 | 11 | public class TestPredictions { 12 | public static void main(String[] args) throws Exception { 13 | if (args.length < 7) { 14 | System.err.println("######## Wrong number of arguments"); 15 | System.err.println("######## required args: numUsers numItems numLatent" 16 | + " trainingFilename testFilename userFilename itemFilename"); 17 | return; 18 | } 19 | 20 | long testStartTime = System.currentTimeMillis(); 21 | System.out.printf("######## Testing started: %1$tY-%1$tb-%1$td %1$tT %tZ\n", testStartTime); 22 | 23 | int numUsers = Integer.parseInt(args[0]); 24 | int numItems = Integer.parseInt(args[1]); 25 | int numLatent = Integer.parseInt(args[2]); 26 | String trainingFilename = args[3]; 27 | String testFilename = args[4]; 28 | String userFilename = args[5]; 29 | String itemFilename = args[6]; 30 | 31 | float trainingTotal = 0.0f; 32 | int trainingCount = 0; 33 | 34 | Map userCount = new HashMap(); 35 | Map itemCount = new HashMap(); 36 | Map userTotal = new HashMap(); 37 | Map itemTotal = new HashMap(); 38 | 39 | long startTime = System.currentTimeMillis(); 40 | System.out.printf("######## Started reading training file: %1$tY-%1$tb-%1$td %1$tT %tZ\n", startTime); 41 | 42 | String line; 43 | LineNumberReader in = new LineNumberReader(new FileReader(trainingFilename)); 44 | while ((line = in.readLine()) != null) { 45 | try { 46 | String[] token = StringUtils.split(line, ' '); 47 | int i = Integer.parseInt(token[0]); 48 | int j = Integer.parseInt(token[1]); 49 | float rating = Float.parseFloat(token[2]); 50 | 51 | trainingTotal += rating; 52 | ++trainingCount; 53 | 54 | if (userCount.containsKey(i)) { 55 | userCount.put(i, userCount.get(i) + 1); 56 | userTotal.put(i, userTotal.get(i) + rating); 57 | } else { 58 | userCount.put(i, 1); 59 | userTotal.put(i, rating); 60 | } 61 | 62 | if (itemCount.containsKey(j)) { 63 | itemCount.put(j, itemCount.get(j) + 1); 64 | itemTotal.put(j, itemTotal.get(j) + rating); 65 | } else { 66 | itemCount.put(j, 1); 67 | itemTotal.put(j, rating); 68 | } 69 | } catch (Exception e) { 70 | System.err.printf("######## Could not parse line %d in %s\n%s\n", in.getLineNumber(), trainingFilename, e); 71 | } 72 | } 73 | in.close(); 74 | 75 | float trainingAvg = trainingTotal / trainingCount; 76 | 77 | long endTime = System.currentTimeMillis(); 78 | System.out.printf("######## Finished reading training file: %1$tY-%1$tb-%1$td %1$tT %tZ\n", endTime); 79 | System.out.println("######## Time elapsed reading training file: " 80 | + DurationFormatUtils.formatPeriod(startTime, endTime, "H:m:s") + " (h:m:s)"); 81 | 82 | float[][] userMatrix = new float[numUsers][numLatent]; 83 | for (int i = 0; i < numUsers; ++i) { 84 | for (int k = 0; k < numLatent; ++k) { 85 | userMatrix[i][k] = 0.0f; 86 | } 87 | } 88 | 89 | float[][] itemMatrix = new float[numItems][numLatent]; 90 | for (int i = 0; i < numItems; ++i) { 91 | for (int k = 0; k < numLatent; ++k) { 92 | itemMatrix[i][k] = 0.0f; 93 | } 94 | } 95 | 96 | startTime = System.currentTimeMillis(); 97 | System.out.printf("######## Started reading user file: %1$tY-%1$tb-%1$td %1$tT %tZ\n", startTime); 98 | 99 | in = new LineNumberReader(new FileReader(userFilename)); 100 | while ((line = in.readLine()) != null) { 101 | try { 102 | String[] token = StringUtils.split(line, ' '); 103 | int i = Integer.parseInt(token[0]); 104 | for (int k = 0; k < numLatent; ++k) { 105 | userMatrix[i][k] = Float.parseFloat(token[k+1]); 106 | } 107 | } catch (Exception e) { 108 | System.err.printf("######## Could not parse line %d in %s\n%s\n", in.getLineNumber(), userFilename, e); 109 | } 110 | } 111 | in.close(); 112 | 113 | endTime = System.currentTimeMillis(); 114 | System.out.printf("######## Finished reading user file: %1$tY-%1$tb-%1$td %1$tT %tZ\n", endTime); 115 | System.out.println("######## Time elapsed reading user file: " 116 | + DurationFormatUtils.formatPeriod(startTime, endTime, "H:m:s") + " (h:m:s)"); 117 | 118 | startTime = System.currentTimeMillis(); 119 | System.out.printf("######## Started reading item file: %1$tY-%1$tb-%1$td %1$tT %tZ\n", startTime); 120 | 121 | in = new LineNumberReader(new FileReader(itemFilename)); 122 | while ((line = in.readLine()) != null) { 123 | try { 124 | String[] token = StringUtils.split(line, ' '); 125 | int j = Integer.parseInt(token[0]); 126 | for (int k = 0; k < numLatent; ++k) { 127 | itemMatrix[j][k] = Float.parseFloat(token[k+1]); 128 | } 129 | } catch (Exception e) { 130 | System.err.printf("######## Could not parse line %d in %s\n%s\n", in.getLineNumber(), itemFilename, e); 131 | } 132 | } 133 | in.close(); 134 | 135 | endTime = System.currentTimeMillis(); 136 | System.out.printf("######## Finished reading item file: %1$tY-%1$tb-%1$td %1$tT %tZ\n", endTime); 137 | System.out.println("######## Time elapsed reading item file: " 138 | + DurationFormatUtils.formatPeriod(startTime, endTime, "H:m:s") + " (h:m:s)"); 139 | 140 | startTime = System.currentTimeMillis(); 141 | System.out.printf("######## Started reading test file: %1$tY-%1$tb-%1$td %1$tT %tZ\n", startTime); 142 | 143 | float totalSqErr = 0.0f; 144 | int numRatings = 0; 145 | 146 | in = new LineNumberReader(new FileReader(testFilename)); 147 | while ((line = in.readLine()) != null) { 148 | try { 149 | String[] token = StringUtils.split(line, ' '); 150 | int i = Integer.parseInt(token[0]); 151 | int j = Integer.parseInt(token[1]); 152 | float rating = Float.parseFloat(token[2]); 153 | float prediction; 154 | 155 | boolean userKnown = userCount.containsKey(i); 156 | boolean itemKnown = itemCount.containsKey(j); 157 | 158 | if (userKnown && itemKnown) { 159 | prediction = 0.0f; 160 | for (int k = 0; k < numLatent; ++k) { 161 | prediction += userMatrix[i][k] * itemMatrix[j][k]; 162 | } 163 | } else if (userKnown) { 164 | prediction = userTotal.get(i) / userCount.get(i); 165 | } else if (itemKnown) { 166 | prediction = itemTotal.get(j) / itemCount.get(j); 167 | } else { 168 | prediction = trainingAvg; 169 | } 170 | 171 | float diff = prediction - rating; 172 | totalSqErr += diff*diff; 173 | ++numRatings; 174 | } catch (Exception e) { 175 | System.err.printf("######## Could not parse line %d in %s\n%s\n", in.getLineNumber(), testFilename, e); 176 | } 177 | } 178 | 179 | double rmse = Math.sqrt(totalSqErr / numRatings); 180 | 181 | endTime = System.currentTimeMillis(); 182 | System.out.printf("######## Finished reading test file: %1$tY-%1$tb-%1$td %1$tT %tZ\n", endTime); 183 | System.out.println("######## Time elapsed reading test file: " 184 | + DurationFormatUtils.formatPeriod(startTime, endTime, "H:m:s") + " (h:m:s)"); 185 | System.out.println("######## Total elapsed testing time: " 186 | + DurationFormatUtils.formatPeriod(testStartTime, endTime, "H:m:s") + " (h:m:s)"); 187 | System.out.println("######## Number of ratings used: " + numRatings); 188 | System.out.println("######## RMSE: " + rmse); 189 | } 190 | } -------------------------------------------------------------------------------- /src/main/java/collabstream/streaming/TrainingExample.java: -------------------------------------------------------------------------------- 1 | package collabstream.streaming; 2 | 3 | import java.io.DataInputStream; 4 | import java.io.DataOutputStream; 5 | import java.io.IOException; 6 | import java.io.Serializable; 7 | 8 | import backtype.storm.serialization.ISerialization; 9 | 10 | public class TrainingExample implements Serializable { 11 | // Not a system timestamp; just a sequence number. Using type int to save memory; as a consequence, 12 | // cannot handle more than 2^31 training examples. 13 | public final int timestamp; 14 | public final int userId, itemId; 15 | public final float rating; 16 | public int numTrainingIters = 0; 17 | 18 | public TrainingExample(int timestamp, int userId, int itemId, float rating) { 19 | this.timestamp = timestamp; 20 | this.userId = userId; 21 | this.itemId = itemId; 22 | this.rating = rating; 23 | } 24 | 25 | public String toString() { 26 | return "(<" + timestamp + ">," + userId + "," + itemId + "," + rating + "," + numTrainingIters + ")"; 27 | } 28 | 29 | public boolean equals(Object obj) { 30 | if (obj instanceof TrainingExample) { 31 | TrainingExample ex = (TrainingExample)obj; 32 | return this.userId == ex.userId && this.itemId == ex.itemId; 33 | } else { 34 | return false; 35 | } 36 | } 37 | 38 | public int hashCode() { 39 | int userHi = userId >>> 16; 40 | int userLo = userId & 0xFFFF; 41 | int itemHi = itemId >>> 16; 42 | int itemLo = itemId & 0xFFFF; 43 | 44 | return ((userLo ^ itemHi) << 16) | (userHi ^ itemLo); 45 | } 46 | 47 | public static class Serialization implements ISerialization { 48 | public boolean accept(Class c) { 49 | return TrainingExample.class.equals(c); 50 | } 51 | 52 | public void serialize(TrainingExample ex, DataOutputStream out) throws IOException { 53 | out.writeInt(ex.timestamp); 54 | out.writeInt(ex.userId); 55 | out.writeInt(ex.itemId); 56 | out.writeFloat(ex.rating); 57 | out.writeInt(ex.numTrainingIters); 58 | } 59 | 60 | public TrainingExample deserialize(DataInputStream in) throws IOException { 61 | TrainingExample ex = new TrainingExample(in.readInt(), in.readInt(), in.readInt(), in.readFloat()); 62 | ex.numTrainingIters = in.readInt(); 63 | return ex; 64 | } 65 | } 66 | } -------------------------------------------------------------------------------- /src/main/java/collabstream/streaming/Worker.java: -------------------------------------------------------------------------------- 1 | package collabstream.streaming; 2 | 3 | import java.util.Map; 4 | import java.util.concurrent.ConcurrentHashMap; 5 | 6 | import backtype.storm.task.OutputCollector; 7 | import backtype.storm.task.TopologyContext; 8 | import backtype.storm.topology.IRichBolt; 9 | import backtype.storm.topology.OutputFieldsDeclarer; 10 | import backtype.storm.tuple.Fields; 11 | import backtype.storm.tuple.Tuple; 12 | import backtype.storm.tuple.Values; 13 | 14 | import static collabstream.streaming.MsgType.*; 15 | 16 | public class Worker implements IRichBolt { 17 | public static final int TO_MASTER_STREAM_ID = 1; 18 | public static final int USER_BLOCK_STREAM_ID = 2; 19 | public static final int ITEM_BLOCK_STREAM_ID = 3; 20 | 21 | private OutputCollector collector; 22 | private final Configuration config; 23 | private final Map workingBlockMap = new ConcurrentHashMap(); 24 | 25 | public Worker(Configuration config) { 26 | this.config = config; 27 | } 28 | 29 | public void prepare(Map stormConfig, TopologyContext context, OutputCollector collector) { 30 | this.collector = collector; 31 | } 32 | 33 | public void cleanup() { 34 | } 35 | 36 | public void execute(Tuple tuple) { 37 | MsgType msgType = (MsgType)tuple.getValue(0); 38 | if (config.debug && msgType != TRAINING_EXAMPLE) { 39 | System.out.println("######## Worker.execute: " + msgType + " " + tuple.getValue(1)); 40 | } 41 | BlockPair bp; 42 | WorkingBlock workingBlock; 43 | 44 | switch (msgType) { 45 | case TRAINING_EXAMPLE: 46 | TrainingExample ex = (TrainingExample)tuple.getValue(2); 47 | if (config.debug) { 48 | System.out.println("######## Worker.execute: " + msgType + " " + ex); 49 | } 50 | bp = new BlockPair(config.getUserBlockIdx(ex.userId), config.getItemBlockIdx(ex.itemId)); 51 | workingBlock = getWorkingBlock(bp); 52 | workingBlock.examples.add(ex); 53 | collector.ack(tuple); 54 | break; 55 | case PROCESS_BLOCK_REQ: 56 | bp = (BlockPair)tuple.getValue(1); 57 | workingBlock = getWorkingBlock(bp); 58 | workingBlock.waitingForBlocks = true; 59 | collector.emit(USER_BLOCK_STREAM_ID, new Values(USER_BLOCK_REQ, bp, null, null, bp.userBlockIdx)); 60 | collector.emit(ITEM_BLOCK_STREAM_ID, new Values(ITEM_BLOCK_REQ, bp, null, null, bp.itemBlockIdx)); 61 | break; 62 | case USER_BLOCK_REQ: 63 | // Forward request from master 64 | bp = (BlockPair)tuple.getValue(1); 65 | collector.emit(USER_BLOCK_STREAM_ID, new Values(USER_BLOCK_REQ, bp, null, tuple.getSourceTask(), bp.userBlockIdx)); 66 | break; 67 | case ITEM_BLOCK_REQ: 68 | // Forward request from master 69 | bp = (BlockPair)tuple.getValue(1); 70 | collector.emit(ITEM_BLOCK_STREAM_ID, new Values(ITEM_BLOCK_REQ, bp, null, tuple.getSourceTask(), bp.itemBlockIdx)); 71 | break; 72 | case USER_BLOCK: 73 | bp = (BlockPair)tuple.getValue(1); 74 | float[][] userBlock = (float[][])tuple.getValue(2); 75 | workingBlock = getWorkingBlock(bp); 76 | 77 | if (workingBlock.waitingForBlocks) { 78 | workingBlock.userBlock = userBlock; 79 | if (workingBlock.itemBlock != null) { 80 | update(bp, workingBlock); 81 | } 82 | } 83 | break; 84 | case ITEM_BLOCK: 85 | bp = (BlockPair)tuple.getValue(1); 86 | float[][] itemBlock = (float[][])tuple.getValue(2); 87 | workingBlock = getWorkingBlock(bp); 88 | 89 | if (workingBlock.waitingForBlocks) { 90 | workingBlock.itemBlock = itemBlock; 91 | if (workingBlock.userBlock != null) { 92 | update(bp, workingBlock); 93 | } 94 | } 95 | break; 96 | case USER_BLOCK_SAVED: 97 | bp = (BlockPair)tuple.getValue(1); 98 | workingBlock = getWorkingBlock(bp); 99 | 100 | if (workingBlock.waitingForStorage) { 101 | workingBlock.userBlock = null; 102 | if (workingBlock.itemBlock == null) { 103 | workingBlock.waitingForStorage = false; 104 | collector.emit(TO_MASTER_STREAM_ID, new Values(PROCESS_BLOCK_FIN, bp, workingBlock.getLatestExample())); 105 | } 106 | } 107 | break; 108 | case ITEM_BLOCK_SAVED: 109 | bp = (BlockPair)tuple.getValue(1); 110 | workingBlock = getWorkingBlock(bp); 111 | 112 | if (workingBlock.waitingForStorage) { 113 | workingBlock.itemBlock = null; 114 | if (workingBlock.userBlock == null) { 115 | workingBlock.waitingForStorage = false; 116 | collector.emit(TO_MASTER_STREAM_ID, new Values(PROCESS_BLOCK_FIN, bp, workingBlock.getLatestExample())); 117 | } 118 | } 119 | break; 120 | } 121 | } 122 | 123 | public void declareOutputFields(OutputFieldsDeclarer declarer) { 124 | // Fields "userBlockIdx" and "itemBlockIdx" used solely for grouping 125 | declarer.declareStream(TO_MASTER_STREAM_ID, new Fields("msgType", "blockPair", "latestExample")); 126 | declarer.declareStream(USER_BLOCK_STREAM_ID, new Fields("msgType", "blockPair", "block", "taskId", "userBlockIdx")); 127 | declarer.declareStream(ITEM_BLOCK_STREAM_ID, new Fields("msgType", "blockPair", "block", "taskId", "itemBlockIdx")); 128 | } 129 | 130 | private WorkingBlock getWorkingBlock(BlockPair bp) { 131 | // In general, this pattern of access is not thread-safe. But since requests with the same userBlockIdx 132 | // are sent to the same thread, it should be safe in our case. 133 | WorkingBlock workingBlock = workingBlockMap.get(bp); 134 | if (workingBlock == null) { 135 | workingBlock = new WorkingBlock(); 136 | workingBlockMap.put(bp, workingBlock); 137 | } 138 | return workingBlock; 139 | } 140 | 141 | private void update(BlockPair bp, WorkingBlock workingBlock) { 142 | int userBlockStart = config.getUserBlockStart(bp.userBlockIdx); 143 | int itemBlockStart = config.getItemBlockStart(bp.itemBlockIdx); 144 | float[][] userBlock = workingBlock.userBlock; 145 | float[][] itemBlock = workingBlock.itemBlock; 146 | 147 | TrainingExample[] examples = workingBlock.examples.toArray(new TrainingExample[workingBlock.examples.size()]); 148 | PermutationUtils.permute(examples); 149 | 150 | for (TrainingExample ex : examples) { 151 | if (ex.numTrainingIters >= config.maxTrainingIters) continue; 152 | int i = ex.userId - userBlockStart; 153 | int j = ex.itemId - itemBlockStart; 154 | 155 | float dotProduct = 0.0f; 156 | for (int k = 0; k < config.numLatent; ++k) { 157 | dotProduct += userBlock[i][k] * itemBlock[j][k]; 158 | } 159 | float ratingDiff = dotProduct - ex.rating; 160 | 161 | ++ex.numTrainingIters; 162 | float stepSize = 2 * config.initialStepSize / ex.numTrainingIters; 163 | 164 | for (int k = 0; k < config.numLatent; ++k) { 165 | float oldUserWeight = userBlock[i][k]; 166 | float oldItemWeight = itemBlock[j][k]; 167 | userBlock[i][k] -= stepSize*(ratingDiff * oldItemWeight + config.userPenalty * oldUserWeight); 168 | itemBlock[j][k] -= stepSize*(ratingDiff * oldUserWeight + config.itemPenalty * oldItemWeight); 169 | } 170 | } 171 | workingBlock.waitingForBlocks = false; 172 | workingBlock.waitingForStorage = true; 173 | 174 | collector.emit(USER_BLOCK_STREAM_ID, new Values(USER_BLOCK, bp, userBlock, null, bp.userBlockIdx)); 175 | collector.emit(ITEM_BLOCK_STREAM_ID, new Values(ITEM_BLOCK, bp, itemBlock, null, bp.itemBlockIdx)); 176 | } 177 | } -------------------------------------------------------------------------------- /src/main/java/collabstream/streaming/WorkingBlock.java: -------------------------------------------------------------------------------- 1 | package collabstream.streaming; 2 | 3 | import java.io.Serializable; 4 | import java.util.HashSet; 5 | import java.util.Set; 6 | 7 | public class WorkingBlock implements Serializable { 8 | public final Set examples = new HashSet(); 9 | public float[][] userBlock = null; 10 | public float[][] itemBlock = null; 11 | public boolean waitingForBlocks = false; 12 | public boolean waitingForStorage = false; 13 | 14 | public String toString() { 15 | StringBuilder b = new StringBuilder(24*examples.size() + 72); 16 | 17 | b.append("examples={"); 18 | boolean first = true; 19 | for (TrainingExample ex : examples) { 20 | if (first) { 21 | first = false; 22 | } else { 23 | b.append(", "); 24 | } 25 | b.append(ex.toString()); 26 | } 27 | b.append('}'); 28 | 29 | b.append("\nuserBlock=\n").append(MatrixUtils.toString(userBlock)); 30 | b.append("\nitemBlock=\n").append(MatrixUtils.toString(itemBlock)); 31 | b.append("\nwaitingForBlocks=").append(waitingForBlocks); 32 | b.append("\nwaitingForStorage=").append(waitingForStorage); 33 | 34 | return b.toString(); 35 | } 36 | 37 | public TrainingExample getLatestExample() { 38 | TrainingExample latest = null; 39 | for (TrainingExample ex : examples) { 40 | if (latest == null || latest.timestamp < ex.timestamp) { 41 | latest = ex; 42 | } 43 | } 44 | return latest; 45 | } 46 | } -------------------------------------------------------------------------------- /src/main/java/comparison/dsgd/DSGDMain.java: -------------------------------------------------------------------------------- 1 | package comparison.dsgd; 2 | 3 | import java.util.ArrayList; 4 | import java.util.Collections; 5 | import java.util.List; 6 | 7 | import org.apache.hadoop.conf.Configuration; 8 | import org.apache.hadoop.conf.Configured; 9 | import org.apache.hadoop.fs.FileSystem; 10 | import org.apache.hadoop.fs.Path; 11 | import org.apache.hadoop.io.DoubleWritable; 12 | import org.apache.hadoop.io.NullWritable; 13 | import org.apache.hadoop.io.Text; 14 | import org.apache.hadoop.mapreduce.Job; 15 | import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; 16 | import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; 17 | import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; 18 | import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; 19 | import org.apache.log4j.Logger; 20 | 21 | import comparison.dsgd.mapper.DSGDOutputFactorsMapper; 22 | import comparison.dsgd.mapper.DSGDIntermediateMapper; 23 | import comparison.dsgd.mapper.DSGDPreprocFactorMapper; 24 | import comparison.dsgd.mapper.DSGDPreprocRatingsMapper; 25 | import comparison.dsgd.mapper.DSGDRmseMapper; 26 | import comparison.dsgd.reducer.DSGDOutputFactorsReducer; 27 | import comparison.dsgd.reducer.DSGDIntermediateReducer; 28 | import comparison.dsgd.reducer.DSGDPreprocFactorReducer; 29 | import comparison.dsgd.reducer.DSGDPreprocRatingsReducer; 30 | import comparison.dsgd.reducer.DSGDRmseReducer; 31 | 32 | 33 | public class DSGDMain extends Configured { 34 | 35 | private static final Logger sLogger = Logger.getLogger(DSGDMain.class); 36 | 37 | public static int getBlockRow(int row, int numUsers, int numReducers) { 38 | int usersPerBlock = (numUsers / numReducers) + 1; 39 | return (int)Math.floor(((double)row / (double)usersPerBlock)); 40 | } 41 | 42 | public static int getBlockColumn(int column, int numItems, int numReducers) { 43 | int itemsPerBlock = (numItems / numReducers) + 1; 44 | return (int)Math.floor(((double)column / (double)itemsPerBlock)); 45 | } 46 | 47 | public static void main(String[] args) throws Exception { 48 | if (args.length != 3) { 49 | System.out.println("usage: [input-path] [output-path]"); 50 | return; 51 | } 52 | long startTime = System.currentTimeMillis(); 53 | long timeElapsed = 0; 54 | 55 | String input = args[0]; 56 | String output = args[1]; 57 | int numReducers = Integer.parseInt(args[2]); 58 | 59 | // Fixed parameters 60 | // Edit these for individual experiment 61 | String numUsers = "573"; 62 | String numItems = "2649430"; 63 | int kValue = 10; 64 | int numIterations = 10; 65 | int numBlocks = numReducers; 66 | double stepSize = 0.1; // \tau (step size for SGD updates) 67 | double lambda = 0.1; 68 | 69 | String trainInput = input + "/train"; 70 | String testInput = input + "/test"; 71 | String tempFactorsInput = output + "/temp/input"; 72 | String tempFactorsOutput = output + "/temp/output/"; 73 | String tempTrainRatings = output + "/temp/train"; 74 | String tempTestRatings = output + "/temp/test"; 75 | String emptyInput = input + "/empty"; 76 | String finalFactorsOutput = output + "/final_factors"; 77 | String rmseOutput = output + "/rmse"; 78 | 79 | /** 80 | * jobPreprocRatings 81 | * write train and test ratings to RatingsItem format 82 | */ 83 | Configuration confPreprocRatings = new Configuration(); 84 | confPreprocRatings.set("numUsers", numUsers); 85 | confPreprocRatings.set("numReducers", Integer.toString(numBlocks)); 86 | confPreprocRatings.set("numItems", numItems); 87 | confPreprocRatings.set("kValue", Integer.toString(kValue)); 88 | 89 | Job jobPreprocRatings = new Job(confPreprocRatings, "train_ratings"); 90 | 91 | FileInputFormat.setInputPaths(jobPreprocRatings, new Path(trainInput)); 92 | FileOutputFormat.setOutputPath(jobPreprocRatings, new Path(tempTrainRatings)); 93 | // Delete the output directory if it exists already 94 | FileSystem.get(confPreprocRatings).delete(new Path(tempTrainRatings), true); 95 | 96 | jobPreprocRatings.setOutputFormatClass(SequenceFileOutputFormat.class); 97 | 98 | jobPreprocRatings.setJarByClass(DSGDMain.class); 99 | jobPreprocRatings.setNumReduceTasks(numReducers); 100 | jobPreprocRatings.setOutputKeyClass(MatrixItem.class); 101 | jobPreprocRatings.setOutputValueClass(NullWritable.class); 102 | jobPreprocRatings.setMapOutputKeyClass(MatrixItem.class); 103 | jobPreprocRatings.setMapOutputValueClass(NullWritable.class); 104 | jobPreprocRatings.setMapperClass(DSGDPreprocRatingsMapper.class); 105 | jobPreprocRatings.setReducerClass(DSGDPreprocRatingsReducer.class); 106 | jobPreprocRatings.setCombinerClass(DSGDPreprocRatingsReducer.class); 107 | 108 | // train RatingsItems 109 | jobPreprocRatings.waitForCompletion(true); 110 | 111 | // test RatingsItems 112 | jobPreprocRatings = new Job(confPreprocRatings, "train_ratings"); 113 | 114 | FileInputFormat.setInputPaths(jobPreprocRatings, new Path(testInput)); 115 | FileOutputFormat.setOutputPath(jobPreprocRatings, new Path(tempTestRatings)); 116 | // Delete the output directory if it exists already 117 | FileSystem.get(confPreprocRatings).delete(new Path(tempTestRatings), true); 118 | 119 | jobPreprocRatings.setOutputFormatClass(SequenceFileOutputFormat.class); 120 | jobPreprocRatings.setJarByClass(DSGDMain.class); 121 | jobPreprocRatings.setNumReduceTasks(numReducers); 122 | jobPreprocRatings.setOutputKeyClass(MatrixItem.class); 123 | jobPreprocRatings.setOutputValueClass(NullWritable.class); 124 | jobPreprocRatings.setMapOutputKeyClass(MatrixItem.class); 125 | jobPreprocRatings.setMapOutputValueClass(NullWritable.class); 126 | jobPreprocRatings.setMapperClass(DSGDPreprocRatingsMapper.class); 127 | jobPreprocRatings.setReducerClass(DSGDPreprocRatingsReducer.class); 128 | jobPreprocRatings.setCombinerClass(DSGDPreprocRatingsReducer.class); 129 | 130 | jobPreprocRatings.waitForCompletion(true); 131 | 132 | /** 133 | * jobPreprocFactors 134 | * Create initial U and M and output to FactorItem format 135 | */ 136 | 137 | Configuration confPreprocFactors = new Configuration(); 138 | confPreprocFactors.set("numUsers", numUsers); 139 | confPreprocFactors.set("numItems", numItems); 140 | confPreprocFactors.set("kValue", Integer.toString(kValue)); 141 | 142 | Job jobPreprocFactors = new Job(confPreprocFactors, "factors"); 143 | 144 | FileInputFormat.setInputPaths(jobPreprocFactors, new Path(emptyInput)); 145 | FileOutputFormat.setOutputPath(jobPreprocFactors, new Path(tempFactorsInput)); 146 | // Delete the output directory if it exists already 147 | FileSystem.get(confPreprocRatings).delete(new Path(tempFactorsInput), true); 148 | 149 | jobPreprocFactors.setJarByClass(DSGDMain.class); 150 | jobPreprocFactors.setNumReduceTasks(1); 151 | jobPreprocFactors.setOutputKeyClass(MatrixItem.class); 152 | jobPreprocFactors.setOutputValueClass(NullWritable.class); 153 | jobPreprocFactors.setMapOutputKeyClass(MatrixItem.class); 154 | jobPreprocFactors.setMapOutputValueClass(NullWritable.class); 155 | jobPreprocFactors.setMapperClass(DSGDPreprocFactorMapper.class); 156 | jobPreprocFactors.setReducerClass(DSGDPreprocFactorReducer.class); 157 | jobPreprocFactors.setCombinerClass(DSGDPreprocFactorReducer.class); 158 | jobPreprocFactors.setOutputFormatClass(SequenceFileOutputFormat.class); 159 | 160 | jobPreprocFactors.waitForCompletion(true); 161 | 162 | // if(true){ 163 | // return; 164 | // } 165 | 166 | for(int a = 0; numIterations < 1; ++a){ 167 | sLogger.info("Running iteration " + a); 168 | // Build stratum 169 | List stratum = new ArrayList(); 170 | for(int i = 0; i < numBlocks; ++i ){ 171 | stratum.add(i); 172 | } 173 | Collections.shuffle(stratum); //choose random initial stratum 174 | String stratumString; 175 | 176 | for(int i=0; i < numReducers; ++i) { 177 | 178 | sLogger.info("Running stratum " + i + " on iteration " + a ); 179 | // Choose next stratum 180 | int current1 = stratum.get(0); 181 | int current2; 182 | for(int j = 1; j < numBlocks; ++j){ 183 | current2 = stratum.get(j); 184 | stratum.set(j, current1); 185 | current1 = current2; 186 | } 187 | stratum.set(0, current1); 188 | 189 | // Build string version of stratum 190 | stratumString = ""; 191 | for(int j : stratum){ 192 | stratumString += " " + j; 193 | } 194 | 195 | /** 196 | * jobIntermediate 197 | * Perform DSGD updates on stratum 198 | */ 199 | Configuration confIntermediate = new Configuration(); 200 | confIntermediate.set("stratum", stratumString); 201 | confIntermediate.set("numUsers", numUsers); 202 | confIntermediate.set("numReducers", Integer.toString(numBlocks)); 203 | confIntermediate.set("numItems", numItems); 204 | confIntermediate.set("kValue", Integer.toString(kValue)); 205 | confIntermediate.set("stepSize", Double.toString(stepSize)); 206 | confIntermediate.set("lambda", Double.toString(lambda)); 207 | FileSystem fs = FileSystem.get(confIntermediate); 208 | 209 | Job jobIntermediate = new Job(confIntermediate, "DSGD: Intermediate"); 210 | 211 | FileInputFormat.setInputPaths(jobIntermediate, tempFactorsInput + "," + tempTrainRatings); 212 | FileOutputFormat.setOutputPath(jobIntermediate, new Path(tempFactorsOutput)); 213 | // Delete the output directory if it exists already 214 | FileSystem.get(confPreprocRatings).delete(new Path(tempFactorsOutput), true); 215 | jobIntermediate.setInputFormatClass(SequenceFileInputFormat.class); 216 | jobIntermediate.setOutputFormatClass(SequenceFileOutputFormat.class); 217 | 218 | jobIntermediate.setJarByClass(DSGDMain.class); 219 | jobIntermediate.setNumReduceTasks(numReducers); 220 | jobIntermediate.setOutputKeyClass(MatrixItem.class); 221 | jobIntermediate.setOutputValueClass(NullWritable.class); 222 | jobIntermediate.setMapOutputKeyClass(IntMatrixItemPair.class); 223 | jobIntermediate.setMapOutputValueClass(NullWritable.class); 224 | 225 | jobIntermediate.setMapperClass(DSGDIntermediateMapper.class); 226 | // job.setCombinerClass(IntermediateReducer.class); 227 | jobIntermediate.setReducerClass(DSGDIntermediateReducer.class); 228 | 229 | jobIntermediate.waitForCompletion(true); 230 | 231 | fs.delete(new Path(tempFactorsInput), true); 232 | fs.rename(new Path(tempFactorsOutput), new Path(tempFactorsInput)); 233 | } 234 | /** 235 | * jobRmse 236 | * Calculate RMSE on test data using current factors 237 | */ 238 | timeElapsed += (System.currentTimeMillis() - startTime); 239 | 240 | Configuration confRmse = new Configuration(); 241 | confRmse.set("numUsers", numUsers); 242 | confRmse.set("numReducers", Integer.toString(numBlocks)); 243 | confRmse.set("numItems", numItems); 244 | confRmse.set("timeElapsed", Long.toString(timeElapsed)); 245 | confRmse.set("kValue", Integer.toString(kValue)); 246 | Job jobRmse = new Job(confRmse, "DSGD: RMSE"); 247 | FileInputFormat.setInputPaths(jobRmse, tempFactorsInput + "," + tempTestRatings); 248 | FileOutputFormat.setOutputPath(jobRmse, new Path(rmseOutput + "/iter_" + a)); 249 | // Delete the output directory if it exists already 250 | FileSystem.get(confPreprocRatings).delete(new Path(rmseOutput + "/iter_" + a), true); 251 | jobRmse.setInputFormatClass(SequenceFileInputFormat.class); 252 | 253 | jobRmse.setJarByClass(DSGDMain.class); 254 | jobRmse.setNumReduceTasks(1); 255 | jobRmse.setOutputKeyClass(DoubleWritable.class); 256 | jobRmse.setOutputValueClass(DoubleWritable.class); 257 | jobRmse.setMapOutputKeyClass(MatrixItem.class); 258 | jobRmse.setMapOutputValueClass(NullWritable.class); 259 | jobRmse.setMapperClass(DSGDRmseMapper.class); 260 | jobRmse.setReducerClass(DSGDRmseReducer.class); 261 | 262 | jobRmse.waitForCompletion(true); 263 | 264 | startTime = System.currentTimeMillis(); 265 | } 266 | // Emit final factor matrices 267 | Configuration confFinal = new Configuration(); 268 | confFinal.set("kValue", Integer.toString(kValue)); 269 | confFinal.set("numUsers", numUsers); 270 | confFinal.set("numItems", numItems); 271 | 272 | Job jobFinal = new Job(confFinal, "ImprovedPBFS"); 273 | FileInputFormat.setInputPaths(jobFinal, new Path(tempFactorsInput)); 274 | FileOutputFormat.setOutputPath(jobFinal, new Path(finalFactorsOutput)); 275 | // Delete the output directory if it exists already 276 | FileSystem.get(confPreprocRatings).delete(new Path(finalFactorsOutput), true); 277 | 278 | jobFinal.setInputFormatClass(SequenceFileInputFormat.class); 279 | jobFinal.setJarByClass(DSGDMain.class); 280 | jobFinal.setNumReduceTasks(1); // Must use 1 reducer so that both U and M 281 | // get sent to same reducer 282 | 283 | jobFinal.setOutputKeyClass(Text.class); 284 | jobFinal.setOutputValueClass(NullWritable.class); 285 | 286 | jobFinal.setMapOutputKeyClass(MatrixItem.class); 287 | jobFinal.setMapOutputValueClass(NullWritable.class); 288 | 289 | jobFinal.setMapperClass(DSGDOutputFactorsMapper.class); 290 | // jobFinal.setCombinerClass(FinalReducer.class); 291 | jobFinal.setReducerClass(DSGDOutputFactorsReducer.class); 292 | 293 | jobFinal.waitForCompletion(true); 294 | // fsFinal.delete(tempInputDir, true); 295 | } 296 | 297 | } 298 | -------------------------------------------------------------------------------- /src/main/java/comparison/dsgd/IntMatrixItemPair.java: -------------------------------------------------------------------------------- 1 | package comparison.dsgd; 2 | 3 | import java.io.DataInput; 4 | import java.io.DataOutput; 5 | import java.io.IOException; 6 | 7 | import org.apache.hadoop.io.IntWritable; 8 | import org.apache.hadoop.io.WritableComparable; 9 | 10 | /** 11 | * Pair consisting of an IntWritable (the reducer number) and a MatrixItem Used 12 | * for 2 way sorting 13 | * 14 | * @author christopherjohnson 15 | * 16 | */ 17 | public class IntMatrixItemPair implements WritableComparable { 18 | 19 | private IntWritable reducerNum = new IntWritable(); 20 | private MatrixItem matItem = new MatrixItem(); 21 | 22 | public IntMatrixItemPair() { 23 | 24 | } 25 | 26 | public void set(IntWritable r, MatrixItem m) { 27 | this.reducerNum = new IntWritable(r.get()); 28 | this.matItem.set(m.getRow().get(), m.getColumn().get(), m 29 | .getValue().get(), m.getMatrixType().toString()); 30 | } 31 | 32 | public void readFields(DataInput arg0) throws IOException { 33 | reducerNum.readFields(arg0); 34 | matItem.readFields(arg0); 35 | } 36 | 37 | public void write(DataOutput arg0) throws IOException { 38 | reducerNum.write(arg0); 39 | matItem.write(arg0); 40 | } 41 | 42 | public int compareTo(IntMatrixItemPair o) { 43 | int cmp = reducerNum.compareTo(o.reducerNum); 44 | if (cmp != 0) { 45 | return cmp; 46 | } 47 | return matItem.compareTo(o.getMatItem()); 48 | } 49 | 50 | public IntWritable getReducerNum() { 51 | return reducerNum; 52 | } 53 | 54 | public MatrixItem getMatItem() { 55 | return matItem; 56 | } 57 | 58 | public void setReducerNum(IntWritable reducerNum) { 59 | this.reducerNum = reducerNum; 60 | } 61 | 62 | public void setMatItem(MatrixItem matItem) { 63 | this.matItem = matItem; 64 | } 65 | 66 | } 67 | -------------------------------------------------------------------------------- /src/main/java/comparison/dsgd/LongPair.java: -------------------------------------------------------------------------------- 1 | package comparison.dsgd; 2 | 3 | public class LongPair { 4 | 5 | public long first; 6 | public long second; 7 | 8 | public LongPair(long first, long second) { 9 | this.first = first; 10 | this.second = second; 11 | } 12 | 13 | public int compareTo(LongPair o) { 14 | if (this.first != o.first) { 15 | if (this.first < o.first) { 16 | return -1; 17 | } else { 18 | return 1; 19 | } 20 | } else if (this.second != o.second) { 21 | if (this.second < o.second) { 22 | return -1; 23 | } else { 24 | return 1; 25 | } 26 | } 27 | return 0; 28 | } 29 | 30 | } 31 | -------------------------------------------------------------------------------- /src/main/java/comparison/dsgd/MatrixItem.java: -------------------------------------------------------------------------------- 1 | package comparison.dsgd; 2 | 3 | import java.io.DataInput; 4 | import java.io.DataOutput; 5 | import java.io.IOException; 6 | 7 | import org.apache.hadoop.io.DoubleWritable; 8 | import org.apache.hadoop.io.IntWritable; 9 | import org.apache.hadoop.io.Text; 10 | import org.apache.hadoop.io.WritableComparable; 11 | 12 | public class MatrixItem implements WritableComparable, Cloneable { 13 | 14 | private IntWritable row = new IntWritable(); 15 | private IntWritable column = new IntWritable(); 16 | private DoubleWritable value = new DoubleWritable(); 17 | private Text matrixType = new Text(); 18 | 19 | public static final Text U_MATRIX = new Text("U_MATRIX"); 20 | public static final Text M_MATRIX = new Text("M_MATRIX"); 21 | public static final Text R_MATRIX = new Text("R_MATRIX"); 22 | 23 | public boolean isRatingsItem(){ 24 | return matrixType.equals(R_MATRIX); 25 | } 26 | 27 | public boolean isFactorItem(){ 28 | return !matrixType.equals(R_MATRIX); 29 | } 30 | 31 | public void set(int r, int c, double v, String m) { 32 | setRow(new IntWritable(r)); 33 | setColumn(new IntWritable(c)); 34 | setValue(new DoubleWritable(v)); 35 | setMatrixType(new Text(m)); 36 | } 37 | 38 | public void readFields(DataInput arg0) throws IOException { 39 | row.readFields(arg0); 40 | column.readFields(arg0); 41 | value.readFields(arg0); 42 | matrixType.readFields(arg0); 43 | } 44 | 45 | public void write(DataOutput arg0) throws IOException { 46 | row.write(arg0); 47 | column.write(arg0); 48 | value.write(arg0); 49 | matrixType.write(arg0); 50 | } 51 | 52 | public MatrixItem clone() { 53 | MatrixItem matItem = new MatrixItem(); 54 | matItem.set(this.row.get(), this.column.get(), this.value.get(), 55 | new String(this.matrixType.toString())); 56 | return matItem; 57 | } 58 | 59 | public int compareTo(MatrixItem o) { 60 | int compare = this.getMatrixType().compareTo(o.getMatrixType()); 61 | if(compare != 0){ 62 | if(this.isRatingsItem()){ 63 | return 1; 64 | } else if(o.isRatingsItem()){ 65 | return -1; 66 | } 67 | return compare; 68 | } 69 | // compare = this.getRow().compareTo(o.getRow()); 70 | // if(compare != 0){ 71 | // return compare; 72 | // } 73 | // compare = this.getColumn().compareTo(o.getColumn()); 74 | // if(compare != 0){ 75 | // return compare; 76 | // } 77 | // return this.getValue().compareTo(o.getValue()); 78 | double rand = Math.random(); 79 | if(rand > 0.5){ 80 | return 1; 81 | } else{ 82 | return -1; 83 | } 84 | } 85 | 86 | public IntWritable getRow() { 87 | return row; 88 | } 89 | 90 | public IntWritable getColumn() { 91 | return column; 92 | } 93 | 94 | public DoubleWritable getValue() { 95 | return value; 96 | } 97 | 98 | public void setRow(IntWritable row) { 99 | this.row = row; 100 | } 101 | 102 | public void setColumn(IntWritable column) { 103 | this.column = column; 104 | } 105 | 106 | public void setValue(DoubleWritable value) { 107 | this.value = value; 108 | } 109 | 110 | public Text getMatrixType() { 111 | return matrixType; 112 | } 113 | 114 | public void setMatrixType(Text matrixType) { 115 | this.matrixType = matrixType; 116 | } 117 | 118 | // public IntWritable getRow(); 119 | // public IntWritable getColumn(); 120 | // public MatrixItem clone(); 121 | 122 | } 123 | -------------------------------------------------------------------------------- /src/main/java/comparison/dsgd/MatrixUtils.java: -------------------------------------------------------------------------------- 1 | package comparison.dsgd; 2 | 3 | public class MatrixUtils { 4 | 5 | public static double dotProduct(double[] v1, double[] v2){ 6 | if(v1.length != v2.length){ 7 | System.out.println("lengths do not match!!"); 8 | return 0; 9 | } 10 | 11 | double total = 0; 12 | for(int i = 0; i < v1.length; ++i){ 13 | // System.out.println("v1{" + i + "]: " + v1[i]); 14 | // System.out.println("v2{" + i + "]: " + v2[i]); 15 | total += v1[i]*v2[i]; 16 | } 17 | 18 | return total; 19 | } 20 | 21 | public static class MatrixException extends Exception{ 22 | 23 | private static final long serialVersionUID = 6678191521290159097L; 24 | 25 | 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /src/main/java/comparison/dsgd/mapper/DSGDFinalMapper.java: -------------------------------------------------------------------------------- 1 | package comparison.dsgd.mapper; 2 | 3 | public class DSGDFinalMapper { 4 | 5 | } 6 | -------------------------------------------------------------------------------- /src/main/java/comparison/dsgd/mapper/DSGDIntermediateMapper.java: -------------------------------------------------------------------------------- 1 | package comparison.dsgd.mapper; 2 | 3 | import java.io.IOException; 4 | import java.util.StringTokenizer; 5 | 6 | import org.apache.hadoop.io.IntWritable; 7 | import org.apache.hadoop.io.NullWritable; 8 | import org.apache.hadoop.mapreduce.Mapper; 9 | 10 | import comparison.dsgd.DSGDMain; 11 | import comparison.dsgd.IntMatrixItemPair; 12 | import comparison.dsgd.MatrixItem; 13 | 14 | public class DSGDIntermediateMapper extends 15 | Mapper { 16 | 17 | int[] stratumArray; 18 | int numItems; 19 | int numUsers; 20 | int numReducers; 21 | IntMatrixItemPair intMatPair = new IntMatrixItemPair(); 22 | IntWritable reducerNum = new IntWritable(); 23 | MatrixItem matItem = new MatrixItem(); 24 | NullWritable nw = NullWritable.get(); 25 | 26 | @Override 27 | protected void setup(Context context) throws IOException, 28 | InterruptedException { 29 | String stratumString = context.getConfiguration().get("stratum"); 30 | StringTokenizer st = new StringTokenizer(stratumString); 31 | numReducers = Integer.parseInt(context.getConfiguration().get( 32 | "numReducers")); 33 | stratumArray = new int[numReducers]; 34 | for (int i = 0; i < numReducers; ++i) { 35 | stratumArray[i] = Integer.parseInt(st.nextToken()); 36 | } 37 | numUsers = Integer.parseInt(context.getConfiguration().get("numUsers")); 38 | numItems = Integer.parseInt(context.getConfiguration().get("numItems")); 39 | super.setup(context); 40 | } 41 | 42 | public void map(MatrixItem key, NullWritable value, Context context) 43 | throws IOException, InterruptedException { 44 | int row = key.getRow().get(); 45 | int column = key.getColumn().get(); 46 | if (key.isRatingsItem()) { 47 | int blockRow = DSGDMain.getBlockColumn(row, numUsers, numReducers); 48 | int blockColumn = DSGDMain.getBlockColumn(column, numItems, 49 | numReducers); 50 | // Emit if in current stratum 51 | if (stratumArray[blockColumn] == blockRow) { 52 | reducerNum.set(blockColumn); 53 | matItem = new MatrixItem(); 54 | matItem.set(key.getRow().get(), key.getColumn().get(), key 55 | .getValue().get(), key.getMatrixType().toString()); 56 | intMatPair.set(reducerNum, matItem); 57 | context.write(intMatPair, nw); 58 | } 59 | } else if (key.isFactorItem()) { 60 | if (key.getMatrixType().equals(MatrixItem.U_MATRIX)) { 61 | // System.out.println("UMatrix[" 62 | // + key.getRow() 63 | // + "][" 64 | // + key.getColumn().get() 65 | // + "]: " 66 | // + key.getValue()); 67 | int blockRow = DSGDMain.getBlockColumn(row, numUsers, 68 | numReducers); 69 | reducerNum.set(stratumArray[blockRow]); 70 | matItem = new MatrixItem(); 71 | matItem.set(key.getRow().get(), key.getColumn().get(), key 72 | .getValue().get(), key.getMatrixType().toString()); 73 | intMatPair.set(reducerNum, matItem); 74 | context.write(intMatPair, nw); 75 | } else { 76 | // System.out.println("MMatrix[" 77 | // + key.getRow() 78 | // + "][" 79 | // + key.getColumn().get() 80 | // + "]: " 81 | // + key.getValue()); 82 | int blockColumn = DSGDMain.getBlockColumn(row, numItems, 83 | numReducers); 84 | reducerNum.set(blockColumn); 85 | matItem = new MatrixItem(); 86 | matItem.set(key.getRow().get(), key.getColumn().get(), key 87 | .getValue().get(), key.getMatrixType().toString()); 88 | intMatPair.set(reducerNum, matItem); 89 | context.write(intMatPair, nw); 90 | } 91 | } 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /src/main/java/comparison/dsgd/mapper/DSGDOutputFactorsMapper.java: -------------------------------------------------------------------------------- 1 | package comparison.dsgd.mapper; 2 | 3 | import java.io.IOException; 4 | 5 | import org.apache.hadoop.io.NullWritable; 6 | import org.apache.hadoop.io.Text; 7 | import org.apache.hadoop.mapreduce.Mapper; 8 | 9 | import comparison.dsgd.MatrixItem; 10 | 11 | public class DSGDOutputFactorsMapper extends Mapper{ 12 | 13 | public void map(MatrixItem key, NullWritable value, Context context) 14 | throws IOException, InterruptedException { 15 | 16 | context.write(key, value); 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /src/main/java/comparison/dsgd/mapper/DSGDPreprocFactorMapper.java: -------------------------------------------------------------------------------- 1 | package comparison.dsgd.mapper; 2 | 3 | import java.io.IOException; 4 | 5 | import org.apache.hadoop.io.LongWritable; 6 | import org.apache.hadoop.io.NullWritable; 7 | import org.apache.hadoop.io.Text; 8 | import org.apache.hadoop.mapreduce.Mapper; 9 | 10 | import comparison.dsgd.MatrixItem; 11 | 12 | public class DSGDPreprocFactorMapper extends 13 | Mapper { 14 | 15 | MatrixItem matItem = new MatrixItem(); // reuse writable objects for efficiency 16 | //RatingsItem ratItem = new RatingsItem(); 17 | NullWritable nw = NullWritable.get(); 18 | 19 | @Override 20 | protected void setup(Context context) 21 | throws IOException, InterruptedException { 22 | context.write(new MatrixItem(), nw); 23 | super.setup(context); 24 | } 25 | 26 | @Override 27 | public void map(LongWritable key, Text value, Context context) 28 | throws IOException, InterruptedException { 29 | } 30 | } 31 | 32 | 33 | -------------------------------------------------------------------------------- /src/main/java/comparison/dsgd/mapper/DSGDPreprocMapper.java: -------------------------------------------------------------------------------- 1 | package comparison.dsgd.mapper; 2 | 3 | import java.io.IOException; 4 | import java.util.HashMap; 5 | import java.util.Set; 6 | import java.util.StringTokenizer; 7 | 8 | import org.apache.hadoop.io.LongWritable; 9 | import org.apache.hadoop.io.Text; 10 | import org.apache.hadoop.mapreduce.Mapper; 11 | import org.apache.hadoop.mapreduce.Mapper.Context; 12 | 13 | public class DSGDPreprocMapper extends 14 | Mapper { 15 | 16 | private HashMap> gramMap; 17 | 18 | @Override 19 | public void map(LongWritable key, Text value, Context context) 20 | throws IOException, InterruptedException { 21 | 22 | int numReducers = Integer.parseInt(context.getConfiguration().get( 23 | "numReducers")); 24 | long numUsers = Long.parseLong(context.getConfiguration().get( 25 | "numUsers")); 26 | long numItems = Long.parseLong(context.getConfiguration().get( 27 | "numItems")); 28 | long user = key.get(); 29 | long movie = 0; 30 | String line = ((Text) value).toString(); 31 | StringTokenizer itr = new StringTokenizer(line); 32 | 33 | 34 | // build gramMap 35 | while (itr.hasMoreTokens()) { 36 | 37 | user++; 38 | } 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /src/main/java/comparison/dsgd/mapper/DSGDPreprocRatingsMapper.java: -------------------------------------------------------------------------------- 1 | package comparison.dsgd.mapper; 2 | 3 | import java.io.IOException; 4 | import java.util.StringTokenizer; 5 | 6 | import org.apache.hadoop.io.LongWritable; 7 | import org.apache.hadoop.io.NullWritable; 8 | import org.apache.hadoop.io.Text; 9 | import org.apache.hadoop.mapreduce.Mapper; 10 | 11 | import comparison.dsgd.MatrixItem; 12 | 13 | public class DSGDPreprocRatingsMapper extends 14 | Mapper { 15 | 16 | // reuse writable objects for efficiency 17 | MatrixItem matItem = new MatrixItem(); 18 | NullWritable nw = NullWritable.get(); 19 | 20 | @Override 21 | public void map(LongWritable key, Text value, Context context) 22 | throws IOException, InterruptedException { 23 | 24 | String line = ((Text) value).toString(); 25 | StringTokenizer itr = new StringTokenizer(line); 26 | double dUser = Double.parseDouble(itr.nextToken()); //use double due to data's floating point format 27 | int user = (int)dUser; 28 | double dItem = Double.parseDouble(itr.nextToken()); 29 | int item = (int)dItem; 30 | double rating = Double.parseDouble(itr.nextToken()); 31 | 32 | if(rating != 0){ // don't do anything with un-rated items 33 | matItem.set(user, item, rating, MatrixItem.R_MATRIX.toString()); 34 | context.write(matItem, nw); 35 | } 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /src/main/java/comparison/dsgd/mapper/DSGDRmseMapper.java: -------------------------------------------------------------------------------- 1 | package comparison.dsgd.mapper; 2 | 3 | import java.io.IOException; 4 | 5 | import org.apache.hadoop.io.NullWritable; 6 | import org.apache.hadoop.mapreduce.Mapper; 7 | 8 | import comparison.dsgd.MatrixItem; 9 | 10 | public class DSGDRmseMapper extends Mapper{ 11 | 12 | NullWritable nw = NullWritable.get(); 13 | 14 | public void map(MatrixItem key, NullWritable value, Context context) 15 | throws IOException, InterruptedException { 16 | 17 | context.write(key, nw); 18 | 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /src/main/java/comparison/dsgd/reducer/DSGDFinalReducer.java: -------------------------------------------------------------------------------- 1 | package comparison.dsgd.reducer; 2 | 3 | public class DSGDFinalReducer { 4 | 5 | } 6 | -------------------------------------------------------------------------------- /src/main/java/comparison/dsgd/reducer/DSGDIntermediateReducer.java: -------------------------------------------------------------------------------- 1 | package comparison.dsgd.reducer; 2 | 3 | import java.io.IOException; 4 | 5 | import org.apache.hadoop.io.NullWritable; 6 | import org.apache.hadoop.mapreduce.Reducer; 7 | 8 | import comparison.dsgd.IntMatrixItemPair; 9 | import comparison.dsgd.MatrixItem; 10 | import comparison.dsgd.MatrixUtils; 11 | 12 | public class DSGDIntermediateReducer extends 13 | Reducer { 14 | 15 | double[][] UMatrix; 16 | double[][] MMatrix; 17 | boolean[][] UMatrixSet; 18 | boolean[][] MMatrixSet; 19 | MatrixItem matItem; 20 | NullWritable nw = NullWritable.get(); 21 | int kValue; 22 | double tau; 23 | double lambda; 24 | int numUsers; 25 | int numItems; 26 | 27 | @Override 28 | protected void cleanup(Context context) throws IOException, 29 | InterruptedException { 30 | 31 | // Emit UMatrix 32 | for (int i = 0; i < numUsers; ++i) { 33 | for (int j = 0; j < kValue; ++j) { 34 | if (UMatrixSet[i][j]) { 35 | if (Double.isNaN(UMatrix[i][j]) 36 | || Double.isInfinite(UMatrix[i][j])) { 37 | System.out.println("cool"); 38 | } 39 | matItem.set(i, j, UMatrix[i][j], 40 | MatrixItem.U_MATRIX.toString()); 41 | context.write(matItem, nw); 42 | } 43 | } 44 | } 45 | 46 | // Emit MMatrix 47 | for (int i = 0; i < numItems; ++i) { 48 | for (int j = 0; j < kValue; ++j) { 49 | if (MMatrixSet[i][j]) { 50 | if (Double.isNaN(MMatrix[i][j]) 51 | || Double.isInfinite(MMatrix[i][j])) { 52 | System.out.println("cool"); 53 | } 54 | matItem.set(i, j, MMatrix[i][j], 55 | MatrixItem.M_MATRIX.toString()); 56 | context.write(matItem, nw); 57 | } 58 | } 59 | } 60 | super.cleanup(context); 61 | } 62 | 63 | @Override 64 | protected void setup(Context context) throws IOException, 65 | InterruptedException { 66 | kValue = Integer.parseInt(context.getConfiguration().get("kValue")); 67 | tau = Double.parseDouble(context.getConfiguration().get("stepSize")); 68 | lambda = Double.parseDouble(context.getConfiguration().get("lambda")); 69 | numUsers = Integer.parseInt(context.getConfiguration().get("numUsers")); 70 | numItems = Integer.parseInt(context.getConfiguration().get("numItems")); 71 | 72 | UMatrix = new double[numUsers][kValue]; 73 | MMatrix = new double[numItems][kValue]; 74 | UMatrixSet = new boolean[numUsers][kValue]; 75 | MMatrixSet = new boolean[numItems][kValue]; 76 | 77 | super.setup(context); 78 | } 79 | 80 | protected void reduce(IntMatrixItemPair key, Iterable values, 81 | Context context) throws IOException, InterruptedException { 82 | 83 | matItem = key.getMatItem(); 84 | 85 | if (matItem.isFactorItem()) { 86 | if (matItem.getMatrixType().equals(MatrixItem.U_MATRIX)) { 87 | UMatrix[matItem.getRow().get()][matItem.getColumn().get()] = matItem 88 | .getValue().get(); 89 | UMatrixSet[matItem.getRow().get()][matItem.getColumn().get()] = true; 90 | // System.out.println("UMatrix[" 91 | // + matItem.getRow() 92 | // + "][" 93 | // + matItem.getColumn().get() 94 | // + "]: " 95 | // + matItem.getValue().get()); 96 | } else { 97 | MMatrix[matItem.getRow().get()][matItem.getColumn().get()] = matItem 98 | .getValue().get(); 99 | MMatrixSet[matItem.getRow().get()][matItem.getColumn().get()] = true; 100 | // System.out.println("MMatrix[" 101 | // + matItem.getRow() 102 | // + "][" 103 | // + matItem.getColumn().get() 104 | // + "]: " 105 | // + matItem.getValue().get()); 106 | } 107 | } else { // If we are now processing Ratings Items then we must have the 108 | // full Factor Matrices 109 | // Perform update 110 | // forall l 111 | // U_il <-- U_il + 2*\tau*M_jl*(r-r') + (\lambda/m)U_il 112 | // M_jl <-- M_jl + 2*\tau*U_il*(r-r') + (\lambda/n)M_jl 113 | int i = matItem.getRow().get(); 114 | int j = matItem.getColumn().get(); 115 | double oldUval; 116 | double oldMval; 117 | double newUval; 118 | double newMval; 119 | // Get predicted value 120 | // build U_i vector 121 | double[] UVector = UMatrix[i]; 122 | // System.out.println("UMatrix[" + i + "][" + 0 + "]: " + 123 | // UVector[0]); 124 | 125 | // build M_j vector 126 | double[] MVector = MMatrix[j]; 127 | // System.out.println("MMatrix[" + j + "][" + 0 + "]: " + 128 | // MVector[0]); 129 | 130 | // Compute ratings prediction using dot product of factor vectors 131 | double prediction = MatrixUtils.dotProduct(UVector, MVector); 132 | 133 | // Update Factor Matrices 134 | for (int l = 0; l < kValue; ++l) { 135 | oldUval = UMatrix[i][l]; 136 | oldMval = MMatrix[j][l]; 137 | newUval = oldUval - (2*tau* (oldMval 138 | * (prediction - matItem.getValue().get()) + lambda 139 | * oldUval)); 140 | newMval = oldMval - (2 * tau * (oldUval 141 | * (prediction - matItem.getValue().get()) + lambda 142 | * oldMval)); 143 | UMatrix[i][l] = newUval; 144 | MMatrix[j][l] = newMval; 145 | // if (Double.isNaN(newUval) || Double.isInfinite(newUval) 146 | // || newUval == Double.NEGATIVE_INFINITY) { 147 | // System.out.println("cool"); 148 | // } 149 | // if (Double.isNaN(newMval) || Double.isInfinite(newMval) 150 | // || newMval == Double.NEGATIVE_INFINITY) { 151 | // System.out.println("cool"); 152 | // } 153 | if(Math.abs(newUval) > 100 || Math.abs(newMval) > 100){ 154 | System.out.println("cool"); 155 | } 156 | } 157 | } 158 | } 159 | 160 | } 161 | -------------------------------------------------------------------------------- /src/main/java/comparison/dsgd/reducer/DSGDOutputFactorsReducer.java: -------------------------------------------------------------------------------- 1 | package comparison.dsgd.reducer; 2 | 3 | import java.io.IOException; 4 | 5 | import org.apache.hadoop.io.NullWritable; 6 | import org.apache.hadoop.io.Text; 7 | import org.apache.hadoop.mapreduce.Reducer; 8 | 9 | import comparison.dsgd.MatrixItem; 10 | 11 | /** 12 | * Note: DGSDFinalReducer requires a single reducer!! 13 | * 14 | * @author christopherjohnson 15 | * 16 | */ 17 | public class DSGDOutputFactorsReducer extends 18 | Reducer { 19 | 20 | double[][] UMatrix; 21 | double[][] MMatrix; 22 | MatrixItem matItem; 23 | NullWritable nw = NullWritable.get(); 24 | int kValue; 25 | double tau; 26 | double lambda; 27 | int numUsers; 28 | int numItems; 29 | Text text = new Text(); 30 | 31 | /** 32 | * Output factor matrices on cleanup 33 | */ 34 | @Override 35 | protected void cleanup(Context context) throws IOException, 36 | InterruptedException { 37 | text.set("U_Matrix"); 38 | context.write(text, nw); 39 | for(double[] UVector : UMatrix){ 40 | String current = ""; 41 | for(double d : UVector){ 42 | current += d + " "; 43 | } 44 | text.set(current); 45 | context.write(text, nw); 46 | } 47 | text.set("M_Matrix"); 48 | context.write(text, nw); 49 | 50 | for(double[] MVector : MMatrix){ 51 | String current = ""; 52 | for(double d : MVector){ 53 | current += d + " "; 54 | } 55 | text.set(current); 56 | context.write(text, nw); 57 | } 58 | super.cleanup(context); 59 | } 60 | 61 | @Override 62 | protected void setup(Context context) throws IOException, 63 | InterruptedException { 64 | kValue = Integer.parseInt(context.getConfiguration().get("kValue")); 65 | numUsers = Integer.parseInt(context.getConfiguration().get("numUsers")); 66 | numItems = Integer.parseInt(context.getConfiguration().get("numItems")); 67 | 68 | UMatrix = new double[numUsers][kValue]; 69 | MMatrix = new double[numItems][kValue]; 70 | 71 | 72 | super.setup(context); 73 | } 74 | 75 | @Override 76 | protected void reduce(MatrixItem key, Iterable values, Context context) 77 | throws IOException, InterruptedException { 78 | if(key.getMatrixType().equals(MatrixItem.U_MATRIX)){ 79 | UMatrix[key.getRow().get()][key.getColumn().get()] = key.getValue().get(); 80 | } else{ 81 | MMatrix[key.getRow().get()][key.getColumn().get()] = key.getValue().get(); 82 | } 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /src/main/java/comparison/dsgd/reducer/DSGDPreprocFactorReducer.java: -------------------------------------------------------------------------------- 1 | package comparison.dsgd.reducer; 2 | 3 | import java.io.IOException; 4 | 5 | import org.apache.hadoop.io.NullWritable; 6 | import org.apache.hadoop.mapreduce.Reducer; 7 | 8 | import comparison.dsgd.MatrixItem; 9 | 10 | public class DSGDPreprocFactorReducer extends Reducer{ 11 | 12 | NullWritable nw = NullWritable.get(); 13 | MatrixItem matItem = new MatrixItem(); 14 | 15 | @Override 16 | protected void setup(Context context) throws IOException, 17 | InterruptedException { 18 | // Emit Factor Items (on the reduce side be sure to ignore duplicates 19 | // that may occur due to multiple mappers) 20 | int numUsers = Integer.parseInt(context.getConfiguration().get( 21 | "numUsers")); 22 | int numItems = Integer.parseInt(context.getConfiguration().get( 23 | "numItems")); 24 | int kValue = Integer.parseInt(context.getConfiguration().get("kValue")); 25 | for (int i = 0; i < numUsers; ++i) { 26 | for (int j = 0; j < kValue; ++j) { 27 | matItem.set(i, j, Math.random(), MatrixItem.U_MATRIX.toString()); 28 | context.write(matItem, nw); 29 | } 30 | } 31 | for (int i = 0; i < numItems; ++i) { 32 | for (int j = 0; j < kValue; ++j) { 33 | matItem.set(i, j, Math.random(), MatrixItem.M_MATRIX.toString()); 34 | context.write(matItem, nw); 35 | } 36 | } 37 | super.setup(context); 38 | } 39 | 40 | protected void reduce(MatrixItem key, Iterable values, Context context) 41 | throws IOException, InterruptedException { 42 | 43 | } 44 | 45 | } 46 | -------------------------------------------------------------------------------- /src/main/java/comparison/dsgd/reducer/DSGDPreprocRatingsReducer.java: -------------------------------------------------------------------------------- 1 | package comparison.dsgd.reducer; 2 | 3 | import java.io.IOException; 4 | 5 | import org.apache.hadoop.io.NullWritable; 6 | import org.apache.hadoop.mapreduce.Reducer; 7 | 8 | import comparison.dsgd.MatrixItem; 9 | 10 | public class DSGDPreprocRatingsReducer extends Reducer{ 11 | 12 | NullWritable nw = NullWritable.get(); 13 | 14 | protected void reduce(MatrixItem key, Iterable values, Context context) 15 | throws IOException, InterruptedException { 16 | 17 | context.write(key, nw); 18 | 19 | } 20 | 21 | } 22 | -------------------------------------------------------------------------------- /src/main/java/comparison/dsgd/reducer/DSGDPreprocReducer.java: -------------------------------------------------------------------------------- 1 | package comparison.dsgd.reducer; 2 | 3 | public class DSGDPreprocReducer { 4 | 5 | } 6 | -------------------------------------------------------------------------------- /src/main/java/comparison/dsgd/reducer/DSGDRmseReducer.java: -------------------------------------------------------------------------------- 1 | package comparison.dsgd.reducer; 2 | 3 | import java.io.IOException; 4 | 5 | import org.apache.hadoop.io.DoubleWritable; 6 | import org.apache.hadoop.io.NullWritable; 7 | import org.apache.hadoop.mapreduce.Reducer; 8 | 9 | import comparison.dsgd.MatrixItem; 10 | import comparison.dsgd.MatrixUtils; 11 | import comparison.dsgd.MatrixUtils.MatrixException; 12 | 13 | public class DSGDRmseReducer extends 14 | Reducer { 15 | 16 | double[][] UMatrix; 17 | double[][] MMatrix; 18 | int kValue; 19 | double timeElapsed; 20 | int numUsers; 21 | int numItems; 22 | double sqError = 0; 23 | long numRatings = 0; 24 | 25 | @Override 26 | protected void cleanup(Context context) throws IOException, 27 | InterruptedException { 28 | // Calculate RMSE from Sq Error and emit along with time elapsed 29 | double rmse = sqError / (double)numRatings; 30 | context.write(new DoubleWritable(timeElapsed), new DoubleWritable(rmse)); 31 | 32 | super.cleanup(context); 33 | } 34 | 35 | @Override 36 | protected void setup(Context context) throws IOException, 37 | InterruptedException { 38 | // UMap = new HashMap(); 39 | // MMap = new HashMap(); 40 | kValue = Integer.parseInt(context.getConfiguration().get("kValue")); 41 | numUsers = Integer.parseInt(context.getConfiguration().get("numUsers")); 42 | numItems = Integer.parseInt(context.getConfiguration().get("numItems")); 43 | timeElapsed = Double.parseDouble(context.getConfiguration().get("timeElapsed")); 44 | 45 | UMatrix = new double[numUsers][kValue]; 46 | MMatrix = new double[numItems][kValue]; 47 | 48 | super.setup(context); 49 | } 50 | 51 | @Override 52 | protected void reduce(MatrixItem key, Iterable values, 53 | Context context) throws IOException, InterruptedException { 54 | 55 | if (key.isFactorItem()) { 56 | if (key.getMatrixType().equals(MatrixItem.U_MATRIX)) { 57 | UMatrix[key.getRow().get()][key.getColumn().get()] = key 58 | .getValue().get(); 59 | } else { 60 | // MMap.put(new LongPair(facItem.getRow().get(), facItem 61 | // .getColumn().get()), facItem.getValue().get()); 62 | MMatrix[key.getRow().get()][key.getColumn().get()] = key 63 | .getValue().get(); 64 | } 65 | } else { // If we are now processing RatingsItems then we must have the 66 | // full Factor Matrices 67 | int i = key.getRow().get(); 68 | int j = key.getColumn().get(); 69 | 70 | double[] UVector = UMatrix[i]; 71 | double[] MVector = MMatrix[j]; 72 | 73 | double prediction = MatrixUtils.dotProduct(UVector, MVector); 74 | 75 | sqError += ((prediction - key.getValue().get()) * (prediction - key 76 | .getValue().get())); 77 | numRatings++; 78 | } 79 | } 80 | 81 | } 82 | -------------------------------------------------------------------------------- /src/main/resources/log4j.properties: -------------------------------------------------------------------------------- 1 | log4j.rootLogger=INFO, A1 2 | log4j.appender.A1=org.apache.log4j.ConsoleAppender 3 | log4j.appender.A1.layout=org.apache.log4j.PatternLayout 4 | log4j.appender.A1.layout.ConversionPattern=%-4r [%t] %-5p %c %x - %m%n 5 | 6 | log4j.logger.backtype.storm.daemon=WARN 7 | log4j.logger.backtype.storm.serialization=WARN 8 | log4j.logger.backtype.storm.zookeeper=WARN 9 | log4j.logger.org.apache.zookeeper=ERROR 10 | -------------------------------------------------------------------------------- /src/main/resources/lotr.txt: -------------------------------------------------------------------------------- 1 | Three Rings for the Elven-kings under the sky, 2 | Seven for the Dwarf-lords in their halls of stone, 3 | Nine for Mortal Men doomed to die, 4 | One for the Dark Lord on his dark throne 5 | In the Land of Mordor where the Shadows lie. 6 | One Ring to rule them all, One Ring to find them, 7 | One Ring to bring them all and in the darkness bind them 8 | In the Land of Mordor where the Shadows lie. -------------------------------------------------------------------------------- /src/test/java/collabstream/AppTest.java: -------------------------------------------------------------------------------- 1 | package collabstream; 2 | 3 | import junit.framework.Test; 4 | import junit.framework.TestCase; 5 | import junit.framework.TestSuite; 6 | 7 | /** 8 | * Unit test for simple App. 9 | */ 10 | public class AppTest 11 | extends TestCase 12 | { 13 | /** 14 | * Create the test case 15 | * 16 | * @param testName name of the test case 17 | */ 18 | public AppTest( String testName ) 19 | { 20 | super( testName ); 21 | } 22 | 23 | /** 24 | * @return the suite of tests being tested 25 | */ 26 | public static Test suite() 27 | { 28 | return new TestSuite( AppTest.class ); 29 | } 30 | 31 | /** 32 | * Rigourous Test :-) 33 | */ 34 | public void testApp() 35 | { 36 | assertTrue( true ); 37 | } 38 | } 39 | --------------------------------------------------------------------------------